超512词大文本下BERT问答模型运行问题排查与解决
解决BERT问答模型处理长文本(超512词)返回[CLS]的问题
问题根源
你遇到的核心问题是BERT系列模型的最大输入长度限制为512个token——当你把长文本和问题一起编码时,超过512的部分会被直接截断。如果答案恰好位于被截断的文本片段里,模型就找不到有效答案,只能返回默认的[CLS]标记。你之前设置的max_length=512只是做了截断,并没有解决答案不在保留片段里的问题。
解决方案:迭代处理文本片段
正确的思路是把长文本分割成多个重叠的小片段(每个片段的长度要预留出问题的token空间),然后让模型对每个片段单独进行问答,最后从所有片段的结果中选出置信度最高的答案。这样就能覆盖整个长文本,确保答案所在的片段被模型处理到。
修改后的完整代码
from transformers import AutoTokenizer, AutoModelForQuestionAnswering import torch max_seq_length = 512 tokenizer = AutoTokenizer.from_pretrained("henryk/bert-base-multilingual-cased-finetuned-dutch-squad2") model = AutoModelForQuestionAnswering.from_pretrained("henryk/bert-base-multilingual-cased-finetuned-dutch-squad2") # 读取长文本 with open("test.txt", "r") as f: text = f.read() questions = [ "Wat is de hoofdstad van Nederland?", "Van welk automerk is een Cayenne?", "In welk jaar is pindakaas geproduceerd?", ] for question in questions: # 先计算问题的token数(包含特殊标记) question_tokens = tokenizer.encode(question, add_special_tokens=True) question_len = len(question_tokens) # 文本片段的最大长度:总长度 - 问题长度 - 预留的特殊标记空间 max_text_chunk_len = max_seq_length - question_len - 2 # 减去[CLS]和[SEP]的位置 # 把文本分割成多个重叠的片段(重叠部分确保答案不会被切在中间) text_tokens = tokenizer.encode(text, add_special_tokens=False) chunks = [] start = 0 overlap = 100 # 重叠100个token,避免答案被分割在两个片段之间 while start < len(text_tokens): end = start + max_text_chunk_len chunks.append(text_tokens[start:end]) start = end - overlap best_answer = "" best_score = -float("inf") # 遍历每个片段进行问答 for chunk in chunks: # 拼接问题和当前文本片段 inputs = tokenizer.encode_plus( question, tokenizer.decode(chunk), add_special_tokens=True, max_length=max_seq_length, truncation=True, return_tensors="pt" ) answer_start_scores, answer_end_scores = model(**inputs, return_dict=False) # 获取当前片段的最高分 start_score = torch.max(answer_start_scores).item() end_score = torch.max(answer_end_scores).item() total_score = start_score + end_score # 提取答案 answer_start = torch.argmax(answer_start_scores) answer_end = torch.argmax(answer_end_scores) + 1 input_ids = inputs["input_ids"].tolist()[0] answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(input_ids[answer_start:answer_end])) # 跳过无效的[CLS]答案 if answer.strip() == "[CLS]": continue # 更新最优答案 if total_score > best_score: best_score = total_score best_answer = answer # 输出结果 print(f"Question: {question}") print(f"Answer: {best_answer if best_answer else 'No valid answer found'}\n")
代码关键说明
- 文本分割逻辑:先计算问题的token长度,确保每个文本片段加上问题后不超过512token;同时设置重叠部分,避免答案被切割在两个片段之间。
- 置信度筛选:记录每个片段答案的总得分(start_score + end_score),最终选择得分最高的答案,确保最准确的结果被保留。
- 无效答案过滤:跳过返回
[CLS]的片段,避免干扰最终结果。
内容的提问来源于stack exchange,提问作者Liza Darwesh




