You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

PyTorch中ResNet残差模块代码理解求助

理解ResNet PyTorch实现中的残差块与网络结构

我对Python不太熟悉,在理解ResNet架构的以下代码部分时遇到困难,相关代码片段如下:

ResidualBlock 残差块类

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(ResidualBlock, self).__init__()
        self.conv1 = conv3x3(in_channels, out_channels, stride)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(out_channels, out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.downsample = downsample
    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        if self.downsample:
            residual = self.downsample(x)
        out += residual
        out = self.relu(out)
        return out

ResNet 主网络类

class ResNet(nn.Module):
    def __init__(self, block, layers, num_classes=10):
        super(ResNet, self).__init__()
        self.in_channels = 16
        self.conv = conv3x3(3, 16)
        self.bn = nn.BatchNorm2d(16)
        self.relu = nn.ReLU(inplace=True)
        self.layer1 = self.make_layer(block, 16, layers[0])
        self.layer2 = self.make_layer(block, 32, layers[0], 2)
        self.layer3 = self.make_layer(block, 64, layers[1], 2)
        self.avg_pool = nn.AvgPool2d(8)
        self.fc = nn.Linear(64, num_classes)
    def make_layer(self, block, out_channels, blocks, stride=1):
        downsample = None
        if (stride != 1) or (self.in_channels != out_channels):
            # 补全常见的downsample实现逻辑
            downsample = nn.Sequential(
                conv3x3(self.in_channels, out_channels, stride),
                nn.BatchNorm2d(out_channels),
            )
        layers = []
        layers.append(block(self.in_channels, out_channels, stride, downsample))
        self.in_channels = out_channels
        for _ in range(1, blocks):
            layers.append(block(self.in_channels, out_channels))
        return nn.Sequential(*layers)

代码逐段解析

1. ResidualBlock:残差连接的核心

这是ResNet的灵魂模块,实现了残差跳跃连接,解决了深层网络的退化问题:

  • __init__参数说明:
    • in_channels:输入特征图的通道数
    • out_channels:输出特征图的通道数
    • stride:第一个卷积的步长,用来缩小特征图尺寸
    • downsample:可选调整模块,用来让输入x的维度(通道数/尺寸)和主分支输出匹配,这样才能做加法运算
  • 前向传播逻辑:
    1. 先把输入x存为residual(残差分支)
    2. 主分支走两次3x3卷积+批量归一化+ReLU(注意第二次卷积后先不激活,等残差相加后再统一激活)
    3. 如果存在downsample,就对residual做维度调整(比如通道数从16变32,或者步长2缩小特征图尺寸)
    4. 主分支输出和残差相加,最后过ReLU返回结果

2. ResNet:整体网络结构

这是ResNet的主类,负责把多个残差块组装成完整的分类网络:

  • __init__参数说明:
    • block:刚才定义的ResidualBlock,作为重复使用的基础模块
    • layers:列表,每个元素表示对应阶段的残差块数量(比如[2,2,2]就是每个阶段包含2个残差块)
    • num_classes:最终分类的类别数(这里是10,对应CIFAR-10数据集)
  • 网络流程:
    1. 输入层:将3通道的RGB图像,通过3x3卷积转成16通道特征图,再做批量归一化和ReLU激活
    2. 三个残差块阶段:
      • layer1:保持16通道,步长1,特征图尺寸不变
      • layer2:升级到32通道,步长2,特征图尺寸减半
      • layer3:升级到64通道,步长2,特征图尺寸再减半
    3. 输出层:通过全局平均池化把64x8x8的特征图压缩成64维向量,再通过全连接层输出10类分类结果

3. make_layer:批量创建残差块

这个方法用来批量生成同一通道数下的多个残差块,简化网络搭建:

  • 首先判断是否需要downsample:如果步长不是1,或者当前输入通道数和目标输出通道数不一致,就创建一个卷积+BN的模块,用来调整残差分支的维度
  • 先创建第一个残差块(需要传入downsamplestride来做初始维度调整)
  • 然后循环创建剩下的残差块(步长默认1,不需要额外调整维度)
  • 最后把这些残差块打包成nn.Sequential容器返回

内容的提问来源于stack exchange,提问作者user570593

火山引擎 最新活动