Keras Weight Decay Hack

by on under keras
5 minute read

Weight decay, or L2 regularization, is a common regularization method used in training neural networks. The idea is to add a term to the loss which signifies the magnitude of the weight values in the network, thereby encouraging the weight values to decrease during the training process.

Intuitively, there are different ways to think about the benefit of weight decay. One way to think about it is that only weights which consistently affect the main loss term throughout the data iterations will be tuned accordingly. On the other hand, the weights which are just responding to noise and which have little to no consistent effect on the main loss term will instead be overcome by the weight decay term and will steadily decrease toward zero.

Another way to think about it is as a sort of prior. We assume that the model which can perform the task at hand can be represented with small weights, and that the presence of overly large weights signifies an attempt to fit to outliers or noise in the dataset - the dreaded “overfitting” scenario.

Anyhow, Keras has a built-in Regularizer class, and common regilarizers, like L1 and L2, can be added to each layer independently. This means that if you want a weight decay with coefficient alpha for all the weights in your network, you need to add an instance of regularizers.l2(alpha) to each layer with weights (typically Conv2D and Dense layers) as you initialize them. See the examples in the Keras docs.

The way this is set up, however, can be annoying. Firstly, it’s tedious to be adding a regularizer every time you initialize a new layer, especially if your code contains many layer initializations. This can be circumvented somewhat by creating wrapper functions for common “blocks” of the network (e.g. a residual convolutional block), so that the regularizers are hidden inside the wrappers and don’t litter your code too much.

But another common problem arising from this setup comes when you use an out-of-the-box Keras model from another code base, or load a pre-trained model file. As an example, let’s say I want to use a ResNet50 architecture to fit to my data. Luckily, Keras Applications has a function which will return a ResNet50 as a Keras model. But now, what if I want to train that model with weight decay? The layers already exist; they’re initialized inside the keras.applications.ResNet50 function. So how do I add regularizers after the fact?

This calls for a bit of a hack. What we have to do here is access those layers on which we want to apply regularizers, and apply them. We’ll loop through the model.layers and if the layer type is Conv2D or Dense, then we’ll manually call a regularizer on the layer’s weights and add the result to that layer’s loss. This is essentially the same thing that happens when you initialize a layer with a regularizer to begin with.

model = keras.applications.ResNet50(include_top=True, weights='imagenet')
alpha = 0.00002  # weight decay coefficient

for layer in model.layers:
    if isinstance(layer, keras.layers.Conv2D) or isinstance(layer, keras.layers.Dense):
        layer.add_loss(keras.regularizers.l2(alpha)(layer.kernel))
    if hasattr(layer, 'bias_regularizer') and layer.use_bias:
        layer.add_loss(keras.regularizers.l2(alpha)(layer.bias))

Make sure to do this before compiling the model. After that, we’re ready to train!

One more thing, though. If the model you’re doing this with contains Depthwise Convolution layers, and you want to apply weight decay to those layers as well, you need an extra if statement in the above loop, since the variable containing the DepthwiseConv2D layer’s conv weights has a different name. You also have to be careful, since DepthwiseConv2D inherits from Conv2D, so isinstance(layer, keras.layers.Conv2D will return True even in the case of a Depthwise Convolution. Thus, if you want to use a Mobilenet, for example, which is also available in Keras Applications, you’ve got to add the following:

model = keras.applications.MobileNet(include_top=True, weights='imagenet',
                                     alpha=1., depth_multiplier=1)
alpha = 0.00002  # weight decay coefficient

for layer in model.layers:
    if isinstance(layer, keras.layers.DepthwiseConv2D):
        layer.add_loss(keras.regularizers.l2(alpha)(l.depthwise_kernel))
    elif isinstance(layer, keras.layers.Conv2D) or isinstance(layer, keras.layers.Dense):
        layer.add_loss(keras.regularizers.l2(alpha)(layer.kernel))
    if hasattr(layer, 'bias_regularizer') and layer.use_bias:
        layer.add_loss(keras.regularizers.l2(alpha)(layer.bias))

(Note: In the original Mobilenet paper, the authors note that “ Additionally, we found that it was important to put very little or no weight decay (l2 regularization) on the depthwise filters since their [sic] are so few parameters in them.” Many open-source implementations of Mobilenet, though, do apply weight decay on the depthwise filters to an equal measure. More experimentation is required to see how much effect this has in practice.)

keras, weight-decay