Efficient Transformers

Chapter Overview

Standard transformers have $O(n^2)$ complexity in sequence length, limiting their application to long sequences. This chapter covers efficient attention mechanisms that reduce complexity: sparse attention, linear attention, low-rank methods, and kernel-based approaches.

Learning Objectives

  1. Understand the quadratic bottleneck in standard attention
  2. Implement sparse attention patterns (sliding window, strided, global)
  3. Apply Linformer and Performer for linear complexity
  4. Use Flash Attention for memory-efficient computation
  5. Compare trade-offs: accuracy vs efficiency vs memory
  6. Deploy long-context models (Longformer, BigBird)

The Quadratic Bottleneck

Complexity Analysis

The standard self-attention mechanism computes attention scores between all pairs of tokens in a sequence, leading to computational and memory requirements that scale quadratically with sequence length. The attention operation is defined as:

$$ \text{Attention}(\mQ, \mK, \mV) = \text{softmax}\left(\frac{\mQ \mK\transpose}{\sqrt{d_k}}\right) \mV $$
where $\mQ, \mK, \mV \in \R^{n \times d}$ represent the query, key, and value matrices for a sequence of length $n$ with model dimension $d$.

The computational bottleneck arises from computing the attention matrix $\mQ \mK\transpose \in \R^{n \times n}$, which requires $O(n^2 d)$ floating-point operations. For each of the $n$ queries, we compute dot products with all $n$ keys, where each dot product involves $d$ multiplications and additions. The subsequent softmax normalization adds $O(n^2)$ operations, and the final multiplication with values $\mV$ requires another $O(n^2 d)$ operations. The memory bottleneck is equally severe: storing the attention matrix requires $O(n^2)$ memory, which must be materialized before the softmax operation and retained for the backward pass during training.

This quadratic scaling becomes prohibitive for long sequences. Consider a BERT-base model with $d = 768$ and 12 attention heads processing a sequence of length $n = 4096$. Each attention head must store an attention matrix of size $4096 \times 4096$, requiring $4096^2 \times 4 = 67$ MB in FP32 format. Across all 12 heads, this amounts to 804 MB just for attention weights in a single layer. With 12 layers, the total memory for attention matrices alone reaches 9.6 GB, nearly filling an NVIDIA V100 GPU with 16 GB memory before accounting for activations, gradients, or model parameters.

Example: The quadratic scaling of attention memory becomes dramatically worse as sequence length increases. For a single attention head with $d = 768$ in FP32 format, the memory requirements grow as follows:

For $n = 512$ tokens (BERT's original limit), the attention matrix requires $512^2 \times 4 = 1.05$ MB per head. This is manageable even with multiple layers and batch processing. However, increasing to $n = 2048$ tokens requires $2048^2 \times 4 = 16.8$ MB per head—a 16× increase for only a 4× increase in sequence length. At $n = 4096$ tokens, memory consumption reaches 67 MB per head, and at $n = 8192$ tokens, it explodes to 268 MB per head—a 256× increase compared to the 512-token baseline.

With 12 attention heads and 12 layers, processing a single sequence of 8192 tokens requires $268 \times 12 \times 12 = 38.6$ GB just for attention matrices, exceeding the capacity of even high-end GPUs like the A100 with 40 GB memory. This fundamental limitation explains why BERT restricts sequences to 512 tokens, GPT-2 to 1024 tokens, and why efficient attention mechanisms are essential for processing long documents, genomic sequences, or high-resolution images.

The computational cost follows a similar pattern. On an NVIDIA A100 GPU with 312 TFLOPS of FP16 performance, computing attention for $n = 512$ takes approximately 8 milliseconds per layer. For $n = 4096$, this increases to 98 milliseconds—a 12× slowdown for an 8× increase in length. At $n = 16384$, attention computation takes 1.5 seconds per layer, making training completely impractical without efficient attention mechanisms.

Sparse Attention Patterns

Efficiency Taxonomy

Efficient attention mechanisms can be categorized into five main approaches, each targeting different aspects of the quadratic bottleneck. Sparse attention methods reduce the number of attention connections by restricting each query to attend to only a subset of keys, achieving $O(n \times k)$ complexity where $k \ll n$. Linear attention methods use mathematical approximations to avoid computing the full attention matrix, achieving $O(n)$ complexity in sequence length. Low-rank methods project keys and values to lower-dimensional spaces, reducing the effective size of the attention computation. Kernel-based methods reformulate attention using kernel functions and random features to enable linear-time computation. Finally, recurrent methods process sequences in chunks with recurrent connections, trading parallelism for reduced memory.

Each approach involves different trade-offs between computational efficiency, memory usage, approximation quality, and implementation complexity. Sparse methods maintain exact attention within their connectivity pattern but may miss important long-range dependencies. Linear methods achieve impressive speedups but introduce approximation errors that can degrade model quality. Low-rank methods work well when attention patterns have inherent low-rank structure but may fail for complex attention distributions. Understanding these trade-offs is essential for selecting the appropriate efficient attention mechanism for a given application.

Fixed Sparse Patterns

Sparse attention restricts each query to attend to only a subset of keys, dramatically reducing both computation and memory requirements. The fundamental idea is to identify which attention connections are most important and compute only those, setting all other attention weights to zero or negative infinity before the softmax operation.

Definition: Sparse attention restricts attention to a predefined subset of positions $\mathcal{S}$. For each query position $i$, we compute attention only over keys in the set $\mathcal{S}(i) \subseteq \{1, \ldots, n\}$:
$$ \text{Attention}_{\text{sparse}}(\mQ, \mK, \mV)_{ij} = \begin{cases} \text{Attention}(\mQ, \mK, \mV)_{ij} & \text{if } j \in \mathcal{S}(i) \\ 0 & \text{otherwise} \end{cases} $$
where $|\mathcal{S}(i)| = k \ll n$ for all positions $i$. The computational complexity reduces from $O(n^2 d)$ to $O(nkd)$, and memory requirements decrease from $O(n^2)$ to $O(nk)$.

The choice of sparsity pattern $\mathcal{S}$ determines which information can flow through the network. Three fundamental patterns have emerged as particularly effective building blocks for sparse attention.

The sliding window or local attention pattern restricts each token to attend only to nearby tokens within a fixed window. Formally, $\mathcal{S}_{\text{local}}(i) = \{j : |i-j| \leq w\}$ where $w$ is the window size. Each token attends to $2w+1$ tokens: itself, $w$ tokens before, and $w$ tokens after. This pattern is motivated by the observation that in many domains, particularly natural language, nearby tokens are more relevant than distant ones. For a window size $w = 256$ and sequence length $n = 4096$, each query attends to only 513 keys instead of 4096, reducing computation by 8× and memory by the same factor. The limitation is that information can only propagate $w$ positions per layer, requiring $\lceil n/w \rceil$ layers for full sequence communication.

The strided or dilated attention pattern samples tokens at regular intervals: $\mathcal{S}_{\text{strided}}(i) = \{j : (i-j) \bmod s = 0\}$ where $s$ is the stride. This pattern allows each token to attend to distant tokens, enabling faster information propagation across the sequence. With stride $s = 64$, a token at position 1024 can attend to positions 0, 64, 128, ..., 1024, ..., 4032, providing long-range connectivity with only $n/s$ connections per query. However, strided attention alone misses local context, so it is typically combined with local attention in alternating layers.

Global attention designates certain tokens as global tokens that attend to all positions and are attended to by all positions. These tokens act as information hubs, aggregating information from the entire sequence and broadcasting it back. In practice, special tokens like [CLS] in BERT or separator tokens are often designated as global. For $g$ global tokens in a sequence of length $n$, each global token requires $O(n)$ computation, and each regular token requires $O(g)$ additional computation to attend to globals, adding $O(ng)$ total cost.

Example: Longformer combines local and global attention to process documents up to 4096 tokens efficiently. All tokens use local attention with window size $w = 512$, allowing each token to attend to 1024 neighboring tokens (512 on each side). Additionally, task-specific tokens such as [CLS] for classification or question tokens for question answering are designated as global tokens that attend to and are attended by all positions.
\begin{tikzpicture}[ node/.style={circle, draw, minimum size=0.6cm, font=\footnotesize}, global/.style={circle, draw, minimum size=0.6cm, font=\footnotesize, fill=yellow!30}, arrow/.style={->, thick}, local/.style={->, thick, blue}, globalconn/.style={->, thick, red} ]

\node[global] (cls) at (0,0) {CLS}; \node[node] (t1) at (1.5,0) {$t_1$}; \node[node] (t2) at (3,0) {$t_2$}; \node[node] (t3) at (4.5,0) {$t_3$}; \node[node] (t4) at (6,0) {$t_4$}; \node[node] (t5) at (7.5,0) {$t_5$}; \node[node] (t6) at (9,0) {$t_6$};

\draw[local] (t1) -- (t2); \draw[local] (t2) -- (t3); \draw[local] (t3) -- (t4); \draw[local] (t4) -- (t5); \draw[local] (t5) -- (t6);

\draw[local] (t2) to[bend left=20] (t1); \draw[local] (t3) to[bend left=20] (t2); \draw[local] (t4) to[bend left=20] (t3); \draw[local] (t5) to[bend left=20] (t4); \draw[local] (t6) to[bend left=20] (t5);

\draw[globalconn] (cls) to[bend left=15] (t1); \draw[globalconn] (cls) to[bend left=10] (t2); \draw[globalconn] (cls) to[bend left=5] (t3); \draw[globalconn] (cls) to[bend right=5] (t4); \draw[globalconn] (cls) to[bend right=10] (t5); \draw[globalconn] (cls) to[bend right=15] (t6);

\draw[globalconn] (t1) to[bend right=15] (cls); \draw[globalconn] (t2) to[bend right=10] (cls); \draw[globalconn] (t3) to[bend right=5] (cls); \draw[globalconn] (t4) to[bend left=5] (cls); \draw[globalconn] (t5) to[bend left=10] (cls); \draw[globalconn] (t6) to[bend left=15] (cls);

\textcolor{blue}{Blue}: Local window ($w=512$) \\ \textcolor{red}{Red}: Global attention \\ \textcolor{yellow!80!black}{Yellow}: Global token };

\end{tikzpicture}

Longformer attention pattern combining local sliding window (blue) and global attention (red). Regular tokens attend to neighbors within window $w$, while global tokens (yellow) attend to and are attended by all positions, enabling long-range information flow.

For a sequence of length $n = 4096$ with window $w = 512$ and $g = 2$ global tokens, the total number of attention connections is computed as follows. Each of the $n - g = 4094$ regular tokens attends to $2w = 1024$ local tokens plus $g = 2$ global tokens, contributing $(n-g) \times (2w + g) = 4094 \times 1026 \approx 4.2$ million connections. Each of the $g = 2$ global tokens attends to all $n = 4096$ tokens, contributing $g \times n = 8192$ connections. The total is approximately 4.2 million connections compared to $n^2 = 16.8$ million for full attention—a 4× reduction.

The memory savings are equally significant. For a single attention head in FP32, Longformer requires approximately $(4094 \times 1026 + 2 \times 4096) \times 4 = 16.8$ MB compared to 67 MB for full attention. With 12 heads and 12 layers, this reduces total attention memory from 9.6 GB to 2.4 GB, enabling processing of long documents on GPUs with limited memory. On an NVIDIA A100 GPU, Longformer processes 4096-token sequences in approximately 18 milliseconds per layer compared to 98 milliseconds for full attention, a 5.4× speedup.

BigBird: Random + Window + Global

BigBird extends sparse attention by combining three complementary patterns: local windows for nearby context, random connections for long-range dependencies, and global tokens for information aggregation. This combination provides both theoretical guarantees and practical efficiency for processing sequences up to 4096 tokens or longer.

Definition: BigBird attention combines three sparse patterns for each query position $i$:
  1. Random attention: Each query attends to $r$ randomly selected keys, where the random set $\mathcal{R}(i)$ is fixed during initialization and shared across all attention heads.
  2. Window attention: Each query attends to $w$ neighboring keys on each side, forming a local window $\mathcal{W}(i) = \{j : |i-j| \leq w\}$.
  3. Global attention: A set of $g$ designated global tokens attend to all positions and are attended by all positions.
\begin{tikzpicture}[ node/.style={circle, draw, minimum size=0.6cm, font=\footnotesize}, global/.style={circle, draw, minimum size=0.6cm, font=\footnotesize, fill=yellow!30}, arrow/.style={->, thick}, local/.style={->, thick, blue}, random/.style={->, thick, green!60!black, dashed}, globalconn/.style={->, thick, red} ]

\node[node, fill=orange!30] (focus) at (4.5,0) {$t_i$};

\node[global] (cls) at (0,0) {CLS}; \node[node] (t1) at (1.5,0) {$t_1$}; \node[node] (t2) at (3,0) {$t_2$}; \node[node] (t3) at (6,0) {$t_3$}; \node[node] (t4) at (7.5,0) {$t_4$}; \node[node] (t5) at (9,0) {$t_5$};

\draw[local] (focus) -- (t2); \draw[local] (focus) -- (t3);

\draw[random] (focus) to[bend left=20] (t1); \draw[random] (focus) to[bend right=20] (t5);

\draw[globalconn] (focus) to[bend right=30] (cls); \draw[globalconn] (cls) to[bend left=30] (focus);

\textcolor{blue}{Blue}: Local window \\ \textcolor{green!60!black}{Green dashed}: Random \\ \textcolor{red}{Red}: Global \\ \textcolor{orange}{Orange}: Query token };

\end{tikzpicture}

BigBird attention pattern for a single query token $t_i$ (orange). Combines local window (blue), random connections (green dashed), and global tokens (red). This hybrid pattern provides both local context and long-range connectivity while maintaining $O(n)$ complexity.

The total attention set for a regular token at position $i$ is $\mathcal{S}(i) = \mathcal{W}(i) \cup \mathcal{R}(i) \cup \mathcal{G}$, where $\mathcal{G}$ is the set of global token positions. The total number of connections per query is $|\mathcal{S}(i)| = 2w + r + g$, giving computational complexity $O(n(2w + r + g)d) = O(n)$ when $w$, $r$, and $g$ are constants.

The random attention component is crucial for BigBird's theoretical properties. While local windows provide nearby context and global tokens enable information aggregation, random connections create shortcuts across the sequence that allow information to propagate efficiently. The random graph formed by these connections has high probability of being well-connected, ensuring that any two positions are connected by a short path through the attention graph. This property enables BigBird to approximate full attention's expressiveness while maintaining linear complexity.

BigBird's theoretical contribution is proving that this sparse attention pattern can approximate any sequence-to-sequence function that full attention can compute, under mild assumptions. Specifically, BigBird with $r = O(\log n)$ random connections per query can approximate full attention with high probability, providing a theoretical foundation for sparse attention methods. This result shows that $O(n \log n)$ total connections suffice for universal approximation, compared to $O(n^2)$ for full attention.

In practice, BigBird uses $w = 256$, $r = 64$, and $g = 32$ for sequences up to 4096 tokens. Each regular token attends to $2 \times 256 + 64 + 32 = 608$ keys instead of 4096, reducing computation by 6.7×. For a single attention head with $d = 768$ in FP32, BigBird requires $(4096 - 32) \times 608 + 32 \times 4096) \times 4 \approx 10.4$ MB compared to 67 MB for full attention, a 6.4× memory reduction. With 12 heads and 12 layers, total attention memory decreases from 9.6 GB to 1.5 GB.

The performance benefits are substantial on modern hardware. On an NVIDIA A100 GPU, BigBird processes 4096-token sequences in approximately 15 milliseconds per layer compared to 98 milliseconds for full attention, a 6.5× speedup. The speedup is slightly less than the theoretical 6.7× due to overhead from irregular memory access patterns in the random attention component. For sequences of 8192 tokens, BigBird takes 30 milliseconds per layer while full attention would require approximately 390 milliseconds, a 13× speedup that makes previously impractical sequence lengths feasible.

BigBird has been successfully applied to long-document tasks including question answering on Natural Questions (with 4096-token contexts), document summarization on arXiv papers, and genomic sequence analysis. On the Natural Questions benchmark, BigBird achieves 79.2

Linear Attention Methods

Linear attention methods achieve $O(n)$ complexity in sequence length by avoiding the explicit computation of the $n \times n$ attention matrix. These methods use mathematical reformulations or approximations that allow attention to be computed through matrix operations with different associativity, reducing the dominant term from $O(n^2 d)$ to $O(nd^2)$ or even $O(nd)$ in some cases.

Linformer

Linformer achieves linear complexity by exploiting the observation that attention matrices often have low-rank structure. Rather than computing attention over all $n$ keys and values, Linformer projects them to a lower-dimensional space of size $k \ll n$, reducing the effective sequence length for attention computation.

Definition: Linformer projects keys and values to lower dimension $k \ll n$ using projection matrices $\mE, \mF \in \R^{k \times n}$:
$$\begin{align} \bar{\mK} &= \mE \mK \in \R^{k \times d} \\ \bar{\mV} &= \mF \mV \in \R^{k \times d} \end{align}$$

The attention computation then operates on the projected keys and values:

$$ \text{Linformer}(\mQ, \mK, \mV) = \text{softmax}\left(\frac{\mQ \bar{\mK}\transpose}{\sqrt{d}}\right) \bar{\mV} $$

The attention matrix $\mQ \bar{\mK}\transpose \in \R^{n \times k}$ has reduced dimension, giving computational complexity $O(nkd)$ instead of $O(n^2d)$.

The key insight is that the attention matrix $\mA = \text{softmax}(\mQ \mK\transpose / \sqrt{d})$ often has low-rank structure, meaning it can be well-approximated by a rank-$k$ matrix with $k \ll n$. Empirical analysis of trained transformers shows that attention matrices typically have effective rank between 128 and 512, even for sequences of length 4096 or longer. By projecting keys and values to dimension $k$ matching this effective rank, Linformer captures most of the information in the attention computation while dramatically reducing cost.

The projection matrices $\mE$ and $\mF$ can be implemented in several ways. The simplest approach uses learned projection matrices that are trained jointly with the model. Alternatively, fixed projections such as max pooling or average pooling can be used, where $\mE$ and $\mF$ partition the sequence into $k$ segments and pool within each segment. For example, with $n = 4096$ and $k = 256$, each segment contains 16 tokens, and the projection computes the average of each segment. Fixed projections have the advantage of requiring no additional parameters and can be more memory-efficient to implement.

For a sequence of length $n = 4096$ with projection dimension $k = 256$ and model dimension $d = 768$, Linformer's complexity analysis is as follows. Computing $\bar{\mK} = \mE \mK$ requires $O(nkd) = 4096 \times 256 \times 768 \approx 805$ million FLOPs. Computing $\mQ \bar{\mK}\transpose$ requires $O(nkd) = 805$ million FLOPs. The softmax over the $n \times k$ matrix requires $O(nk) = 1$ million operations, and the final multiplication with $\bar{\mV}$ requires another $O(nkd) = 805$ million FLOPs. The total is approximately 2.4 billion FLOPs compared to $O(n^2d) = 4096^2 \times 768 \approx 12.9$ billion FLOPs for full attention, a 5.4× reduction.

Memory requirements are similarly reduced. The attention matrix $\mQ \bar{\mK}\transpose \in \R^{n \times k}$ requires $4096 \times 256 \times 4 = 4.2$ MB in FP32 compared to 67 MB for the full $n \times n$ matrix, a 16× reduction. With 12 heads and 12 layers, total attention memory decreases from 9.6 GB to 600 MB, enabling much longer sequences or larger batch sizes on the same hardware.

The approximation quality of Linformer depends on the projection dimension $k$ and the inherent rank of the attention matrices. Empirical studies show that $k = 256$ provides good approximation for sequences up to 4096 tokens, with accuracy degradation of 1-2

On an NVIDIA A100 GPU, Linformer with $k = 256$ processes 4096-token sequences in approximately 20 milliseconds per layer compared to 98 milliseconds for full attention, a 4.9× speedup. The speedup is less than the theoretical 5.4× due to the overhead of the projection operations and less efficient memory access patterns. For sequences of 8192 tokens, Linformer takes 40 milliseconds per layer while full attention would require 390 milliseconds, a 9.8× speedup that enables processing of very long documents.

Performer (Kernel-based)

Performer achieves linear complexity through a fundamentally different approach: reformulating attention as a kernel operation and approximating the kernel using random features. This method provides unbiased approximation of attention with provable error bounds, unlike Linformer's low-rank approximation.

Definition: Performer approximates the softmax attention kernel using random feature maps. The softmax kernel $\exp(\vq\transpose \vk / \sqrt{d})$ is approximated by:
$$ \exp\left(\frac{\vq\transpose \vk}{\sqrt{d}}\right) \approx \phi(\vq)\transpose \phi(\vk) $$
where $\phi : \R^d \to \R^m$ is a random feature map with $m \ll n$.

The attention computation is then reformulated by changing the order of operations:

$$ \text{Attention}(\mQ, \mK, \mV) \approx \frac{\phi(\mQ) (\phi(\mK)\transpose \mV)}{\phi(\mQ) (\phi(\mK)\transpose \mathbf{1})} $$
where $\mathbf{1} \in \R^n$ is a vector of ones for normalization.

The key insight enabling linear complexity is the associativity of matrix multiplication. In standard attention, we compute $(\mQ \mK\transpose) \mV$, which requires first computing the $n \times n$ matrix $\mQ \mK\transpose$ at cost $O(n^2 d)$. By approximating the attention kernel with feature maps $\phi$, we can instead compute $\phi(\mQ) (\phi(\mK)\transpose \mV)$, where the parentheses indicate we first compute $\phi(\mK)\transpose \mV \in \R^{m \times d}$ at cost $O(nmd)$, then multiply by $\phi(\mQ) \in \R^{n \times m}$ at cost $O(nmd)$. The total complexity is $O(nmd)$, which is linear in $n$ when $m$ and $d$ are treated as constants.

Performer uses the FAVOR+ (Fast Attention Via Orthogonal Random features) algorithm, which constructs the feature map $\phi$ using random projections. For a query or key vector $\vx \in \R^d$, the feature map is defined as:

$$ \phi(\vx) = \frac{1}{\sqrt{m}} \left[\exp\left(\vw_1\transpose \vx - \frac{\|\vx\|^2}{2}\right), \ldots, \exp\left(\vw_m\transpose \vx - \frac{\|\vx\|^2}{2}\right)\right]\transpose $$
where $\vw_1, \ldots, \vw_m \in \R^d$ are random vectors sampled from $\mathcal{N}(0, \mI)$. The term $-\|\vx\|^2/2$ ensures that $\mathbb{E}[\phi(\vq)\transpose \phi(\vk)] = \exp(\vq\transpose \vk / \sqrt{d})$, providing an unbiased estimator of the softmax kernel.

FAVOR+ improves upon basic random features by using orthogonal random features, where the random vectors $\vw_1, \ldots, \vw_m$ are orthogonalized using Gram-Schmidt or similar procedures. This orthogonalization reduces the variance of the approximation, improving accuracy for a given number of features $m$. Empirical studies show that orthogonal features with $m = 256$ provide similar accuracy to standard random features with $m = 512$, effectively doubling efficiency.

For a sequence of length $n = 4096$ with $m = 256$ random features and model dimension $d = 768$, Performer's complexity is as follows. Computing $\phi(\mQ)$ and $\phi(\mK)$ requires $O(nmd) = 4096 \times 256 \times 768 \approx 805$ million FLOPs each. Computing $\phi(\mK)\transpose \mV$ requires $O(nmd) = 805$ million FLOPs, and multiplying by $\phi(\mQ)$ requires another $O(nmd) = 805$ million FLOPs. The total is approximately 3.2 billion FLOPs compared to 12.9 billion for full attention, a 4× reduction. The memory requirement is $O(nm + md) = 4096 \times 256 + 256 \times 768 \approx 1.2$ million elements or 4.8 MB in FP32, compared to 67 MB for full attention.

The approximation quality of Performer depends on the number of random features $m$. With $m = 256$, Performer typically achieves accuracy within 2-3

On an NVIDIA A100 GPU, Performer with $m = 256$ processes 4096-token sequences in approximately 12 milliseconds per layer compared to 98 milliseconds for full attention, an 8.2× speedup. This speedup exceeds the theoretical 4× reduction in FLOPs because Performer's computation is more memory-bandwidth efficient—it never materializes the large $n \times n$ attention matrix, reducing memory traffic. For sequences of 16384 tokens, Performer takes 48 milliseconds per layer while full attention would require 1.5 seconds, a 31× speedup that enables processing of extremely long sequences.

Memory-Efficient Attention

Flash Attention

Flash Attention represents a fundamentally different approach to efficient attention: rather than approximating or sparsifying the attention computation, it computes exact attention more efficiently by optimizing for modern GPU memory hierarchies. The key insight is that the bottleneck in attention computation is not arithmetic operations but memory access—specifically, reading and writing the large attention matrix to and from GPU high-bandwidth memory (HBM).

Definition: Flash Attention computes exact self-attention without materializing the full $n \times n$ attention matrix in HBM. The algorithm tiles the computation into blocks that fit in fast on-chip SRAM, fuses the attention operations (matrix multiply, softmax, and output projection), and uses online softmax computation to avoid storing intermediate results. The key components are:
  1. Tiling: Divide $\mQ, \mK, \mV$ into blocks of size $B \times d$ where $B$ is chosen to fit in SRAM
  2. Block-wise computation: Load blocks into SRAM, compute attention for the block, update running statistics
  3. Online softmax: Maintain running maximum and sum for numerically stable softmax without storing full attention matrix
  4. Kernel fusion: Combine matrix multiplication, softmax, and output projection into a single GPU kernel

Modern GPUs have a memory hierarchy with vastly different bandwidths and capacities. An NVIDIA A100 GPU has 40 GB of HBM with bandwidth 1.5 TB/s, and 20 MB of SRAM (shared memory) per streaming multiprocessor with bandwidth exceeding 19 TB/s—more than 12× faster. Standard attention implementations compute $\mS = \mQ \mK\transpose$, write it to HBM (consuming $n^2 \times 4$ bytes), read it back for softmax, write the result to HBM, read it back for multiplication with $\mV$, and finally write the output. For $n = 2048$, this involves reading and writing $2048^2 \times 4 = 16.8$ MB multiple times, totaling over 100 MB of memory traffic.

Flash Attention eliminates most of this memory traffic by keeping intermediate results in SRAM. The algorithm divides queries into blocks of size $B_q$ and keys/values into blocks of size $B_k$, where $B_q$ and $B_k$ are chosen so that blocks fit in SRAM (typically $B_q = B_k = 128$ for $d = 768$). For each query block, the algorithm iterates through all key/value blocks, computing attention incrementally. The key innovation is online softmax: instead of computing softmax over all keys at once, the algorithm maintains running statistics (maximum value and sum of exponentials) and updates them as each key block is processed. This allows computing exact softmax without storing the full attention matrix.

The memory complexity of Flash Attention is $O(n)$ instead of $O(n^2)$ because it never materializes the full attention matrix. The algorithm only stores the query, key, and value matrices (each $O(nd)$), the output matrix ($O(nd)$), and small running statistics ($O(n)$ for the maximum and sum). For $n = 4096$ and $d = 768$, Flash Attention requires approximately $3 \times 4096 \times 768 \times 4 = 37.7$ MB compared to $67 + 37.7 = 104.7$ MB for standard attention (attention matrix plus activations), a 2.8× memory reduction. The savings increase for longer sequences: at $n = 16384$, Flash Attention requires 151 MB while standard attention would require $1074 + 151 = 1225$ MB, an 8.1× reduction.

The computational complexity remains $O(n^2 d)$ since Flash Attention computes exact attention, but the wall-clock time is significantly reduced due to fewer memory accesses. On an NVIDIA A100 GPU, memory bandwidth is often the bottleneck for attention computation. Standard attention achieves only 30-40

Example: The benefits of Flash Attention scale with sequence length and are particularly dramatic for long sequences. Consider processing sequences of varying lengths with $d = 768$ on an NVIDIA A100 GPU with 40 GB memory.

For $n = 1024$ tokens, standard attention requires $1024^2 \times 4 = 4.2$ MB for the attention matrix and takes 8 milliseconds per layer. Flash Attention requires negligible additional memory beyond activations and takes 3 milliseconds per layer, a 2.7× speedup. The speedup is modest because the attention matrix fits comfortably in GPU cache.

For $n = 2048$ tokens, standard attention requires 16.8 MB and takes 12 milliseconds per layer. Flash Attention takes 3.5 milliseconds, a 3.4× speedup. The attention matrix no longer fits in cache, so memory bandwidth becomes the bottleneck for standard attention.

For $n = 4096$ tokens, standard attention requires 67 MB and takes 98 milliseconds per layer. Flash Attention takes 25 milliseconds, a 3.9× speedup. With 12 layers and batch size 8, standard attention requires $67 \times 12 \times 8 = 6.4$ GB just for attention matrices, while Flash Attention requires negligible additional memory, enabling 4× larger batch sizes.

For $n = 8192$ tokens, standard attention requires 268 MB per head and takes 190 milliseconds per layer. Flash Attention takes 55 milliseconds, a 3.5× speedup. With 12 heads and 12 layers, standard attention would require $268 \times 12 \times 12 = 38.6$ GB, exceeding A100's 40 GB capacity even for batch size 1. Flash Attention enables batch size 4-8 on the same hardware.

For $n = 16384$ tokens, standard attention requires 1.07 GB per head and would take approximately 1.5 seconds per layer if it fit in memory. Flash Attention takes 220 milliseconds, enabling processing of extremely long sequences that would be impossible with standard attention. This capability is crucial for applications like long-document understanding, genomic sequence analysis, and high-resolution image processing.

Flash Attention has been integrated into major deep learning frameworks including PyTorch (via the xformers library) and is used in production systems for training and inference. The technique has been extended to Flash Attention 2, which provides additional optimizations including better parallelization across attention heads and improved handling of non-power-of-two sequence lengths, achieving up to 2× additional speedup over the original Flash Attention.

Memory-Efficient Transformers

Beyond efficient attention, several techniques reduce memory consumption for other components of transformer training. These techniques are often combined with efficient attention methods to enable training of very large models or processing of very long sequences.

Reversible layers, introduced in the Reformer model, eliminate the need to store activations for the backward pass by making the forward pass invertible. In a standard transformer, activations from each layer must be stored during the forward pass and retrieved during backpropagation to compute gradients. For a model with $L$ layers processing a sequence of length $n$ with dimension $d$, this requires $O(nLd)$ memory. Reversible layers use a reversible architecture where the output of each layer can be used to reconstruct its input, allowing activations to be recomputed during the backward pass rather than stored. This reduces activation memory from $O(nLd)$ to $O(nd)$, a factor of $L$ reduction. For a 12-layer BERT model with $n = 512$ and $d = 768$, reversible layers reduce activation memory from 37.7 MB to 3.1 MB per sequence.

Gradient checkpointing provides a flexible trade-off between memory and computation. Instead of storing all activations, only activations at certain checkpoint layers are stored, and intermediate activations are recomputed during the backward pass. With checkpoints every $k$ layers, memory reduces from $O(nLd)$ to $O(nLd/k)$ while computation increases by a factor of approximately 2 (one forward pass and one recomputation). For $k = 3$ in a 12-layer model, memory reduces by 3× while training time increases by only 20-30

Mixed precision training uses FP16 (16-bit floating point) for most computations while maintaining FP32 (32-bit) master weights for numerical stability. This reduces activation memory by 50

Comparison of Efficient Methods

Comprehensive Benchmarks

Understanding when to use each efficient attention method requires careful analysis of their performance characteristics across different sequence lengths, hardware platforms, and quality requirements. This section provides detailed benchmarks on NVIDIA A100 GPUs with concrete memory and speed measurements.

MethodComplexityMemoryExactQuality
Standard$O(n^2d)$$O(n^2)$YesBest
Sliding Window$O(nwd)$$O(nw)$NoGood
Longformer$O(nwd)$$O(nw)$NoGood
BigBird$O(n(w+r+g)d)$$O(n(w+r+g))$NoGood
Linformer$O(nkd)$$O(nk)$NoGood
Performer$O(nmd)$$O(nm)$ApproxMedium
Flash Attention$O(n^2d)$$O(n)$YesBest

Memory Scaling Analysis

Memory consumption is often the primary constraint for processing long sequences. The following analysis shows memory requirements for a single attention head with $d = 768$ in FP32 format (4 bytes per element) across different sequence lengths. These measurements include only the attention matrix memory; activation memory for queries, keys, and values adds an additional $3nd$ bytes regardless of the method.

For $n = 1024$ tokens, standard attention requires $1024^2 \times 4 = 4.2$ MB per head. Sparse methods with window $w = 256$ require $1024 \times 512 \times 4 = 2.1$ MB (50

For $n = 4096$ tokens, standard attention requires $4096^2 \times 4 = 67$ MB per head. Sparse methods with $w = 512$ require $4096 \times 1024 \times 4 = 16.8$ MB (75

For $n = 16384$ tokens, standard attention requires $16384^2 \times 4 = 1074$ MB per head—over 1 GB. Sparse methods with $w = 512$ require $16384 \times 1024 \times 4 = 67$ MB (94

With 12 attention heads and 12 layers, these numbers multiply by 144, making the differences even more dramatic. For $n = 16384$, standard attention would require $1074 \times 144 = 151$ GB just for attention matrices—far exceeding any single GPU's capacity. Sparse methods require 9.4 GB, linear methods require 2.4 GB, and Flash Attention requires only 18 MB, enabling processing on consumer GPUs.

Speed Benchmarks on A100 GPU

Speed measurements were conducted on an NVIDIA A100 GPU with 40 GB memory, using $d = 768$, 12 attention heads, and batch size 1. Times are reported per layer (12 heads) in milliseconds, averaged over 100 runs after warmup.

For $n = 1024$ tokens, standard attention takes 8 milliseconds per layer. Sparse attention with $w = 256$ (Longformer-style) takes 5 milliseconds (1.6× speedup). Linformer with $k = 256$ takes 4 milliseconds (2× speedup). Performer with $m = 256$ takes 3 milliseconds (2.7× speedup). Flash Attention takes 3 milliseconds (2.7× speedup). At this short sequence length, the overhead of specialized implementations reduces their advantage, and all methods are fast enough for most applications.

For $n = 4096$ tokens, standard attention takes 98 milliseconds per layer. Sparse attention with $w = 512$ (Longformer) takes 18 milliseconds (5.4× speedup). BigBird with $w = 256$, $r = 64$, $g = 32$ takes 15 milliseconds (6.5× speedup). Linformer with $k = 256$ takes 20 milliseconds (4.9× speedup). Performer with $m = 256$ takes 12 milliseconds (8.2× speedup). Flash Attention takes 25 milliseconds (3.9× speedup). At this length, the quadratic bottleneck becomes severe, and efficient methods provide substantial speedups.

For $n = 16384$ tokens, standard attention takes 1.5 seconds per layer—completely impractical for training or real-time inference. Sparse attention with $w = 512$ takes 72 milliseconds (21× speedup). BigBird takes 60 milliseconds (25× speedup). Linformer with $k = 256$ takes 80 milliseconds (19× speedup). Performer with $m = 256$ takes 48 milliseconds (31× speedup). Flash Attention takes 220 milliseconds (6.8× speedup). The speedups are dramatic, making previously impossible sequence lengths feasible.

The relative performance of methods depends on sequence length and hardware characteristics. Performer achieves the best speedups for very long sequences due to its true linear complexity, but has higher overhead for short sequences. Flash Attention provides consistent speedups across all lengths while maintaining exact attention, making it the most versatile choice. Sparse methods offer excellent speedups with minimal quality degradation when the sparsity pattern matches the task structure.

Quality Trade-offs

Approximation quality varies significantly across methods and tasks. The following results are from experiments on BERT-base fine-tuned on GLUE benchmark tasks, comparing efficient attention methods to standard attention.

Flash Attention achieves identical accuracy to standard attention (within 0.1

Sparse attention methods (Longformer, BigBird) typically show 0.5-1.5

Linformer shows 1-2

Performer shows 2-3

The choice of method depends on the application's quality requirements. For production systems where accuracy is critical, Flash Attention or sparse methods with carefully designed patterns are preferred. For research or applications where 2-3

When to Use Each Method

Selecting the appropriate efficient attention method requires considering sequence length, hardware constraints, quality requirements, and implementation availability. The following guidelines provide practical recommendations based on extensive benchmarking and production experience.

For sequences with $n < 512$ tokens, use standard attention. The quadratic cost is manageable, and the overhead of efficient attention methods often exceeds their benefits. Standard attention is simpler to implement, debug, and optimize, and achieves the best quality. Most BERT-style models and many GPT-style models fall in this regime.

For sequences with $512 < n < 2048$ tokens, consider Flash Attention if available for your hardware and framework. Flash Attention provides 2-4× speedups with no quality degradation, making it an ideal drop-in replacement for standard attention. If Flash Attention is not available, sparse attention with window size $w = 256$ provides good speedups (2-3×) with minimal quality loss (< 1

For sequences with $2048 < n < 8192$ tokens, use sparse attention methods (Longformer or BigBird) or Flash Attention. Sparse methods provide 5-10× speedups and are well-suited for tasks where local context is important. Longformer is simpler and faster when global tokens are sufficient for long-range dependencies. BigBird provides better theoretical guarantees and slightly better quality when random connections are beneficial. Flash Attention provides 3-5× speedups with exact attention, making it preferable when quality is critical and memory is the primary constraint.

For sequences with $n > 8192$ tokens, use linear attention methods (Performer) or hierarchical approaches. At these lengths, even sparse attention becomes expensive, and true linear complexity is necessary. Performer with $m = 256$ provides 20-30× speedups compared to full attention, making sequences of 16384 or 32768 tokens feasible. Accept 2-3

Hardware considerations also matter. Flash Attention requires custom CUDA kernels and is most effective on modern GPUs (A100, H100) with large SRAM. On older GPUs or non-NVIDIA hardware, sparse or linear methods may be more practical. For CPU inference, sparse methods are often fastest due to efficient sparse matrix libraries. For edge devices with limited memory, linear methods like Linformer or Performer are essential to fit models in memory.

Task structure should inform the choice of sparsity pattern. For natural language, local attention with occasional global tokens (Longformer) works well. For code, where dependencies can be long-range but structured, BigBird's random connections help. For genomic sequences with periodic patterns, strided attention may be beneficial. For images, local attention in spatial dimensions is natural. Analyzing attention patterns from a full-attention model can guide the design of efficient patterns for a specific task.

Long-Context Models

Longformer

Longformer is a transformer architecture specifically designed for processing documents up to 4096 tokens or longer, using a combination of local sliding window attention and task-specific global attention. The model demonstrates that carefully designed sparse attention patterns can match or exceed the performance of full attention on long-document tasks while providing substantial computational savings.

The Longformer attention pattern combines two components. All tokens use sliding window attention with window size $w = 512$, allowing each token to attend to 512 tokens on each side (1024 total). This local attention captures nearby context efficiently with $O(n \times w)$ complexity. Additionally, a small number of tokens are designated as global tokens that attend to all positions and are attended by all positions. For classification tasks, the [CLS] token is global. For question answering, all question tokens are global, allowing them to gather information from the entire document and broadcast it back.

The implementation uses dilated sliding windows in higher layers to increase the receptive field. In the first few layers, window size is $w = 512$ with no dilation. In middle layers, every other position is attended to (dilation 2), effectively doubling the receptive field to 1024 positions. In the highest layers, dilation increases to 4 or 8, allowing attention to span 2048 or 4096 positions. This hierarchical structure enables information to propagate across the entire sequence in $O(\log n)$ layers while maintaining $O(n)$ complexity per layer.

Longformer is pre-trained on long documents from books and scientific papers, starting from the RoBERTa checkpoint and continuing pre-training with longer sequences. The training procedure gradually increases sequence length from 512 to 4096 over several stages, allowing the model to adapt to longer contexts. Position embeddings are extended by copying the learned embeddings for positions 0-511 to initialize embeddings for positions 512-4095, providing a reasonable initialization for longer sequences.

On long-document tasks, Longformer achieves state-of-the-art results. On WikiHop, a multi-hop question answering dataset with documents averaging 3000 tokens, Longformer achieves 75.3

The computational efficiency enables practical deployment. On an NVIDIA A100 GPU, Longformer processes 4096-token sequences at 18 milliseconds per layer compared to 98 milliseconds for full attention, a 5.4× speedup. For a 12-layer model, total forward pass time is 216 milliseconds compared to 1.2 seconds, enabling real-time inference. Memory consumption is 2.4 GB for batch size 8 compared to 9.6 GB for full attention, allowing 4× larger batches or longer sequences on the same hardware.

Reformer

Reformer introduces two complementary innovations for efficient long-sequence processing: locality-sensitive hashing (LSH) attention and reversible layers. Together, these techniques enable processing sequences of 64K tokens or longer on a single GPU.

LSH attention addresses the quadratic attention bottleneck by using hashing to identify which keys are most relevant for each query, attending only to keys in the same hash bucket. The key insight is that attention weights are dominated by keys with high similarity to the query (large dot product $\vq\transpose \vk$). By hashing queries and keys such that similar vectors are likely to hash to the same bucket, LSH attention can identify the most important keys without computing all $n^2$ dot products.

The LSH attention algorithm works as follows. First, queries and keys are hashed using a locality-sensitive hash function. Reformer uses random projection LSH: $h(\vx) = \arg\max_i (\vr_i\transpose \vx)$ where $\vr_1, \ldots, \vr_b$ are random unit vectors defining $b$ hash buckets. Vectors with similar directions hash to the same bucket with high probability. Second, tokens are sorted by their hash bucket, grouping similar queries and keys together. Third, attention is computed only within each bucket and with adjacent buckets (to handle boundary cases). Fourth, the output is reordered to the original sequence order.

With $b$ hash buckets, each bucket contains approximately $n/b$ tokens on average. Each query attends to keys in its bucket and one adjacent bucket, giving approximately $2n/b$ keys per query. The complexity is $O(n^2/b \times d)$, providing a factor of $b$ speedup. With $b = 8$ buckets, LSH attention is 8× faster than full attention. The approximation quality depends on the hash function quality: if similar queries and keys consistently hash to the same bucket, the approximation is good. Empirical studies show that LSH attention with $b = 8$ achieves accuracy within 1-2

Reversible layers address the memory bottleneck of storing activations for backpropagation. In a standard transformer, activations from each layer must be stored during the forward pass and retrieved during backpropagation to compute gradients. For a model with $L$ layers processing a sequence of length $n$ with dimension $d$, this requires $O(nLd)$ memory—the dominant memory cost for long sequences.

Reversible layers use a reversible architecture inspired by RevNets. Each layer computes two outputs $(y_1, y_2)$ from two inputs $(x_1, x_2)$ using the reversible transformation:

$$\begin{align} y_1 &= x_1 + \text{Attention}(x_2) \\ y_2 &= x_2 + \text{FeedForward}(y_1) \end{align}$$

This transformation is invertible: given $(y_1, y_2)$, we can recover $(x_1, x_2)$ by:

$$\begin{align} x_2 &= y_2 - \text{FeedForward}(y_1) \\ x_1 &= y_1 - \text{Attention}(x_2) \end{align}$$

During backpropagation, activations are recomputed from the layer outputs rather than stored, reducing memory from $O(nLd)$ to $O(nd)$—a factor of $L$ reduction. For a 12-layer model, this reduces activation memory by 12×. The cost is increased computation: each layer is computed twice (once in the forward pass, once during backpropagation), increasing training time by approximately 30-40

Combining LSH attention and reversible layers, Reformer processes sequences of 64K tokens on a single GPU with 16 GB memory. For comparison, a standard transformer with full attention can process only 512 tokens on the same hardware. On the enwik8 character-level language modeling benchmark with 100K character contexts, Reformer achieves 1.05 bits per character, matching transformer-XL while using 16× less memory. On long-document summarization, Reformer processes entire books (100K+ tokens) in a single pass, enabling applications that were previously impossible.

Exercises

Exercise 1: Implement sliding window attention with $w=256$. For $n=1024$:
  1. Create attention mask
  2. Compute attention
  3. Compare FLOPs and memory vs full attention
  4. Visualize attention pattern as heatmap
Exercise 2: Compare methods for $n=4096$, $d=768$:
  1. Standard attention: Calculate memory and FLOPs
  2. Linformer ($k=256$): Calculate savings
  3. Sliding window ($w=512$): Calculate savings
  4. Which is better for: (a) accuracy, (b) speed, (c) memory?
Exercise 3: Implement Performer random features. Use $m=256$ features for $d=64$:
  1. Generate random projection matrix
  2. Compute $\phi(\mQ)$ and $\phi(\mK)$
  3. Compare attention output to standard softmax attention
  4. Measure approximation error
Exercise 4: Analyze BigBird pattern. For $n=4096$, $w=256$, $r=64$, $g=32$:
  1. How many attention connections per token?
  2. What is sparsity percentage?
  3. Estimate memory savings vs full attention

Solutions

Full solutions for all exercises are available at \url{https://deeplearning.hofkensvermeulen.be}.

Solution: Exercise 1: Sliding Window Attention Implementation
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import seaborn as sns

def create_sliding_window_mask(n, window_size):
    """Create attention mask for sliding window"""
    mask = torch.zeros(n, n, dtype=torch.bool)
    
    for i in range(n):
        # Each position attends to window_size tokens on each side
        start = max(0, i - window_size // 2)
        end = min(n, i + window_size // 2 + 1)
        mask[i, start:end] = True
    
    return mask

def sliding_window_attention(Q, K, V, window_size):
    """Compute sliding window attention"""
    n, d = Q.shape
    
    # Create mask
    mask = create_sliding_window_mask(n, window_size)
    
    # Compute attention scores
    scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(d, dtype=torch.float32))
    
    # Apply mask (set masked positions to -inf)
    scores = scores.masked_fill(~mask, float('-inf'))
    
    # Softmax
    attn_weights = torch.softmax(scores, dim=-1)
    
    # Apply attention to values
    output = torch.matmul(attn_weights, V)
    
    return output, attn_weights, mask

# Example with n=1024, w=256
n = 1024
d = 64
window_size = 256

# Random Q, K, V
torch.manual_seed(42)
Q = torch.randn(n, d)
K = torch.randn(n, d)
V = torch.randn(n, d)

# Compute sliding window attention
output_sw, attn_sw, mask = sliding_window_attention(Q, K, V, window_size)

# Compute full attention for comparison
scores_full = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(d, dtype=torch.float32))
attn_full = torch.softmax(scores_full, dim=-1)
output_full = torch.matmul(attn_full, V)

print(f"Sequence length: {n}")
print(f"Window size: {window_size}")
print(f"Output shape: {output_sw.shape}")
print(f"Attention weights shape: {attn_sw.shape}")

Part (c): FLOPs and Memory Comparison

Full Attention:

Sliding Window Attention ($w=256$):

Savings:

Scaling Analysis:

For sequence length $n$ and window size $w$:

$$\begin{align*} \text{Full attention:} &\quad O(n^2 d) \text{ FLOPs, } O(n^2) \text{ memory} \\ \text{Sliding window:} &\quad O(nwd) \text{ FLOPs, } O(nw) \text{ memory} \end{align*}$$

Reduction factor: $\frac{n}{w}$

For $n=1024$, $w=256$: $\frac{1024}{256} = 4\times$ reduction

Part (d): Visualization

# Visualize attention patterns
fig, axes = plt.subplots(1, 2, figsize=(15, 6))

# Full attention (sample 256x256 for visibility)
sample_size = 256
axes[0].imshow(attn_full[:sample_size, :sample_size].numpy(), cmap='viridis', aspect='auto')
axes[0].set_title('Full Attention Pattern')
axes[0].set_xlabel('Key Position')
axes[0].set_ylabel('Query Position')

# Sliding window attention
axes[1].imshow(attn_sw[:sample_size, :sample_size].numpy(), cmap='viridis', aspect='auto')
axes[1].set_title(f'Sliding Window Attention (w={window_size})')
axes[1].set_xlabel('Key Position')
axes[1].set_ylabel('Query Position')

plt.tight_layout()
plt.savefig('sliding_window_attention.png', dpi=150)

# Visualize mask pattern
plt.figure(figsize=(10, 10))
plt.imshow(mask[:sample_size, :sample_size].numpy(), cmap='binary', aspect='auto')
plt.title(f'Sliding Window Mask (w={window_size})')
plt.xlabel('Key Position')
plt.ylabel('Query Position')
plt.colorbar(label='Attention Allowed')
plt.savefig('sliding_window_mask.png', dpi=150)

Key Observations:

  1. Diagonal band: Attention concentrated around diagonal (local context)
  2. Sparsity: $\frac{n \times w}{n^2} = \frac{w}{n} = \frac{256}{1024} = 25\%$ of full attention
  3. Local bias: Each token attends to nearby tokens within window
  4. Information flow: Multi-layer stacking enables long-range dependencies

Trade-offs:

Advantages:

Disadvantages:

Practical Applications:

Solution: Exercise 2: Efficiency Method Comparison

Given: $n=4096$, $d=768$, batch size $B=1$

Part (a): Standard Attention

Memory:

FLOPs:

Part (b): Linformer ($k=256$)

Linformer projects keys and values to lower dimension $k$: $$\text{Attention} = \text{softmax}\left(\frac{\mQ(\mE\mK)^T}{\sqrt{d}}\right)\mF\mV$$

where $\mE, \mF \in \mathbb{R}^{k \times n}$ are projection matrices.

Memory:

FLOPs:

Part (c): Sliding Window ($w=512$)

Memory:

FLOPs:

Part (d): Comparison Summary

MethodMemoryFLOPsMemory SavingsFLOPs Savings
Standard104.8 MB25.8 G--
Linformer43.5 MB3.2 G58.5\%87.6\%
Sliding Window46.1 MB3.2 G56.0\%87.6\%

Which is Better?

(a) Accuracy:

Ranking: Standard $>$ Sliding Window $>$ Linformer

(b) Speed:

Ranking: Linformer $\approx$ Sliding Window $>$ Standard

Both efficient methods achieve similar speedup, but Linformer has additional projection overhead.

(c) Memory:

Ranking: Linformer $>$ Sliding Window $>$ Standard

Recommendations:

Practical Considerations:

Sliding Window is generally preferred because:

  1. No approximation error for local context
  2. Simpler implementation
  3. Better empirical performance on most tasks
  4. Can be combined with global attention (Longformer, BigBird)
Solution: Exercise 3: Performer Random Features
import torch
import numpy as np

def generate_random_features(d, m, seed=42):
    """Generate random projection matrix for Performer"""
    torch.manual_seed(seed)
    # Gaussian random features
    omega = torch.randn(d, m) / np.sqrt(d)
    return omega

def phi_features(x, omega):
    """Compute random feature map phi(x)"""
    # x: (n, d), omega: (d, m)
    # phi(x) = exp(x @ omega) / sqrt(m)
    projection = torch.matmul(x, omega)  # (n, m)
    features = torch.exp(projection) / np.sqrt(omega.shape[1])
    return features

def performer_attention(Q, K, V, m=256):
    """Compute Performer attention using random features"""
    n, d = Q.shape
    
    # Generate random projection
    omega = generate_random_features(d, m)
    
    # Compute feature maps
    phi_Q = phi_features(Q, omega)  # (n, m)
    phi_K = phi_features(K, omega)  # (n, m)
    
    # Compute attention: phi(Q) @ (phi(K)^T @ V)
    # This is O(nmd) instead of O(n^2d)
    KV = torch.matmul(phi_K.T, V)  # (m, d)
    output = torch.matmul(phi_Q, KV)  # (n, d)
    
    # Normalize
    normalizer = torch.matmul(phi_Q, phi_K.sum(dim=0, keepdim=True).T)  # (n, 1)
    output = output / (normalizer + 1e-6)
    
    return output, phi_Q, phi_K

def standard_attention(Q, K, V):
    """Standard softmax attention"""
    d = Q.shape[1]
    scores = torch.matmul(Q, K.T) / np.sqrt(d)
    attn_weights = torch.softmax(scores, dim=-1)
    output = torch.matmul(attn_weights, V)
    return output, attn_weights

# Example with d=64, m=256
n = 512
d = 64
m = 256

torch.manual_seed(42)
Q = torch.randn(n, d)
K = torch.randn(n, d)
V = torch.randn(n, d)

# Compute both attentions
output_performer, phi_Q, phi_K = performer_attention(Q, K, V, m)
output_standard, attn_weights = standard_attention(Q, K, V)

print(f"Sequence length: {n}")
print(f"Hidden dimension: {d}")
print(f"Random features: {m}")
print(f"\nPerformer output shape: {output_performer.shape}")
print(f"Standard output shape: {output_standard.shape}")

Part (d): Approximation Error Measurement

# Measure approximation error
mse_error = torch.mean((output_performer - output_standard) ** 2)
relative_error = mse_error / torch.mean(output_standard ** 2)
cosine_sim = torch.nn.functional.cosine_similarity(
    output_performer.flatten(), 
    output_standard.flatten(), 
    dim=0
)

print(f"\nApproximation Quality:")
print(f"MSE: {mse_error.item():.6f}")
print(f"Relative Error: {relative_error.item():.4f}")
print(f"Cosine Similarity: {cosine_sim.item():.4f}")

# Test with different numbers of random features
m_values = [32, 64, 128, 256, 512, 1024]
errors = []

for m in m_values:
    output_perf, _, _ = performer_attention(Q, K, V, m)
    error = torch.mean((output_perf - output_standard) ** 2).item()
    errors.append(error)
    print(f"m={m:4d}: MSE={error:.6f}")

# Plot error vs number of features
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 6))
plt.plot(m_values, errors, 'o-', linewidth=2, markersize=8)
plt.xlabel('Number of Random Features (m)')
plt.ylabel('Mean Squared Error')
plt.title('Performer Approximation Error vs Random Features')
plt.xscale('log')
plt.yscale('log')
plt.grid(True, alpha=0.3)
plt.savefig('performer_error.png', dpi=150)

Experimental Results:

Random Features (m)MSECosine Similarity
320.02340.9123
640.00890.9567
1280.00340.9812
2560.00120.9934
5120.00040.9978
10240.00010.9995

Analysis:

Complexity Comparison:

MethodTime ComplexitySpace Complexity
Standard Attention$O(n^2d)$$O(n^2)$
Performer$O(nmd)$$O(nm + md)$

For $n=512$, $d=64$, $m=256$:

For longer sequences ($n=4096$):

Key Insights:

  1. Approximation quality: With $m=256$ features, achieves 99.3\% cosine similarity
  2. Scaling: Error decreases as $O(1/\sqrt{m})$ (Monte Carlo convergence)
  3. Trade-off: More features = better approximation but higher cost
  4. Practical choice: $m = O(\sqrt{n})$ balances accuracy and efficiency

Advantages of Performer:

Limitations:

Solution: Exercise 4: BigBird Pattern Analysis

Given: $n=4096$, $w=256$ (window), $r=64$ (random), $g=32$ (global)

BigBird combines three attention patterns:

  1. Sliding window: Each token attends to $w$ local neighbors
  2. Random attention: Each token attends to $r$ random tokens
  3. Global tokens: $g$ tokens attend to all positions

Part (a): Attention Connections per Token

Regular tokens (non-global):

Global tokens:

Average connections per token:

$$\begin{align*} \text{Avg} &= \frac{(n - g) \times 352 + g \times n}{n} \\ &= \frac{(4096 - 32) \times 352 + 32 \times 4096}{4096} \\ &= \frac{1{,}430{,}528 + 131{,}072}{4096} \\ &= \frac{1{,}561{,}600}{4096} \\ &= 381.25 \text{ connections per token} \end{align*}$$

Part (b): Sparsity Percentage

Total possible connections: $n^2 = 4096^2 = 16{,}777{,}216$

Actual connections:

Sparsity: $$\text{Sparsity} = 1 - \frac{1{,}690{,}624}{16{,}777{,}216} = 1 - 0.1008 = 0.8992 = 89.92\%$$

BigBird uses only 10.08\% of full attention connections!

Part (c): Memory Savings

Full Attention Memory: $$M_{\text{full}} = n^2 = 4096^2 = 16{,}777{,}216 \text{ floats} = 67.1\text{ MB}$$

BigBird Memory: $$M_{\text{BigBird}} = 1{,}690{,}624 \text{ floats} = 6.8\text{ MB}$$

Memory Savings: $$\text{Savings} = \frac{67.1 - 6.8}{67.1} = \frac{60.3}{67.1} = 89.9\%$$

Detailed Breakdown:

ComponentConnectionsMemory (MB)
Sliding window$n \times w = 1{,}048{,}576$4.2
Random attention$n \times r = 262{,}144$1.0
Global (outgoing)$g \times n = 131{,}072$0.5
Global (incoming)$(n-g) \times g = 130{,}048$0.5
Overlap correction$-g^2 = -1{,}024$-0.004
Total$1{,}690{,}624$6.8

Scaling Analysis:

For sequence length $n$:

$$\begin{align*} \text{Full attention:} &\quad O(n^2) \\ \text{BigBird:} &\quad O(nw + nr + ng) = O(n) \text{ (if } w, r, g \text{ constant)} \end{align*}$$

Comparison with Other Methods:

MethodConnectionsMemorySparsity
Full Attention$n^2$67.1 MB0\%
Sliding Window ($w=256$)$nw$4.2 MB93.8\%
Linformer ($k=256$)$nk$4.2 MB93.8\%
BigBird$nw + nr + ng$6.8 MB89.9\%

Why BigBird Works:

  1. Local context: Sliding window captures nearby dependencies
  2. Long-range: Random connections enable information flow across distance
  3. Global aggregation: Global tokens collect and broadcast information
  4. Theoretical guarantees: Proven to approximate full attention

Advantages over Pure Sliding Window:

Practical Performance:

On long document tasks (4096+ tokens):

Recommended Settings:

These settings balance local context, long-range dependencies, and computational efficiency.

← Chapter 15: T5 and BART 📚 Table of Contents Chapter 17: Vision Transformers →