You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

如何计算卷积神经网络(CNN)的总参数量?附示例代码

如何计算卷积神经网络(CNN)的总参数量?

我来一步步拆解怎么计算CNN的可训练参数量,结合你给出的示例代码来解释会更直观~

首先要明确:只有卷积层、全连接层(Dense)这类带可训练权重的层才会贡献参数量,像激活层(Activation)、池化层(MaxPooling2D)、Dropout这些层是没有可训练参数的,直接跳过就行。

一、卷积层的参数计算方法

卷积层的参数量由两部分组成:卷积核的权重 + 每个输出通道的偏置项,公式是:
(卷积核高度 × 卷积核宽度 × 输入通道数) × 输出通道数 + 输出通道数

我们来逐个分析示例里的卷积层:

  1. 第一个Conv2D层:

    cnn_model.add(Conv2D(32, (3, 3), padding='same', input_shape=input_shape))
    

    输入通道数是input_shape的最后一位,也就是1;输出通道数是32,卷积核尺寸是3×3。代入公式:
    (3×3×1)×32 + 32 = 288 + 32 = 320 个参数

  2. 第二个Conv2D层:

    cnn_model.add(Conv2D(64, (3, 3)))
    

    这里的输入通道数是上一层的输出通道数32,输出通道数64,卷积核还是3×3:
    (3×3×32)×64 + 64 = 18432 + 64 = 18496 个参数

二、全连接层(Dense)的参数计算方法

如果模型后面接了全连接层,需要先算出Flatten后的特征图尺寸,再用公式:
(输入特征数) × 输出神经元数 + 输出神经元数

先看示例里的特征图变化:

  • 输入是32×32×1,第一个卷积层用了padding='same',所以卷积后尺寸还是32×32;经过MaxPooling2D(2,2)后,尺寸变成16×16
  • 第二个卷积层没有padding='same',输入16×16的特征图,卷积后尺寸是16-3+1=14×14;再经过MaxPooling2D(2,2),尺寸变成7×7
  • 所以Flatten后的总特征数是 7×7×64 = 3136

假设后面补全全连接层代码:

cnn_model.add(Flatten())
cnn_model.add(Dense(num_classes))
cnn_model.add(Activation('softmax'))

代入公式计算全连接层参数:
3136×4 +4 = 12544 +4 =12548 个参数

三、总参数量求和

把所有带可训练参数的层加起来:
320 + 18496 +12548 = 31364

最后给你补全完整的示例代码方便验证:

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, Activation, MaxPooling2D, Dropout, Flatten, Dense

input_shape = (32, 32, 1)
num_classes = 4

cnn_model = Sequential()
cnn_model.add(Conv2D(32, (3, 3), padding='same', input_shape=input_shape))
cnn_model.add(Activation('relu'))
cnn_model.add(MaxPooling2D(pool_size=(2, 2)))
cnn_model.add(Conv2D(64, (3, 3)))
cnn_model.add(Activation('relu'))
cnn_model.add(MaxPooling2D(pool_size=(2, 2)))
cnn_model.add(Dropout(0.25))
cnn_model.add(Flatten())
cnn_model.add(Dense(num_classes))
cnn_model.add(Activation('softmax'))

# 可以直接用模型的summary()方法验证参数数量
cnn_model.summary()

运行summary()后,你会看到各层的参数数量和总参数数和我们手动计算的一致~

内容的提问来源于stack exchange,提问作者FlyingBurger

火山引擎 最新活动