Attention Variants and Mechanisms
Chapter Overview
Beyond standard scaled dot-product attention, numerous variants have been developed for specific use cases and improved efficiency. This chapter explores cross-attention for encoder-decoder models, soft vs hard attention, attention with relative position representations, and practical considerations for implementing attention mechanisms.
Learning Objectives
- Distinguish between self-attention and cross-attention
- Understand relative position representations
- Implement attention with different scoring functions
- Apply attention masking for various scenarios
- Understand attention dropout and layer normalization
- Visualize and interpret attention patterns
Cross-Attention
Dimensions:
- Decoder input: $\mX_{\text{dec}} \in \R^{m \times d}$ ($m$ decoder positions)
- Encoder output: $\mX_{\text{enc}} \in \R^{n \times d}$ ($n$ encoder positions)
- Attention matrix: $\mA \in \R^{m \times n}$ (decoder $\times$ encoder)
- Output: $\R^{m \times d_v}$ (same decoder length)
French target: "Le chat" (2 tokens so far, $\mX_{\text{dec}} \in \R^{2 \times 512}$)
Cross-attention computes:
where $\alpha_{1,j}$ = attention from decoder position 1 ("Le") to encoder position $j$.
When generating "Le" (the), model should attend strongly to "The" in source.
When generating "chat" (cat), model should attend strongly to "cat" in source.
Transformer Decoder Attention Layers
A transformer decoder block contains three attention mechanisms:
- Masked self-attention: Decoder attends to previous decoder positions
$$ \mQ = \mK = \mV = \mX_{\text{dec}} \quad \text{(with causal mask)} $$
- Cross-attention: Decoder attends to encoder output
$$ \mQ = \mX_{\text{dec}}, \quad \mK = \mV = \mX_{\text{enc}} $$
- Feed-forward: Position-wise MLP (not attention)
Relative Position Representations
Problem with absolute positions: Model learns positions 0-512 during training. How to handle position 600 at inference?
Solution: Relative position representationsâencode distance between positions, not absolute positions.
Shaw et al. Relative Attention
Advantages:
- Generalize to longer sequences
- Model learns distance-based patterns
- More parameter efficient
T5 Relative Position Bias
T5 uses even simpler approachâadd learned bias based on relative position:
Alternative Attention Scoring Functions
Beyond the scaled dot-product used in transformers, several alternative scoring functions exist---additive (Bahdanau), multiplicative (Luong), and general bilinear forms---each with different trade-offs between expressiveness and computational efficiency. These are defined and compared in Chapter~[ref] (Section~7.3). In practice, scaled dot-product attention dominates in transformer architectures due to its hardware-efficient batched matrix multiplication and strong empirical performance.
Attention Masking
\node[node] (e1) at (0,0) {$x_1$}; \node[node] (e2) at (2,0) {$x_2$}; \node[node] (e3) at (4,0) {$x_3$};
\draw[arrow, blue!60] (e1) to[bend left=20] (e2); \draw[arrow, blue!60] (e2) to[bend left=20] (e1); \draw[arrow, blue!60] (e2) to[bend left=20] (e3); \draw[arrow, blue!60] (e3) to[bend left=20] (e2); \draw[arrow, blue!60] (e1) to[bend left=30] (e3); \draw[arrow, blue!60] (e3) to[bend left=30] (e1); \draw[arrow, blue!60] (e1) to[loop left] (e1); \draw[arrow, blue!60] (e2) to[loop above] (e2); \draw[arrow, blue!60] (e3) to[loop right] (e3);
\begin{scope}[shift={(7,0)}] \node[node] (d1) at (0,0) {$x_1$}; \node[node] (d2) at (2,0) {$x_2$}; \node[node] (d3) at (4,0) {$x_3$};
\draw[arrow, green!60] (d1) to[loop left] (d1); \draw[arrow, green!60] (d1) to[bend left=20] (d2); \draw[arrow, green!60] (d2) to[loop above] (d2); \draw[arrow, green!60] (d1) to[bend left=30] (d3); \draw[arrow, green!60] (d2) to[bend left=20] (d3); \draw[arrow, green!60] (d3) to[loop right] (d3);
\draw[blocked] (d2) to[bend right=20] (d1); \draw[blocked] (d3) to[bend right=20] (d2); \draw[blocked] (d3) to[bend right=30] (d1);
\end{scope}
\begin{scope}[shift={(0,-3)}] \foreach \i in {1,2,3} { \foreach \j in {1,2,3} { \fill[blue!30] (\j*0.8-0.8, -\i*0.8+0.8) rectangle (\j*0.8-0.4, -\i*0.8+0.4); } } \end{scope}
\begin{scope}[shift={(7,-4)}] \foreach \i in {1,2,3} { \foreach \j in {1,2,3} { \pgfmathtruncatemacro{\valid}{\j <= \i ? 1 : 0} \ifnum\valid=1 \fill[green!30] (\j*0.8-0.8, -\i*0.8+0.8) rectangle (\j*0.8-0.4, -\i*0.8+0.4); \else \fill[red!30] (\j*0.8-0.8, -\i*0.8+0.8) rectangle (\j*0.8-0.4, -\i*0.8+0.4); \node at (\j*0.8-0.6, -\i*0.8+0.6) {\footnotesize $-\infty$}; \fi } } \end{scope}
\end{tikzpicture}
Padding Mask
For variable-length sequences in batch, mask padding tokens:
Mask for Seq 1:
Prevents attending to padding tokens.
Combined Masks
For decoder, combine causal mask and padding mask:
Element-wise, use most restrictive: if either mask blocks, result blocks.
Attention Dropout
Apply dropout to attention weights for regularization:
Typical dropout rate: 0.1 (10\%)
Effect: Randomly zero out some attention connections, preventing over-reliance on specific positions.
Layer Normalization with Attention
Two architectures for combining attention with layer norm:
Post-Norm (Original Transformer)
Pre-Norm (More Common Now)
Pre-norm advantages:
- More stable training
- Easier gradient flow
- Used in GPT-2, GPT-3, modern transformers
Visualizing Attention
Attention weights $\mA \in \R^{n \times n}$ reveal what model attends to:
Attention Heatmaps
For sentence "The cat sat on the mat":
- Row $i$: attention distribution when processing token $i$
- Bright cell $(i,j)$: token $i$ strongly attends to token $j$
Patterns observed:
- Diagonal: Attending to self
- Vertical lines: Attending to specific important words (e.g., subject, verb)
- Symmetric patterns: Mutual attention between related words
- Head-specific patterns: Different heads learn different relationships
Interpreting Multiple Heads
In 12-head attention, different heads specialize:
- Some heads attend to adjacent words (local syntax)
- Some heads attend to distant words (long-range dependencies)
- Some heads attend to specific parts of speech
- Some heads attend based on semantic similarity
Practical Implementation Considerations
Memory-Efficient Attention
For very long sequences, store attention matrix in chunks:
- Compute $\mQ \mK\transpose$ for chunk of queries
- Apply softmax
- Multiply by $\mV$ chunk
- Accumulate results
Reduces peak memory from $O(n^2)$ to $O(nc)$ where $c$ is chunk size.
Fused Attention Kernels
Modern implementations fuse operations:
Single fused kernel faster than separate operations (fewer memory transfers).
Example: FlashAttention achieves 2-4x speedup through fused operations and memory hierarchy optimization.
Efficient Attention Variants
The standard self-attention mechanism has computational complexity $O(n^2d)$ and memory complexity $O(n^2)$, where $n$ is the sequence length and $d$ is the model dimension. This quadratic scaling in sequence length becomes prohibitive for long sequences. For a sequence of length 4096 with 12 attention heads, the attention matrices alone require $12 \times 4096^2 \times 4 = 805$ MB in FP32 format per example. With batch size 32, this amounts to 25.8 GB just for attention weights, exceeding the memory capacity of most GPUs. This fundamental limitation has motivated extensive research into efficient attention variants that reduce the quadratic complexity while maintaining model quality.
The key insight underlying efficient attention is that not all token pairs require equal attention. In practice, attention patterns often exhibit structureâtokens primarily attend to nearby tokens, specific global tokens, or sparse subsets of the sequence. By exploiting this structure, efficient attention mechanisms can dramatically reduce computational and memory requirements while preserving most of the modeling capacity of full attention. The following sections examine the major classes of efficient attention variants, analyzing their complexity trade-offs, implementation considerations, and practical use cases.
Local Attention
Local attention restricts each token to attend only to tokens within a fixed window around its position, rather than attending to all tokens in the sequence. For a window size $w$, token at position $i$ attends only to positions $[i-w/2, i+w/2]$. This reduces the attention matrix from $n \times n$ to $n \times w$, yielding linear scaling in sequence length.
The computational complexity of local attention is $O(nwd)$, where $n$ is sequence length, $w$ is window size, and $d$ is model dimension. Compared to standard attention's $O(n^2d)$, this represents a reduction factor of $n/w$. For a sequence of length 4096 with window size 256, local attention is 16 times faster than full attention. The memory complexity similarly reduces from $O(n^2)$ to $O(nw)$, enabling much longer sequences to fit in GPU memory. For the same 4096-token sequence with 12 heads, local attention with window 256 requires only $12 \times 4096 \times 256 \times 4 = 50.3$ MB per example, a 16-fold reduction from the 805 MB required by full attention.
The primary trade-off of local attention is the loss of long-range dependencies. Tokens separated by more than $w/2$ positions cannot directly attend to each other, requiring information to propagate through multiple layers. In practice, this limitation is often acceptable. Many natural language tasks exhibit strong localityâsyntactic dependencies are typically short-range, and semantic relationships can be captured through multiple layers of local attention. Empirical studies show that local attention with window size 256-512 typically achieves 98-99\% of full attention's accuracy on language modeling tasks, while enabling sequences 10-20 times longer.
The Longformer architecture demonstrates effective use of local attention for document-level understanding. Longformer combines local windowed attention for most tokens with global attention for special tokens like [CLS] and task-specific tokens. This hybrid approach maintains $O(n)$ complexity while allowing critical tokens to aggregate information from the entire sequence. On document classification tasks with 4096-token inputs, Longformer achieves comparable accuracy to BERT while processing sequences 8 times longer. The local attention pattern also enables efficient implementation on GPUs through blocked matrix operations, achieving 2-3x speedup over naive implementations.
Sparse Attention
Sparse attention generalizes local attention by allowing each token to attend to a sparse subset of positions according to a predefined pattern, rather than a contiguous window. The key insight is that attention patterns in trained transformers often exhibit structureâcertain positions are consistently important while others receive minimal attention. By designing sparsity patterns that capture this structure, sparse attention can dramatically reduce computation while maintaining model quality.
Several sparsity patterns have proven effective in practice. Strided attention divides the sequence into blocks and allows each token to attend within its block and to every $k$-th token globally, where $k$ is the stride. This pattern captures both local context and evenly-spaced global context. Fixed attention combines local attention with attention to a fixed set of global tokens, similar to Longformer. Learned sparse attention uses a separate network to predict which positions each token should attend to, adapting the sparsity pattern to the input. The Sparse Transformer architecture uses a factorized attention pattern where each token attends to positions in a strided pattern in one head and a local pattern in another head, allowing information to flow efficiently across the sequence.
The computational complexity of sparse attention is $O(n \sqrt{n} d)$ for typical sparsity patterns, where each token attends to approximately $\sqrt{n}$ other tokens. This represents a substantial improvement over full attention's $O(n^2 d)$, particularly for long sequences. For a sequence of length 4096, sparse attention with $\sqrt{n} = 64$ positions per token is 64 times faster than full attention. The memory complexity is similarly $O(n \sqrt{n})$, enabling sequences that would be impossible with full attention. For 4096 tokens with 12 heads, sparse attention requires approximately $12 \times 4096 \times 64 \times 4 = 12.6$ MB per example, a 64-fold reduction from full attention's 805 MB.
The accuracy trade-off of sparse attention depends critically on the choice of sparsity pattern. Well-designed patterns that align with the task's dependency structure can achieve 97-99\% of full attention's accuracy. The Sparse Transformer achieves perplexity within 0.1 of full attention on language modeling while using only $\sqrt{n}$ attention per token. BigBird, which combines local, global, and random attention patterns, matches BERT's accuracy on question answering and document classification while processing sequences up to 8 times longer. However, poorly chosen sparsity patterns can significantly degrade accuracy, particularly on tasks requiring long-range reasoning.
Implementation of sparse attention on GPUs presents challenges because modern GPUs are optimized for dense matrix operations. Sparse matrix multiplication is less efficient than dense multiplication due to irregular memory access patterns and reduced arithmetic intensity. Specialized kernels and libraries like cuSPARSE can partially mitigate this, but sparse attention typically achieves only 50-70\% of the theoretical speedup in practice. Recent work on block-sparse attention, which operates on blocks of the attention matrix rather than individual elements, achieves better GPU utilization by maintaining some regularity in memory access patterns. The Triton framework enables efficient implementation of custom sparse attention patterns through automatic optimization of memory access.
Linear Attention
Linear attention achieves $O(nd^2)$ complexity by reformulating the attention computation to avoid explicitly constructing the $n \times n$ attention matrix. The key insight is that attention can be viewed as a kernel operation, and by choosing an appropriate kernel function, the computation can be reordered to compute the output directly without materializing the full attention matrix.
The standard attention computation is:
This requires computing $\mQ \mK\transpose \in \R^{n \times n}$ before applying softmax and multiplying by $\mV$. Linear attention approximates the softmax kernel with a feature map $\phi: \R^{d_k} \to \R^{d'}$ such that:
With this approximation, attention becomes:
The crucial observation is that the parentheses can be reordered. Instead of computing $\phi(\mQ) \phi(\mK)\transpose$ (which is $n \times n$) and then multiplying by $\mV$, we first compute $\phi(\mK)\transpose \mV \in \R^{d' \times d_v}$ and then multiply by $\phi(\mQ)$. This reordering changes complexity from $O(n^2 d)$ to $O(n d'^2)$, where $d'$ is the feature dimension (typically equal to $d_k$).
The computational savings of linear attention are substantial for long sequences. For sequence length 4096 and model dimension 768, standard attention requires approximately $4096^2 \times 768 = 12.9$ billion operations per head, while linear attention requires $4096 \times 768^2 = 2.4$ billion operationsâa 5.4x reduction. The memory complexity is even more favorable: linear attention requires only $O(nd)$ memory for the intermediate $\phi(\mK)\transpose \mV$ matrix, compared to $O(n^2)$ for the full attention matrix. For 4096 tokens with 12 heads, linear attention requires approximately $12 \times 768 \times 768 \times 4 = 28.3$ MB, compared to 805 MB for full attentionâa 28-fold reduction.
The primary challenge of linear attention is choosing a feature map $\phi$ that accurately approximates the softmax kernel while remaining computationally efficient. The Performer architecture uses random Fourier features with $\phi(\vx) = \exp(\vx^2/2) [\cos(\omega_1\transpose \vx), \sin(\omega_1\transpose \vx), \ldots]$ where $\omega_i$ are random projection vectors. This provides an unbiased approximation of the softmax kernel with controllable accuracy based on the number of random features. The Linear Transformer uses a simpler feature map $\phi(\vx) = \text{elu}(\vx) + 1$, which is faster to compute but provides a looser approximation.
The accuracy trade-off of linear attention is more significant than local or sparse attention. Empirical studies show that linear attention typically achieves 95-98\% of full attention's accuracy on language modeling, with larger degradation on tasks requiring precise attention patterns. The approximation error is particularly noticeable for small attention weightsâthe softmax function's sharp peaking is difficult to approximate with simple feature maps. However, for applications where extreme sequence length is critical, such as processing entire books or long-form video, the 2-5\% accuracy loss is often acceptable given the dramatic computational savings. Recent work on learned feature maps and adaptive kernel approximations aims to close this accuracy gap while maintaining linear complexity.
Low-Rank Attention
Low-rank attention exploits the observation that attention matrices in trained transformers often have low effective rankâmost of the variance is captured by a small number of singular values. By explicitly factorizing the attention computation through a low-dimensional bottleneck, low-rank attention reduces complexity from $O(n^2 d)$ to $O(nrd)$, where $r$ is the rank and typically $r \ll n$.
The Linformer architecture implements low-rank attention by projecting the keys and values to a lower-dimensional space before computing attention. Specifically, Linformer adds projection matrices $\mE, \mF \in \R^{r \times n}$ that reduce the sequence length dimension:
The key insight is that $\mE \mK \in \R^{r \times d_k}$ and $\mF \mV \in \R^{r \times d_v}$ have reduced sequence length $r$ instead of $n$. The attention matrix is now $n \times r$ instead of $n \times n$, reducing both computation and memory by a factor of $n/r$.
For sequence length 4096 and rank 256, low-rank attention reduces computation from $4096^2 \times 768 = 12.9$ billion operations to $4096 \times 256 \times 768 = 805$ million operations per headâa 16-fold reduction. The memory savings are equally dramatic: the attention matrix requires $4096 \times 256 \times 4 = 4.2$ MB per head instead of $4096^2 \times 4 = 67.1$ MB, a 16-fold reduction. With 12 heads, total attention memory drops from 805 MB to 50.3 MB per example.
The accuracy of low-rank attention depends on the choice of rank $r$ and the projection matrices $\mE$ and $\mF$. Linformer uses learned projection matrices that are shared across all layers, reducing the parameter overhead. Empirical studies show that rank $r = 256$ achieves 96-98\% of full attention's accuracy for sequences up to 4096 tokens, with minimal degradation on most language understanding tasks. The accuracy loss is more pronounced for tasks requiring fine-grained attention patterns, such as coreference resolution or syntactic parsing, where the low-rank approximation may miss subtle dependencies.
An important consideration for low-rank attention is that the projection matrices $\mE$ and $\mF$ introduce additional parameters and computation. For rank $r$ and sequence length $n$, the projections add $2rn$ parameters per layer. However, these projections can be implemented efficiently as 1D convolutions or learned position-wise projections, and the parameter cost is typically small compared to the savings in attention computation. The projection operations themselves require $O(rnd)$ computation, which is negligible compared to the $O(n^2d)$ cost of full attention for $r \ll n$.
Comprehensive Complexity Comparison
Understanding the trade-offs between different attention variants requires examining multiple dimensions: computational complexity, memory requirements, accuracy preservation, and practical implementation efficiency. The following analysis provides concrete comparisons across these dimensions for typical transformer configurations.
| Variant | Time | Memory | Accuracy | Max Length | Use Case |
|---|---|---|---|---|---|
| Full Attention | $O(n^2d)$ | $O(n^2)$ | 100\% | 512-1024 | Standard tasks |
| Local Attention | $O(nwd)$ | $O(nw)$ | 98-99\% | 4096-8192 | Document processing |
| Sparse Attention | $O(n\sqrt{n}d)$ | $O(n\sqrt{n})$ | 97-99\% | 8192-16384 | Long documents |
| Linear Attention | $O(nd^2)$ | $O(nd)$ | 95-98\% | 16384+ | Extreme length |
| Low-Rank Attention | $O(nrd)$ | $O(nr)$ | 96-98\% | 4096-8192 | Compression |
To make these complexity bounds concrete, consider processing sequences of varying lengths with BERT-base configuration ($d = 768$, 12 heads, $d_k = 64$ per head). The following table shows actual memory requirements for attention matrices across different sequence lengths and attention variants.
| Variant | n=512 | n=4096 | n=8192 | n=16384 |
|---|---|---|---|---|
| Full Attention | 12.6 MB | 805 MB | 3.2 GB | 12.9 GB |
| Local Attention ($w=256$) | 6.3 MB | 50.3 MB | 101 MB | 201 MB |
| Sparse Attention ($\sqrt{n}$) | 1.1 MB | 12.6 MB | 35.7 MB | 101 MB |
| Linear Attention | 0.3 MB | 2.3 MB | 4.7 MB | 9.4 MB |
| Low-Rank ($r=256$) | 6.3 MB | 50.3 MB | 101 MB | 201 MB |
The memory savings become dramatic for long sequences. At 16,384 tokens, full attention requires 12.9 GB per exampleâimpossible to fit on most GPUs even with batch size 1. Local attention reduces this to 201 MB, enabling batch size 32 on a 40 GB A100 GPU. Linear attention requires only 9.4 MB, enabling batch sizes of several hundred even for very long sequences.
The computational cost comparison is equally striking. For a sequence of 8192 tokens with $d=768$ and 12 heads, full attention requires approximately 48.3 billion floating-point operations (FLOPs) per layer. Local attention with window 256 reduces this to 3.0 billion FLOPs (16x speedup), sparse attention to 6.0 billion FLOPs (8x speedup), linear attention to 4.5 billion FLOPs (10.7x speedup), and low-rank attention to 3.0 billion FLOPs (16x speedup). On an NVIDIA A100 GPU with 312 TFLOPS of FP16 throughput, full attention takes approximately 0.15 ms per layer, while efficient variants take 10-20 microsecondsâenabling much faster inference and training.
The accuracy trade-offs vary by task and sequence length. For sequences up to 2048 tokens, local attention with window 512 typically matches full attention within 0.5\% on language modeling perplexity. Sparse attention with well-designed patterns achieves similar accuracy. Linear attention shows 2-3\% degradation, while low-rank attention with rank 256 shows 1-2\% degradation. For longer sequences exceeding 4096 tokens, the accuracy gaps widen slightly, but efficient variants remain highly competitive. Importantly, the accuracy loss is often task-dependentâsome tasks like document classification are more tolerant of approximate attention than tasks like machine translation or question answering that require precise alignment.
Implementation Considerations
Implementing efficient attention variants requires careful consideration of hardware characteristics, numerical stability, and software frameworks. The theoretical complexity improvements do not always translate directly to wall-clock speedups due to GPU architecture constraints and implementation details.
Modern GPUs achieve peak performance on dense matrix multiplications with dimensions that are multiples of 16 or 32 (for tensor cores). Sparse attention patterns that result in irregular memory access or non-aligned dimensions can suffer significant performance degradation. For example, a naive implementation of sparse attention with random sparsity patterns may achieve only 30-40\% of the theoretical speedup due to poor memory coalescing and reduced arithmetic intensity. Block-sparse patterns that operate on 16x16 or 32x32 blocks achieve much better GPU utilization, typically reaching 60-80\% of theoretical speedup.
Memory bandwidth is often the limiting factor for attention computation, particularly for efficient variants. The attention mechanism is memory-bound rather than compute-bound for typical sequence lengthsâthe GPU spends more time loading data from memory than performing arithmetic operations. This means that reducing the number of operations (FLOPs) does not always proportionally reduce runtime. Efficient implementations must minimize memory transfers through kernel fusion, where multiple operations are combined into a single GPU kernel that keeps intermediate results in fast on-chip memory. FlashAttention demonstrates this principle by fusing the attention computation ($\mQ\mK\transpose$, softmax, multiply by $\mV$) into a single kernel that never materializes the full attention matrix in global memory, achieving 2-4x speedup over standard implementations even for full attention.
Numerical stability is a critical concern for efficient attention variants. The softmax operation in attention is numerically sensitiveâsubtracting the maximum value before exponentiation is essential to prevent overflow. Linear attention approximations must carefully handle the feature map computation to avoid numerical issues. The Performer's random Fourier features require computing exponentials of potentially large values, necessitating careful scaling and normalization. Low-rank attention must ensure that the projection matrices are well-conditioned to avoid amplifying numerical errors.
Framework support for efficient attention varies significantly. PyTorch and TensorFlow provide optimized implementations of standard attention through torch.nn.MultiheadAttention and tf.keras.layers.MultiHeadAttention, but efficient variants often require custom implementations. The xFormers library provides optimized implementations of several efficient attention variants, including memory-efficient attention and block-sparse attention. The Triton framework enables writing custom GPU kernels in Python that achieve performance comparable to hand-written CUDA, making it easier to implement and experiment with novel attention patterns. For production deployment, specialized libraries like FasterTransformer and TensorRT provide highly optimized implementations of common attention variants with automatic kernel selection based on input dimensions and hardware capabilities.
Exercises
Solutions
Full solutions for all exercises are available at \url{https://deeplearning.hofkensvermeulen.be}.
class CrossAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
self.mha = MultiHeadAttention(d_model, num_heads)
def forward(self, decoder_input, encoder_output):
# Q from decoder, K and V from encoder
return self.mha(decoder_input, encoder_output, encoder_output)
# Test
cross_attn = CrossAttention(d_model=128, num_heads=4)
decoder_in = torch.randn(1, 7, 128) # length 7
encoder_out = torch.randn(1, 10, 128) # length 10
output = cross_attn(decoder_in, encoder_out)
print(f"Output shape: {output.shape}") # (1, 7, 128)
# Attention matrix shape internally: (1, 4, 7, 10)
The attention matrix has shape $7 \times 10$, showing how each of the 7 decoder positions attends to the 10 encoder positions.
Full attention memory:
- $n=512$: $12 \times 512^2 \times 2 = 6{,}291{,}456$ bytes $\approx 6$ MB
- $n=2048$: $12 \times 2048^2 \times 2 = 100{,}663{,}296$ bytes $\approx 96$ MB
- $n=4096$: $12 \times 4096^2 \times 2 = 402{,}653{,}184$ bytes $\approx 384$ MB
Local attention (window 256):
- $n=512$: $12 \times 512 \times 256 \times 2 = 3{,}145{,}728$ bytes $\approx 3$ MB (50\% savings)
- $n=2048$: $12 \times 2048 \times 256 \times 2 = 12{,}582{,}912$ bytes $\approx 12$ MB (87.5\% savings)
- $n=4096$: $12 \times 4096 \times 256 \times 2 = 25{,}165{,}824$ bytes $\approx 24$ MB (93.75\% savings)
Linear attention: Memory: $O(d^2)$ instead of $O(n^2)$, approximately $12 \times 768^2 \times 2 \approx 14$ MB regardless of sequence length.
Savings increase dramatically with sequence length, making efficient attention essential for long contexts.
Computational cost:
- Full attention: $2n^2d_k = 2 \times 1024^2 \times 64 = 134{,}217{,}728$ FLOPs
- Local attention: $2nwd_k = 2 \times 1024 \times 128 \times 64 = 16{,}777{,}216$ FLOPs
- Theoretical speedup: $\frac{n}{w} = \frac{1024}{128} = 8\times$
Memory usage:
- Full: $n^2 = 1{,}048{,}576$ elements
- Local: $n \times w = 131{,}072$ elements
- Memory reduction: $8\times$
Observed GPU speedup: Typically $5$-$6\times$ instead of theoretical $8\times$ due to:
- Kernel launch overhead
- Less efficient memory access patterns
- Reduced parallelism for smaller operations
For 4096-token sequence:
- Local attention (window 64): $64$ connections per token
- Strided attention (stride 128): $\frac{4096}{128} = 32$ connections per token
- Global attention to first token: $1$ connection per token
- Total: $64 + 32 + 1 = 97$ connections per token
Memory requirements:
Percentage of full attention:
This sparse pattern uses only 2.37\% of full attention's connections while maintaining both local and long-range dependencies.
Standard attention:
Linear attention:
Approximation error: Linear attention diverges most when:
- Attention should be highly peaked (one dominant position)
- Softmax creates sharp distinctions that linear kernel cannot capture
- Typical error: 5-15\% in attention weight distribution
Cases of largest divergence:
- Copying tasks requiring precise attention to single token
- Syntactic dependencies with clear head-dependent relationships
- Tasks requiring hard attention decisions
(1) Full attention:
(2) Local attention (window 512):
(3) Sparse attention ($\sqrt{n} = 90$ connections):
(4) Linear attention:
(5) Low-rank attention (rank 256):
Sparse attention provides the best memory efficiency for this configuration.
Exercise 7 (Relative position bias): T5 uses bucketed relative positions to limit parameter growth while capturing distance information. Attention scores decay with distance.
Exercise 8 (Window size trade-off): Perplexity improves rapidly up to window 256-512, then plateaus. Optimal window correlates with average dependency length in data.
Exercise 9 (Attention visualization): Self-attention shows syntactic patterns (subject-verb, determiner-noun). Causal masking creates triangular pattern. Different heads specialize in different linguistic phenomena.
Exercise 10 (Attention mechanism comparison): Scaled dot-product is most efficient for all sequence lengths due to optimized matrix multiplication. Additive attention has higher constant overhead.