You need to enable JavaScript to run this app.
导航

使用SDK进行数据导出

最近更新时间2023.09.05 10:56:29

首次发布时间2023.09.05 10:56:29

初始化

import wandb
import pandas as pd

project = "ci"                   # 项目名称
id = "run_20230714_bb4b99f4"     # run_id
api = wandb.TrackingApi()        
run = api.run(project=project, run_id=id)

导出概览(超参数、指标)数据

>>> config = run.config
>>> pd.DataFrame(config)

   init_conf  is_cpu  is_gpu     lr  ...  optim  update_nested.batch_sizes  update_nested.epoch update_nested.schedulers.cosine_annealing
0          1   False    True  0.001  ...   adam                         22                   11                                      True
1          1   False    True  0.001  ...   adam                         33                   11                                      True

[2 rows x 11 columns]
>>> run.summary
{'best_loss': 0.12345, 'eval/acc': 0.99, 'inf': 'inf', 'nan': 'nan', 'train/acc': 0.9756182518521165, 'train/loss': 1.7398966523873338}

导出训练数据

# 导出所有图表
>>> h = run.history()
>>> pd.DataFrame(h)
    train/loss  eval.imagenet.loss.v2t  step  train/acc  eval.acc  eval.imagenet.loss.t2v
0     3.430109                       0     0   0.115875  0.123173                       1
1     3.595317                       1     1   0.017146  0.087839                       0
2     3.257887                       2     2   0.035479  0.137657                      -1
3     3.349832                       3     3   0.066500  0.064585                      -2
4     3.872932                       4     4   0.076133  0.198770                      -3
..         ...                     ...   ...        ...       ...                     ...
95    1.181176                      95    95   0.909349  0.948310                     -94
96    1.311653                      96    96   0.885969  0.892896                     -95
97    2.038165                      97    97   0.927819  0.814461                     -96
98    1.370099                      98    98   0.916609  0.919585                     -97
99    1.739897                      99    99   0.975618  0.800897                     -98

# 指定图表名称
>>> names = run.list_entity_names()
>>> h = run.history(name=[names[0], names[1]])
    eval.imagenet.loss.v2t  eval.acc  step
0                        0  0.028165     0
1                        1  0.105383     1
2                        2  0.081822     2
3                        3  0.071960     3
4                        4  0.069980     4
..                     ...       ...   ...
95                      95  0.959726    95
96                      96  0.904129    96
97                      97  0.930942    97
98                      98  0.934407    98
99                      99  0.905092    99

run.history()方法返回的数据与平台界面展示的数据完全一致,但是平台界面为了兼顾前端性能,返回的是经过采样的数据。如果需要看全量数据,需要使用run.scan_history()方法

导出自定义表格数据

>>> table_names = run.list_table_names() # 获取所有表格的名称
>>> t = run.get_table(table_names[0])    # 指定其中一个表格,获取数据
>>> pd.DataFrame(t)
   int     float  ...                                              audio                                              video
0   66  0.226396  ...  {'_type': 'audio-file', 'caption': 'test_capti...  {'_type': 'video-file', 'height': 100, 'path':...
1   58  0.324750  ...  {'_type': 'audio-file', 'caption': 'test_capti...  {'_type': 'video-file', 'height': 100, 'path':...
2   14  0.581490  ...  {'_type': 'audio-file', 'caption': 'test_capti...  {'_type': 'video-file', 'height': 100, 'path':...
3   73  0.431927  ...  {'_type': 'audio-file', 'caption': 'test_capti...  {'_type': 'video-file', 'height': 100, 'path':...