Normalization layers are crucial components in transformer models that help stabilize training. Without normalization, models often fail to converge or behave poorly. This post explores LayerNorm, RMS Norm, and their variations, explaining how they work and their implementations in modern language models.
Let’s get started.
Overview
This post is divided into five parts; they are:
- Why Normalization is Needed in Transformers
- LayerNorm and Its Implementation
- Adaptive LayerNorm
- RMS Norm and Its Implementation
- Using PyTorch’s Built-in Normalization
Why Normalization is Needed in Transformers
Normalization layers improve model quality in deep learning. Convolutional models typically use batch normalization after convolution layers, while transformer models interleave normalization with attention and feed-forward components.
Normalization is important for several reasons:
- Internal Covariate Shift: As data flows through the network, activation distributions change significantly between training steps, making training unstable and requiring careful learning rate tuning. Normalization realigns activation distributions so that updates to one layer don’t drastically affect the next layer’s function.
- Gradient Issues: Deep networks suffer from vanishing gradients because activation functions vary greatly near zero but remain flat at extreme values, resulting in zero gradients in those regions. Vanishing gradients prevent further training, making it essential to shift activations back toward zero.
- Faster Convergence: Normalization keeps gradients within reasonable bounds, making gradient descent more effective and enabling faster convergence. Additionally, normalized values cluster around zero, creating a smaller search space that accelerates finding optimal parameters during training.
Transformer models typically have many layers. For example, the Llama 3 8B model has 32 decoder blocks, each containing one attention layer and three feed-forward layers connected sequentially. This structure makes good gradient flow essential, achieved by strategically placing normalization layers.
LayerNorm and RMSNorm are the two most common normalization techniques in modern transformers. They differ in how they compute normalization statistics. The sections below describe them in detail.
LayerNorm and Its Implementation
Layer norm, like batch norm, instance norm, or group norm, performs shift and scale operations on input tensors:
$$
y = \frac{x – \mu}{\sqrt{\sigma^2 + \epsilon}}
$$
The small quantity $\epsilon$ prevents division by zero. Mean $\mu$ and variance $\sigma^2$ are computed from input data across the feature dimension. Here’s the implementation:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
import torch import torch.nn as nn
class LayerNorm(nn.Module): def __init__(self, eps=1e–5): super().__init__() self.eps = eps
def forward(self, x): # Calculate mean and variance across the last dimension(s) mean = x.mean(dim=–1, keepdim=True) var = x.var(dim=–1, keepdim=True, unbiased=False)
# Normalize x_norm = (x – mean) / torch.sqrt(var + self.eps) return x_norm
# Example usage batch_size, seq_len, hidden_dim = 2, 5, 128 x = torch.randn(batch_size, seq_len, hidden_dim) layer_norm = LayerNorm() output = layer_norm(x) print(f“Input shape: {x.shape}”) print(f“Output shape: {output.shape}”) print(f“Output mean:\n{output.mean(axis=2)}”) print(f“Output std:\n{output.std(axis=2, correction=0)}”) |
LayerNorm computes variance without bias correction: $\sigma^2 = \frac{1}{n} \sum_{i=1}^{n} (x_i – \mu)^2$. While you could use the unbiased estimate, this is the conventional implementation. The simple implementation above has no learnable parameters: it only shifts and scales the input tensor. Running this code produces output with a mean close to zero and a variance of 1, indicating proper normalization.
When you run this code, you may get the following output:
Input shape: torch.Size([2, 5, 128]) Output shape: torch.Size([2, 5, 128]) Output mean: tensor([[-1.8626e-09, 2.4214e-08, -3.7253e-09, -9.3132e-09, 1.4901e-08], [-1.4901e-08, -1.2107e-08, 1.4901e-08, -7.4506e-09, 2.2352e-08]]) Output std: tensor([[1.0000, 1.0000, 1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000, 1.0000, 1.0000]]) |
The output tensor retains all information but distributes values in a range more suitable for neural network operations. LayerNorm applies independently to each element in the sequence, normalizing over the entire feature vector.
You might wonder why we want zero mean and unit variance output. The answer is: not necessarily. Most LayerNorm implementations perform this:
$$
y = \gamma \frac{x – \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta
$$
where $\gamma$ and $\beta$ are learnable parameters applied independently to each vector element. Here’s the modified implementation:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 |
import torch import torch.nn as nn
class LayerNorm(nn.Module): def __init__(self, dim, eps=1e–5): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) self.bias = nn.Parameter(torch.zeros(dim))
def forward(self, x): mean = x.mean(dim=–1, keepdim=True) var = x.var(dim=–1, keepdim=True, unbiased=False) x_norm = (x – mean) / torch.sqrt(var + self.eps) return x_norm * self.weight + self.bias
# Example usage batch_size, seq_len, hidden_dim = 2, 5, 128 x = torch.randn(batch_size, seq_len, hidden_dim) layer_norm = LayerNorm(hidden_dim) output = layer_norm(x) print(f“Input shape: {x.shape}”) print(f“Output shape: {output.shape}”) print(f“Output mean:\n{output.mean(axis=2)}”) print(f“Output std:\n{output.std(axis=2, correction=0)}”) |
Since $\gamma$ and $\beta$ apply to each vector, they must match the vector shape. You specify the vector length when creating the LayerNorm module, with parameters initialized to 1 and 0, respectively. During training, these parameters adjust to optimize output for the next layer.
Adaptive LayerNorm
The $\gamma$ and $\beta$ parameters in the previous section are learnable, but sometimes you want them to be adaptive to the input $x$ instead of using the same value for all inputs. Adaptive LayerNorm, introduced by Xu et al. in 2019, implements this idea. While not common in language models, it’s popular in other architectures like diffusion models.
In equation, the adaptive layer norm from the original paper is:
$$
y = C (1 – kx) \odot \frac{x – \mu}{\sqrt{\sigma^2 + \epsilon}}
$$
where $C$ is a hyperparameter and $k$ is fixed at 0.1. The $(1-kx)$ multiplication is elementwise. Other variations exist, but the core idea is making scale and shift parameters functions of input data. A popular implementation uses linear layers to compute these parameters:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
import torch import torch.nn as nn
class AdaptiveLayerNorm(nn.Module): def __init__(self, dim, eps=1e–5): super().__init__() self.dim = dim self.eps = eps
# Adaptive parameters self.ada_weight = nn.Linear(dim, dim) self.ada_bias = nn.Linear(dim, dim)
def forward(self, x): # Standard LayerNorm mean = x.mean(dim=–1, keepdim=True) var = x.var(dim=–1, keepdim=True, unbiased=False) x_norm = (x – mean) / torch.sqrt(var + self.eps)
# Adaptive scaling and shifting ada_w = self.ada_weight(x) ada_b = self.ada_bias(x)
return x_norm * ada_w + ada_b
# Example usage batch_size, seq_len, hidden_dim = 2, 5, 8 x = torch.randn(batch_size, seq_len, hidden_dim)
ada_ln = AdaptiveLayerNorm(hidden_dim) output = ada_ln(x) |
RMS Norm and Its Implementation
Most recent transformer models use RMS Norm instead of LayerNorm. The key difference is that RMS Norm only scales the input without shifting it. The mathematical formulation is:
$$\text{RMSNorm}(x) = \gamma \odot \frac{x}{\sqrt{\frac{1}{d} \sum_{i=1}^{d} x_i^2 + \epsilon}}$$
where $x$ is a vector of dimension $d$. The denominator computes the root mean squared value of vector elements. The small quantity $\epsilon$ prevents division by zero, and $\gamma$ is a learnable vector for elementwise multiplication.
Compared to LayerNorm, RMS Norm requires fewer calculations and has a smaller memory footprint. Here’s the implementation:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
import torch import torch.nn as nn
class RMSNorm(nn.Module): def __init__(self, dim, eps=1e–6): super().__init__() self.dim = dim self.eps = eps self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x): # Calculate RMS across the last dimension(s) rms = torch.rsqrt(x.pow(2).mean(dim=–1, keepdim=True) + self.eps)
# Normalize x_norm = x * rms * self.weight return x_norm
# Example usage batch_size, seq_len, hidden_dim = 2, 5, 8 x = torch.randn(batch_size, seq_len, hidden_dim) rms_norm = RMSNorm(hidden_dim) output = rms_norm(x) print(f“Input shape: {x.shape}”) print(f“Output shape: {output.shape}”) print(f“Output RMS: {torch.sqrt((output**2).mean(axis=2))}”) |
RMS Norm may not perform as well as LayerNorm in some cases because it doesn’t center activations around zero. However, it’s less sensitive to outliers since it doesn’t subtract the mean. Choosing between RMS Norm and LayerNorm is ultimately a design decision for transformer models.
Using PyTorch’s Built-in Normalization
While understanding how to implement normalization from scratch is valuable, you should use PyTorch’s built-in modules for better performance in practice.
PyTorch’s LayerNorm includes scale and shift parameters, while RMSNorm has only the scale parameter. Here’s how to use them:
import torch import torch.nn as nn
# PyTorch’s LayerNorm batch_size, seq_len, hidden_dim = 2, 5, 8 x = torch.randn(batch_size, seq_len, hidden_dim)
# LayerNorm normalizes over the last dimension layer_norm = nn.LayerNorm(hidden_dim) output_ln = layer_norm(x)
# RMSNorm normalizes over the last dimension rms_norm = nn.RMSNorm(hidden_dim) output_rms = rms_norm(x) |
You can verify that each module has learnable parameters:
... print(layer_norm.weight) # nn.Parameter print(layer_norm.bias) # nn.Parameter print(rms_norm.weight) # nn.Parameter |
Further Readings
Below are some resources that you may find useful:
Summary
In this post, you learned about normalization techniques in transformer models. Specifically, you learned about:
- Why normalization is necessary for training stability in deep networks
- How LayerNorm and RMS Norm work and their mathematical formulations
- How to implement these normalization techniques from scratch
- How to use PyTorch’s built-in normalization layers
Normalization is a fundamental component that enables the training of deep transformer models. Understanding these techniques helps in designing more stable and efficient architectures.