wget https://ml-platform-public-examples-cn-beijing.tos-cn-beijing.volces.com/python_sdk_installer/volcengine_ml_platform-1.1.14-py3-none-any.whl && pip install volcengine_ml_platform-1.1.14-py3-none-any.whl -i https://pypi.tuna.tsinghua.edu.cn/simple
在正式使用 SDK 之前需要先完成火山引擎账号的 AK / SK 的本地配置,用以在使用 SDK 访问机器学习平台时的身份校验。
AccessKeyFullAccess 的 IAM 策略。请使用真实的 AK/SK 替换下列方法中的
mkdir -p $HOME/.volc cat <<EOF > $HOME/.volc/credentials [default] access_key_id = <your access key> secret_access_key = <your secret access key> EOF cat <<EOF > $HOME/.volc/config [default] region = cn-beijing # 填写所在地域,目前仅支持 cn-beijing EOF
import volcengine_ml_platform as vemlp vemlp.init( ak='<your access key>', sk='<your secret access key>', region='cn-beijing', )
export VOLC_ACCESSKEY='<your access key>' export VOLC_SECRETKEY='<your secret access key>' export VOLC_REGION=cn-beijing
对于cn-shanghai, ap-southeast-1 region需要额外配置以下环境变量
#cn-shanghai export TK_HOST=https://tracking.ml-platform-cn-shanghai.volces.com export TOS_ENDPOINT_URL=http://tos-cn-shanghai.volces.com #ap-southeast-1 export TK_HOST=https://tracking.ml-platform-ap-southeast-1.volces.com export TOS_ENDPOINT_URL=http://tos-ap-southeast-1.volces.com
通过init()定义当前训练的实验名称(name)以及希望被托管的实验项目(project),开始运行后即可通过「实验管理」模块在对应的项目内查看该次实验的数据和信息。
import volcengine_ml_platform from volcengine_ml_platform import wandb volcengine_ml_platform.init() wandb.init( project="${experiment_name)", name="$(trial_name)", notes="$(trial_description)", tags=["baseline"], )
配置项:
选终端会输出 Tracking 链接
wandb: ⭐️ View project at https://console.volcengine.com/ml-platform/region:ml-platform+cn-shanghai/tracking/detail?Id=project_20260101_xxxxxxxx wandb: 🚀 View run at https://console.volcengine.com/ml-platform/region:ml-platform+cn-shanghai/tracking/detail?Id=project_20260101_xxxxxxxx&selectedTrial=run_20260101_xxxxxxxx
支持在 wandb.init 内指定 config,后续也可以更新 config。
平台会展示项目内所有运行记录的config
wandb.init(project="sandbox", config={"epochs": 4, "batch_size": 32}) # later wandb.config.update({"lr": 0.1, "channels": 16}) # or wandb.config.lr = 0.1 wandb.config.channels = 16
支持自动解析 argparse.Namespace。
import argparse parser = argparse.ArgumentParser() parser.add_argument('--batch-size', type=int, default=8) args = parser.parse_args() wandb.config.update(args) # adds all of the arguments as config variables
支持自动解析 absl.FLAGS 以及 tf.app.flags。
from absl import flags flags.DEFINE_string("model", None, "model to run") # name, default, help wandb.config.update(flags.FLAGS) # adds all absl flags to config
记录为config的超参数数据,可在概述页面进行查看,并和其他run进行对比。
可以使用 wandb.summary 记录训练的最终指标,与 config 类似支持跨 run 对比。
使用 wandb.log 记录的最后一个数据点将自动作为 summary。
平台会展示项目内所有运行记录的summary
wandb.summary.best_acc = 0.99 wandb.summary["best_loss"] = 0.01 wandb.summary.update({"final_acc": 0.88, "final_loss": 0.11})
用户可通过log方法记录不同类型的数据,已支持普通数值型数据
log方法提供了三个参数,如下所示:
def log(self, data: Dict[str, Any], step: Optional[int] = None, commit=True)
使用 wandb.log 记录指标关于 step 的变化情况。
for step in range(100): wandb.log({ "train/loss": 3 - step / 50 + random.random(), "train/acc": step / 125 + random.random() * 0.2, }, step=step)
同一个 project 下所有 run 的图表支持聚合展示:
自定义X轴
# define our custom x axis metric wandb.define_metric("custom_step") # define which metrics will be plotted against it wandb.define_metric("validation_loss", step_metric="custom_step") for i in range(10): log_dict = { "train_loss": 1/(i+1), "custom_step": i**2, "validation_loss": 1/(i+1) } wandb.log(log_dict)
支持记录关于 step 的直方图:
for step in range(10): scores = [random.gauss(step, 1) for _ in range(200)] # table = wandb.Table(data=[[x] for x in scores], columns=["scores"]) wandb.log({"my_histogram": wandb.Histogram(scores)}, step=step)
前端展示:
支持记录表格,可内嵌图片:
my_data = [] EXTS = ("gif","mp4","webm") for i in range(4): h,w,c=64,64,3 image_array1=np.random.randint( low=0,high=256,size=(100,100,3),dtype=np.uint8 ) frames=np.random.randint(low=0,high=256,size=(10,3,100,100),dtype=np.uint8) video=wandb.Video(frames,fps=4,format=EXTS[i % len(EXTS)]) my_data.append( [ i, wandb.Image(image_array1), wandb.Audio( np.random.randn(1000, 2), sample_rate=44100, caption="test_caption", ), video, ], ) columns=["id", "image", "audio", "video"] my_table=wandb.Table(data=my_data,columns=columns) wandb.log({"table_key":my_table})
前端展示:
wandb.log( { "images": [ wandb.Image( np.random.randint( low=0, high=256, size=(100 * (5 + 1), 100 * (5 + 1), 3), dtype=np.uint8, ), caption="test_caption", ) for _ in range(7) ], "image": wandb.Image( np.random.randint( low=0, high=256, size=(100, 100, 3), dtype=np.uint8, ), caption="test_caption", ), } )
wandb.log( { "audios": [ wandb.Audio( np.random.randn(1000, 2), sample_rate=44100, caption="test_caption", ) for _ in range(7) ], "audio": wandb.Audio( np.random.randn(1000, 2), sample_rate=44100, caption="test_caption", ), } )
wandb.log( { "videos": [ wandb.Video( np.random.randint( low=0, high=256, size=(10, 3, 100, 100), dtype=np.uint8 ), fps=4, caption="test_caption", ) for i in range(7) ], "video": wandb.Video( np.random.randint( low=0, high=256, size=(10, 3, 100, 100), dtype=np.uint8 ), fps=4, caption="test_caption", ), } )
url = '<img src="https://upload.wikimedia.org/wikipedia/commons/3/3a/Cat03.jpg">' def new_html(): return wandb.Html(data=url) wandb.log({'html': new_html()}) table = wandb.Table(data=[[new_html() for _ in range(4)]], columns=['1', '2', '3', '4']) wandb.log({'table': table})
用法参考https://docs.wandb.ai/guides/track/log/plots#docusaurus_skipToContent_fallback
除wandb原生支持的图表外,还支持了其他类型
# define table data table = wandb.Table(data=[ ["a", 1.0, 0.8, 0.9], ["b", 0.5, 0.6, 0.9], ["c", 0.1, 0.6, 0.3], ["d", 0.3, 0.5, 1.5], ["e", 0.4, 0.4, 0.1], ], columns=["round", "gpt-4", "gpt-3.5", "llama"]) # define box columns and title box = wandb.plot.box( table, columns=["gpt-4", "gpt-3.5", "llama"], # 选择table中的哪些列作为数据输入 title="box demo", ) wandb.log({"box": box})
说明
最佳实践法则:
场景 | SDK记录方式 | 预期 |
|---|---|---|
默认场景 | wandb.init(),不指定id/resume | 每次运行都会生成新的实验 |
开发机/本地调试场景: | wandb.init(id="abc", resume=None),会覆盖之前的运行记录 | |
开发机/本地调试场景: | 方法一:wandb.init(id="abc", resume=True)。明确指定id,来继续track,适用于开发机在同时跑多项不同的实验,希望针对性的进行resume,不要记错乱了。 | |
可抢占训练场景(跑到一半被kill了,排到资源后自动从上次继续训练) | 用户指定unique id by wandb.init(id="xyz", resume="allow"),再次运行时会load之前的一些变量,继续track | 同上,resume进行实验数据拼接 |
运行失败场景(和可抢占的区别是意外被kill了,没有提前指定id,还是希望能从上次继续训练) | 运行失败后,可以在前端查看该run的unique_id。后续指定wandb.init(id="$unique_id", resume="must"),继续track | resume=must时,系统会强制寻找对应实验进行拼接,无法找到时,会记为异常
|
训练复制场景(多人复用相同的训练配置,可能每个人对实验是否resume的需要不一样,尤其需要注意防止污染或覆盖) |
|
|
对于即将提交训练,代码内已经包含tensorboard实验打点的用户,可通过指定在 wandb.init 指定 sync_tensorboard=True进行数据同步,减少代码改动。
目前仅针对折线图和指标记录进行同步。其他图表类型和超参数数据仍需按照本文档中的SDK语句进行补充。
注意:wandb.init要早于SummaryWriter对象初始化
import wandb from torch.utils.tensorboard import SummaryWriter wandb.init(project="demo-sync-tb", sync_tensorboard=True) with SummaryWriter("./board") as writer: max_step = 100 for step in range(max_step): writer.add_scalar("train/acc", step / max_step, global_step=step) writer.add_scalar("train/loss", 1 - step / max_step, global_step=step)
方法一:python代码
wandb.TrackingApi().sync_tensorboard( "/home/user/repos/reckon/wandb/event/", # tf_event root dir project="ci", # project_name name="tfevent_file_name", # run_name )
调用后,任务会结束打点,状态转为FINISHED。
不调用时,会在代码运行结束后停止。
wandb.finish()