如何获取PyTorch网络的输入输出及输入输出节点信息?
一、获取网络的输入与最终输出
最直接的方式是构造一个dummy输入张量(模拟真实输入的形状),将其传入模型即可得到输出,同时输入的信息也一目了然:
import torch from your_model import YourNet # 导入你的自定义模型 # 初始化模型并切换到评估模式 model = YourNet() model.eval() # 构造dummy输入(根据你的模型输入形状调整,示例为单张RGB 224x224图片) dummy_input = torch.randn(1, 3, 224, 224) # 获取输入和输出 input_tensor = dummy_input output_tensor = model(input_tensor) # 打印关键信息 print(f"输入尺寸: {input_tensor.shape}") print(f"最终输出尺寸: {output_tensor.shape}")
如果模型支持可变输入尺寸,也可以通过调整dummy输入的形状来测试不同场景下的输出结果。
二、获取输入输出节点的名称与尺寸信息的API
你提到的model.features()只适用于特定预训练模型(比如torchvision中的VGG系列,把特征提取部分封装成了features属性),并非所有PyTorch模型都有这个属性,所以通用场景下推荐以下几种方法:
1. 使用torchinfo快速生成模型摘要
这是最便捷的工具,能一键输出每层的名称、输入输出尺寸、参数数量等完整信息:
先安装依赖(如果没装的话):
pip install torchinfo
然后使用:
from torchinfo import summary # 传入模型和输入尺寸(batch维度可根据需求调整) summary(model, input_size=(1, 3, 224, 224))
输出会清晰列出每一层的Input Shape、Output Shape和层类型/名称,完全满足你的解析需求。
2. 自定义Forward Hook追踪节点信息
如果你需要更灵活的控制(比如只记录Conv2d、Linear等特定层的信息),可以用PyTorch原生的register_forward_hookAPI手动捕获每层的输入输出:
# 定义存储层信息的列表 layer_info = [] def hook_fn(module, input, output): # 提取层类型和输入输出尺寸 layer_type = str(module.__class__.__name__) input_shape = input[0].shape if isinstance(input, tuple) else input.shape output_shape = output.shape layer_info.append({ "layer_type": layer_type, "input_shape": input_shape, "output_shape": output_shape }) # 遍历模型子模块,给目标层注册hook for name, module in model.named_modules(): if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear, torch.nn.MaxPool2d)): module.register_forward_hook(hook_fn) # 运行dummy输入触发hook model(dummy_input) # 打印收集到的信息 for idx, info in enumerate(layer_info): print(f"[{idx+1}] 层类型: {info['layer_type']}, 输入尺寸: {info['input_shape']}, 输出尺寸: {info['output_shape']}")
如果想保留模块的完整层级名称(比如features.0这种路径式名称),可以修改hook函数传入模块名称:
def hook_fn_with_name(layer_full_name): def fn(module, input, output): input_shape = input[0].shape if isinstance(input, tuple) else input.shape output_shape = output.shape layer_info.append({ "layer_full_name": layer_full_name, "layer_type": str(module.__class__.__name__), "input_shape": input_shape, "output_shape": output_shape }) return fn # 注册时传入模块的完整名称 for name, module in model.named_modules(): if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear, torch.nn.MaxPool2d)): module.register_forward_hook(hook_fn_with_name(name))
3. 导出为ONNX结合Netron查看
你已经在使用Netron,那可以把PyTorch模型导出为ONNX格式,Netron会更直观地展示所有节点的名称、形状和连接关系:
# 导出ONNX模型 torch.onnx.export( model, dummy_input, "your_model.onnx", opset_version=11, # 根据你的PyTorch版本选择合适的opset input_names=["input"], # 自定义输入节点名称 output_names=["output"] # 自定义输出节点名称 )
用Netron打开导出的.onnx文件后,就能清晰看到每个Conv2d、MaxPool2d、Linear层的输入输出节点名称和尺寸,和你之前查看的结构完全对应。
三、关于model.features()的补充
model.features()是torchvision中部分预训练模型(如VGG、AlexNet)的专属属性,用于封装特征提取的前半部分网络。如果你的自定义模型没有定义这个属性,调用它自然会报错。通用场景下建议用上面提到的torchinfo或Forward Hook方法。
内容的提问来源于stack exchange,提问作者lee




