PyTorch MLP训练MNIST:PyTorch与Keras数据集准确率差异问题
Keras MNIST数据集在PyTorch MLP中准确率极低的原因与修复方案
我仔细看了你的代码和问题描述,发现核心问题出在数据归一化的细节差异上,这也是导致两个数据集训练结果天差地别的关键:
问题根源
PyTorch自带的MNIST数据集使用了ptv.transforms.ToTensor(),这个转换操作会自动把原始MNIST数据中0-255的整数像素值,缩放到0.0-1.0的浮点数范围。而你从Keras加载的MNIST数据,只是做了形状调整和类型转换(转成float32),但没有做归一化处理,输入数据仍然是0-255的大数值。
你的MLP模型是针对0-1范围的输入设计的,当输入突然变成255倍的量级时,配合你设置的较大学习率(lr=1),模型的权重会在训练过程中出现梯度爆炸或无法有效收敛的情况,最终只能得到和随机猜测差不多的准确率(约10%,对应10个手写数字类别)。
修复步骤
只需要在加载Keras MNIST数据时,添加归一化操作即可,修改load_mnist()函数中的两行代码:
def load_mnist(): (x, y), (x_test, y_test) = mnist.load_data() # 新增除以255.0的归一化操作,将像素值缩放到0.0-1.0范围 x = x.reshape((-1, 1, 28, 28)).astype(np.float32) / 255.0 x_test = x_test.reshape((-1, 1, 28, 28)).astype(np.float32) / 255.0 y = y.astype(np.int64) y_test = y_test.astype(np.int64) print("x.shape", x.shape, "y.shape", y.shape, "\nx_test.shape", x_test.shape, "y_test.shape", y_test.shape, ) return x, y, x_test, y_test
修改后重新训练,你会发现使用Keras MNIST数据集训练的模型准确率,会和PyTorch数据集训练的结果持平(达到0.95以上)。
额外验证小技巧
你可以在代码里打印两种数据集的输入范围,直观确认归一化的差异:
- 对于PyTorch的数据集,取一个batch的输入,打印
inputs.min()和inputs.max(),会输出0.0和1.0; - 对于未归一化的Keras数据集,同样打印的话会得到
0.0和255.0,这就能清晰看到问题所在。
内容的提问来源于stack exchange,提问作者YF Yan




