Attention Mechanisms: Fundamentals

Chapter Overview

Attention mechanisms revolutionized sequence modeling by allowing models to focus on relevant parts of the input when producing each output. This chapter introduces attention from first principles, developing the query-key-value paradigm that underpins modern transformers.

Attention solves a fundamental limitation of RNN encoder-decoder models: compressing entire input sequence into single fixed-size vector. Instead, attention computes dynamic, context-dependent representations by weighted combination of all input positions.

Learning Objectives

  1. Understand the motivation for attention in sequence-to-sequence models
  2. Master the query-key-value attention paradigm
  3. Implement additive (Bahdanau) and multiplicative (Luong) attention
  4. Understand scaled dot-product attention
  5. Compute attention weights and apply to values
  6. Visualize and interpret attention distributions

Motivation: The Seq2Seq Bottleneck

RNN Encoder-Decoder Architecture

The sequence-to-sequence (seq2seq) problem requires mapping an input sequence $\vx_1, \ldots, \vx_n$ to an output sequence $\vy_1, \ldots, \vy_m$ of potentially different length. This formulation encompasses machine translation, text summarization, question answering, and many other natural language processing tasks. Before attention mechanisms, the standard approach used recurrent neural networks in an encoder-decoder architecture that suffered from a fundamental information bottleneck.

The encoder RNN processes the input sequence sequentially, updating its hidden state at each time step:

$$ \vh_t^{\text{enc}} = \text{RNN}(\vx_t, \vh_{t-1}^{\text{enc}}) $$
After processing all $n$ input tokens, the final hidden state $\mathbf{c} = \vh_n^{\text{enc}}$ serves as the context vector—a fixed-size representation intended to capture the entire input sequence. This context vector, typically 512 or 1024 dimensions for LSTM-based systems, must encode all relevant information from the source sequence regardless of its length.

The decoder RNN then generates the output sequence conditioned on this context vector:

$$ \vh_t^{\text{dec}} = \text{RNN}([\vy_{t-1}, \mathbf{c}], \vh_{t-1}^{\text{dec}}) $$
where $[\vy_{t-1}, \mathbf{c}]$ denotes concatenation of the previous output token embedding and the context vector. The decoder must rely on this single fixed-size vector throughout the entire generation process, accessing the same $\mathbf{c}$ when producing the first output word and the last.

The Information Bottleneck: Compressing an entire input sequence into a single fixed-size vector creates severe information loss, particularly for long sequences. Consider translating a 50-word English sentence to French. The encoder must compress 50 words of semantic content, syntactic structure, and contextual relationships into a 512-dimensional vector. This is fundamentally insufficient—the context vector becomes an information bottleneck that limits the model's capacity to handle complex or lengthy inputs.

Empirical evidence from early neural machine translation systems demonstrated this limitation quantitatively. For English-French translation using LSTM encoder-decoders with 1000-dimensional hidden states, translation quality (measured by BLEU score) remained stable for source sentences up to 20-25 words but degraded significantly beyond this length. Sentences of 30-40 words showed BLEU score drops of 5-10 points compared to shorter sentences, and sentences exceeding 50 words often produced nearly incomprehensible translations. The fixed-size context vector simply could not retain sufficient information about long, complex source sentences.

Memory and Computational Characteristics: The RNN encoder-decoder architecture requires $O(n+m)$ memory for storing hidden states during the forward pass, where $n$ is the source length and $m$ is the target length. For a typical translation task with $n=50$, $m=50$, and hidden dimension $d_h=1024$, this amounts to $(50+50) \times 1024 \times 4 = 400$ KB per sequence in FP32. However, the sequential nature of RNN processing prevents parallelization across time steps. Each hidden state $\vh_t$ depends on $\vh_{t-1}$, forcing strictly sequential computation. On a GPU capable of processing thousands of operations in parallel, this sequential constraint severely limits throughput.

For a batch of 32 sequences, the encoder processes $32 \times 50 = 1600$ time steps sequentially, even though the GPU could theoretically process all 1600 in parallel if the operations were independent. This sequential bottleneck means RNN encoder-decoders achieve only 5-10\% of peak GPU utilization during training, wasting the majority of available compute capacity.

Attention Solution

Attention mechanisms solve the information bottleneck by allowing the decoder to access all encoder hidden states directly, rather than relying on a single compressed representation. The key insight is that when generating each output word $\vy_t$, different input words have different relevance. When translating "The cat sat on the mat" to French, generating "chat" (cat) should focus primarily on the input word "cat," while generating "assis" (sat) should focus on "sat." The decoder's information needs change dynamically throughout generation.

Rather than computing a single context vector $\mathbf{c}$ for the entire sequence, attention computes a different context vector $\mathbf{c}_t$ for each output position $t$. This context vector is a weighted sum of all encoder hidden states:

$$ \mathbf{c}_t = \sum_{i=1}^{n} \alpha_{t,i} \vh_i^{\text{enc}} $$
where the attention weights $\alpha_{t,i}$ indicate how much the decoder should focus on input position $i$ when generating output position $t$. These weights form a probability distribution: $\alpha_{t,i} \geq 0$ and $\sum_{i=1}^n \alpha_{t,i} = 1$.

The attention weights are computed dynamically based on the current decoder state $\mathbf{s}_t$ and each encoder hidden state $\vh_i$. This allows the model to learn which input positions are relevant for each output position, adapting the context vector to the decoder's current needs. When generating the first word of a translation, the attention might focus on the beginning of the source sentence. When generating the last word, attention shifts to the end of the source.

Memory Trade-off: Attention increases memory requirements from $O(n+m)$ to $O(nm)$ because we must store attention weights $\alpha_{t,i}$ for all pairs of input and output positions. For translation with $n=50$ and $m=50$, this requires storing a $50 \times 50 = 2500$ element attention matrix. At 4 bytes per element (FP32), this is 10 KB per sequence—modest compared to the benefits. However, this quadratic scaling becomes significant for very long sequences. For document-level translation with $n=1000$ and $m=1000$, the attention matrix requires $1000^2 \times 4 = 4$ MB per sequence, or 128 MB for batch size 32.

Parallelization Benefit: The crucial advantage is that attention enables parallelization. Unlike RNN hidden states that must be computed sequentially, attention weights for all output positions can be computed simultaneously during training when the target sequence is known. This transforms the sequential $O(m)$ decoder steps into a single parallel operation, dramatically improving GPU utilization from 5-10\% to 60-80\% in practice.

Example: Consider translating the English sentence "The cat sat on the mat" to French: "Le chat Ă©tait assis sur le tapis." Without attention, the encoder compresses all six English words into a single 512-dimensional context vector, which the decoder uses to generate all seven French words. The context vector must simultaneously encode that "cat" translates to "chat," "sat" translates to "Ă©tait assis," and "mat" translates to "tapis"—a challenging compression task.

With attention, when generating "chat" (cat), the attention mechanism computes weights that heavily favor the input position containing "cat." The attention distribution might be $[0.05, 0.82, 0.03, 0.02, 0.03, 0.05]$, placing 82\% of the weight on position 2 (the word "cat"). The context vector $\mathbf{c}_2$ is then dominated by the encoder hidden state for "cat," providing the decoder with direct access to the relevant input information.

When generating "assis" (sat), the attention distribution shifts to $[0.03, 0.08, 0.75, 0.04, 0.05, 0.05]$, now focusing 75\% on position 3 (the word "sat"). The context vector $\mathbf{c}_4$ adapts to provide information about "sat" rather than "cat." This dynamic reweighting allows the decoder to access different parts of the input as needed, eliminating the information bottleneck of the fixed context vector.

Empirically, attention-based translation systems improved BLEU scores by 3-5 points on standard benchmarks and maintained consistent quality even for sentences exceeding 50 words—a regime where RNN encoder-decoders failed catastrophically.

Additive Attention (Bahdanau)

Bahdanau attention, introduced in 2015 for neural machine translation, was the first widely successful attention mechanism. It computes attention weights using an additive scoring function that combines the decoder state and encoder hidden states through learned transformations. While later superseded by more efficient mechanisms, understanding Bahdanau attention provides crucial insights into attention design and the evolution toward modern transformers.

Definition: Given encoder hidden states $\vh_1, \ldots, \vh_n \in \R^{d_h}$ and decoder hidden state $\mathbf{s}_t \in \R^{d_s}$ at time $t$, Bahdanau attention computes a context vector through four steps:

Step 1: Compute alignment scores

$$ e_{t,i} = \mathbf{v}\transpose \tanh(\mW_1 \mathbf{s}_t + \mW_2 \vh_i) $$
where $\mW_1 \in \R^{d_a \times d_s}$, $\mW_2 \in \R^{d_a \times d_h}$, $\mathbf{v} \in \R^{d_a}$, and $d_a$ is the attention dimension (typically 256-512).

Step 2: Compute attention weights (softmax)

$$ \alpha_{t,i} = \frac{\exp(e_{t,i})}{\sum_{j=1}^{n} \exp(e_{t,j})} $$

Step 3: Compute context vector

$$ \mathbf{c}_t = \sum_{i=1}^{n} \alpha_{t,i} \vh_i $$

Step 4: Use in decoder

$$ \mathbf{s}_t = \text{RNN}([\vy_{t-1}, \mathbf{c}_t], \mathbf{s}_{t-1}) $$

Computational Cost Analysis: The additive scoring function in Step 1 requires substantial computation for each query-key pair. For a single alignment score $e_{t,i}$, we must:

  1. Compute $\mW_1 \mathbf{s}_t$: $2d_a d_s$ FLOPs (matrix-vector multiplication)
  2. Compute $\mW_2 \vh_i$: $2d_a d_h$ FLOPs
  3. Add the results: $d_a$ FLOPs
  4. Apply $\tanh$: $\approx 3d_a$ FLOPs (exponentials and divisions)
  5. Compute $\mathbf{v}\transpose (\cdot)$: $2d_a$ FLOPs

Total per alignment score: approximately $2d_a(d_s + d_h + 3)$ FLOPs. For a translation task with source length $n$ and target length $m$, we compute $nm$ alignment scores, requiring:

$$ \text{Bahdanau alignment FLOPs} \approx 2nmd_a(d_s + d_h + 3) $$

For typical dimensions $n=50$, $m=50$, $d_a=256$, $d_s=d_h=512$:

$$ 2 \times 50 \times 50 \times 256 \times (512 + 512 + 3) \approx 1.3 \text{ billion FLOPs} $$

This is substantial, but the more critical issue is that these operations do not map efficiently to GPU hardware. The additive scoring function involves element-wise operations ($\tanh$), vector additions, and small matrix-vector products that achieve poor utilization on GPUs optimized for large matrix multiplications. In practice, Bahdanau attention achieves only 15-25\% of peak GPU throughput.

Memory Requirements: The attention mechanism requires storing:

For $n=50$, $m=50$, $d_h=512$, $d_a=256$ in FP32:

$$\begin{align} \text{Encoder states:} \quad &50 \times 512 \times 4 = 102 \text{ KB} \\ \text{Alignment scores:} \quad &50 \times 50 \times 4 = 10 \text{ KB} \\ \text{Attention weights:} \quad &50 \times 50 \times 4 = 10 \text{ KB} \\ \text{Intermediate:} \quad &50 \times 50 \times 256 \times 4 = 2.5 \text{ MB} \end{align}$$

The intermediate activations dominate memory usage, requiring 2.5 MB per sequence or 80 MB for batch size 32. This is manageable for short sequences but scales poorly to longer contexts.

Parameter Count: Bahdanau attention introduces $O(d_a(d_s + d_h))$ parameters:

$$\begin{align} \mW_1 &\in \R^{d_a \times d_s}: \quad d_a d_s \text{ parameters} \\ \mW_2 &\in \R^{d_a \times d_h}: \quad d_a d_h \text{ parameters} \\ \mathbf{v} &\in \R^{d_a}: \quad d_a \text{ parameters} \end{align}$$

For $d_a=256$, $d_s=d_h=512$: $(256 \times 512) + (256 \times 512) + 256 = 262{,}400$ parameters. While not enormous, these parameters must be learned specifically for the attention mechanism, adding to the model's overall capacity requirements.

Attention weights $\alpha_{t,i}$ form a probability distribution: $\alpha_{t,i} \geq 0$ and $\sum_{i=1}^n \alpha_{t,i} = 1$. This ensures the context vector $\mathbf{c}_t$ is a convex combination of encoder states, interpolating between them rather than extrapolating. The softmax normalization is crucial for training stability—without it, attention weights could grow unbounded, causing gradient explosion.
Example: Consider a small example with encoder hidden states $\vh_1, \vh_2, \vh_3 \in \R^{4}$, decoder state $\mathbf{s}_2 \in \R^{4}$, and attention dimension $d_a = 3$. We compute attention for the second decoder position.

Step 1: Compute alignment scores for each encoder position. Suppose after applying $\mW_1 \mathbf{s}_2 + \mW_2 \vh_i$ and passing through $\tanh$ and $\mathbf{v}\transpose$, we obtain:

$$\begin{align} e_{2,1} &= 0.8 \\ e_{2,2} &= 2.1 \\ e_{2,3} &= 0.5 \end{align}$$

These raw scores indicate that encoder position 2 has the highest compatibility with the current decoder state, but the scores are not yet normalized.

Step 2: Apply softmax to convert scores to a probability distribution:

$$\begin{align} \sum_j \exp(e_{2,j}) &= \exp(0.8) + \exp(2.1) + \exp(0.5) \\ &\approx 2.23 + 8.17 + 1.65 = 12.05 \end{align}$$

Computing each attention weight:

$$\begin{align} \alpha_{2,1} &= \frac{\exp(0.8)}{12.05} = \frac{2.23}{12.05} \approx 0.185 \\ \alpha_{2,2} &= \frac{\exp(2.1)}{12.05} = \frac{8.17}{12.05} \approx 0.678 \\ \alpha_{2,3} &= \frac{\exp(0.5)}{12.05} = \frac{1.65}{12.05} \approx 0.137 \end{align}$$

The decoder places 67.8\% of its attention on encoder position 2, with the remaining attention distributed between positions 1 and 3. This sharp distribution indicates high confidence about which input position is relevant.

Step 3: Compute the context vector as a weighted sum:

$$ \mathbf{c}_2 = 0.185 \vh_1 + 0.678 \vh_2 + 0.137 \vh_3 \in \R^{4} $$

If $\vh_1 = [1.0, 0.5, -0.3, 0.8]\transpose$, $\vh_2 = [0.3, 0.9, 0.6, -0.2]\transpose$, $\vh_3 = [-0.4, 0.2, 0.7, 0.5]\transpose$:

$$\begin{align} \mathbf{c}_2 &= 0.185 \begin{bmatrix} 1.0 \\ 0.5 \\ -0.3 \\ 0.8 \end{bmatrix} + 0.678 \begin{bmatrix} 0.3 \\ 0.9 \\ 0.6 \\ -0.2 \end{bmatrix} + 0.137 \begin{bmatrix} -0.4 \\ 0.2 \\ 0.7 \\ 0.5 \end{bmatrix} \\ &= \begin{bmatrix} 0.185 + 0.203 - 0.055 \\ 0.093 + 0.610 + 0.027 \\ -0.056 + 0.407 + 0.096 \\ 0.148 - 0.136 + 0.069 \end{bmatrix} = \begin{bmatrix} 0.333 \\ 0.730 \\ 0.447 \\ 0.081 \end{bmatrix} \end{align}$$

The context vector is dominated by $\vh_2$ due to the high attention weight $\alpha_{2,2} = 0.678$, but includes contributions from the other encoder states proportional to their attention weights.

Scaled Dot-Product Attention

Scaled dot-product attention, introduced in the "Attention is All You Need" paper, represents a fundamental simplification and improvement over additive attention. By replacing the learned additive scoring function with a simple scaled dot product, this mechanism achieves superior computational efficiency while maintaining or improving model performance. This design choice enabled the transformer architecture to scale to billions of parameters and become the foundation of modern large language models.

Definition: Given queries $\mQ \in \R^{m \times d_k}$, keys $\mK \in \R^{n \times d_k}$, and values $\mV \in \R^{n \times d_v}$, scaled dot-product attention computes:

Step 1: Compute attention scores

$$ \mE = \mQ \mK\transpose \in \R^{m \times n} $$
where entry $e_{i,j} = \vq_i\transpose \vk_j$ measures the compatibility of query $i$ with key $j$.

Step 2: Scale by $\sqrt{d_k}$

$$ \mE_{\text{scaled}} = \frac{\mQ \mK\transpose}{\sqrt{d_k}} $$

Step 3: Softmax over keys (row-wise)

$$ \mA = \text{softmax}\left(\frac{\mQ \mK\transpose}{\sqrt{d_k}}\right) \in \R^{m \times n} $$

Step 4: Apply attention to values

$$ \text{Attention}(\mQ, \mK, \mV) = \mA \mV \in \R^{m \times d_v} $$

The complete formula in one line:

$$ \text{Attention}(\mQ, \mK, \mV) = \text{softmax}\left(\frac{\mQ \mK\transpose}{\sqrt{d_k}}\right) \mV $$
\begin{tikzpicture}[ matrix/.style={rectangle, draw, fill=blue!15, minimum width=1.5cm, minimum height=1cm, font=\small}, operation/.style={rectangle, draw, fill=green!15, minimum width=1.8cm, minimum height=0.7cm, font=\small}, arrow/.style={->, >=stealth, thick} ]

\node[matrix] (Q) at (0,0) {$\mQ$}; \node[matrix] (K) at (2.5,0) {$\mK$}; \node[matrix] (V) at (5,0) {$\mV$};

\node[operation] (QKT) at (1.25,-2.5) {$\mQ\mK^\top$};

\node[operation] (scale) at (1.25,-4) {Scale};

\node[operation] (softmax) at (1.25,-5.5) {Softmax};

\node[operation] (AV) at (3.5,-7) {$\mA\mV$};

\node[matrix] (output) at (3.5,-8.5) {Output};

\draw[arrow] (Q) -- (QKT); \draw[arrow] (K) -- (QKT); \draw[arrow] (QKT) -- (scale); \draw[arrow] (scale) -- (softmax); \draw[arrow] (softmax) -- (AV); \draw[arrow] (V) -- (AV); \draw[arrow] (AV) -- (output);

\end{tikzpicture}

Scaled dot-product attention computational flow. Queries $\mQ$ and keys $\mK$ are combined via matrix multiplication to produce attention scores ($m \times n$ matrix). After scaling and softmax, the attention weights $\mA$ form a probability distribution over keys for each query. These weights are applied to values $\mV$ to produce the final output. Each of the $m$ queries attends to all $n$ keys, creating $m \times n$ attention connections.

Why Scaling Matters: Variance Analysis

The scaling factor $1/\sqrt{d_k}$ is not merely a normalization convenience—it is essential for maintaining stable gradients during training. To understand why, we analyze the variance of dot products between queries and keys.

Assume query and key vectors have independent elements with zero mean and unit variance: $\mathbb{E}[\vq_i] = \mathbb{E}[\vk_i] = 0$ and $\text{Var}(\vq_i) = \text{Var}(\vk_i) = 1$. The dot product between a query and key is:

$$ \vq\transpose \vk = \sum_{i=1}^{d_k} q_i k_i $$

Since the elements are independent, the variance of the sum equals the sum of variances:

$$ \text{Var}(\vq\transpose \vk) = \sum_{i=1}^{d_k} \text{Var}(q_i k_i) = \sum_{i=1}^{d_k} \mathbb{E}[q_i^2] \mathbb{E}[k_i^2] = \sum_{i=1}^{d_k} 1 \cdot 1 = d_k $$

Without scaling, the dot product has variance $d_k$, which grows linearly with the key dimension. For $d_k = 64$, typical dot products have standard deviation $\sqrt{64} = 8$. For $d_k = 512$, the standard deviation grows to $\sqrt{512} \approx 22.6$. These large magnitudes cause severe problems for the softmax function.

Softmax Saturation Problem: The softmax function is defined as:

$$ \text{softmax}(\vz)_i = \frac{\exp(z_i)}{\sum_j \exp(z_j)} $$

When input magnitudes are large, softmax saturates—one element dominates and receives nearly all the probability mass, while others receive exponentially small probabilities. Consider a simple example with two elements:

$$ \text{softmax}([z, 0]) = \left[\frac{\exp(z)}{\exp(z) + 1}, \frac{1}{\exp(z) + 1}\right] $$

For $z = 10$: $\text{softmax}([10, 0]) \approx [0.9999, 0.0001]$. For $z = 20$: $\text{softmax}([20, 0]) \approx [1.0, 2 \times 10^{-9}]$. The distribution becomes a hard selection rather than a soft weighting.

Gradient Flow Analysis: The gradient of softmax with respect to its input is:

$$ \frac{\partial \text{softmax}(\vz)_i}{\partial z_j} = \text{softmax}(\vz)_i (\delta_{ij} - \text{softmax}(\vz)_j) $$

When softmax saturates with one element near 1 and others near 0, these gradients become tiny. For the dominant element $i$ where $\text{softmax}(\vz)_i \approx 1$:

$$ \frac{\partial \text{softmax}(\vz)_i}{\partial z_i} \approx 1 \cdot (1 - 1) = 0 $$

For non-dominant elements where $\text{softmax}(\vz)_j \approx 0$:

$$ \frac{\partial \text{softmax}(\vz)_i}{\partial z_j} \approx 1 \cdot (0 - 0) = 0 $$

All gradients vanish, preventing the model from learning. This is analogous to the vanishing gradient problem in deep networks, but occurring within a single attention layer.

Scaling Solution: Dividing by $\sqrt{d_k}$ normalizes the variance:

$$ \text{Var}\left(\frac{\vq\transpose \vk}{\sqrt{d_k}}\right) = \frac{1}{d_k} \text{Var}(\vq\transpose \vk) = \frac{1}{d_k} \cdot d_k = 1 $$

With unit variance, dot products typically range from $-3$ to $+3$ (within three standard deviations), keeping softmax in its sensitive region where gradients are substantial. This maintains effective gradient flow throughout training.

Numerical Example: Consider $d_k = 64$ versus $d_k = 512$ with random unit-variance queries and keys. Without scaling, for $d_k = 64$, a typical attention score might be $\vq\transpose \vk = 12.3$. After softmax over 10 keys with similar magnitudes, the distribution might be $[0.45, 0.18, 0.12, 0.08, 0.06, 0.04, 0.03, 0.02, 0.01, 0.01]$—reasonably distributed. The gradient of the top element is approximately $0.45 \times (1 - 0.45) = 0.248$, which is healthy.

For $d_k = 512$ without scaling, the same query-key pair might produce $\vq\transpose \vk = 35.2$. After softmax, the distribution becomes $[0.9997, 0.0001, 0.0001, 0.0001, \ldots]$—completely saturated. The gradient is approximately $0.9997 \times (1 - 0.9997) = 0.0003$, which is 800 times smaller. Over many layers, these tiny gradients compound, making training extremely difficult or impossible.

With scaling by $\sqrt{512} \approx 22.6$, the score becomes $35.2 / 22.6 \approx 1.56$, producing a softmax distribution like $[0.38, 0.15, 0.12, 0.10, \ldots]$ with gradient $0.38 \times (1 - 0.38) = 0.236$—similar to the $d_k = 64$ case. The scaling makes attention behavior independent of the key dimension, enabling stable training across different model sizes.

Computational Efficiency

Scaled dot-product attention achieves dramatically better computational efficiency than additive attention, both in raw FLOP count and in hardware utilization. This efficiency difference is why transformers can scale to billions of parameters while additive attention models remained limited to hundreds of millions.

FLOP Count Comparison: For $m$ queries and $n$ keys with dimension $d_k$:

Scaled dot-product attention:

$$\begin{align} \mQ \mK\transpose &: \quad 2mnd_k \text{ FLOPs} \\ \text{Scaling} &: \quad mn \text{ FLOPs (division)} \\ \text{Softmax} &: \quad \approx 5mn \text{ FLOPs (exp, sum, divide)} \\ \mA \mV &: \quad 2mnd_v \text{ FLOPs} \\ \text{Total} &: \quad 2mn(d_k + d_v) + 6mn \approx 2mn(d_k + d_v) \end{align}$$

For $d_k = d_v = 64$, $m = n = 512$:

$$ 2 \times 512 \times 512 \times (64 + 64) = 67{,}108{,}864 \text{ FLOPs} \approx 67 \text{ MFLOPs} $$

Bahdanau attention: As computed earlier, for $d_a = 256$, $d_s = d_h = 512$, $m = n = 512$:

$$ 2 \times 512 \times 512 \times 256 \times (512 + 512 + 3) \approx 69 \text{ billion FLOPs} $$

Scaled dot-product attention requires approximately 1000× fewer FLOPs than Bahdanau attention for this configuration! The difference grows with sequence length since Bahdanau's cost scales with the attention dimension $d_a$ while scaled dot-product depends only on $d_k$.

Hardware Efficiency: Beyond raw FLOP count, scaled dot-product attention maps naturally to highly optimized GPU operations. The core computation $\mQ \mK\transpose$ is a dense matrix multiplication (GEMM), which is the most optimized operation on modern GPUs. NVIDIA's cuBLAS library and Tensor Cores are specifically designed for GEMM, achieving 80-90\% of theoretical peak performance.

In contrast, Bahdanau attention requires element-wise operations ($\tanh$), vector additions, and many small matrix-vector products. These operations achieve only 15-25\% of peak GPU performance due to memory bandwidth limitations and poor parallelization. The $\tanh$ activation requires computing exponentials for each element, which is slow compared to the fused multiply-add operations in GEMM.

Memory Bandwidth Considerations: Modern GPUs are often memory-bandwidth limited rather than compute-limited. The NVIDIA A100 has 312 TFLOPS of FP16 compute but only 1.5 TB/s memory bandwidth. For operations to be compute-bound, they must perform many FLOPs per byte loaded from memory.

Matrix multiplication $\mQ \mK\transpose$ for $\mQ, \mK \in \R^{512 \times 64}$ loads $2 \times 512 \times 64 \times 2 = 131$ KB (FP16) and performs $2 \times 512 \times 512 \times 64 = 67$ MFLOPs, achieving $67{,}000{,}000 / 131{,}072 \approx 511$ FLOPs per byte. This high arithmetic intensity keeps the GPU compute units busy.

Bahdanau's element-wise operations load data, perform a few operations, and store results—achieving only 1-5 FLOPs per byte. The GPU spends most of its time waiting for memory rather than computing, wasting the available compute capacity.

Practical Performance: On an NVIDIA A100 GPU, computing attention for a batch of 32 sequences with $n = 512$ and $d_k = 64$:

The 19× speedup from scaled dot-product attention is what enables training GPT-3 scale models (175B parameters) in reasonable time. With Bahdanau attention, training would take 19× longer, making such models economically infeasible.

Example: Consider a single query attending to 3 keys with $d_k = 4$ and $d_v = 5$:
$$ \vq = \begin{bmatrix} 1.0 \\ 0.5 \\ -0.3 \\ 0.8 \end{bmatrix}, \quad \mK = \begin{bmatrix} 0.8 & 0.2 & -0.1 & 0.5 \\ 0.3 & 0.7 & 0.4 & -0.2 \\ -0.5 & 0.1 & 0.9 & 0.6 \end{bmatrix} $$

Step 1: Compute dot products between the query and each key:

$$\begin{align} \vq\transpose \vk_1 &= 1.0(0.8) + 0.5(0.2) + (-0.3)(-0.1) + 0.8(0.5) \\ &= 0.8 + 0.1 + 0.03 + 0.4 = 1.33 \\ \vq\transpose \vk_2 &= 1.0(0.3) + 0.5(0.7) + (-0.3)(0.4) + 0.8(-0.2) \\ &= 0.3 + 0.35 - 0.12 - 0.16 = 0.37 \\ \vq\transpose \vk_3 &= 1.0(-0.5) + 0.5(0.1) + (-0.3)(0.9) + 0.8(0.6) \\ &= -0.5 + 0.05 - 0.27 + 0.48 = -0.24 \end{align}$$

Step 2: Scale by $\sqrt{d_k} = \sqrt{4} = 2$:

$$ \text{scaled scores} = \left[\frac{1.33}{2}, \frac{0.37}{2}, \frac{-0.24}{2}\right] = [0.665, 0.185, -0.120] $$

Without scaling, the scores would be $[1.33, 0.37, -0.24]$. For this small example with $d_k = 4$, the difference is modest. But for $d_k = 64$, unscaled scores would be $\sqrt{64/4} = 4$ times larger, and for $d_k = 512$, they would be $\sqrt{512/4} \approx 11.3$ times larger, causing severe softmax saturation.

Step 3: Apply softmax to obtain attention weights:

$$\begin{align} \sum_j \exp(\text{score}_j) &= \exp(0.665) + \exp(0.185) + \exp(-0.120) \\ &\approx 1.945 + 1.203 + 0.887 = 4.035 \end{align}$$

Computing each weight:

$$\begin{align} \alpha_1 &= \frac{1.945}{4.035} \approx 0.482 \\ \alpha_2 &= \frac{1.203}{4.035} \approx 0.298 \\ \alpha_3 &= \frac{0.887}{4.035} \approx 0.220 \end{align}$$

The attention is distributed across all three keys, with the highest weight on key 1 (48.2\%) but substantial attention to keys 2 and 3 as well. This soft distribution allows the model to incorporate information from multiple positions.

Step 4: Apply attention weights to values. Suppose:

$$ \mV = \begin{bmatrix} 0.5 & 0.8 & -0.2 & 0.6 & 0.3 \\ 0.2 & -0.4 & 0.7 & 0.1 & 0.9 \\ -0.3 & 0.5 & 0.4 & -0.6 & 0.2 \end{bmatrix} \in \R^{3 \times 5} $$

The output is:

$$\begin{align} \text{output} &= 0.482 \vv_1 + 0.298 \vv_2 + 0.220 \vv_3 \\ &= 0.482 \begin{bmatrix} 0.5 \\ 0.8 \\ -0.2 \\ 0.6 \\ 0.3 \end{bmatrix} + 0.298 \begin{bmatrix} 0.2 \\ -0.4 \\ 0.7 \\ 0.1 \\ 0.9 \end{bmatrix} + 0.220 \begin{bmatrix} -0.3 \\ 0.5 \\ 0.4 \\ -0.6 \\ 0.2 \end{bmatrix} \\ &= \begin{bmatrix} 0.241 + 0.060 - 0.066 \\ 0.386 - 0.119 + 0.110 \\ -0.096 + 0.209 + 0.088 \\ 0.289 + 0.030 - 0.132 \\ 0.145 + 0.268 + 0.044 \end{bmatrix} = \begin{bmatrix} 0.235 \\ 0.377 \\ 0.201 \\ 0.187 \\ 0.457 \end{bmatrix} \in \R^5 \end{align}$$

The output vector is a weighted combination of the value vectors, with weights determined by the query-key similarities. This output can then be used by subsequent layers in the transformer.

Attention Score Computation Methods

The following table summarizes the four major attention scoring functions introduced in this chapter:

MethodComputationParametersGPU Util.Used In
Additive (Bahdanau)$\mathbf{v}\transpose\tanh(\mW_1\vq + \mW_2\vk)$$O(d_a(d_q + d_k))$15--25\%Early seq2seq
Dot-product$\vq\transpose\vk$080--90\%Not used (unstable)
Scaled dot-product$\vq\transpose\vk/\sqrt{d_k}$080--90\%All transformers
General (Luong)$\vq\transpose\mW\vk$$O(d_q d_k)$50--70\%Some seq2seq

Scaled dot-product attention dominates modern transformers due to its parameter-free nature, high GPU utilization from regular matrix multiplication structure, and strong empirical performance. The simplicity enables hardware-specific optimizations like FlashAttention (Chapter~16).

Query-Key-Value Paradigm

Intuition

The query-key-value (QKV) framework provides an elegant abstraction for understanding attention mechanisms through the lens of information retrieval. This paradigm, borrowed from database systems and search engines, offers intuitive explanations for attention's behavior while precisely defining its mathematical operations.

Consider a database system where you want to retrieve relevant information. You provide a query describing what you're looking for, the system compares your query against keys (indexed descriptions of stored content), and returns the values (actual content) associated with the most relevant keys. Attention mechanisms operate identically: queries represent "what I'm looking for," keys represent "what information is available," and values represent "the actual information to retrieve."

In the context of neural networks, these three components serve distinct roles. The query $\vq$ encodes the current position's information needs—what aspects of the input are relevant for processing this position. The keys $\vk_i$ encode what information each input position offers—what content is available at that position. The values $\vv_i$ encode the actual information to be retrieved—the representations that will be combined to form the output.

This separation of concerns is crucial. By decoupling "what to look for" (queries) from "what is available" (keys) and "what to retrieve" (values), the attention mechanism gains flexibility. The same input can be queried in different ways by different positions, and the retrieved information can differ from the indexing representation. This three-way separation enables the model to learn rich, task-specific attention patterns.

Concrete Example: In machine translation, when generating the French word "chat" (cat) from the English sentence "The cat sat on the mat," the decoder's query encodes "I need information about the subject noun." The keys encode what each English word represents: "the" offers determiner information, "cat" offers subject noun information, "sat" offers verb information, etc. The attention mechanism computes high similarity between the query and the key for "cat," then retrieves the value associated with "cat"—a rich representation encoding its meaning, grammatical role, and context.

Importantly, the key and value for "cat" can differ. The key might emphasize grammatical features (noun, singular, animate) that help match queries, while the value emphasizes semantic features (animal, feline, pet) that are useful for generation. This separation allows the attention mechanism to index on one set of features while retrieving another.

Projecting to QKV

In transformers, queries, keys, and values are not provided directly but are computed from the input through learned linear projections. This design choice allows the model to learn task-specific representations for each role rather than using the raw input embeddings.

Given input $\mX \in \R^{n \times d_{\text{model}}}$ where $n$ is the sequence length and $d_{\text{model}}$ is the model dimension, we compute:

$$\begin{align} \mQ &= \mX \mW^Q && \mW^Q \in \R^{d_{\text{model}} \times d_k} \\ \mK &= \mX \mW^K && \mW^K \in \R^{d_{\text{model}} \times d_k} \\ \mV &= \mX \mW^V && \mW^V \in \R^{d_{\text{model}} \times d_v} \end{align}$$

Each projection matrix is a learned parameter that transforms the input into the appropriate representation space. The query and key projections map to the same dimension $d_k$ (typically $d_{\text{model}}/h$ where $h$ is the number of attention heads) because they must be compatible for dot products. The value projection maps to dimension $d_v$, which is often equal to $d_k$ but can differ.

Why Learn Separate Projections? One might ask: why not use the input $\mX$ directly as queries, keys, and values? The answer lies in representation learning. The raw input embeddings encode general semantic and syntactic information, but attention requires specialized representations. The query projection learns to emphasize features relevant for determining what to attend to. The key projection learns to emphasize features relevant for being attended to. The value projection learns to emphasize features relevant for the output representation.

These three projections can learn different aspects of the input. For example, in a language model, the query projection might emphasize the current word's part of speech and semantic category to determine what context is needed. The key projection might emphasize each word's grammatical role and position to help queries find relevant context. The value projection might emphasize semantic content and relationships to provide useful information for prediction.

Computational Cost: Each projection is a matrix multiplication requiring $2nd_{\text{model}}d_k$ FLOPs (for queries and keys) or $2nd_{\text{model}}d_v$ FLOPs (for values). With three projections and $d_k = d_v$:

$$ \text{QKV projection FLOPs} = 3 \times 2nd_{\text{model}}d_k = 6nd_{\text{model}}d_k $$

For typical transformer configurations where $d_k = d_{\text{model}}/h$ and we consider all $h$ heads together (so $hd_k = d_{\text{model}}$):

$$ \text{QKV projection FLOPs} = 6nd_{\text{model}}^2 $$

For BERT-base with $n = 512$ and $d_{\text{model}} = 768$:

$$ 6 \times 512 \times 768^2 = 1{,}811{,}939{,}328 \text{ FLOPs} \approx 1.8 \text{ GFLOPs} $$

This is substantial but represents only about 20\% of the total attention computation for typical sequence lengths. The attention score computation ($\mQ \mK\transpose$) and output computation ($\mA \mV$) dominate for longer sequences.

Parameter Count: The three projection matrices introduce $d_{\text{model}}(2d_k + d_v)$ parameters per attention head. For $h$ heads with $d_k = d_v = d_{\text{model}}/h$:

$$ \text{QKV parameters} = h \times d_{\text{model}} \times 3 \times \frac{d_{\text{model}}}{h} = 3d_{\text{model}}^2 $$

For BERT-base with $d_{\text{model}} = 768$: $3 \times 768^2 = 1{,}769{,}472$ parameters per attention layer. With 12 layers, the QKV projections account for $12 \times 1.77 = 21.2$ million parameters out of BERT's total 110 million—about 19\% of the model.

Example: Consider a sequence of 5 tokens, each represented by a $d_{\text{model}} = 512$ dimensional vector:
$$ \mX \in \R^{5 \times 512} $$

We project to $d_k = d_v = 64$ (as in a single attention head of a model with $h = 8$ heads):

$$\begin{align} \mQ &= \mX \mW^Q \in \R^{5 \times 64} \quad (\mW^Q \in \R^{512 \times 64}) \\ \mK &= \mX \mW^K \in \R^{5 \times 64} \quad (\mW^K \in \R^{512 \times 64}) \\ \mV &= \mX \mW^V \in \R^{5 \times 64} \quad (\mW^V \in \R^{512 \times 64}) \end{align}$$

Each projection matrix has $512 \times 64 = 32{,}768$ parameters. Computing each projection requires $2 \times 5 \times 512 \times 64 = 327{,}680$ FLOPs, for a total of $3 \times 327{,}680 = 983{,}040$ FLOPs across all three projections.

Attention computation: After projection, we compute attention:

$$ \text{Attention}(\mQ, \mK, \mV) = \text{softmax}\left(\frac{\mQ \mK\transpose}{\sqrt{64}}\right) \mV $$

The attention matrix $\mA = \text{softmax}(\mQ \mK\transpose / \sqrt{64}) \in \R^{5 \times 5}$ has entry $a_{ij}$ representing how much position $i$ attends to position $j$. For example:

$$ \mA = \begin{bmatrix} 0.45 & 0.25 & 0.15 & 0.10 & 0.05 \\ 0.10 & 0.50 & 0.25 & 0.10 & 0.05 \\ 0.05 & 0.15 & 0.40 & 0.30 & 0.10 \\ 0.05 & 0.10 & 0.20 & 0.50 & 0.15 \\ 0.05 & 0.05 & 0.10 & 0.25 & 0.55 \end{bmatrix} $$

Position 1 attends most strongly to itself (45\%) and position 2 (25\%). Position 5 attends most strongly to itself (55\%) and position 4 (25\%). This pattern might emerge in a language model where each position attends to nearby context, with stronger attention to the current position and recent tokens.

The output $\mA \mV \in \R^{5 \times 64}$ provides an attended representation for each position, combining information from all positions according to the attention weights. This output can then be processed by subsequent layers.

Hardware Implications of Attention

Attention mechanisms align well with modern GPU architectures for two key reasons. First, attention eliminates the sequential bottleneck of RNNs: all positions are processed simultaneously via batched matrix multiplications, achieving 75\% GPU utilization compared to $\sim$5\% for LSTMs (see Chapter~6 for the detailed comparison). Second, the core $\mQ\mK\transpose$ operation achieves high arithmetic intensity ($\sim$256 FLOPs/byte for typical dimensions), keeping compute units busy rather than waiting for memory transfers.

The attention matrix requires $b \times h \times n^2 \times 4$ bytes of memory (batch size $\times$ heads $\times$ sequence length squared $\times$ FP32). For BERT-base ($b=32$, $h=12$, $n=512$), this is $\sim$402~MB per layer, accounting for 57\% of per-layer activation memory (see Section~[ref] for the complete memory breakdown). This quadratic cost is the primary bottleneck for long sequences and motivates efficient attention methods (Chapters~9 and~16).

Attention Variants

Self-Attention vs Cross-Attention

Self-Attention: $\mQ$, $\mK$, $\mV$ all from same source

$$ \mQ = \mK = \mV = \mX \mW $$
Used in: Transformer encoder, BERT

Cross-Attention: Queries from one source, keys and values from another

$$ \mQ = \mX_{\text{dec}} \mW^Q, \quad \mK = \mV = \mX_{\text{enc}} \mW^{K/V} $$
Used in: Transformer decoder (attending to encoder output)

Masked Attention

For autoregressive models (GPT), prevent attending to future positions:

$$ \text{Attention}(\mQ, \mK, \mV) = \text{softmax}\left(\frac{\mQ \mK\transpose + \mM}{\sqrt{d_k}}\right) \mV $$
where mask $\mM_{ij} = -\infty$ if $j > i$, else $\mM_{ij} = 0$.

After softmax, $\exp(-\infty) = 0$, so no attention to future!

Exercises

Exercise 1: Compute Bahdanau attention for sequence length 4, decoder state dim 3, attention dim 2. Given specific $\mW_1$, $\mW_2$, $\mathbf{v}$, encoder states, and decoder state, calculate all attention weights.
Exercise 2: For scaled dot-product attention with $\mQ \in \R^{10 \times 64}$, $\mK \in \R^{20 \times 64}$, $\mV \in \R^{20 \times 128}$: (1) What is output dimension? (2) What is attention matrix shape? (3) How many FLOPs for computing $\mQ \mK\transpose$?
Exercise 3: Show that without scaling, for $d_k = 64$ and unit variance elements, dot products have variance 64. Demonstrate numerically how this affects softmax gradients.
Exercise 4: Implement scaled dot-product attention in PyTorch. Test with sequences of length 5 and 10, dimensions $d_k = 32$, $d_v = 48$. Visualize attention weights as heatmap.

Solutions

Full solutions for all exercises are available at \url{https://deeplearning.hofkensvermeulen.be}.

Solution: For Bahdanau attention with sequence length 4, decoder state dim 3, attention dim 2:

Given:

$$ \mW_1 = \begin{bmatrix} 0.5 & -0.3 & 0.2 \\ 0.4 & 0.6 & -0.1 \end{bmatrix}, \quad \mW_2 = \begin{bmatrix} 0.3 & 0.5 & 0.2 \\ -0.2 & 0.4 & 0.6 \end{bmatrix}, \quad \mathbf{v} = \begin{bmatrix} 1.0 \\ 0.8 \end{bmatrix} $$

Encoder states: $\vh_1 = [1, 0, 1]\transpose$, $\vh_2 = [0, 1, 1]\transpose$, $\vh_3 = [1, 1, 0]\transpose$, $\vh_4 = [0, 0, 1]\transpose$

Decoder state: $\vs = [0.5, 0.5, 0.5]\transpose$

Step 1: Compute alignment scores

$$\begin{align} e_i &= \mathbf{v}\transpose \tanh(\mW_1 \vh_i + \mW_2 \vs) \end{align}$$

For $i=1$:

$$\begin{align} \mW_1 \vh_1 + \mW_2 \vs &= \begin{bmatrix} 0.5 - 0.2 \\ 0.4 - 0.1 \end{bmatrix} + \begin{bmatrix} 0.5 \\ 0.4 \end{bmatrix} = \begin{bmatrix} 0.8 \\ 0.7 \end{bmatrix} \\ e_1 &= [1.0, 0.8] \cdot \tanh([0.8, 0.7]\transpose) \approx 1.0(0.664) + 0.8(0.604) \approx 1.147 \end{align}$$

Similarly: $e_2 \approx 1.089$, $e_3 \approx 1.118$, $e_4 \approx 0.856$

Step 2: Apply softmax

$$ \alpha_i = \frac{\exp(e_i)}{\sum_{j=1}^4 \exp(e_j)} $$
$$ \boldsymbol{\alpha} \approx [0.268, 0.252, 0.260, 0.220] $$

These are the attention weights showing how much the decoder attends to each encoder state.

Solution: For scaled dot-product attention with $\mQ \in \R^{10 \times 64}$, $\mK \in \R^{20 \times 64}$, $\mV \in \R^{20 \times 128}$:

(1) Output dimension:

$$ \text{Output} = \text{softmax}\left(\frac{\mQ\mK\transpose}{\sqrt{d_k}}\right)\mV \in \R^{10 \times 128} $$

(2) Attention matrix shape:

$$ \mA = \mQ\mK\transpose \in \R^{10 \times 20} $$

(3) FLOPs for $\mQ\mK\transpose$:

$$ \text{FLOPs} = 2 \times 10 \times 64 \times 20 = 25{,}600 $$
Solution: Variance analysis without scaling:

For $d_k = 64$ with unit variance elements:

$$ \text{Var}(\vq\transpose \vk) = \sum_{i=1}^{64} \text{Var}(q_i k_i) = 64 \cdot \text{Var}(q_i) \cdot \text{Var}(k_i) = 64 $$

Standard deviation: $\sigma = \sqrt{64} = 8$

Effect on softmax gradients:

Without scaling, dot products range roughly $[-24, 24]$ (3 standard deviations). After softmax:

Numerical demonstration:

$$\begin{align} \text{Unscaled: } &\mathbf{z} = [20, 18, -15, -18] \\ &\text{softmax}(\mathbf{z}) \approx [0.881, 0.119, 0, 0] \\ &\text{Gradient} \approx [0.105, 0.105, 0, 0] \text{ (vanishing)} \end{align}$$
$$\begin{align} \text{Scaled by } \sqrt{64}: &\mathbf{z}' = [2.5, 2.25, -1.875, -2.25] \\ &\text{softmax}(\mathbf{z}') \approx [0.476, 0.378, 0.061, 0.085] \\ &\text{Gradient} \approx [0.249, 0.235, 0.057, 0.078] \text{ (healthy)} \end{align}$$

Scaling by $\sqrt{d_k}$ keeps dot products in a range where softmax gradients are well-behaved.

Solution: PyTorch implementation:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

def scaled_dot_product_attention(Q, K, V):
    d_k = Q.size(-1)
    scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
    attention_weights = F.softmax(scores, dim=-1)
    output = torch.matmul(attention_weights, V)
    return output, attention_weights

# Test with sequence length 5
Q = torch.randn(1, 5, 32)  # (batch, seq_len, d_k)
K = torch.randn(1, 5, 32)
V = torch.randn(1, 5, 48)  # d_v = 48

output, weights = scaled_dot_product_attention(Q, K, V)
print(f"Output shape: {output.shape}")  # (1, 5, 48)
print(f"Attention weights shape: {weights.shape}")  # (1, 5, 5)

# Visualize attention weights
plt.imshow(weights[0].detach().numpy(), cmap='viridis')
plt.colorbar()
plt.xlabel('Key position')
plt.ylabel('Query position')
plt.title('Attention Weights Heatmap')
plt.show()

The heatmap shows which positions each query attends to, with brighter colors indicating higher attention weights.

← Chapter 6: Recurrent Neural Networks 📚 Table of Contents Chapter 8: Self-Attention and Multi-Head Attention →