如何基于Hugging Face Transformers按等长Token拆分长文本并拼接摘要结果
解决方案
要实现基于Token级别的精准拆分,核心是用模型对应的Tokenizer来计算和分割文本——毕竟不同模型的Token化规则不一样(比如BART会把长词拆成子词),只有用模型自带的Tokenizer才能得到和模型处理逻辑完全一致的Token计数,解决空格拆分和Token数不符的问题。
下面是调整后的完整代码,我会逐段解释关键改动:
import logging from transformers import pipeline, AutoTokenizer # 加载模型及配套Tokenizer,确保Token计数和模型逻辑对齐 model_name = "facebook/bart-large-cnn" summarizer = pipeline("summarization", model=model_name) tokenizer = AutoTokenizer.from_pretrained(model_name) # 读取原始长文本 with open("TextFile1.txt", "r") as f: ARTICLE = f.read() # 定义每个片段的最大Token容量 MAX_TOKEN_LENGTH = 1024 def split_text_by_tokens(text, max_tokens): # 将文本编码为Token列表(不生成张量,直接得到可操作的列表) tokens = tokenizer.encode(text, add_special_tokens=False) split_token_chunks = [] # 按max_tokens长度拆分Token列表 for i in range(0, len(tokens), max_tokens): chunk = tokens[i:i+max_tokens] split_token_chunks.append(chunk) # 将每个Token片段解码回可读文本 return [tokenizer.decode(chunk, skip_special_tokens=True) for chunk in split_token_chunks] def summarize_long_text(text): # 拆分文本为Token级片段 text_chunks = split_text_by_tokens(text, MAX_TOKEN_LENGTH) all_summaries = [] counter = 1 for chunk in text_chunks: try: # 对单个片段生成摘要 summary = summarizer(chunk, min_length=30, do_sample=False) all_summaries.append(summary[0]["summary_text"]) # 保留你原来的片段保存逻辑,方便验证拆分结果 with open(f'parsed_{counter}.txt', 'w') as f: f.write(chunk) counter += 1 except Exception as ex: logging.warning(f"处理第{counter}个片段时出错: {str(ex)}") all_summaries.append(f"[第{counter}个片段处理失败]") # 按原始顺序拼接所有摘要 return " ".join(all_summaries) # 执行长文本摘要并写入结果文件 final_summary = summarize_long_text(ARTICLE) with open('summarized.txt', 'w') as f: f.write(final_summary)
关键改动说明:
- 引入配套Tokenizer:用
AutoTokenizer加载和BART模型匹配的Tokenizer,从根源上解决了空格拆分和Token计数不一致的问题。 - Token级拆分逻辑:先把文本编码成Token列表,再按1024长度拆分,最后解码回文本——确保每个片段的Token数严格控制在模型可处理的范围内。
- 替代递归拆分:用循环遍历每个片段生成摘要,逻辑更直观,也避免了递归可能带来的栈溢出问题。
- 保留原有功能:依然保留了保存每个原始片段到文件的逻辑,方便你验证拆分是否符合预期。
可选优化(按需添加):
如果担心片段之间的上下文断裂,可以让片段之间保留少量重叠(比如每个片段和前一个重叠100个Token),提升摘要的连贯性。只需修改split_text_by_tokens里的循环部分:
# 带有100个Token重叠的拆分逻辑 OVERLAP_TOKENS = 100 for i in range(0, len(tokens), max_tokens - OVERLAP_TOKENS): chunk = tokens[i:i+max_tokens] split_token_chunks.append(chunk) # 避免最后一次循环重复处理剩余Token if i + max_tokens >= len(tokens): break
内容的提问来源于stack exchange,提问作者Furkan Gözükara




