微调BERT QA模型导出ONNX后精度下降,如何排查原因?
BERT-based QA模型转ONNX后精度下降的排查方案
问题背景
将微调后的BERT-based QA模型导出为ONNX格式以加速推理时,发现ONNX模型预测精度低于原PyTorch模型,预测结果常相差数个token,且已确认输入张量与分词处理完全一致。导出代码如下:
from transformers import AutoModelForQuestionAnswering, AutoTokenizer import torch model_name = "deepset/bert-base-cased-squad2" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForQuestionAnswering.from_pretrained(model_name) dummy = tokenizer("Where is the Eiffel Tower?", "The Eiffel Tower is in Paris.", return_tensors="pt") torch.onnx.export( model, (dummy["input_ids"], dummy["attention_mask"]), "qa_model.onnx", input_names=["input_ids", "attention_mask"], output_names=["start_logits", "end_logits"], opset_version=14 )
可能的诱因
- 输入参数遗漏:BERT QA模型需要
token_type_ids区分问题与上下文,但导出代码未传入该参数,ONNX模型会使用默认全0值,与PyTorch推理时的输入逻辑不一致,这是最可能的核心原因。 - OPSET版本兼容性:不同opset版本对BERT关键操作(如Attention掩码处理、LayerNorm)的实现存在差异,opset14对部分BERT操作的支持不如高版本完善,易引发数值偏差。
- 导出精度与常量折叠:默认导出时的常量折叠可能固化部分静态计算,若模型存在混合精度训练残留,或导出过程中隐式精度转换,会累积数值误差影响最终logits排序。
- 推理执行器差异:ONNX Runtime使用不同执行provider(如TensorRT)时,可能自动启用低精度推理,与PyTorch的FP32计算逻辑产生偏差。
诊断与解决步骤
1. 补全输入参数并重导出
修改导出代码,加入token_type_ids作为输入,同时设置动态轴支持可变序列长度:
torch.onnx.export( model, (dummy["input_ids"], dummy["attention_mask"], dummy["token_type_ids"]), "qa_model.onnx", input_names=["input_ids", "attention_mask", "token_type_ids"], output_names=["start_logits", "end_logits"], opset_version=16, # 升级到更兼容BERT的opset版本 do_constant_folding=False, # 关闭常量折叠便于逐层排查 dynamic_axes={ "input_ids": {0: "batch_size", 1: "seq_len"}, "attention_mask": {0: "batch_size", 1: "seq_len"}, "token_type_ids": {0: "batch_size", 1: "seq_len"} }, export_params=True )
2. 逐层对比输出定位差异
- 将PyTorch模型与ONNX模型拆分为逐层执行,对比每一层的输出张量(如Embedding层、Transformer层输出、最终logits)。
- 示例:先提取PyTorch的Embedding输出:
with torch.no_grad(): pt_emb = model.bert.embeddings(dummy["input_ids"], dummy["token_type_ids"], dummy["attention_mask"]) - 再用ONNX Runtime单独运行对应层(可通过ONNX工具截取子模型),对比两者张量的差值,若差值过大则定位到对应层。
3. 统一推理精度与执行器
- 使用ONNX Runtime的CPU执行器排除GPU加速的精度影响:
import onnxruntime as ort sess = ort.InferenceSession("qa_model.onnx", providers=["CPUExecutionProvider"]) onnx_outputs = sess.run( ["start_logits", "end_logits"], { "input_ids": dummy["input_ids"].numpy(), "attention_mask": dummy["attention_mask"].numpy(), "token_type_ids": dummy["token_type_ids"].numpy() } ) - 对比PyTorch输出:
with torch.no_grad(): pt_outputs = model(**dummy) - 计算两者logits的L2差值,若差值在1e-5以内则为正常数值误差,否则需进一步排查操作层。
4. 排查模型精度状态
若模型采用混合精度训练,导出前需将模型转为FP32:
model = model.float()
内容的提问来源于stack exchange,提问作者vinoth




