You need to enable JavaScript to run this app.
优惠活动
大模型
产品
解决方案
定价
更多
文档控制台
免费开始使用

存在多个正确标签时,如何训练机器学习分类模型?

哈哈,这个场景我太熟了——尤其是标签存在层级关联或者标注时允许多个正确结果的情况,结合你给的鸟类数据集例子,我给你几个落地性强的实现思路:

1. 最简方案:修改评估逻辑的多类别分类

如果你的数据集标签没有强层级关系,这个方法最快上手:

  • 训练阶段:把每个样本的多个正确标签,每次迭代随机选一个作为训练的「目标标签」,用常规的多类别交叉熵损失(比如PyTorch的CrossEntropyLoss,TensorFlow的SparseCategoricalCrossentropy)训练模型。要是想充分利用所有标注信息,也可以把每个样本复制多份,每份对应一个正确标签,再批量训练。
  • 推理阶段:模型输出所有8个标签的概率后,取概率最高的标签,只要这个标签在该样本的正确标签集合里(比如你的例子里,样本正确标签是SparrowBird,预测其中任意一个都算对),就判定预测成功。
  • 优点:完全不用改模型结构,代码改动极小,适合快速验证效果。
2. 更精准的方案:多标签分类+宽松评估

这个方法能充分利用所有标注的正确标签信息,让模型学习更全面:

  • 训练阶段:把每个样本的多个正确标签转换成one-hot向量(比如样本正确标签是SparrowBird,那one-hot向量中这两个位置设为1,其余为0),用**二元交叉熵损失(BCEWithLogitsLoss)**训练多标签分类模型。
  • 推理阶段:模型输出每个标签的概率后,你有两种判断方式:
    • 取概率最高的标签,检查是否在正确集合中;
    • 设置一个概率阈值(比如0.5),只要有一个正确标签的概率超过阈值,就算预测正确。
  • 你的例子适配:模型输出Sparrow概率0.6、Bird概率0.7,其他标签概率0.1,不管是取最高概率的Bird,还是触发阈值判断,都符合要求。
3. 进阶方案:基于标签层级的分类

如果你的8个标签存在明确的层级关系(比如BirdSparrowEagle等子类的父类),可以针对性构建层级分类逻辑:

  • 训练阶段:让模型同时学习父类和子类标签,比如先判断样本是否属于Bird,再判断具体是哪种鸟类。可以用分层损失(父类标签损失+子类标签损失)来训练,让模型理解标签间的从属关系。
  • 推理阶段:只要预测的标签属于正确标签的层级链(比如样本正确标签是Sparrow,预测Bird或者Sparrow都算对),就判定正确。

代码示例(PyTorch,多标签方案)

假设你的数据已经处理为特征张量X(形状(batch_size, 20))和one-hot标签张量y(形状(batch_size, 8)):

import torch
import torch.nn as nn
import torch.optim as optim

# 定义简单的全连接模型
class BirdClassifier(nn.Module):
    def __init__(self, num_features=20, num_labels=8):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(num_features, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, num_labels)  # 输出logits,后续用sigmoid转概率
        )
    
    def forward(self, x):
        return self.layers(x)

# 初始化组件
model = BirdClassifier()
criterion = nn.BCEWithLogitsLoss()  # 多标签分类专用损失
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# 训练循环示例
for epoch in range(100):
    model.train()
    optimizer.zero_grad()
    outputs = model(X_train)
    loss = criterion(outputs, y_train.float())
    loss.backward()
    optimizer.step()
    if epoch % 10 == 0:
        print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

# 推理与宽松评估
model.eval()
with torch.no_grad():
    outputs = torch.sigmoid(model(X_test))  # 把logits转为0-1的概率
    pred_labels = outputs.argmax(dim=1)  # 取概率最高的标签
    
    # 假设test_true_labels是每个样本的正确标签集合(比如[[1,3], [0], ...],数字对应标签索引)
    correct_count = 0
    for pred, true_ids in zip(pred_labels, test_true_labels):
        if pred.item() in true_ids:
            correct_count += 1
    loose_accuracy = correct_count / len(test_true_labels)
    print(f"宽松评估准确率: {loose_accuracy:.4f}")

关键注意点

  • 评估指标自定义:别用常规的多类别准确率,一定要用「宽松准确率」——只要预测标签在正确集合中就算对。
  • 标签权重调整:如果某些标签出现频率极低,可以给它们的损失加权重,避免模型偏向高频标签。
  • 数据复用:样本量小时,把每个样本按正确标签数量复制多份,每份对应一个正确标签,能有效提升模型的学习效果。

内容的提问来源于stack exchange,提问作者SingularRiver

火山引擎 最新活动