Recurrent Neural Networks
Chapter Overview
Recurrent Neural Networks (RNNs) process sequential data by maintaining hidden states that capture information from previous time steps. This chapter develops RNNs from basic recurrence to modern architectures like LSTMs and GRUs, establishing foundations for understanding transformers.
Learning Objectives
- Understand recurrent architectures for sequential data
- Implement vanilla RNNs, LSTMs, and GRUs
- Understand vanishing/exploding gradient problems
- Apply RNNs to sequence modeling tasks
- Understand bidirectional and multi-layer RNNs
Vanilla RNNs
- $\mW_{hh} \in \R^{h \times h}$: hidden-to-hidden weights
- $\mW_{xh} \in \R^{h \times d}$: input-to-hidden weights
- $\mW_{hy} \in \R^{k \times h}$: hidden-to-output weights
- $\vh_0$ initialized (often zeros)
\node[node] (h0) at (0,0) {$\vh_0$}; \node[node] (h1) at (3,0) {$\vh_1$}; \node[node] (h2) at (6,0) {$\vh_2$}; \node[node] (h3) at (9,0) {$\vh_3$};
\node[input] (x1) at (3,1.5) {$\vx_1$}; \node[input] (x2) at (6,1.5) {$\vx_2$}; \node[input] (x3) at (9,1.5) {$\vx_3$};
\node[input] (y1) at (3,-1.5) {$\vy_1$}; \node[input] (y2) at (6,-1.5) {$\vy_2$}; \node[input] (y3) at (9,-1.5) {$\vy_3$};
\draw[recurrent] (h0) -- (h1); \draw[recurrent] (h1) -- (h2); \draw[recurrent] (h2) -- (h3);
\draw[arrow, blue!70] (x1) -- (h1); \draw[arrow, blue!70] (x2) -- (h2); \draw[arrow, blue!70] (x3) -- (h3);
\draw[arrow] (h1) -- (y1); \draw[arrow] (h2) -- (y2); \draw[arrow] (h3) -- (y3);
\end{tikzpicture}
Input sequence: "hello" encoded as one-hot vectors $\vx_1, \ldots, \vx_5 \in \R^5$
Initialize: $\vh_0 = [0, 0, 0]\transpose$
Time step 1: Process 'h'
Time step 2: Process 'e' using $\vh_1$
Hidden state $\vh_t$ carries information from all previous time steps.
Backpropagation Through Time (BPTT)
Vanishing and Exploding Gradients
The fundamental challenge in training RNNs on long sequences arises from the multiplicative nature of gradient backpropagation through time. When computing gradients with respect to early hidden states, the chain rule requires multiplying Jacobian matrices across all intermediate time steps, leading to exponential growth or decay of gradient magnitudes.
The gradient of the loss with respect to an early hidden state $\vh_0$ involves the product of Jacobians across all time steps:
For a sequence of length $T = 100$, if $\norm{\mW_{hh}} = 0.95$ (slightly less than 1), the gradient magnitude decays as $0.95^{100} \approx 0.006$, reducing gradients by a factor of 167. If $\norm{\mW_{hh}} = 0.9$, the decay is $0.9^{100} \approx 2.7 \times 10^{-5}$, reducing gradients by a factor of 37,000. This exponential decay makes it nearly impossible for the network to learn long-range dependencies: the gradient signal from time step 100 is effectively zero by the time it reaches time step 0. In practice, vanilla RNNs struggle to learn dependencies longer than 10-20 time steps due to vanishing gradients.
Conversely, if $\norm{\mW_{hh}} = 1.05$, the gradient magnitude grows as $1.05^{100} \approx 131.5$, amplifying gradients by a factor of 131. If $\norm{\mW_{hh}} = 1.1$, the growth is $1.1^{100} \approx 13{,}781$, causing gradients to explode. Exploding gradients lead to numerical overflow (NaN values) and training instability, where loss suddenly spikes to infinity. While gradient clipping (capping gradient norms at a threshold like 1.0) provides a practical solution for exploding gradients, it does not address the fundamental problem of vanishing gradients.
The vanishing gradient problem is particularly severe because the spectral norm of $\mW_{hh}$ must be precisely 1.0 to avoid both vanishing and exploding gradients, and maintaining this property during training is extremely difficult. Initialization schemes like orthogonal initialization set $\mW_{hh}$ to have spectral norm 1.0 initially, but gradient descent updates quickly perturb this property. Even with careful initialization, vanilla RNNs rarely learn dependencies beyond 20-30 time steps in practice.
Quantitative Analysis of Gradient Decay
To understand the severity of vanishing gradients, consider a concrete example with BERT-base dimensions. Suppose we have a vanilla RNN with hidden dimension $h = 768$ (matching BERT-base) and sequence length $n = 512$ (BERT's maximum sequence length). The recurrence matrix $\mW_{hh} \in \R^{768 \times 768}$ has 589,824 parameters. If we initialize $\mW_{hh}$ orthogonally (spectral norm exactly 1.0) and the $\tanh$ derivatives average 0.5 (typical for non-saturated activations), the effective Jacobian norm per time step is approximately $1.0 \times 0.5 = 0.5$.
Over 512 time steps, the gradient magnitude decays as $0.5^{512} \approx 10^{-154}$, which is far below machine precision for FP32 (approximately $10^{-38}$) or even FP64 (approximately $10^{-308}$). The gradient effectively becomes exactly zero after about 130 time steps in FP32 or 1,000 time steps in FP64. This means a vanilla RNN cannot learn any dependencies spanning more than 130 tokens when using FP32 arithmetic, regardless of optimization algorithm or learning rate. The mathematical structure of the recurrence fundamentally limits the learnable dependency length.
For comparison, consider the gradient flow in a transformer with the same dimensions. The self-attention mechanism computes attention scores $\mA = \text{softmax}(\frac{\mQ\mK\transpose}{\sqrt{d_k}})$ and outputs $\mO = \mA\mV$. The gradient $\frac{\partial L}{\partial \mV}$ flows directly from the output through the attention weights, without any multiplicative accumulation across time steps. The gradient magnitude remains approximately constant regardless of sequence length, enabling transformers to learn dependencies spanning thousands of tokens. This fundamental difference in gradient flow explains why transformers replaced RNNs for nearly all sequence modeling tasks: they solve the vanishing gradient problem by design.
The LSTM architecture addresses vanishing gradients through its cell state mechanism, which provides an additive path for gradient flow. The cell state update $\mathbf{c}_t = \vf_t \odot \mathbf{c}_{t-1} + \vi_t \odot \tilde{\mathbf{c}}_t$ includes an additive term rather than purely multiplicative updates. The gradient with respect to $\mathbf{c}_{t-1}$ is:
Long Short-Term Memory (LSTM)
Key components:
- Cell state $\mathbf{c}_t$: Long-term memory, flows with minimal modification
- Forget gate $\vf_t$: What to remove from cell state
- Input gate $\vi_t$: What new information to store
- Output gate $\vo_t$: What to output from cell state
\node[state] (input) at (0,4) {$[\vh_{t-1}, \vx_t]$};
\node[gate] (forget) at (-3,2.5) {$\sigma$}; \node[gate] (input_gate) at (-1,2.5) {$\sigma$}; \node[gate] (candidate) at (1,2.5) {$\tanh$}; \node[gate] (output_gate) at (3,2.5) {$\sigma$};
\node[state] (c_prev) at (-4,0) {$\mathbf{c}_{t-1}$}; \node[state] (c_curr) at (2,0) {$\mathbf{c}_t$};
\node[operation] (mult1) at (-3,0) {$\odot$}; \node[operation] (mult2) at (1,0) {$\odot$}; \node[operation] (add) at (-1,0) {$+$}; \node[operation] (tanh_c) at (2,-1.5) {$\tanh$}; \node[operation] (mult3) at (3,-1.5) {$\odot$};
\node[state] (h_out) at (3,-3) {$\vh_t$};
\draw[arrow] (input) -- (-3,3.2); \draw[arrow] (input) -- (-1,3.2); \draw[arrow] (input) -- (1,3.2); \draw[arrow] (input) -- (3,3.2);
\draw[arrow] (forget) -- (mult1); \draw[arrow] (input_gate) -- (mult2); \draw[arrow] (candidate) -- (mult2); \draw[arrow] (output_gate) -- (mult3);
\draw[cell_flow] (c_prev) -- (mult1); \draw[cell_flow] (mult1) -- (add); \draw[cell_flow] (add) -- (c_curr); \draw[arrow] (mult2) -- (add);
\draw[arrow] (c_curr) -- (tanh_c); \draw[arrow] (tanh_c) -- (mult3); \draw[arrow] (mult3) -- (h_out);
\end{tikzpicture}
Each gate has weight matrix for $[\vh_{t-1}, \vx_t] \in \R^{h+d}$:
LSTM has 4 gates (forget, input, cell, output):
Compare to transformer attention with same dimensions: often fewer parameters and better parallelization!
LSTM Computational Analysis
Understanding the computational cost of LSTMs is essential for comparing them to transformers and explaining why transformers have become dominant despite LSTMs' theoretical advantages for sequential processing. The LSTM's gating mechanisms provide powerful modeling capabilities but come with significant computational overhead that limits their efficiency on modern hardware.
For an LSTM with input dimension $d$ and hidden dimension $h$, each time step requires computing four gates (forget, input, candidate, output), each involving a matrix multiplication with the concatenated input $[\vh_{t-1}, \vx_t] \in \R^{h+d}$. The computational cost per time step is:
This computational cost is deceptively modest compared to transformers. A single transformer layer with the same dimensions requires approximately 12.9 GFLOPs for self-attention (with $n = 512$) plus 9.4 GFLOPs for the feed-forward network, totaling 22.3 GFLOPs—about 4.6× more than the LSTM. However, this comparison is misleading because it ignores the critical difference in parallelization: the transformer can process all 512 positions simultaneously, while the LSTM must process them sequentially.
The sequential nature of LSTMs means that the 4.8 GFLOPs cannot be parallelized across time steps. On an NVIDIA A100 GPU with peak throughput of 312 TFLOPS (FP16), the theoretical minimum time to process a sequence of length 512 is $\frac{4.8 \times 10^9}{312 \times 10^{12}} = 15.4$ microseconds if we could achieve perfect parallelization. However, the sequential dependency forces us to process one time step at a time, with each step taking approximately $\frac{9.4 \times 10^6}{312 \times 10^{12}} = 0.03$ microseconds at peak throughput. In practice, small matrix multiplications achieve only 1-5\% of peak throughput due to insufficient parallelism, so each time step actually takes approximately 1-3 microseconds, giving a total sequence processing time of 512-1,536 microseconds (0.5-1.5 milliseconds).
For comparison, a transformer layer can process the entire sequence in parallel. The self-attention computation requires three matrix multiplications ($\mQ = \mX\mW_Q$, $\mK = \mX\mW_K$, $\mV = \mX\mW_V$) with dimensions $512 \times 768 \times 768$, followed by the attention score computation $\mA = \text{softmax}(\frac{\mQ\mK\transpose}{\sqrt{d_k}})$ and output computation $\mO = \mA\mV$. These operations can be batched into large matrix multiplications that achieve 40-60\% of peak GPU throughput, completing in approximately 50-100 microseconds total. The transformer is 5-30× faster than the LSTM despite having more FLOPs, purely due to better parallelization.
The memory requirements for LSTM hidden states are modest compared to transformer attention matrices. For batch size $B$ and sequence length $n$, the LSTM must store hidden states $\vh_t \in \R^{B \times h}$ and cell states $\mathbf{c}_t \in \R^{B \times h}$ for each time step, requiring $2Bnh \times 4 = 8Bnh$ bytes in FP32. For BERT-base dimensions with $B = 32$, $n = 512$, $h = 768$, this totals $8 \times 32 \times 512 \times 768 = 100{,}663{,}296$ bytes, or approximately 96 MB. This is substantially less than the 384 MB required for transformer attention scores in a single layer, making LSTMs more memory-efficient for long sequences.
However, this memory advantage is offset by the sequential processing requirement. While transformers can trade memory for speed by using gradient checkpointing (recomputing activations during the backward pass rather than storing them), LSTMs cannot benefit from this technique as effectively because the sequential dependency prevents parallelization of the recomputation. Gradient checkpointing reduces transformer memory by 3-5× with only 20-30\% slowdown, but for LSTMs, the slowdown is 2-3× because the recomputation cannot be parallelized. This makes gradient checkpointing less attractive for LSTMs, limiting their ability to scale to very long sequences.
Gated Recurrent Unit (GRU)
Advantages over LSTM:
- Fewer parameters (3 gates vs 4)
- Simpler architecture
- Often similar performance
- Faster training
Bidirectional RNNs
Bidirectional RNNs capture context from both past and future, useful when entire sequence is available (not for online/causal tasks).
Example: BERT uses bidirectional transformers (attention, not RNN), capturing full context.
RNN Applications
Sequence-to-Sequence:
- Machine translation: Encoder RNN $\to$ Decoder RNN
- Text summarization
- Speech recognition
Sequence Labeling:
- Part-of-speech tagging
- Named entity recognition
- Output at each time step
Sequence Generation:
- Language modeling
- Music generation
- Sample from output distribution
RNNs vs Transformers: A Computational Comparison
The transition from RNNs to transformers represents one of the most significant architectural shifts in deep learning. The following table summarizes the key computational differences:
| Property | RNNs | Transformers |
|---|---|---|
| Computation | Sequential ($n$ steps) | Parallel (constant depth) |
| Memory scaling | $O(nd)$ | $O(n^2 + nd)$ |
| GPU utilization | 1--5\% | 40--60\% |
| Memory bandwidth | Reload weights each step | Load weights once |
| Training time (BERT-scale) | Estimated 100--200 days | 4 days |
The fundamental bottleneck of RNNs is sequential processing: each hidden state $\vh_t$ depends on $\vh_{t-1}$, preventing parallelization across time steps. On an A100 GPU with 6,912 CUDA cores, an LSTM processing batch size 32 utilizes only $\sim$0.5\% of parallel capacity. Transformers eliminate this bottleneck by computing all positions simultaneously via matrix multiplication, achieving 15$\times$ or greater speedup despite having more total FLOPs.
The parallelization advantage compounds with hardware efficiency: transformers achieve high arithmetic intensity through data reuse in matrix multiplications ($\sim$256 FLOPs/byte), while RNNs perform small matrix-vector products with low data reuse ($\sim$1--10 FLOPs/byte). The combined effect is 100--500$\times$ faster training for equivalent model capacity.
Exercises
Solutions
Full solutions for all exercises are available at \url{https://deeplearning.hofkensvermeulen.be}.
(1) Parameter count:
- $\mW_{xh} \in \R^{h \times d}$: $256 \times 128 = 32{,}768$ parameters
- $\mW_{hh} \in \R^{h \times h}$: $256 \times 256 = 65{,}536$ parameters
- $\mW_{hy} \in \R^{V \times h}$: Assuming vocabulary $V=10{,}000$: $10{,}000 \times 256 = 2{,}560{,}000$ parameters
- Biases: $h + h + V = 256 + 256 + 10{,}000 = 10{,}512$ parameters
- Total: $32{,}768 + 65{,}536 + 2{,}560{,}000 + 10{,}512 = 2{,}668{,}816$ parameters
(2) FLOPs for forward pass: Per time step:
- $\mW_{xh}\vx_t$: $2 \times h \times d = 2 \times 256 \times 128 = 65{,}536$ FLOPs
- $\mW_{hh}\vh_{t-1}$: $2 \times h \times h = 2 \times 256 \times 256 = 131{,}072$ FLOPs
- $\tanh$ activation: $h = 256$ FLOPs
- Per time step total: $\approx 196{,}864$ FLOPs
For $T=50$ time steps: $50 \times 196{,}864 = 9{,}843{,}200 \approx 9.8$ MFLOPs
(3) GPU utilization with batch size 32:
- Total FLOPs per batch: $32 \times 9.8 \text{ MFLOPs} = 313.6$ MFLOPs
- At 2\% peak throughput: $0.02 \times 312 \text{ TFLOPS} = 6.24$ TFLOPS
- Time per batch: $\frac{313.6 \text{ MFLOPs}}{6.24 \text{ TFLOPS}} = 0.05$ ms
Why utilization is so low:
- Sequential dependency: Each time step depends on previous, preventing parallelization
- Small matrix operations: $256\times256$ matrices don't saturate GPU
- Memory-bound: Constantly loading/storing hidden states
- Low arithmetic intensity: Few operations per memory access
- Kernel launch overhead dominates for small operations
Forward pass:
Gradient derivation:
The gradient involves products of Jacobians:
Gradient magnitude decay: With $\norm{\mW_{hh}} = 0.9$ and $\tanh'$ averaging 0.5:
From time step 3 to 0 (3 steps back):
Vanishing gradient threshold: For gradients to vanish below $10^{-38}$:
Gradients vanish below FP32 precision after approximately 110 time steps.
(1) LSTM: $2{,}099{,}200$ parameters, $2.15$ GFLOPs
(2) GRU: $1{,}574{,}400$ parameters, $1.61$ GFLOPs
(3) Transformer: $1{,}048{,}576$ parameters, $1.61$ GFLOPs
Most parameters: LSTM (2.1M)
Most FLOPs: LSTM (2.15 GFLOPs)
Highest GPU utilization: Transformer (40-60\% vs 2-5\% for RNNs) due to full parallelization across sequence length and large matrix operations.
Output dimensions: $6 \times 64$ (concatenated forward and backward)
Memory for LSTM states:
- Forward hidden + cell: $2 \times 6 \times 32 \times 4 = 1{,}536$ bytes
- Backward hidden + cell: $1{,}536$ bytes
- Total: $3{,}072$ bytes $\approx 3$ KB
Transformer attention scores (8 heads): $8 \times 6 \times 6 \times 4 = 1{,}152$ bytes
LSTM requires 2.7× more memory for this small sequence, but attention scales as $O(n^2)$ vs $O(n)$ for LSTM.
(1) LSTM memory (12 layers): $\approx 1.13$ GB
(2) Transformer attention (12 layers): $\approx 4.5$ GB
(3) Equal memory at: $n = 128$ tokens
For $n>128$, transformers use more memory due to $O(n^2)$ attention scores. Transformers are memory-limited for long sequences, while LSTMs are compute-limited due to sequential processing.
BERT training time: 4 days
Speedup: 2.1× (BERT is faster)
Three main reasons:
- Parallelization: BERT processes all tokens in parallel ($\approx 10\times$ speedup)
- Memory bandwidth: BERT has higher arithmetic intensity ($\approx 3\times$ better)
- GPU utilization: BERT achieves 40-60\% vs 2-5\% for LSTM ($\approx 17\times$ better)