You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

Python 3.12.3环境下使用Transformers的Seq2SeqTrainer运行训练代码时报大量错误求助

Python 3.12.3环境下使用Transformers的Seq2SeqTrainer运行训练代码时报大量错误求助

各位大佬好!我在Ubuntu服务器上用Python 3.12.3搭了环境,专门开了虚拟环境隔离依赖,想跑一个粤译中的Seq2Seq训练任务,但启动训练后抛出一堆错误,折腾好久没搞定,来求助大家!

环境配置细节

  • 系统:Ubuntu服务器
  • Python版本:3.12.3
  • 依赖管理:虚拟环境(确保不影响系统全局包)
  • 核心依赖版本:
    • accelerate==1.11.0
    • transformers==4.57.1
    • tokenizers==0.22.0
    • datasets、evaluate、numpy等配套依赖也已正常安装

复现代码

以下是我整理的完整测试代码(原代码最后有截断,我补全了Trainer的关键参数):

from transformers import (
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    BartForConditionalGeneration,
    BertTokenizer,
    DataCollatorForSeq2Seq,
    EarlyStoppingCallback,
    EncoderDecoderModel,
    IntervalStrategy,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    Text2TextGenerationPipeline
)
from datasets import Dataset, DatasetDict, load_dataset, disable_progress_bar
import evaluate
import numpy as np

default_dataset = "raptorkwok/cantonese-traditional-chinese-parallel-corpus-gen3"
base_model_name = "fnlp/bart-base-chinese"

# 加载数据集
canton_ds = load_dataset(default_dataset)
yuezh_train = canton_ds["train"]
yuezh_test = canton_ds["test"]
yuezh_val = canton_ds["validation"]

print("Train Dataset Count: ", len(yuezh_train))
print("Test Dataset Count: ", len(yuezh_test))
print("Validation Dataset Count: ", len(yuezh_val))

yuezh_master = DatasetDict({
    "train": yuezh_train,
    "test": yuezh_test,
    "val": yuezh_val
})

# 加载模型与Tokenizer
base_tokenizer = BertTokenizer.from_pretrained(base_model_name)
base_model = BartForConditionalGeneration.from_pretrained(base_model_name, output_hidden_states=True)

# 数据预处理函数
def _filter_valid_examples(example):
    return (
        isinstance(example["yue"], str) and example["yue"].strip() and
        isinstance(example["zh"], str) and example["zh"].strip()
    )

def _preprocess_dataset(examples):
    inputs = [text for text in examples["yue"]]
    targets = [text for text in examples["zh"]]
    model_inputs = base_tokenizer(inputs, text_target=targets, max_length=550, truncation=True)
    return model_inputs

def _postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [[label.strip()] for label in labels]
    return preds, labels

# 处理数据集
filtered_yuezh_master = yuezh_master.filter(_filter_valid_examples)
tokenized_yuezh_master = filtered_yuezh_master.map(_preprocess_dataset, batched=True)
tokenized_yuezh_master = tokenized_yuezh_master.remove_columns(yuezh_train.column_names)

# 加载评估指标与数据收集器
metric_bleu = evaluate.load("sacrebleu")
metric_chrf = evaluate.load("chrf")
data_collator = DataCollatorForSeq2Seq(tokenizer=base_tokenizer, model=base_model)

def _compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    decoded_preds = base_tokenizer.batch_decode(preds, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, base_tokenizer.pad_token_id)
    decoded_labels = base_tokenizer.batch_decode(labels, skip_special_tokens=True)
    decoded_preds, decoded_labels = _postprocess_text(decoded_preds, decoded_labels)
    
    result_bleu = metric_bleu.compute(predictions=decoded_preds, references=decoded_labels, tokenize='zh')
    result_chrf = metric_chrf.compute(predictions=decoded_preds, references=decoded_labels, word_order=2)
    
    results = {"bleu": result_bleu["score"], "chrf": result_chrf["score"]}
    prediction_lens = [np.count_nonzero(pred != base_tokenizer.pad_token_id) for pred in preds]
    results["gen_len"] = np.mean(prediction_lens)
    results = {k: round(v, 4) for k, v in results.items()}
    return results

# 训练参数配置
model_path = "test_minimal"
batch_size = 8
num_epochs = 1

training_args = Seq2SeqTrainingArguments(
    output_dir = model_path,
    evaluation_strategy = IntervalStrategy.STEPS,
    logging_strategy = "no",
    optim = "adamw_torch",
    eval_steps = 10000,
    save_steps = 10000,
    learning_rate = 2e-5,
    per_device_train_batch_size = batch_size,
    per_device_eval_batch_size = batch_size,
    weight_decay = 0.01,
    save_total_limit = 1,
    num_train_epochs = num_epochs,
    predict_with_generate=True,
    remove_unused_columns=True,
    fp16 = True,
    push_to_hub = False,
    metric_for_best_model = "bleu",
    load_best_model_at_end = True,
    report_to = "wandb"
)

# 初始化Trainer并启动训练
trainer = Seq2SeqTrainer(
    model = base_model,
    args = training_args,
    train_dataset = tokenized_yuezh_master['train'],
    eval_dataset = tokenized_yuezh_master['val'],
    tokenizer = base_tokenizer,
    data_collator = data_collator,
    compute_metrics = _compute_metrics
)

trainer.train()

我已经做过的排查

  1. 确认虚拟环境激活正常,所有依赖包安装成功,无ImportError
  2. 数据集能正常加载,打印的train/test/val样本数量符合预期
  3. 测试过Tokenizer预处理函数,单条样本能正常生成input_ids、attention_mask和labels
  4. 尝试关闭fp16=True,但仍会报错

想请教的问题

启动trainer.train()后会抛出一系列错误(比如张量维度不匹配、训练循环运算错误等),因为错误信息太多没法全贴,先问问大家:

  • 结合代码和环境,有没有明显的问题点?比如Python 3.12.3和Transformers 4.57.1的兼容性?
  • BertTokenizer搭配BartForConditionalGeneration会不会有潜在不兼容?
  • 训练参数里有没有配置错误的地方?

如果需要我补充具体错误日志或其他信息,我马上贴出来!麻烦各位大佬帮忙看看,谢谢啦!

火山引擎 最新活动