Language models need to understand relationships between words in a sequence, regardless of their distance. This post explores how attention mechanisms enable this capability and their various implementations in modern language models.
Let’s get started.

A Gentle Introduction to Multi-Head Attention and Grouped-Query Attention
Photo by Ye Min Htet. Some rights reserved.
Overview
This post is divided into three parts; they are:
- Why Attention is Needed
- The Attention Operation
- Multi-Head Attention (MHA)
- Grouped-Query Attention (GQA) and Multi-Query Attention (MQA)
Why Attention is Needed
Traditional neural networks struggle with long-range dependencies in sequences. Consider the sentence:
“The animal didn’t cross the road because it was too tired.”
To understand what “it” refers to, the model needs to look back at “animal”, a relationship that spans multiple words.
Another example is translating the sentence, “I want to try on a suit that I saw in a shop that’s across the street from the hotel”. This is a well-known sentence that demonstrates the differences between languages, as shown in the illustration below:
If you translate this sentence from English to French, you can probably match the words one by one in their original order. However, it’s not straightforward because the word “want” in English can be “veux”, “voulons”, or “voulez” depending on whether it’s associated with “I”, “we”, or “you”. Therefore, a translation model needs to attend to both “Je” (the French equivalent of “I”, to determine the verb form) and “want” (to determine the verb) to find the correct translation.
Translating into Japanese is even more challenging because Japanese uses subject-object-verb (SOV) word order. When the model sees “I want…”, it needs to wait until the end of the sentence to determine the object. After the model produced “私は” (the Japanese equivalent of “I”) and to create “ホテル” (the Japanese equivalent of “hotel”), the immediately preceding word doesn’t influence the translation, but it must match the final word in the English sentence.
The purpose of the attention mechanism is to help the model focus on the relevant parts of the sequence while ignoring the rest.
The Attention Operation
The attention mechanism was invented to solve the problem of long-range dependencies in translation models. It’s easiest to understand in the translation context.
Let’s consider that we’ve processed the English sentence from the previous section. Now the model is producing French words one by one. After the first word “Je”, the model needs to decide what the second word should be.
First, we define the French words produced so far as the “query” sequence and the processed English sentence as the “key” sequence. The attention operation first computes the attention scores:
$$
\frac{QK^T}{\sqrt{d}}
$$
This is a matrix where element $(i,j)$ is the score of alignment between the $i$-th French word and the $j$-th English word. The higher the score, the more closely the two words are aligned. Alignment doesn’t mean equivalence, but rather indicates what should be the next focus. In this example, the alignment should focus on “want” in the English sentence, which should have the highest attention score. The $\sqrt{d}$ part in the formula above is a constant that scales the attention score. You can ignore its effect for now.
After that, the attention score is normalized by the softmax function:
$$
\text{softmax}\Big(\frac{QK^T}{\sqrt{d}}\Big)
$$
The softmax function normalizes the attention scores so that each row sums to 1. The reason for this becomes clear in the next step, which is to compute the weighted sum of the values:
$$
O = \text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right)V
$$
$V$ is the “value” sequence. In this example, you can think of it as the French translation of each English word in the original sentence. This is just a simple translation without considering any context. In the case of “want”, it might be “vouloir”, not “veux”.
The equation above computes a weighted sum of the values. Because we want the output $O$ as a word, the sum of weights should be 1 to keep the output from becoming too large or too small. Why use a weighted sum instead of picking a single word translation? Because one word in French may be related to multiple words in the English sentence. In this example, you may need 90% of “want” (“vouloir”) and 10% of “I” (“je”) to create the correct verb form “veux” as the next output word in French.
One missing piece in the above explanation is how the words in the English sentence become French words so that you can use them in the “value” sequence. This is indeed produced by a “projection matrix”. In full, the equations are:
$$
\begin{aligned}
Q &= F W^Q \\
K &= E W^K \\
V &= E W^V \\
X_O &= \Big(\text{softmax}\big(\frac{QK^T}{\sqrt{d}}\big)V\Big)W^O
\end{aligned}
$$
where $E$ and $F$ are the sequences of English words and the partial sequence of French words, respectively. $X_O$ is the output, which is the next word in French. $W^Q$, $W^K$, and $W^V$ are the projection matrices that transform a sequence to a different space. $W^O$ is the projection matrix that transforms the output $O$ back to the original space $X_O$.
Multi-Head Attention (MHA)
The description in the previous section is just a high-level view of the attention operation. In a translation model, the sequence fed into attention is actually not a sequence of words, but a sequence of word embedding vectors. The projection matrices transform the embedding vectors to a different space, and the attention operation is applied in this transformed space.
How do we determine the projection matrices? This is actually difficult. The reason is that each word can be transformed into multiple different spaces. For example, there is a “meaning space” to represent the meaning of the word. There can also be a part-of-speech space to indicate whether a word is a noun, a verb, or an adjective. You don’t need to pick just one. Nothing prevents you from using multiple spaces in parallel.
This is why “multi-head attention” (MHA) is introduced. You use not one, but many sets of projection matrices, and each set performs its own attention operation. Then the outputs are concatenated to produce the final output.
Because you have multiple attention heads that are independent of each other, you can run them in parallel. The original Transformer architecture uses 8 attention heads and has been found to perform well in translation tasks.
In equation form, MHA can be represented as:
$$
\begin{aligned}
\text{Attention}(Q, K, V) &= \text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right)V \\
\text{head}_i &= \text{Attention}(X_QW^Q_i, X_KW^K_i, X_VW^V_i) \\
\text{MultiHead}(X_Q, X_K, X_V) &= \text{Concat}(\text{head}_1, …, \text{head}_h)W^O
\end{aligned}
$$
In PyTorch, you can create your own MHA layer by implementing the above equations:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
import torch import torch.nn as nn import torch.nn.functional as F import math
class MultiHeadAttention(nn.Module): def __init__(self, d_model, num_heads): super().__init__() self.d_model = d_model self.num_heads = num_heads self.head_dim = d_model // num_heads
self.q_proj = nn.Linear(d_model, d_model) self.k_proj = nn.Linear(d_model, d_model) self.v_proj = nn.Linear(d_model, d_model) self.out_proj = nn.Linear(d_model, d_model)
def forward(self, x): batch_size = x.size(0) seq_length = x.size(1)
# Project queries, keys, and values q = self.q_proj(x).view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) k = self.k_proj(x).view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) v = self.v_proj(x).view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
# Compute attention scores, optionally add attention mask to the score scores = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(self.head_dim) attn_weights = F.softmax(scores, dim=–1) # optional: attn_weights = F.dropout(attn_weights, p=0.2)
# Apply attention weights to values context = torch.matmul(attn_weights, v).transpose(1, 2).contiguous() context = context.view(batch_size, seq_length, self.d_model)
return self.out_proj(context) |
One notable characteristic of attention is that we usually keep the vectors in the input and output sequences in the same dimension. This is usually not a problem for the simple attention operation. But for MHA, there are multiple attention heads and the output will be concatenated along the vector dimension. Therefore, each head should operate in a reduced dimension, head_dim = d_model // num_heads
, to make the concatenation possible.
In the constructor above, the input projection matrices are defined as q_proj
, k_proj
, and v_proj
. The output projection matrix is defined as out_proj
.
In the forward()
function, the input x
is projected to q
, k
, and v
by the input projection matrices. Input x
has shape (batch_size, seq_length, d_model)
. The same shape is retained after the projection, but then it is reshaped and transposed to (batch_size, num_heads, seq_length, head_dim)
. The matmul()
to compute scores
aligns the head_dim
dimension (last axis), and the resulting attention score has the shape (batch_size, num_heads, seq_length, seq_length)
. Then softmax is applied along the last axis so that the sum along this axis is 1.
If you need to apply a mask to the attention, such as the causal mask usually used in the decoder-only models, you should apply this to the scores
tensor before applying softmax. Some models implement dropout on the attention weights after the softmax. It is believed that this can help the model become more robust.
The attention weights are then multiplied with v
, and the result is transposed back to the shape (batch_size, seq_length, num_heads, head_dim)
. The contiguous()
is used to make the result contiguous in memory so that the vectors from each head can be concatenated using view()
back to the original shape (batch_size, seq_length, d_model)
. This tensor is then projected by out_proj
and serves as the output of the attention operation.
The above implementation is self-attention because the forward()
function in the class uses the same input x
to create q
, k
, and v
. In cross-attention, one input sequence is used for q
and another one for k
and v
.
Note that in PyTorch, view()
is the faster way to change the shape of a tensor then reshape()
but it requires the tensor in a contiguous shape. A tensor will not be contiguous if you transposed some axes. You should call contiguous()
on a tensor to move the memory around to make it contiguous again. This is the case in the line creating context
tensor.
Indeed, in PyTorch, the above class is already implemented as torch.nn.MultiheadAttention
. You should use it instead.
Grouped-Query Attention (GQA) and Multi-Query Attention (MQA)
Multi-Head Attention (MHA) is the most powerful attention mechanism, but it involves a lot of computation. There are multiple ways to reduce the computational cost. Grouped-Query Attention (GQA) is the most popular one.
GQA reduces the computational cost of MHA by sharing key and value projections across groups of query heads:
$$
\begin{aligned}
\text{head}_i &= \text{Attention}(X_QW^Q_i, X_KW^K_{g(i)}, X_VW^V_{g(i)}) \\
\text{GQA}(X_Q, X_K, X_V) &= \text{Concat}(\text{head}_1, …, \text{head}_h)W^O
\end{aligned}
$$
Compared to MHA, GQA uses the same projection matrices $W^K_{g(i)}$ and $W^V_{g(i)}$ for all the query heads in the same group $g(i)$. A usual grouping is to split the heads evenly, such as:
$$
\begin{aligned}
g(i) &= \left\lfloor \frac{i}{m} \right\rfloor \\
\therefore\; 0 &= g(0) = g(1) = \cdots = g(m-1) \\
1 &= g(m) = g(m+1) = \cdots = g(2m-1) \\
\vdots \\
\end{aligned}
$$
It’s easy to modify the above code example into GQA:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 |
import torch import torch.nn as nn import torch.nn.functional as F import math
class GroupedQueryAttention(nn.Module): def __init__(self, d_model, num_heads, num_groups): super().__init__() self.d_model = d_model self.num_heads = num_heads # num of query heads self.num_groups = num_groups self.group_size = num_heads // num_groups self.head_dim = d_model // num_heads
self.q_proj = nn.Linear(d_model, self.num_heads * self.head_dim) self.k_proj = nn.Linear(d_model, self.num_groups * self.head_dim) self.v_proj = nn.Linear(d_model, self.num_groups * self.head_dim) self.out_proj = nn.Linear(self.num_heads * self.head_dim, d_model)
def forward(self, x): batch_size = x.size(0) seq_length = x.size(1)
# Project queries, keys, and values q = self.q_proj(x).view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) k = self.k_proj(x).view(batch_size, seq_length, self.num_groups, self.head_dim).transpose(1, 2) v = self.v_proj(x).view(batch_size, seq_length, self.num_groups, self.head_dim).transpose(1, 2)
# Expand k and v to match the number of query heads k = k.repeat_interleave(self.group_size, dim=1) v = v.repeat_interleave(self.group_size, dim=1)
# Compute attention scores scores = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(self.head_dim) attn_weights = F.softmax(scores, dim=–1) # optional: attn_weights = F.dropout(attn_weights, p=0.2)
# Apply attention weights to values context = torch.matmul(attn_weights, v).transpose(1, 2).contiguous() context = context.view(batch_size, seq_length, self.d_model)
return self.out_proj(context) |
Compared to MHA, the projection matrices k_proj
and v_proj
are different. Specifically, these projection matrices are smaller, so matrix multiplication is faster to compute.
With the projection matrices k_proj
and q_proj
in different shapes, the multiplication between q
and k
is impossible. Hence you need to use repeat_interleave()
to expand k
to the same shape as q
. This works only if num_heads
is divisible by num_groups
. For the same reason, you need to use repeat_interleave()
to expand v
as well.
Alternatively, you can use PyTorch’s built-in scaled_dot_product_attention()
function to simplify the implementation above:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
import torch import torch.nn as nn import torch.nn.functional as F import math
class GroupedQueryAttention(nn.Module): def __init__(self, d_model, num_heads, num_groups): super().__init__() self.d_model = d_model self.num_heads = num_heads # num of query heads self.num_groups = num_groups self.group_size = num_heads // num_groups self.head_dim = d_model // num_heads
self.q_proj = nn.Linear(d_model, self.num_heads * self.head_dim) self.k_proj = nn.Linear(d_model, self.num_groups * self.head_dim) self.v_proj = nn.Linear(d_model, self.num_groups * self.head_dim) self.out_proj = nn.Linear(self.num_heads * self.head_dim, d_model)
def forward(self, x): batch_size = x.size(0) seq_length = x.size(1)
# Project queries, keys, and values q = self.q_proj(x).view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) k = self.k_proj(x).view(batch_size, seq_length, self.num_groups, self.head_dim).transpose(1, 2) v = self.v_proj(x).view(batch_size, seq_length, self.num_groups, self.head_dim).transpose(1, 2)
# Compute attention scores using PyTorch’s built-in function attn_output = F.scaled_dot_product_attention(q, k, v, enable_gqa=True)
# Project for output context = attn_output.transpose(1, 2).contiguous() context = context.view(batch_size, seq_length, self.d_model) return self.out_proj(context) |
Note that the single function call to scaled_dot_product_attention()
with enable_gqa=True
replaces the calls to repeat_interleave()
, matmul()
, and softmax()
.
It has been found that GQA can reduce memory usage and computation time with minimal impact on model quality.
If you set the number of groups to 1 in GQA, it becomes Multi-Query Attention (MQA). But if you set the number of groups to the same as the number of query heads, it falls back to multi-head attention.
Further Readings
Below are some papers that are related to the topic:
Summary
In this post, you learned about attention mechanisms in language models. In particular, you learned about:
- Why attention is crucial for capturing relationships in sequences
- How Multi-Head Attention enables different types of relationships
- How Grouped-Query Attention balances efficiency and performance