Self-Attention and Multi-Head Attention
Chapter Overview
Self-attention is the core innovation enabling transformers. This chapter develops self-attention from first principles, then introduces multi-head attentionâthe mechanism that allows transformers to attend to multiple types of relationships simultaneously.
Learning Objectives
- Understand self-attention and its advantages over RNNs
- Implement multi-head attention from scratch
- Compute output dimensions and parameter counts
- Understand positional encodings for sequence order
- Analyze computational complexity of attention
- Apply masking for causal (autoregressive) attention
Self-Attention Mechanism
Self-attention exhibits several fundamental properties that distinguish it from recurrent architectures. The mechanism is permutation equivariant, meaning that if the input sequence order changes, the output changes correspondinglyâthere is no inherent notion of sequence order without positional encodings. Every position in the sequence attends to every other position through all-to-all connections, creating direct paths between any pair of tokens regardless of their distance in the sequence. This contrasts sharply with RNNs, where information must propagate sequentially through intermediate hidden states, potentially degrading over long distances.
\node[node] (x1) at (0,0) {$\vx_1$}; \node[node] (x2) at (0,-2) {$\vx_2$}; \node[node] (x3) at (0,-4) {$\vx_3$};
\node[node] (o1) at (6,0) {$\vo_1$}; \node[node] (o2) at (6,-2) {$\vo_2$}; \node[node] (o3) at (6,-4) {$\vo_3$};
\draw[attention] (x1) -- (o1); \draw[attention] (x1) -- (o2); \draw[attention] (x1) -- (o3);
\draw[attention] (x2) -- (o1); \draw[attention] (x2) -- (o2); \draw[attention] (x2) -- (o3);
\draw[attention] (x3) -- (o1); \draw[attention] (x3) -- (o2); \draw[attention] (x3) -- (o3);
\begin{scope}[shift={(10,-2)}] \node[node, scale=0.7] (h0) at (0,0) {$\vh_0$}; \node[node, scale=0.7] (h1) at (1.5,0) {$\vh_1$}; \node[node, scale=0.7] (h2) at (3,0) {$\vh_2$}; \node[node, scale=0.7] (h3) at (4.5,0) {$\vh_3$}; \draw[arrow, red!60] (h0) -- (h1); \draw[arrow, red!60] (h1) -- (h2); \draw[arrow, red!60] (h2) -- (h3); \end{scope}
\end{tikzpicture}
The parallel computation property is perhaps the most significant advantage for modern hardware. Unlike RNNs which process sequences sequentially due to the recurrence relation $\vh_t = f(\vh_{t-1}, \vx_t)$, self-attention computes all output positions simultaneously. This enables full utilization of GPU parallelism, where thousands of cores can work concurrently on different positions and attention heads. The long-range dependency modeling is direct rather than transitive: position 0 can attend to position 1000 with a single attention operation, whereas an RNN requires 1000 sequential steps, each potentially losing information through the recurrent bottleneck.
Projection matrices with $d_k = d_v = 3$:
Step 1: Project to QKV
Step 2: Compute attention scores
Entry $(i,j)$ measures how much position $i$ attends to position $j$.
Step 3: Scale and softmax
Each row sums to 1 (probability distribution over positions to attend to).
Step 4: Apply to values
Each output position is weighted combination of all input value vectors.
Hardware Considerations and Memory Layout
The memory layout of attention matrices in GPU memory significantly impacts performance. When computing self-attention for a batch of sequences, the attention matrix $\mA \in \R^{n \times n}$ for each head must be materialized in GPU global memory. For BERT-base with 12 attention heads and maximum sequence length 512, each attention matrix contains $512 \times 512 = 262{,}144$ elements. Storing these in FP32 format requires $262{,}144 \times 4 = 1{,}048{,}576$ bytes, or approximately 1 MB per head per sequence. With 12 heads, this amounts to 12 MB per sequence just for the attention weights themselves, not including the query, key, and value matrices.
The memory requirements scale dramatically with batch size. For a batch of 32 sequencesâa typical training batch sizeâthe attention matrices alone consume $12 \times 32 = 384$ MB of GPU memory. This explains why training transformers on long sequences quickly exhausts available GPU memory. For sequence length 2048, the attention matrices grow to $2048^2 \times 4 = 16{,}777{,}216$ bytes per head, or approximately 16 MB. With 12 heads and batch size 32, this becomes $16 \times 12 \times 32 = 6{,}144$ MB, or roughly 6 GB just for attention weights. An NVIDIA A100 with 40 GB of memory can accommodate this, but longer sequences of 4096 tokens would require $4096^2 \times 4 \times 12 \times 32 / (1024^3) \approx 24$ GB for attention matrices alone, leaving little room for activations, gradients, and model parameters.
The memory access patterns during attention computation determine whether the operation is compute-bound or memory-bound. Computing the attention scores $\mQ \mK\transpose$ involves reading the query matrix $\mQ \in \R^{n \times d_k}$ and key matrix $\mK \in \R^{n \times d_k}$ from global memory, performing $O(n^2 d_k)$ floating-point operations, and writing the result $\mS \in \R^{n \times n}$ back to memory. For small batch sizes, the computation is fast but the memory transfers dominate. An NVIDIA A100 has memory bandwidth of approximately 1.5 TB/s and peak FP16 compute throughput of 312 TFLOPS. For BERT-base with $n=512$ and $d_k=64$, computing $\mQ \mK\transpose$ requires reading $2 \times 512 \times 64 \times 2 = 131{,}072$ bytes (in FP16) and performing $512^2 \times 64 \times 2 = 33{,}554{,}432$ FLOPs. The arithmetic intensity is $33{,}554{,}432 / 131{,}072 \approx 256$ FLOPs per byte, which is reasonably high. However, the subsequent softmax operation and multiplication by $\mV$ have lower arithmetic intensity, making attention memory-bound for small batches.
Cache locality plays a crucial role in attention performance. Modern GPUs have a memory hierarchy with small but fast on-chip SRAM (shared memory) and large but slower off-chip DRAM (global memory). The attention computation as typically implemented requires multiple passes over the data: first computing $\mQ \mK\transpose$, then applying softmax, then multiplying by $\mV$. Each pass reads data from global memory, processes it, and writes results back. For long sequences, the attention matrix $\mA \in \R^{n \times n}$ is too large to fit in SRAM, forcing repeated global memory accesses. Flash Attention (Section~[ref]) addresses this by tiling the computation to fit in SRAM, dramatically reducing memory traffic and improving performance by 2-4Ă for long sequences.
Multi-Head Attention
Single-head attention with a single set of query, key, and value projections may capture only one type of relationship between tokens. In natural language, tokens relate to each other in multiple ways simultaneously: syntactically (subject-verb agreement, dependency structure), semantically (synonymy, antonymy, topic coherence), and positionally (proximity, relative ordering). A single attention head must compress all these relationship types into a single attention distribution, potentially losing important information. Multi-head attention addresses this limitation by computing multiple attention functions in parallel, each with its own learned projection matrices, allowing the model to attend to different types of relationships simultaneously.
For each head $i = 1, \ldots, h$:
Concatenate and project:
where $\mW^{Q(i)}, \mW^{K(i)}, \mW^{V(i)} \in \R^{d_{\text{model}} \times d_k}$ and $\mW^O \in \R^{hd_k \times d_{\text{model}}}$.
\node[box] (input) at (0,0) {Input $\mX$};
\node[operation] (proj1) at (-2,-2) {Head 1}; \node[operation] (proj2) at (0,-2) {Head 2}; \node[operation] (proj3) at (2,-2) {Head $h$}; \node at (1,-2) {$\cdots$};
\node[box, fill=green!10] (head1) at (-2,-3.5) {Attention}; \node[box, fill=green!10] (head2) at (0,-3.5) {Attention}; \node[box, fill=green!10] (head3) at (2,-3.5) {Attention};
\node[operation] (concat) at (0,-5) {Concatenate};
\node[operation] (output_proj) at (0,-6.5) {$\mW^O$};
\node[box] (output) at (0,-8) {Output};
\draw[arrow] (input) -- (-2,-1.5); \draw[arrow] (input) -- (0,-1.5); \draw[arrow] (input) -- (2,-1.5);
\draw[arrow] (proj1) -- (head1); \draw[arrow] (proj2) -- (head2); \draw[arrow] (proj3) -- (head3);
\draw[arrow] (head1) -- (-2,-4.5); \draw[arrow] (head2) -- (0,-4.5); \draw[arrow] (head3) -- (2,-4.5);
\draw[arrow] (concat) -- (output_proj); \draw[arrow] (output_proj) -- (output);
\end{tikzpicture}
Concatenating all 12 heads yields $\R^{512 \times 768}$, which the output projection $\mW^O \in \R^{768 \times 768}$ maps back to $\R^{512 \times 768}$. The total is $4 \times 768^2 = 2{,}359{,}296$ attention parameters per layer (see Section~[ref] for the full model analysis).
Parallel Computation and Memory Layout
Multiple attention heads can be computed in parallel on modern GPUs, with each head assigned to different streaming multiprocessors or computed concurrently through batched matrix operations. The key design decision is the memory layout: should the heads be stored in an interleaved fashion where all heads for a given position are contiguous, or should each head's data be stored separately? The interleaved layout $[\text{head}_1(\text{pos}_1), \text{head}_2(\text{pos}_1), \ldots, \text{head}_h(\text{pos}_1), \text{head}_1(\text{pos}_2), \ldots]$ provides better cache locality when concatenating heads for the output projection, since all data for a position is contiguous. The separated layout $[\text{head}_1(\text{pos}_1), \text{head}_1(\text{pos}_2), \ldots, \text{head}_2(\text{pos}_1), \ldots]$ allows each head to be processed independently with better memory coalescing within a head. Most implementations use the separated layout during attention computation and transpose to interleaved layout before the output projection.
The standard choice of $d_k = d_{\text{model}}/h$ ensures that the total number of parameters remains constant regardless of the number of heads. With $h$ heads each of dimension $d_k$, the total dimension after concatenation is $h \cdot d_k = d_{\text{model}}$, matching the input dimension. This design choice means that using more heads does not increase the parameter countâit simply partitions the representation space into more subspaces. For BERT-base with $d_{\text{model}} = 768$ and $h = 12$, each head has dimension $d_k = 64$. The QKV projection matrices have shape $768 \times 64$ per head, for a total of $3 \times 12 \times 768 \times 64 = 1{,}769{,}472$ parameters. If instead a single head with $d_k = 768$ were used, the QKV projections would have shape $768 \times 768$ each, for a total of $3 \times 768^2 = 1{,}769{,}472$ parametersâexactly the same. The difference lies not in parameter count but in representational capacity: multiple heads can learn diverse attention patterns, while a single large head must compress all patterns into one.
Load balancing across heads is generally not a concern during training, as all heads are computed in parallel through batched matrix operations. However, during inference with dynamic batching or when pruning less important heads, load imbalance can occur. Some heads may be more important than others for the task at hand, and recent work has shown that many heads can be pruned without significant performance degradation. For example, in BERT-base, pruning 40\% of attention heads (keeping only 7-8 heads per layer) typically reduces accuracy by less than 1\% on downstream tasks, while reducing inference time by approximately 20\%. This suggests that the 12 heads provide redundancy and that the model could function with fewer heads, though training with more heads may help optimization by providing multiple gradient pathways.
Tensor Core Utilization
Modern NVIDIA GPUs include specialized Tensor Cores that accelerate matrix multiplication for reduced-precision formats. Tensor Cores on A100 GPUs can perform FP16 matrix multiplication at 312 TFLOPS, compared to 19.5 TFLOPS for standard FP32 CUDA coresâa 16Ă difference. However, Tensor Cores have alignment requirements: matrix dimensions should be multiples of 8 for FP16 or multiples of 16 for INT8 to achieve peak throughput. This hardware constraint influences architecture design choices.
For BERT-base with $d_k = 64$, the dimension is a multiple of 8, enabling efficient Tensor Core utilization. The query-key multiplication $\mQ \mK\transpose$ has dimensions $(n \times 64) \times (64 \times n)$, where $n = 512$ is also a multiple of 8. The attention-value multiplication $\mA \mV$ has dimensions $(n \times n) \times (n \times 64)$, again with aligned dimensions. In practice, implementations pad dimensions to the nearest multiple of 8 if necessary. For example, if $d_k = 63$, it would be padded to 64, wasting 1.6\% of computation but gaining the 16Ă Tensor Core speedupâa worthwhile trade-off.
The memory bandwidth requirements for multi-head attention depend on the batch size and sequence length. For BERT-base with batch size 32 and sequence length 512, the QKV projections read $32 \times 512 \times 768 \times 2 = 25{,}165{,}824$ bytes (in FP16) and write $3 \times 32 \times 512 \times 64 \times 12 \times 2 = 75{,}497{,}472$ bytes for all heads. The attention computation reads these QKV matrices and writes attention outputs, totaling approximately 100 MB of memory traffic per layer. With 12 layers in BERT-base, this amounts to 1.2 GB of memory traffic per forward pass, which takes approximately $1.2 / 1.5 \approx 0.8$ ms on an A100 with 1.5 TB/s bandwidth. The actual time is higher due to kernel launch overhead, non-coalesced accesses, and compute time, typically around 2-3 ms per forward pass for BERT-base on an A100.
Comparing one head with $d_k = 768$ versus 12 heads with $d_k = 64$ reveals why multiple heads are better for hardware. The single large head would compute attention scores $\mQ \mK\transpose$ with dimensions $(512 \times 768) \times (768 \times 512)$, requiring $512^2 \times 768 \times 2 = 402{,}653{,}184$ FLOPs. The 12 smaller heads each compute $(512 \times 64) \times (64 \times 512)$, requiring $512^2 \times 64 \times 2 = 33{,}554{,}432$ FLOPs per head, or $12 \times 33{,}554{,}432 = 402{,}653{,}184$ FLOPs totalâexactly the same. However, the 12 heads can be computed in parallel across different streaming multiprocessors, achieving better GPU utilization. Additionally, the smaller matrices fit better in cache, reducing memory traffic. The single large head would produce an attention matrix of size $512 \times 512 \times 4 = 1{,}048{,}576$ bytes, while the 12 smaller heads produce 12 matrices of the same size, totaling 12 MB. The memory usage is higher for multiple heads, but the parallelism and cache benefits outweigh this cost.
Positional Encoding
Self-attention is inherently permutation equivariant, meaning it treats the input as an unordered set rather than a sequence. If we shuffle the input tokens, the attention mechanism produces correspondingly shuffled outputs, with no awareness that the order has changed. For sequence modeling tasks like language understanding and generation, word order is crucialâ"dog bites man" has a very different meaning from "man bites dog." To inject positional information into the model, we add positional encodings to the input embeddings before the first attention layer.
The sinusoidal positional encoding has several desirable properties. Each position receives a unique encoding, ensuring that the model can distinguish between different positions. The use of periodic functions with different frequencies allows the model to potentially extrapolate to longer sequences than seen during trainingâif the model learns to interpret the sinusoidal patterns, it can apply this understanding to positions beyond the training maximum. Different dimensions use different frequencies, with lower dimensions oscillating rapidly (high frequency) and higher dimensions oscillating slowly (low frequency). This multi-scale representation allows the model to capture both fine-grained local position information and coarse-grained global position information. Finally, the relative position between any two positions can be expressed as a linear transformation of their absolute positional encodings, which may help the model learn relative position relationships.
The usage is straightforward: the positional encoding matrix $\text{PE} \in \R^{n_{\max} \times d_{\text{model}}}$ is precomputed for the maximum sequence length $n_{\max}$, and for each input sequence of length $n \leq n_{\max}$, we add the first $n$ rows of $\text{PE}$ to the token embeddings: $\mX_{\text{input}} = \mX_{\text{embed}} + \text{PE}_{1:n}$. This addition happens before the first transformer layer, and the positional information propagates through the network via the residual connections.
Position 0:
Position 1:
Higher dimension indices have lower frequencies (longer periods).
Positional Encoding Variants
While sinusoidal positional encoding was used in the original Transformer, several alternative approaches have been developed, each with different trade-offs in terms of memory usage, extrapolation capability, and performance.
Learned positional embeddings treat position encodings as trainable parameters rather than fixed functions. A learnable embedding matrix $\mE_{\text{pos}} \in \R^{n_{\max} \times d_{\text{model}}}$ is initialized randomly and optimized during training alongside other model parameters. This approach is used in BERT and GPT-2. The advantage is that the model can learn position representations optimized for the specific task and data distribution, potentially capturing patterns that sinusoidal encodings cannot express. The disadvantage is memory cost: for BERT with $n_{\max} = 512$ and $d_{\text{model}} = 768$, the positional embeddings require $512 \times 768 \times 4 = 1{,}572{,}864$ bytes (1.5 MB) in FP32. More critically, learned positional embeddings do not extrapolate well to longer sequencesâif the model is trained on sequences up to length 512, it has never seen positional embeddings for positions 513 and beyond, and these positions must be either extrapolated (often poorly) or the model must be fine-tuned on longer sequences.
Relative positional encoding, used in T5 and Transformer-XL, encodes the relative distance between positions rather than their absolute positions. Instead of adding positional information to the input embeddings, relative position information is incorporated directly into the attention computation. For positions $i$ and $j$, a learned bias $b_{i-j}$ is added to the attention score, where the bias depends only on the relative distance $i - j$. This requires learning biases for relative distances up to some maximum, typically $\pm 128$ or $\pm 256$. The memory cost is $O(d_{\text{rel}})$ where $d_{\text{rel}}$ is the maximum relative distance, much smaller than the $O(n_{\max} \times d_{\text{model}})$ cost of learned absolute positional embeddings. Relative positional encoding extrapolates well to longer sequences because the model learns to interpret relative distances, which remain meaningful regardless of absolute sequence length. T5 uses a simplified form where relative position biases are shared across attention heads and bucketed into logarithmically-spaced bins, further reducing memory requirements.
Rotary Positional Encoding (RoPE), introduced in RoFormer and used in LLaMA and GPT-NeoX, applies rotation matrices to the query and key vectors based on their positions. For position $m$, the query and key vectors are rotated by angle $m\theta$ where $\theta$ depends on the dimension. Mathematically, for each pair of dimensions $(2i, 2i+1)$, the rotation is:
ALiBi (Attention with Linear Biases), used in BLOOM, adds a simple linear bias to attention scores based on position distance. For query position $i$ attending to key position $j$, a bias $-m \cdot |i - j|$ is added to the attention score, where $m$ is a head-specific slope. Different heads use different slopes, typically $m = 2^{-8/h}, 2^{-16/h}, \ldots, 2^{-8}$ for $h$ heads. This penalizes attention to distant positions, with the penalty strength varying across heads. ALiBi requires no parameters and no additional computation beyond adding the bias. It extrapolates remarkably well: BLOOM was trained on sequences of length 2048 but can generate coherent text at lengths exceeding 8000 tokens. The linear bias naturally extends to any sequence length, and the model learns to work within this inductive bias.
The following table summarizes the trade-offs between positional encoding methods:
| Type | Memory | Extrapolation | Used In |
|---|---|---|---|
| Sinusoidal | None | Good | Original Transformer |
| Learned | $n_{\max} \times d$ | Poor | BERT, GPT-2 |
| Relative | $O(d_{\text{rel}})$ | Good | T5, Transformer-XL |
| RoPE | None | Excellent | LLaMA, GPT-NeoX |
| ALiBi | None | Excellent | BLOOM |
The trend in recent large language models has been toward parameter-free methods with strong extrapolation: RoPE and ALiBi dominate current architectures. These methods avoid the memory cost of learned positional embeddings while providing better length generalization than sinusoidal encodings. For practitioners, the choice depends on the application: if sequences will always be shorter than the training maximum, learned embeddings may provide slightly better performance. If length generalization is important, RoPE or ALiBi are superior choices.
Computational Complexity
Memory Complexity Analysis
The memory requirements of self-attention are dominated by the attention matrices, which scale quadratically with sequence length. For a batch of $B$ sequences, each of length $n$, with $h$ attention heads, the attention matrices $\mA^{(i)} \in \R^{n \times n}$ for $i = 1, \ldots, h$ require $O(Bhn^2)$ memory. In FP32, this amounts to $Bhn^2 \times 4$ bytes. For BERT-base with $B = 32$, $h = 12$, and $n = 512$, the attention matrices consume $32 \times 12 \times 512^2 \times 4 = 402{,}653{,}184$ bytes, or approximately 384 MB. This quadratic scaling means that doubling the sequence length quadruples the memory requirement: for $n = 1024$, the attention matrices would require 1.5 GB, and for $n = 2048$, they would require 6 GB.
In contrast, the QKV projection matrices and their outputs scale linearly with sequence length. The query, key, and value matrices each have shape $B \times n \times d_k$ for each head, requiring $3Bhnd_k \times 4$ bytes total across all heads. For BERT-base with $d_k = 64$, this amounts to $3 \times 32 \times 12 \times 512 \times 64 \times 4 = 150{,}994{,}944$ bytes, or approximately 144 MB. The linear scaling means that doubling the sequence length only doubles this memory requirement.
The crossover point where attention matrices dominate total memory usage depends on the model dimensions. Attention matrices require $Bhn^2$ elements, while QKV matrices require $3Bhnd_k$ elements. Attention dominates when $Bhn^2 > 3Bhnd_k$, which simplifies to $n > 3d_k$. For BERT-base with $d_k = 64$, attention dominates when $n > 192$âessentially always, since typical sequence lengths are 512. For models with larger $d_k$, the crossover occurs at longer sequences. However, since $d_k = d_{\text{model}}/h$ and typical architectures use $h = 12$ to $h = 96$, the value of $d_k$ is usually in the range 64 to 128, meaning attention matrices dominate for sequences longer than a few hundred tokens.
Time Complexity Breakdown
The time complexity of self-attention can be decomposed into several operations, each with different scaling properties. The QKV projections involve three matrix multiplications $\mX \mW^Q$, $\mX \mW^K$, and $\mX \mW^V$, where $\mX \in \R^{Bn \times d_{\text{model}}}$ and each weight matrix has shape $d_{\text{model}} \times d_k$. For $h$ heads, the total complexity is $O(3Bhnd_{\text{model}}d_k) = O(Bhnd_{\text{model}}^2)$ since $hd_k = d_{\text{model}}$. This is linear in sequence length $n$ but quadratic in model dimension $d_{\text{model}}$.
Computing the attention scores $\mQ \mK\transpose$ requires a batch matrix multiplication with dimensions $(Bh \times n \times d_k) \times (Bh \times d_k \times n)$, resulting in complexity $O(Bhn^2d_k)$. This is quadratic in sequence length and linear in head dimension. The softmax operation over the attention scores has complexity $O(Bhn^2)$, dominated by the exponential and normalization computations. Finally, applying the attention weights to the values $\mA \mV$ has complexity $O(Bhn^2d_v)$, again quadratic in sequence length. The output projection $[\text{head}_1; \ldots; \text{head}_h] \mW^O$ has complexity $O(Bnd_{\text{model}}^2)$, linear in sequence length.
Summing these components, the total complexity is:
Since $hd_k = d_{\text{model}}$, this simplifies to $O(Bnd_{\text{model}}^2 + Bn^2d_{\text{model}})$. The relative importance of these terms depends on the ratio $n / d_{\text{model}}$. When $n < d_{\text{model}}$, the $O(Bnd_{\text{model}}^2)$ term from the linear projections dominates, and the feed-forward network (which also has $O(Bnd_{\text{model}}^2)$ complexity) is the computational bottleneck. When $n > d_{\text{model}}$, the $O(Bn^2d_{\text{model}})$ term from attention dominates, and attention becomes the bottleneck.
For BERT-base with $d_{\text{model}} = 768$, attention dominates when $n > 768$. Since BERT uses maximum sequence length 512, the model is in the regime where linear projections and feed-forward networks dominate. For GPT-3 with $d_{\text{model}} = 12{,}288$, attention would only dominate for sequences longer than 12,288 tokensâfar beyond the typical context length of 2048 tokens. This explains why efficient attention mechanisms (Chapter 16) focus on reducing the $O(n^2)$ term: for very long sequences, this term becomes prohibitive, but for typical sequence lengths in large models, the linear terms are actually more expensive.
Scaling Experiments
To illustrate the scaling behavior empirically, consider BERT-base with batch size 1 on an NVIDIA A100 GPU. The following measurements show forward pass time for different sequence lengths:
| Sequence Length | Time (ms) | Bottleneck |
|---|---|---|
| 128 | 2.1 | FFN dominates |
| 256 | 3.8 | FFN dominates |
| 512 | 8.5 | Balanced |
| 1024 | 28.3 | Attention dominates |
| 2048 | 98.7 | Attention dominates |
| 4096 | 367 | Attention dominates |
For short sequences (128, 256), the time scales approximately linearly, indicating that the $O(nd_{\text{model}}^2)$ terms dominate. At sequence length 512, the scaling begins to show quadratic behavior. For long sequences (1024, 2048, 4096), the time scales quadratically: doubling from 1024 to 2048 increases time by $98.7 / 28.3 \approx 3.5\times$, and doubling again to 4096 increases time by $367 / 98.7 \approx 3.7\times$. The slight deviation from exactly $4\times$ is due to the linear terms and memory bandwidth effects, but the quadratic scaling is clearly visible.
These measurements demonstrate why long-context transformers require specialized attention mechanisms. Extending BERT-base to sequence length 8192 would require approximately $367 \times 4 \approx 1{,}468$ ms per forward pass, or 1.5 secondsâprohibitively slow for interactive applications. The memory requirement would be $32 \times 12 \times 8192^2 \times 4 / (1024^3) \approx 96$ GB for attention matrices alone with batch size 32, exceeding the capacity of even the largest single GPUs. This fundamental scaling limitation motivates the development of sparse attention, linear attention, and other efficient variants discussed in Chapter 16.
Causal (Masked) Self-Attention
Autoregressive language models like GPT generate text sequentially, predicting each token based only on previous tokens. During training, the entire sequence is provided as input, but the model must not be allowed to "see" future tokens when predicting each positionâthis would constitute cheating, as the model would have access to information unavailable during generation. Causal masking enforces this constraint by preventing each position from attending to subsequent positions in the sequence.
Apply before softmax:
After softmax, $\exp(-\infty) = 0$, so position $i$ cannot attend to positions $j > i$.
Position 0 attends only to itself. Position 1 attends to positions 0, 1. Position 3 attends to all positions 0, 1, 2, 3.
This ensures autoregressive property for language modeling.
Efficient Causal Mask Implementation
The naive implementation of causal masking stores a full $n \times n$ mask matrix in memory. For sequence length 2048, this requires $2048^2 \times 4 = 16{,}777{,}216$ bytes (16 MB) per sequence in FP32. With batch size 32, this amounts to 512 MB just for the maskâa significant memory overhead. Moreover, the mask must be added to the attention scores before softmax, requiring a memory read of the mask matrix.
Efficient implementations compute the mask on-the-fly during the attention computation rather than storing it explicitly. Modern deep learning frameworks support this through boolean masking or by directly computing the upper triangular structure. For example, in PyTorch, the operation torch.triu(scores, diagonal=1).fill\_(-float('inf')) modifies the attention scores in-place without allocating a separate mask matrix. This reduces memory usage to zero for the mask itself, though the attention scores matrix must still be stored.
Flash Attention takes this optimization further by fusing the masking operation with the attention computation and tiling the computation to fit in SRAM. Instead of computing the full attention matrix, materializing it in global memory, applying the mask, and then computing softmax, Flash Attention computes attention in tiles that fit in on-chip memory. For each tile, the mask is computed on-the-fly, attention is computed, and the result is written back to global memory. This approach reduces memory usage from $O(n^2)$ to $O(n)$ and provides 2-4Ă speedup for long sequences by minimizing global memory traffic.
The impact of causal masking differs between training and inference. During training, the entire sequence is processed in parallel, with masking ensuring that each position only attends to previous positions. The forward pass computes outputs for all positions simultaneously, and the backward pass computes gradients for all positions simultaneously. The masking is explicit in the attention computation. During inference, text generation is inherently sequential: we generate one token at a time, appending it to the context and generating the next token. In this setting, the masking is implicitâwhen generating position $t$, we only have tokens $0, \ldots, t-1$ available, so there are no future tokens to mask. However, naive inference would recompute attention over the entire sequence for each new token, resulting in $O(n^2)$ complexity for generating $n$ tokens. Key-value caching addresses this by storing the key and value vectors for all previous tokens, allowing each new token to attend to the cached keys and values without recomputation. This reduces inference complexity to $O(n)$ for generating $n$ tokens, at the cost of $O(nd_{\text{model}})$ memory for the cache.
Attention Patterns and Interpretability
Analysis of trained transformer models reveals that different attention heads learn to capture different types of linguistic relationships. Some heads focus on syntactic structure, attending strongly between words that have grammatical dependencies such as subject-verb agreement or determiner-noun relationships. For example, in the sentence "The cat that chased the mouse was hungry," a syntactic head might show strong attention from "was" to "cat" (the subject), skipping over the relative clause. Other heads capture semantic relationships, attending between words with similar meanings or words that are topically related. In a sentence about cooking, a semantic head might show attention between "recipe," "ingredients," and "oven," even if these words are not syntactically related.
Positional heads exhibit attention patterns based primarily on token distance rather than content. Some heads attend primarily to adjacent tokens, capturing local context. Others attend to tokens at specific relative positions, such as attending to the previous token or to tokens at fixed offsets. These positional patterns can be useful for tasks like copying or for capturing regular linguistic structures. Rare word heads show distinctive behavior where attention is concentrated on infrequent tokens, potentially allowing the model to give special processing to unusual or important words that might otherwise be overwhelmed by common function words.
Attention visualization provides insight into model behavior by displaying the attention weights as heatmaps or graphs. For a given input sentence, we can visualize the attention distribution for each head in each layer, showing which tokens each position attends to. These visualizations often reveal interpretable patterns: early layers tend to focus on local, syntactic relationships, while later layers capture more abstract, semantic relationships. However, interpretation must be approached with cautionâattention weights show where the model looks, but not necessarily what information is extracted or how it is used. High attention weight does not necessarily imply high importance for the final prediction.
Research on attention head importance has shown that many heads can be pruned without significant performance degradation. In BERT-base with 144 attention heads (12 heads per layer Ă 12 layers), pruning 40-50\% of heads typically reduces downstream task accuracy by less than 1\%. This suggests substantial redundancy in the multi-head attention mechanism. Some heads are consistently important across tasksâoften those capturing syntactic relationships or attending to special tokens like [CLS] or [SEP]. Other heads appear to be less critical, and their removal has minimal impact. This redundancy may serve an important role during training by providing multiple gradient pathways and helping optimization, even if the final model does not require all heads for inference.
Hardware-Specific Optimizations
Flash Attention
Flash Attention represents a fundamental rethinking of how attention is computed on modern GPUs. The standard attention implementation computes the full attention matrix $\mA = \text{softmax}(\mQ \mK\transpose / \sqrt{d_k})$ and materializes it in GPU global memory before multiplying by $\mV$. For long sequences, this attention matrix is largeâfor sequence length 4096, a single attention head requires $4096^2 \times 4 = 67$ MB in FP32. Reading and writing this matrix to global memory becomes the performance bottleneck, as global memory bandwidth (approximately 1.5 TB/s on an A100) is much lower than compute throughput (312 TFLOPS for FP16).
Flash Attention addresses this by tiling the attention computation to fit in SRAM, the fast on-chip memory available on each streaming multiprocessor. SRAM has much higher bandwidth (approximately 19 TB/s on A100) but limited capacity (192 KB per SM, totaling about 40 MB across all SMs). The key insight is that attention can be computed in blocks: we partition the query, key, and value matrices into tiles, load each tile into SRAM, compute attention for that tile, and accumulate the results. The attention matrix is never fully materialized in global memoryâonly the tiles currently being processed reside in SRAM.
The tiling strategy works as follows. Partition the queries into blocks $\mQ_1, \ldots, \mQ_T$ and the keys and values into blocks $\mK_1, \mV_1, \ldots, \mK_T, \mV_T$. For each query block $\mQ_i$, iterate over all key-value blocks $(\mK_j, \mV_j)$, computing the attention contribution $\text{softmax}(\mQ_i \mK_j\transpose / \sqrt{d_k}) \mV_j$ in SRAM and accumulating the results. The softmax normalization requires special handling since we compute it in blocksâwe maintain running statistics (maximum and sum of exponentials) and update them as we process each block, then renormalize at the end. This online softmax algorithm ensures numerical stability while avoiding materialization of the full attention matrix.
The memory usage of Flash Attention is $O(n)$ rather than $O(n^2)$, as we only store the query, key, and value matrices (each $O(n \times d)$) and the output, not the attention matrix. The computational cost remains the sameâwe perform the same number of FLOPs as standard attentionâbut the memory traffic is dramatically reduced. For sequence length 4096 with $d_k = 64$, standard attention reads/writes approximately $4096^2 \times 4 = 67$ MB for the attention matrix, while Flash Attention reads/writes only the QKV matrices, approximately $3 \times 4096 \times 64 \times 4 = 3$ MB. This 20Ă reduction in memory traffic translates to 2-4Ă speedup in practice, with larger speedups for longer sequences where memory bandwidth is the primary bottleneck.
Fused Kernels
Kernel fusion combines multiple operations into a single GPU kernel, reducing memory traffic by keeping intermediate results in registers or shared memory rather than writing them to global memory. For attention, a common fusion is combining the softmax operation with the attention score computation and the multiplication by values. The standard implementation computes $\mS = \mQ \mK\transpose$, writes $\mS$ to global memory, launches a separate kernel to compute $\mA = \text{softmax}(\mS)$, writes $\mA$ to global memory, and launches another kernel to compute $\mO = \mA \mV$. Each write and read to global memory incurs latency and consumes bandwidth.
A fused attention kernel computes all these operations in a single kernel launch. The kernel loads tiles of $\mQ$, $\mK$, and $\mV$ into shared memory, computes attention scores in registers, applies softmax, multiplies by values, and writes the final outputâall without intermediate global memory traffic. This fusion reduces memory bandwidth requirements by approximately 2Ă, as we eliminate the reads and writes of $\mS$ and $\mA$. The speedup is typically 1.5-2Ă for attention-dominated workloads, with larger benefits for smaller batch sizes where memory bandwidth is the primary bottleneck.
Fused kernels require careful implementation to maximize occupancy and minimize register pressure. The kernel must balance the tile size (larger tiles reduce global memory traffic but increase shared memory and register usage) with occupancy (the number of thread blocks that can run concurrently on each SM). Modern deep learning frameworks like PyTorch and TensorFlow provide fused attention implementations through libraries like cuDNN and custom CUDA kernels, making these optimizations accessible without manual kernel development.
Tensor Core Optimization
Tensor Cores on NVIDIA GPUs provide specialized hardware for matrix multiplication, achieving much higher throughput than standard CUDA cores for reduced-precision formats. To fully utilize Tensor Cores, matrix dimensions should be multiples of 8 for FP16 or multiples of 16 for INT8. For attention, this means padding $d_k$, $n$, and batch size to these multiples when necessary. For example, if $d_k = 63$, padding to 64 wastes 1.6\% of computation but enables Tensor Core usage, providing a net speedup of 10-15Ă.
The WMMA (Warp Matrix Multiply-Accumulate) API provides access to Tensor Cores from CUDA code. A warp (32 threads) cooperatively loads matrix tiles into registers, performs matrix multiplication using Tensor Cores, and stores the result. For attention, the query-key multiplication $\mQ \mK\transpose$ and the attention-value multiplication $\mA \mV$ are both matrix multiplications that can leverage Tensor Cores. Achieving 70-80\% of peak TFLOPS requires careful attention to data layout (row-major vs column-major), tile sizes, and memory access patterns to ensure coalesced loads and stores.
In practice, modern deep learning frameworks handle Tensor Core optimization automatically for standard operations like matrix multiplication. However, custom attention implementations or fused kernels may require explicit use of WMMA or the higher-level cuBLAS library to achieve peak performance. The key takeaway for practitioners is that attention performance depends critically on matrix dimensions being multiples of 8 or 16, and that padding dimensions to meet this requirement is almost always worthwhile.
Memory-Efficient Attention Variants
The quadratic memory and time complexity of standard attention motivates approximate mechanisms that reduce complexity while maintaining most modeling power. Three main approaches exist:
- Sparse attention restricts each query to a subset of keys (e.g., local window $\pm w$, strided, or random positions), reducing complexity from $O(n^2)$ to $O(ns)$ where $s \ll n$. Accuracy loss is typically $<$1\% for local-context tasks with $w = 512$.
- Linear attention replaces the softmax kernel with a decomposable kernel $\phi(\vq) \cdot \phi(\vk)$, enabling $O(nd^2)$ computation by reordering the matrix multiplications. Accuracy loss is 1--3\% due to imperfect kernel approximation.
- Low-rank attention (e.g., Linformer) projects keys and values to dimension $r \ll n$, giving an attention matrix of shape $n \times r$ instead of $n \times n$. Memory and computation reduce by a factor of $n/r$.
These variants are covered in depth in Chapter~9 (attention variants and mechanisms) and Chapter~16 (efficient transformers), including detailed complexity analysis, implementation strategies, and benchmark comparisons.
Exercises
Solutions
Full solutions for all exercises are available at \url{https://deeplearning.hofkensvermeulen.be}.
(1) Attention matrix memory:
(2) Parameters in multi-head attention:
- $\mW_Q, \mW_K, \mW_V$: $3 \times d^2 = 3 \times 1024^2 = 3{,}145{,}728$
- $\mW_O$: $d^2 = 1{,}048{,}576$
- Total: $4{,}194{,}304$ parameters
(3) FLOPs for forward pass:
- QKV projections: $3 \times 2nd^2 = 6{,}442{,}450{,}944$ FLOPs
- Attention scores: $2hn^2d_k = 2{,}147{,}483{,}648$ FLOPs
- Attention output: $2hn^2d_k = 2{,}147{,}483{,}648$ FLOPs
- Output projection: $2nd^2 = 2{,}147{,}483{,}648$ FLOPs
- Total: $\approx 12.9$ GFLOPs
import torch
import torch.nn as nn
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
assert d_model
self.d_k = d_model // num_heads
self.num_heads = num_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
def forward(self, x):
batch_size, seq_len, d_model = x.size()
# Project and reshape
Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
K = self.W_k(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
V = self.W_v(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
# Attention
scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.d_k, dtype=torch.float32))
attn = torch.softmax(scores, dim=-1)
out = torch.matmul(attn, V)
# Concatenate and project
out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
return self.W_o(out)
# Test
mha = MultiHeadAttention(d_model=128, num_heads=4)
x = torch.randn(32, 20, 128)
output = mha(x)
print(f"Output shape: {output.shape}") # (32, 20, 128)
print(f"Parameters: {sum(p.numel() for p in mha.parameters())}") # 66,048
Expected parameters: $4 \times 128^2 = 65{,}536$ (matches implementation).
For position $\text{pos}$ and dimension $2i$:
For position $\text{pos}+k$:
Using trigonometric identity:
Therefore:
where $\mM_k$ is a linear transformation matrix depending only on $k$. This allows the model to learn relative positions through linear transformations.
Without PE, attention is permutation-invariant:
Numerical example: Sentence: "cat sat mat"
Without PE:
With PE:
With PE, each token attends more strongly to nearby positions, capturing word order information.