You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

加载Keras H5模型时出现'Unrecognized keyword arguments: ['batch_shape']' TypeError的问题咨询

加载Keras H5模型时出现'Unrecognized keyword arguments: ['batch_shape']' TypeError的问题咨询

看起来你遇到的是TensorFlow/Keras模型序列化与反序列化时的参数兼容性问题——虽然你用了相同版本的TF/Keras(2.13),但模型保存时生成的InputLayer配置里包含了batch_shape这个当前版本不再支持的关键字参数,导致加载失败。

问题原因

你构建模型时直接复用了base_model.input作为模型的输入,而VGG16的底层输入层在序列化时会把输入信息以batch_shape的形式保存,但当前版本的InputLayer.from_config()方法不再接受这个参数,只支持shape参数,因此触发了这个TypeError。

解决方法

这里有两种可行的方案,你可以根据需求选择:

方案1:修改模型构建代码,显式定义输入层

重新构建模型时手动创建Input层,而非直接复用base model的input,这样保存的模型输入层会用shape参数而非batch_shape

import tensorflow as tf
from tensorflow.keras.applications.vgg16 import VGG16
from tensorflow.keras.models import Model
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense

def build_model():
    base_model = VGG16(include_top=False, input_shape=(224, 224, 3))
    # 显式定义输入层
    inputs = tf.keras.layers.Input(shape=(224, 224, 3))
    x = base_model(inputs)
    x = GlobalAveragePooling2D()(x)
    x = Dense(128, activation='relu')(x)
    model = Model(inputs=inputs, outputs=x)
    return model

# 重新训练并保存模型
model = build_model()
model.save('face_recog_vggface.h5')

方案2:加载模型时自定义InputLayer的配置处理

如果你不想重新训练模型,可以在加载时通过custom_objects处理batch_shape参数,将其转换为当前版本支持的shape参数:

from tensorflow.keras.models import load_model
from tensorflow.keras.layers import InputLayer

def custom_input_layer_from_config(config):
    # 移除batch_shape参数,转换为shape(去掉batch维度的None)
    if 'batch_shape' in config:
        config['shape'] = config.pop('batch_shape')[1:]
    return InputLayer.from_config(config)

# 加载模型时传入自定义的InputLayer处理逻辑
model = load_model('face_recog_vggface.h5', 
                   custom_objects={'triplet_loss': triplet_loss,
                                   'InputLayer': custom_input_layer_from_config})

额外建议

TensorFlow 2.x版本更推荐使用SavedModel格式(保存为文件夹而非.h5文件),这种格式的序列化更稳定,不容易出现这类参数兼容性问题。你可以这样保存和加载:

# 保存为SavedModel格式
model.save('face_recog_vggface')

# 加载SavedModel
model = tf.keras.models.load_model('face_recog_vggface', custom_objects={'triplet_loss': triplet_loss})

备注:内容来源于stack exchange,提问作者Igwe Ugochukwu

火山引擎 最新活动