You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

Keras中预训练模型输入形状修改方法求助:8通道输入场景下的报错与精度问题

Keras中预训练模型输入形状修改方法求助:8通道输入场景下的报错与精度问题

Hi there! Let's tackle your two main issues one by one—getting the pretrained InceptionResNetV2 to work with 8-channel inputs, and fixing the low accuracy caused by class imbalance when training from scratch.


1. Fixing the Pretrained Model Input Channel Error

The error happens because the ImageNet-pretrained InceptionResNetV2 is explicitly trained for 3-channel RGB images, so its first convolutional layer expects exactly 3 input channels. There are two reliable ways to adapt it to 8-channel inputs:

Option 1: Modify the First Convolutional Layer's Input Channels

This approach keeps the pretrained weights for the first 3 channels and initializes weights for the extra 5 channels, preserving as much of the pretrained feature extraction capability as possible:

import numpy as np
from tensorflow import keras
from tensorflow.keras import applications

# Load the pretrained 3-channel model first
base_3ch = applications.InceptionResNetV2(
    include_top=False,
    weights='imagenet',
    input_shape=[512, 512, 3]
)

# Extract weights from the first convolutional layer
original_weights = base_3ch.layers[0].get_weights()
original_kernel, original_bias = original_weights

# Create a new 8-channel input convolutional layer (matches original layer params except input channels)
new_input_conv = keras.layers.Conv2D(
    filters=32,
    kernel_size=(3, 3),
    strides=(2, 2),
    padding='valid',
    activation='relu',
    input_shape=[512, 512, 8]
)

# Adjust the kernel weights: keep original 3-channel weights, initialize extra 5 channels randomly
# Original kernel shape: (3, 3, 3, 32) → we need (3, 3, 8, 32)
new_kernel = np.concatenate(
    [original_kernel, np.random.normal(size=(3, 3, 5, 32))],
    axis=2
)
new_input_conv.set_weights([new_kernel, original_bias])

# Build the new base model with the modified input layer
base = keras.Sequential()
base.add(new_input_conv)
# Add all remaining layers from the pretrained model
for layer in base_3ch.layers[1:]:
    base.add(layer)

# Build your full model as before
model = keras.Sequential([
    base,
    keras.layers.BatchNormalization(renorm=True),
    keras.layers.GlobalAveragePooling2D(),
    keras.layers.Dense(512, activation='relu'),
    keras.layers.Dropout(0.5),
    keras.layers.Dense(256, activation='relu'),
    keras.layers.Dropout(0.3),
    keras.layers.Dense(128, activation='relu'),
    keras.layers.Dense(7, activation='softmax')
])

Option 2: Add a 1x1 Convolution to Map 8 Channels to 3

This is a simpler workaround—we add a 1x1 convolution layer to convert 8 channels to 3, then feed that into the pretrained model. Note that this might lose some channel-specific information, but it's quick to implement:

from tensorflow import keras
from tensorflow.keras import applications

# Define input layer for 8 channels
input_layer = keras.Input(shape=[512, 512, 8])
# 1x1 convolution to map 8 channels to 3
conv_to_3ch = keras.layers.Conv2D(3, (1, 1), padding='same')(input_layer)

# Load pretrained model using the converted 3-channel tensor as input
base = applications.InceptionResNetV2(
    include_top=False,
    weights='imagenet',
    input_tensor=conv_to_3ch
)
# Optional: Freeze pretrained layers first for transfer learning
base.trainable = False

# Build the rest of the model
x = base.output
x = keras.layers.BatchNormalization(renorm=True)(x)
x = keras.layers.GlobalAveragePooling2D()(x)
x = keras.layers.Dense(512, activation='relu')(x)
x = keras.layers.Dropout(0.5)(x)
x = keras.layers.Dense(256, activation='relu')(x)
x = keras.layers.Dropout(0.3)(x)
x = keras.layers.Dense(128, activation='relu')(x)
output = keras.layers.Dense(7, activation='softmax')(x)

model = keras.Model(inputs=input_layer, outputs=output)

After training the top layers, you can unfreeze some of the pretrained layers and fine-tune for better performance.


2. Fixing Low Accuracy & Class Imbalance Issue

Your model's predictions are centralized on the largest class (90 images) because your dataset has severe class imbalance (other classes have 20-50 images each). Here are actionable fixes:

Data-Level Fixes

  • Oversample minority classes: Apply data augmentation (flips, rotations, zooms, etc.) to small classes to increase their sample count. Keras' ImageDataGenerator or tf.data can handle this easily.
  • Class weights: Calculate weights for each class inversely proportional to their size, then pass them to model.fit():
    # Assume class indices 0-6, with class 6 having 90 images, others 20-50
    class_counts = [20, 30, 40, 25, 50, 35, 90]
    total_samples = sum(class_counts)
    class_weights = {i: total_samples / count for i, count in enumerate(class_counts)}
    # Use in training
    model.fit(train_data, epochs=50, class_weight=class_weights)
    

Model-Level Fixes

  • Focal Loss: Replace standard cross-entropy with Focal Loss, which down-weights easy-to-classify samples (the large class) and focuses on hard cases:
    import tensorflow.keras.backend as K
    
    def focal_loss(gamma=2.0, alpha=0.25):
        def focal_loss_fn(y_true, y_pred):
            pt_1 = tf.where(tf.equal(y_true, 1), y_pred, tf.ones_like(y_pred))
            pt_0 = tf.where(tf.equal(y_true, 0), y_pred, tf.zeros_like(y_pred))
            return -K.sum(alpha * K.pow(1. - pt_1, gamma) * K.log(K.epsilon() + pt_1)) \
                   -K.sum((1 - alpha) * K.pow(pt_0, gamma) * K.log(K.epsilon() + 1. - pt_0))
        return focal_loss_fn
    
    # Compile model with Focal Loss
    model.compile(optimizer='adam', loss=focal_loss(gamma=2, alpha=0.25), metrics=['accuracy'])
    

Training Strategy

  • Monitor metrics like F1-score or confusion matrix instead of just accuracy—accuracy is misleading with imbalanced data.
  • Use stratified cross-validation to ensure each training fold has the same class distribution as your full dataset.

备注:内容来源于stack exchange,提问作者Syuuuu

火山引擎 最新活动