You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

自定义_one_hot_encode处理含未标记数据的标签时触发RuntimeError索引越界问题求助

问题分析与修复方案

首先,咱们来拆解这个错误的根源:

RuntimeError: index 1 is out of bounds for dimension 1 with size 1

这个错误出在scatter操作时,你试图用索引1去访问一个维度尺寸为1的张量。为什么会这样?看你提供的标签数组:[1,1,...9,1...],当你执行classes = classes[classes !=9]后,classes只剩下[1],所以self.n_classes=1,创建的self.one_hot_labels(n_nodes,1)的张量。但你直接用原始标签值1作为scatter的索引,而这个维度的合法索引只能是0,自然就越界了。

更深层的问题是:你的原始标签是-2、-1、0、1、2这类非连续从0开始的数值,不能直接当作独热编码的索引来用——独热编码的索引必须是从0n_classes-1的连续整数。

修复步骤

我们需要先把原始标签值(排除未标记的9)映射为合法的索引,再进行独热编码。修改后的_one_hot_encode函数如下:

def _one_hot_encode(self, labels):
    # 第一步:分离已标记和未标记数据
    unlabeled_mask = (labels == 9)
    labeled_labels = labels[~unlabeled_mask]
    
    # 第二步:获取所有已标记的唯一类别,并映射为0开始的连续索引
    unique_classes = torch.unique(labeled_labels)
    self.n_classes = unique_classes.size(0)
    
    # 创建类别到索引的映射字典
    class_to_idx = {cls.item(): idx for idx, cls in enumerate(unique_classes)}
    
    # 第三步:将原始标签转换为合法索引,未标记数据暂时设为0(后续会清零)
    labels = labels.clone()
    for cls, idx in class_to_idx.items():
        labels[labels == cls] = idx
    labels[unlabeled_mask] = 0  # 未标记数据先随便设个值,后面会清零
    
    # 第四步:创建独热编码张量并填充
    self.one_hot_labels = torch.zeros((self.n_nodes, self.n_classes), dtype=torch.float)
    # 现在labels里的索引都是0到n_classes-1的合法值了
    self.one_hot_labels = self.one_hot_labels.scatter(1, labels.unsqueeze(1), 1)
    
    # 第五步:清零未标记数据的独热编码行
    self.one_hot_labels[unlabeled_mask] = 0.0
    self.labeled_mask = ~unlabeled_mask

关键改进点

  • 类别到索引的映射:把原始的-2、-1、0、1、2这类值,转换成0、1、2、3、4(或当前数据存在的类别对应的连续索引),确保scatter时的索引不会越界。
  • 未标记数据的处理:先将未标记数据的标签临时设为0,完成scatter后再把整行清零,避免干扰已标记数据的编码。

额外优化(可选)

如果你确定数据集的有效类别固定是-2、-1、0、1、2,可以直接预先定义映射,不用每次从数据里取unique,这样更高效:

# 预先定义固定的类别映射
FIXED_CLASS_MAP = {-2:0, -1:1, 0:2, 1:3, 2:4}
self.n_classes = 5  # 固定5个类别

# 替换映射部分的代码
labels = labels.clone()
for cls, idx in FIXED_CLASS_MAP.items():
    labels[labels == cls] = idx
labels[unlabeled_mask] = 0

这样不管你的数据里包含哪些有效类别,都能保证索引的合法性,彻底避免索引越界的问题。

内容的提问来源于stack exchange,提问作者V_sqrt

火山引擎 最新活动