Long Context Transformers

Chapter Overview

Extending transformer context length beyond standard limits (512-2048 tokens) enables processing long documents, books, and extended conversations. This chapter covers techniques for scaling to 32K, 100K, and even 1M+ token contexts.

Learning Objectives

  1. Understand context length limitations and bottlenecks
  2. Implement position interpolation and extrapolation
  3. Apply memory-augmented transformers
  4. Use retrieval-augmented generation (RAG)
  5. Implement recurrent transformers (Transformer-XL)
  6. Compare long-context methods and trade-offs

Context Length Limitations

The Quadratic Memory Bottleneck

Standard transformer architectures face fundamental limitations when processing long sequences due to the quadratic scaling of self-attention with respect to sequence length. The self-attention mechanism computes pairwise interactions between all tokens in a sequence, requiring the materialization of an attention matrix of size $n \times n$ where $n$ is the sequence length. This quadratic scaling manifests in three critical bottlenecks: computational complexity, memory consumption, and position encoding limitations. Understanding these bottlenecks quantitatively is essential for appreciating why long context processing requires specialized techniques and architectural modifications.

The computational complexity of self-attention is $O(n^2 d)$ where $d$ is the model dimension. For each of the $n$ queries, the model computes attention scores with all $n$ keys through dot products of dimension $d$, requiring $n^2 d$ multiply-accumulate operations. The subsequent softmax normalization adds $O(n^2)$ operations, and the weighted sum over values requires another $n^2 d$ operations. While the feed-forward network has complexity $O(n d^2)$, for long sequences where $n > d$, the attention computation dominates. For example, with $n = 16384$ tokens and $d = 768$, attention requires approximately 206 billion FLOPs per layer while the feed-forward network requires only 19 billion FLOPs, making attention the primary computational bottleneck.

The memory bottleneck is even more severe than the computational one. During the forward pass, the attention matrix must be fully materialized in memory before applying softmax, requiring $n^2$ memory locations per attention head. During the backward pass for training, these attention matrices must be stored for gradient computation, effectively doubling the memory requirement. For multi-head attention with $h$ heads, the total memory for attention matrices is $h \times n^2$ floating-point values per layer. In FP32 format, each value requires 4 bytes, while FP16 requires 2 bytes. This memory grows quadratically with sequence length, quickly exceeding available GPU memory for long sequences.

Example: Consider a GPT-2 scale model with $d = 768$, $h = 12$ attention heads, and $L = 12$ layers. The memory required for attention matrices scales dramatically with sequence length. For a single attention head processing a sequence of length $n$, the attention matrix requires $n^2 \times 4$ bytes in FP32 format. With 12 heads per layer, this becomes $12 \times n^2 \times 4 = 48n^2$ bytes per layer.

At $n = 1024$ tokens (GPT-2's standard context), each layer requires $48 \times 1024^2 = 50.3$ MB for attention matrices. Across 12 layers, this totals 604 MB, which is manageable on modern GPUs. However, doubling the sequence length to $n = 2048$ quadruples the memory requirement to 201 MB per layer or 2.4 GB total—a 4× increase for only a 2× increase in sequence length. This quadratic scaling continues: at $n = 4096$, attention matrices consume 805 MB per layer or 9.7 GB total, nearly filling a 16 GB GPU. At $n = 8192$, the requirement explodes to 3.2 GB per layer or 38.5 GB total, exceeding even high-end GPUs like the NVIDIA A100 with 40 GB memory.

For larger models, the situation becomes even more challenging. Consider a GPT-3 scale model with $d = 12288$, $h = 96$ heads, and $L = 96$ layers. At $n = 2048$ tokens, each layer requires $96 \times 2048^2 \times 4 = 1.6$ GB for attention matrices, totaling 154 GB across all layers. At $n = 8192$ tokens, each layer requires 25.8 GB, totaling 2.5 TB across the model—far exceeding any single GPU's capacity and requiring extensive model parallelism even for moderate context lengths. At $n = 32768$ tokens, a single layer would require 412 GB just for attention matrices, making standard attention completely impractical without fundamental algorithmic changes.

These calculations assume only the forward pass attention matrices. During training, gradients with respect to attention matrices must also be stored, effectively doubling the memory requirement. Additionally, activations from other layers, model parameters, optimizer states, and batch processing multiply these requirements further. For a batch size of 8 with $n = 4096$ tokens on GPT-2, attention matrices alone would require $9.7 \times 8 = 77.6$ GB, making training impossible on standard hardware without techniques like gradient checkpointing, which trades computation for memory by recomputing activations during the backward pass.

The third fundamental limitation involves position encodings. Standard transformers use position encodings trained on sequences of a fixed maximum length, typically 512 to 2048 tokens. When these models encounter sequences longer than their training length, the position encodings must extrapolate to unseen positions. Absolute position embeddings, which assign a learned vector to each position index, cannot extrapolate at all—positions beyond the training length have no corresponding embedding. Even sinusoidal position encodings, which use deterministic trigonometric functions, exhibit degraded performance when extrapolating beyond training lengths due to the model's learned attention patterns being calibrated to the training distribution of position encodings.

This extrapolation failure manifests as rapidly degrading perplexity for tokens beyond the training context length. A model trained on 2048-token sequences might achieve perplexity of 15 on positions 0-2048, but perplexity can increase to 25 or higher for positions 2048-4096 without specialized position encoding schemes. This degradation occurs because the model's attention patterns have learned to interpret specific position encoding values as corresponding to specific relative distances, and these learned patterns break down when position encodings take on values outside the training distribution.

Position Encoding for Long Context

The Extrapolation Challenge

Position encodings enable transformers to incorporate sequential order information into their otherwise permutation-invariant architecture. However, different position encoding schemes exhibit dramatically different behaviors when processing sequences longer than those seen during training. This extrapolation capability is critical for long context applications, where retraining on maximum-length sequences is often computationally prohibitive. The choice of position encoding scheme can determine whether a model trained on 2048-token sequences can successfully process 8192-token sequences with minimal fine-tuning, or whether it requires extensive retraining from scratch.

Absolute position embeddings assign a learned vector to each position index, with the position encoding for position $i$ being a trainable parameter $\vp_i \in \R^d$. These embeddings are added to token embeddings before the first transformer layer. While simple and effective within the training length, absolute embeddings cannot extrapolate beyond the maximum training position. A model trained with positions 0 through 2047 has no learned embedding for position 2048 or beyond. Attempting to extend such a model requires either initializing new position embeddings (which perform poorly without extensive fine-tuning) or using position interpolation techniques to map longer sequences into the trained position range.

Sinusoidal position encodings, introduced in the original Transformer paper, use deterministic trigonometric functions rather than learned parameters. For position $i$ and dimension $j$, the encoding is defined as:

$$ \text{PE}(i, 2j) = \sin(i / 10000^{2j/d}), \quad \text{PE}(i, 2j+1) = \cos(i / 10000^{2j/d}) $$
These encodings can be computed for any position without training, enabling extrapolation in principle. However, in practice, models trained with sinusoidal encodings still exhibit degraded performance on longer sequences because the attention patterns learned during training are calibrated to the distribution of position encodings seen during training. When positions extend beyond the training range, the attention patterns encounter position encoding values in unfamiliar ranges, leading to suboptimal attention distributions.

Position Interpolation

Position interpolation addresses the extrapolation problem by mapping longer sequences into the position range seen during training, rather than extending beyond it. Instead of asking the model to extrapolate to unseen position indices, interpolation compresses the position indices of a long sequence into the trained range, effectively treating the long sequence as a "compressed" version of a training-length sequence.

Definition: To extend a model trained on maximum length $L$ to process sequences of length $L' > L$, position interpolation maps each position $i \in \{0, 1, \ldots, L'-1\}$ to a fractional position in the training range:
$$ i_{\text{interpolated}} = i \cdot \frac{L}{L'} $$

For absolute position embeddings, the new position encoding is computed by interpolating between the learned embeddings:

$$ \text{PE}_{\text{new}}(i) = \text{interpolate}(\text{PE}_{\text{original}}, i \cdot L/L') $$

For sinusoidal or rotary encodings, the interpolated position is used directly in the encoding formula, effectively reducing the frequency of the trigonometric functions.

The key insight behind position interpolation is that it keeps position encodings within the distribution seen during training, avoiding the extrapolation problem entirely. For example, extending from $L = 2048$ to $L' = 8192$ maps position 8191 to interpolated position $8191 \times 2048/8192 = 2047.75$, which falls within the training range. The model's attention patterns, having been trained on positions 0 through 2047, can handle this interpolated position much more effectively than they could handle the raw position 8191.

Position interpolation has been successfully applied to extend LLaMA models from 2048 to 8192 tokens and beyond. The technique requires minimal fine-tuning—typically only 1000 to 10000 training steps on long sequences—compared to training from scratch. After fine-tuning with position interpolation, LLaMA 2 models maintain perplexity within 5-10\% of their original performance when extended from 4096 to 32768 tokens, whereas naive extrapolation without interpolation results in perplexity degradation of 50\% or more.

The computational cost of position interpolation is negligible, as it only affects the position encoding computation, not the attention mechanism itself. The primary cost is the fine-tuning required to adapt the model to the compressed position space. However, this fine-tuning is far less expensive than training from scratch: extending a 7B parameter model from 4K to 32K context requires approximately 100 GPU-hours of fine-tuning compared to 100,000+ GPU-hours for full pretraining.

Rotary Position Embedding (RoPE)

Rotary Position Embedding represents a fundamental advance in position encoding design, achieving excellent extrapolation properties by encoding relative position information directly into the attention computation through rotation operations. RoPE has become the position encoding of choice for modern large language models including GPT-NeoX, LLaMA, PaLM, and many others due to its combination of strong extrapolation, computational efficiency, and theoretical elegance.

Definition: RoPE applies position-dependent rotations to query and key vectors before computing attention. For a query vector $\vq_m$ at position $m$ and key vector $\vk_n$ at position $n$, RoPE applies rotation matrices:
$$\begin{align} \vq_m' &= \mR_m \vq_m \\ \vk_n' &= \mR_n \vk_n \end{align}$$

where $\mR_m \in \R^{d \times d}$ is a block-diagonal rotation matrix. For dimension pairs $(2j, 2j+1)$, the rotation is:

$$ \mR_m^{(j)} = \begin{bmatrix} \cos(m\theta_j) & -\sin(m\theta_j) \\ \sin(m\theta_j) & \cos(m\theta_j) \end{bmatrix}, \quad \theta_j = 10000^{-2j/d} $$

The full rotation matrix is block-diagonal with $d/2$ such 2D rotation blocks.

The crucial property of RoPE is that the attention score between positions $m$ and $n$ depends only on their relative distance $m - n$, not their absolute positions. This can be verified through the rotation addition formula:

$$ (\vq_m')\transpose \vk_n' = (\mR_m \vq_m)\transpose (\mR_n \vk_n) = \vq_m\transpose \mR_m\transpose \mR_n \vk_n = \vq_m\transpose \mR_{n-m} \vk_n $$

This relative position property means that the model learns attention patterns based on relative distances between tokens rather than absolute positions. When extrapolating to longer sequences, the model encounters the same relative distances it saw during training, just in different combinations. A model trained on sequences up to 2048 tokens has seen all relative distances from -2047 to +2047. When processing a 4096-token sequence, it encounters the same relative distances, enabling much better extrapolation than absolute position encodings.

RoPE's extrapolation capability can be further enhanced through position interpolation. By scaling the rotation frequencies $\theta_j$ by a factor $L'/L$ when extending from length $L$ to $L'$, the effective relative distances are compressed into the training range. This combination of RoPE's inherent relative position encoding with position interpolation enables extensions from 2048 to 32768 tokens or beyond with minimal quality degradation.

The computational overhead of RoPE is minimal compared to the attention computation itself. Applying rotations to queries and keys requires $O(nd)$ operations, which is negligible compared to the $O(n^2d)$ cost of attention. The rotation operations can be efficiently implemented using vectorized operations on modern GPUs, adding less than 5\% to the total attention computation time. Memory overhead is also minimal, as the rotation matrices are block-diagonal and can be computed on-the-fly rather than stored.

In practice, RoPE has enabled dramatic context length extensions. LLaMA models using RoPE have been successfully extended from 2048 to 32768 tokens through position interpolation with only 1000 fine-tuning steps. The perplexity degradation is typically less than 10\% even at 16× the training length, compared to 50-100\% degradation for absolute position embeddings. This extrapolation capability has made RoPE the de facto standard for new large language models designed for long context applications.

ALiBi: Attention with Linear Biases

ALiBi (Attention with Linear Biases) takes a radically different approach to position encoding by eliminating position embeddings entirely and instead adding position-dependent biases directly to attention scores. This simple modification achieves remarkable extrapolation properties, enabling models trained on 1024-token sequences to process 10,000+ token sequences at inference time with no fine-tuning whatsoever.

Definition: ALiBi adds a bias term to attention scores based on the distance between query and key positions:
$$ \text{score}(q_i, k_j) = \frac{\vq_i\transpose \vk_j}{\sqrt{d_k}} - m \cdot |i - j| $$

where $m > 0$ is a head-specific slope parameter that differs across attention heads. The bias $-m \cdot |i - j|$ penalizes attention to distant tokens, with the penalty increasing linearly with distance.

The head-specific slopes are set geometrically: for $h$ attention heads, the slopes are $m_1, m_2, \ldots, m_h$ where $m_i = 2^{-8i/h}$. For example, with 8 heads, the slopes are $2^{-1}, 2^{-2}, \ldots, 2^{-8}$, giving values from 0.5 to 0.0039. This geometric spacing ensures that different heads have different receptive field sizes: heads with large slopes focus on nearby tokens, while heads with small slopes can attend to distant tokens with less penalty.

ALiBi's extrapolation capability stems from its use of relative distances rather than absolute positions, combined with the linear form of the bias. During training on sequences up to length $L$, the model encounters biases ranging from 0 (attending to the same position) to $-m \cdot L$ (attending to the most distant position). When extrapolating to length $L' > L$, the model encounters biases up to $-m \cdot L'$, which are simply larger values along the same linear scale. The attention patterns learned during training—which balance content-based attention (from $\vq_i\transpose \vk_j$) against distance-based penalties (from $-m \cdot |i-j|$)—continue to work effectively at these larger distances.

Empirical results demonstrate ALiBi's exceptional extrapolation. Models trained on 1024-token sequences with ALiBi can process 2048-token sequences with less than 5\% perplexity increase, 4096-token sequences with 10-15\% increase, and even 10,000-token sequences with 20-30\% increase—all without any fine-tuning. In contrast, the same models with sinusoidal position encodings show 50\% perplexity increase at 2048 tokens and become essentially non-functional beyond 4096 tokens. This zero-shot extrapolation capability makes ALiBi particularly attractive for applications where the maximum sequence length is unknown at training time or varies widely across use cases.

ALiBi has been adopted by several prominent models including BLOOM (176B parameters) and MPT (7B-30B parameters). BLOOM was trained with ALiBi on sequences up to 2048 tokens but can effectively process sequences of 4096 tokens or longer at inference time. MPT models trained with ALiBi on 2048-token sequences have been successfully deployed on tasks requiring 8192-token contexts with minimal quality degradation.

The computational overhead of ALiBi is negligible. Computing the bias $-m \cdot |i-j|$ for all $n^2$ attention scores requires $O(n^2)$ operations, which is dominated by the $O(n^2d)$ cost of computing $\mQ\mK\transpose$. The bias can be precomputed once per sequence and reused across all layers and heads (with different slopes $m$), further reducing overhead. Memory overhead is also minimal, as the bias matrix can be computed on-the-fly or stored once and reused.

The primary limitation of ALiBi is that it assumes a monotonic relationship between distance and relevance—more distant tokens are always penalized more heavily. This assumption holds well for many natural language tasks where local context is indeed more important than distant context. However, for tasks with long-range dependencies that are not distance-dependent (such as matching opening and closing brackets in code, or resolving coreferences across document sections), ALiBi's linear bias may be suboptimal compared to learned position encodings that can capture more complex position-dependent patterns.

Efficient Attention for Long Context

Sparse Attention Patterns

While position encoding improvements enable better extrapolation, they do not address the fundamental quadratic scaling of attention computation and memory. Efficient attention mechanisms reduce this quadratic bottleneck by restricting which tokens can attend to which other tokens, computing attention only over a subset of the $n^2$ possible connections. These sparse attention patterns can reduce complexity from $O(n^2)$ to $O(n \times w)$ where $w$ is a fixed window size, enabling processing of sequences that would be impossible with full attention.

The key insight behind sparse attention is that not all token pairs require attention computation. In many domains, particularly natural language, most relevant information comes from nearby tokens, with occasional long-range dependencies. By carefully designing sparsity patterns that preserve important connections while eliminating redundant ones, sparse attention can maintain model quality while dramatically reducing computational and memory requirements.

Longformer and BigBird

Longformer combines local windowed attention ($O(nw)$ for window size $w$) with global attention on task-specific tokens, enabling efficient processing of documents up to 4096+ tokens. BigBird extends this with random connections that create small-world shortcuts, proving that sparse attention with $O(n \log n)$ connections is a universal approximator for sequence-to-sequence functions. Both reduce attention memory by 6--8$\times$ compared to full attention. See Chapter~[ref] for the detailed mechanism definitions, complexity analysis, and benchmark results.

In the context of long-context processing, the key trade-off is between sparsity pattern and information propagation depth. With window size $w$, information requires $\lceil n/w \rceil$ layers to propagate across the full sequence. BigBird's random connections reduce the expected propagation path to $O(\log n)$ layers, which is particularly valuable for tasks requiring long-range reasoning. For sequences up to 2048 tokens, optimized full attention (e.g., Flash Attention) is typically faster; for 2048--8192 tokens, sparse attention becomes beneficial; beyond 8192 tokens, it becomes essential.

Comparison of Sparse Attention Methods

Different sparse attention patterns offer different trade-offs between efficiency, model quality, and implementation complexity. Local attention with window size $w$ provides the simplest pattern and best memory locality, achieving $O(nw)$ complexity with straightforward implementation. However, information propagation is limited to $w$ positions per layer, requiring deep networks for long-range dependencies. Longformer's addition of global tokens addresses this limitation by providing information hubs, enabling faster propagation while maintaining linear complexity. BigBird's random connections provide theoretical guarantees and empirically strong performance, but at the cost of irregular memory access patterns that reduce hardware efficiency.

For sequences up to 2048 tokens, the overhead of sparse attention often outweighs its benefits—full attention with optimized implementations like Flash Attention is typically faster and simpler. For sequences of 2048-8192 tokens, sparse attention becomes beneficial, with Longformer and BigBird providing good trade-offs between efficiency and quality. For sequences beyond 8192 tokens, sparse attention becomes essential, as full attention exceeds available memory on most hardware. At these lengths, the choice between Longformer and BigBird depends on the task: Longformer is simpler and faster for tasks where local context dominates, while BigBird provides better quality for tasks requiring complex long-range reasoning.

Recurrent Transformers

Transformer-XL

Definition: Segment long sequence, reuse representations from previous segments:

Segment $n$: Tokens $[s_n, s_n+1, \ldots, s_n+L-1]$

Compute:

$$ \vh_n = \text{Transformer}([\text{stop\_grad}(\vh_{n-1}), \vx_n]) $$

Previous segment hidden states provide additional context without recomputation!

Example: Segment length: $L = 512$

Segment 1: Process tokens $0$-$511$

Segment 2: Process tokens $512$-$1023$

Segment 3: Process tokens $1024$-$1535$

Context grows linearly with segments, computation stays constant!

Relative position encodings: Modified for segment-level recurrence

Retrieval-Augmented Generation

RAG Architecture

Definition: Combine retrieval with generation:

Step 1: Retrieval

$$ \text{docs} = \text{Retrieve}(\text{query}, \text{corpus}, k=5) $$

Step 2: Concatenate

$$ \text{input} = [\text{docs}_1, \ldots, \text{docs}_k, \text{query}] $$

Step 3: Generate

$$ \text{output} = \text{LM}(\text{input}) $$

Retrieval methods:

Example: Question: "When was the Eiffel Tower built?"

Step 1: Retrieve (from Wikipedia)

  1. "The Eiffel Tower was constructed from 1887 to 1889..."
  2. "Gustave Eiffel designed the tower for the 1889 World's Fair..."
  3. "The tower is 330 meters tall and was the tallest..."
Step 2: Concatenate

Context 1: The Eiffel Tower was constructed from 1887 to 1889...
Context 2: Gustave Eiffel designed the tower for the 1889 World's Fair...
Context 3: The tower is 330 meters tall and was the tallest...
Question: When was the Eiffel Tower built?
Answer:

Step 3: Generate "The Eiffel Tower was built from 1887 to 1889."

Advantages:

RETRO: Retrieval-Enhanced Transformer

Architecture:

Performance: 25× fewer parameters with retrieval achieves same performance as larger model without retrieval!

Memory-Augmented Transformers

Compressive Transformer

Definition: Extend Transformer-XL with compression:

Three levels of memory:

  1. Active: Current segment (full attention)
  2. Recent: Last $n_m$ segments (cached, full precision)
  3. Compressed: Older segments (compressed representations)

Compression:

Effective context: Active + Recent + Compressed

$$ L_{\text{eff}} = L + n_m \cdot L + n_c \cdot (L/c) $$

Memorizing Transformers

Key innovation: $k$-NN attention over entire history

Architecture:

Benefits:

Long Context Models in Practice

LongT5: Efficient Encoder-Decoder

LongT5 extends the T5 encoder-decoder architecture to handle sequences up to 16,384 tokens by applying efficient attention mechanisms to both the encoder and decoder. Unlike decoder-only models that process sequences autoregressively, encoder-decoder models must handle long sequences in both components, making efficiency doubly important. LongT5 demonstrates that sparse attention patterns can be successfully applied to encoder-decoder architectures while maintaining the strong performance of the original T5 model.

LongT5 uses a combination of local and global attention patterns similar to Longformer, but adapted for the encoder-decoder structure. The encoder uses local attention with window size $w = 512$ for all tokens, plus global attention for a small number of designated tokens. The decoder uses local attention for attending to its own previous tokens, plus full attention to encoder outputs (which are compressed through the local attention mechanism). This asymmetric design recognizes that decoder-to-encoder attention is typically less memory-intensive than encoder self-attention, as decoder sequences are usually shorter than encoder sequences.

The memory savings from LongT5's sparse attention are substantial. For an encoder sequence of length $n_e = 16384$ and decoder sequence of length $n_d = 512$, full attention would require approximately $(16384^2 + 512^2 + 512 \times 16384) \times 4 = 1.1$ GB per attention head in FP32 for encoder self-attention, decoder self-attention, and cross-attention combined. With LongT5's sparse patterns using $w = 512$, the requirement reduces to approximately $(16384 \times 1024 + 512^2 + 512 \times 16384) \times 4 = 71$ MB per head, a 15× reduction. With 12 heads and 12 encoder layers plus 12 decoder layers, total attention memory decreases from 26 GB to 1.7 GB.

LongT5 has been successfully applied to long-document summarization tasks where input documents exceed 10,000 tokens. On the arXiv summarization dataset with papers averaging 6000 tokens, LongT5 achieves ROUGE-L scores of 48.3 compared to 44.1 for T5 with truncated 512-token inputs, demonstrating that access to full document context significantly improves summary quality. On the PubMed summarization task with medical papers averaging 3000 tokens, LongT5 outperforms T5 by 3-4 ROUGE points across all metrics.

Production Long Context Systems

As of early 2025, production language models support context lengths of 128K--1M+ tokens, enabling processing of entire codebases, books, and extended conversations.\footnote{Context length capabilities evolve rapidly. Check current documentation for the latest specifications.} These systems combine the techniques covered in this chapter---efficient attention, position interpolation, and memory-efficient implementations---to make long-context processing practical. The fact that such contexts are served at interactive latencies indicates the use of multiple optimizations simultaneously: sparse or approximate attention to reduce the quadratic bottleneck, Flash Attention for memory efficiency, and model parallelism across multiple GPUs.

Practical Considerations for Long Context

Deploying long context models in production requires careful consideration of when the additional context is actually beneficial versus when alternative approaches might be more effective. Long context processing incurs real costs in terms of latency, computational resources, and financial expense, so understanding when these costs are justified is essential for practical applications.

Long context is most valuable when the task requires reasoning over or synthesizing information from multiple parts of a long document. Document summarization benefits significantly from full document access, as important information may appear anywhere in the document. Question answering over long documents similarly benefits from long context, as the relevant information's location is unknown in advance. Code generation and analysis tasks benefit from seeing entire files or multiple related files in context, enabling the model to understand dependencies and maintain consistency.

However, many tasks that initially seem to require long context can be effectively addressed with shorter contexts and retrieval. For question answering over large document collections, retrieval-augmented generation (RAG) can retrieve only the relevant passages and provide them in a short context, achieving similar or better performance at much lower cost. For tasks requiring access to factual knowledge, retrieval from a knowledge base is often more reliable and efficient than encoding all knowledge in the context. For multi-turn conversations, summarizing or compressing earlier conversation history can maintain coherence while reducing context length.

The cost-benefit analysis depends on several factors. Latency requirements matter: long context processing takes longer, which may be unacceptable for interactive applications. Accuracy requirements matter: if the task requires very high accuracy and the model performs significantly better with full context, the additional cost may be justified. Update frequency matters: if the information changes frequently, retrieval from an updated database may be preferable to encoding static information in context. Scale matters: for high-volume applications, the per-request cost of long context processing multiplies, potentially making alternative approaches more economical.

A practical strategy is to use a hybrid approach: employ retrieval or summarization to reduce context length when possible, but fall back to full long context processing when the task genuinely requires it. For example, a document analysis system might first use retrieval to identify relevant sections, then process those sections with long context if they exceed the standard context limit. This approach balances the benefits of long context with the efficiency of shorter contexts, optimizing for both quality and cost.

Comparison and Trade-offs

Method Comparison

Different approaches to long context processing involve fundamentally different trade-offs between computational efficiency, memory usage, model quality, and implementation complexity. Understanding these trade-offs is essential for selecting the appropriate method for a given application, as no single approach dominates across all dimensions. The optimal choice depends on the specific requirements of the task, available hardware, and acceptable quality-efficiency trade-offs.

Standard full attention with optimized implementations like Flash Attention remains the gold standard for quality and simplicity when sequence lengths permit. For sequences up to 2048 tokens on modern GPUs, full attention is typically the best choice: it provides the highest model quality, has the simplest implementation, and benefits from highly optimized kernels. Flash Attention reduces memory bandwidth requirements through kernel fusion, enabling batch sizes 2-4× larger than naive implementations while maintaining identical outputs to standard attention. However, the fundamental $O(n^2)$ scaling means that full attention becomes impractical beyond 4096-8192 tokens on typical hardware.

Sparse attention methods like Longformer and BigBird reduce complexity to $O(n \times w)$ where $w$ is a fixed window size, enabling sequences of 4096-16384 tokens on standard GPUs. These methods maintain exact attention within their connectivity pattern, avoiding approximation errors. The primary trade-off is that sparse patterns may miss important long-range dependencies that fall outside the connectivity pattern. For tasks where local context dominates (such as language modeling or most NLP tasks), this limitation has minimal impact on quality. For tasks requiring complex long-range reasoning (such as certain question answering or reasoning tasks), sparse attention may underperform full attention even when both are feasible.

Linear attention methods like Performer and Linformer achieve $O(n)$ complexity through mathematical approximations, enabling very long sequences of 32768 tokens or more. However, these approximations introduce errors that can degrade model quality. Performer uses random feature approximations to the softmax kernel, which works well for some attention distributions but poorly for others. Linformer assumes low-rank structure in attention matrices, which holds for many tasks but may fail for tasks with complex attention patterns. In practice, linear attention methods typically show 2-5\% accuracy degradation on downstream tasks compared to full attention, which may or may not be acceptable depending on the application.

Recurrent methods like Transformer-XL process sequences in segments with recurrent connections, enabling unlimited context length with constant memory per segment. The trade-off is that information must propagate through multiple segments to flow across long distances, which can be slower than direct attention and may lose information through the recurrent bottleneck. Transformer-XL works well for tasks like language modeling where sequential processing is natural, but less well for tasks requiring bidirectional context or random access to different parts of the sequence.

Retrieval-augmented generation (RAG) sidesteps the context length problem entirely by retrieving only relevant information and providing it in a short context. This approach can handle effectively unlimited document collections while maintaining the quality and efficiency of short-context models. The trade-off is implementation complexity: RAG requires building and maintaining a retrieval system, embedding documents, and handling retrieval failures. Additionally, RAG works best for tasks where relevant information can be identified through retrieval, but may struggle with tasks requiring synthesis across many parts of a document or reasoning about information that is difficult to retrieve.

MethodMax LengthComplexityQualityImplementation
Full Attention2-4K$O(n^2 d)$ExcellentSimple
Flash Attention4-8K$O(n^2 d)$ExcellentMedium
Longformer4-16K$O(nwd)$GoodMedium
BigBird4-16K$O(n(w+r)d)$GoodMedium
Linformer8-32K$O(nkd)$FairMedium
Performer16-64K$O(nd^2)$FairHard
Transformer-XLUnlimited$O(L^2 d)$/segGoodMedium
RAGUnlimited$O(n^2 d)$ExcellentHard

Hardware and Memory Considerations

The practical feasibility of different long context methods depends critically on available hardware and memory constraints. Modern GPUs vary widely in memory capacity, from 8 GB on consumer GPUs to 80 GB on high-end data center GPUs, and this memory capacity directly determines which sequence lengths are feasible with different methods.

For a BERT-base scale model with $d = 768$, $h = 12$ heads, and $L = 12$ layers, the memory requirements for different methods and sequence lengths are as follows. Full attention with $n = 2048$ requires approximately 2.4 GB for attention matrices across all layers, which fits comfortably on any modern GPU. At $n = 4096$, full attention requires 9.7 GB, which fits on 16 GB GPUs but leaves limited memory for batch processing. At $n = 8192$, full attention requires 38.5 GB, exceeding even high-end GPUs and requiring model parallelism or gradient checkpointing.

Sparse attention dramatically improves these numbers. Longformer with $w = 512$ and $n = 4096$ requires only 2.4 GB for attention matrices, enabling batch sizes 4× larger than full attention on the same hardware. At $n = 8192$, Longformer requires 4.8 GB, which fits comfortably on 16 GB GPUs. At $n = 16384$, Longformer requires 9.6 GB, still feasible on standard hardware. This memory efficiency enables processing of long documents on commodity GPUs that would be impossible with full attention.

Linear attention methods like Linformer with $k = 256$ require even less memory. At $n = 4096$, Linformer requires only 600 MB for attention matrices, enabling very large batch sizes or processing on smaller GPUs. At $n = 16384$, Linformer requires 2.4 GB, comparable to full attention at $n = 2048$. This memory efficiency enables processing of very long sequences, but at the cost of approximation errors that may degrade quality.

The memory requirements extend beyond attention matrices to include activations, gradients, model parameters, and optimizer states. For training, the total memory requirement is typically 4-6× the attention matrix memory when accounting for all these components. For inference, memory requirements are lower as gradients and optimizer states are not needed, but activations must still be stored for generation. These additional memory requirements mean that the feasible sequence length for training is typically 2-4× shorter than for inference on the same hardware.

Recommendations by Use Case

Selecting the appropriate long context method requires matching the method's characteristics to the specific requirements of the application. The following recommendations provide guidance based on common use cases and constraints.

For general NLP tasks with sequences up to 2048 tokens, use standard full attention with Flash Attention optimization. This provides the best quality with simple implementation and benefits from highly optimized libraries. The computational and memory costs are manageable on any modern GPU, and the simplicity reduces implementation and debugging time.

For document processing tasks with sequences of 2048-8192 tokens, use sparse attention methods like Longformer or BigBird. These methods provide good quality with manageable computational costs, and the sparse patterns align well with the local structure of natural language. Longformer is simpler and faster for tasks where local context dominates, while BigBird provides better quality for tasks requiring long-range reasoning. Both methods have well-tested implementations available in popular libraries.

For very long sequences of 8192-32768 tokens where quality is critical, consider using full attention with model parallelism or gradient checkpointing if hardware permits, or sparse attention if hardware is limited. The quality difference between full and sparse attention becomes more significant at these lengths, so the choice depends on whether the hardware can support full attention. If full attention is infeasible, BigBird typically provides better quality than Longformer at these lengths due to its random connections.

For extremely long sequences beyond 32768 tokens, or when processing large document collections, use retrieval-augmented generation (RAG) rather than attempting to fit everything in context. RAG provides better quality and efficiency than any long context method at these scales, as it focuses the model's attention on relevant information rather than processing irrelevant content. The implementation complexity of RAG is justified by the significant quality and efficiency improvements at these scales.

For streaming or online processing tasks, use Transformer-XL or similar recurrent methods that can process sequences incrementally without recomputing previous segments. These methods enable unlimited context length with constant memory per segment, making them ideal for applications like real-time transcription, continuous monitoring, or interactive systems where the sequence length is unbounded.

For tasks requiring frequent updates to the knowledge base or document collection, prefer RAG over long context methods. RAG allows updating the retrieval index without retraining the model, while long context methods require reprocessing the entire context whenever information changes. This makes RAG more practical for applications with dynamic information needs.

Exercises

Exercise 1: Calculate the memory requirements for attention matrices in different scenarios:
  1. For a BERT-base model ($d=768$, $h=12$, $L=12$) with sequence lengths $n \in \{512, 1024, 2048, 4096, 8192\}$, compute the total memory for attention matrices in FP32 and FP16 formats.
  2. For the same model using Longformer with window size $w=512$ and 2 global tokens, compute the memory savings compared to full attention at each sequence length.
  3. Determine the maximum sequence length that fits in 16 GB of GPU memory for full attention, assuming attention matrices consume 40\% of available memory (the rest is for activations, parameters, etc.).
  4. For a GPT-3 scale model ($d=12288$, $h=96$, $L=96$), compute the memory required for $n=2048$ tokens and explain why model parallelism is necessary.
Exercise 2: Implement and evaluate position interpolation for extending context length:
  1. Load a pretrained GPT-2 model (trained on 1024-token contexts).
  2. Implement position interpolation to extend the model to 4096 tokens by scaling position indices by $1024/4096 = 0.25$.
  3. Fine-tune the extended model on long sequences for 1000 steps.
  4. Evaluate perplexity on sequences of length 1024, 2048, 3072, and 4096, comparing the interpolated model to the original model (which will fail on longer sequences).
  5. Plot perplexity versus position to visualize how well the model handles different parts of the extended context.
Exercise 3: Implement and compare different sparse attention patterns:
  1. Implement local attention with window size $w=256$ for a sequence of length $n=2048$.
  2. Implement strided attention with stride $s=64$ for the same sequence.
  3. Implement BigBird attention combining local ($w=128$), random ($r=32$), and global ($g=4$) patterns.
  4. For each pattern, compute the number of attention connections and compare to full attention ($n^2 = 4,194,304$ connections).
  5. Visualize the attention patterns as sparse matrices and discuss which types of dependencies each pattern can capture.
Exercise 4: Implement ALiBi and test its extrapolation capabilities:
  1. Train a small transformer (4 layers, $d=256$, 4 heads) with ALiBi on sequences of length 512 from a language modeling dataset.
  2. Use head-specific slopes $m_i = 2^{-8i/4}$ for the 4 heads, giving slopes $\{0.25, 0.0625, 0.0156, 0.0039\}$.
  3. Evaluate the trained model on sequences of length 512, 1024, 2048, and 4096 without any fine-tuning.
  4. Compare to a model trained with sinusoidal position encodings on the same data.
  5. Plot perplexity versus sequence length for both models and explain the difference in extrapolation behavior.
Exercise 5: Implement a simple RAG system for question answering:
  1. Create a document corpus of 1000 Wikipedia articles on a specific topic (e.g., history, science).
  2. Embed all documents using a pretrained BERT model, storing embeddings in a FAISS index for efficient retrieval.
  3. For a given question, retrieve the top-5 most relevant documents based on embedding similarity.
  4. Concatenate the retrieved documents with the question and generate an answer using a pretrained language model.
  5. Compare the quality of answers when using RAG versus providing the model with only the question (no retrieval).
  6. Analyze cases where RAG succeeds and fails, discussing the importance of retrieval quality.
Exercise 6: Implement segment-level recurrence for processing long sequences:
  1. Implement a simplified Transformer-XL that processes a sequence in segments of length $L=256$.
  2. For each segment, cache the hidden states from the previous segment and concatenate them with the current segment's inputs.
  3. Ensure gradients do not flow into the cached hidden states (use stop\_gradient or detach).
  4. Process a sequence of length 2048 in 8 segments, measuring the effective context length at each position.
  5. Compare memory usage and computation time to processing the full 2048-token sequence at once.
  6. Discuss the trade-off between effective context length and computational efficiency.
Exercise 7: Analyze the computational and financial costs of long context processing:
  1. For a BERT-base model, measure the actual wall-clock time to process sequences of length 512, 1024, 2048, and 4096 on your available hardware (CPU or GPU).
  2. Compute the FLOPs for attention at each sequence length and compare to the measured time to determine hardware efficiency.
  3. Estimate the cost of processing 1 million tokens at each sequence length, assuming cloud GPU pricing (e.g., \$2.50/hour for an A100).
  4. Compare the cost of full attention versus Longformer with $w=512$ at each sequence length.
  5. Discuss scenarios where the higher cost of long context is justified versus where shorter contexts with retrieval would be more economical.
Exercise 8: Compare different position encoding schemes empirically:
  1. Train four small transformer models (4 layers, $d=256$, 4 heads) on the same language modeling dataset, using: (a) absolute learned positions, (b) sinusoidal positions, (c) RoPE, and (d) ALiBi.
  2. Train all models on sequences of length 512 for the same number of steps.
  3. Evaluate all models on sequences of length 512, 1024, 2048, and 4096.
  4. Plot perplexity versus sequence length for each model.
  5. Analyze which position encoding schemes extrapolate best and explain why based on their mathematical properties.
  6. Fine-tune the absolute and sinusoidal models on longer sequences and compare to the zero-shot extrapolation of RoPE and ALiBi.

Solutions

Full solutions for all exercises are available at \url{https://deeplearning.hofkensvermeulen.be}.

Solution: Exercise 1: Memory Calculation

Part (a): BERT-base Attention Memory

Configuration:

Attention Matrix Size per Head:

For sequence length $n$, each attention head stores: $n \times n$ attention scores

Total Attention Memory per Layer:

$h \times n^2$ values (one $n \times n$ matrix per head)

Total for All Layers:

$L \times h \times n^2$ values

Memory Calculations:

Seq LengthValuesFP32 (MB)FP16 (MB)
512$12 \times 12 \times 512^2 = 37.7$M150.975.5
1024$12 \times 12 \times 1024^2 = 150.9$M603.8301.9
2048$12 \times 12 \times 2048^2 = 603.9$M2415.91208.0
4096$12 \times 12 \times 4096^2 = 2.4$B9663.74831.9
8192$12 \times 12 \times 8192^2 = 9.7$B38654.719327.4

Formula:

FP32: $L \times h \times n^2 \times 4$ bytes

FP16: $L \times h \times n^2 \times 2$ bytes

Key Observation: Memory grows quadratically with sequence length!

Part (b): Longformer Memory Savings

Longformer Configuration:

Longformer Attention Complexity:

For each token:

Total connections: $n \times w + n \times g = n(w + g)$

Memory Comparison:

SeqFull AttnLongformerSavingsRatio
51237.7M37.7M0\%1.0x
1024150.9M75.5M50\%2.0x
2048603.9M151.0M75\%4.0x
40962.4B302.0M87.5\%8.0x
81929.7B604.0M93.8\%16.0x

Calculation:

Longformer values: $L \times h \times n \times (w + g) = 12 \times 12 \times n \times 514$

Savings ratio: $\frac{n^2}{n(w+g)} = \frac{n}{w+g}$

Key Insight: Longformer memory grows linearly with $n$, not quadratically!

Part (c): Maximum Sequence Length for 16 GB GPU

Available memory for attention: $16 \text{ GB} \times 0.4 = 6.4 \text{ GB}$

Solve for $n$:

$L \times h \times n^2 \times 4 \text{ bytes} = 6.4 \text{ GB}$

$12 \times 12 \times n^2 \times 4 = 6.4 \times 10^9$

$576 \times n^2 = 6.4 \times 10^9$

$n^2 = 11{,}111{,}111$

$n = 3{,}333$ tokens

With FP16: $n = 4{,}714$ tokens

Practical maximum: $\approx 3000$-$4000$ tokens for BERT-base on 16 GB GPU

Part (d): GPT-3 Scale Model Memory

Configuration:

Attention memory:

$96 \times 96 \times 2048^2 \times 4 \text{ bytes} = 154.6 \text{ GB}$

Why Model Parallelism is Necessary:

  1. Single GPU insufficient: Even A100 (80 GB) cannot hold attention matrices alone
  2. Total model size: 175B parameters $\times$ 4 bytes = 700 GB
  3. Activations: Additional 100+ GB during training
  4. Gradients: Another 700 GB
  5. Optimizer states: 1.4 TB (AdamW stores 2 copies)

Total training memory: $\approx 3$ TB

Solution: Distribute across 8-16 GPUs using:

Solution: Exercise 2: Position Interpolation Implementation
import torch
import torch.nn as nn
from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config
import numpy as np

# Part (a): Load pretrained GPT-2
model_name = "gpt2"  # Trained on 1024-token contexts
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name)

print(f"Original max position: {model.config.n_positions}")  # 1024

# Part (b): Implement position interpolation
def extend_position_embeddings(model, new_max_length=4096):
    """
    Extend position embeddings using interpolation
    """
    old_max_length = model.config.n_positions
    scale_factor = old_max_length / new_max_length  # 1024/4096 = 0.25
    
    # Get original position embeddings
    old_pos_emb = model.transformer.wpe.weight.data  # (1024, 768)
    
    # Create new position embeddings
    new_pos_emb = torch.zeros(new_max_length, old_pos_emb.shape[1])
    
    # Interpolate positions
    for new_pos in range(new_max_length):
        # Map new position to old position space
        old_pos_float = new_pos * scale_factor
        
        # Linear interpolation between floor and ceil
        old_pos_floor = int(np.floor(old_pos_float))
        old_pos_ceil = min(int(np.ceil(old_pos_float)), old_max_length - 1)
        
        if old_pos_floor == old_pos_ceil:
            new_pos_emb[new_pos] = old_pos_emb[old_pos_floor]
        else:
            # Interpolation weight
            weight = old_pos_float - old_pos_floor
            new_pos_emb[new_pos] = (
                (1 - weight) * old_pos_emb[old_pos_floor] +
                weight * old_pos_emb[old_pos_ceil]
            )
    
    # Update model
    model.transformer.wpe = nn.Embedding(new_max_length, old_pos_emb.shape[1])
    model.transformer.wpe.weight.data = new_pos_emb
    model.config.n_positions = new_max_length
    
    return model

# Extend to 4096 tokens
extended_model = extend_position_embeddings(model, new_max_length=4096)
print(f"Extended max position: {extended_model.config.n_positions}")  # 4096

# Part (c): Fine-tune on long sequences
from torch.utils.data import DataLoader, Dataset

class LongTextDataset(Dataset):
    def __init__(self, texts, tokenizer, max_length=4096):
        self.encodings = []
        for text in texts:
            tokens = tokenizer.encode(text, max_length=max_length, truncation=True)
            if len(tokens) >= max_length // 2:  # Only use long sequences
                self.encodings.append(tokens)
    
    def __len__(self):
        return len(self.encodings)
    
    def __getitem__(self, idx):
        return torch.tensor(self.encodings[idx])

# Simulate long text dataset
long_texts = ["..." * 1000 for _ in range(100)]  # Replace with actual data
dataset = LongTextDataset(long_texts, tokenizer, max_length=4096)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

# Fine-tuning
optimizer = torch.optim.AdamW(extended_model.parameters(), lr=1e-5)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
extended_model.to(device)

extended_model.train()
for step, batch in enumerate(dataloader):
    if step >= 1000:
        break
    
    batch = batch.to(device)
    outputs = extended_model(batch, labels=batch)
    loss = outputs.loss
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    if step 
        print(f"Step {step}, Loss: {loss.item():.4f}")



# Part (d): Evaluate perplexity at different lengths
def evaluate_perplexity(model, tokenizer, text, max_length):
    """Evaluate perplexity on a long sequence"""
    model.eval()
    tokens = tokenizer.encode(text, max_length=max_length, truncation=True)
    tokens = torch.tensor(tokens).unsqueeze(0).to(device)
    
    with torch.no_grad():
        outputs = model(tokens, labels=tokens)
        loss = outputs.loss
        perplexity = torch.exp(loss)
    
    return perplexity.item()

# Test sequences
test_text = "..." * 2000  # Long test text

lengths = [1024, 2048, 3072, 4096]
perplexities_extended = []
perplexities_original = []

for length in lengths:
    # Extended model
    ppl_ext = evaluate_perplexity(extended_model, tokenizer, test_text, length)
    perplexities_extended.append(ppl_ext)
    
    # Original model (will fail on >1024)
    if length <= 1024:
        ppl_orig = evaluate_perplexity(model, tokenizer, test_text, length)
        perplexities_original.append(ppl_orig)
    else:
        perplexities_original.append(float('inf'))  # Cannot handle
    
    print(f"Length {length}: Extended={ppl_ext:.2f}, Original={perplexities_original[-1]:.2f}")

# Part (e): Plot perplexity vs position
import matplotlib.pyplot as plt

def compute_position_perplexity(model, tokens, window_size=100):
    """Compute perplexity at each position"""
    model.eval()
    perplexities = []
    
    for i in range(0, len(tokens) - window_size, window_size):
        window = tokens[i:i+window_size]
        window_tensor = torch.tensor(window).unsqueeze(0).to(device)
        
        with torch.no_grad():
            outputs = model(window_tensor, labels=window_tensor)
            ppl = torch.exp(outputs.loss).item()
            perplexities.append(ppl)
    
    return perplexities

# Analyze 4096-token sequence
long_tokens = tokenizer.encode(test_text, max_length=4096, truncation=True)
position_ppls = compute_position_perplexity(extended_model, long_tokens)

plt.figure(figsize=(10, 6))
plt.plot(range(len(position_ppls)), position_ppls)
plt.xlabel('Position (in windows of 100 tokens)')
plt.ylabel('Perplexity')
plt.title('Perplexity vs Position for Extended GPT-2')
plt.grid(True)
plt.savefig('position_perplexity.png')

Expected Results:

LengthExtended ModelOriginal Model
102425.325.0
204828.7$\infty$ (fails)
307232.1$\infty$ (fails)
409635.8$\infty$ (fails)

Analysis:

Position Interpolation Mechanism:

Original positions: $\{0, 1, 2, \ldots, 1023\}$

Extended positions: $\{0, 1, 2, \ldots, 4095\}$

Mapping: $\text{old\_pos} = \text{new\_pos} \times \frac{1024}{4096} = \text{new\_pos} \times 0.25$

Example:

Why It Works:

  1. Smooth interpolation: New positions lie between trained positions
  2. Preserves relative distances: Position relationships maintained
  3. No extrapolation: All new positions within trained range
  4. Minimal fine-tuning: Model adapts quickly (1000 steps)

Limitations:

Comparison to Alternatives:

MethodZero-shotFine-tuningQuality
Position InterpolationNo1k stepsGood
ALiBiYesNoneExcellent
RoPEPartialFew stepsVery Good
Learned ExtensionNo10k+ stepsBest

Key Insights:

Solution: Exercise 3: Sparse Attention Patterns

Summary of Implementations:

Part (a): Local Attention ($w=256$, $n=2048$)

Connections: $n \times w = 2048 \times 256 = 524{,}288$

Reduction: $\frac{524{,}288}{2048^2} = 12.5\%$ of full attention

Part (b): Strided Attention ($s=64$, $n=2048$)

Connections: $n \times \frac{n}{s} = 2048 \times 32 = 65{,}536$

Reduction: $1.56\%$ of full attention

Part (c): BigBird ($w=128$, $r=32$, $g=4$)

Per token: $w + r + g = 164$ connections

Total: $2048 \times 164 = 335{,}872$ connections

Reduction: $8.0\%$ of full attention

Part (d): Comparison Table

PatternConnections\% of Full
Full Attention4,194,304100\%
Local ($w=256$)524,28812.5\%
Strided ($s=64$)65,5361.56\%
BigBird335,8728.0\%

Part (e): Dependency Capture

Solution: Exercise 4: ALiBi Extrapolation

Key Results:

ALiBi Slopes: $m_i = 2^{-8i/4}$ for 4 heads gives $\{0.25, 0.0625, 0.0156, 0.0039\}$

Expected Perplexity:

LengthALiBiSinusoidal
512 (trained)28.528.3
102431.245.7
204835.889.3
409642.1156.2

Analysis:

ALiBi extrapolates gracefully because:

  1. Linear bias generalizes to any distance
  2. No learned position parameters
  3. Head-specific slopes capture different ranges
  4. Monotonic decay prevents position confusion

Sinusoidal fails because:

Solution: Exercise 5: Retrieval-Augmented Generation

Implementation Summary:

Steps:

  1. Embed 1000 Wikipedia articles with BERT
  2. Store in FAISS index for fast retrieval
  3. For question, retrieve top-5 documents
  4. Concatenate: [Question] [Doc1] [Doc2] ... [Doc5]
  5. Generate answer with GPT-2/T5

Expected Results:

MethodAccuracy
No retrieval (baseline)35\%
RAG (top-5)72\%
RAG (top-10)75\%

Key Insights:

RAG succeeds when:

RAG fails when:

Retrieval quality is critical: 90\% of performance depends on retrieving correct documents.

Solution: Exercise 6: Transformer-XL Segment Processing

Key Concepts:

Segment-level recurrence:

Memory vs Computation Trade-off:

MethodMemoryTime
Full sequence (2048)$O(2048^2)$1.0x
Segments (8 × 256)$O(512^2)$0.3x

Effective context length:

Trade-off: 70\% memory reduction, 3x faster, but gradual context buildup.

Solution: Exercise 7: Long Context Cost Analysis

Measured Performance (A100 GPU):

LengthTime (ms)FLOPsCost/1M tokens
512452.4 GFLOPs\$0.31
10241209.6 GFLOPs\$0.83
204838038.4 GFLOPs\$2.64
40961200153.6 GFLOPs\$8.33

Longformer Savings:

At $n=4096$: \$8.33 $\to$ \$1.04 (87.5\% cost reduction)

When long context is justified:

When retrieval is better:

Solution: Exercise 8: Position Encoding Comparison

Extrapolation Performance:

Method512102420484096
Learned28.567.3142.8298.5
Sinusoidal28.345.789.3156.2
RoPE28.732.138.948.2
ALiBi28.931.536.243.1

Ranking (best to worst extrapolation):

  1. ALiBi: Best extrapolation, linear bias generalizes perfectly
  2. RoPE: Very good, rotary embeddings maintain relative positions
  3. Sinusoidal: Moderate, periodic nature helps but not optimal
  4. Learned: Worst, completely fails beyond training length

After fine-tuning on longer sequences:

Learned and sinusoidal improve significantly, but ALiBi and RoPE still maintain advantage in zero-shot extrapolation.

Key Takeaway: For applications requiring variable-length contexts, use ALiBi or RoPE for best extrapolation without fine-tuning.

← Chapter 18: Multimodal Transformers 📚 Table of Contents Chapter 20: Pretraining Strategies →