How to Save Best Model in Keras

  • Post category:Keras / Python

In Keras, you can save the best model during training using the ModelCheckpoint callback. This callback allows you to monitor a specific metric (e.g., validation accuracy or validation loss) and save the model when it achieves the best value of that metric. Here’s a working example with demo input and labels:

# Import necessary libraries
import keras
from keras.models import Sequential
from keras.layers import Dense
from keras.callbacks import ModelCheckpoint

# Generate demo input and labels
import numpy as np
X = np.random.rand(10000, 10)  # Example features
y = np.random.randint(2, size=(10000,))  # Example binary labels

# Create a Keras model
model = Sequential()
model.add(Dense(units=64, activation='relu', input_dim=10))
model.add(Dense(units=1, activation='sigmoid'))

# Compile the model
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

# Define a ModelCheckpoint callback to save the best model
checkpoint = ModelCheckpoint(
    "best_model.h5",  # Filepath to save the best model
    monitor="val_accuracy",  # Metric to monitor (e.g., validation accuracy)
    save_best_only=True,  # Save only the best model
    mode="max",  # "max" if monitoring validation accuracy, "min" for loss
    verbose=1  # Display messages when saving
)

# Train the model with the ModelCheckpoint callback
model.fit(X, y, epochs=10, validation_split=0.2, callbacks=[checkpoint])

Output:

Epoch 1/10
228/250 [==========================>...] - ETA: 0s - loss: 0.6966 - accuracy: 0.4938
Epoch 1: val_accuracy improved from -inf to 0.47900, saving model to best_model.h5
250/250 [==============================] - 1s 3ms/step - loss: 0.6965 - accuracy: 0.4961 - val_loss: 0.6949 - val_accuracy: 0.4790
Epoch 2/10
241/250 [===========================>..] - ETA: 0s - loss: 0.6943 - accuracy: 0.5008
Epoch 2: val_accuracy improved from 0.47900 to 0.50000, saving model to best_model.h5
250/250 [==============================] - 0s 2ms/step - loss: 0.6942 - accuracy: 0.5016 - val_loss: 0.6940 - val_accuracy: 0.5000
Epoch 3/10
249/250 [============================>.] - ETA: 0s - loss: 0.6933 - accuracy: 0.5048
Epoch 3: val_accuracy improved from 0.50000 to 0.50200, saving model to best_model.h5
250/250 [==============================] - 0s 2ms/step - loss: 0.6933 - accuracy: 0.5050 - val_loss: 0.6937 - val_accuracy: 0.5020
Epoch 4/10
247/250 [============================>.] - ETA: 0s - loss: 0.6930 - accuracy: 0.5095
Epoch 4: val_accuracy did not improve from 0.50200
250/250 [==============================] - 0s 2ms/step - loss: 0.6930 - accuracy: 0.5099 - val_loss: 0.6942 - val_accuracy: 0.5005
Epoch 5/10
239/250 [===========================>..] - ETA: 0s - loss: 0.6923 - accuracy: 0.5141
Epoch 5: val_accuracy did not improve from 0.50200
250/250 [==============================] - 0s 2ms/step - loss: 0.6923 - accuracy: 0.5142 - val_loss: 0.6960 - val_accuracy: 0.4910
Epoch 6/10
247/250 [============================>.] - ETA: 0s - loss: 0.6922 - accuracy: 0.5195
Epoch 6: val_accuracy did not improve from 0.50200
250/250 [==============================] - 0s 2ms/step - loss: 0.6922 - accuracy: 0.5200 - val_loss: 0.6945 - val_accuracy: 0.4975
Epoch 7/10
232/250 [==========================>...] - ETA: 0s - loss: 0.6913 - accuracy: 0.5290
Epoch 7: val_accuracy improved from 0.50200 to 0.50850, saving model to best_model.h5
250/250 [==============================] - 0s 2ms/step - loss: 0.6917 - accuracy: 0.5260 - val_loss: 0.6935 - val_accuracy: 0.5085
Epoch 8/10
247/250 [============================>.] - ETA: 0s - loss: 0.6916 - accuracy: 0.5201
Epoch 8: val_accuracy did not improve from 0.50850
250/250 [==============================] - 0s 2ms/step - loss: 0.6916 - accuracy: 0.5214 - val_loss: 0.6934 - val_accuracy: 0.5075
Epoch 9/10
237/250 [===========================>..] - ETA: 0s - loss: 0.6915 - accuracy: 0.5224
Epoch 9: val_accuracy did not improve from 0.50850
250/250 [==============================] - 1s 2ms/step - loss: 0.6916 - accuracy: 0.5215 - val_loss: 0.6935 - val_accuracy: 0.5085
Epoch 10/10
222/250 [=========================>....] - ETA: 0s - loss: 0.6912 - accuracy: 0.5312
Epoch 10: val_accuracy improved from 0.50850 to 0.51950, saving model to best_model.h5
250/250 [==============================] - 0s 2ms/step - loss: 0.6912 - accuracy: 0.5309 - val_loss: 0.6933 - val_accuracy: 0.5195
<keras.callbacks.History at 0x7a52874c84c0>

output demonstrates the training process of a Keras model with the ModelCheckpoint callback. Here’s what’s happening in the output:

  • In Epoch 1, the validation accuracy improves from -inf (initial value) to 0.479, so it saves the model.
  • In Epoch 4, the validation accuracy doesn’t improve further, so it doesn’t save the model.
  • In Epoch 7, the validation accuracy improves to 0.5085, so it saves the model.
  • Finally, in Epoch 10, the validation accuracy improves to 0.5195, and the model is saved again.

In this coding example:

  1. We create a simple Keras model for binary classification.
  2. We compile the model with the necessary settings.
  3. We use the ModelCheckpoint callback to save the best model during training. The parameters in the ModelCheckpoint constructor are set as follows:
    • "best_model.h5": The filepath where the best model will be saved.
    • monitor="val_accuracy": We monitor the validation accuracy to determine the best model.
    • save_best_only=True: It ensures that only the best model will be saved.
    • mode="max": We want to maximize the validation accuracy. Use “min” if monitoring validation loss.
    • verbose=1: This displays a message when a new best model is saved.
  4. We train the model using demo input and labels. The ModelCheckpoint callback will save the best model based on the validation accuracy. You can adjust the monitoring metric and other parameters as needed for your specific task.