How to Freeze Layers in Keras

  • Post category:Keras / Python

Freezing layers in a Keras model means making certain layers unchangeable during training. This is useful when you have a pre-trained model and want to keep some parts fixed while adapting others. Here’s how to do it:

Working Example:

Before Freezing:

# Import necessary libraries
import keras
from keras.models import Sequential
from keras.layers import Dense

# Create a simple Keras model (for demonstration)
model = Sequential()

# Add layers to the model
model.add(Dense(units=64, activation='relu', input_dim=100))
model.add(Dense(units=32, activation='relu'))
model.add(Dense(units=10, activation='softmax'))

# Print the model summary before freezing layers
print("Model summary before freezing layers:")
model.summary()
Model summary before freezing layers:
Model: "sequential_8"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 dense_18 (Dense)            (None, 64)                6464      
                                                                 
 dense_19 (Dense)            (None, 32)                2080      
                                                                 
 dense_20 (Dense)            (None, 10)                330       
                                                                 
=================================================================
Total params: 8,874
Trainable params: 8,874
Non-trainable params: 0
_________________________________________________________________

After Freezing:

# Identify the index of the last layer you want to freeze
last_frozen_layer_index = 1  # Freeze all layers up to the second layer

# Iterate through the layers and set them to be non-trainable (frozen)
for layer in model.layers[:last_frozen_layer_index + 1]:
    layer.trainable = False

# Compile the model
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

# Print the model summary after freezing layers
print("\nModel summary after freezing layers:")
model.summary()
Model summary after freezing layers:
Model: "sequential_8"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 dense_18 (Dense)            (None, 64)                6464      
                                                                 
 dense_19 (Dense)            (None, 32)                2080      
                                                                 
 dense_20 (Dense)            (None, 10)                330       
                                                                 
=================================================================
Total params: 8,874
Trainable params: 330
Non-trainable params: 8,544
_________________________________________________________________

In this example:

  1. We create a simple Keras Sequential model with three layers.
  2. Before freezing any layers, we print the model summary to see the initial architecture.
  3. We set the last_frozen_layer_index to 1, which means we’re going to freeze the first two layers.
  4. We iterate through the layers, setting layer.trainable to False for layers up to and including the second layer.
  5. After freezing the layers, we compile the model again.
  6. Finally, we print the model summary after freezing layers to see which layers are trainable.

You’ll notice that the layers up to the second layer have their trainable property set to False in the second model summary, indicating that they are frozen and won’t be updated during training.