自定义_one_hot_encode处理含未标记数据的标签时触发RuntimeError索引越界问题求助
问题分析与修复方案
首先,咱们来拆解这个错误的根源:
RuntimeError: index 1 is out of bounds for dimension 1 with size 1
这个错误出在scatter操作时,你试图用索引1去访问一个维度尺寸为1的张量。为什么会这样?看你提供的标签数组:[1,1,...9,1...],当你执行classes = classes[classes !=9]后,classes只剩下[1],所以self.n_classes=1,创建的self.one_hot_labels是(n_nodes,1)的张量。但你直接用原始标签值1作为scatter的索引,而这个维度的合法索引只能是0,自然就越界了。
更深层的问题是:你的原始标签是-2、-1、0、1、2这类非连续从0开始的数值,不能直接当作独热编码的索引来用——独热编码的索引必须是从0到n_classes-1的连续整数。
修复步骤
我们需要先把原始标签值(排除未标记的9)映射为合法的索引,再进行独热编码。修改后的_one_hot_encode函数如下:
def _one_hot_encode(self, labels): # 第一步:分离已标记和未标记数据 unlabeled_mask = (labels == 9) labeled_labels = labels[~unlabeled_mask] # 第二步:获取所有已标记的唯一类别,并映射为0开始的连续索引 unique_classes = torch.unique(labeled_labels) self.n_classes = unique_classes.size(0) # 创建类别到索引的映射字典 class_to_idx = {cls.item(): idx for idx, cls in enumerate(unique_classes)} # 第三步:将原始标签转换为合法索引,未标记数据暂时设为0(后续会清零) labels = labels.clone() for cls, idx in class_to_idx.items(): labels[labels == cls] = idx labels[unlabeled_mask] = 0 # 未标记数据先随便设个值,后面会清零 # 第四步:创建独热编码张量并填充 self.one_hot_labels = torch.zeros((self.n_nodes, self.n_classes), dtype=torch.float) # 现在labels里的索引都是0到n_classes-1的合法值了 self.one_hot_labels = self.one_hot_labels.scatter(1, labels.unsqueeze(1), 1) # 第五步:清零未标记数据的独热编码行 self.one_hot_labels[unlabeled_mask] = 0.0 self.labeled_mask = ~unlabeled_mask
关键改进点
- 类别到索引的映射:把原始的
-2、-1、0、1、2这类值,转换成0、1、2、3、4(或当前数据存在的类别对应的连续索引),确保scatter时的索引不会越界。 - 未标记数据的处理:先将未标记数据的标签临时设为0,完成
scatter后再把整行清零,避免干扰已标记数据的编码。
额外优化(可选)
如果你确定数据集的有效类别固定是-2、-1、0、1、2,可以直接预先定义映射,不用每次从数据里取unique,这样更高效:
# 预先定义固定的类别映射 FIXED_CLASS_MAP = {-2:0, -1:1, 0:2, 1:3, 2:4} self.n_classes = 5 # 固定5个类别 # 替换映射部分的代码 labels = labels.clone() for cls, idx in FIXED_CLASS_MAP.items(): labels[labels == cls] = idx labels[unlabeled_mask] = 0
这样不管你的数据里包含哪些有效类别,都能保证索引的合法性,彻底避免索引越界的问题。
内容的提问来源于stack exchange,提问作者V_sqrt




