In Keras, you can save the best model during training using the ModelCheckpoint
callback. This callback allows you to monitor a specific metric (e.g., validation accuracy or validation loss) and save the model when it achieves the best value of that metric. Here’s a working example with demo input and labels:
# Import necessary libraries import keras from keras.models import Sequential from keras.layers import Dense from keras.callbacks import ModelCheckpoint # Generate demo input and labels import numpy as np X = np.random.rand(10000, 10) # Example features y = np.random.randint(2, size=(10000,)) # Example binary labels # Create a Keras model model = Sequential() model.add(Dense(units=64, activation='relu', input_dim=10)) model.add(Dense(units=1, activation='sigmoid')) # Compile the model model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy']) # Define a ModelCheckpoint callback to save the best model checkpoint = ModelCheckpoint( "best_model.h5", # Filepath to save the best model monitor="val_accuracy", # Metric to monitor (e.g., validation accuracy) save_best_only=True, # Save only the best model mode="max", # "max" if monitoring validation accuracy, "min" for loss verbose=1 # Display messages when saving ) # Train the model with the ModelCheckpoint callback model.fit(X, y, epochs=10, validation_split=0.2, callbacks=[checkpoint])
Output:
Epoch 1/10
228/250 [==========================>...] - ETA: 0s - loss: 0.6966 - accuracy: 0.4938
Epoch 1: val_accuracy improved from -inf to 0.47900, saving model to best_model.h5
250/250 [==============================] - 1s 3ms/step - loss: 0.6965 - accuracy: 0.4961 - val_loss: 0.6949 - val_accuracy: 0.4790
Epoch 2/10
241/250 [===========================>..] - ETA: 0s - loss: 0.6943 - accuracy: 0.5008
Epoch 2: val_accuracy improved from 0.47900 to 0.50000, saving model to best_model.h5
250/250 [==============================] - 0s 2ms/step - loss: 0.6942 - accuracy: 0.5016 - val_loss: 0.6940 - val_accuracy: 0.5000
Epoch 3/10
249/250 [============================>.] - ETA: 0s - loss: 0.6933 - accuracy: 0.5048
Epoch 3: val_accuracy improved from 0.50000 to 0.50200, saving model to best_model.h5
250/250 [==============================] - 0s 2ms/step - loss: 0.6933 - accuracy: 0.5050 - val_loss: 0.6937 - val_accuracy: 0.5020
Epoch 4/10
247/250 [============================>.] - ETA: 0s - loss: 0.6930 - accuracy: 0.5095
Epoch 4: val_accuracy did not improve from 0.50200
250/250 [==============================] - 0s 2ms/step - loss: 0.6930 - accuracy: 0.5099 - val_loss: 0.6942 - val_accuracy: 0.5005
Epoch 5/10
239/250 [===========================>..] - ETA: 0s - loss: 0.6923 - accuracy: 0.5141
Epoch 5: val_accuracy did not improve from 0.50200
250/250 [==============================] - 0s 2ms/step - loss: 0.6923 - accuracy: 0.5142 - val_loss: 0.6960 - val_accuracy: 0.4910
Epoch 6/10
247/250 [============================>.] - ETA: 0s - loss: 0.6922 - accuracy: 0.5195
Epoch 6: val_accuracy did not improve from 0.50200
250/250 [==============================] - 0s 2ms/step - loss: 0.6922 - accuracy: 0.5200 - val_loss: 0.6945 - val_accuracy: 0.4975
Epoch 7/10
232/250 [==========================>...] - ETA: 0s - loss: 0.6913 - accuracy: 0.5290
Epoch 7: val_accuracy improved from 0.50200 to 0.50850, saving model to best_model.h5
250/250 [==============================] - 0s 2ms/step - loss: 0.6917 - accuracy: 0.5260 - val_loss: 0.6935 - val_accuracy: 0.5085
Epoch 8/10
247/250 [============================>.] - ETA: 0s - loss: 0.6916 - accuracy: 0.5201
Epoch 8: val_accuracy did not improve from 0.50850
250/250 [==============================] - 0s 2ms/step - loss: 0.6916 - accuracy: 0.5214 - val_loss: 0.6934 - val_accuracy: 0.5075
Epoch 9/10
237/250 [===========================>..] - ETA: 0s - loss: 0.6915 - accuracy: 0.5224
Epoch 9: val_accuracy did not improve from 0.50850
250/250 [==============================] - 1s 2ms/step - loss: 0.6916 - accuracy: 0.5215 - val_loss: 0.6935 - val_accuracy: 0.5085
Epoch 10/10
222/250 [=========================>....] - ETA: 0s - loss: 0.6912 - accuracy: 0.5312
Epoch 10: val_accuracy improved from 0.50850 to 0.51950, saving model to best_model.h5
250/250 [==============================] - 0s 2ms/step - loss: 0.6912 - accuracy: 0.5309 - val_loss: 0.6933 - val_accuracy: 0.5195
<keras.callbacks.History at 0x7a52874c84c0>
output demonstrates the training process of a Keras model with the ModelCheckpoint
callback. Here’s what’s happening in the output:
- In Epoch 1, the validation accuracy improves from -inf (initial value) to 0.479, so it saves the model.
- In Epoch 4, the validation accuracy doesn’t improve further, so it doesn’t save the model.
- In Epoch 7, the validation accuracy improves to 0.5085, so it saves the model.
- Finally, in Epoch 10, the validation accuracy improves to 0.5195, and the model is saved again.
In this coding example:
- We create a simple Keras model for binary classification.
- We compile the model with the necessary settings.
- We use the
ModelCheckpoint
callback to save the best model during training. The parameters in theModelCheckpoint
constructor are set as follows:"best_model.h5"
: The filepath where the best model will be saved.monitor="val_accuracy"
: We monitor the validation accuracy to determine the best model.save_best_only=True
: It ensures that only the best model will be saved.mode="max"
: We want to maximize the validation accuracy. Use “min” if monitoring validation loss.verbose=1
: This displays a message when a new best model is saved.
- We train the model using demo input and labels. The
ModelCheckpoint
callback will save the best model based on the validation accuracy. You can adjust the monitoring metric and other parameters as needed for your specific task.