You need to enable JavaScript to run this app.
优惠活动
大模型
产品
解决方案
定价
更多
文档控制台
免费开始使用

如何在Matplotlib单个画布中实现多变量动态曲线图?

嘿,很高兴看到你已经搞定了单变量的动态损失曲线!要在同一画布实现多变量动态图其实不难,咱们可以从核心逻辑扩展入手,我给你拆解一下具体方案:

实现多变量动态图的具体方案

本质上就是为每个要监控的变量创建独立的线条对象,然后在更新逻辑里批量处理这些线条的数据,最后统一刷新画布就行。下面是适配你的神经网络训练场景的完整示例:

1. 初始化多线条与画布基础设置

import numpy as np
import matplotlib.pyplot as plt

# 先设置画布和基础坐标轴
plt.figure(figsize=(8, 6))
plt.xlabel('iter')
plt.ylabel('Metrics')
# 根据你的变量范围调整坐标轴,比如损失0-10、准确率0-1,这里可以设一个兼容的范围
plt.axis([0, set_size, 0, 10])

# 初始化多条线:比如损失(蓝色)、准确率(红色),记得加图例区分
cost_plot, = plt.plot([], [], 'b-', label='Training Cost')
acc_plot, = plt.plot([], [], 'r-', label='Training Accuracy')
plt.legend(loc='upper right')  # 显示图例方便区分变量

2. 编写支持多变量的更新函数

你可以把原来的单线条更新函数扩展成批量版本,这样复用性更强:

def update_multiple_lines(line_objects, new_data_list):
    # line_objects是所有线条对象的列表,new_data_list是对应的数据组,每组格式为[iter, value]
    for line, new_data in zip(line_objects, new_data_list):
        # 追加新的x、y数据
        line.set_xdata(np.append(line.get_xdata(), new_data[0]))
        line.set_ydata(np.append(line.get_ydata(), new_data[1]))
    plt.draw()  # 统一刷新画布

3. 训练循环中调用更新

在你的训练迭代里,每次拿到itercost和其他变量(比如accuracy)后,这样调用即可:

# 假设循环中每次迭代能获取到iter、cost、accuracy
...
# 把线条对象和对应的数据组传入更新函数
update_multiple_lines([cost_plot, acc_plot], [[iter, cost], [iter, accuracy]])
plt.pause(0.001)  # 短暂暂停实现动态效果

额外优化:双Y轴适配差异大的变量

如果你的变量数值范围差异很大(比如损失在0-10,准确率在0-1),单Y轴会导致其中一个变量的曲线被压缩,这时候可以用双Y轴来优化展示效果:

fig, ax1 = plt.subplots(figsize=(8, 6))

# 第一个Y轴:负责展示损失
ax1.set_xlabel('iter')
ax1.set_ylabel('Cost', color='b')
ax1.set_ylim(0, 10)
cost_plot, = ax1.plot([], [], 'b-')

# 第二个Y轴:负责展示准确率,与ax1共享X轴
ax2 = ax1.twinx()
ax2.set_ylabel('Accuracy', color='r')
ax2.set_ylim(0, 1)
acc_plot, = ax2.plot([], [], 'r-')

# 适配双Y轴的更新函数
def update_dual_axis(line_cost, line_acc, iter, cost_val, acc_val):
    line_cost.set_xdata(np.append(line_cost.get_xdata(), iter))
    line_cost.set_ydata(np.append(line_cost.get_ydata(), cost_val))
    line_acc.set_xdata(np.append(line_acc.get_xdata(), iter))
    line_acc.set_ydata(np.append(line_acc.get_ydata(), acc_val))
    fig.canvas.draw()

# 循环中调用
...
update_dual_axis(cost_plot, acc_plot, iter, cost, accuracy)
plt.pause(0.001)

这样就能在同一画布上清晰展示多个动态变化的训练指标啦,不管是同Y轴还是双Y轴场景都能覆盖!

内容的提问来源于stack exchange,提问作者Konrad Rzońca

火山引擎 最新活动