Keras中预训练模型输入形状修改方法求助:8通道输入场景下的报错与精度问题
Hi there! Let's tackle your two main issues one by one—getting the pretrained InceptionResNetV2 to work with 8-channel inputs, and fixing the low accuracy caused by class imbalance when training from scratch.
1. Fixing the Pretrained Model Input Channel Error
The error happens because the ImageNet-pretrained InceptionResNetV2 is explicitly trained for 3-channel RGB images, so its first convolutional layer expects exactly 3 input channels. There are two reliable ways to adapt it to 8-channel inputs:
Option 1: Modify the First Convolutional Layer's Input Channels
This approach keeps the pretrained weights for the first 3 channels and initializes weights for the extra 5 channels, preserving as much of the pretrained feature extraction capability as possible:
import numpy as np from tensorflow import keras from tensorflow.keras import applications # Load the pretrained 3-channel model first base_3ch = applications.InceptionResNetV2( include_top=False, weights='imagenet', input_shape=[512, 512, 3] ) # Extract weights from the first convolutional layer original_weights = base_3ch.layers[0].get_weights() original_kernel, original_bias = original_weights # Create a new 8-channel input convolutional layer (matches original layer params except input channels) new_input_conv = keras.layers.Conv2D( filters=32, kernel_size=(3, 3), strides=(2, 2), padding='valid', activation='relu', input_shape=[512, 512, 8] ) # Adjust the kernel weights: keep original 3-channel weights, initialize extra 5 channels randomly # Original kernel shape: (3, 3, 3, 32) → we need (3, 3, 8, 32) new_kernel = np.concatenate( [original_kernel, np.random.normal(size=(3, 3, 5, 32))], axis=2 ) new_input_conv.set_weights([new_kernel, original_bias]) # Build the new base model with the modified input layer base = keras.Sequential() base.add(new_input_conv) # Add all remaining layers from the pretrained model for layer in base_3ch.layers[1:]: base.add(layer) # Build your full model as before model = keras.Sequential([ base, keras.layers.BatchNormalization(renorm=True), keras.layers.GlobalAveragePooling2D(), keras.layers.Dense(512, activation='relu'), keras.layers.Dropout(0.5), keras.layers.Dense(256, activation='relu'), keras.layers.Dropout(0.3), keras.layers.Dense(128, activation='relu'), keras.layers.Dense(7, activation='softmax') ])
Option 2: Add a 1x1 Convolution to Map 8 Channels to 3
This is a simpler workaround—we add a 1x1 convolution layer to convert 8 channels to 3, then feed that into the pretrained model. Note that this might lose some channel-specific information, but it's quick to implement:
from tensorflow import keras from tensorflow.keras import applications # Define input layer for 8 channels input_layer = keras.Input(shape=[512, 512, 8]) # 1x1 convolution to map 8 channels to 3 conv_to_3ch = keras.layers.Conv2D(3, (1, 1), padding='same')(input_layer) # Load pretrained model using the converted 3-channel tensor as input base = applications.InceptionResNetV2( include_top=False, weights='imagenet', input_tensor=conv_to_3ch ) # Optional: Freeze pretrained layers first for transfer learning base.trainable = False # Build the rest of the model x = base.output x = keras.layers.BatchNormalization(renorm=True)(x) x = keras.layers.GlobalAveragePooling2D()(x) x = keras.layers.Dense(512, activation='relu')(x) x = keras.layers.Dropout(0.5)(x) x = keras.layers.Dense(256, activation='relu')(x) x = keras.layers.Dropout(0.3)(x) x = keras.layers.Dense(128, activation='relu')(x) output = keras.layers.Dense(7, activation='softmax')(x) model = keras.Model(inputs=input_layer, outputs=output)
After training the top layers, you can unfreeze some of the pretrained layers and fine-tune for better performance.
2. Fixing Low Accuracy & Class Imbalance Issue
Your model's predictions are centralized on the largest class (90 images) because your dataset has severe class imbalance (other classes have 20-50 images each). Here are actionable fixes:
Data-Level Fixes
- Oversample minority classes: Apply data augmentation (flips, rotations, zooms, etc.) to small classes to increase their sample count. Keras'
ImageDataGeneratorortf.datacan handle this easily. - Class weights: Calculate weights for each class inversely proportional to their size, then pass them to
model.fit():# Assume class indices 0-6, with class 6 having 90 images, others 20-50 class_counts = [20, 30, 40, 25, 50, 35, 90] total_samples = sum(class_counts) class_weights = {i: total_samples / count for i, count in enumerate(class_counts)} # Use in training model.fit(train_data, epochs=50, class_weight=class_weights)
Model-Level Fixes
- Focal Loss: Replace standard cross-entropy with Focal Loss, which down-weights easy-to-classify samples (the large class) and focuses on hard cases:
import tensorflow.keras.backend as K def focal_loss(gamma=2.0, alpha=0.25): def focal_loss_fn(y_true, y_pred): pt_1 = tf.where(tf.equal(y_true, 1), y_pred, tf.ones_like(y_pred)) pt_0 = tf.where(tf.equal(y_true, 0), y_pred, tf.zeros_like(y_pred)) return -K.sum(alpha * K.pow(1. - pt_1, gamma) * K.log(K.epsilon() + pt_1)) \ -K.sum((1 - alpha) * K.pow(pt_0, gamma) * K.log(K.epsilon() + 1. - pt_0)) return focal_loss_fn # Compile model with Focal Loss model.compile(optimizer='adam', loss=focal_loss(gamma=2, alpha=0.25), metrics=['accuracy'])
Training Strategy
- Monitor metrics like F1-score or confusion matrix instead of just accuracy—accuracy is misleading with imbalanced data.
- Use stratified cross-validation to ensure each training fold has the same class distribution as your full dataset.
备注:内容来源于stack exchange,提问作者Syuuuu




