How to Resume Training from Checkpoint Keras

  • Post category:Keras / Python

To resume training from a checkpoint in Keras, you can use the ModelCheckpoint callback to save the best model during training and then load that saved model to continue training from where it left off. Here’s how to resume training from a saved checkpoint:

Training the Model and Saving Checkpoints:

# Import necessary libraries
import keras
from keras.models import Sequential
from keras.layers import Dense
from keras.callbacks import ModelCheckpoint

# Create a Keras model (for demonstration)
model = Sequential()
model.add(Dense(units=64, activation='relu', input_dim=10))
model.add(Dense(units=32, activation='relu'))
model.add(Dense(units=1, activation='sigmoid'))

# Compile the model
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

# Define the ModelCheckpoint callback to save the best model
checkpoint = ModelCheckpoint(
    "best_model.h5",  # Filepath to save the best model
    monitor='val_accuracy',  # Metric to monitor (e.g., validation accuracy)
    save_best_only=True,  # Save only the best model
    mode='max',  # 'max' if monitoring validation accuracy, 'min' for loss
    verbose=1  # Display messages when saving
)

# Train the model with the ModelCheckpoint callback
model.fit(X, y, epochs=10, validation_split=0.2, callbacks=[checkpoint])

Resuming Training from the Checkpoint:

# Import necessary libraries
import keras
from keras.models import load_model

# Load the best model checkpoint saved during training
model = load_model("best_model.h5")

# Resume training from where it left off
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

# Continue training on new data or for additional epochs
model.fit(new_X, new_y, epochs=5)  # Replace new_X and new_y with your new data

In this example:

  1. First, you train the model with the ModelCheckpoint callback, which saves the best model based on the validation accuracy.
  2. After training, you load the best model using load_model. The model is loaded with the weights and architecture that performed best during training.
  3. You compile the model again with the same optimizer, loss function, and metrics.
  4. You can then continue training the model from the point it left off. This is useful for further fine-tuning, additional training epochs, or working with new data.

By saving and loading the best model using checkpoints, you can ensure that you continue training from a point of good performance, even if you need to stop and resume the training process.