如何正确保存加载含Resize Token Embeddings的PEFT训练Unsloth模型?
解决Unsloth FastVisionModel添加新Token后加载模型尺寸不匹配问题
问题背景
基于unsloth/qwen2-VL-2B-Instruct训练含大量特殊字符的数据集时,添加新Token并完成PEFT训练后,直接加载保存的模型会触发RuntimeError,提示嵌入层与lm_head权重尺寸不匹配。现有加载方法要么无法正常启动,要么无法保留新Token的训练嵌入,导致模型性能下降。
核心原因
- PEFT训练默认不更新新增Token的嵌入权重(该部分属于基础模型的非适配器参数)
- 直接保存PEFT模型时,仅会存储适配器权重,不会包含调整后的完整嵌入层权重,加载时会与原始基础模型的嵌入尺寸冲突
正确解决方案流程
步骤1:训练完成后完整保存必要组件
除了保存PEFT适配器和分词器,需单独存储更新后的嵌入层与lm_head权重:
# 保存PEFT适配器与分词器 model.save_pretrained("outputs/checkpoint-60") tokenizer.tokenizer.save_pretrained("outputs/checkpoint-60") # 单独保存嵌入层和lm_head的权重(训练后已更新) torch.save( { "embed_tokens": model.base_model.model.model.embed_tokens.weight, "lm_head": model.base_model.model.lm_head.weight }, "outputs/checkpoint-60/embedding_weights.pt" )
步骤2:加载模型时恢复完整参数
按以下顺序加载,确保嵌入层尺寸匹配且恢复训练后的权重:
import torch from peft import PeftModel from unsloth import FastVisionModel from transformers import AutoTokenizer # 1. 加载原始基础模型 model, tokenizer = FastVisionModel.from_pretrained("unsloth/qwen2-VL-2B-Instruct") # 2. 加载训练时保存的分词器(自动包含新增Token) tokenizer = AutoTokenizer.from_pretrained("outputs/checkpoint-60") # 3. 调整嵌入层尺寸以匹配新增Token后的词汇表 model.resize_token_embeddings(len(tokenizer.tokenizer)) # 4. 加载PEFT适配器 model = PeftModel.from_pretrained(model, "outputs/checkpoint-60") # 5. 恢复训练后的嵌入层与lm_head权重 embedding_weights = torch.load("outputs/checkpoint-60/embedding_weights.pt") model.base_model.model.model.embed_tokens.weight.data = embedding_weights["embed_tokens"] model.base_model.model.lm_head.weight.data = embedding_weights["lm_head"]
步骤3:可选:合并模型实现无缝加载
若需要后续直接通过FastVisionModel.from_pretrained加载,可合并PEFT适配器与基础模型后保存:
# 合并适配器与基础模型(需确保已恢复嵌入层权重) merged_model = model.merge_and_unload() # 保存完整模型与分词器 merged_model.save_pretrained("outputs/full_finetuned_model") tokenizer.save_pretrained("outputs/full_finetuned_model") # 后续直接加载 model, tokenizer = FastVisionModel.from_pretrained("outputs/full_finetuned_model")
关键补充说明
- 新增Token的初始嵌入是随机生成的,若要让模型真正学习其语义,建议在PEFT训练前,先单独微调嵌入层1-2轮(冻结其他参数),再开启完整PEFT训练
- 合并模型会占用更多磁盘空间,但加载流程更简洁,适合部署场景
内容的提问来源于stack exchange,提问作者GauravGiri




