Post

Weight Decay is Not L2 Regularization

Weight Decay is Not L2 Regularization

When training neural networks, the choice and configuration of optimizers can make or break your results. A particularly subtle pitfall is that PyTorch’s weight_decay parameter on many adaptive optimizers—like Adam or RMSprop—actually applies L2 regularization rather than true weight decay. With vanilla stochastic gradient descent (SGD) the distinction is largely academic, but when you’re using adaptive methods it can lead to noticeably worse generalization if you’re not careful.

We once treated weight decay and L2 regularization as interchangeable terms in machine learning. In fact, this is exactly what I was taught over 10 years ago. Even today Google’s AI search results still equate one with the other saying, “weight decay, also known as L2 regularization…” In this post I’ll unpack why weight decay and L2 regularization were considered the same historically, why the two diverge under adaptive optimizers, and show you how to get the appropriate weight decay behavior in PyTorch.

Neural Network Optimization Basics

Before we get into weight decay and L2 regularization, let’s briefly review how neural networks are trained and how weights get updated.

At its core, training a neural network is all about adjusting the model’s weights (or parameters, which I’ll use interchangeably) to minimize a loss function. The loss function tells us how badly our model is performing on the training data. We use an optimizer, like SGD, to find the weights that result in the lowest loss. In its simplest form, the weight update rule for SGD at each time step t looks like this:

\[\Theta_{t+1} = \Theta_t - \alpha \cdot \nabla L(\Theta_t)\]

Here, $\Theta$ is a vector representing the weights of our model. At each step, we calculate the gradient of the loss function with respect to the weights, $\nabla L(\Theta_t)$, and take a small step in the opposite direction. The scalar $\alpha$ (learning rate) controls how big of a step we take. This process is repeated thousands or millions of times until the model converges.

Understanding L2 Regularization

With that context, let’s talk about L2 regularization. L2 regularization is a technique to prevent overfitting, which is when a model learns the training data too well and fails to generalize to new, unseen data. It works by adding a penalty term to the loss function. This new term is the sum of the squares of all the weights, multiplied by a small scalar hyperparameter $\lambda$:

\[L_{total}(\Theta) = L_{data}(\Theta) + \frac{\lambda}{2} \|\Theta\|_2^2\]

The $\lVert\Theta\rVert_2^2$ term is the squared L2 norm of the weight vector, $\Theta$, which is simply the sum of the squares of all its elements. The $\lambda$ parameter controls the strength of the regularization. A larger $\lambda$ means a stronger penalty for large weights, which encourages the model to keep its weights small. The intuition behind L2 regularization is that by keeping model weights small, no individual weight can exert disproportionate influence on the output, thereby reducing the model’s tendency to overfit.

What is Weight Decay?

Weight decay, on the other hand, seeks to reduce the risk of overfitting by directly nudging weights toward zero. It does this by applying a small multiplicative shrinkage factor to the weights at each update step. Here is what it looks like with SGD:

Decay step: multiply each weight by the factor $(1-\alpha\lambda)$ (which will be a bit less than 1):

\[\Theta_{t}' = (1-\alpha\lambda)\Theta_t\]

Optimizer step: take the normal optimizer update (SGD in this case) using the original gradient:

\[\Theta_{t+1} = \Theta_{t}' - \alpha \nabla L_{\text{data}}(\Theta_t)\]

Here $\alpha$ is the learning-rate and $\lambda$ tunes how aggressively weights shrink (larger $\lambda$ means faster decay).

Why Were They Considered the Same?

So, why did we use L2 regularization and weight decay interchangeably in the past? Let’s see how L2 regularization impacts the weight update step when using SGD. First, we need to find the gradient of our new total loss with respect to the weights:

\[\begin{align*} L_{total}(\Theta) &= L_{data}(\Theta) + \frac{\lambda}{2} \|\Theta\|_2^2 \\ \nabla L_{total}(\Theta) &= \nabla L_{data}(\Theta) + \lambda \Theta \end{align*}\]

Now, let’s plug this into our SGD update rule:

\[\begin{align*} \Theta_{t+1} & = \Theta_t - \alpha \cdot (\nabla L_{data}(\Theta_t) + \lambda \Theta_t) \\ & = \Theta_t - \alpha \cdot \nabla L_{data}(\Theta_t) - \alpha \lambda \Theta_t \\ & = (1 - \alpha \lambda) \Theta_t - \alpha \cdot \nabla L_{data}(\Theta_t) \end{align*}\]

Look at that, the update steps are identical! This is why L2 regularization and weight decay were historically treated as the same thing, especially in the context of plain stochastic gradient descent (SGD). In fact, many deep learning libraries implemented weight decay as L2 regularization in their optimizers, which further reinforced this misconception.

If the math works out the same, then what’s the problem? Issues arise when we introduce adaptive optimizers like RMSprop and Adam. These optimizers adjust the learning rate for each parameter based on its historical gradients. This is where the equivalence breaks down. The paper Decoupled Weight Decay Regularization (Loshchilov & Hutter, 2017) explains this in detail. They focus on Adam, but I’ll use RMSprop as an example since it’s a bit simpler to grasp.

In SGD, weights are updated by subtracting the gradient of the loss function scaled by a fixed learning rate. However, gradients can be noisy, and using the same learning rate for all parameters may lead to inefficient or unstable updates. RMSprop improves upon this by maintaining an exponentially decaying average of the squared gradients for each parameter. This moving average is used to scale the gradient, effectively adapting the learning rate per parameter. As a result, RMSprop helps stabilize training and improve convergence, especially in settings with noisy or sparse gradients. The update step for RMSprop looks like this:

\[\begin{align*} V_t &= \beta V_{t-1} + (1 - \beta) (\nabla L(\Theta_t))^2 \\ \Theta_{t+1} &= \Theta_t - \frac{\alpha}{\sqrt{V_t} + \epsilon} \nabla L(\Theta_t) \end{align*}\]

The vector $V_t$ keeps a moving average of the squared gradients, where $\nabla L(\Theta_t)$ is the gradient at the current step t. The scalars $\beta$ and $\epsilon$ are hyperparameters, with $\epsilon$ being a small number to prevent division by zero. The key part is that the learning rate $\alpha$ is divided by the square root of $V_t$, effectively giving each parameter its own learning rate.

Now let’s look at what happens if we add L2 regularization to our cost function and use RMSprop:

\[\begin{align*} L_{total}(\Theta_t) & = L_{data}(\Theta_t) + \frac{\lambda}{2} \|\Theta_t\|_2^2 \\ \nabla L_{total}(\Theta_t) & = \nabla L_{data}(\Theta_t) + \lambda \Theta_t \end{align*}\]

When we plug this into the RMSprop update rule, we get:

\[\begin{align*} V_t &= \beta V_{t-1} + (1 - \beta) (\nabla L_{total}(\Theta_t))^2 \\ \Theta_{t+1} &= \Theta_t - \frac{\alpha}{\sqrt{V_t} + \epsilon} (\nabla L_{data}(\Theta_t) + \lambda \Theta_t) \\ & = \Theta_t - \frac{\alpha}{\sqrt{V_t} + \epsilon} \nabla L_{data}(\Theta_t) - \frac{\alpha }{\sqrt{V_t} + \epsilon}\lambda \Theta_t \\ & = (1 - \frac{\alpha \lambda}{\sqrt{V_t} + \epsilon}) \Theta_t - \frac{\alpha}{\sqrt{V_t} + \epsilon} \nabla L_{data}(\Theta_t) \end{align*}\]

Notice the problem? What would be our weight decay term $(1 - \alpha \lambda) \Theta_t$, now divides $\alpha \lambda$ by $\sqrt{V_t}$. This means the effective weight decay is different for each parameter and changes over time, which is not what we want from true weight decay. High-variance parameters (with large historical gradients) will get less decay, and low-variance parameters will get more. This is not the uniform shrinkage we expect from weight decay, and it turns out to not be very helpful either.

The Solution: Decoupled Weight Decay

The solution proposed by Loshchilov and Hutter is pretty straightforward. Instead of using L2 regularization in the loss function to get weight decay, just apply the weight decay directly to the weights in the update step. This way, we ensure that the weight decay is applied uniformly across all parameters, regardless of the optimizer used. The update step for RMSprop with decoupled weight decay looks like this:

\[\begin{align*} \Theta_{t}' &= (1-\alpha\lambda)\Theta_t \\ V_t &= \beta V_{t-1} + (1 - \beta) (\nabla L_{data}(\Theta_t))^2 \\ \Theta_{t+1} &= \Theta_{t}' - \frac{\alpha}{\sqrt{V_t} + \epsilon} \nabla L_{data}(\Theta_t) \end{align*}\]

Proper Weight Decay in PyTorch

Great! Now that we know weight decay and L2 regularization are not the same thing, and weight decay is probably the thing we actually want, how do we use it in PyTorch? You might think that using the weight_decay parameter in PyTorch’s optimizers like Adam or RMSprop would do the trick, but that’s not always the case. Even though this issue was thoroughly worked out back in 2017, PyTorch still preserves this bug in many instances. Luckily, the docs for each optimizer will tell you how the weight_decay parameter is implemented. For example, in the Adam optimizer, weight_decay is implemented as L2 regularization (probably not what you want). If you want to use decoupled weight decay, you need to use the AdamW optimizer or pass the decoupled_weight_decay argument to the Adam optimizer:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import torch
import torch.optim as optim

# Using AdamW with decoupled weight decay
model = ...  # Your model here
optimizer = optim.AdamW(
    model.parameters(),
    lr=0.001,
    weight_decay=0.01
)

# Alternatively, using Adam with decoupled weight decay
model = ...  # Your model here
optimizer = optim.Adam(
    model.parameters(),
    lr=0.001,
    weight_decay=0.01,
    decoupled_weight_decay=True
)

As far as I’m aware, PyTorch does not come with a built-in RMSprop optimizer that supports decoupled weight decay, but if you have your heart set on using RMSprop, you could just use AdamW and set the momentum coefficient, the first element of the betas parameter tuple, to 0, which will essentially give you RMSprop with decoupled weight decay:

1
2
3
4
5
6
7
8
9
10
11
import torch
import torch.optim as optim

# Using AdamW with momentum set to 0 for RMSprop-like behavior
model = ...  # Your model here
optimizer = optim.AdamW(
    model.parameters(),
    lr=0.001,
    weight_decay=0.01,
    betas=(0, 0.99)
)

Conclusion: Getting Weight Decay Right Matters

In the world of deep learning, subtle implementation details can make a significant difference in model performance. The difference between weight decay and L2 regularization is one such detail that’s easy to overlook but can be crucial when using adaptive optimizers.

To summarize the key points:

  1. With standard SGD, weight decay and L2 regularization are mathematically equivalent.
  2. With adaptive optimizers like Adam and RMSprop, they diverge significantly.
  3. True weight decay applies a uniform shrinkage to all parameters, while L2 regularization with adaptive optimizers leads to parameter-specific decay that varies with gradient history. See Decoupled Weight Decay Regularization (Loshchilov & Hutter, 2017) for more details.
  4. In PyTorch, use AdamW or set decoupled_weight_decay=True with Adam to use proper weight decay.

When using adaptive optimizers like Adam, prefer decoupled weight decay (e.g. AdamW) over L2 regularization to ensure consistent parameter shrinkage.

This post is licensed under CC BY 4.0 by the author.