Attention Variants and Mechanisms

Chapter Overview

Beyond standard scaled dot-product attention, numerous variants have been developed for specific use cases and improved efficiency. This chapter explores cross-attention for encoder-decoder models, soft vs hard attention, attention with relative position representations, and practical considerations for implementing attention mechanisms.

Learning Objectives

  1. Distinguish between self-attention and cross-attention
  2. Understand relative position representations
  3. Implement attention with different scoring functions
  4. Apply attention masking for various scenarios
  5. Understand attention dropout and layer normalization
  6. Visualize and interpret attention patterns

Cross-Attention

Definition: In encoder-decoder architectures, decoder attends to encoder output via cross-attention:
$$\begin{align} \mQ &= \mX_{\text{dec}} \mW^Q \quad \text{(queries from decoder)} \\ \mK &= \mX_{\text{enc}} \mW^K \quad \text{(keys from encoder)} \\ \mV &= \mX_{\text{enc}} \mW^V \quad \text{(values from encoder)} \\ \text{CrossAttn}(\mX_{\text{dec}}, \mX_{\text{enc}}) &= \text{softmax}\left(\frac{\mQ \mK\transpose}{\sqrt{d_k}}\right) \mV \end{align}$$

Dimensions:

Example: English source: "The cat sat" (3 tokens encoded to $\mX_{\text{enc}} \in \R^{3 \times 512}$)

French target: "Le chat" (2 tokens so far, $\mX_{\text{dec}} \in \R^{2 \times 512}$)

Cross-attention computes:

$$ \mA = \begin{bmatrix} \alpha_{1,1} & \alpha_{1,2} & \alpha_{1,3} \\ \alpha_{2,1} & \alpha_{2,2} & \alpha_{2,3} \end{bmatrix} \in \R^{2 \times 3} $$

where $\alpha_{1,j}$ = attention from decoder position 1 ("Le") to encoder position $j$.

When generating "Le" (the), model should attend strongly to "The" in source.

When generating "chat" (cat), model should attend strongly to "cat" in source.

Transformer Decoder Attention Layers

A transformer decoder block contains three attention mechanisms:

  1. Masked self-attention: Decoder attends to previous decoder positions
    $$ \mQ = \mK = \mV = \mX_{\text{dec}} \quad \text{(with causal mask)} $$

  2. Cross-attention: Decoder attends to encoder output
    $$ \mQ = \mX_{\text{dec}}, \quad \mK = \mV = \mX_{\text{enc}} $$

  3. Feed-forward: Position-wise MLP (not attention)
Encoder-only models (BERT) use only self-attention. Decoder-only models (GPT) use only masked self-attention. Encoder-decoder models (T5, BART) use all three mechanisms.

Relative Position Representations

Problem with absolute positions: Model learns positions 0-512 during training. How to handle position 600 at inference?

Solution: Relative position representations—encode distance between positions, not absolute positions.

Shaw et al. Relative Attention

Definition: Modify attention scores to include relative position information:
$$ e_{ij} = \frac{\vq_i\transpose \vk_j}{\sqrt{d_k}} + \vq_i\transpose \vr^{K}_{i-j} $$
where $\vr^{K}_{i-j} \in \R^{d_k}$ encodes relative position $i-j$ (clipped to maximum distance).

Advantages:

T5 Relative Position Bias

T5 uses even simpler approach—add learned bias based on relative position:

$$ \mA_{ij} = \text{softmax}\left(\frac{\mQ \mK\transpose}{\sqrt{d_k}} + \mB\right)_{ij} $$
where $B_{ij}$ depends only on $|i-j|$ (bucketed by distance).

Alternative Attention Scoring Functions

Beyond the scaled dot-product used in transformers, several alternative scoring functions exist---additive (Bahdanau), multiplicative (Luong), and general bilinear forms---each with different trade-offs between expressiveness and computational efficiency. These are defined and compared in Chapter~[ref] (Section~7.3). In practice, scaled dot-product attention dominates in transformer architectures due to its hardware-efficient batched matrix multiplication and strong empirical performance.

Attention Masking

\begin{tikzpicture}[ node/.style={circle, draw, minimum size=0.7cm, font=\small}, arrow/.style={->, >=stealth, thick}, blocked/.style={->, >=stealth, thick, red, dashed} ]

\node[node] (e1) at (0,0) {$x_1$}; \node[node] (e2) at (2,0) {$x_2$}; \node[node] (e3) at (4,0) {$x_3$};

\draw[arrow, blue!60] (e1) to[bend left=20] (e2); \draw[arrow, blue!60] (e2) to[bend left=20] (e1); \draw[arrow, blue!60] (e2) to[bend left=20] (e3); \draw[arrow, blue!60] (e3) to[bend left=20] (e2); \draw[arrow, blue!60] (e1) to[bend left=30] (e3); \draw[arrow, blue!60] (e3) to[bend left=30] (e1); \draw[arrow, blue!60] (e1) to[loop left] (e1); \draw[arrow, blue!60] (e2) to[loop above] (e2); \draw[arrow, blue!60] (e3) to[loop right] (e3);

\begin{scope}[shift={(7,0)}] \node[node] (d1) at (0,0) {$x_1$}; \node[node] (d2) at (2,0) {$x_2$}; \node[node] (d3) at (4,0) {$x_3$};

\draw[arrow, green!60] (d1) to[loop left] (d1); \draw[arrow, green!60] (d1) to[bend left=20] (d2); \draw[arrow, green!60] (d2) to[loop above] (d2); \draw[arrow, green!60] (d1) to[bend left=30] (d3); \draw[arrow, green!60] (d2) to[bend left=20] (d3); \draw[arrow, green!60] (d3) to[loop right] (d3);

\draw[blocked] (d2) to[bend right=20] (d1); \draw[blocked] (d3) to[bend right=20] (d2); \draw[blocked] (d3) to[bend right=30] (d1);

\end{scope}

\begin{scope}[shift={(0,-3)}] \foreach \i in {1,2,3} { \foreach \j in {1,2,3} { \fill[blue!30] (\j*0.8-0.8, -\i*0.8+0.8) rectangle (\j*0.8-0.4, -\i*0.8+0.4); } } \end{scope}

\begin{scope}[shift={(7,-4)}] \foreach \i in {1,2,3} { \foreach \j in {1,2,3} { \pgfmathtruncatemacro{\valid}{\j <= \i ? 1 : 0} \ifnum\valid=1 \fill[green!30] (\j*0.8-0.8, -\i*0.8+0.8) rectangle (\j*0.8-0.4, -\i*0.8+0.4); \else \fill[red!30] (\j*0.8-0.8, -\i*0.8+0.8) rectangle (\j*0.8-0.4, -\i*0.8+0.4); \node at (\j*0.8-0.6, -\i*0.8+0.6) {\footnotesize $-\infty$}; \fi } } \end{scope}

\end{tikzpicture}

Bidirectional vs causal attention masking. Left: Bidirectional attention (encoder) allows each position to attend to all positions, creating a fully-connected graph and full attention matrix. Right: Causal attention (decoder) masks future positions by setting them to $-\infty$ before softmax, creating a triangular connectivity pattern. Position 1 can only see itself, position 2 can see positions 1-2, and position 3 can see all positions 1-3. This prevents the model from "cheating" by looking at future tokens during training.

Padding Mask

For variable-length sequences in batch, mask padding tokens:

$$ M_{ij} = \begin{cases} 0 & \text{if position } j \text{ is valid} \\ -\infty & \text{if position } j \text{ is padding} \end{cases} $$
Example: Batch with sequences of length [5, 7, 4], padded to length 7:
$$\begin{align} \text{Seq 1:} & \quad [w_1, w_2, w_3, w_4, w_5, \text{PAD}, \text{PAD}] \\ \text{Seq 2:} & \quad [w_1, w_2, w_3, w_4, w_5, w_6, w_7] \\ \text{Seq 3:} & \quad [w_1, w_2, w_3, w_4, \text{PAD}, \text{PAD}, \text{PAD}] \end{align}$$

Mask for Seq 1:

$$ [0, 0, 0, 0, 0, -\infty, -\infty] $$

Prevents attending to padding tokens.

Combined Masks

For decoder, combine causal mask and padding mask:

$$ \mM_{\text{total}} = \mM_{\text{causal}} + \mM_{\text{padding}} $$

Element-wise, use most restrictive: if either mask blocks, result blocks.

Attention Dropout

Apply dropout to attention weights for regularization:

$$ \mA = \text{Dropout}\left(\text{softmax}\left(\frac{\mQ \mK\transpose}{\sqrt{d_k}}\right)\right) $$

Typical dropout rate: 0.1 (10\%)

Effect: Randomly zero out some attention connections, preventing over-reliance on specific positions.

Layer Normalization with Attention

Two architectures for combining attention with layer norm:

Post-Norm (Original Transformer)

$$\begin{align} \vh &= \mX + \text{MultiHeadAttn}(\mX) \\ \mZ &= \text{LayerNorm}(\vh) \end{align}$$

Pre-Norm (More Common Now)

$$\begin{align} \vh &= \mX + \text{MultiHeadAttn}(\text{LayerNorm}(\mX)) \\ \mZ &= \vh \end{align}$$

Pre-norm advantages:

Visualizing Attention

Attention weights $\mA \in \R^{n \times n}$ reveal what model attends to:

Attention Heatmaps

For sentence "The cat sat on the mat":

Patterns observed:

Interpreting Multiple Heads

In 12-head attention, different heads specialize:

Attention weights are NOT necessarily model explanations! High attention doesn't always mean high importance for prediction. Attention shows where model looks, not why decisions are made.

Practical Implementation Considerations

Memory-Efficient Attention

For very long sequences, store attention matrix in chunks:

  1. Compute $\mQ \mK\transpose$ for chunk of queries
  2. Apply softmax
  3. Multiply by $\mV$ chunk
  4. Accumulate results

Reduces peak memory from $O(n^2)$ to $O(nc)$ where $c$ is chunk size.

Fused Attention Kernels

Modern implementations fuse operations:

$$ \text{QK}^T \to \text{Scale} \to \text{Mask} \to \text{Softmax} \to \text{Dropout} \to \text{multiply } \mV $$

Single fused kernel faster than separate operations (fewer memory transfers).

Example: FlashAttention achieves 2-4x speedup through fused operations and memory hierarchy optimization.

Efficient Attention Variants

The standard self-attention mechanism has computational complexity $O(n^2d)$ and memory complexity $O(n^2)$, where $n$ is the sequence length and $d$ is the model dimension. This quadratic scaling in sequence length becomes prohibitive for long sequences. For a sequence of length 4096 with 12 attention heads, the attention matrices alone require $12 \times 4096^2 \times 4 = 805$ MB in FP32 format per example. With batch size 32, this amounts to 25.8 GB just for attention weights, exceeding the memory capacity of most GPUs. This fundamental limitation has motivated extensive research into efficient attention variants that reduce the quadratic complexity while maintaining model quality.

The key insight underlying efficient attention is that not all token pairs require equal attention. In practice, attention patterns often exhibit structure—tokens primarily attend to nearby tokens, specific global tokens, or sparse subsets of the sequence. By exploiting this structure, efficient attention mechanisms can dramatically reduce computational and memory requirements while preserving most of the modeling capacity of full attention. The following sections examine the major classes of efficient attention variants, analyzing their complexity trade-offs, implementation considerations, and practical use cases.

Local Attention

Local attention restricts each token to attend only to tokens within a fixed window around its position, rather than attending to all tokens in the sequence. For a window size $w$, token at position $i$ attends only to positions $[i-w/2, i+w/2]$. This reduces the attention matrix from $n \times n$ to $n \times w$, yielding linear scaling in sequence length.

The computational complexity of local attention is $O(nwd)$, where $n$ is sequence length, $w$ is window size, and $d$ is model dimension. Compared to standard attention's $O(n^2d)$, this represents a reduction factor of $n/w$. For a sequence of length 4096 with window size 256, local attention is 16 times faster than full attention. The memory complexity similarly reduces from $O(n^2)$ to $O(nw)$, enabling much longer sequences to fit in GPU memory. For the same 4096-token sequence with 12 heads, local attention with window 256 requires only $12 \times 4096 \times 256 \times 4 = 50.3$ MB per example, a 16-fold reduction from the 805 MB required by full attention.

The primary trade-off of local attention is the loss of long-range dependencies. Tokens separated by more than $w/2$ positions cannot directly attend to each other, requiring information to propagate through multiple layers. In practice, this limitation is often acceptable. Many natural language tasks exhibit strong locality—syntactic dependencies are typically short-range, and semantic relationships can be captured through multiple layers of local attention. Empirical studies show that local attention with window size 256-512 typically achieves 98-99\% of full attention's accuracy on language modeling tasks, while enabling sequences 10-20 times longer.

The Longformer architecture demonstrates effective use of local attention for document-level understanding. Longformer combines local windowed attention for most tokens with global attention for special tokens like [CLS] and task-specific tokens. This hybrid approach maintains $O(n)$ complexity while allowing critical tokens to aggregate information from the entire sequence. On document classification tasks with 4096-token inputs, Longformer achieves comparable accuracy to BERT while processing sequences 8 times longer. The local attention pattern also enables efficient implementation on GPUs through blocked matrix operations, achieving 2-3x speedup over naive implementations.

Sparse Attention

Sparse attention generalizes local attention by allowing each token to attend to a sparse subset of positions according to a predefined pattern, rather than a contiguous window. The key insight is that attention patterns in trained transformers often exhibit structure—certain positions are consistently important while others receive minimal attention. By designing sparsity patterns that capture this structure, sparse attention can dramatically reduce computation while maintaining model quality.

Several sparsity patterns have proven effective in practice. Strided attention divides the sequence into blocks and allows each token to attend within its block and to every $k$-th token globally, where $k$ is the stride. This pattern captures both local context and evenly-spaced global context. Fixed attention combines local attention with attention to a fixed set of global tokens, similar to Longformer. Learned sparse attention uses a separate network to predict which positions each token should attend to, adapting the sparsity pattern to the input. The Sparse Transformer architecture uses a factorized attention pattern where each token attends to positions in a strided pattern in one head and a local pattern in another head, allowing information to flow efficiently across the sequence.

The computational complexity of sparse attention is $O(n \sqrt{n} d)$ for typical sparsity patterns, where each token attends to approximately $\sqrt{n}$ other tokens. This represents a substantial improvement over full attention's $O(n^2 d)$, particularly for long sequences. For a sequence of length 4096, sparse attention with $\sqrt{n} = 64$ positions per token is 64 times faster than full attention. The memory complexity is similarly $O(n \sqrt{n})$, enabling sequences that would be impossible with full attention. For 4096 tokens with 12 heads, sparse attention requires approximately $12 \times 4096 \times 64 \times 4 = 12.6$ MB per example, a 64-fold reduction from full attention's 805 MB.

The accuracy trade-off of sparse attention depends critically on the choice of sparsity pattern. Well-designed patterns that align with the task's dependency structure can achieve 97-99\% of full attention's accuracy. The Sparse Transformer achieves perplexity within 0.1 of full attention on language modeling while using only $\sqrt{n}$ attention per token. BigBird, which combines local, global, and random attention patterns, matches BERT's accuracy on question answering and document classification while processing sequences up to 8 times longer. However, poorly chosen sparsity patterns can significantly degrade accuracy, particularly on tasks requiring long-range reasoning.

Implementation of sparse attention on GPUs presents challenges because modern GPUs are optimized for dense matrix operations. Sparse matrix multiplication is less efficient than dense multiplication due to irregular memory access patterns and reduced arithmetic intensity. Specialized kernels and libraries like cuSPARSE can partially mitigate this, but sparse attention typically achieves only 50-70\% of the theoretical speedup in practice. Recent work on block-sparse attention, which operates on blocks of the attention matrix rather than individual elements, achieves better GPU utilization by maintaining some regularity in memory access patterns. The Triton framework enables efficient implementation of custom sparse attention patterns through automatic optimization of memory access.

Linear Attention

Linear attention achieves $O(nd^2)$ complexity by reformulating the attention computation to avoid explicitly constructing the $n \times n$ attention matrix. The key insight is that attention can be viewed as a kernel operation, and by choosing an appropriate kernel function, the computation can be reordered to compute the output directly without materializing the full attention matrix.

The standard attention computation is:

$$ \text{Attention}(\mQ, \mK, \mV) = \text{softmax}(\mQ \mK\transpose) \mV $$

This requires computing $\mQ \mK\transpose \in \R^{n \times n}$ before applying softmax and multiplying by $\mV$. Linear attention approximates the softmax kernel with a feature map $\phi: \R^{d_k} \to \R^{d'}$ such that:

$$ \text{softmax}(\vq\transpose \vk) \approx \phi(\vq)\transpose \phi(\vk) $$

With this approximation, attention becomes:

$$ \text{LinearAttn}(\mQ, \mK, \mV) = \phi(\mQ) (\phi(\mK)\transpose \mV) $$

The crucial observation is that the parentheses can be reordered. Instead of computing $\phi(\mQ) \phi(\mK)\transpose$ (which is $n \times n$) and then multiplying by $\mV$, we first compute $\phi(\mK)\transpose \mV \in \R^{d' \times d_v}$ and then multiply by $\phi(\mQ)$. This reordering changes complexity from $O(n^2 d)$ to $O(n d'^2)$, where $d'$ is the feature dimension (typically equal to $d_k$).

The computational savings of linear attention are substantial for long sequences. For sequence length 4096 and model dimension 768, standard attention requires approximately $4096^2 \times 768 = 12.9$ billion operations per head, while linear attention requires $4096 \times 768^2 = 2.4$ billion operations—a 5.4x reduction. The memory complexity is even more favorable: linear attention requires only $O(nd)$ memory for the intermediate $\phi(\mK)\transpose \mV$ matrix, compared to $O(n^2)$ for the full attention matrix. For 4096 tokens with 12 heads, linear attention requires approximately $12 \times 768 \times 768 \times 4 = 28.3$ MB, compared to 805 MB for full attention—a 28-fold reduction.

The primary challenge of linear attention is choosing a feature map $\phi$ that accurately approximates the softmax kernel while remaining computationally efficient. The Performer architecture uses random Fourier features with $\phi(\vx) = \exp(\vx^2/2) [\cos(\omega_1\transpose \vx), \sin(\omega_1\transpose \vx), \ldots]$ where $\omega_i$ are random projection vectors. This provides an unbiased approximation of the softmax kernel with controllable accuracy based on the number of random features. The Linear Transformer uses a simpler feature map $\phi(\vx) = \text{elu}(\vx) + 1$, which is faster to compute but provides a looser approximation.

The accuracy trade-off of linear attention is more significant than local or sparse attention. Empirical studies show that linear attention typically achieves 95-98\% of full attention's accuracy on language modeling, with larger degradation on tasks requiring precise attention patterns. The approximation error is particularly noticeable for small attention weights—the softmax function's sharp peaking is difficult to approximate with simple feature maps. However, for applications where extreme sequence length is critical, such as processing entire books or long-form video, the 2-5\% accuracy loss is often acceptable given the dramatic computational savings. Recent work on learned feature maps and adaptive kernel approximations aims to close this accuracy gap while maintaining linear complexity.

Low-Rank Attention

Low-rank attention exploits the observation that attention matrices in trained transformers often have low effective rank—most of the variance is captured by a small number of singular values. By explicitly factorizing the attention computation through a low-dimensional bottleneck, low-rank attention reduces complexity from $O(n^2 d)$ to $O(nrd)$, where $r$ is the rank and typically $r \ll n$.

The Linformer architecture implements low-rank attention by projecting the keys and values to a lower-dimensional space before computing attention. Specifically, Linformer adds projection matrices $\mE, \mF \in \R^{r \times n}$ that reduce the sequence length dimension:

$$ \text{LinformerAttn}(\mQ, \mK, \mV) = \text{softmax}\left(\frac{\mQ (\mE \mK)\transpose}{\sqrt{d_k}}\right) (\mF \mV) $$

The key insight is that $\mE \mK \in \R^{r \times d_k}$ and $\mF \mV \in \R^{r \times d_v}$ have reduced sequence length $r$ instead of $n$. The attention matrix is now $n \times r$ instead of $n \times n$, reducing both computation and memory by a factor of $n/r$.

For sequence length 4096 and rank 256, low-rank attention reduces computation from $4096^2 \times 768 = 12.9$ billion operations to $4096 \times 256 \times 768 = 805$ million operations per head—a 16-fold reduction. The memory savings are equally dramatic: the attention matrix requires $4096 \times 256 \times 4 = 4.2$ MB per head instead of $4096^2 \times 4 = 67.1$ MB, a 16-fold reduction. With 12 heads, total attention memory drops from 805 MB to 50.3 MB per example.

The accuracy of low-rank attention depends on the choice of rank $r$ and the projection matrices $\mE$ and $\mF$. Linformer uses learned projection matrices that are shared across all layers, reducing the parameter overhead. Empirical studies show that rank $r = 256$ achieves 96-98\% of full attention's accuracy for sequences up to 4096 tokens, with minimal degradation on most language understanding tasks. The accuracy loss is more pronounced for tasks requiring fine-grained attention patterns, such as coreference resolution or syntactic parsing, where the low-rank approximation may miss subtle dependencies.

An important consideration for low-rank attention is that the projection matrices $\mE$ and $\mF$ introduce additional parameters and computation. For rank $r$ and sequence length $n$, the projections add $2rn$ parameters per layer. However, these projections can be implemented efficiently as 1D convolutions or learned position-wise projections, and the parameter cost is typically small compared to the savings in attention computation. The projection operations themselves require $O(rnd)$ computation, which is negligible compared to the $O(n^2d)$ cost of full attention for $r \ll n$.

Comprehensive Complexity Comparison

Understanding the trade-offs between different attention variants requires examining multiple dimensions: computational complexity, memory requirements, accuracy preservation, and practical implementation efficiency. The following analysis provides concrete comparisons across these dimensions for typical transformer configurations.

VariantTimeMemoryAccuracyMax LengthUse Case
Full Attention$O(n^2d)$$O(n^2)$100\%512-1024Standard tasks
Local Attention$O(nwd)$$O(nw)$98-99\%4096-8192Document processing
Sparse Attention$O(n\sqrt{n}d)$$O(n\sqrt{n})$97-99\%8192-16384Long documents
Linear Attention$O(nd^2)$$O(nd)$95-98\%16384+Extreme length
Low-Rank Attention$O(nrd)$$O(nr)$96-98\%4096-8192Compression

To make these complexity bounds concrete, consider processing sequences of varying lengths with BERT-base configuration ($d = 768$, 12 heads, $d_k = 64$ per head). The following table shows actual memory requirements for attention matrices across different sequence lengths and attention variants.

Variantn=512n=4096n=8192n=16384
Full Attention12.6 MB805 MB3.2 GB12.9 GB
Local Attention ($w=256$)6.3 MB50.3 MB101 MB201 MB
Sparse Attention ($\sqrt{n}$)1.1 MB12.6 MB35.7 MB101 MB
Linear Attention0.3 MB2.3 MB4.7 MB9.4 MB
Low-Rank ($r=256$)6.3 MB50.3 MB101 MB201 MB

The memory savings become dramatic for long sequences. At 16,384 tokens, full attention requires 12.9 GB per example—impossible to fit on most GPUs even with batch size 1. Local attention reduces this to 201 MB, enabling batch size 32 on a 40 GB A100 GPU. Linear attention requires only 9.4 MB, enabling batch sizes of several hundred even for very long sequences.

The computational cost comparison is equally striking. For a sequence of 8192 tokens with $d=768$ and 12 heads, full attention requires approximately 48.3 billion floating-point operations (FLOPs) per layer. Local attention with window 256 reduces this to 3.0 billion FLOPs (16x speedup), sparse attention to 6.0 billion FLOPs (8x speedup), linear attention to 4.5 billion FLOPs (10.7x speedup), and low-rank attention to 3.0 billion FLOPs (16x speedup). On an NVIDIA A100 GPU with 312 TFLOPS of FP16 throughput, full attention takes approximately 0.15 ms per layer, while efficient variants take 10-20 microseconds—enabling much faster inference and training.

The accuracy trade-offs vary by task and sequence length. For sequences up to 2048 tokens, local attention with window 512 typically matches full attention within 0.5\% on language modeling perplexity. Sparse attention with well-designed patterns achieves similar accuracy. Linear attention shows 2-3\% degradation, while low-rank attention with rank 256 shows 1-2\% degradation. For longer sequences exceeding 4096 tokens, the accuracy gaps widen slightly, but efficient variants remain highly competitive. Importantly, the accuracy loss is often task-dependent—some tasks like document classification are more tolerant of approximate attention than tasks like machine translation or question answering that require precise alignment.

Implementation Considerations

Implementing efficient attention variants requires careful consideration of hardware characteristics, numerical stability, and software frameworks. The theoretical complexity improvements do not always translate directly to wall-clock speedups due to GPU architecture constraints and implementation details.

Modern GPUs achieve peak performance on dense matrix multiplications with dimensions that are multiples of 16 or 32 (for tensor cores). Sparse attention patterns that result in irregular memory access or non-aligned dimensions can suffer significant performance degradation. For example, a naive implementation of sparse attention with random sparsity patterns may achieve only 30-40\% of the theoretical speedup due to poor memory coalescing and reduced arithmetic intensity. Block-sparse patterns that operate on 16x16 or 32x32 blocks achieve much better GPU utilization, typically reaching 60-80\% of theoretical speedup.

Memory bandwidth is often the limiting factor for attention computation, particularly for efficient variants. The attention mechanism is memory-bound rather than compute-bound for typical sequence lengths—the GPU spends more time loading data from memory than performing arithmetic operations. This means that reducing the number of operations (FLOPs) does not always proportionally reduce runtime. Efficient implementations must minimize memory transfers through kernel fusion, where multiple operations are combined into a single GPU kernel that keeps intermediate results in fast on-chip memory. FlashAttention demonstrates this principle by fusing the attention computation ($\mQ\mK\transpose$, softmax, multiply by $\mV$) into a single kernel that never materializes the full attention matrix in global memory, achieving 2-4x speedup over standard implementations even for full attention.

Numerical stability is a critical concern for efficient attention variants. The softmax operation in attention is numerically sensitive—subtracting the maximum value before exponentiation is essential to prevent overflow. Linear attention approximations must carefully handle the feature map computation to avoid numerical issues. The Performer's random Fourier features require computing exponentials of potentially large values, necessitating careful scaling and normalization. Low-rank attention must ensure that the projection matrices are well-conditioned to avoid amplifying numerical errors.

Framework support for efficient attention varies significantly. PyTorch and TensorFlow provide optimized implementations of standard attention through torch.nn.MultiheadAttention and tf.keras.layers.MultiHeadAttention, but efficient variants often require custom implementations. The xFormers library provides optimized implementations of several efficient attention variants, including memory-efficient attention and block-sparse attention. The Triton framework enables writing custom GPU kernels in Python that achieve performance comparable to hand-written CUDA, making it easier to implement and experiment with novel attention patterns. For production deployment, specialized libraries like FasterTransformer and TensorRT provide highly optimized implementations of common attention variants with automatic kernel selection based on input dimensions and hardware capabilities.

Exercises

Exercise 1: Implement cross-attention layer in PyTorch. Test with encoder output (length 10, dim 128) and decoder input (length 7, dim 128). Verify attention matrix shape is $7 \times 10$.
Exercise 2: Calculate the memory requirements for attention matrices in a BERT-base model (12 heads, $d_{\text{model}} = 768$) processing sequences of length 512, 2048, and 4096 tokens. Compare full attention, local attention with window size 256, and linear attention. How much memory is saved at each sequence length?
Exercise 3: Implement local attention with window size $w=128$ for a sequence of length 1024. Compare the computational cost (FLOPs) and memory usage to full attention. Measure actual runtime on GPU and explain any discrepancy between theoretical and observed speedup.
Exercise 4: Design a sparse attention pattern for document understanding that combines local attention (window 64), strided attention (stride 128), and global attention to the first token. Calculate the number of attention connections per token and total memory requirements for a 4096-token sequence. What percentage of full attention's connections does this pattern use?
Exercise 5: Implement linear attention using the feature map $\phi(\vx) = \text{elu}(\vx) + 1$. Compare attention patterns to standard softmax attention on a sample sequence. Measure the approximation error and identify cases where linear attention diverges most from full attention.
Exercise 6: For a transformer with 24 layers processing 8192-token sequences, calculate the total memory required for attention matrices using: (1) full attention, (2) local attention with window 512, (3) sparse attention with $\sqrt{n}$ connections per token, (4) linear attention, and (5) low-rank attention with rank 256. Assume 12 heads, $d_{\text{model}} = 1024$, batch size 8, and FP16 precision.
Exercise 7: Implement relative position bias as in T5. Use buckets: [0, 1, 2, 3, 4, 5-7, 8-15, 16-31, 32+]. Show how attention scores change with relative distance and compare to absolute position encodings.
Exercise 8: Analyze the trade-off between window size and accuracy for local attention. Train a small transformer on a language modeling task with window sizes [64, 128, 256, 512, full]. Plot perplexity vs window size and identify the point of diminishing returns. How does this relate to the average dependency length in the dataset?
Exercise 9: Create visualization showing: (1) Self-attention patterns for sentence "The quick brown fox jumps", (2) Effect of causal masking, (3) Difference between heads 1 and 12 in multi-head attention. What patterns emerge?
Exercise 10: Compare computational cost of: (1) Additive (Bahdanau) attention, (2) Multiplicative attention, (3) Scaled dot-product attention. For $n = 512$, $d_k = 64$, which is most efficient? How does the ranking change for $n = 4096$?

Solutions

Full solutions for all exercises are available at \url{https://deeplearning.hofkensvermeulen.be}.

Solution: Cross-attention PyTorch implementation:
class CrossAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.mha = MultiHeadAttention(d_model, num_heads)
        
    def forward(self, decoder_input, encoder_output):
        # Q from decoder, K and V from encoder
        return self.mha(decoder_input, encoder_output, encoder_output)

# Test
cross_attn = CrossAttention(d_model=128, num_heads=4)
decoder_in = torch.randn(1, 7, 128)  # length 7
encoder_out = torch.randn(1, 10, 128)  # length 10
output = cross_attn(decoder_in, encoder_out)
print(f"Output shape: {output.shape}")  # (1, 7, 128)
# Attention matrix shape internally: (1, 4, 7, 10)

The attention matrix has shape $7 \times 10$, showing how each of the 7 decoder positions attends to the 10 encoder positions.

Solution: For BERT-base (12 heads, $d=768$), batch size 1:

Full attention memory:

Local attention (window 256):

Linear attention: Memory: $O(d^2)$ instead of $O(n^2)$, approximately $12 \times 768^2 \times 2 \approx 14$ MB regardless of sequence length.

Savings increase dramatically with sequence length, making efficient attention essential for long contexts.

Solution: For local attention with window $w=128$ and sequence length $n=1024$:

Computational cost:

Memory usage:

Observed GPU speedup: Typically $5$-$6\times$ instead of theoretical $8\times$ due to:

Solution: Sparse attention pattern design:

For 4096-token sequence:

Memory requirements:

$$ 4096 \times 97 \times 2 \text{ bytes} = 794{,}624 \text{ bytes} \approx 0.76 \text{ MB} $$

Percentage of full attention:

$$ \frac{97}{4096} \approx 2.37\% $$

This sparse pattern uses only 2.37\% of full attention's connections while maintaining both local and long-range dependencies.

Solution: Linear attention with $\phi(\vx) = \text{elu}(\vx) + 1$:

Standard attention:

$$ \text{Attention}(\mQ, \mK, \mV) = \text{softmax}(\mQ\mK\transpose)\mV $$

Linear attention:

$$ \text{Attention}(\mQ, \mK, \mV) = \phi(\mQ)(\phi(\mK)\transpose \mV) $$

Approximation error: Linear attention diverges most when:

Cases of largest divergence:

Solution: For 24 layers, 8192 tokens, 12 heads, $d=1024$, batch size 8, FP16:

(1) Full attention:

$$ 24 \times 8 \times 12 \times 8192^2 \times 2 = 309{,}237{,}645{,}312 \text{ bytes} \approx 288 \text{ GB} $$

(2) Local attention (window 512):

$$ 24 \times 8 \times 12 \times 8192 \times 512 \times 2 = 19{,}327{,}352{,}832 \text{ bytes} \approx 18 \text{ GB} $$

(3) Sparse attention ($\sqrt{n} = 90$ connections):

$$ 24 \times 8 \times 12 \times 8192 \times 90 \times 2 = 3{,}397{,}286{,}400 \text{ bytes} \approx 3.2 \text{ GB} $$

(4) Linear attention:

$$ 24 \times 8 \times 12 \times 1024^2 \times 2 = 4{,}831{,}838{,}208 \text{ bytes} \approx 4.5 \text{ GB} $$

(5) Low-rank attention (rank 256):

$$ 24 \times 8 \times 12 \times 8192 \times 256 \times 2 = 9{,}663{,}676{,}416 \text{ bytes} \approx 9 \text{ GB} $$

Sparse attention provides the best memory efficiency for this configuration.

Solution: Due to space constraints, these exercises involve implementation and visualization tasks. Key points:

Exercise 7 (Relative position bias): T5 uses bucketed relative positions to limit parameter growth while capturing distance information. Attention scores decay with distance.

Exercise 8 (Window size trade-off): Perplexity improves rapidly up to window 256-512, then plateaus. Optimal window correlates with average dependency length in data.

Exercise 9 (Attention visualization): Self-attention shows syntactic patterns (subject-verb, determiner-noun). Causal masking creates triangular pattern. Different heads specialize in different linguistic phenomena.

Exercise 10 (Attention mechanism comparison): Scaled dot-product is most efficient for all sequence lengths due to optimized matrix multiplication. Additive attention has higher constant overhead.

← Chapter 8: Self-Attention and Multi-Head Attention 📚 Table of Contents Chapter 10: The Transformer Model →