600类BERT文本分类性能不佳,求分层模型具体实现方案
粗粒度-细粒度分层文本分类的两种实现方案
嗨,针对你遇到的600类文本分类BERT效果不佳的问题,那个粗粒度-细粒度分层的思路确实能有效降低单模型的分类难度,我来给你拆解两种具体实现方式,你可以根据自己的需求和资源情况选择:
一、分开训练模型+逻辑串联(轻量易调试)
这种方式是把任务拆成两步,先训粗分类模型,再给每个粗粒度类别单独训细分类模型,推理时通过逻辑判断串联起来,适合刚上手调试的场景。
具体步骤:
先给数据打粗粒度标签
把600个细粒度类别按照业务逻辑或语义聚类成N个粗粒度类别(比如20-50个,根据你的数据分布调整),比如“电子产品”“家居用品”“食品”这类大类,确保每个粗类下的细类数量相对均衡。训练粗粒度分类器
用你的全量数据训练一个BERT文本分类模型,输出是粗粒度类别。代码片段大概是这样(基于transformers库):from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments import torch # 加载预训练模型和tokenizer tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=20) # 20是粗粒度类别数 # 准备你的数据集(这里假设已经处理成Dataset对象) train_dataset = ... val_dataset = ... # 训练参数设置(Colab里可以调整batch_size和epochs) training_args = TrainingArguments( output_dir='./coarse_model', per_device_train_batch_size=16, per_device_eval_batch_size=16, num_train_epochs=3, logging_dir='./logs', evaluation_strategy="epoch" ) # 启动训练 trainer = Trainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=val_dataset ) trainer.train()训练细粒度分类器
把数据按粗粒度类别拆分,每个粗类下单独训练一个细分类模型。这里可以复用粗分类模型的BERT层输出做特征(不用再从头训BERT,节省资源),只训练顶层的全连接层:# 加载训练好的粗分类模型,冻结BERT层 coarse_model = BertForSequenceClassification.from_pretrained('./coarse_model') for param in coarse_model.bert.parameters(): param.requires_grad = False # 替换分类头为对应粗类的细类别数(比如某个粗类下有30个细类) coarse_model.classifier = torch.nn.Linear(768, 30) # 用该粗类下的数据集训练 trainer = Trainer( model=coarse_model, args=training_args, train_dataset=fine_train_dataset, eval_dataset=fine_val_dataset ) trainer.train()推理时的逻辑串联
收到输入文本后,先过粗分类模型得到粗类别,再调用对应粗类的细分类模型得到最终结果:def predict(text): # 粗分类预测 inputs = tokenizer(text, return_tensors='pt').to('cuda') coarse_logits = coarse_model(**inputs).logits coarse_label = torch.argmax(coarse_logits, dim=1).item() # 根据粗类别调用对应的细分类模型 fine_model = load_fine_model_by_coarse_label(coarse_label) fine_logits = fine_model(**inputs).logits fine_label = torch.argmax(fine_logits, dim=1).item() return fine_label
二、统一端到端训练(高效且性能更稳定)
这种方式是在BERT之上同时接粗粒度和细粒度两个分类头,训练时同时优化两个分类任务的损失,模型可以共享BERT提取的特征,不用维护多个模型,适合追求整体性能的场景。
具体实现:
自定义多任务BERT模型
继承BertPreTrainedModel,添加两个分类头:from transformers import BertPreTrainedModel, BertModel class HierarchicalBERT(BertPreTrainedModel): def __init__(self, config, num_coarse_labels, num_fine_labels_per_coarse): super().__init__(config) self.bert = BertModel(config) self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) # 粗粒度分类头 self.coarse_classifier = torch.nn.Linear(config.hidden_size, num_coarse_labels) # 细粒度分类头:用字典存储每个粗类对应的分类层 self.fine_classifiers = torch.nn.ModuleDict({ str(i): torch.nn.Linear(config.hidden_size, num_fine_labels_per_coarse[i]) for i in range(num_coarse_labels) }) self.init_weights() def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, coarse_labels=None, fine_labels=None): outputs = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids) pooled_output = outputs[1] pooled_output = self.dropout(pooled_output) # 粗粒度预测 coarse_logits = self.coarse_classifier(pooled_output) loss = None # 计算损失:同时计算粗和细的损失 if coarse_labels is not None: # 粗粒度损失 coarse_loss = torch.nn.CrossEntropyLoss()(coarse_logits, coarse_labels) loss = coarse_loss # 细粒度损失:只计算对应粗类的细分类损失 fine_loss = 0.0 for i in torch.unique(coarse_labels): mask = (coarse_labels == i) if torch.any(mask): fine_logits = self.fine_classifiers[str(i.item())](pooled_output[mask]) fine_loss += torch.nn.CrossEntropyLoss()(fine_logits, fine_labels[mask]) # 加权求和两个损失,权重可自行调整 loss += 0.8 * fine_loss return (loss, coarse_logits, fine_logits) if loss is not None else (coarse_logits, fine_logits)训练端到端模型
准备好同时包含粗粒度和细粒度标签的数据集,然后启动训练:from sklearn.metrics import accuracy_score # 初始化模型:假设20个粗类,每个粗类的细类数存在num_fine_per_coarse字典里 model = HierarchicalBERT.from_pretrained( 'bert-base-uncased', num_coarse_labels=20, num_fine_labels_per_coarse=num_fine_per_coarse ) # 自定义评估指标,同时计算粗、细分类的准确率 def compute_metrics(eval_pred): logits, labels = eval_pred coarse_logits, fine_logits_batch = logits coarse_labels, fine_labels = labels coarse_preds = torch.argmax(torch.tensor(coarse_logits), dim=1) fine_preds = [] # 按粗类别匹配细分类预测 for idx, cl in enumerate(coarse_preds): fine_logit = fine_logits_batch[idx][cl.item()] fl = torch.argmax(torch.tensor(fine_logit), dim=0).item() fine_preds.append(fl) coarse_acc = accuracy_score(coarse_labels, coarse_preds) fine_acc = accuracy_score(fine_labels, fine_preds) return {'coarse_accuracy': coarse_acc, 'fine_accuracy': fine_acc} # 启动训练 training_args = TrainingArguments( output_dir='./hierarchical_model', per_device_train_batch_size=16, num_train_epochs=4, evaluation_strategy="epoch" ) trainer = Trainer( model=model, args=training_args, train_dataset=train_dataset, # 数据集需包含coarse_label和fine_label字段 eval_dataset=val_dataset, compute_metrics=compute_metrics ) trainer.train()推理时的操作
先预测粗类别,再用对应粗类的细分类头得到结果:def predict_hierarchical(text): inputs = tokenizer(text, return_tensors='pt').to('cuda') coarse_logits, fine_logits_batch = model(**inputs) coarse_label = torch.argmax(coarse_logits, dim=1).item() fine_logit = fine_logits_batch[0][coarse_label] fine_label = torch.argmax(fine_logit, dim=0).item() return fine_label
额外优化建议
- 粗粒度类别划分要合理:尽量把语义或业务相关的细类归到同一个粗类,让模型更容易学习到分层特征。
- 处理数据不平衡:600类肯定有长尾问题,分层后可以对每个粗类下的细类做过采样/欠采样,或者用加权损失。
- Colab资源优化:用TPU加速训练,或者选用轻量版BERT(比如DistilBERT),减少训练时间。
内容的提问来源于stack exchange,提问作者Zopui




