使用Keras训练CIFAR-10模型筛选类别时触发IndexError问题求助
解决Keras CIFAR-10部分类别训练的标签索引越界问题
我一眼就看出问题所在了:你用keras.utils.to_categorical()的时候犯了一个常见的小错误——这个函数要求输入的类别标签必须是从0开始的连续整数(也就是范围在0到num_classes-1之间),但你现在保留的原始标签是[2,3,5,6,7],这些数值明显大于num_classes=5对应的最大索引4,自然会触发索引越界的报错。
具体解决方案
你需要把原始的类别标签映射成0到4的连续整数,步骤很简单:
- 先创建一个标签映射字典,给每个选中的原始类别分配新的连续索引
- 对筛选后的训练集和测试集标签进行替换
- 再调用
to_categorical转换为独热编码
下面是修改后的完整代码:
import numpy as np from tensorflow import keras selected_classes = [2, 3, 5, 6, 7] # 核心:创建原始类别到新连续索引的映射 label_map = {original_cls: new_idx for new_idx, original_cls in enumerate(selected_classes)} # 处理训练集 print('train\n', x_train.shape, y_train.shape) x_train_filtered = [ex for ex, ey in zip(x_train, y_train) if ey in selected_classes] # 替换标签为新的连续索引 y_train_filtered = [label_map[ey] for ex, ey in zip(x_train, y_train) if ey in selected_classes] x_train = np.stack(x_train_filtered) y_train = np.stack(y_train_filtered).reshape(-1, 1) print(x_train.shape, y_train.shape) # 处理测试集 print('test\n', x_test.shape, y_test.shape) x_test_filtered = [ex for ex, ey in zip(x_test, y_test) if ey in selected_classes] y_test_filtered = [label_map[ey] for ex, ey in zip(x_test, y_test) if ey in selected_classes] x_test = np.stack(x_test_filtered) y_test = np.stack(y_test_filtered).reshape(-1, 1) print(x_test.shape, y_test.shape) num_classes = len(selected_classes) # 现在可以正常转换为独热编码了 y_train = keras.utils.to_categorical(y_train, num_classes) y_test = keras.utils.to_categorical(y_test, num_classes)
为什么这样能解决问题?
我们通过enumerate给每个选中的原始类别分配了从0开始的索引:
- 原始类别2 → 新索引0
- 原始类别3 → 新索引1
- 原始类别5 → 新索引2
- 原始类别6 → 新索引3
- 原始类别7 → 新索引4
处理后的标签范围正好是0到4,完美匹配num_classes=5的要求,调用to_categorical时就不会再出现索引越界的错误了。
内容的提问来源于stack exchange,提问作者S.Haviv




