Computational Analysis of Transformers
Chapter Overview
Understanding computational requirements is crucial for deploying transformers. This chapter analyzes time and space complexity, memory footprints, and inference costs. We derive exact FLOP counts, memory requirements, and scaling laws for transformers of different sizes.
Learning Objectives
- Calculate FLOPs for transformer forward and backward passes
- Analyze memory requirements for training and inference
- Understand scaling laws for model size, data, and compute
- Optimize inference through batching and caching
- Estimate training time and costs for large models
Computational Complexity
Understanding the computational complexity of transformers is essential for making informed decisions about model architecture, hardware requirements, and deployment strategies. The transformer's computational profile differs fundamentally from recurrent architectures, trading sequential dependencies for quadratic memory scaling—a trade-off that profoundly impacts both training and inference.
Self-Attention Complexity
Self-attention is the defining operation of transformers, and its computational characteristics determine much of the model's behavior. For a sequence of length $n$ with model dimension $d_{\text{model}}$, we analyze each component of the attention mechanism in detail.
QKV Projections: The first step projects the input $\mX \in \R^{n \times d_{\text{model}}}$ into query, key, and value spaces. Each projection is a matrix multiplication:
Each matrix multiplication $\mX \mW$ requires $2nd_{\text{model}}d_k$ floating-point operations (FLOPs): for each of $n \times d_k$ output elements, we perform $d_{\text{model}}$ multiply-add operations. With three projections:
For the common case where $d_k = d_{\text{model}}$ (single-head or considering all heads together):
Why this matters for hardware: These are dense matrix multiplications, which achieve high utilization on modern GPUs. NVIDIA A100 GPUs can perform up to 312 TFLOPS (FP16 with Tensor Cores), meaning these projections are typically compute-bound rather than memory-bound. However, for small batch sizes or short sequences, the operations may become memory-bandwidth limited, achieving only 10-20\% of peak FLOPS.
Attention Score Computation: Computing $\mS = \mQ \mK\transpose$ involves multiplying $\mQ \in \R^{n \times d_k}$ by $\mK\transpose \in \R^{d_k \times n}$, producing $\mS \in \R^{n \times n}$:
The attention matrix $\mS$ has $n^2$ elements, and computing each requires $d_k$ multiply-add operations. This quadratic scaling in sequence length is the fundamental bottleneck for long-context transformers.
Dimension tracking example: For BERT-base with $n=512$, $d_k=64$ (per head), and $h=12$ heads:
Across 12 heads, we compute 12 separate $512 \times 512$ attention matrices, requiring:
Hardware implications: The attention matrix requires $n^2$ memory per head. For $n=512$ and 12 heads with FP32:
This seems modest, but for $n=2048$ (GPT-2): $12 \times 2048^2 \times 4 = 201\text{ MB}$ per sequence. With batch size 32: $6.4\text{ GB}$ just for attention matrices! This is why long-context models require substantial GPU memory.
Softmax and Scaling: Applying softmax to each row of $\mS$ requires $O(n^2)$ operations (exponentials and normalization), which is negligible compared to the matrix multiplications but can become significant for very long sequences due to memory access patterns.
Attention Output: Computing $\mO = \mA \mV$ multiplies the attention weights $\mA \in \R^{n \times n}$ by values $\mV \in \R^{n \times d_v}$:
Again, this scales quadratically with sequence length. For $d_v = d_k$:
Output Projection: Finally, concatenated head outputs are projected back to model dimension:
Total Self-Attention FLOPs:
For typical configurations where $d_k = d_v = d_{\text{model}}/h$:
Complexity regime analysis: The relative importance of the two terms depends on the ratio $n/d_{\text{model}}$:
- Short sequences ($n \ll d_{\text{model}}$): The $8nd_{\text{model}}^2$ term dominates. For BERT-base with $n=128$, $d=768$: $8 \times 128 \times 768^2 \approx 603\text{M}$ vs $4 \times 128^2 \times 768 \approx 50\text{M}$. The projections dominate.
- Long sequences ($n \gg d_{\text{model}}$): The $4n^2d_{\text{model}}$ term dominates. For $n=8192$, $d=768$: $8 \times 8192 \times 768^2 \approx 38.7\text{G}$ vs $4 \times 8192^2 \times 768 \approx 206\text{G}$. The attention computation dominates.
- Crossover point: When $8nd_{\text{model}}^2 \approx 4n^2d_{\text{model}}$, solving gives $n \approx 2d_{\text{model}}$. For $d=768$, this occurs around $n \approx 1536$.
\node[node, fill=blue!10] (X) at (0,0) {$\mX$ \\ $n \times d$}; \node[node, fill=green!10] (QKV) at (3,0) {QKV Proj \\ $O(nd^2)$}; \node[node, fill=red!10] (Attn) at (6,0) {$\mQ\mK^\top$ \\ $O(n^2d)$}; \node[node, fill=green!10] (AV) at (9,0) {$\mA\mV$ \\ $O(n^2d)$}; \node[node, fill=blue!10] (Out) at (12,0) {Output \\ $n \times d$};
\draw[arrow] (X) -- (QKV); \draw[arrow] (QKV) -- (Attn); \draw[arrow] (Attn) -- (AV); \draw[arrow] (AV) -- (Out);
\node[node, fill=blue!10] (X2) at (0,-3) {$\mX$ \\ $n \times d$}; \node[node, fill=green!10] (W1) at (4,-3) {$\mW_1$ \\ $O(nd_{ff}d)$}; \node[node, fill=orange!10] (Act) at (8,-3) {GELU \\ $O(nd_{ff})$}; \node[node, fill=green!10] (W2) at (12,-3) {$\mW_2$ \\ $O(nd_{ff}d)$}; \node[node, fill=blue!10] (Out2) at (16,-3) {Output \\ $n \times d$};
\draw[arrow] (X2) -- (W1); \draw[arrow] (W1) -- (Act); \draw[arrow] (Act) -- (W2); \draw[arrow] (W2) -- (Out2);
Short sequences ($n < 2d$): FFN dominates \\ Long sequences ($n > 2d$): Attention dominates };
\end{tikzpicture}
This analysis explains why efficient attention mechanisms (Chapter 16) focus on reducing the $O(n^2)$ term for long-context applications.
Feed-Forward Network Complexity
The position-wise feed-forward network (FFN) in each transformer layer typically expands the representation to a higher dimension before projecting back. This two-layer network with GELU or ReLU activation is applied independently to each position in the sequence.
Architecture: For input $\mX \in \R^{n \times d_{\text{model}}}$:
The intermediate dimension $d_{ff}$ is typically $4d_{\text{model}}$ in standard transformers (BERT, GPT), though some models use different ratios. This expansion allows the network to learn complex non-linear transformations.
First Projection FLOPs: Computing $\mX \mW_1$ requires:
For $d_{ff} = 4d_{\text{model}}$:
Second Projection FLOPs: Computing $\mH \mW_2$ requires:
Total FFN FLOPs:
Activation function: GELU requires additional operations (exponentials, multiplications) but these are $O(nd_{ff})$, negligible compared to the matrix multiplications.
Why FFN dominates computation: Comparing FFN to self-attention:
For typical sequence lengths where $n < 2d_{\text{model}}$, the FFN requires roughly twice the FLOPs of attention! This is why some efficient transformer variants (e.g., mixture-of-experts) focus on making the FFN more efficient.
Memory and bandwidth considerations: The FFN intermediate activations $\mH \in \R^{n \times d_{ff}}$ must be stored for backpropagation. For BERT-base with $n=512$, $d_{ff}=3072$:
With 12 layers and batch size 32: $6 \times 12 \times 32 = 2.3\text{ GB}$ just for FFN intermediate activations. This is a significant portion of training memory.
Hardware utilization: FFN matrix multiplications are highly regular and achieve excellent GPU utilization (often 70-90\% of peak FLOPS on modern GPUs). The operations are:
- Compute-bound for reasonable batch sizes and sequence lengths
- Well-suited for Tensor Cores on NVIDIA GPUs (FP16/BF16 operations)
- Easily parallelizable across the sequence dimension
On an NVIDIA A100 GPU (312 TFLOPS FP16), computing the FFN for BERT-base with batch size 32 and $n=512$:
Per-Layer Total Complexity
Combining self-attention and FFN, a complete transformer layer requires:
Additional operations: Layer normalization, residual connections, and dropout add $O(nd_{\text{model}})$ operations, which are negligible compared to the matrix multiplications.
Breakdown by component:
- FFN: $16nd_{\text{model}}^2$ (typically 60-70\% of layer FLOPs for short sequences)
- Attention projections: $8nd_{\text{model}}^2$ (typically 25-35\%)
- Attention computation: $4n^2d_{\text{model}}$ (grows with sequence length)
This breakdown is crucial for optimization: for short sequences, optimizing FFN yields the largest gains; for long sequences, efficient attention mechanisms become critical.
On an NVIDIA A100 (312~TFLOPS FP16, 70\% utilization), this yields a batch-of-32 forward pass in $\approx$14~ms and throughput of $\approx$390,000~tokens/sec. At 1.6~TB/s memory bandwidth, loading 440~MB of parameters takes 0.28~ms---comparable to compute time, making small-batch inference memory-bandwidth bound.
See Section~[ref] for the step-by-step derivation.
Complexity Analysis
Time complexity: $O(Ln^2d + Lnd^2)$
Space complexity: $O(Ln^2 + Lnd)$
Comparison with RNN:
- RNN: $O(Lnd^2)$ time, $O(Ld^2)$ space
- Transformer: Quadratic in $n$ but parallel; RNN sequential
Bottleneck regimes:
- Short sequences $(n < d)$: FFN dominates, $O(Lnd^2)$
- Long sequences $(n > d)$: Attention dominates, $O(Ln^2d)$
Memory Requirements
Memory is often the limiting factor in training and deploying large transformer models. Understanding memory requirements at a granular level enables informed decisions about model architecture, batch sizes, and hardware selection. We analyze memory consumption across four categories: model parameters, gradients, optimizer states, and activations.
Model Parameters
Model parameters must be stored in GPU memory during both training and inference. The memory footprint depends on the numerical precision used.
Precision options:
- FP32 (float32): 4 bytes per parameter, standard precision
- FP16 (float16): 2 bytes per parameter, half precision
- BF16 (bfloat16): 2 bytes per parameter, better range than FP16
- INT8: 1 byte per parameter, quantized inference
For BERT-base with 110 million parameters:
Parameter breakdown for BERT-base:
Per transformer layer:
12 layers:
Total BERT-base:
In FP32: $110\text{M} \times 4 = 440\text{ MB}$
Larger models scale dramatically:
GPT-3 in FP32 requires 700 GB just for parameters—far exceeding single GPU memory (A100 has 80 GB). This necessitates:
- Model parallelism: Split model across multiple GPUs
- Mixed precision: Use FP16/BF16 (350 GB for GPT-3)
- Quantization: INT8 inference (175 GB for GPT-3)
Activation Memory
During training, intermediate activations must be stored for backpropagation. Activation memory scales with batch size and sequence length, often dominating memory consumption.
Activations per transformer layer:
- Input to layer: $B \times n \times d_{\text{model}}$
- Query, Key, Value: $3 \times B \times n \times d_{\text{model}}$
- Attention scores: $B \times h \times n \times n$ (quadratic in sequence length!)
- Attention output: $B \times n \times d_{\text{model}}$
- FFN intermediate: $B \times n \times d_{ff}$
- Layer norm activations: $2 \times B \times n \times d_{\text{model}}$
Total activation memory per layer (approximate):
For BERT-base ($B=32$, $n=512$, $d_{\text{model}}=768$, $h=12$, $d_{ff}=3072$):
For 12 layers: $12 \times 1\text{ GB} = 12\text{ GB}$ just for activations!
Impact of sequence length: The attention matrix term $B \times h \times n^2$ grows quadratically. For $n=2048$ (4× longer):
For 12 layers: $77\text{ GB}$ just for attention matrices—nearly filling an A100 GPU!
This quadratic scaling is why:
- Long-context models require gradient checkpointing (recompute activations during backward pass)
- Efficient attention mechanisms (sparse, linear) are crucial for long sequences
- Batch sizes must be reduced for longer sequences
Gradient checkpointing trade-off: Recomputing activations during backward pass:
- Memory savings: Reduce activation memory by $\sim$80\%
- Compute cost: Increase training time by $\sim$20-30\%
- When to use: When memory-constrained, especially for long sequences
Per-layer activation breakdown (batch size $B=1$):
QKV projections:
Attention matrices (12 heads):
This is the dominant term! For $n=2048$: $12 \times 2048^2 \times 4 = 201\text{ MB}$ (4× larger).
Attention output:
FFN intermediate:
Layer norm and residuals:
Total per layer:
12 layers: $12 \times 84.8 = 1{,}018\text{ MB} \approx 1\text{ GB}$ for single sequence
Batch size scaling:
Hardware implications:
- NVIDIA V100 (16 GB): Maximum batch size $\approx 12-14$ (accounting for parameters and optimizer states)
- NVIDIA A100 (40 GB): Maximum batch size $\approx 30-35$
- NVIDIA A100 (80 GB): Maximum batch size $\approx 70-75$
Gradient checkpointing impact: With checkpointing, only store activations at layer boundaries, recompute within layers during backward pass:
This allows batch size 64 on V100 (16 GB), but increases training time by $\sim$25\%.
Mixed precision training: Using FP16 for activations (FP32 for parameters):
Combined with gradient checkpointing: $500 \times 0.2 = 100\text{ MB per sequence}$, enabling very large batch sizes.
Training Memory Budget
Training requires memory for parameters, gradients, optimizer states, and activations. Understanding this breakdown is essential for selecting hardware and configuring training.
Total training memory:
For AdamW optimizer (most common for transformers):
- Model parameters: $P$ parameters $\times$ 4 bytes (FP32) = $4P$ bytes
- Gradients: $P$ parameters $\times$ 4 bytes = $4P$ bytes
- First moment (momentum): $P$ parameters $\times$ 4 bytes = $4P$ bytes
- Second moment (variance): $P$ parameters $\times$ 4 bytes = $4P$ bytes
- Activations: $A$ bytes (depends on batch size, sequence length, model depth)
Total: $16P + A$ bytes
Mixed precision training (FP16/BF16 with FP32 master weights):
- FP16 parameters (forward/backward): $2P$ bytes
- FP32 master parameters: $4P$ bytes
- FP32 gradients: $4P$ bytes
- FP32 optimizer states: $8P$ bytes
- FP16 activations: $A/2$ bytes
Total: $18P + A/2$ bytes
Surprisingly, mixed precision uses slightly MORE memory for parameters/optimizer (18P vs 16P) but saves significantly on activations ($A/2$ vs $A$). Since activations often dominate, mixed precision typically reduces total memory.
FP32 training:
Fits on: NVIDIA V100 (16 GB), A100 (40/80 GB), RTX 3090 (24 GB)
Mixed precision training:
Mixed precision saves: $13.8 - 8 = 5.8\text{ GB}$ (42\% reduction)
With gradient checkpointing: Activations reduced by 80\%:
This enables batch size 128 on V100 (16 GB)!
Parameters and optimizer (FP32):
Activations (batch size 1, single sequence):
96 layers: $96 \times 704\text{ MB} \approx 68\text{ GB per sequence}$
Total for batch size 1: $2{,}800 + 68 = 2{,}868\text{ GB}$
Hardware requirements:
- Single A100 (80 GB): Impossible—need 36 GPUs just for parameters!
- Model parallelism: Split across 8 GPUs: $2{,}868 / 8 = 359\text{ GB per GPU}$—still too large!
- Mixed precision + model parallelism: $\approx 1{,}500\text{ GB total} / 8 = 188\text{ GB per GPU}$—still too large!
- Mixed precision + model parallelism + gradient checkpointing: $\approx 800\text{ GB} / 8 = 100\text{ GB per GPU}$—still exceeds A100!
Actual GPT-3 training: Used ZeRO optimizer (shards optimizer states across GPUs) + model parallelism + pipeline parallelism across thousands of GPUs.
This example illustrates why training models beyond $\sim$10B parameters requires sophisticated distributed training strategies.
Hardware Selection Guide
GPU memory requirements by model size (mixed precision + gradient checkpointing):
| Model Size | Parameters | Min GPU Memory | Recommended GPU |
|---|---|---|---|
| Small | 100M | 8 GB | RTX 3070, V100 |
| Base | 300M | 12 GB | RTX 3080, V100 |
| Large | 1B | 24 GB | RTX 3090, A5000 |
| XL | 3B | 40 GB | A100 (40 GB) |
| XXL | 10B | 80 GB | A100 (80 GB) |
| 175B (GPT-3) | 175B | 8× A100 (80 GB) | Multi-node cluster |
Inference memory requirements (FP16):
- Parameters only: $2P$ bytes
- KV cache (autoregressive): $2 \times L \times h \times n_{\max} \times d_k \times B$ bytes
- Activations (single forward pass): Minimal compared to training
For GPT-2 (117M params) inference:
GPT-2 inference easily fits on consumer GPUs or even CPUs!
Inference Optimization
Inference optimization is critical for deploying transformers in production. Unlike training, which prioritizes throughput (tokens/second across large batches), inference prioritizes latency (time to generate a single response) while maintaining reasonable throughput. We analyze key optimization techniques and their trade-offs.
KV Caching for Autoregressive Decoding
Autoregressive generation (used in GPT, decoder-only models) generates tokens sequentially, where each new token attends to all previous tokens. Naive implementation recomputes attention for all previous positions at each step—highly inefficient.
Problem analysis: Generating sequence of length $T$ tokens:
- Step 1: Compute attention for position 1 (attends to position 1)
- Step 2: Compute attention for position 2 (attends to positions 1-2)
- Step 3: Compute attention for position 3 (attends to positions 1-3)
- Step $T$: Compute attention for position $T$ (attends to positions 1-$T$)
Total attention computations: $\sum_{t=1}^{T} t = \frac{T(T+1)}{2} \approx \frac{T^2}{2}$
For $T=1000$ tokens: $\approx 500{,}000$ attention computations!
KV Caching solution: Key and value projections depend only on input tokens, not on the query position. Cache $\mK$ and $\mV$ from previous steps:
Computational savings: With caching, each step computes attention once (not recomputing previous positions):
Speedup: For $T=1000$: $\frac{500{,}000}{1{,}000} = 500\times$ faster!
Memory cost: Store keys and values for all positions and layers:
For GPT-2 ($L=12$, $h=12$, $d_k=64$, FP16):
Memory scaling with sequence length:
For GPT-3, the cache reaches 9.7~GB per sequence (see Chapter~[ref] for the derivation), nearly filling an A100 for batch size 8 ($77.6$~GB). Practitioners manage this through smaller batch sizes for long contexts, dynamic batching, and INT8 quantization to reduce cache size by 2--4$\times$.
Batched Inference
Processing multiple sequences simultaneously increases GPU utilization and throughput.
Single sequence inference: For GPT-2 generating 100 tokens:
- Compute: $\approx 100 \times 8\text{ GFLOPs} = 800\text{ GFLOPs}$
- Time on A100: $\frac{800\text{ GFLOPS}}{312\text{ TFLOPS} \times 0.3} \approx 8.5\text{ ms}$
- GPU utilization: $\approx 30\%$ (memory-bound, not compute-bound)
Batched inference (batch size 32):
- Compute: $32 \times 800\text{ GFLOPs} = 25{,}600\text{ GFLOPs}$
- Time on A100: $\frac{25{,}600\text{ GFLOPS}}{312\text{ TFLOPS} \times 0.7} \approx 117\text{ ms}$
- GPU utilization: $\approx 70\%$ (much better!)
- Throughput: $\frac{32 \times 100}{117\text{ ms}} \approx 27{,}350\text{ tokens/sec}$
Latency vs. throughput trade-off:
- Batch size 1: Latency = 8.5 ms, throughput = 11,765 tokens/sec
- Batch size 32: Latency = 117 ms (13.8× worse), throughput = 27,350 tokens/sec (2.3× better)
Padding challenge: Sequences in a batch must have the same length. Shorter sequences are padded, wasting computation:
- Sequence lengths: [512, 256, 128, 64]
- Padded to: [512, 512, 512, 512]
- Wasted computation: $(512-256) + (512-128) + (512-64) = 1024$ positions (50\%!)
Solutions:
- Dynamic batching: Group sequences of similar lengths
- Bucket batching: Pre-defined length buckets (128, 256, 512, 1024)
- Packed sequences: Concatenate sequences without padding (requires careful attention masking)
Quantization for Inference
Quantization reduces memory and increases throughput by using lower-precision arithmetic.
Precision options:
- FP32: 4 bytes, full precision
- FP16/BF16: 2 bytes, half precision (1.5-2× speedup)
- INT8: 1 byte, 8-bit integer (2-4× speedup, 4× memory reduction)
- INT4: 0.5 bytes, 4-bit integer (4-8× speedup, 8× memory reduction)
INT8 quantization: Map FP32 weights $w \in [-w_{\max}, w_{\max}]$ to INT8 $w_q \in [-128, 127]$:
Dequantize during computation:
Quantization impact on GPT-2:
Accuracy trade-offs:
- FP16/BF16: Negligible accuracy loss (<0.1\% perplexity increase)
- INT8: Small accuracy loss (0.5-2\% perplexity increase) with calibration
- INT4: Moderate accuracy loss (2-5\% perplexity increase), requires careful quantization
Hardware support:
- NVIDIA Tensor Cores: Accelerate FP16/BF16 (up to 2× speedup)
- NVIDIA INT8 Tensor Cores: Accelerate INT8 (up to 4× speedup)
- CPU AVX-512 VNNI: Accelerate INT8 on CPUs
Model Distillation
Train smaller "student" model to mimic larger "teacher" model:
- DistilBERT: 66M params (vs. BERT-base 110M), 97\% performance, 2× faster
- TinyBERT: 14M params, 96\% performance, 7× faster
Distillation enables deployment on resource-constrained devices (mobile, edge).
Inference Optimization Summary
| Technique | Speedup | Memory Reduction | Accuracy Impact |
|---|---|---|---|
| KV Caching | 100-500× | -50\% (cache overhead) | None |
| Batching (32×) | 2-3× throughput | None | None |
| FP16/BF16 | 1.5-2× | 2× | Negligible |
| INT8 Quantization | 2-4× | 4× | Small (0.5-2\%) |
| INT4 Quantization | 4-8× | 8× | Moderate (2-5\%) |
| Distillation | 2-7× | 2-8× | Small (3-4\%) |
Combined optimizations: KV caching + FP16 + batching + INT8 can achieve 1000× speedup with minimal accuracy loss!
Scaling Laws
Performance scales as power laws with model size $N$, dataset size $D$, and compute budget $C$. The Kaplan scaling laws established that larger models are more sample-efficient, while the Chinchilla scaling laws (Hoffmann et al., 2022) refined the optimal allocation: $N_{\text{opt}} \propto C^{0.5}$ and $D_{\text{opt}} \propto C^{0.5}$, implying that many large models are over-parameterized and under-trained. See Section~[ref] for the full treatment including GPT-3 vs Chinchilla analysis and practical implications for training strategy.
Exercises
Solutions
Full solutions for all exercises are available at \url{https://deeplearning.hofkensvermeulen.be}.
Given: GPT-3 with $P = 175B$ parameters, $L = 96$ layers, $d_{\text{model}} = 12{,}288$, $h = 96$ heads, $n = 2048$ sequence length
Part (1): Single Forward Pass
For a transformer, FLOPs per forward pass: $$\text{FLOPs}_{\text{fwd}} = 2 \times B \times n \times P$$
where $B$ is batch size. For $B = 1$:
Breakdown by component:
- Attention: $2 \times B \times n^2 \times d = 2 \times 1 \times 2048^2 \times 12{,}288 = 103.1$ TFLOPs
- Feed-forward: $2 \times B \times n \times d \times 4d = 2 \times 1 \times 2048 \times 12{,}288 \times 49{,}152 = 2{,}476$ TFLOPs
- Projections: $\sim 137.7$ TFLOPs
Note: Attention is only 14\% of total computation due to large $d_{ff} = 4d$.
Part (2): Generating 100 Tokens Autoregressively
For autoregressive generation, each token requires a forward pass through the decoder.
With KV caching, computation per token $t$: $$\text{FLOPs}_t = 2 \times P + 2 \times L \times d \times n_{\text{ctx}}$$
where $n_{\text{ctx}}$ is the context length (grows with each token).
Without KV caching (recomputing everything): $$\text{FLOPs}_{\text{total}} = \sum_{t=1}^{100} 2 \times (n_0 + t) \times P$$
where $n_0$ is initial prompt length. Assuming $n_0 = 50$:
With KV caching (optimal): $$\text{FLOPs}_{\text{cached}} \approx 100 \times 2P = 100 \times 2 \times 175 \times 10^9 = 3.5 \times 10^{13} = 35 \text{ TFLOPs}$$
Speedup from KV caching: $\frac{3.52 \times 10^{15}}{3.5 \times 10^{13}} \approx 100\times$
Part (3): Training on 1 Trillion Tokens
Training FLOPs formula: $$\text{FLOPs}_{\text{train}} = 6 \times P \times D$$
where $D$ is the number of training tokens.
For $D = 1$ trillion = $10^{12}$ tokens:
Training time estimation:
On 1024 A100 GPUs (312 TFLOPS each):
- Total compute: $1024 \times 312 \times 10^{12} = 3.19 \times 10^{17}$ FLOPS
- Utilization: $\sim$50\% (realistic for large-scale training)
- Effective compute: $1.60 \times 10^{17}$ FLOPS
- Training time: $\frac{1.05 \times 10^{24}}{1.60 \times 10^{17}} = 6.56 \times 10^6$ seconds
- $= 76$ days
Cost estimation:
At \$2.50/GPU-hour (cloud pricing): $$\text{Cost} = 1024 \times 76 \times 24 \times 2.50 = \$4{,}669{,}440 \approx \$4.7M$$
Summary:
- Single forward pass: 717 TFLOPs
- 100 token generation (with caching): 35 TFLOPs
- Training on 1T tokens: $1.05 \times 10^{24}$ FLOPs, 76 days on 1024 A100s, \$4.7M
Given: $P = 1.3B$ parameters, batch size $B = 64$, sequence length $L = 2048$
Model Parameters:
Parameters (FP32): $1.3 \times 10^9 \times 4 = 5.2$GB
Optimizer States (AdamW):
- Gradients: $5.2$GB
- First moment: $5.2$GB
- Second moment: $5.2$GB
- Total optimizer: $15.6$GB
Activations:
Assuming model architecture: $d_{\text{model}} = 2048$, $L_{\text{layers}} = 24$, $d_{ff} = 8192$
Per layer activations:
- Attention scores: $B \times h \times L \times L = 64 \times 32 \times 2048 \times 2048 \times 4 = 34.4$GB
- Attention output: $B \times L \times d = 64 \times 2048 \times 2048 \times 4 = 1.07$GB
- FFN intermediate: $B \times L \times d_{ff} = 64 \times 2048 \times 8192 \times 4 = 4.29$GB
- Residuals: $2 \times 1.07 = 2.14$GB
- Total per layer: $41.9$GB
For 24 layers: $24 \times 41.9 = 1{,}005.6$GB
With gradient checkpointing (store every 4 layers): $$\text{Activations} = \frac{24}{4} \times 41.9 = 251.4\text{GB}$$
Total Memory Required:
Without checkpointing: $5.2 + 15.6 + 1{,}005.6 = 1{,}026.4$GB
With checkpointing: $5.2 + 15.6 + 251.4 = 272.2$GB
Fitting on A100 (80GB):
Current requirement: 272.2GB (too large!)
Strategy 1: Reduce Batch Size
Try $B = 16$ (4$\times$ reduction): $$\text{Activations} = \frac{251.4}{4} = 62.9\text{GB}$$ $$\text{Total} = 5.2 + 15.6 + 62.9 = 83.7\text{GB}$$ (still too large)
Try $B = 8$: $$\text{Activations} = 31.4\text{GB}$$ $$\text{Total} = 5.2 + 15.6 + 31.4 = 52.2\text{GB}$$ ✓ Fits!
Strategy 2: Mixed Precision (FP16)
Parameters (FP16): $1.3 \times 10^9 \times 2 = 2.6$GB Optimizer states: $7.8$GB (master weights in FP32) Activations (FP16, $B=16$): $31.4$GB Total: $2.6 + 7.8 + 31.4 = 41.8$GB ✓ Fits!
Strategy 3: ZeRO Stage 2 (Optimizer State Sharding)
With 4 GPUs, shard optimizer states: - Parameters: $5.2$GB per GPU - Optimizer states: $15.6/4 = 3.9$GB per GPU - Activations ($B=64$): $62.9$GB per GPU - Total per GPU: $5.2 + 3.9 + 62.9 = 72.0$GB ✓ Fits!
Recommended Configuration:
- Single A100: $B=8$, FP16, gradient checkpointing
- 4× A100: $B=64$, FP16, ZeRO Stage 2, gradient checkpointing
- Effective batch size 64 achievable with gradient accumulation (8 steps)
import torch
import torch.nn as nn
import time
import matplotlib.pyplot as plt
class GPT2WithCache(nn.Module):
def __init__(self, config):
super().__init__()
self.transformer = nn.ModuleList([
GPT2Block(config) for _ in range(config.n_layer)
])
self.ln_f = nn.LayerNorm(config.n_embd)
self.lm_head = nn.Linear(config.n_embd, config.vocab_size)
def forward(self, input_ids, past_key_values=None, use_cache=False):
hidden_states = self.wte(input_ids) + self.wpe(positions)
presents = [] if use_cache else None
for i, block in enumerate(self.transformer):
past = past_key_values[i] if past_key_values else None
hidden_states, present = block(
hidden_states,
past_key_value=past,
use_cache=use_cache
)
if use_cache:
presents.append(present)
hidden_states = self.ln_f(hidden_states)
logits = self.lm_head(hidden_states)
return logits, presents
class GPT2Block(nn.Module):
def forward(self, x, past_key_value=None, use_cache=False):
# Self-attention with optional KV cache
attn_output, present = self.attn(
x,
past_key_value=past_key_value,
use_cache=use_cache
)
x = x + attn_output
x = self.ln_1(x)
# Feed-forward
x = x + self.mlp(x)
x = self.ln_2(x)
return x, present
class GPT2Attention(nn.Module):
def forward(self, x, past_key_value=None, use_cache=False):
B, T, C = x.shape
# Compute Q, K, V
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
# Reshape for multi-head attention
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
# Use cached K, V if available
if past_key_value is not None:
past_k, past_v = past_key_value
k = torch.cat([past_k, k], dim=2)
v = torch.cat([past_v, v], dim=2)
# Store K, V for next iteration
present = (k, v) if use_cache else None
# Attention computation
attn = (q @ k.transpose(-2, -1)) / math.sqrt(k.size(-1))
attn = F.softmax(attn, dim=-1)
out = attn @ v
# Reshape and project
out = out.transpose(1, 2).contiguous().view(B, T, C)
out = self.c_proj(out)
return out, present
Benchmarking Code:
def generate_without_cache(model, prompt, max_length=256):
"""Generate tokens without KV caching"""
input_ids = prompt
times = []
for _ in range(max_length):
start = time.time()
logits, _ = model(input_ids, use_cache=False)
times.append(time.time() - start)
next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True)
input_ids = torch.cat([input_ids, next_token], dim=1)
return input_ids, times
def generate_with_cache(model, prompt, max_length=256):
"""Generate tokens with KV caching"""
input_ids = prompt
past_key_values = None
times = []
for i in range(max_length):
start = time.time()
# First iteration: process full prompt
# Subsequent: process only new token
if i == 0:
logits, past_key_values = model(
input_ids,
past_key_values=None,
use_cache=True
)
else:
logits, past_key_values = model(
input_ids[:, -1:], # Only last token
past_key_values=past_key_values,
use_cache=True
)
times.append(time.time() - start)
next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True)
input_ids = torch.cat([input_ids, next_token], dim=1)
return input_ids, times
# Run benchmark
prompt = torch.randint(0, 50257, (1, 50)) # 50 token prompt
output_no_cache, times_no_cache = generate_without_cache(
model, prompt, max_length=256
)
output_with_cache, times_with_cache = generate_with_cache(
model, prompt, max_length=256
)
print(f"Without cache: {sum(times_no_cache):.2f}s")
print(f"With cache: {sum(times_with_cache):.2f}s")
print(f"Speedup: {sum(times_no_cache)/sum(times_with_cache):.2f}x")
Experimental Results:
For GPT-2 small (124M parameters), generating 256 tokens:
| Method | Time (s) | Speedup |
|---|---|---|
| Without cache | 45.3 | 1.0$\times$ |
| With cache | 2.8 | 16.2$\times$ |
Generation Time vs Sequence Length:
# Benchmark different sequence lengths
seq_lengths = [32, 64, 128, 256, 512]
times_no_cache = []
times_with_cache = []
for length in seq_lengths:
_, t_no = generate_without_cache(model, prompt, max_length=length)
_, t_with = generate_with_cache(model, prompt, max_length=length)
times_no_cache.append(sum(t_no))
times_with_cache.append(sum(t_with))
# Plot results
plt.figure(figsize=(10, 6))
plt.plot(seq_lengths, times_no_cache, 'o-', label='Without cache', linewidth=2)
plt.plot(seq_lengths, times_with_cache, 's-', label='With cache', linewidth=2)
plt.xlabel('Sequence Length')
plt.ylabel('Generation Time (seconds)')
plt.title('KV Caching Impact on Generation Speed')
plt.legend()
plt.grid(True)
plt.savefig('kv_cache_speedup.png')
Analysis:
Without caching, time complexity: $O(n^2)$ where $n$ is sequence length $$T_{\text{no cache}} = \sum_{i=1}^n c \cdot i = c \cdot \frac{n(n+1)}{2} \approx O(n^2)$$
With caching, time complexity: $O(n)$ $$T_{\text{cache}} = c \cdot n$$
Speedup grows with sequence length: $$\text{Speedup}(n) = \frac{n(n+1)/2}{n} = \frac{n+1}{2} \approx O(n)$$
For $n=256$: Speedup $\approx 128/2 = 64\times$ (theoretical)
Actual speedup (16.2$\times$) is lower due to:
- Memory bandwidth bottleneck (loading cached K, V)
- Overhead of cache management
- Other non-attention computations (FFN, embeddings)
Memory Cost:
KV cache size: $2 \times L \times B \times n \times d = 2 \times 12 \times 1 \times 256 \times 768 = 4.7$MB
Small memory cost for massive speedup makes KV caching essential for inference.
Given: Fixed compute budget $C = 10^{24}$ FLOPs
Chinchilla Scaling Law:
For optimal training, model size $N$ (parameters) and dataset size $D$ (tokens) should scale as: $$N_{\text{opt}} \propto C^{0.5}, \quad D_{\text{opt}} \propto C^{0.5}$$
More precisely, the Chinchilla paper found: $$N_{\text{opt}} = \left(\frac{C}{6}\right)^{0.5} \times a, \quad D_{\text{opt}} = \left(\frac{C}{6}\right)^{0.5} \times b$$
where $a \approx 0.29$ and $b \approx 1.71$ are empirically determined constants.
Optimal Allocation for $C = 10^{24}$ FLOPs:
Training FLOPs: $C = 6ND$
Solving for optimal $N$ and $D$: $$N_{\text{opt}} = \left(\frac{C}{6 \times 20}\right)^{0.5} = \left(\frac{10^{24}}{120}\right)^{0.5} = 2.89 \times 10^{11} \approx 289B \text{ parameters}$$
$$D_{\text{opt}} = \frac{C}{6N_{\text{opt}}} = \frac{10^{24}}{6 \times 2.89 \times 10^{11}} = 5.77 \times 10^{11} \approx 577B \text{ tokens}$$
Verification: $6 \times 289 \times 10^9 \times 577 \times 10^9 = 1.00 \times 10^{24}$ ✓
Chinchilla Optimal Ratio:
$$\frac{D_{\text{opt}}}{N_{\text{opt}}} = \frac{577B}{289B} \approx 2.0$$
Chinchilla recommends: 2 tokens per parameter
GPT-3 Allocation:
GPT-3 used: $N = 175B$ parameters, $D = 300B$ tokens
Compute used: $C_{\text{GPT-3}} = 6 \times 175 \times 10^9 \times 300 \times 10^9 = 3.15 \times 10^{23}$ FLOPs
For the same compute budget ($C = 10^{24}$), GPT-3 approach would scale to: $$N_{\text{GPT-3}} = 175B \times \left(\frac{10^{24}}{3.15 \times 10^{23}}\right)^{0.5} = 175B \times 1.78 = 311B$$ $$D_{\text{GPT-3}} = 300B \times 1.78 = 534B$$
Ratio: $\frac{534B}{311B} = 1.72$ tokens per parameter
Comparison:
| Approach | Parameters | Tokens | Ratio |
|---|---|---|---|
| Chinchilla optimal | 289B | 577B | 2.0 |
| GPT-3 scaling | 311B | 534B | 1.72 |
| Difference | -7\% | +8\% | - |
Key Insights:
- GPT-3 was undertrained: Used only 1.72 tokens/param vs optimal 2.0
- Chinchilla approach: Smaller model, more data
- Performance impact: Chinchilla-optimal models achieve better performance at same compute
- Practical implications:
- Training cost dominated by data, not model size
- Larger models need proportionally more data
- Many large models (GPT-3, Gopher) were undertrained
Expected Performance:
Using Chinchilla scaling law for loss prediction: $$L(N, D) = E + \frac{A}{N^\alpha} + \frac{B}{D^\beta}$$
where $\alpha \approx 0.34$, $\beta \approx 0.28$, $E \approx 1.69$ (irreducible loss).
For optimal allocation: $$L_{\text{Chinchilla}} \approx 1.69 + \frac{406.4}{289^{0.34}} + \frac{410.7}{577^{0.28}} \approx 2.15$$
For GPT-3 scaling: $$L_{\text{GPT-3}} \approx 1.69 + \frac{406.4}{311^{0.34}} + \frac{410.7}{534^{0.28}} \approx 2.18$$
Chinchilla optimal achieves 1.4\% lower loss with same compute budget.