How to Use Early Stopping in Keras

  • Post category:Keras / Python

Early stopping in Keras is a technique to prevent overfitting during training by monitoring a specified validation metric and stopping training when the metric stops improving. You can use the EarlyStopping callback to implement this in your Keras model. Here’s how to use early stopping:

# Import necessary libraries
import keras
from keras.models import Sequential
from keras.layers import Dense
from keras.callbacks import EarlyStopping

# Generate demo input and labels
import numpy as np
X = np.random.rand(10000, 10)  # Example features
y = np.random.randint(2, size=(10000,))  # Example binary labels


# Create a Keras model (for demonstration)
model = Sequential()
model.add(Dense(units=64, activation='relu', input_dim=10))
model.add(Dense(units=32, activation='relu'))
model.add(Dense(units=1, activation='sigmoid'))

# Compile the model
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

# Define the EarlyStopping callback
early_stopping = EarlyStopping(
    monitor='val_loss',  # Metric to monitor (e.g., validation loss)
    patience=5,  # Number of epochs with no improvement after which training will be stopped
    verbose=1,  # Display a message when early stopping is applied
    restore_best_weights=True  # Restore model weights to the best achieved during training
)

# Train the model with the EarlyStopping callback
model.fit(X, y, epochs=100, validation_split=0.2, callbacks=[early_stopping])
  1. We create a simple Keras model for binary classification and compile it.
  2. We define the EarlyStopping callback, specifying the metric to monitor (in this case, val_loss, the validation loss), the patience parameter (the number of epochs with no improvement before stopping), and other optional settings.
  3. The EarlyStopping callback is passed to the callbacks parameter of the model.fit method.

Output:

Epoch 1/100
250/250 [==============================] - 1s 3ms/step - loss: 0.6943 - accuracy: 0.4924 - val_loss: 0.6938 - val_accuracy: 0.5105
Epoch 2/100
250/250 [==============================] - 1s 2ms/step - loss: 0.6932 - accuracy: 0.5099 - val_loss: 0.6926 - val_accuracy: 0.5090
Epoch 3/100
250/250 [==============================] - 1s 2ms/step - loss: 0.6928 - accuracy: 0.5092 - val_loss: 0.6932 - val_accuracy: 0.5090
Epoch 4/100
250/250 [==============================] - 1s 2ms/step - loss: 0.6928 - accuracy: 0.5185 - val_loss: 0.6935 - val_accuracy: 0.4985
Epoch 5/100
250/250 [==============================] - 1s 2ms/step - loss: 0.6921 - accuracy: 0.5190 - val_loss: 0.6933 - val_accuracy: 0.5080
Epoch 6/100
250/250 [==============================] - 1s 2ms/step - loss: 0.6922 - accuracy: 0.5164 - val_loss: 0.6931 - val_accuracy: 0.5080
Epoch 7/100
243/250 [============================>.] - ETA: 0s - loss: 0.6916 - accuracy: 0.5231Restoring model weights from the end of the best epoch: 2.
250/250 [==============================] - 1s 3ms/step - loss: 0.6916 - accuracy: 0.5236 - val_loss: 0.6946 - val_accuracy: 0.4955
Epoch 7: early stopping
<keras.callbacks.History at 0x7a526e781840>

This output illustrates the use of early stopping during the training of a Keras model. Here’s what’s happening in the output:

  • Training proceeds for a maximum of 100 epochs, as specified in the model.fit method.
  • For each epoch, you see updates on the training and validation accuracy and loss.
  • The EarlyStopping callback is monitoring the validation loss (val_loss), with a patience of 5 epochs.
  • In Epoch 7, it observes that the validation loss is no longer improving after three epochs, which exceeds the patience threshold.
  • As a result, it stops the training process early and restores the model weights to those at the end of the best epoch (Epoch 2) when the validation loss was lower.
  • The message “Epoch 7: early stopping” is displayed to indicate that early stopping was applied.