You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

PyTorch MLP训练MNIST:PyTorch与Keras数据集准确率差异问题

Keras MNIST数据集在PyTorch MLP中准确率极低的原因与修复方案

我仔细看了你的代码和问题描述,发现核心问题出在数据归一化的细节差异上,这也是导致两个数据集训练结果天差地别的关键:

问题根源

PyTorch自带的MNIST数据集使用了ptv.transforms.ToTensor(),这个转换操作会自动把原始MNIST数据中0-255的整数像素值,缩放到0.0-1.0的浮点数范围。而你从Keras加载的MNIST数据,只是做了形状调整和类型转换(转成float32),但没有做归一化处理,输入数据仍然是0-255的大数值。

你的MLP模型是针对0-1范围的输入设计的,当输入突然变成255倍的量级时,配合你设置的较大学习率(lr=1),模型的权重会在训练过程中出现梯度爆炸或无法有效收敛的情况,最终只能得到和随机猜测差不多的准确率(约10%,对应10个手写数字类别)。

修复步骤

只需要在加载Keras MNIST数据时,添加归一化操作即可,修改load_mnist()函数中的两行代码:

def load_mnist():
    (x, y), (x_test, y_test) = mnist.load_data()
    # 新增除以255.0的归一化操作,将像素值缩放到0.0-1.0范围
    x = x.reshape((-1, 1, 28, 28)).astype(np.float32) / 255.0
    x_test = x_test.reshape((-1, 1, 28, 28)).astype(np.float32) / 255.0
    y = y.astype(np.int64)
    y_test = y_test.astype(np.int64)
    print("x.shape", x.shape, "y.shape", y.shape, "\nx_test.shape", x_test.shape, "y_test.shape", y_test.shape, )
    return x, y, x_test, y_test

修改后重新训练,你会发现使用Keras MNIST数据集训练的模型准确率,会和PyTorch数据集训练的结果持平(达到0.95以上)。

额外验证小技巧

你可以在代码里打印两种数据集的输入范围,直观确认归一化的差异:

  • 对于PyTorch的数据集,取一个batch的输入,打印inputs.min()inputs.max(),会输出0.01.0
  • 对于未归一化的Keras数据集,同样打印的话会得到0.0255.0,这就能清晰看到问题所在。

内容的提问来源于stack exchange,提问作者YF Yan

火山引擎 最新活动