You need to enable JavaScript to run this app.
机器学习平台

机器学习平台

复制全文
实验管理
使用SDK进行实验记录
复制全文
使用SDK进行实验记录

1 安装
wget https://ml-platform-public-examples-cn-beijing.tos-cn-beijing.volces.com/python_sdk_installer/volcengine_ml_platform-1.1.14-py3-none-any.whl && pip install volcengine_ml_platform-1.1.14-py3-none-any.whl -i https://pypi.tuna.tsinghua.edu.cn/simple

2 鉴权配置

在正式使用 SDK 之前需要先完成火山引擎账号的 AK / SK 的本地配置,用以在使用 SDK 访问机器学习平台时的身份校验。

  1. 登录火山引擎控制台并前往【密钥管理】查看当前账号的 AK / SK。
    1. 若当前账号为子账号,需要具备 AccessKeyFullAccess 的 IAM 策略。

请使用真实的 AK/SK 替换下列方法中的

方法一(通过配置文件配置):

mkdir -p $HOME/.volc

cat <<EOF > $HOME/.volc/credentials
[default]
access_key_id     = <your access key>
secret_access_key = <your secret access key>
EOF

cat <<EOF > $HOME/.volc/config
[default]
region       = cn-beijing  # 填写所在地域,目前仅支持 cn-beijing
EOF

方法二(通过代码配置):

import volcengine_ml_platform as vemlp
vemlp.init(
    ak='<your access key>',
    sk='<your secret access key>',
    region='cn-beijing',
)

方法三(通过环境变量配置):

export VOLC_ACCESSKEY='<your access key>'
export VOLC_SECRETKEY='<your secret access key>'
export VOLC_REGION=cn-beijing

其他配置

对于cn-shanghai, ap-southeast-1 region需要额外配置以下环境变量

#cn-shanghai
export TK_HOST=https://tracking.ml-platform-cn-shanghai.volces.com
export TOS_ENDPOINT_URL=http://tos-cn-shanghai.volces.com

#ap-southeast-1
export TK_HOST=https://tracking.ml-platform-ap-southeast-1.volces.com
export TOS_ENDPOINT_URL=http://tos-ap-southeast-1.volces.com

3 实验记录

指定实验项目和实验名称

通过init()定义当前训练的实验名称(name)以及希望被托管的实验项目(project),开始运行后即可通过「实验管理」模块在对应的项目内查看该次实验的数据和信息。

import volcengine_ml_platform
from volcengine_ml_platform import wandb 

volcengine_ml_platform.init()
wandb.init(
    project="${experiment_name)", 
    name="$(trial_name)",
    notes="$(trial_description)",
    tags=["baseline"],
)

配置项:

  • project,必要参数。为实验项目的名称。长度上限128,支持中英文、数字及-_./@。
  • name,非必要参数。为当前实验的名称,长度上限128,支持中英文、数字及-_./@;如不指定,系统会随机生成
  • notes,非必要参数。为当前实验的描述,默认为空字符串
  • Tags, 非必要参数。为当前实验的标签信息,后续可用于分组归类和快速筛

选终端会输出 Tracking 链接

wandb: ⭐️ View project at https://console.volcengine.com/ml-platform/region:ml-platform+cn-shanghai/tracking/detail?Id=project_20260101_xxxxxxxx
wandb: 🚀 View run at https://console.volcengine.com/ml-platform/region:ml-platform+cn-shanghai/tracking/detail?Id=project_20260101_xxxxxxxx&selectedTrial=run_20260101_xxxxxxxx

记录超参配置(config)

支持在 wandb.init 内指定 config,后续也可以更新 config。
平台会展示项目内所有运行记录的config

wandb.init(project="sandbox", config={"epochs": 4, "batch_size": 32})
# later
wandb.config.update({"lr": 0.1, "channels": 16})
# or
wandb.config.lr = 0.1
wandb.config.channels = 16

支持自动解析 argparse.Namespace。

import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--batch-size', type=int, default=8)
args = parser.parse_args()
wandb.config.update(args) # adds all of the arguments as config variables

支持自动解析 absl.FLAGS 以及 tf.app.flags。

from absl import flags
flags.DEFINE_string("model", None, "model to run") # name, default, help
wandb.config.update(flags.FLAGS) # adds all absl flags to config

记录为config的超参数数据,可在概述页面进行查看,并和其他run进行对比。

记录评估指标(summary)

可以使用 wandb.summary 记录训练的最终指标,与 config 类似支持跨 run 对比。
使用 wandb.log 记录的最后一个数据点将自动作为 summary。
平台会展示项目内所有运行记录的summary

wandb.summary.best_acc = 0.99
wandb.summary["best_loss"] = 0.01
wandb.summary.update({"final_acc": 0.88, "final_loss": 0.11})

记录图表 (Metrics)

用户可通过log方法记录不同类型的数据,已支持普通数值型数据
log方法提供了三个参数,如下所示:

  • data: Dict[str, Any]格式,value为int/float(scalar)或tracking定义的其他类型。
  • step: tracking log存在全局唯一的step概念,step强制递增(会过滤掉非递增的数据)
    • 如果指定step,则以指定step为准
    • 如果不指定step,则本次step等于global step + 1
  • commit: 如果本次step的数据分多次上传,可指定commit=False。commit=True后,global step自增1
def log(self, data: Dict[str, Any], step: Optional[int] = None, commit=True)

曲线图 Scalar

使用 wandb.log 记录指标关于 step 的变化情况。

for step in range(100):
    wandb.log({
        "train/loss": 3 - step / 50 + random.random(),
        "train/acc": step / 125 + random.random() * 0.2,
    }, step=step)

同一个 project 下所有 run 的图表支持聚合展示:

自定义X轴

# define our custom x axis metric
wandb.define_metric("custom_step")
# define which metrics will be plotted against it
wandb.define_metric("validation_loss", step_metric="custom_step")

for i in range(10):
  log_dict = {
      "train_loss": 1/(i+1),
      "custom_step": i**2,
      "validation_loss": 1/(i+1)   
  }
  wandb.log(log_dict)

直方图 Histogram

支持记录关于 step 的直方图:

for step in range(10):
    scores = [random.gauss(step, 1) for _ in range(200)]
    # table = wandb.Table(data=[[x] for x in scores], columns=["scores"])
    wandb.log({"my_histogram": wandb.Histogram(scores)}, step=step)

前端展示:

自定义表格 Table

支持记录表格,可内嵌图片:

my_data = []
EXTS = ("gif","mp4","webm")
for i in range(4):
    h,w,c=64,64,3
    image_array1=np.random.randint(
        low=0,high=256,size=(100,100,3),dtype=np.uint8
        )
    frames=np.random.randint(low=0,high=256,size=(10,3,100,100),dtype=np.uint8)
    video=wandb.Video(frames,fps=4,format=EXTS[i % len(EXTS)])
    my_data.append(
       [
        i,
        wandb.Image(image_array1),
        wandb.Audio(
        np.random.randn(1000, 2),
        sample_rate=44100,
        caption="test_caption",
        ),
        video,
        ],
     )

columns=["id", "image", "audio", "video"]
my_table=wandb.Table(data=my_data,columns=columns)
wandb.log({"table_key":my_table})

前端展示:

记录多媒体(Media)

图片 image

wandb.log(
    {
        "images": [
            wandb.Image(
                np.random.randint(
                    low=0,
                    high=256,
                    size=(100 * (5 + 1), 100 * (5 + 1), 3),
                    dtype=np.uint8,
                ),
                caption="test_caption",
            )
            for _ in range(7)
        ],
        "image": wandb.Image(
            np.random.randint(
                low=0,
                high=256,
                size=(100, 100, 3),
                dtype=np.uint8,
            ),
            caption="test_caption",
        ),
    }
)

音频 audio

wandb.log(
    {
        "audios": [
            wandb.Audio(
                np.random.randn(1000, 2),
                sample_rate=44100,
                caption="test_caption",
            )
            for _ in range(7)
        ],
        "audio": wandb.Audio(
            np.random.randn(1000, 2),
            sample_rate=44100,
            caption="test_caption",
        ),
    }
)

视频 video

wandb.log(
    {
        "videos": [
            wandb.Video(
                np.random.randint(
                    low=0, high=256, size=(10, 3, 100, 100), dtype=np.uint8
                ),
                fps=4,
                caption="test_caption",
            )
            for i in range(7)
        ],
        "video": wandb.Video(
            np.random.randint(
                low=0, high=256, size=(10, 3, 100, 100), dtype=np.uint8
            ),
            fps=4,
            caption="test_caption",
        ),
    }
)

HTML

url = '<img src="https://upload.wikimedia.org/wikipedia/commons/3/3a/Cat03.jpg">'

def new_html():
    return wandb.Html(data=url)

wandb.log({'html': new_html()})

table = wandb.Table(data=[[new_html() for _ in range(4)]], columns=['1', '2', '3', '4'])
wandb.log({'table': table})

记录自定义图表

用法参考https://docs.wandb.ai/guides/track/log/plots#docusaurus_skipToContent_fallback
除wandb原生支持的图表外,还支持了其他类型

箱型图

# define table data
table = wandb.Table(data=[
    ["a", 1.0, 0.8, 0.9], 
    ["b", 0.5, 0.6, 0.9], 
    ["c", 0.1, 0.6, 0.3], 
    ["d", 0.3, 0.5, 1.5], 
    ["e", 0.4, 0.4, 0.1],
], columns=["round", "gpt-4", "gpt-3.5", "llama"])
# define box columns and title
box = wandb.plot.box(
    table, 
    columns=["gpt-4", "gpt-3.5", "llama"],  # 选择table中的哪些列作为数据输入
    title="box demo",
)
wandb.log({"box": box})

Resume功能对中断实验进行补充

说明

最佳实践法则:

  • 没有明确希望进行中断续连先前实验数据需求时,选择默认(不指定ID 和 resume)
  • 不希望自己数据因为重跑,fork等被污染,严格设置 resume="never"
  • 不希望生成垃圾运行记录,希望覆盖数据的,设置wandb.init(id="abc", resume=None)
  • 其他场景,参考不同的resume模式

场景

SDK记录方式

预期

默认场景

wandb.init(),不指定id/resume

每次运行都会生成新的实验

开发机/本地调试场景:
不希望产生垃圾运行记录

wandb.init(id="abc", resume=None),会覆盖之前的运行记录

开发机/本地调试场景:
希望resume上次的运行记录

方法一:wandb.init(id="abc", resume=True)。明确指定id,来继续track,适用于开发机在同时跑多项不同的实验,希望针对性的进行resume,不要记错乱了。
方法二:wandb.init(resume=True)。没有明确指定id,但系统会自行帮助resume。(首次会在当前目录下记录meta文件,后续会从meta文件获取unique_id)

可抢占训练场景(跑到一半被kill了,排到资源后自动从上次继续训练)

用户指定unique id by wandb.init(id="xyz", resume="allow"),再次运行时会load之前的一些变量,继续track

同上,resume进行实验数据拼接

运行失败场景(和可抢占的区别是意外被kill了,没有提前指定id,还是希望能从上次继续训练)

运行失败后,可以在前端查看该run的unique_id。后续指定wandb.init(id="$unique_id", resume="must"),继续track
Image

resume=must时,系统会强制寻找对应实验进行拼接,无法找到时,会记为异常

  • 指定WANDB_RUN_ID环境变量等同于wandb.init(id="") 对自定义训练更加友好

训练复制场景(多人复用相同的训练配置,可能每个人对实验是否resume的需要不一样,尤其需要注意防止污染或覆盖)

  1. 对于原始训练的owner,如不希望resume污染先前数据,wandb.init(id="$unique_id", resume="never")。如果id相同,新的实验会crash来保护以前的实验数据
  2. 对于复制他人实验的用户,确保id不冲突,避免污染先前数据
  • 指定WANDB_RESUME环境变量,默认never也可保障不被污染

Tensorboard 自动同步

实时数据同步

对于即将提交训练,代码内已经包含tensorboard实验打点的用户,可通过指定在 wandb.init 指定 sync_tensorboard=True进行数据同步,减少代码改动。
目前仅针对折线图和指标记录进行同步。其他图表类型和超参数数据仍需按照本文档中的SDK语句进行补充。
注意:wandb.init要早于SummaryWriter对象初始化

import wandb
from torch.utils.tensorboard import SummaryWriter

wandb.init(project="demo-sync-tb", sync_tensorboard=True)

with SummaryWriter("./board") as writer:
    max_step = 100
    for step in range(max_step):
        writer.add_scalar("train/acc", step / max_step, global_step=step)
        writer.add_scalar("train/loss", 1 - step / max_step, global_step=step)

历史数据同步

方法一:python代码

wandb.TrackingApi().sync_tensorboard(
    "/home/user/repos/reckon/wandb/event/", # tf_event root dir
    project="ci", # project_name
    name="tfevent_file_name", # run_name
)

结束任务

调用后,任务会结束打点,状态转为FINISHED。
不调用时,会在代码运行结束后停止。

wandb.finish()
最近更新时间:2026.01.13 11:50:39
这个页面对您有帮助吗?
有用
有用
无用
无用