You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

PyTorch转ONNX报错:输入名数量超输入数,含Dropout时失败

问题原因分析

你遇到的RuntimeError: number of input names provided (9) exceeded number of inputs (7)错误,核心问题出在手动构建input_names的逻辑完全错误,再加上没正确处理Dropout的运行模式导致的:

  • 你初始设置了input_names = ['input_1'],之后又遍历模型的每个模块,给每个模块都添加了一个输入名,最终input_names的长度变成1+8=9(你的Sequential模型一共有8个模块)。但实际上,你的PyTorch模型只有1个输入张量,ONNX导出时只需要指定这个输入的名称,完全不需要给每个内部模块额外添加输入名。
  • 移除Dropout后能运行只是巧合:去掉两个Dropout后模块数变成6个,input_names长度变成1+6=7,刚好和错误提示里的“7个输入”数量匹配,但本质上你的input_names构建逻辑还是错的。另外,Dropout在训练模式下会引入随机分支,ONNX需要确定的可复现计算图,这也是导致导出异常的隐性原因。

保留Dropout层的正确导出步骤

要在保留Dropout的前提下成功导出ONNX,你需要做两个关键修正:

  1. 将模型切换到评估模式:Dropout在train()模式下有随机失活逻辑,ONNX需要稳定的计算图,所以必须先调用model.eval(),让Dropout固定失活比例,不再随机丢弃神经元。
  2. 正确设置input_names:只需要指定模型的实际输入张量名称,不需要给每个模块添加额外的输入名。

修正后的完整代码如下:

import torch
# 假设D_in, H, D_out已定义,df_X也已加载完成

# Define the model
model = torch.nn.Sequential(
    torch.nn.Linear(D_in, H),
    torch.nn.ReLU(),
    torch.nn.Dropout(0.2),
    torch.nn.Linear(H, H),
    torch.nn.LeakyReLU(),
    torch.nn.Dropout(0.2),
    torch.nn.Linear(H, D_out),
    torch.nn.Sigmoid()
)

checkpoint = torch.load("./saved_pytorch_model.pth")  # 加载PyTorch模型
model.load_state_dict(checkpoint['state_dict'])

# 关键步骤1:切换模型到评估模式
model.eval()

# 准备输入张量(建议添加batch维度,ONNX通常期望批量输入)
# 原代码的features是1D张量,改成2D适配大多数部署场景
features = torch.Tensor(df_X.values[0]).unsqueeze(0)  # 变成(1, D_in)的batch格式

# Convert PyTorch model to ONNX
input_names = ['input_1']  # 只指定实际输入的名称即可
output_names = ['output_1']

torch_out = torch.onnx.export(
    model, 
    features, 
    "onnx_model.onnx", 
    export_params=True, 
    verbose=True, 
    input_names=input_names, 
    output_names=output_names,
    # 可选:如果需要兼容旧版ONNX工具,可指定opset_version,比如opset_version=11
)

额外说明

  • 为什么要加batch维度?大多数ONNX模型都是为批量输入设计的,直接用1D张量导出可能在后续部署时出现兼容性问题。如果你的场景确实不需要批量,也可以保留1D,但优先推荐用2D张量。
  • 如果你需要给模型内部节点命名,可使用torch.onnx.exportdynamic_axes参数指定动态维度,无需手动给每个模块加输入名——ONNX会自动处理内部节点结构。

内容的提问来源于stack exchange,提问作者Shawn Zhang

火山引擎 最新活动