LayerNorm and RMS Norm in Transformer Models


Normalization layers are crucial components in transformer models that help stabilize training. Without normalization, models often fail to converge or behave poorly. This post explores LayerNorm, RMS Norm, and their variations, explaining how they work and their implementations in modern language models.

Let’s get started.

redd francisco mE yfvS0TSY unsplash scaled

LayerNorm and RMS Norm in Transformer Models
Photo by Redd Francisco. Some rights reserved.

Overview

This post is divided into five parts; they are:

  • Why Normalization is Needed in Transformers
  • LayerNorm and Its Implementation
  • Adaptive LayerNorm
  • RMS Norm and Its Implementation
  • Using PyTorch’s Built-in Normalization

Why Normalization is Needed in Transformers

Normalization layers improve model quality in deep learning. Convolutional models typically use batch normalization after convolution layers, while transformer models interleave normalization with attention and feed-forward components.

Normalization is important for several reasons:

  1. Internal Covariate Shift: As data flows through the network, activation distributions change significantly between training steps, making training unstable and requiring careful learning rate tuning. Normalization realigns activation distributions so that updates to one layer don’t drastically affect the next layer’s function.
  2. Gradient Issues: Deep networks suffer from vanishing gradients because activation functions vary greatly near zero but remain flat at extreme values, resulting in zero gradients in those regions. Vanishing gradients prevent further training, making it essential to shift activations back toward zero.
  3. Faster Convergence: Normalization keeps gradients within reasonable bounds, making gradient descent more effective and enabling faster convergence. Additionally, normalized values cluster around zero, creating a smaller search space that accelerates finding optimal parameters during training.

Transformer models typically have many layers. For example, the Llama 3 8B model has 32 decoder blocks, each containing one attention layer and three feed-forward layers connected sequentially. This structure makes good gradient flow essential, achieved by strategically placing normalization layers.

LayerNorm and RMSNorm are the two most common normalization techniques in modern transformers. They differ in how they compute normalization statistics. The sections below describe them in detail.

LayerNorm and Its Implementation

Layer norm, like batch norm, instance norm, or group norm, performs shift and scale operations on input tensors:

$$
y = \frac{x – \mu}{\sqrt{\sigma^2 + \epsilon}}
$$

The small quantity $\epsilon$ prevents division by zero. Mean $\mu$ and variance $\sigma^2$ are computed from input data across the feature dimension. Here’s the implementation:

LayerNorm computes variance without bias correction: $\sigma^2 = \frac{1}{n} \sum_{i=1}^{n} (x_i – \mu)^2$. While you could use the unbiased estimate, this is the conventional implementation. The simple implementation above has no learnable parameters: it only shifts and scales the input tensor. Running this code produces output with a mean close to zero and a variance of 1, indicating proper normalization.

When you run this code, you may get the following output:

The output tensor retains all information but distributes values in a range more suitable for neural network operations. LayerNorm applies independently to each element in the sequence, normalizing over the entire feature vector.

You might wonder why we want zero mean and unit variance output. The answer is: not necessarily. Most LayerNorm implementations perform this:

$$
y = \gamma \frac{x – \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta
$$

where $\gamma$ and $\beta$ are learnable parameters applied independently to each vector element. Here’s the modified implementation:

Since $\gamma$ and $\beta$ apply to each vector, they must match the vector shape. You specify the vector length when creating the LayerNorm module, with parameters initialized to 1 and 0, respectively. During training, these parameters adjust to optimize output for the next layer.

Adaptive LayerNorm

The $\gamma$ and $\beta$ parameters in the previous section are learnable, but sometimes you want them to be adaptive to the input $x$ instead of using the same value for all inputs. Adaptive LayerNorm, introduced by Xu et al. in 2019, implements this idea. While not common in language models, it’s popular in other architectures like diffusion models.

In equation, the adaptive layer norm from the original paper is:

$$
y = C (1 – kx) \odot \frac{x – \mu}{\sqrt{\sigma^2 + \epsilon}}
$$

where $C$ is a hyperparameter and $k$ is fixed at 0.1. The $(1-kx)$ multiplication is elementwise. Other variations exist, but the core idea is making scale and shift parameters functions of input data. A popular implementation uses linear layers to compute these parameters:

RMS Norm and Its Implementation

Most recent transformer models use RMS Norm instead of LayerNorm. The key difference is that RMS Norm only scales the input without shifting it. The mathematical formulation is:

$$\text{RMSNorm}(x) = \gamma \odot \frac{x}{\sqrt{\frac{1}{d} \sum_{i=1}^{d} x_i^2 + \epsilon}}$$

where $x$ is a vector of dimension $d$. The denominator computes the root mean squared value of vector elements. The small quantity $\epsilon$ prevents division by zero, and $\gamma$ is a learnable vector for elementwise multiplication.

Compared to LayerNorm, RMS Norm requires fewer calculations and has a smaller memory footprint. Here’s the implementation:

RMS Norm may not perform as well as LayerNorm in some cases because it doesn’t center activations around zero. However, it’s less sensitive to outliers since it doesn’t subtract the mean. Choosing between RMS Norm and LayerNorm is ultimately a design decision for transformer models.

Using PyTorch’s Built-in Normalization

While understanding how to implement normalization from scratch is valuable, you should use PyTorch’s built-in modules for better performance in practice.

PyTorch’s LayerNorm includes scale and shift parameters, while RMSNorm has only the scale parameter. Here’s how to use them:

You can verify that each module has learnable parameters:

Further Readings

Below are some resources that you may find useful:

Summary

In this post, you learned about normalization techniques in transformer models. Specifically, you learned about:

  • Why normalization is necessary for training stability in deep networks
  • How LayerNorm and RMS Norm work and their mathematical formulations
  • How to implement these normalization techniques from scratch
  • How to use PyTorch’s built-in normalization layers

Normalization is a fundamental component that enables the training of deep transformer models. Understanding these techniques helps in designing more stable and efficient architectures.

Learn Transformers and Attention!

Building Transformer Models with Attention

Teach your deep learning model to read a sentence

…using transformer models with attention

Discover how in my new Ebook:

Building Transformer Models with Attention

It provides self-study tutorials with working code to guide you into building a fully-working transformer models that can

translate sentences from one language to another

Give magical power of understanding human language for
Your Projects

See What’s Inside


Leave a Comment