You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

超512词大文本下BERT问答模型运行问题排查与解决

解决BERT问答模型处理长文本(超512词)返回[CLS]的问题

问题根源

你遇到的核心问题是BERT系列模型的最大输入长度限制为512个token——当你把长文本和问题一起编码时,超过512的部分会被直接截断。如果答案恰好位于被截断的文本片段里,模型就找不到有效答案,只能返回默认的[CLS]标记。你之前设置的max_length=512只是做了截断,并没有解决答案不在保留片段里的问题。

解决方案:迭代处理文本片段

正确的思路是把长文本分割成多个重叠的小片段(每个片段的长度要预留出问题的token空间),然后让模型对每个片段单独进行问答,最后从所有片段的结果中选出置信度最高的答案。这样就能覆盖整个长文本,确保答案所在的片段被模型处理到。

修改后的完整代码

from transformers import AutoTokenizer, AutoModelForQuestionAnswering
import torch

max_seq_length = 512
tokenizer = AutoTokenizer.from_pretrained("henryk/bert-base-multilingual-cased-finetuned-dutch-squad2")
model = AutoModelForQuestionAnswering.from_pretrained("henryk/bert-base-multilingual-cased-finetuned-dutch-squad2")

# 读取长文本
with open("test.txt", "r") as f:
    text = f.read()

questions = [
    "Wat is de hoofdstad van Nederland?",
    "Van welk automerk is een Cayenne?",
    "In welk jaar is pindakaas geproduceerd?",
]

for question in questions:
    # 先计算问题的token数(包含特殊标记)
    question_tokens = tokenizer.encode(question, add_special_tokens=True)
    question_len = len(question_tokens)
    # 文本片段的最大长度:总长度 - 问题长度 - 预留的特殊标记空间
    max_text_chunk_len = max_seq_length - question_len - 2  # 减去[CLS]和[SEP]的位置

    # 把文本分割成多个重叠的片段(重叠部分确保答案不会被切在中间)
    text_tokens = tokenizer.encode(text, add_special_tokens=False)
    chunks = []
    start = 0
    overlap = 100  # 重叠100个token,避免答案被分割在两个片段之间
    while start < len(text_tokens):
        end = start + max_text_chunk_len
        chunks.append(text_tokens[start:end])
        start = end - overlap

    best_answer = ""
    best_score = -float("inf")

    # 遍历每个片段进行问答
    for chunk in chunks:
        # 拼接问题和当前文本片段
        inputs = tokenizer.encode_plus(
            question,
            tokenizer.decode(chunk),
            add_special_tokens=True,
            max_length=max_seq_length,
            truncation=True,
            return_tensors="pt"
        )

        answer_start_scores, answer_end_scores = model(**inputs, return_dict=False)
        # 获取当前片段的最高分
        start_score = torch.max(answer_start_scores).item()
        end_score = torch.max(answer_end_scores).item()
        total_score = start_score + end_score

        # 提取答案
        answer_start = torch.argmax(answer_start_scores)
        answer_end = torch.argmax(answer_end_scores) + 1
        input_ids = inputs["input_ids"].tolist()[0]
        answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(input_ids[answer_start:answer_end]))

        # 跳过无效的[CLS]答案
        if answer.strip() == "[CLS]":
            continue

        # 更新最优答案
        if total_score > best_score:
            best_score = total_score
            best_answer = answer

    # 输出结果
    print(f"Question: {question}")
    print(f"Answer: {best_answer if best_answer else 'No valid answer found'}\n")

代码关键说明

  • 文本分割逻辑:先计算问题的token长度,确保每个文本片段加上问题后不超过512token;同时设置重叠部分,避免答案被切割在两个片段之间。
  • 置信度筛选:记录每个片段答案的总得分(start_score + end_score),最终选择得分最高的答案,确保最准确的结果被保留。
  • 无效答案过滤:跳过返回[CLS]的片段,避免干扰最终结果。

内容的提问来源于stack exchange,提问作者Liza Darwesh

火山引擎 最新活动