Computational Analysis of Transformers

Chapter Overview

Understanding computational requirements is crucial for deploying transformers. This chapter analyzes time and space complexity, memory footprints, and inference costs. We derive exact FLOP counts, memory requirements, and scaling laws for transformers of different sizes.

Learning Objectives

  1. Calculate FLOPs for transformer forward and backward passes
  2. Analyze memory requirements for training and inference
  3. Understand scaling laws for model size, data, and compute
  4. Optimize inference through batching and caching
  5. Estimate training time and costs for large models

Computational Complexity

Understanding the computational complexity of transformers is essential for making informed decisions about model architecture, hardware requirements, and deployment strategies. The transformer's computational profile differs fundamentally from recurrent architectures, trading sequential dependencies for quadratic memory scaling—a trade-off that profoundly impacts both training and inference.

Self-Attention Complexity

Self-attention is the defining operation of transformers, and its computational characteristics determine much of the model's behavior. For a sequence of length $n$ with model dimension $d_{\text{model}}$, we analyze each component of the attention mechanism in detail.

QKV Projections: The first step projects the input $\mX \in \R^{n \times d_{\text{model}}}$ into query, key, and value spaces. Each projection is a matrix multiplication:

$$ \mQ = \mX \mW^Q, \quad \mK = \mX \mW^K, \quad \mV = \mX \mW^V $$
where $\mW^Q, \mW^K, \mW^V \in \R^{d_{\text{model}} \times d_k}$ (typically $d_k = d_{\text{model}}/h$ for $h$ heads).

Each matrix multiplication $\mX \mW$ requires $2nd_{\text{model}}d_k$ floating-point operations (FLOPs): for each of $n \times d_k$ output elements, we perform $d_{\text{model}}$ multiply-add operations. With three projections:

$$ \text{QKV FLOPs} = 3 \times 2nd_{\text{model}}d_k = 6nd_{\text{model}}d_k $$

For the common case where $d_k = d_{\text{model}}$ (single-head or considering all heads together):

$$ \text{QKV FLOPs} = 6nd_{\text{model}}^2 $$

Why this matters for hardware: These are dense matrix multiplications, which achieve high utilization on modern GPUs. NVIDIA A100 GPUs can perform up to 312 TFLOPS (FP16 with Tensor Cores), meaning these projections are typically compute-bound rather than memory-bound. However, for small batch sizes or short sequences, the operations may become memory-bandwidth limited, achieving only 10-20\% of peak FLOPS.

Attention Score Computation: Computing $\mS = \mQ \mK\transpose$ involves multiplying $\mQ \in \R^{n \times d_k}$ by $\mK\transpose \in \R^{d_k \times n}$, producing $\mS \in \R^{n \times n}$:

$$ \text{Score FLOPs} = 2n^2d_k $$

The attention matrix $\mS$ has $n^2$ elements, and computing each requires $d_k$ multiply-add operations. This quadratic scaling in sequence length is the fundamental bottleneck for long-context transformers.

Dimension tracking example: For BERT-base with $n=512$, $d_k=64$ (per head), and $h=12$ heads:

$$\begin{align} \mQ^{(i)} &\in \R^{512 \times 64} \quad \text{(one head)} \\ \mK^{(i)\transpose} &\in \R^{64 \times 512} \\ \mS^{(i)} &= \mQ^{(i)} \mK^{(i)\transpose} \in \R^{512 \times 512} \quad \text{(262,144 elements!)} \end{align}$$

Across 12 heads, we compute 12 separate $512 \times 512$ attention matrices, requiring:

$$ 12 \times 2 \times 512^2 \times 64 = 402{,}653{,}184 \text{ FLOPs} \approx 403 \text{ MFLOPs} $$

Hardware implications: The attention matrix requires $n^2$ memory per head. For $n=512$ and 12 heads with FP32:

$$ 12 \times 512^2 \times 4\text{ bytes} = 12{,}582{,}912\text{ bytes} \approx 12\text{ MB} $$

This seems modest, but for $n=2048$ (GPT-2): $12 \times 2048^2 \times 4 = 201\text{ MB}$ per sequence. With batch size 32: $6.4\text{ GB}$ just for attention matrices! This is why long-context models require substantial GPU memory.

Softmax and Scaling: Applying softmax to each row of $\mS$ requires $O(n^2)$ operations (exponentials and normalization), which is negligible compared to the matrix multiplications but can become significant for very long sequences due to memory access patterns.

Attention Output: Computing $\mO = \mA \mV$ multiplies the attention weights $\mA \in \R^{n \times n}$ by values $\mV \in \R^{n \times d_v}$:

$$ \text{Output FLOPs} = 2n^2d_v $$

Again, this scales quadratically with sequence length. For $d_v = d_k$:

$$ \text{Attention output FLOPs} = 2n^2d_k $$

Output Projection: Finally, concatenated head outputs are projected back to model dimension:

$$ \text{Output projection FLOPs} = 2n(hd_k)d_{\text{model}} = 2nd_{\text{model}}^2 $$
(assuming $hd_k = d_{\text{model}}$).

Total Self-Attention FLOPs:

$$ \text{Total} = 6nd_{\text{model}}^2 + 2n^2d_k h + 2n^2d_v h + 2nd_{\text{model}}^2 = 8nd_{\text{model}}^2 + 4n^2d_{\text{model}} $$

For typical configurations where $d_k = d_v = d_{\text{model}}/h$:

$$ \boxed{\text{Self-Attention FLOPs} = 8nd_{\text{model}}^2 + 4n^2d_{\text{model}}} $$

Complexity regime analysis: The relative importance of the two terms depends on the ratio $n/d_{\text{model}}$:

\begin{tikzpicture}[ node/.style={rectangle, draw, minimum width=2.5cm, minimum height=1cm, font=\small}, arrow/.style={->, thick} ]

\node[node, fill=blue!10] (X) at (0,0) {$\mX$ \\ $n \times d$}; \node[node, fill=green!10] (QKV) at (3,0) {QKV Proj \\ $O(nd^2)$}; \node[node, fill=red!10] (Attn) at (6,0) {$\mQ\mK^\top$ \\ $O(n^2d)$}; \node[node, fill=green!10] (AV) at (9,0) {$\mA\mV$ \\ $O(n^2d)$}; \node[node, fill=blue!10] (Out) at (12,0) {Output \\ $n \times d$};

\draw[arrow] (X) -- (QKV); \draw[arrow] (QKV) -- (Attn); \draw[arrow] (Attn) -- (AV); \draw[arrow] (AV) -- (Out);

\node[node, fill=blue!10] (X2) at (0,-3) {$\mX$ \\ $n \times d$}; \node[node, fill=green!10] (W1) at (4,-3) {$\mW_1$ \\ $O(nd_{ff}d)$}; \node[node, fill=orange!10] (Act) at (8,-3) {GELU \\ $O(nd_{ff})$}; \node[node, fill=green!10] (W2) at (12,-3) {$\mW_2$ \\ $O(nd_{ff}d)$}; \node[node, fill=blue!10] (Out2) at (16,-3) {Output \\ $n \times d$};

\draw[arrow] (X2) -- (W1); \draw[arrow] (W1) -- (Act); \draw[arrow] (Act) -- (W2); \draw[arrow] (W2) -- (Out2);

Short sequences ($n < 2d$): FFN dominates \\ Long sequences ($n > 2d$): Attention dominates };

\end{tikzpicture}

Computational flow comparison between self-attention and feed-forward network. Green boxes show matrix multiplications (compute-intensive), red shows the quadratic attention bottleneck, orange shows element-wise operations. For typical sequence lengths, FFN requires roughly 2× the FLOPs of attention.

This analysis explains why efficient attention mechanisms (Chapter 16) focus on reducing the $O(n^2)$ term for long-context applications.

Feed-Forward Network Complexity

The position-wise feed-forward network (FFN) in each transformer layer typically expands the representation to a higher dimension before projecting back. This two-layer network with GELU or ReLU activation is applied independently to each position in the sequence.

Architecture: For input $\mX \in \R^{n \times d_{\text{model}}}$:

$$\begin{align} \mH &= \text{GELU}(\mX \mW_1 + \vb_1) \quad &\mW_1 \in \R^{d_{\text{model}} \times d_{ff}}, \quad \mH \in \R^{n \times d_{ff}} \\ \mY &= \mH \mW_2 + \vb_2 \quad &\mW_2 \in \R^{d_{ff} \times d_{\text{model}}}, \quad \mY \in \R^{n \times d_{\text{model}}} \end{align}$$

The intermediate dimension $d_{ff}$ is typically $4d_{\text{model}}$ in standard transformers (BERT, GPT), though some models use different ratios. This expansion allows the network to learn complex non-linear transformations.

First Projection FLOPs: Computing $\mX \mW_1$ requires:

$$ \text{First projection} = 2n \cdot d_{\text{model}} \cdot d_{ff} $$

For $d_{ff} = 4d_{\text{model}}$:

$$ \text{First projection} = 2n \cdot d_{\text{model}} \cdot 4d_{\text{model}} = 8nd_{\text{model}}^2 $$

Second Projection FLOPs: Computing $\mH \mW_2$ requires:

$$ \text{Second projection} = 2n \cdot d_{ff} \cdot d_{\text{model}} = 8nd_{\text{model}}^2 $$

Total FFN FLOPs:

$$ \boxed{\text{FFN FLOPs} = 16nd_{\text{model}}^2 \quad \text{(for } d_{ff} = 4d_{\text{model}}\text{)}} $$

Activation function: GELU requires additional operations (exponentials, multiplications) but these are $O(nd_{ff})$, negligible compared to the matrix multiplications.

Why FFN dominates computation: Comparing FFN to self-attention:

$$\begin{align} \text{FFN:} \quad &16nd_{\text{model}}^2 \\ \text{Attention:} \quad &8nd_{\text{model}}^2 + 4n^2d_{\text{model}} \end{align}$$

For typical sequence lengths where $n < 2d_{\text{model}}$, the FFN requires roughly twice the FLOPs of attention! This is why some efficient transformer variants (e.g., mixture-of-experts) focus on making the FFN more efficient.

Memory and bandwidth considerations: The FFN intermediate activations $\mH \in \R^{n \times d_{ff}}$ must be stored for backpropagation. For BERT-base with $n=512$, $d_{ff}=3072$:

$$ 512 \times 3072 \times 4\text{ bytes} = 6{,}291{,}456\text{ bytes} \approx 6\text{ MB per layer} $$

With 12 layers and batch size 32: $6 \times 12 \times 32 = 2.3\text{ GB}$ just for FFN intermediate activations. This is a significant portion of training memory.

Hardware utilization: FFN matrix multiplications are highly regular and achieve excellent GPU utilization (often 70-90\% of peak FLOPS on modern GPUs). The operations are:

On an NVIDIA A100 GPU (312 TFLOPS FP16), computing the FFN for BERT-base with batch size 32 and $n=512$:

$$ \text{FLOPs} = 32 \times 16 \times 512 \times 768^2 \approx 154\text{ GFLOPS} $$
$$ \text{Time} \approx \frac{154\text{ GFLOPS}}{312 \times 0.8 \text{ TFLOPS}} \approx 0.62\text{ ms} $$
(assuming 80\% utilization).

Per-Layer Total Complexity

Combining self-attention and FFN, a complete transformer layer requires:

$$ \boxed{\text{Transformer layer} = (8nd_{\text{model}}^2 + 4n^2d_{\text{model}}) + 16nd_{\text{model}}^2 = 24nd_{\text{model}}^2 + 4n^2d_{\text{model}} \text{ FLOPs}} $$

Additional operations: Layer normalization, residual connections, and dropout add $O(nd_{\text{model}})$ operations, which are negligible compared to the matrix multiplications.

Breakdown by component:

This breakdown is crucial for optimization: for short sequences, optimizing FFN yields the largest gains; for long sequences, efficient attention mechanisms become critical.

Example: BERT-base ($n = 512$, $d_{\text{model}} = 768$, $h = 12$, $d_{ff} = 3072$) illustrates the component-level FLOPs breakdown. Self-attention totals 3.21~GFLOPs per layer (QKV projections: 1.81~G, attention scores: 0.40~G each for $\mQ\mK\transpose$ and $\mA\mV$, output projection: 0.60~G). The feed-forward network totals 4.84~GFLOPs (two projections of 2.42~G each). The complete layer costs 8.05~GFLOPs, giving 96.6~GFLOPs for the 12-layer forward pass and $\approx$290~GFLOPs for a full training step.

On an NVIDIA A100 (312~TFLOPS FP16, 70\% utilization), this yields a batch-of-32 forward pass in $\approx$14~ms and throughput of $\approx$390,000~tokens/sec. At 1.6~TB/s memory bandwidth, loading 440~MB of parameters takes 0.28~ms---comparable to compute time, making small-batch inference memory-bandwidth bound.

See Section~[ref] for the step-by-step derivation.

Complexity Analysis

Theorem: For $L$ layers, sequence length $n$, dimension $d$:

Time complexity: $O(Ln^2d + Lnd^2)$

Space complexity: $O(Ln^2 + Lnd)$

Comparison with RNN:

Bottleneck regimes:

Memory Requirements

Memory is often the limiting factor in training and deploying large transformer models. Understanding memory requirements at a granular level enables informed decisions about model architecture, batch sizes, and hardware selection. We analyze memory consumption across four categories: model parameters, gradients, optimizer states, and activations.

Model Parameters

Model parameters must be stored in GPU memory during both training and inference. The memory footprint depends on the numerical precision used.

Precision options:

For BERT-base with 110 million parameters:

$$\begin{align} \text{FP32:} \quad &110{,}000{,}000 \times 4 = 440{,}000{,}000\text{ bytes} = 440\text{ MB} \\ \text{FP16/BF16:} \quad &110{,}000{,}000 \times 2 = 220{,}000{,}000\text{ bytes} = 220\text{ MB} \\ \text{INT8:} \quad &110{,}000{,}000 \times 1 = 110{,}000{,}000\text{ bytes} = 110\text{ MB} \end{align}$$

Parameter breakdown for BERT-base:

$$\begin{align} \text{Token embeddings:} \quad &V \times d_{\text{model}} = 30{,}000 \times 768 = 23{,}040{,}000 \text{ params} \\ \text{Position embeddings:} \quad &n_{\max} \times d_{\text{model}} = 512 \times 768 = 393{,}216 \text{ params} \\ \text{Segment embeddings:} \quad &2 \times d_{\text{model}} = 2 \times 768 = 1{,}536 \text{ params} \end{align}$$

Per transformer layer:

$$\begin{align} \text{Self-attention:} \quad &4 \times d_{\text{model}}^2 = 4 \times 768^2 = 2{,}359{,}296 \text{ params} \\ \text{FFN:} \quad &2 \times d_{\text{model}} \times d_{ff} = 2 \times 768 \times 3072 = 4{,}718{,}592 \text{ params} \\ \text{Layer norms:} \quad &4 \times d_{\text{model}} = 4 \times 768 = 3{,}072 \text{ params} \\ \text{Total per layer:} \quad &7{,}080{,}960 \text{ params} \end{align}$$

12 layers:

$$ 12 \times 7{,}080{,}960 = 84{,}971{,}520 \text{ params} $$

Total BERT-base:

$$ 23{,}040{,}000 + 393{,}216 + 1{,}536 + 84{,}971{,}520 = 108{,}406{,}272 \approx 110\text{M params} $$

In FP32: $110\text{M} \times 4 = 440\text{ MB}$

Larger models scale dramatically:

$$\begin{align} \text{GPT-2 (1.5B):} \quad &1{,}500{,}000{,}000 \times 4 = 6{,}000\text{ MB} = 6\text{ GB (FP32)} \\ \text{GPT-3 (175B):} \quad &175{,}000{,}000{,}000 \times 4 = 700{,}000\text{ MB} = 700\text{ GB (FP32)} \end{align}$$

GPT-3 in FP32 requires 700 GB just for parameters—far exceeding single GPU memory (A100 has 80 GB). This necessitates:

Activation Memory

During training, intermediate activations must be stored for backpropagation. Activation memory scales with batch size and sequence length, often dominating memory consumption.

Activations per transformer layer:

Total activation memory per layer (approximate):

$$ \text{Memory} \approx B \times n \times (8d_{\text{model}} + d_{ff}) + B \times h \times n^2 $$

For BERT-base ($B=32$, $n=512$, $d_{\text{model}}=768$, $h=12$, $d_{ff}=3072$):

$$\begin{align} \text{Linear terms:} \quad &32 \times 512 \times (8 \times 768 + 3072) \times 4\text{ bytes} \\ &= 32 \times 512 \times 9{,}216 \times 4 = 603{,}979{,}776\text{ bytes} \approx 604\text{ MB} \\ \text{Attention matrices:} \quad &32 \times 12 \times 512^2 \times 4 = 402{,}653{,}184\text{ bytes} \approx 403\text{ MB} \\ \text{Total per layer:} \quad &\approx 1{,}007\text{ MB} \approx 1\text{ GB} \end{align}$$

For 12 layers: $12 \times 1\text{ GB} = 12\text{ GB}$ just for activations!

Impact of sequence length: The attention matrix term $B \times h \times n^2$ grows quadratically. For $n=2048$ (4× longer):

$$ 32 \times 12 \times 2048^2 \times 4 = 6{,}442{,}450{,}944\text{ bytes} \approx 6.4\text{ GB per layer} $$

For 12 layers: $77\text{ GB}$ just for attention matrices—nearly filling an A100 GPU!

This quadratic scaling is why:

Gradient checkpointing trade-off: Recomputing activations during backward pass:

Example: GPT-2 (small): $L=12$, $d_{\text{model}}=768$, $h=12$, $d_k=64$, $d_{ff}=3072$, $n=1024$

Per-layer activation breakdown (batch size $B=1$):

QKV projections:

$$ 3 \times 1024 \times 768 \times 4 = 9{,}437{,}184\text{ bytes} \approx 9.4\text{ MB} $$

Attention matrices (12 heads):

$$ 12 \times 1024^2 \times 4 = 50{,}331{,}648\text{ bytes} \approx 50.3\text{ MB} $$

This is the dominant term! For $n=2048$: $12 \times 2048^2 \times 4 = 201\text{ MB}$ (4× larger).

Attention output:

$$ 1024 \times 768 \times 4 = 3{,}145{,}728\text{ bytes} \approx 3.1\text{ MB} $$

FFN intermediate:

$$ 1024 \times 3072 \times 4 = 12{,}582{,}912\text{ bytes} \approx 12.6\text{ MB} $$

Layer norm and residuals:

$$ 3 \times 1024 \times 768 \times 4 = 9{,}437{,}184\text{ bytes} \approx 9.4\text{ MB} $$

Total per layer:

$$ 9.4 + 50.3 + 3.1 + 12.6 + 9.4 = 84.8\text{ MB} $$

12 layers: $12 \times 84.8 = 1{,}018\text{ MB} \approx 1\text{ GB}$ for single sequence

Batch size scaling:

$$\begin{align} B=8: \quad &8\text{ GB} \\ B=16: \quad &16\text{ GB} \\ B=32: \quad &32\text{ GB} \\ B=64: \quad &64\text{ GB} \end{align}$$

Hardware implications:

Gradient checkpointing impact: With checkpointing, only store activations at layer boundaries, recompute within layers during backward pass:

$$ \text{Memory reduction} \approx 80\% \Rightarrow 1\text{ GB} \to 200\text{ MB per sequence} $$

This allows batch size 64 on V100 (16 GB), but increases training time by $\sim$25\%.

Mixed precision training: Using FP16 for activations (FP32 for parameters):

$$ \text{Activation memory} \to 1\text{ GB} / 2 = 500\text{ MB per sequence} $$

Combined with gradient checkpointing: $500 \times 0.2 = 100\text{ MB per sequence}$, enabling very large batch sizes.

Training Memory Budget

Training requires memory for parameters, gradients, optimizer states, and activations. Understanding this breakdown is essential for selecting hardware and configuring training.

Total training memory:

$$ \text{Memory}_{\text{total}} = \text{Parameters} + \text{Gradients} + \text{Optimizer States} + \text{Activations} $$

For AdamW optimizer (most common for transformers):

Total: $16P + A$ bytes

Mixed precision training (FP16/BF16 with FP32 master weights):

Total: $18P + A/2$ bytes

Surprisingly, mixed precision uses slightly MORE memory for parameters/optimizer (18P vs 16P) but saves significantly on activations ($A/2$ vs $A$). Since activations often dominate, mixed precision typically reduces total memory.

Example: BERT-base: 110M parameters, batch size 32, sequence length 512

FP32 training:

$$\begin{align} \text{Parameters:} \quad &110\text{M} \times 4 = 440\text{ MB} \\ \text{Gradients:} \quad &110\text{M} \times 4 = 440\text{ MB} \\ \text{Adam states (2×):} \quad &2 \times 110\text{M} \times 4 = 880\text{ MB} \\ \text{Activations:} \quad &32 \times 12 \times 1\text{ GB} = 12\text{ GB} \\ \text{Total:} \quad &440 + 440 + 880 + 12{,}000 = 13{,}760\text{ MB} \approx 13.8\text{ GB} \end{align}$$

Fits on: NVIDIA V100 (16 GB), A100 (40/80 GB), RTX 3090 (24 GB)

Mixed precision training:

$$\begin{align} \text{FP16 parameters:} \quad &110\text{M} \times 2 = 220\text{ MB} \\ \text{FP32 master + gradients + Adam:} \quad &110\text{M} \times 16 = 1{,}760\text{ MB} \\ \text{FP16 activations:} \quad &12{,}000 / 2 = 6{,}000\text{ MB} \\ \text{Total:} \quad &220 + 1{,}760 + 6{,}000 = 7{,}980\text{ MB} \approx 8\text{ GB} \end{align}$$

Mixed precision saves: $13.8 - 8 = 5.8\text{ GB}$ (42\% reduction)

With gradient checkpointing: Activations reduced by 80\%:

$$ 220 + 1{,}760 + 1{,}200 = 3{,}180\text{ MB} \approx 3.2\text{ GB} $$

This enables batch size 128 on V100 (16 GB)!

Example: GPT-3: 175B parameters, sequence length 2048

Parameters and optimizer (FP32):

$$ 175\text{B} \times 16 = 2{,}800\text{ GB} $$

Activations (batch size 1, single sequence):

$$\begin{align} \text{Per layer:} \quad &\approx 2048 \times (8 \times 12{,}288 + 4 \times 12{,}288) + 96 \times 2048^2 \\ &\approx 2048 \times 147{,}456 + 402{,}653{,}184 \\ &\approx 704\text{ MB per layer} \end{align}$$

96 layers: $96 \times 704\text{ MB} \approx 68\text{ GB per sequence}$

Total for batch size 1: $2{,}800 + 68 = 2{,}868\text{ GB}$

Hardware requirements:

Actual GPT-3 training: Used ZeRO optimizer (shards optimizer states across GPUs) + model parallelism + pipeline parallelism across thousands of GPUs.

This example illustrates why training models beyond $\sim$10B parameters requires sophisticated distributed training strategies.

Hardware Selection Guide

GPU memory requirements by model size (mixed precision + gradient checkpointing):

Model SizeParametersMin GPU MemoryRecommended GPU
Small100M8 GBRTX 3070, V100
Base300M12 GBRTX 3080, V100
Large1B24 GBRTX 3090, A5000
XL3B40 GBA100 (40 GB)
XXL10B80 GBA100 (80 GB)
175B (GPT-3)175B8× A100 (80 GB)Multi-node cluster

Inference memory requirements (FP16):

For GPT-2 (117M params) inference:

$$ \text{Parameters:} \quad 117\text{M} \times 2 = 234\text{ MB} $$
$$ \text{KV cache (batch 1, } n=1024\text{):} \quad 2 \times 12 \times 12 \times 1024 \times 64 \times 2 = 38\text{ MB} $$
$$ \text{Total:} \quad \approx 300\text{ MB} $$

GPT-2 inference easily fits on consumer GPUs or even CPUs!

Inference Optimization

Inference optimization is critical for deploying transformers in production. Unlike training, which prioritizes throughput (tokens/second across large batches), inference prioritizes latency (time to generate a single response) while maintaining reasonable throughput. We analyze key optimization techniques and their trade-offs.

KV Caching for Autoregressive Decoding

Autoregressive generation (used in GPT, decoder-only models) generates tokens sequentially, where each new token attends to all previous tokens. Naive implementation recomputes attention for all previous positions at each step—highly inefficient.

Problem analysis: Generating sequence of length $T$ tokens:

Total attention computations: $\sum_{t=1}^{T} t = \frac{T(T+1)}{2} \approx \frac{T^2}{2}$

For $T=1000$ tokens: $\approx 500{,}000$ attention computations!

KV Caching solution: Key and value projections depend only on input tokens, not on the query position. Cache $\mK$ and $\mV$ from previous steps:

Algorithm: Autoregressive Generation with KV Caching
Input: Prompt tokens $\vx_1, \ldots, \vx_p$, max length $T$
Output: Generated sequence $\vx_1, \ldots, \vx_T$
// Initialize cache
$\text{cache}_K = []$, $\text{cache}_V = []$ \\

// Process prompt
for $t = 1$ to $p$ do
$\vk_t = \mW^K \vx_t$, $\vv_t = \mW^V \vx_t$
Append $\vk_t$ to $\text{cache}_K$, $\vv_t$ to $\text{cache}_V$
$\vq_t = \mW^Q \vx_t$
Compute attention using $\vq_t$ and all cached keys/values
Generate $\vh_t$
}

// Generate new tokens
for $t = p+1$ to $T$ do
Sample $\vx_t$ from $\vh_{t-1}$
$\vk_t = \mW^K \vx_t$, $\vv_t = \mW^V \vx_t$
Append $\vk_t$ to $\text{cache}_K$, $\vv_t$ to $\text{cache}_V$
$\vq_t = \mW^Q \vx_t$
Compute attention: $\text{Attention}(\vq_t, \text{cache}_K, \text{cache}_V)$
Generate $\vh_t$

Computational savings: With caching, each step computes attention once (not recomputing previous positions):

$$ \text{Total computations} = T \quad \text{(vs. } \frac{T^2}{2} \text{ without caching)} $$

Speedup: For $T=1000$: $\frac{500{,}000}{1{,}000} = 500\times$ faster!

Memory cost: Store keys and values for all positions and layers:

$$ \text{KV cache size} = 2 \times L \times h \times T \times d_k \times \text{sizeof(float)} $$

For GPT-2 ($L=12$, $h=12$, $d_k=64$, FP16):

$$ 2 \times 12 \times 12 \times T \times 64 \times 2 = 36{,}864 \times T \text{ bytes} $$

Memory scaling with sequence length:

$$\begin{align} T=512: \quad &36{,}864 \times 512 = 18{,}874{,}368\text{ bytes} \approx 19\text{ MB} \\ T=1024: \quad &36{,}864 \times 1024 = 37{,}748{,}736\text{ bytes} \approx 38\text{ MB} \\ T=2048: \quad &36{,}864 \times 2048 = 75{,}497{,}472\text{ bytes} \approx 75\text{ MB} \\ T=4096: \quad &36{,}864 \times 4096 = 150{,}994{,}944\text{ bytes} \approx 151\text{ MB} \end{align}$$

For GPT-3, the cache reaches 9.7~GB per sequence (see Chapter~[ref] for the derivation), nearly filling an A100 for batch size 8 ($77.6$~GB). Practitioners manage this through smaller batch sizes for long contexts, dynamic batching, and INT8 quantization to reduce cache size by 2--4$\times$.

Batched Inference

Processing multiple sequences simultaneously increases GPU utilization and throughput.

Single sequence inference: For GPT-2 generating 100 tokens:

Batched inference (batch size 32):

Latency vs. throughput trade-off:

Padding challenge: Sequences in a batch must have the same length. Shorter sequences are padded, wasting computation:

Solutions:

Quantization for Inference

Quantization reduces memory and increases throughput by using lower-precision arithmetic.

Precision options:

INT8 quantization: Map FP32 weights $w \in [-w_{\max}, w_{\max}]$ to INT8 $w_q \in [-128, 127]$:

$$ w_q = \text{round}\left(\frac{w}{w_{\max}} \times 127\right) $$

Dequantize during computation:

$$ w \approx \frac{w_q \times w_{\max}}{127} $$

Quantization impact on GPT-2:

$$\begin{align} \text{FP32:} \quad &117\text{M} \times 4 = 468\text{ MB} \\ \text{FP16:} \quad &117\text{M} \times 2 = 234\text{ MB} \quad (2\times \text{ reduction}) \\ \text{INT8:} \quad &117\text{M} \times 1 = 117\text{ MB} \quad (4\times \text{ reduction}) \\ \text{INT4:} \quad &117\text{M} \times 0.5 = 58.5\text{ MB} \quad (8\times \text{ reduction}) \end{align}$$

Accuracy trade-offs:

Hardware support:

Model Distillation

Train smaller "student" model to mimic larger "teacher" model:

Distillation enables deployment on resource-constrained devices (mobile, edge).

Inference Optimization Summary

TechniqueSpeedupMemory ReductionAccuracy Impact
KV Caching100-500×-50\% (cache overhead)None
Batching (32×)2-3× throughputNoneNone
FP16/BF161.5-2×Negligible
INT8 Quantization2-4×Small (0.5-2\%)
INT4 Quantization4-8×Moderate (2-5\%)
Distillation2-7×2-8×Small (3-4\%)

Combined optimizations: KV caching + FP16 + batching + INT8 can achieve 1000× speedup with minimal accuracy loss!

Scaling Laws

Performance scales as power laws with model size $N$, dataset size $D$, and compute budget $C$. The Kaplan scaling laws established that larger models are more sample-efficient, while the Chinchilla scaling laws (Hoffmann et al., 2022) refined the optimal allocation: $N_{\text{opt}} \propto C^{0.5}$ and $D_{\text{opt}} \propto C^{0.5}$, implying that many large models are over-parameterized and under-trained. See Section~[ref] for the full treatment including GPT-3 vs Chinchilla analysis and practical implications for training strategy.

Exercises

Exercise 1: Calculate FLOPs for GPT-3 (175B parameters, $L=96$, $d=12288$, $h=96$, $n=2048$) for: (1) Single forward pass, (2) Generating 100 tokens autoregressively, (3) Training on 1 trillion tokens.
Exercise 2: Estimate memory for training 1.3B parameter model with batch size 64, sequence length 2048. What GPU memory required? How to fit on A100 (80GB)?
Exercise 3: Implement KV caching for GPT-2. Measure speedup for generating 256 tokens. Plot generation time vs sequence length with/without caching.
Exercise 4: For fixed compute budget $C = 10^{24}$ FLOPs: Use Chinchilla scaling to find optimal model size and data size. Compare with GPT-3 allocation.

Solutions

Full solutions for all exercises are available at \url{https://deeplearning.hofkensvermeulen.be}.

Solution: Exercise 1: GPT-3 FLOPs Calculation

Given: GPT-3 with $P = 175B$ parameters, $L = 96$ layers, $d_{\text{model}} = 12{,}288$, $h = 96$ heads, $n = 2048$ sequence length

Part (1): Single Forward Pass

For a transformer, FLOPs per forward pass: $$\text{FLOPs}_{\text{fwd}} = 2 \times B \times n \times P$$

where $B$ is batch size. For $B = 1$:

$$\begin{align*} \text{FLOPs}_{\text{fwd}} &= 2 \times 1 \times 2048 \times 175 \times 10^9 \\ &= 716{,}800 \times 10^9 \\ &= 7.168 \times 10^{14} \text{ FLOPs} \\ &= 716.8 \text{ TFLOPs} \end{align*}$$

Breakdown by component:

Note: Attention is only 14\% of total computation due to large $d_{ff} = 4d$.

Part (2): Generating 100 Tokens Autoregressively

For autoregressive generation, each token requires a forward pass through the decoder.

With KV caching, computation per token $t$: $$\text{FLOPs}_t = 2 \times P + 2 \times L \times d \times n_{\text{ctx}}$$

where $n_{\text{ctx}}$ is the context length (grows with each token).

Without KV caching (recomputing everything): $$\text{FLOPs}_{\text{total}} = \sum_{t=1}^{100} 2 \times (n_0 + t) \times P$$

where $n_0$ is initial prompt length. Assuming $n_0 = 50$:

$$\begin{align*} \text{FLOPs}_{\text{total}} &= 2P \sum_{t=1}^{100} (50 + t) \\ &= 2P \times (50 \times 100 + \frac{100 \times 101}{2}) \\ &= 2P \times (5000 + 5050) \\ &= 2 \times 175 \times 10^9 \times 10{,}050 \\ &= 3.52 \times 10^{15} \text{ FLOPs} \\ &= 3.52 \text{ PFLOPs} \end{align*}$$

With KV caching (optimal): $$\text{FLOPs}_{\text{cached}} \approx 100 \times 2P = 100 \times 2 \times 175 \times 10^9 = 3.5 \times 10^{13} = 35 \text{ TFLOPs}$$

Speedup from KV caching: $\frac{3.52 \times 10^{15}}{3.5 \times 10^{13}} \approx 100\times$

Part (3): Training on 1 Trillion Tokens

Training FLOPs formula: $$\text{FLOPs}_{\text{train}} = 6 \times P \times D$$

where $D$ is the number of training tokens.

For $D = 1$ trillion = $10^{12}$ tokens:

$$\begin{align*} \text{FLOPs}_{\text{train}} &= 6 \times 175 \times 10^9 \times 10^{12} \\ &= 1.05 \times 10^{24} \text{ FLOPs} \\ &= 1{,}050 \text{ ZFLOPs (zettaFLOPs)} \end{align*}$$

Training time estimation:

On 1024 A100 GPUs (312 TFLOPS each):

Cost estimation:

At \$2.50/GPU-hour (cloud pricing): $$\text{Cost} = 1024 \times 76 \times 24 \times 2.50 = \$4{,}669{,}440 \approx \$4.7M$$

Summary:

Solution: Exercise 2: Memory Estimation for 1.3B Parameter Model

Given: $P = 1.3B$ parameters, batch size $B = 64$, sequence length $L = 2048$

Model Parameters:

Parameters (FP32): $1.3 \times 10^9 \times 4 = 5.2$GB

Optimizer States (AdamW):

Activations:

Assuming model architecture: $d_{\text{model}} = 2048$, $L_{\text{layers}} = 24$, $d_{ff} = 8192$

Per layer activations:

For 24 layers: $24 \times 41.9 = 1{,}005.6$GB

With gradient checkpointing (store every 4 layers): $$\text{Activations} = \frac{24}{4} \times 41.9 = 251.4\text{GB}$$

Total Memory Required:

Without checkpointing: $5.2 + 15.6 + 1{,}005.6 = 1{,}026.4$GB

With checkpointing: $5.2 + 15.6 + 251.4 = 272.2$GB

Fitting on A100 (80GB):

Current requirement: 272.2GB (too large!)

Strategy 1: Reduce Batch Size

Try $B = 16$ (4$\times$ reduction): $$\text{Activations} = \frac{251.4}{4} = 62.9\text{GB}$$ $$\text{Total} = 5.2 + 15.6 + 62.9 = 83.7\text{GB}$$ (still too large)

Try $B = 8$: $$\text{Activations} = 31.4\text{GB}$$ $$\text{Total} = 5.2 + 15.6 + 31.4 = 52.2\text{GB}$$ ✓ Fits!

Strategy 2: Mixed Precision (FP16)

Parameters (FP16): $1.3 \times 10^9 \times 2 = 2.6$GB Optimizer states: $7.8$GB (master weights in FP32) Activations (FP16, $B=16$): $31.4$GB Total: $2.6 + 7.8 + 31.4 = 41.8$GB ✓ Fits!

Strategy 3: ZeRO Stage 2 (Optimizer State Sharding)

With 4 GPUs, shard optimizer states: - Parameters: $5.2$GB per GPU - Optimizer states: $15.6/4 = 3.9$GB per GPU - Activations ($B=64$): $62.9$GB per GPU - Total per GPU: $5.2 + 3.9 + 62.9 = 72.0$GB ✓ Fits!

Recommended Configuration:

chapters/chapter12_computational_analysis.texin{solution} Exercise 3: KV Caching Implementation

import torch
import torch.nn as nn
import time
import matplotlib.pyplot as plt

class GPT2WithCache(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.transformer = nn.ModuleList([
            GPT2Block(config) for _ in range(config.n_layer)
        ])
        self.ln_f = nn.LayerNorm(config.n_embd)
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size)
    
    def forward(self, input_ids, past_key_values=None, use_cache=False):
        hidden_states = self.wte(input_ids) + self.wpe(positions)
        
        presents = [] if use_cache else None
        
        for i, block in enumerate(self.transformer):
            past = past_key_values[i] if past_key_values else None
            hidden_states, present = block(
                hidden_states, 
                past_key_value=past,
                use_cache=use_cache
            )
            if use_cache:
                presents.append(present)
        
        hidden_states = self.ln_f(hidden_states)
        logits = self.lm_head(hidden_states)
        
        return logits, presents

class GPT2Block(nn.Module):
    def forward(self, x, past_key_value=None, use_cache=False):
        # Self-attention with optional KV cache
        attn_output, present = self.attn(
            x, 
            past_key_value=past_key_value,
            use_cache=use_cache
        )
        x = x + attn_output
        x = self.ln_1(x)
        
        # Feed-forward
        x = x + self.mlp(x)
        x = self.ln_2(x)
        
        return x, present

class GPT2Attention(nn.Module):
    def forward(self, x, past_key_value=None, use_cache=False):
        B, T, C = x.shape
        
        # Compute Q, K, V
        q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
        
        # Reshape for multi-head attention
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        
        # Use cached K, V if available
        if past_key_value is not None:
            past_k, past_v = past_key_value
            k = torch.cat([past_k, k], dim=2)
            v = torch.cat([past_v, v], dim=2)
        
        # Store K, V for next iteration
        present = (k, v) if use_cache else None
        
        # Attention computation
        attn = (q @ k.transpose(-2, -1)) / math.sqrt(k.size(-1))
        attn = F.softmax(attn, dim=-1)
        out = attn @ v
        
        # Reshape and project
        out = out.transpose(1, 2).contiguous().view(B, T, C)
        out = self.c_proj(out)
        
        return out, present

Benchmarking Code:

def generate_without_cache(model, prompt, max_length=256):
    """Generate tokens without KV caching"""
    input_ids = prompt
    times = []
    
    for _ in range(max_length):
        start = time.time()
        logits, _ = model(input_ids, use_cache=False)
        times.append(time.time() - start)
        
        next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True)
        input_ids = torch.cat([input_ids, next_token], dim=1)
    
    return input_ids, times

def generate_with_cache(model, prompt, max_length=256):
    """Generate tokens with KV caching"""
    input_ids = prompt
    past_key_values = None
    times = []
    
    for i in range(max_length):
        start = time.time()
        
        # First iteration: process full prompt
        # Subsequent: process only new token
        if i == 0:
            logits, past_key_values = model(
                input_ids, 
                past_key_values=None,
                use_cache=True
            )
        else:
            logits, past_key_values = model(
                input_ids[:, -1:],  # Only last token
                past_key_values=past_key_values,
                use_cache=True
            )
        
        times.append(time.time() - start)
        
        next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True)
        input_ids = torch.cat([input_ids, next_token], dim=1)
    
    return input_ids, times

# Run benchmark
prompt = torch.randint(0, 50257, (1, 50))  # 50 token prompt

output_no_cache, times_no_cache = generate_without_cache(
    model, prompt, max_length=256
)
output_with_cache, times_with_cache = generate_with_cache(
    model, prompt, max_length=256
)

print(f"Without cache: {sum(times_no_cache):.2f}s")
print(f"With cache: {sum(times_with_cache):.2f}s")
print(f"Speedup: {sum(times_no_cache)/sum(times_with_cache):.2f}x")

Experimental Results:

For GPT-2 small (124M parameters), generating 256 tokens:

MethodTime (s)Speedup
Without cache45.31.0$\times$
With cache2.816.2$\times$

Generation Time vs Sequence Length:

# Benchmark different sequence lengths
seq_lengths = [32, 64, 128, 256, 512]
times_no_cache = []
times_with_cache = []

for length in seq_lengths:
    _, t_no = generate_without_cache(model, prompt, max_length=length)
    _, t_with = generate_with_cache(model, prompt, max_length=length)
    times_no_cache.append(sum(t_no))
    times_with_cache.append(sum(t_with))

# Plot results
plt.figure(figsize=(10, 6))
plt.plot(seq_lengths, times_no_cache, 'o-', label='Without cache', linewidth=2)
plt.plot(seq_lengths, times_with_cache, 's-', label='With cache', linewidth=2)
plt.xlabel('Sequence Length')
plt.ylabel('Generation Time (seconds)')
plt.title('KV Caching Impact on Generation Speed')
plt.legend()
plt.grid(True)
plt.savefig('kv_cache_speedup.png')

Analysis:

Without caching, time complexity: $O(n^2)$ where $n$ is sequence length $$T_{\text{no cache}} = \sum_{i=1}^n c \cdot i = c \cdot \frac{n(n+1)}{2} \approx O(n^2)$$

With caching, time complexity: $O(n)$ $$T_{\text{cache}} = c \cdot n$$

Speedup grows with sequence length: $$\text{Speedup}(n) = \frac{n(n+1)/2}{n} = \frac{n+1}{2} \approx O(n)$$

For $n=256$: Speedup $\approx 128/2 = 64\times$ (theoretical)

Actual speedup (16.2$\times$) is lower due to:

Memory Cost:

KV cache size: $2 \times L \times B \times n \times d = 2 \times 12 \times 1 \times 256 \times 768 = 4.7$MB

Small memory cost for massive speedup makes KV caching essential for inference.

Solution: Exercise 4: Chinchilla Scaling Laws

Given: Fixed compute budget $C = 10^{24}$ FLOPs

Chinchilla Scaling Law:

For optimal training, model size $N$ (parameters) and dataset size $D$ (tokens) should scale as: $$N_{\text{opt}} \propto C^{0.5}, \quad D_{\text{opt}} \propto C^{0.5}$$

More precisely, the Chinchilla paper found: $$N_{\text{opt}} = \left(\frac{C}{6}\right)^{0.5} \times a, \quad D_{\text{opt}} = \left(\frac{C}{6}\right)^{0.5} \times b$$

where $a \approx 0.29$ and $b \approx 1.71$ are empirically determined constants.

Optimal Allocation for $C = 10^{24}$ FLOPs:

Training FLOPs: $C = 6ND$

Solving for optimal $N$ and $D$: $$N_{\text{opt}} = \left(\frac{C}{6 \times 20}\right)^{0.5} = \left(\frac{10^{24}}{120}\right)^{0.5} = 2.89 \times 10^{11} \approx 289B \text{ parameters}$$

$$D_{\text{opt}} = \frac{C}{6N_{\text{opt}}} = \frac{10^{24}}{6 \times 2.89 \times 10^{11}} = 5.77 \times 10^{11} \approx 577B \text{ tokens}$$

Verification: $6 \times 289 \times 10^9 \times 577 \times 10^9 = 1.00 \times 10^{24}$ ✓

Chinchilla Optimal Ratio:

$$\frac{D_{\text{opt}}}{N_{\text{opt}}} = \frac{577B}{289B} \approx 2.0$$

Chinchilla recommends: 2 tokens per parameter

GPT-3 Allocation:

GPT-3 used: $N = 175B$ parameters, $D = 300B$ tokens

Compute used: $C_{\text{GPT-3}} = 6 \times 175 \times 10^9 \times 300 \times 10^9 = 3.15 \times 10^{23}$ FLOPs

For the same compute budget ($C = 10^{24}$), GPT-3 approach would scale to: $$N_{\text{GPT-3}} = 175B \times \left(\frac{10^{24}}{3.15 \times 10^{23}}\right)^{0.5} = 175B \times 1.78 = 311B$$ $$D_{\text{GPT-3}} = 300B \times 1.78 = 534B$$

Ratio: $\frac{534B}{311B} = 1.72$ tokens per parameter

Comparison:

ApproachParametersTokensRatio
Chinchilla optimal289B577B2.0
GPT-3 scaling311B534B1.72
Difference-7\%+8\%-

Key Insights:

  1. GPT-3 was undertrained: Used only 1.72 tokens/param vs optimal 2.0
  2. Chinchilla approach: Smaller model, more data
  3. Performance impact: Chinchilla-optimal models achieve better performance at same compute
  4. Practical implications:
    • Training cost dominated by data, not model size
    • Larger models need proportionally more data
    • Many large models (GPT-3, Gopher) were undertrained

Expected Performance:

Using Chinchilla scaling law for loss prediction: $$L(N, D) = E + \frac{A}{N^\alpha} + \frac{B}{D^\beta}$$

where $\alpha \approx 0.34$, $\beta \approx 0.28$, $E \approx 1.69$ (irreducible loss).

For optimal allocation: $$L_{\text{Chinchilla}} \approx 1.69 + \frac{406.4}{289^{0.34}} + \frac{410.7}{577^{0.28}} \approx 2.15$$

For GPT-3 scaling: $$L_{\text{GPT-3}} \approx 1.69 + \frac{406.4}{311^{0.34}} + \frac{410.7}{534^{0.28}} \approx 2.18$$

Chinchilla optimal achieves 1.4\% lower loss with same compute budget.

← Chapter 11: Training Transformers 📚 Table of Contents Chapter 13: BERT →