这是因为在BART模型的推理过程中遇到了CUDA的设备端断言。要解决此问题,可以尝试使用PyTorch的torch.autograd.set_detect_anomaly(True)
来发现计算图中的异常,将其包装在with torch.autograd.detect_anomaly():
块中进行训练,以便在发现错误时进行调试。
示例代码:
import torch
from transformers import BartForConditionalGeneration, BartTokenizer
tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
model = BartForConditionalGeneration.from_pretrained('facebook/bart-large')
# 打开异常检测
torch.autograd.set_detect_anomaly(True)
# 将代码包装在with语句中
with torch.autograd.detect_anomaly():
input_text = "summarize: it is a beautiful day outside, with the sun shining and the birds singing."
input_ids = tokenizer.encode(input_text, return_tensors='pt')
outputs = model.generate(input_ids)