MobileSAM.pt解码模块转ONNX格式失败,寻求浏览器端部署该模型的方案
MobileSAM.pt解码模块转ONNX格式失败,寻求浏览器端部署该模型的方案
我太懂你这种转模型卡壳的烦躁了!之前帮朋友踩过MobileSAM浏览器部署的坑,给你梳理几个可行的解决方向:
一、先搞定ONNX解码模块的导出问题
你提到编码模块导出成功了,解码模块卡壳,大概率是输入对齐或者export参数没配对,试试这些调整:
- 先理清楚环境和模型状态
- 确保用的是torch2.0以上的稳定版,别用nightly测试版,很多export的兼容bug都是新版本带的;
- 导出前一定要把解码模块设为eval模式:
model.decoder.eval(),然后在torch.no_grad()上下文里执行导出操作,避免梯度相关的报错。
- 严格对齐dummy input的形状和类型
解码模块的输入比编码模块复杂得多,不能随便给个空张量,得和实际推理时的输入完全匹配:
比如你要构造这些dummy输入(形状要和编码模块输出的特征图对应):import torch # 假设编码模块输出的embedding形状是(1, 256, 64, 64) image_embeddings = torch.randn(1, 256, 64, 64) # 模拟2个点提示 point_coords = torch.randn(1, 2, 2) point_labels = torch.tensor([[1, 0]]) # 初始mask输入,和embedding的空间维度一致 mask_input = torch.randn(1, 1, 64, 64) has_mask_input = torch.tensor([1]) dummy_inputs = (image_embeddings, point_coords, point_labels, mask_input, has_mask_input) - 调整export的关键参数
解码模块的输入有动态维度(比如不同数量的点提示),必须加上dynamic_axes来指定动态维度,不然export会因为形状不固定报错,试试这个导出代码:dynamic_axes = { "image_embeddings": {0: "batch", 2: "height", 3: "width"}, "point_coords": {0: "batch", 1: "num_points"}, "point_labels": {0: "batch", 1: "num_points"}, "mask_input": {0: "batch", 2: "height", 3: "width"}, "has_mask_input": {0: "batch"}, } torch.onnx.export( model.decoder, dummy_inputs, "mobile_sam_decoder.onnx", opset_version=16, # 用16版本兼容大部分浏览器端算子 dynamic_axes=dynamic_axes, input_names=["image_embeddings", "point_coords", "point_labels", "mask_input", "has_mask_input"], output_names=["masks", "iou_predictions"], verbose=True # 打开这个能看到具体哪个环节出错 ) - 用draft_export排查具体错误
按照提示把export()换成draft_export(),它会输出更详细的错误日志,比如是哪个算子不支持ONNX导出,你可以根据日志把不兼容的自定义算子替换成PyTorch原生算子,或者调整opset版本试试(比如降到15)。
二、如果ONNX实在搞不定,试试浏览器部署的替代方案
- 转TorchScript用PyTorch.js部署
这是另一个很成熟的浏览器端部署路径:- 分别把编码和解码模块转成TorchScript:
# 转编码模块 model.encoder.eval() dummy_image = torch.randn(1, 3, 1024, 1024) # 输入图像形状 traced_encoder = torch.jit.trace(model.encoder, dummy_image) traced_encoder.save("mobile_sam_encoder.pt") # 转解码模块 model.decoder.eval() traced_decoder = torch.jit.trace(model.decoder, dummy_inputs) # dummy_inputs就是上面构造的那个 traced_decoder.save("mobile_sam_decoder.pt") - 然后用PyTorch.js把这两个.pt文件转成浏览器能加载的格式,之后在前端用PyTorch.js的API加载模型,分两步推理:先跑编码模块得到embedding,再传进解码模块得到mask。
- 分别把编码和解码模块转成TorchScript:
- 用WebAssembly推理框架直接加载预转换权重
有些开源项目已经把MobileSAM预转成了浏览器兼容的权重格式,你可以在官方MobileSAM仓库或者相关的开源部署代码里找现成的转换脚本,直接用这些脚本生成浏览器端能跑的权重,省得自己踩坑。
最后提醒一下,浏览器端部署要注意模型大小,MobileSAM已经很小了,但还是可以做一些量化,比如转成int8量化的ONNX或者TorchScript,能大幅提升浏览器端的推理速度,也减少内存占用。




