You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

使用Keras训练CIFAR-10模型筛选类别时触发IndexError问题求助

解决Keras CIFAR-10部分类别训练的标签索引越界问题

我一眼就看出问题所在了:你用keras.utils.to_categorical()的时候犯了一个常见的小错误——这个函数要求输入的类别标签必须是从0开始的连续整数(也就是范围在0到num_classes-1之间),但你现在保留的原始标签是[2,3,5,6,7],这些数值明显大于num_classes=5对应的最大索引4,自然会触发索引越界的报错。

具体解决方案

你需要把原始的类别标签映射成0到4的连续整数,步骤很简单:

  1. 先创建一个标签映射字典,给每个选中的原始类别分配新的连续索引
  2. 对筛选后的训练集和测试集标签进行替换
  3. 再调用to_categorical转换为独热编码

下面是修改后的完整代码:

import numpy as np
from tensorflow import keras

selected_classes = [2, 3, 5, 6, 7]
# 核心:创建原始类别到新连续索引的映射
label_map = {original_cls: new_idx for new_idx, original_cls in enumerate(selected_classes)}

# 处理训练集
print('train\n', x_train.shape, y_train.shape)
x_train_filtered = [ex for ex, ey in zip(x_train, y_train) if ey in selected_classes]
# 替换标签为新的连续索引
y_train_filtered = [label_map[ey] for ex, ey in zip(x_train, y_train) if ey in selected_classes]
x_train = np.stack(x_train_filtered)
y_train = np.stack(y_train_filtered).reshape(-1, 1)
print(x_train.shape, y_train.shape)

# 处理测试集
print('test\n', x_test.shape, y_test.shape)
x_test_filtered = [ex for ex, ey in zip(x_test, y_test) if ey in selected_classes]
y_test_filtered = [label_map[ey] for ex, ey in zip(x_test, y_test) if ey in selected_classes]
x_test = np.stack(x_test_filtered)
y_test = np.stack(y_test_filtered).reshape(-1, 1)
print(x_test.shape, y_test.shape)

num_classes = len(selected_classes)

# 现在可以正常转换为独热编码了
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)

为什么这样能解决问题?

我们通过enumerate给每个选中的原始类别分配了从0开始的索引:

  • 原始类别2 → 新索引0
  • 原始类别3 → 新索引1
  • 原始类别5 → 新索引2
  • 原始类别6 → 新索引3
  • 原始类别7 → 新索引4

处理后的标签范围正好是0到4,完美匹配num_classes=5的要求,调用to_categorical时就不会再出现索引越界的错误了。

内容的提问来源于stack exchange,提问作者S.Haviv

火山引擎 最新活动