A Gentle Introduction to Multi-Head Attention and Grouped-Query Attention


Language models need to understand relationships between words in a sequence, regardless of their distance. This post explores how attention mechanisms enable this capability and their various implementations in modern language models.

Let’s get started.

ye min htet uEFxAxZwyBs unsplash scaled

A Gentle Introduction to Multi-Head Attention and Grouped-Query Attention
Photo by Ye Min Htet. Some rights reserved.

Overview

This post is divided into three parts; they are:

  • Why Attention is Needed
  • The Attention Operation
  • Multi-Head Attention (MHA)
  • Grouped-Query Attention (GQA) and Multi-Query Attention (MQA)

Why Attention is Needed

Traditional neural networks struggle with long-range dependencies in sequences. Consider the sentence:

“The animal didn’t cross the road because it was too tired.”

To understand what “it” refers to, the model needs to look back at “animal”, a relationship that spans multiple words.

Another example is translating the sentence, “I want to try on a suit that I saw in a shop that’s across the street from the hotel”. This is a well-known sentence that demonstrates the differences between languages, as shown in the illustration below:

language

If you translate this sentence from English to French, you can probably match the words one by one in their original order. However, it’s not straightforward because the word “want” in English can be “veux”, “voulons”, or “voulez” depending on whether it’s associated with “I”, “we”, or “you”. Therefore, a translation model needs to attend to both “Je” (the French equivalent of “I”, to determine the verb form) and “want” (to determine the verb) to find the correct translation.

Translating into Japanese is even more challenging because Japanese uses subject-object-verb (SOV) word order. When the model sees “I want…”, it needs to wait until the end of the sentence to determine the object. After the model produced “私は” (the Japanese equivalent of “I”) and to create “ホテル” (the Japanese equivalent of “hotel”), the immediately preceding word doesn’t influence the translation, but it must match the final word in the English sentence.

The purpose of the attention mechanism is to help the model focus on the relevant parts of the sequence while ignoring the rest.

The Attention Operation

The attention mechanism was invented to solve the problem of long-range dependencies in translation models. It’s easiest to understand in the translation context.

Let’s consider that we’ve processed the English sentence from the previous section. Now the model is producing French words one by one. After the first word “Je”, the model needs to decide what the second word should be.

First, we define the French words produced so far as the “query” sequence and the processed English sentence as the “key” sequence. The attention operation first computes the attention scores:

$$
\frac{QK^T}{\sqrt{d}}
$$

This is a matrix where element $(i,j)$ is the score of alignment between the $i$-th French word and the $j$-th English word. The higher the score, the more closely the two words are aligned. Alignment doesn’t mean equivalence, but rather indicates what should be the next focus. In this example, the alignment should focus on “want” in the English sentence, which should have the highest attention score. The $\sqrt{d}$ part in the formula above is a constant that scales the attention score. You can ignore its effect for now.

After that, the attention score is normalized by the softmax function:

$$
\text{softmax}\Big(\frac{QK^T}{\sqrt{d}}\Big)
$$

The softmax function normalizes the attention scores so that each row sums to 1. The reason for this becomes clear in the next step, which is to compute the weighted sum of the values:

$$
O = \text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right)V
$$

$V$ is the “value” sequence. In this example, you can think of it as the French translation of each English word in the original sentence. This is just a simple translation without considering any context. In the case of “want”, it might be “vouloir”, not “veux”.

The equation above computes a weighted sum of the values. Because we want the output $O$ as a word, the sum of weights should be 1 to keep the output from becoming too large or too small. Why use a weighted sum instead of picking a single word translation? Because one word in French may be related to multiple words in the English sentence. In this example, you may need 90% of “want” (“vouloir”) and 10% of “I” (“je”) to create the correct verb form “veux” as the next output word in French.

One missing piece in the above explanation is how the words in the English sentence become French words so that you can use them in the “value” sequence. This is indeed produced by a “projection matrix”. In full, the equations are:

$$
\begin{aligned}
Q &= F W^Q \\
K &= E W^K \\
V &= E W^V \\
X_O &= \Big(\text{softmax}\big(\frac{QK^T}{\sqrt{d}}\big)V\Big)W^O
\end{aligned}
$$

where $E$ and $F$ are the sequences of English words and the partial sequence of French words, respectively. $X_O$ is the output, which is the next word in French. $W^Q$, $W^K$, and $W^V$ are the projection matrices that transform a sequence to a different space. $W^O$ is the projection matrix that transforms the output $O$ back to the original space $X_O$.

Multi-Head Attention (MHA)

The description in the previous section is just a high-level view of the attention operation. In a translation model, the sequence fed into attention is actually not a sequence of words, but a sequence of word embedding vectors. The projection matrices transform the embedding vectors to a different space, and the attention operation is applied in this transformed space.

How do we determine the projection matrices? This is actually difficult. The reason is that each word can be transformed into multiple different spaces. For example, there is a “meaning space” to represent the meaning of the word. There can also be a part-of-speech space to indicate whether a word is a noun, a verb, or an adjective. You don’t need to pick just one. Nothing prevents you from using multiple spaces in parallel.

This is why “multi-head attention” (MHA) is introduced. You use not one, but many sets of projection matrices, and each set performs its own attention operation. Then the outputs are concatenated to produce the final output.

Because you have multiple attention heads that are independent of each other, you can run them in parallel. The original Transformer architecture uses 8 attention heads and has been found to perform well in translation tasks.

In equation form, MHA can be represented as:

$$
\begin{aligned}
\text{Attention}(Q, K, V) &= \text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right)V \\
\text{head}_i &= \text{Attention}(X_QW^Q_i, X_KW^K_i, X_VW^V_i) \\
\text{MultiHead}(X_Q, X_K, X_V) &= \text{Concat}(\text{head}_1, …, \text{head}_h)W^O
\end{aligned}
$$

In PyTorch, you can create your own MHA layer by implementing the above equations:

One notable characteristic of attention is that we usually keep the vectors in the input and output sequences in the same dimension. This is usually not a problem for the simple attention operation. But for MHA, there are multiple attention heads and the output will be concatenated along the vector dimension. Therefore, each head should operate in a reduced dimension, head_dim = d_model // num_heads, to make the concatenation possible.

In the constructor above, the input projection matrices are defined as q_proj, k_proj, and v_proj. The output projection matrix is defined as out_proj.

In the forward() function, the input x is projected to q, k, and v by the input projection matrices. Input x has shape (batch_size, seq_length, d_model). The same shape is retained after the projection, but then it is reshaped and transposed to (batch_size, num_heads, seq_length, head_dim). The matmul() to compute scores aligns the head_dim dimension (last axis), and the resulting attention score has the shape (batch_size, num_heads, seq_length, seq_length). Then softmax is applied along the last axis so that the sum along this axis is 1.

If you need to apply a mask to the attention, such as the causal mask usually used in the decoder-only models, you should apply this to the scores tensor before applying softmax. Some models implement dropout on the attention weights after the softmax. It is believed that this can help the model become more robust.

The attention weights are then multiplied with v, and the result is transposed back to the shape (batch_size, seq_length, num_heads, head_dim). The contiguous() is used to make the result contiguous in memory so that the vectors from each head can be concatenated using view() back to the original shape (batch_size, seq_length, d_model). This tensor is then projected by out_proj and serves as the output of the attention operation.

The above implementation is self-attention because the forward() function in the class uses the same input x to create q, k, and v. In cross-attention, one input sequence is used for q and another one for k and v.

Note that in PyTorch, view() is the faster way to change the shape of a tensor then reshape() but it requires the tensor in a contiguous shape. A tensor will not be contiguous if you transposed some axes. You should call contiguous() on a tensor to move the memory around to make it contiguous again. This is the case in the line creating context tensor.

Indeed, in PyTorch, the above class is already implemented as torch.nn.MultiheadAttention. You should use it instead.

Grouped-Query Attention (GQA) and Multi-Query Attention (MQA)

Multi-Head Attention (MHA) is the most powerful attention mechanism, but it involves a lot of computation. There are multiple ways to reduce the computational cost. Grouped-Query Attention (GQA) is the most popular one.

GQA reduces the computational cost of MHA by sharing key and value projections across groups of query heads:

$$
\begin{aligned}
\text{head}_i &= \text{Attention}(X_QW^Q_i, X_KW^K_{g(i)}, X_VW^V_{g(i)}) \\
\text{GQA}(X_Q, X_K, X_V) &= \text{Concat}(\text{head}_1, …, \text{head}_h)W^O
\end{aligned}
$$

Compared to MHA, GQA uses the same projection matrices $W^K_{g(i)}$ and $W^V_{g(i)}$ for all the query heads in the same group $g(i)$. A usual grouping is to split the heads evenly, such as:

$$
\begin{aligned}
g(i) &= \left\lfloor \frac{i}{m} \right\rfloor \\
\therefore\; 0 &= g(0) = g(1) = \cdots = g(m-1) \\
1 &= g(m) = g(m+1) = \cdots = g(2m-1) \\
\vdots \\
\end{aligned}
$$

It’s easy to modify the above code example into GQA:

Compared to MHA, the projection matrices k_proj and v_proj are different. Specifically, these projection matrices are smaller, so matrix multiplication is faster to compute.

With the projection matrices k_proj and q_proj in different shapes, the multiplication between q and k is impossible. Hence you need to use repeat_interleave() to expand k to the same shape as q. This works only if num_heads is divisible by num_groups. For the same reason, you need to use repeat_interleave() to expand v as well.

Alternatively, you can use PyTorch’s built-in scaled_dot_product_attention() function to simplify the implementation above:

Note that the single function call to scaled_dot_product_attention() with enable_gqa=True replaces the calls to repeat_interleave(), matmul(), and softmax().

It has been found that GQA can reduce memory usage and computation time with minimal impact on model quality.

If you set the number of groups to 1 in GQA, it becomes Multi-Query Attention (MQA). But if you set the number of groups to the same as the number of query heads, it falls back to multi-head attention.

Further Readings

Below are some papers that are related to the topic:

Summary

In this post, you learned about attention mechanisms in language models. In particular, you learned about:

  • Why attention is crucial for capturing relationships in sequences
  • How Multi-Head Attention enables different types of relationships
  • How Grouped-Query Attention balances efficiency and performance

Learn Transformers and Attention!

Building Transformer Models with Attention

Teach your deep learning model to read a sentence

…using transformer models with attention

Discover how in my new Ebook:

Building Transformer Models with Attention

It provides self-study tutorials with working code to guide you into building a fully-working transformer models that can

translate sentences from one language to another

Give magical power of understanding human language for
Your Projects

See What’s Inside


Leave a Comment