5.1 Attention Mechanism

An attention mechanism is a machine learning technique that directs deep learning models to prioritize (or attend to) the most relevant parts of input data. Innovation in attention mechanisms enabled the transformer architecture that yielded the modern large language models (LLMs) that power popular applications like ChatGPT.

Created Date: 2025-05-19

As their name suggests, attention mechanisms are inspired by the ability of humans (and other animals) to selectively pay more attention to salient details and ignore details that are less important in the moment. Having access to all information but focusing on only the most relevant information helps to ensure that no meaningful details are lost while enabling efficient use of limited memory and time.

Mathematically speaking, an attention mechanism computes attention weights that reflect the relative importance of each part of an input sequence to the task at hand. It then applies those attention weights to increase (or decrease) the influence of each part of the input, in accordance with its respective importance. An attention model - that is, an artificial intelligence model that employs an attention mechanism — is trained to assign accurate attention weights through supervised learning or self-supervised learning on a large dataset of examples.

An example of the attention mechanism following long-distance dependencies in the encoder self-attention in layer 5 of 6. Many of the attention heads attend to a distant dependency of the verb 'making', completing the phrase 'making...more difficult'. Attentions here shown only for the word 'making'. Different colors represent different heads. Best viewed in color:

Make ... More Difficult

Attention mechanisms were originally introduced by Bahdanau et al in 2014 as a technique to address the shortcomings of what were then state-of-the-art recurrent neural network (RNN) models used for machine translation. Subsequent research integrated attention mechanisms into the convolutional neural networks (CNNs) used for tasks such as image captioning and visual question answering.

In 2017, the seminal paper Attention is All You Need introduced the transformer model, which eschews recurrence and convolutions altogether in favor of only attention layers and standard feedforward layers. The transformer architecture has since become the backbone of the cutting-edge models powering the ongoing era of generative AI.

While attention mechanisms are primarily associated with LLMs used for natural language processing (NLP) tasks, such as summarization, question answering, text generation and sentiment analysis, attention-based models are also used widely in other domains. Leading diffusion models used for image generation often incorporate an attention mechanism. In the field of computer vision, vision transformers (ViTs) have achieved superior results on tasks including object detection, image segmentation and visual question answering.

5.1.1 Queries, Keys, and Values

So far all the networks we have reviewed crucially relied on the input being of a well-defined size. For instance, the images in ImageNet are of size \(224 \times 224\) pixels and CNNs are specifically tuned to this size. Even in natural language processing the input size for RNNs is well defined and fixed.

Variable size is addressed by sequentially processing one token at a time, or by specially designed convolution kernels. This approach can lead to significant problems when the input is truly of varying size with varying information content. In particular, for long sequences it becomes quite difficult to keep track of everything that has already been generated or even viewed by the network.

Compare this to databases. In their simplest form they are collections of keys (\(k\)) and value (\(v\)). For instance, our database \(D\) might consist of tuples {("Zhang", "Aston"), ("Lipton", "Zachary"), ("Li", "Mu"), ("Smola", "Alex"), ("Hu", "Rachel"), ("Werness", "Brent")} with the last name being the key and the first name being the value.

We can operate on \(D\), for instance with the exact query for "Li" which would return the value "Mu". If ("Li", "Mu") was not a record in \(D\), there would be no valid answer. If we also allowed for approximate matches, we would retrieve ("Lipton", "Zachary") instead.

This quite simple and trivial example nonetheless teaches us a number of useful things:

  • We can design queries \(q\) that operate on \((k, v)\) pairs in such a manner as to be valid regardless of the database size.

  • The same query can receive different answers, according to the contents of the database.

  • The "code" being executed for operating on a large state space (the database) can be quite simple (e.g., exact match, approximate match, top-k).

  • There is no need to compress or simplify the database to make the operations effective.

Clearly we would not have introduced a simple database here if it wasn't for the purpose of explaining deep learning. Indeed, this leads to one of the most exciting concepts introduced in deep learning in the past decade: the attention mechanism.

We will cover the specifics of its application to machine translation later. For now, simply consider the following: denote by \(D = {(k_1, v_1), \cdots, (k_m, v_m)}\) a database of \(m\) tuples of keys and values. Moreover, denote by \(q\) a query. Then we can define the attention over \(D\) as:

\(Attention(q, D) = \sum_{i=1}^{m} \alpha(q, k_i) v_i\)

where \(\alpha(q, k_i) \in \mathbb{R} (i = 1, \cdots , m)\) are scalar attention weights. The operation itself is typically referred to as attention pooling. The name attention derives from the fact that the operation pays particular attention to the terms for which the weight \(\alpha\) is significant (i.e., large).

As such, the attention over \(D\) generates a linear combination of values contained in the database. In fact, this contains the above example as a special case where all but one weight is zero. We have a number of special cases:

  • The weights \(\alpha(q, k_i)\) are nonnegative. In this case the output of the attention mechanism is contained in the convex cone spanned by the values \(v_i\).

  • The weights \(\alpha(q, k_i)\) form a convex combination, i.e. \(\sum_i \alpha(q, k_i) = 1\) and \(\alpha(q, k_i) \ge 0\) for all \(i\). This is the most common setting in deep learning.

  • Exactly one of the weights \(\alpha(q, k_i)\) is 1, while all others are 0. This is akin to a traditional database query.

  • All weights are equal, i.e. \(\alpha(q, k_i) = \frac{1}{m}\) for all \(i\). This amounts to averaging across the entire database, also called average pooling in deep learning.

A common strategy for ensuring that the weights sum up to 1 is to normalize them via:

\(\alpha(q, k_i) = \frac{\alpha(q, k_i)}{\sum_j \alpha(q, k_j)}\)

In particular, to ensure that the weights are also nonnegative, one can resort to exponentiation. This means that we can now pick any function \(\alpha(q, k)\) and then apply the softmax operation used for multinomial models to it via:

\(\alpha(q, k_i) = \frac{exp(\alpha(q, k_i))}{\sum_j exp(\alpha(q, k_j))}\)

This operation is readily available in all deep learning frameworks. It is differentiable and its gradient never vanishes, all of which are desirable properties in a model. Note though, the attention mechanism introduced above is not the only option.

For instance, we can design a non-differentiable attention model that can be trained using reinforcement learning methods. As one would expect, training such a model is quite complex. Consequently the bulk of modern attention research follows the framework outlined in below figure. We thus focus our exposition on this family of differentiable mechanisms.

Query - (Keys, Values)

5.1.2 Attention Pooling

Now that we have introduced the primary components of the attention mechanism, let’s use them in a rather classical setting, namely regression and classification via kernel density estimation.

This detour simply provides additional background: it is entirely optional and can be skipped if needed. At their core, Nadaraya–Watson estimators rely on some similarity kernel \(\alpha(q, k)\) relating queries \(q\) to keys \(k\). Some common kernels are:

\(\alpha(q, k) = exp(-\frac{1}{2} {\left| q - k \right|}^2)\)

\(\alpha(q, k) = 1 \quad if \quad \left| q - k \right| \le 1\)

\(\alpha(q, k) = max(0, 1 - \left| q - k \right|)\)

There are many more choices that we could pick. All of the kernels are heuristic and can be tuned. For instance, we can adjust the width, not only on a global basis but even on a per-coordinate basis. Regardless, all of them lead to the following equation for regression and classification alike:

\(f(q) = \sum_i v_i \frac{\alpha(q, k_i)}{\sum_j \alpha(q, k_j)}\)

In the case of a (scalar) regression with observations \(x_i, y_i\) for features and labels respectively, \(v_i = y_i\) are scalars, \(k_i = x_i\) are vectors, and the query \(q\) denotes the new location where \(f\) should be evaluated. In the case of (multiclass) classification, we use one-hot-encoding of \(y_i\) to obtain \(v_i\). One of the convenient properties of this estimator is that it requires no training.

All the kernels \(\alpha (k, q)\) defined in this section are translation and rotation invariant; that is, if we shift and rotate \(k\) and \(q\) in the same manner, the value of \(\alpha\) remains unchanged. For simplicity we thus pick scalar arguments \(k, q \in \mathbb{R}\) and pick the key \(k = 0\) as the origin. This yields:

Kernel Functions

Different kernels correspond to different notions of range and smoothness. For instance, the boxcar kernel only attends to observations within a distance of 1 (or some otherwise defined hyperparameter) and does so indiscriminately.

To see Nadaraya–Watson estimation in action, let's define some training data. In the following we use the dependency:

\(y_i = 2 sin(x_i) + x_i + \epsilon\)

where \(\epsilon\) is drawn from a normal distribution with zero mean and unit variance. We draw 40 training examples.

Attention Pooling via Nadaraya–Watson Regression

Now that we have data and kernels, all we need is a function that computes the kernel regression estimates. Note that we also want to obtain the relative kernel weights in order to perform some minor diagnostics.

Hence we first compute the kernel between all training features (covariates) x_train and all validation features x_val. This yields a matrix, which we subsequently normalize. When multiplied with the training labels y_train we obtain the estimates.

Recall attention pooling formula, Let each validation feature be a query, and each training feature–label pair be a key–value pair. As a result, the normalized relative kernel weights (attention_w below) are the attention weights.

The first thing that stands out is that all three nontrivial kernels (Gaussian, Boxcar, and Epanechikov) produce fairly workable estimates that are not too far from the true function.

Only the constant kernel that leads to the trivial estimate \(f(x) = \frac{1}{n} \sum_i y_i\) produces a rather unrealistic result. Let’s inspect the attention weighting a bit more closely:

Adapting Attention Pooling

We could replace the Gaussian kernel with one of a different width. That is, we could use \(\alpha(\mathbf{q}, \mathbf{k}) = \exp\left(-\frac{1}{2 \sigma^2} \|\mathbf{q} - \mathbf{k}\|^2 \right)\) where \(\sigma^2\) determines the width of the kernel. Let’s see whether this affects the outcomes.

Clearly, the narrower the kernel, the less smooth the estimate. At the same time, it adapts better to the local variations. Let’s look at the corresponding attention weights.

As we would expect, the narrower the kernel, the narrower the range of large attention weights. It is also clear that picking the same width might not be ideal.

Nadaraya–Watson kernel regression is an early precursor of the current attention mechanisms. It can be used directly with little to no training or tuning, either for classification or regression. The attention weight is assigned according to the similarity (or distance) between query and key, and according to how many similar observations are available.

5.1.3 Scoring Functions

we used a number of different distance-based kernels, including a Gaussian kernel to model interactions between queries and keys. As it turns out, distance functions are slightly more expensive to compute than dot products.

Attention Output

Weights are computed with \(\alpha\) and the softmax function

Dot Product Attention

Let’s review the attention function (without exponentiation) from the Gaussian kernel for a moment:

\(a(\mathbf{q}, \mathbf{k}_i) = -\frac{1}{2} \|\mathbf{q} - \mathbf{k}_i\|^2 = \mathbf{q}^\top \mathbf{k}_i -\frac{1}{2} \|\mathbf{k}_i\|^2 -\frac{1}{2} \|\mathbf{q}\|^2\)

First, note that the final term depends on \(\mathbf{q}\) only. As such it is identical for all \((\mathbf{q}, \mathbf{k}_i)\) pairs. Normalizing the attention weights to 1, ensures that this term disappears entirely.

Second, note that both batch and layer normalization (to be discussed later) lead to activations that have well-bounded, and often constant, norms \(\|\mathbf{k}_i\|\). This is the case, for instance, whenever the keys \(\mathbf{k}_i\) were generated by a layer norm. As such, we can drop it from the definition of \(a\) without any major change in the outcome.

Last, we need to keep the order of magnitude of the arguments in the exponential function under control. Assume that all the elements of the query \(\mathbf{q} \in \mathbb{R}^d\) and the key \(\mathbf{k}_i \in \mathbb{R}^d\) are independent and identically drawn random variables with zero mean and unit variance.

The dot product between both vectors has zero mean and a variance of \(d\). To ensure that the variance of the dot product still remains \(1\) regardless of vector length, we use the scaled dot product attention scoring function. That is, we rescale the dot product by \(1/\sqrt{d}\). We thus arrive at the first commonly used attention function that is used, e.g., in Transformers:

\(a(\mathbf{q}, \mathbf{k}_i) = \mathbf{q}^\top \mathbf{k}_i / \sqrt{d}\)

Note that attention weights \(\alpha\) still need normalizing. We can simplify this further by using the softmax operation:

\(\alpha(\mathbf{q}, \mathbf{k}_i) = \mathrm{softmax}(a(\mathbf{q}, \mathbf{k}_i)) = \frac{\exp(\mathbf{q}^\top \mathbf{k}_i / \sqrt{d})}{\sum_{j=1} \exp(\mathbf{q}^\top \mathbf{k}_j / \sqrt{d})}\)

As it turns out, all popular attention mechanisms use the softmax, hence we will limit ourselves to that in the remainder of this chapter.

Convenience Functions

We need a few functions to make the attention mechanism efficient to deploy. This includes tools for dealing with strings of variable lengths (common for natural language processing) and tools for efficient evaluation on minibatches (batch matrix multiplication).

One of the most popular applications of the attention mechanism is to sequence models. Hence we need to be able to deal with sequences of different lengths. In some cases, such sequences may end up in the same minibatch, necessitating padding with dummy tokens for shorter sequences. These special tokens do not carry meaning. For instance, assume that we have the following three sentences:

Dive  into  Deep    Learning
Learn to    code    <blank>
Hello world <blank> <blank>

Since we do not want blanks in our attention model we simply need to limit \(\sum_{i=1}^n \alpha(\mathbf{q}, \mathbf{k}_i) \mathbf{v}_i\) to \(\sum_{i=1}^l \alpha(\mathbf{q}, \mathbf{k}_i) \mathbf{v}_i\) for however long, \(l \leq n\), the actual sentence is. Since it is such a common problem, it has a name: the masked softmax operation.

Let’s implement it. Actually, the implementation cheats ever so slightly by setting the values of \(\mathbf{v}_i\), for \(i > l\), to zero. Moreover, it sets the attention weights to a large negative number, such as \(-10^{6}\), in order to make their contribution to gradients and values vanish in practice.

This is done since linear algebra kernels and operators are heavily optimized for GPUs and it is faster to be slightly wasteful in computation rather than to have code with conditional (if then else) statements.

To illustrate how this function works, consider a minibatch of two examples of size \(2 \times 4\), where their valid lengths are \(2\) and \(3\), respectively. As a result of the masked softmax operation, values beyond the valid lengths for each pair of vectors are all masked as zero.

Scaled Dot Product Attention

Let’s return to the dot product attention. In general, it requires that both the query and the key have the same vector length, say \(d\), even though this can be addressed easily by replacing \(\mathbf{q}^\top \mathbf{k}\) with \(\mathbf{q}^\top \mathbf{M} \mathbf{k}\) where \(\mathbf{M}\) is a matrix suitably chosen for translating between both spaces. For now assume that the dimensions match.

In practice, we often think of minibatches for efficiency, such as computing attention for \(n\) queries and \(m\) key-value pairs, where queries and keys are of length \(d\) and values are of length \(v\). The scaled dot product attention of queries \(\mathbf Q\in\mathbb R^{n\times d}\), keys \(\mathbf K\in\mathbb R^{m\times d}\), and values \(\mathbf V\in\mathbb R^{m\times v}\) thus can be written as:

\(\mathrm{softmax}\left(\frac{\mathbf Q \mathbf K^\top }{\sqrt{d}}\right) \mathbf V \in \mathbb{R}^{n\times v}\)

Note that when applying this to a minibatch, we need the batch matrix multiplication. In the following implementation of the scaled dot product attention, we use dropout for model regularization.

Additive Attention

When queries \(\mathbf{q}\) and keys \(\mathbf{k}\) are vectors of different dimension, we can either use a matrix to address the mismatch via \(\mathbf{q}^\top \mathbf{M} \mathbf{k}\), or we can use additive attention as the scoring function. Another benefit is that, as its name indicates, the attention is additive. This can lead to some minor computational savings.

Given a query \(\mathbf{q} \in \mathbb{R}^q\) and a key \(\mathbf{k} \in \mathbb{R}^k\), the additive attention scoring function (Bahdanau et al., 2014) is given by:

\(a(\mathbf q, \mathbf k) = \mathbf w_v^\top \textrm{tanh}(\mathbf W_q\mathbf q + \mathbf W_k \mathbf k) \in \mathbb{R}\)

where \(\mathbf W_q\in\mathbb R^{h\times q}\), \(\mathbf W_k\in\mathbb R^{h\times k}\), and \(\mathbf w_v\in\mathbb R^{h}\) are the learnable parameters.

This term is then fed into a softmax to ensure both nonnegativity and normalization. An equivalent interpretation of (11.3.7) is that the query and key are concatenated and fed into an MLP with a single hidden layer. Using \(\tanh\) as the activation function and disabling bias terms, we implement additive attention as follows: