PyTorch中ResNet残差模块代码理解求助
理解ResNet PyTorch实现中的残差块与网络结构
我对Python不太熟悉,在理解ResNet架构的以下代码部分时遇到困难,相关代码片段如下:
ResidualBlock 残差块类
class ResidualBlock(nn.Module): def __init__(self, in_channels, out_channels, stride=1, downsample=None): super(ResidualBlock, self).__init__() self.conv1 = conv3x3(in_channels, out_channels, stride) self.bn1 = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU(inplace=True) self.conv2 = conv3x3(out_channels, out_channels) self.bn2 = nn.BatchNorm2d(out_channels) self.downsample = downsample def forward(self, x): residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) if self.downsample: residual = self.downsample(x) out += residual out = self.relu(out) return out
ResNet 主网络类
class ResNet(nn.Module): def __init__(self, block, layers, num_classes=10): super(ResNet, self).__init__() self.in_channels = 16 self.conv = conv3x3(3, 16) self.bn = nn.BatchNorm2d(16) self.relu = nn.ReLU(inplace=True) self.layer1 = self.make_layer(block, 16, layers[0]) self.layer2 = self.make_layer(block, 32, layers[0], 2) self.layer3 = self.make_layer(block, 64, layers[1], 2) self.avg_pool = nn.AvgPool2d(8) self.fc = nn.Linear(64, num_classes) def make_layer(self, block, out_channels, blocks, stride=1): downsample = None if (stride != 1) or (self.in_channels != out_channels): # 补全常见的downsample实现逻辑 downsample = nn.Sequential( conv3x3(self.in_channels, out_channels, stride), nn.BatchNorm2d(out_channels), ) layers = [] layers.append(block(self.in_channels, out_channels, stride, downsample)) self.in_channels = out_channels for _ in range(1, blocks): layers.append(block(self.in_channels, out_channels)) return nn.Sequential(*layers)
代码逐段解析
1. ResidualBlock:残差连接的核心
这是ResNet的灵魂模块,实现了残差跳跃连接,解决了深层网络的退化问题:
__init__参数说明:in_channels:输入特征图的通道数out_channels:输出特征图的通道数stride:第一个卷积的步长,用来缩小特征图尺寸downsample:可选调整模块,用来让输入x的维度(通道数/尺寸)和主分支输出匹配,这样才能做加法运算
- 前向传播逻辑:
- 先把输入
x存为residual(残差分支) - 主分支走两次3x3卷积+批量归一化+ReLU(注意第二次卷积后先不激活,等残差相加后再统一激活)
- 如果存在
downsample,就对residual做维度调整(比如通道数从16变32,或者步长2缩小特征图尺寸) - 主分支输出和残差相加,最后过ReLU返回结果
- 先把输入
2. ResNet:整体网络结构
这是ResNet的主类,负责把多个残差块组装成完整的分类网络:
__init__参数说明:block:刚才定义的ResidualBlock,作为重复使用的基础模块layers:列表,每个元素表示对应阶段的残差块数量(比如[2,2,2]就是每个阶段包含2个残差块)num_classes:最终分类的类别数(这里是10,对应CIFAR-10数据集)
- 网络流程:
- 输入层:将3通道的RGB图像,通过3x3卷积转成16通道特征图,再做批量归一化和ReLU激活
- 三个残差块阶段:
layer1:保持16通道,步长1,特征图尺寸不变layer2:升级到32通道,步长2,特征图尺寸减半layer3:升级到64通道,步长2,特征图尺寸再减半
- 输出层:通过全局平均池化把64x8x8的特征图压缩成64维向量,再通过全连接层输出10类分类结果
3. make_layer:批量创建残差块
这个方法用来批量生成同一通道数下的多个残差块,简化网络搭建:
- 首先判断是否需要
downsample:如果步长不是1,或者当前输入通道数和目标输出通道数不一致,就创建一个卷积+BN的模块,用来调整残差分支的维度 - 先创建第一个残差块(需要传入
downsample和stride来做初始维度调整) - 然后循环创建剩下的残差块(步长默认1,不需要额外调整维度)
- 最后把这些残差块打包成
nn.Sequential容器返回
内容的提问来源于stack exchange,提问作者user570593




