You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

如何获取PyTorch网络的输入输出及输入输出节点信息?

获取PyTorch网络的输入输出及节点信息的方法

一、获取网络的输入与最终输出

最直接的方式是构造一个dummy输入张量(模拟真实输入的形状),将其传入模型即可得到输出,同时输入的信息也一目了然:

import torch
from your_model import YourNet  # 导入你的自定义模型

# 初始化模型并切换到评估模式
model = YourNet()
model.eval()

# 构造dummy输入(根据你的模型输入形状调整,示例为单张RGB 224x224图片)
dummy_input = torch.randn(1, 3, 224, 224)

# 获取输入和输出
input_tensor = dummy_input
output_tensor = model(input_tensor)

# 打印关键信息
print(f"输入尺寸: {input_tensor.shape}")
print(f"最终输出尺寸: {output_tensor.shape}")

如果模型支持可变输入尺寸,也可以通过调整dummy输入的形状来测试不同场景下的输出结果。

二、获取输入输出节点的名称与尺寸信息的API

你提到的model.features()只适用于特定预训练模型(比如torchvision中的VGG系列,把特征提取部分封装成了features属性),并非所有PyTorch模型都有这个属性,所以通用场景下推荐以下几种方法:

1. 使用torchinfo快速生成模型摘要

这是最便捷的工具,能一键输出每层的名称、输入输出尺寸、参数数量等完整信息:

先安装依赖(如果没装的话):

pip install torchinfo

然后使用:

from torchinfo import summary

# 传入模型和输入尺寸(batch维度可根据需求调整)
summary(model, input_size=(1, 3, 224, 224))

输出会清晰列出每一层的Input ShapeOutput Shape和层类型/名称,完全满足你的解析需求。

2. 自定义Forward Hook追踪节点信息

如果你需要更灵活的控制(比如只记录Conv2d、Linear等特定层的信息),可以用PyTorch原生的register_forward_hookAPI手动捕获每层的输入输出:

# 定义存储层信息的列表
layer_info = []

def hook_fn(module, input, output):
    # 提取层类型和输入输出尺寸
    layer_type = str(module.__class__.__name__)
    input_shape = input[0].shape if isinstance(input, tuple) else input.shape
    output_shape = output.shape
    layer_info.append({
        "layer_type": layer_type,
        "input_shape": input_shape,
        "output_shape": output_shape
    })

# 遍历模型子模块,给目标层注册hook
for name, module in model.named_modules():
    if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear, torch.nn.MaxPool2d)):
        module.register_forward_hook(hook_fn)

# 运行dummy输入触发hook
model(dummy_input)

# 打印收集到的信息
for idx, info in enumerate(layer_info):
    print(f"[{idx+1}] 层类型: {info['layer_type']}, 输入尺寸: {info['input_shape']}, 输出尺寸: {info['output_shape']}")

如果想保留模块的完整层级名称(比如features.0这种路径式名称),可以修改hook函数传入模块名称:

def hook_fn_with_name(layer_full_name):
    def fn(module, input, output):
        input_shape = input[0].shape if isinstance(input, tuple) else input.shape
        output_shape = output.shape
        layer_info.append({
            "layer_full_name": layer_full_name,
            "layer_type": str(module.__class__.__name__),
            "input_shape": input_shape,
            "output_shape": output_shape
        })
    return fn

# 注册时传入模块的完整名称
for name, module in model.named_modules():
    if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear, torch.nn.MaxPool2d)):
        module.register_forward_hook(hook_fn_with_name(name))

3. 导出为ONNX结合Netron查看

你已经在使用Netron,那可以把PyTorch模型导出为ONNX格式,Netron会更直观地展示所有节点的名称、形状和连接关系:

# 导出ONNX模型
torch.onnx.export(
    model,
    dummy_input,
    "your_model.onnx",
    opset_version=11,  # 根据你的PyTorch版本选择合适的opset
    input_names=["input"],  # 自定义输入节点名称
    output_names=["output"]  # 自定义输出节点名称
)

用Netron打开导出的.onnx文件后,就能清晰看到每个Conv2d、MaxPool2d、Linear层的输入输出节点名称和尺寸,和你之前查看的结构完全对应。

三、关于model.features()的补充

model.features()是torchvision中部分预训练模型(如VGG、AlexNet)的专属属性,用于封装特征提取的前半部分网络。如果你的自定义模型没有定义这个属性,调用它自然会报错。通用场景下建议用上面提到的torchinfo或Forward Hook方法。

内容的提问来源于stack exchange,提问作者lee

火山引擎 最新活动