You need to enable JavaScript to run this app.
优惠活动
大模型
产品
解决方案
定价
更多
文档控制台
免费开始使用

微调BERT QA模型导出ONNX后精度下降,如何排查原因?

BERT-based QA模型转ONNX后精度下降的排查方案

问题背景

将微调后的BERT-based QA模型导出为ONNX格式以加速推理时,发现ONNX模型预测精度低于原PyTorch模型,预测结果常相差数个token,且已确认输入张量与分词处理完全一致。导出代码如下:

from transformers import AutoModelForQuestionAnswering, AutoTokenizer
import torch

model_name = "deepset/bert-base-cased-squad2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForQuestionAnswering.from_pretrained(model_name)

dummy = tokenizer("Where is the Eiffel Tower?", "The Eiffel Tower is in Paris.", return_tensors="pt")

torch.onnx.export(
    model,
    (dummy["input_ids"], dummy["attention_mask"]),
    "qa_model.onnx",
    input_names=["input_ids", "attention_mask"],
    output_names=["start_logits", "end_logits"],
    opset_version=14
)

可能的诱因

  • 输入参数遗漏:BERT QA模型需要token_type_ids区分问题与上下文,但导出代码未传入该参数,ONNX模型会使用默认全0值,与PyTorch推理时的输入逻辑不一致,这是最可能的核心原因。
  • OPSET版本兼容性:不同opset版本对BERT关键操作(如Attention掩码处理、LayerNorm)的实现存在差异,opset14对部分BERT操作的支持不如高版本完善,易引发数值偏差。
  • 导出精度与常量折叠:默认导出时的常量折叠可能固化部分静态计算,若模型存在混合精度训练残留,或导出过程中隐式精度转换,会累积数值误差影响最终logits排序。
  • 推理执行器差异:ONNX Runtime使用不同执行provider(如TensorRT)时,可能自动启用低精度推理,与PyTorch的FP32计算逻辑产生偏差。

诊断与解决步骤

1. 补全输入参数并重导出

修改导出代码,加入token_type_ids作为输入,同时设置动态轴支持可变序列长度:

torch.onnx.export(
    model,
    (dummy["input_ids"], dummy["attention_mask"], dummy["token_type_ids"]),
    "qa_model.onnx",
    input_names=["input_ids", "attention_mask", "token_type_ids"],
    output_names=["start_logits", "end_logits"],
    opset_version=16,  # 升级到更兼容BERT的opset版本
    do_constant_folding=False,  # 关闭常量折叠便于逐层排查
    dynamic_axes={
        "input_ids": {0: "batch_size", 1: "seq_len"},
        "attention_mask": {0: "batch_size", 1: "seq_len"},
        "token_type_ids": {0: "batch_size", 1: "seq_len"}
    },
    export_params=True
)

2. 逐层对比输出定位差异

  • 将PyTorch模型与ONNX模型拆分为逐层执行,对比每一层的输出张量(如Embedding层、Transformer层输出、最终logits)。
  • 示例:先提取PyTorch的Embedding输出:
    with torch.no_grad():
        pt_emb = model.bert.embeddings(dummy["input_ids"], dummy["token_type_ids"], dummy["attention_mask"])
    
  • 再用ONNX Runtime单独运行对应层(可通过ONNX工具截取子模型),对比两者张量的差值,若差值过大则定位到对应层。

3. 统一推理精度与执行器

  • 使用ONNX Runtime的CPU执行器排除GPU加速的精度影响:
    import onnxruntime as ort
    
    sess = ort.InferenceSession("qa_model.onnx", providers=["CPUExecutionProvider"])
    onnx_outputs = sess.run(
        ["start_logits", "end_logits"],
        {
            "input_ids": dummy["input_ids"].numpy(),
            "attention_mask": dummy["attention_mask"].numpy(),
            "token_type_ids": dummy["token_type_ids"].numpy()
        }
    )
    
  • 对比PyTorch输出:
    with torch.no_grad():
        pt_outputs = model(**dummy)
    
  • 计算两者logits的L2差值,若差值在1e-5以内则为正常数值误差,否则需进一步排查操作层。

4. 排查模型精度状态

若模型采用混合精度训练,导出前需将模型转为FP32:

model = model.float()

内容的提问来源于stack exchange,提问作者vinoth

火山引擎 最新活动