如何在Keras训练过程中保存性能最优的模型?
Save the Best Model During Keras Training
Absolutely! You’re looking for Keras' ModelCheckpoint callback—it’s tailor-made to solve exactly this problem: saving the optimal version of your model as it trains, even when your validation loss has those annoying peaks and valleys.
Here’s how to implement it step by step:
Step 1: Import the Required Callbacks
First, make sure you import both EarlyStopping (which you’re already using) and ModelCheckpoint:
from keras.callbacks import EarlyStopping, ModelCheckpoint from keras.models import Sequential from keras.layers import Dense
Step 2: Configure Your Callbacks
Define your EarlyStopping monitor as before, then set up ModelCheckpoint to track your validation loss and save only the best model:
# Define early stopping (you likely already have this) early_stopping_monitor = EarlyStopping(patience=5, monitor='val_loss') # Configure ModelCheckpoint to save the best model model_checkpoint = ModelCheckpoint( filepath='best_model.h5', # Path where the best model will be saved monitor='val_loss', # Metric to monitor (use val_accuracy if you care more about accuracy) save_best_only=True, # Only save the model when the monitored metric improves mode='min', # 'min' because we want to minimize val_loss; use 'max' for accuracy verbose=1 # Print a message when a better model is saved )
Step 3: Update Your Training Code
Fix the small typo in your original code (model_2.compile should be model.compile) and add both callbacks to the fit method:
input_shape = # Your input shape here X, y = # Your training data here X_test = # Your test data here model = Sequential() model.add(Dense(100, activation='relu', input_shape=input_shape)) model.add(Dense(1)) # Compile the model (fixed the typo here) model.compile(optimizer='adam', loss='mean_squared_error', metrics=['accuracy']) # Train with both callbacks model.fit( X, y, epochs=15, validation_split=0.4, callbacks=[early_stopping_monitor, model_checkpoint], verbose=False ) # Later, when you want to use the best model best_model = keras.models.load_model('best_model.h5') predictions = best_model.predict(X_test)
How It Works
- The
ModelCheckpointcallback will evaluate your model’s performance on the validation set after every epoch. - If the
val_loss(or whatever metric you’re monitoring) is better than the previous best, it will overwrite the saved model file with this new, better version. - Once training finishes,
best_model.h5will contain the exact version of your model that achieved the lowest validation loss during training—no more worrying about those late-epoch performance drops!
Notes:
- If you care more about validation accuracy than loss, just change
monitor='val_accuracy'andmode='max'in theModelCheckpointconfig. - Make sure the filepath you choose is accessible (you can use an absolute path if needed) to avoid permission issues.
内容的提问来源于stack exchange,提问作者dJOKER_dUMMY




