You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

MobileSAM.pt解码模块转ONNX格式失败,寻求浏览器端部署该模型的方案

MobileSAM.pt解码模块转ONNX格式失败,寻求浏览器端部署该模型的方案

我太懂你这种转模型卡壳的烦躁了!之前帮朋友踩过MobileSAM浏览器部署的坑,给你梳理几个可行的解决方向:

一、先搞定ONNX解码模块的导出问题

你提到编码模块导出成功了,解码模块卡壳,大概率是输入对齐或者export参数没配对,试试这些调整:

  1. 先理清楚环境和模型状态
    • 确保用的是torch2.0以上的稳定版,别用nightly测试版,很多export的兼容bug都是新版本带的;
    • 导出前一定要把解码模块设为eval模式:model.decoder.eval(),然后在torch.no_grad()上下文里执行导出操作,避免梯度相关的报错。
  2. 严格对齐dummy input的形状和类型
    解码模块的输入比编码模块复杂得多,不能随便给个空张量,得和实际推理时的输入完全匹配:
    比如你要构造这些dummy输入(形状要和编码模块输出的特征图对应):
    import torch
    # 假设编码模块输出的embedding形状是(1, 256, 64, 64)
    image_embeddings = torch.randn(1, 256, 64, 64)
    # 模拟2个点提示
    point_coords = torch.randn(1, 2, 2)
    point_labels = torch.tensor([[1, 0]])
    # 初始mask输入,和embedding的空间维度一致
    mask_input = torch.randn(1, 1, 64, 64)
    has_mask_input = torch.tensor([1])
    dummy_inputs = (image_embeddings, point_coords, point_labels, mask_input, has_mask_input)
    
  3. 调整export的关键参数
    解码模块的输入有动态维度(比如不同数量的点提示),必须加上dynamic_axes来指定动态维度,不然export会因为形状不固定报错,试试这个导出代码:
    dynamic_axes = {
        "image_embeddings": {0: "batch", 2: "height", 3: "width"},
        "point_coords": {0: "batch", 1: "num_points"},
        "point_labels": {0: "batch", 1: "num_points"},
        "mask_input": {0: "batch", 2: "height", 3: "width"},
        "has_mask_input": {0: "batch"},
    }
    torch.onnx.export(
        model.decoder,
        dummy_inputs,
        "mobile_sam_decoder.onnx",
        opset_version=16,  # 用16版本兼容大部分浏览器端算子
        dynamic_axes=dynamic_axes,
        input_names=["image_embeddings", "point_coords", "point_labels", "mask_input", "has_mask_input"],
        output_names=["masks", "iou_predictions"],
        verbose=True  # 打开这个能看到具体哪个环节出错
    )
    
  4. 用draft_export排查具体错误
    按照提示把export()换成draft_export(),它会输出更详细的错误日志,比如是哪个算子不支持ONNX导出,你可以根据日志把不兼容的自定义算子替换成PyTorch原生算子,或者调整opset版本试试(比如降到15)。

二、如果ONNX实在搞不定,试试浏览器部署的替代方案

  1. 转TorchScript用PyTorch.js部署
    这是另一个很成熟的浏览器端部署路径:
    • 分别把编码和解码模块转成TorchScript:
      # 转编码模块
      model.encoder.eval()
      dummy_image = torch.randn(1, 3, 1024, 1024)  # 输入图像形状
      traced_encoder = torch.jit.trace(model.encoder, dummy_image)
      traced_encoder.save("mobile_sam_encoder.pt")
      # 转解码模块
      model.decoder.eval()
      traced_decoder = torch.jit.trace(model.decoder, dummy_inputs)  # dummy_inputs就是上面构造的那个
      traced_decoder.save("mobile_sam_decoder.pt")
      
    • 然后用PyTorch.js把这两个.pt文件转成浏览器能加载的格式,之后在前端用PyTorch.js的API加载模型,分两步推理:先跑编码模块得到embedding,再传进解码模块得到mask。
  2. 用WebAssembly推理框架直接加载预转换权重
    有些开源项目已经把MobileSAM预转成了浏览器兼容的权重格式,你可以在官方MobileSAM仓库或者相关的开源部署代码里找现成的转换脚本,直接用这些脚本生成浏览器端能跑的权重,省得自己踩坑。

最后提醒一下,浏览器端部署要注意模型大小,MobileSAM已经很小了,但还是可以做一些量化,比如转成int8量化的ONNX或者TorchScript,能大幅提升浏览器端的推理速度,也减少内存占用。

火山引擎 最新活动