Vision Transformers

Chapter Overview

Vision Transformers (ViT) apply transformer architecture to computer vision, replacing convolutional neural networks. This chapter covers patch embeddings, position encodings for 2D images, ViT architecture variants, and hybrid CNN-transformer models.

Learning Objectives

  1. Understand how to apply transformers to images
  2. Implement patch embedding and position encoding
  3. Compare ViT to CNNs (ResNet, EfficientNet)
  4. Apply data augmentation and regularization for ViT
  5. Understand ViT variants (DeiT, Swin, CoAtNet)
  6. Implement masked autoencoding (MAE) for vision

From Images to Sequences

The Patch Embedding Approach

Challenge: Image is 2D array, transformer expects 1D sequence.

Solution: Divide image into patches, flatten each patch.

Definition: For image $\mI \in \R^{H \times W \times C}$ with patch size $P$:

Step 1: Divide into $N = HW/P^2$ patches

$$ \mI_{\text{patches}} \in \R^{N \times (P^2 \cdot C)} $$

Step 2: Linear projection

$$ \mX = \mI_{\text{patches}} \mW_{\text{patch}} + \vb \quad \text{where } \mW_{\text{patch}} \in \R^{(P^2C) \times d} $$

Step 3: Add position embeddings

$$ \mX = \mX + \mE_{\text{pos}} $$
Example: Image: $224 \times 224 \times 3$ (ImageNet standard)

Patch size: $P = 16$

\begin{tikzpicture}[ patch/.style={rectangle, draw, minimum size=0.8cm, font=\footnotesize}, node/.style={circle, draw, minimum size=0.6cm, font=\footnotesize}, layer/.style={rectangle, draw, minimum width=3cm, minimum height=0.8cm, font=\small}, arrow/.style={->, thick} ]

\node[font=\small] at (-1,2) {Image}; \foreach \i in {0,1,2,3} { \foreach \j in {0,1,2,3} { \node[patch, fill=blue!20] (p\i\j) at (\i*0.9, -\j*0.9) {}; } }

\node[layer, fill=green!10] (proj) at (5,0) {Linear \\ Projection};

\node[font=\small] at (9,2) {Sequence}; \node[node, fill=yellow!30] (cls) at (9,0.5) {CLS}; \node[node] (e1) at (9,-0.5) {$e_1$}; \node[node] (e2) at (9,-1.5) {$e_2$}; \node[font=\footnotesize] at (9,-2.3) {$\vdots$}; \node[node] (en) at (9,-3) {$e_N$};

\draw[arrow] (p00) to[bend left=10] (proj); \draw[arrow] (p11) -- (proj); \draw[arrow] (p22) to[bend right=10] (proj); \draw[arrow] (proj) -- (cls); \draw[arrow] (proj) -- (e1); \draw[arrow] (proj) -- (e2); \draw[arrow] (proj) -- (en);

\node[layer, fill=orange!10] (pos) at (12,-1.5) {+ Position \\ Embeddings}; \draw[arrow] (cls) to[bend left=10] (pos); \draw[arrow] (e1) -- (pos); \draw[arrow] (e2) -- (pos); \draw[arrow] (en) to[bend right=10] (pos);

\node[layer, fill=purple!10] (trans) at (15,-1.5) {Transformer \\ Encoder}; \draw[arrow] (pos) -- (trans);

\end{tikzpicture}

Vision Transformer patch embedding process. A $224 \times 224$ image is divided into $16 \times 16$ patches (196 total), each patch is flattened and linearly projected to embedding dimension $d$, a CLS token is prepended, position embeddings are added, and the sequence is processed by a standard transformer encoder.

Number of patches:

$$ N = \frac{224 \times 224}{16^2} = \frac{50176}{256} = 196 \text{ patches} $$

Each patch: $16 \times 16 \times 3 = 768$ values

Linear projection to $d = 768$:

$$ \mW_{\text{patch}} \in \R^{768 \times 768} $$

Sequence length: 196 tokens (much shorter than full image 50,176 pixels!)

With [CLS] token: 197 total sequence length

Position Encodings for 2D

Option 1: 1D Position Embeddings

$$ \mE_{\text{pos}} \in \R^{N \times d} $$
Learned absolute positions, treats as 1D sequence.

Option 2: 2D Position Embeddings

$$ \mE_{\text{pos}}(i,j) = \mE_{\text{row}}(i) + \mE_{\text{col}}(j) $$
Separate embeddings for row and column.

Original ViT uses 1D: Simpler, works well in practice!

Vision Transformer (ViT) Architecture

Complete ViT Model

Definition: Input: Image $\mI \in \R^{H \times W \times C}$

Step 1: Patch embedding

$$ \vx_{\text{patches}} = \text{PatchEmbed}(\mI) \in \R^{N \times d} $$

Step 2: Add [CLS] token

$$ \vx_0 = [\vx_{\text{cls}}, \vx_{\text{patches}}] \in \R^{(N+1) \times d} $$

Step 3: Add position embeddings

$$ \vx_0 = \vx_0 + \mE_{\text{pos}} $$

Step 4: Transformer encoder (L layers)

$$ \vx_L = \text{Transformer}(\vx_0) $$

Step 5: Classification head on [CLS]

$$ y = \text{softmax}(\mW_{\text{head}} \vx_L^{\text{cls}} + \vb) $$

ViT Model Variants

The Vision Transformer comes in three standard configurations that scale from moderate to extremely large models. ViT-Base uses 12 layers with hidden dimension $d = 768$ and 12 attention heads, resulting in 86 million parameters. This configuration is comparable in size to BERT-base and serves as the standard baseline for vision transformer research. The patch size is typically set to $P = 16$ for ImageNet-resolution images, producing 196 patches from a $224 \times 224$ input.

ViT-Large scales up to 24 layers with $d = 1024$ and 16 attention heads, totaling 307 million parameters. This represents a roughly 3.5× increase in parameters compared to ViT-Base, with the additional capacity enabling stronger performance when sufficient training data is available. The larger hidden dimension increases both the expressiveness of each layer and the computational cost per token.

ViT-Huge pushes the architecture to 32 layers with $d = 1280$ and 16 heads, reaching 632 million parameters. This massive model requires enormous datasets like JFT-300M for effective training and demonstrates the scalability of the transformer architecture to vision tasks. However, the computational and memory requirements make ViT-Huge impractical for many applications, with inference on a single image requiring several gigabytes of GPU memory and hundreds of milliseconds even on modern accelerators.

Example: Configuration: $L=12$, $d=768$, $h=12$, $P=16$, ImageNet ($N=196$)

Patch embedding:

$$ 768 \times 768 = 589{,}824 $$

Position embeddings:

$$ 197 \times 768 = 151{,}296 $$

Transformer encoder (12 layers):

$$ 12 \times 7{,}084{,}800 = 85{,}017{,}600 $$

Classification head (ImageNet, 1000 classes):

$$ 768 \times 1000 = 768{,}000 $$

Total: $\approx 86{,}527{,}000 \approx$ 86M parameters

Memory Requirements and Computational Analysis

The memory footprint of Vision Transformers scales with both the model size and the input image resolution. For ViT-Base with 86 million parameters, storing the model weights in FP32 requires $86 \times 10^6 \times 4 = 344$ MB. During training, we must also store optimizer states (momentum and variance for Adam), which doubles this to approximately 1 GB for the model alone. Additionally, activations must be stored for backpropagation, and their memory consumption depends critically on the sequence length.

For a standard $224 \times 224$ image with patch size 16, the sequence length is 196 tokens (plus one CLS token for 197 total). The activation memory for a single layer includes the attention scores matrix of size $h \times n \times n$ where $h = 12$ heads and $n = 197$, requiring $12 \times 197^2 \times 4 = 1.86$ MB in FP32. Across 12 layers with batch size 32, attention matrices alone consume approximately 714 MB. The feed-forward network activations add another $32 \times 197 \times 768 \times 4 \times 12 = 2.3$ GB for intermediate representations. In total, training ViT-Base with batch size 32 on $224 \times 224$ images requires approximately 8-10 GB of GPU memory, comfortably fitting on modern GPUs like the NVIDIA RTX 3090 or A100.

However, increasing the image resolution dramatically impacts memory requirements due to the quadratic scaling of attention. For $384 \times 384$ images with the same patch size of 16, the number of patches increases to $(384/16)^2 = 576$ tokens. The attention matrices now require $12 \times 577^2 \times 4 = 16.0$ MB per layer, or 6.1 GB across 12 layers with batch size 32. This represents an 8.5× increase in attention memory compared to $224 \times 224$ resolution. The total memory requirement grows to approximately 18-22 GB, necessitating high-end GPUs or gradient checkpointing techniques to fit in memory.

Example: Compare memory and computation for different resolutions with ViT-Base ($L=12$, $d=768$, $h=12$, $P=16$):

Resolution $224 \times 224$:

$$ n = \frac{224^2}{16^2} = 196 \text{ patches} $$
Attention memory per layer: $12 \times 197^2 \times 4 = 1.86$ MB

FLOPs per attention layer: $4n^2d = 4 \times 197^2 \times 768 = 119$ MFLOPs

Resolution $384 \times 384$:

$$ n = \frac{384^2}{16^2} = 576 \text{ patches} $$
Attention memory per layer: $12 \times 577^2 \times 4 = 16.0$ MB (8.6× increase)

FLOPs per attention layer: $4 \times 577^2 \times 768 = 1.03$ GFLOPs (8.6× increase)

Key insight: Memory and computation scale quadratically with image resolution when patch size is fixed. Doubling resolution increases cost by approximately 4×.

The patch size provides another lever for controlling computational cost. Using larger patches reduces the sequence length, thereby decreasing both memory and computation. For a $224 \times 224$ image, patch size $P = 32$ produces only $(224/32)^2 = 49$ patches compared to 196 for $P = 16$. This 4× reduction in sequence length translates to a 16× reduction in attention memory and computation due to the quadratic scaling. However, larger patches also reduce the model's ability to capture fine-grained visual details, creating a fundamental trade-off between efficiency and representational capacity.

Example: For $224 \times 224$ images with ViT-Base:

Patch size $P = 16$:

$$ n = 196, \quad \text{Attention FLOPs} = 119 \text{ MFLOPs per layer} $$

Patch size $P = 32$:

$$ n = 49, \quad \text{Attention FLOPs} = 7.4 \text{ MFLOPs per layer} $$

The 16× reduction in attention cost makes $P = 32$ attractive for efficiency, but the coarser granularity typically reduces accuracy by 2-3\% on ImageNet. The optimal patch size depends on the application: real-time systems may prefer $P = 32$, while accuracy-critical applications use $P = 16$ or even $P = 14$ for ViT-Huge.

Training Vision Transformers

Pre-training Strategies

Supervised Pre-training (Original ViT):

Key finding: ViT requires massive data to outperform CNNs!

Data Augmentation and Regularization

Essential for ViT (lacks CNN inductive biases):

Augmentation:

Regularization:

DeiT: Data-efficient Image Transformers

Improvements for training without massive datasets:

1. Knowledge Distillation

2. Strong Augmentation

Result: DeiT-Base achieves 81.8\% on ImageNet trained only on ImageNet (1.3M images)!

Masked Autoencoders (MAE)

Self-Supervised Pre-training for Vision

Definition: BERT-style masking for images:

Step 1: Randomly mask 75\% of patches

Step 2: Encoder processes only visible patches

Step 3: Decoder reconstructs all patches (including masked)

Loss: Pixel-level MSE on masked patches

$$ \mathcal{L} = \frac{1}{|\mathcal{M}|} \sum_{i \in \mathcal{M}} \|\hat{\vx}_i - \vx_i\|^2 $$
Example: Image: $224 \times 224$, patches $16 \times 16$ ($N=196$)

Masking: Keep 25\% = 49 patches, mask 147 patches

Encoder:

Decoder:

Benefits:

Hierarchical Vision Transformers

Motivation for Hierarchical Architectures

The original Vision Transformer processes images at a single scale, dividing the input into fixed-size patches and maintaining the same spatial resolution throughout all layers. While this uniform approach simplifies the architecture, it has significant limitations for computer vision tasks. Many vision problems benefit from multi-scale representations: low-level features like edges and textures are best captured at high resolution with small receptive fields, while high-level semantic concepts require large receptive fields that aggregate information across the entire image. CNNs naturally provide this hierarchical structure through pooling layers that progressively reduce spatial resolution while increasing channel capacity.

Additionally, the quadratic complexity of self-attention with respect to sequence length makes standard ViT impractical for high-resolution images or dense prediction tasks like object detection and semantic segmentation. For a $512 \times 512$ image with patch size 16, the sequence length reaches 1,024 tokens, requiring attention matrices of size $1024 \times 1024$ per head. With 12 heads across 12 layers, this consumes over 600 MB just for attention weights in a single forward pass. The computational cost of $O(n^2d)$ attention becomes prohibitive, limiting ViT's applicability to tasks requiring fine-grained spatial reasoning.

Hierarchical Vision Transformers address these limitations by introducing multi-scale processing and localized attention mechanisms. These architectures progressively reduce spatial resolution while increasing feature dimensions, mimicking the pyramid structure of CNNs while retaining the flexibility of transformer layers. By restricting attention to local windows rather than the full image, they achieve linear or near-linear complexity in the number of pixels, enabling efficient processing of high-resolution inputs.

Swin Transformer

The Swin Transformer (Shifted Window Transformer) introduces a hierarchical architecture with shifted window-based attention that achieves linear complexity while maintaining the ability to model long-range dependencies. The architecture consists of four stages, each operating at a different spatial resolution. The first stage processes the image at high resolution with small patches (typically $4 \times 4$), producing a large number of tokens. Subsequent stages merge adjacent patches to reduce the spatial dimensions by 2× while doubling the feature dimension, creating a pyramid structure similar to ResNet.

Definition: For input image $\mI \in \R^{H \times W \times 3}$:

Stage 1: Patch size $4 \times 4$, dimension $C$

$$ \text{Resolution: } \frac{H}{4} \times \frac{W}{4}, \quad \text{Channels: } C $$

Stage 2: Patch merging, dimension $2C$

$$ \text{Resolution: } \frac{H}{8} \times \frac{W}{8}, \quad \text{Channels: } 2C $$

Stage 3: Patch merging, dimension $4C$

$$ \text{Resolution: } \frac{H}{16} \times \frac{W}{16}, \quad \text{Channels: } 4C $$

Stage 4: Patch merging, dimension $8C$

$$ \text{Resolution: } \frac{H}{32} \times \frac{W}{32}, \quad \text{Channels: } 8C $$

For Swin-Base: $C = 128$, producing feature maps at resolutions $\frac{H}{4}, \frac{H}{8}, \frac{H}{16}, \frac{H}{32}$ with dimensions 128, 256, 512, 1024 respectively.

The key innovation of Swin Transformer is shifted window attention, which restricts self-attention to non-overlapping local windows while enabling cross-window connections through window shifting. In even-numbered layers, the image is partitioned into regular $M \times M$ windows (typically $M = 7$), and attention is computed independently within each window. In odd-numbered layers, the windows are shifted by $\lfloor M/2 \rfloor$ pixels in both horizontal and vertical directions, causing the windows to overlap with different regions than in the previous layer. This shifting mechanism allows information to flow between windows while maintaining the computational efficiency of local attention.

The computational complexity of window-based attention is $O(M^2 \cdot HW)$ where $M$ is the window size and $HW$ is the image resolution. For $M = 7$ and a $224 \times 224$ image at stage 1 resolution ($56 \times 56$ tokens), each window contains $7 \times 7 = 49$ tokens. The attention computation within a window requires $49^2 = 2,401$ operations per head, compared to $3,136^2 = 9.8$ million operations for global attention over all $56 \times 56$ tokens. This 4,000× reduction in attention complexity enables Swin Transformer to process high-resolution images efficiently while still capturing long-range dependencies through the hierarchical structure and window shifting.

Example: Compare attention complexity for $224 \times 224$ image at stage 1 ($56 \times 56$ tokens):

Global attention (standard ViT):

$$ \text{Complexity: } O(n^2d) = O(3136^2 \times 128) = 1.26 \text{ GFLOPs per layer} $$

Window attention (Swin, $M=7$):

$$ \text{Windows: } \frac{56}{7} \times \frac{56}{7} = 64 \text{ windows} $$
$$ \text{Complexity: } O(M^2 \cdot HW \cdot d) = O(49 \times 3136 \times 128) = 19.7 \text{ MFLOPs per layer} $$

The window-based approach reduces attention cost by 64×, making high-resolution processing practical. The shifted window mechanism ensures that information still propagates globally through the network depth.

Swin Transformer achieves state-of-the-art performance across multiple vision tasks while maintaining computational efficiency. On ImageNet classification, Swin-Base reaches 83.5\% top-1 accuracy with 88 million parameters and 15.4 GFLOPs—comparable to ViT-Base in parameters but with better accuracy due to the hierarchical structure. For object detection on COCO, Swin-Base achieves 51.9 box AP, surpassing previous transformer-based detectors by significant margins. The multi-scale feature maps produced by the hierarchical architecture are particularly well-suited for dense prediction tasks, making Swin Transformer a versatile backbone for various computer vision applications.

Pyramid Vision Transformer (PVT)

Pyramid Vision Transformer takes a different approach to hierarchical vision transformers by introducing spatial-reduction attention that progressively decreases the key and value sequence lengths. Unlike Swin's window-based attention, PVT maintains global attention but reduces computational cost by downsampling the keys and values before computing attention. This design preserves the ability to attend to the entire image while achieving sub-quadratic complexity.

In PVT, each stage reduces the spatial resolution through patch merging, similar to Swin Transformer. However, within each stage, the attention mechanism uses a spatial reduction operation on keys and values. For a reduction ratio $R$, the keys and values are reshaped and downsampled by $R \times R$, reducing their sequence length by a factor of $R^2$. The queries maintain the original resolution, allowing each token to attend to a downsampled representation of the entire image. This approach reduces attention complexity from $O(n^2d)$ to $O(n^2d/R^2)$, providing a tunable trade-off between computational cost and attention granularity.

The hierarchical structure of PVT produces feature maps at multiple scales, making it suitable as a backbone for dense prediction tasks. PVT-Medium with 44 million parameters achieves 82.0\% ImageNet accuracy while requiring only 6.7 GFLOPs—significantly more efficient than ViT-Base. For object detection, PVT-based detectors achieve competitive performance with CNN-based methods while offering the benefits of transformer architectures, including better transfer learning and attention-based interpretability.

Hybrid Architectures: CoAtNet

Hybrid architectures combine convolutional layers and transformer layers to leverage the complementary strengths of both approaches. Convolutional layers provide efficient local feature extraction with built-in translation equivariance, while transformer layers enable global reasoning and flexible attention patterns. CoAtNet (Convolution and Attention Network) systematically explores this design space, identifying an optimal combination that achieves state-of-the-art performance with improved efficiency.

The CoAtNet architecture consists of five stages with progressively decreasing spatial resolution. The first two stages use convolutional blocks based on the MBConv (Mobile Inverted Bottleneck Convolution) design from EfficientNet, which efficiently extracts local features at high resolution. These convolutional stages capture low-level visual patterns like edges, textures, and simple shapes with strong inductive bias and minimal computational cost. The spatial resolution is reduced by 2× at each stage through strided convolutions.

The final three stages employ transformer blocks with relative attention, enabling global reasoning over the extracted features. By this point in the network, the spatial resolution has been reduced by 8× or more, making global attention computationally feasible. The transformer stages learn high-level semantic representations and long-range dependencies that benefit from the flexibility of self-attention. The final stage uses attention pooling to aggregate spatial information into a global representation for classification.

Example: CoAtNet-3 configuration for $224 \times 224$ input:

Stage 0 (Stem): Convolution, $112 \times 112$ resolution, 64 channels

Stage 1: MBConv blocks, $112 \times 112$ resolution, 96 channels

Stage 2: MBConv blocks, $56 \times 56$ resolution, 192 channels

Stage 3: Transformer blocks, $28 \times 28$ resolution, 384 channels

Stage 4: Transformer blocks, $14 \times 14$ resolution, 768 channels

Stage 5: Attention pooling, global representation

Total parameters: 168M, FLOPs: 34.7G

This hybrid design achieves 87.9\% ImageNet accuracy, outperforming pure CNN and pure transformer architectures of similar size.

The success of CoAtNet demonstrates that the choice between convolution and attention need not be binary. By using convolutions where they excel (local feature extraction at high resolution) and transformers where they excel (global reasoning at lower resolution), hybrid architectures achieve better accuracy-efficiency trade-offs than either approach alone. CoAtNet-7, the largest variant with 2.4 billion parameters, achieved 90.88\% ImageNet accuracy and state-of-the-art results on multiple vision benchmarks at the time of its release, validating the hybrid approach at scale.

ViT vs CNN Comparison

Parameter Efficiency

Vision Transformers and Convolutional Neural Networks differ fundamentally in their parameter efficiency and data requirements. ResNet-50, a standard CNN baseline, contains approximately 25 million parameters distributed across convolutional layers with small kernel sizes (typically $3 \times 3$ or $7 \times 7$). In contrast, ViT-Base requires 86 million parameters—more than 3× the size of ResNet-50—to achieve comparable performance. This parameter gap reflects the different inductive biases: CNNs build in locality and translation equivariance through their convolutional structure, while transformers must learn these properties from data through their flexible attention mechanism.

The parameter distribution also differs significantly between the architectures. In ResNet-50, the majority of parameters reside in the later convolutional layers and the final fully-connected layer. For ViT-Base, the parameters are more evenly distributed across the 12 transformer layers, with each layer containing approximately 7 million parameters in the attention and feed-forward components. The patch embedding layer contributes only 590K parameters, while position embeddings add another 151K—both negligible compared to the transformer layers themselves.

Despite having more parameters, ViT-Base is not necessarily slower than ResNet-50 for inference. The transformer's matrix multiplications are highly optimized on modern GPUs, and the lack of spatial convolutions can actually improve throughput. On an NVIDIA A100 GPU, ViT-Base processes approximately 1,200 images per second at $224 \times 224$ resolution with batch size 128, compared to 1,400 images per second for ResNet-50. The 15\% throughput difference is much smaller than the 3× parameter gap would suggest, demonstrating the efficiency of transformer operations on modern hardware.

Computational Complexity Analysis

The computational complexity of Vision Transformers scales differently than CNNs, leading to different performance characteristics across image resolutions. For a CNN like ResNet-50, the computational cost is approximately $O(C \times k^2 \times H \times W)$ where $C$ is the number of channels, $k$ is the kernel size, and $H \times W$ is the spatial resolution. This linear scaling in spatial dimensions means that doubling the image resolution increases computation by 4×. For ResNet-50 processing a $224 \times 224$ image, the total computation is approximately 4.1 GFLOPs.

Vision Transformers have complexity $O(n^2d + nd^2)$ where $n = (H/P)^2$ is the number of patches and $d$ is the hidden dimension. The $n^2d$ term comes from attention, while $nd^2$ comes from the feed-forward network. For ViT-Base with $224 \times 224$ images and patch size 16, we have $n = 196$ and $d = 768$. The attention computation across 12 layers totals $12 \times 4 \times 196^2 \times 768 = 1.4$ GFLOPs, while the feed-forward network contributes $12 \times 2 \times 196 \times 768^2 = 2.8$ GFLOPs, for a total of approximately 4.2 GFLOPs—nearly identical to ResNet-50.

However, the scaling behavior differs dramatically. When we increase resolution to $384 \times 384$ with the same patch size, the number of patches grows to $n = 576$, increasing by a factor of $(384/224)^2 = 2.94$. The attention cost grows quadratically to $12 \times 4 \times 576^2 \times 768 = 12.3$ GFLOPs (8.6× increase), while the feed-forward cost grows linearly to $12 \times 2 \times 576 \times 768^2 = 8.1$ GFLOPs (2.9× increase). The total ViT computation reaches 20.4 GFLOPs, compared to 12.0 GFLOPs for ResNet-50 at the same resolution. This crossover point illustrates why efficient attention mechanisms become critical for high-resolution vision tasks.

Example: Compare FLOPs for ResNet-50 and ViT-Base across resolutions:
ResolutionResNet-50ViT-Base
$224 \times 224$4.1 GFLOPs4.2 GFLOPs
$384 \times 384$12.0 GFLOPs20.4 GFLOPs
$512 \times 512$21.3 GFLOPs48.7 GFLOPs

At standard ImageNet resolution, ViT and ResNet have similar computational cost. However, ViT's quadratic attention scaling makes it increasingly expensive at higher resolutions, motivating hierarchical architectures like Swin Transformer that reduce attention to local windows.

Data Requirements and Inductive Bias

The most striking difference between Vision Transformers and CNNs lies in their data requirements, which stem from their different inductive biases. CNNs encode strong priors about images: locality (nearby pixels are related), translation equivariance (a cat is a cat regardless of position), and hierarchical structure (edges → textures → objects). These built-in assumptions allow CNNs to learn effectively from moderate-sized datasets like ImageNet with 1.3 million images. ResNet-50 trained only on ImageNet achieves 76.5\% top-1 accuracy, demonstrating that the convolutional structure provides useful inductive bias for natural images.

Vision Transformers, by contrast, have minimal inductive bias. The self-attention mechanism can attend to any patch regardless of spatial distance, and the model must learn locality and translation properties from data. When trained only on ImageNet, ViT-Base achieves only 72.3\% accuracy—4.2 percentage points below ResNet-50 despite having 3× more parameters. This performance gap reveals that the flexibility of attention becomes a liability when training data is limited: the model has too much capacity and insufficient constraints to learn good representations.

The situation reverses dramatically with large-scale pre-training. When ViT-Base is pre-trained on JFT-300M (300 million images with 18,000 classes) and then fine-tuned on ImageNet, it achieves 84.2\% accuracy, surpassing ResNet-50's 76.5\% by a substantial margin. The massive pre-training dataset provides enough examples for the transformer to learn the visual priors that CNNs encode by design. Moreover, the learned representations transfer better to downstream tasks: ViT-Base pre-trained on JFT-300M achieves higher accuracy than ResNet-50 on 19 out of 20 transfer learning benchmarks, with improvements ranging from 2-7 percentage points.

This data-efficiency trade-off has important practical implications. For applications with limited training data or computational budgets, CNNs remain the better choice. For large-scale systems with access to massive datasets and compute, Vision Transformers offer superior performance and transfer learning capabilities. The development of data-efficient training methods like DeiT (Data-efficient Image Transformers) has partially bridged this gap, enabling ViT-Base to achieve 81.8\% on ImageNet without external data through aggressive augmentation and distillation techniques.

When to Use Each Architecture

AspectCNN (ResNet)ViT
Inductive biasStrong (locality, translation)Weak
Data requirementModerate (ImageNet)Large (JFT-300M)
Parameters25M (ResNet-50)86M (ViT-Base)
Computation$O(HW)$$O((HW/P)^2)$
Memory5-7 GB training8-10 GB training
InterpretabilityFilter visualizationAttention maps
TransferGoodExcellent (large-scale)
Best useSmall/medium dataLarge-scale pre-training

The choice between CNNs and Vision Transformers depends on the specific application constraints. CNNs are preferable when training data is limited (fewer than 10 million images), when computational efficiency is critical (mobile or edge deployment), or when strong spatial priors are known to be appropriate for the task. ResNet and EfficientNet variants remain the standard choice for many production computer vision systems due to their reliability and efficiency.

Vision Transformers excel when massive pre-training data is available, when transfer learning to diverse downstream tasks is important, or when state-of-the-art performance justifies the additional computational cost. The superior scaling properties of transformers—both in terms of model size and dataset size—make them the architecture of choice for foundation models in vision. Hybrid architectures like CoAtNet attempt to combine the strengths of both approaches, using convolutional layers for early feature extraction and transformer layers for high-level reasoning.

Exercises

Exercise 1: Implement patch embedding for image $224 \times 224 \times 3$ with patch size 16:
  1. Reshape image to patches
  2. Apply linear projection
  3. Add position embeddings
  4. Verify output shape: $(196, 768)$
Exercise 2: Compare ViT-Base and ResNet-50:
  1. Parameter count
  2. FLOPs for $224 \times 224$ image
  3. Memory footprint
  4. Which is more efficient?
Exercise 3: Implement MAE masking:
  1. Randomly mask 75\% of 196 patches
  2. Keep 49 visible patches
  3. Add mask tokens for decoder
  4. Compute reconstruction loss
Exercise 4: Train ViT-Tiny on CIFAR-10:
  1. Use patch size 4 (for $32 \times 32$ images)
  2. 6 layers, $d=192$, 3 heads
  3. Apply RandAugment
  4. Compare to small ResNet

Solutions

Full solutions for all exercises are available at \url{https://deeplearning.hofkensvermeulen.be}.

Solution: Exercise 1: Patch Embedding Implementation
import torch
import torch.nn as nn
import numpy as np

class PatchEmbedding(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2
        
        # Linear projection of flattened patches
        self.proj = nn.Conv2d(
            in_channels, 
            embed_dim,
            kernel_size=patch_size,
            stride=patch_size
        )
        
        # Position embeddings
        self.pos_embed = nn.Parameter(
            torch.randn(1, self.n_patches + 1, embed_dim)
        )
        
        # CLS token
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
    
    def forward(self, x):
        # x: (B, C, H, W)
        B = x.shape[0]
        
        # Part 1: Reshape image to patches and project
        # Conv2d with stride=patch_size extracts non-overlapping patches
        x = self.proj(x)  # (B, embed_dim, H/P, W/P)
        x = x.flatten(2)  # (B, embed_dim, n_patches)
        x = x.transpose(1, 2)  # (B, n_patches, embed_dim)
        
        # Part 2: Add CLS token
        cls_tokens = self.cls_token.expand(B, -1, -1)  # (B, 1, embed_dim)
        x = torch.cat([cls_tokens, x], dim=1)  # (B, n_patches+1, embed_dim)
        
        # Part 3: Add position embeddings
        x = x + self.pos_embed
        
        return x

# Example usage
img_size = 224
patch_size = 16
in_channels = 3
embed_dim = 768

# Create model
patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)

# Create sample image
batch_size = 4
image = torch.randn(batch_size, in_channels, img_size, img_size)

# Forward pass
output = patch_embed(image)

print(f"Input image shape: {image.shape}")
print(f"Number of patches: {(img_size // patch_size) ** 2}")
print(f"Output shape: {output.shape}")
print(f"Expected: (batch_size, n_patches+1, embed_dim)")
print(f"Actual: ({batch_size}, {(img_size//patch_size)**2 + 1}, {embed_dim})")

Detailed Breakdown:

Part (a): Reshape Image to Patches

Original image: $224 \times 224 \times 3$

Patch size: $16 \times 16$

Number of patches: $\frac{224}{16} \times \frac{224}{16} = 14 \times 14 = 196$ patches

Each patch: $16 \times 16 \times 3 = 768$ values

Reshaped: $(196, 768)$

Part (b): Linear Projection

Using Conv2d with kernel\_size=16, stride=16:

This is equivalent to:

  1. Extract $14 \times 14 = 196$ non-overlapping patches
  2. Flatten each patch: $16 \times 16 \times 3 = 768$ values
  3. Apply linear projection: $\mathbb{R}^{768} \to \mathbb{R}^{768}$

Part (c): Add Position Embeddings

Position embeddings: learnable parameters of shape $(1, 197, 768)$

$$\vx_i = \text{PatchEmbed}(\text{patch}_i) + \vpos_i$$

Part (d): Output Shape Verification


Input image shape: torch.Size([4, 3, 224, 224])
Number of patches: 196
Output shape: torch.Size([4, 197, 768])
Expected: (batch_size, n_patches+1, embed_dim)
Actual: (4, 197, 768)

Output shape: $(B, 197, 768)$ where:

Key Design Choices:

  1. Conv2d for patching: More efficient than manual reshaping
  2. CLS token: Special token for classification (like BERT)
  3. Learnable position embeddings: Unlike sinusoidal in original Transformer
  4. No overlap: Patches don't overlap (stride = patch\_size)

Alternative Implementation (Manual):

def manual_patch_embedding(image, patch_size=16):
    """Manual patch extraction without Conv2d"""
    B, C, H, W = image.shape
    P = patch_size
    
    # Reshape to patches
    patches = image.unfold(2, P, P).unfold(3, P, P)  # (B, C, H/P, W/P, P, P)
    patches = patches.contiguous().view(B, C, -1, P, P)  # (B, C, n_patches, P, P)
    patches = patches.permute(0, 2, 1, 3, 4)  # (B, n_patches, C, P, P)
    patches = patches.reshape(B, -1, C * P * P)  # (B, n_patches, C*P*P)
    
    return patches

# Verify equivalence
manual_patches = manual_patch_embedding(image, patch_size)
print(f"Manual patches shape: {manual_patches.shape}")  # (4, 196, 768)

Both methods produce identical results, but Conv2d is more efficient and commonly used in practice.

Solution: Exercise 2: ViT-Base vs ResNet-50 Comparison

Part (a): Parameter Count

ViT-Base:

Parameters:

Total ViT-Base: $86{,}481{,}408 \approx 86$M parameters

ResNet-50:

Total ResNet-50: $25{,}556{,}032 \approx 25.6$M parameters

Ratio: ViT-Base has $3.4\times$ more parameters than ResNet-50

Part (b): FLOPs for $224 \times 224$ Image

ViT-Base FLOPs:

1. Patch Embedding:

2. Per Transformer Layer (12 layers):

Multi-Head Attention:

MLP:

Total per layer: $1{,}451{,}937{,}792$ FLOPs

12 layers: $12 \times 1{,}451{,}937{,}792 = 17{,}423{,}253{,}504$ FLOPs

3. Classification Head:

Total ViT-Base: $\approx 17.5$ GFLOPs

ResNet-50 FLOPs:

Total ResNet-50: $\approx 4.1$ GFLOPs

Ratio: ViT-Base requires $4.3\times$ more FLOPs than ResNet-50

Part (c): Memory Footprint

ViT-Base Memory (Inference):

1. Activations per layer:

Peak activation memory per layer: $\approx 1.2$M values $\times$ 4 bytes = 4.8 MB

Total for 12 layers: $\approx 58$ MB

Parameters: $86$M $\times$ 4 bytes = 344 MB

Total ViT-Base inference: $\approx 402$ MB

ResNet-50 Memory (Inference):

Peak activation memory:

Peak: $\approx 3.2$ MB

Parameters: $25.6$M $\times$ 4 bytes = 102 MB

Total ResNet-50 inference: $\approx 105$ MB

Ratio: ViT-Base uses $3.8\times$ more memory than ResNet-50

Part (d): Which is More Efficient?

Efficiency Analysis:

MetricViT-BaseResNet-50Ratio
Parameters86M25.6M$3.4\times$
FLOPs17.5 GFLOPs4.1 GFLOPs$4.3\times$
Memory402 MB105 MB$3.8\times$

Conclusion:

ResNet-50 is more computationally efficient in terms of:

However, ViT-Base has advantages:

  1. Better scaling: Performance improves more with larger datasets
  2. Transfer learning: Pre-trained ViT generalizes better
  3. Parallelization: Self-attention is more parallelizable than convolutions
  4. Long-range dependencies: Global receptive field from layer 1
  5. Interpretability: Attention maps show what model focuses on

Trade-off:

Practical Recommendation:

For ImageNet-1k from scratch: ResNet-50

For transfer learning with pre-training: ViT-Base

For production with limited compute: ResNet-50

For research and maximum accuracy: ViT-Base

Solution: Exercise 3: MAE Masking Implementation
import torch
import torch.nn as nn
import numpy as np

class MAEMasking(nn.Module):
    def __init__(self, n_patches=196, embed_dim=768, mask_ratio=0.75):
        super().__init__()
        self.n_patches = n_patches
        self.embed_dim = embed_dim
        self.mask_ratio = mask_ratio
        self.n_visible = int(n_patches * (1 - mask_ratio))
        
        # Mask token for decoder
        self.mask_token = nn.Parameter(torch.randn(1, 1, embed_dim))
    
    def random_masking(self, x):
        """
        Randomly mask patches
        Args:
            x: (B, N, D) where N = n_patches + 1 (including CLS)
        Returns:
            x_visible: (B, n_visible+1, D) visible patches + CLS
            mask: (B, N) binary mask (0 = keep, 1 = remove)
            ids_restore: (B, N) indices to restore original order
        """
        B, N, D = x.shape
        N_patches = N - 1  # Exclude CLS token
        
        # Generate random noise for shuffling
        noise = torch.rand(B, N_patches, device=x.device)
        
        # Sort noise to get shuffle indices
        ids_shuffle = torch.argsort(noise, dim=1)
        ids_restore = torch.argsort(ids_shuffle, dim=1)
        
        # Keep first n_visible patches
        ids_keep = ids_shuffle[:, :self.n_visible]
        
        # Extract CLS token
        cls_token = x[:, :1, :]  # (B, 1, D)
        
        # Extract patch tokens (exclude CLS)
        x_patches = x[:, 1:, :]  # (B, N_patches, D)
        
        # Gather visible patches
        x_visible = torch.gather(
            x_patches, 
            dim=1, 
            index=ids_keep.unsqueeze(-1).expand(-1, -1, D)
        )  # (B, n_visible, D)
        
        # Concatenate CLS token
        x_visible = torch.cat([cls_token, x_visible], dim=1)  # (B, n_visible+1, D)
        
        # Generate binary mask: 0 = keep, 1 = remove
        mask = torch.ones(B, N_patches, device=x.device)
        mask[:, :self.n_visible] = 0
        mask = torch.gather(mask, dim=1, index=ids_restore)
        
        return x_visible, mask, ids_restore
    
    def add_mask_tokens(self, x_visible, ids_restore):
        """
        Add mask tokens for decoder
        Args:
            x_visible: (B, n_visible+1, D)
            ids_restore: (B, N_patches)
        Returns:
            x_full: (B, N, D) with mask tokens
        """
        B, _, D = x_visible.shape
        
        # Extract CLS token
        cls_token = x_visible[:, :1, :]
        
        # Extract visible patches
        x_patches = x_visible[:, 1:, :]  # (B, n_visible, D)
        
        # Create mask tokens
        n_mask = self.n_patches - self.n_visible
        mask_tokens = self.mask_token.expand(B, n_mask, -1)
        
        # Concatenate visible and mask tokens
        x_combined = torch.cat([x_patches, mask_tokens], dim=1)  # (B, N_patches, D)
        
        # Restore original order
        x_restored = torch.gather(
            x_combined,
            dim=1,
            index=ids_restore.unsqueeze(-1).expand(-1, -1, D)
        )  # (B, N_patches, D)
        
        # Add CLS token back
        x_full = torch.cat([cls_token, x_restored], dim=1)  # (B, N, D)
        
        return x_full
    
    def forward(self, x):
        """
        Complete MAE masking pipeline
        """
        # Part 1: Random masking
        x_visible, mask, ids_restore = self.random_masking(x)
        
        # Part 2: Add mask tokens for decoder
        x_full = self.add_mask_tokens(x_visible, ids_restore)
        
        return x_visible, x_full, mask, ids_restore



# Example usage
n_patches = 196
embed_dim = 768
mask_ratio = 0.75
batch_size = 4

# Create MAE masking module
mae_mask = MAEMasking(n_patches, embed_dim, mask_ratio)

# Simulate patch embeddings (including CLS token)
x = torch.randn(batch_size, n_patches + 1, embed_dim)

# Apply masking
x_visible, x_full, mask, ids_restore = mae_mask(x)

print(f"Input shape: {x.shape}")
print(f"Visible patches shape: {x_visible.shape}")
print(f"Full with mask tokens shape: {x_full.shape}")
print(f"Mask shape: {mask.shape}")
print(f"Number of masked patches: {mask.sum(dim=1)[0].item()}")
print(f"Number of visible patches: {(1 - mask).sum(dim=1)[0].item()}")
Output:

Input shape: torch.Size([4, 197, 768])
Visible patches shape: torch.Size([4, 50, 768])
Full with mask tokens shape: torch.Size([4, 197, 768])
Mask shape: torch.Size([4, 196])
Number of masked patches: 147.0
Number of visible patches: 49.0

Part (a): Randomly Mask 75\% of 196 Patches

Masking Strategy:

  1. Generate random noise: $\text{noise} \sim \mathcal{U}(0, 1)^{196}$
  2. Sort noise to get shuffle indices
  3. Keep first $49$ patches (25\% of 196)
  4. Mask remaining $147$ patches (75\% of 196)

Mathematical Formulation:

Let $\vx = [\vx_{\text{cls}}, \vx_1, \vx_2, \ldots, \vx_{196}]$ be patch embeddings.

Random permutation: $\pi: \{1, \ldots, 196\} \to \{1, \ldots, 196\}$

Visible set: $\mathcal{V} = \{\pi(1), \ldots, \pi(49)\}$

Masked set: $\mathcal{M} = \{\pi(50), \ldots, \pi(196)\}$

Binary mask: $m_i = \begin{cases} 0 & \text{if } i \in \mathcal{V} \\ 1 & \text{if } i \in \mathcal{M} \end{cases}$

Part (b): Keep 49 Visible Patches

Encoder Input:

$\vx_{\text{visible}} = [\vx_{\text{cls}}, \vx_{\pi(1)}, \vx_{\pi(2)}, \ldots, \vx_{\pi(49)}]$

Shape: $(B, 50, 768)$ where $50 = 49 + 1$ (CLS token)

Computational Savings:

Encoder processes only 25\% of patches:

This is the key efficiency gain of MAE!

Part (c): Add Mask Tokens for Decoder

Decoder Input Construction:

  1. Take encoder output: $\vz_{\text{visible}} = \text{Encoder}(\vx_{\text{visible}})$
  2. Create mask tokens: $\vm_{\text{mask}} \in \mathbb{R}^{147 \times 768}$ (learnable)
  3. Concatenate: $[\vz_{\pi(1)}, \ldots, \vz_{\pi(49)}, \vm_1, \ldots, \vm_{147}]$
  4. Restore original order using $\pi^{-1}$
  5. Add position embeddings

Decoder Input:

$\vx_{\text{decoder}} = [\vx_{\text{cls}}, \vz_1, \vz_2, \ldots, \vz_{196}]$

where $\vz_i = \begin{cases} \text{Encoder output} & \text{if } i \in \mathcal{V} \\ \vm_{\text{mask}} & \text{if } i \in \mathcal{M} \end{cases}$

Shape: $(B, 197, 768)$ - full sequence restored

Part (d): Compute Reconstruction Loss

Decoder Output:

$\hat{\vx}_i = \text{Decoder}(\vx_{\text{decoder}})_i$ for $i = 1, \ldots, 196$

Reconstruction Loss (MSE on masked patches only):

$\mathcal{L}_{\text{MAE}} = \frac{1}{|\mathcal{M}|} \sum_{i \in \mathcal{M}} \|\hat{\vx}_i - \vx_i\|_2^2$

Only compute loss on masked patches (147 patches):

def compute_mae_loss(original_patches, reconstructed_patches, mask):
    """
    Compute MAE reconstruction loss
    Args:
        original_patches: (B, N, D) original patch embeddings
        reconstructed_patches: (B, N, D) decoder output
        mask: (B, N) binary mask (1 = masked, 0 = visible)
    Returns:
        loss: scalar
    """
    # Compute MSE
    mse = (reconstructed_patches - original_patches) ** 2
    mse = mse.mean(dim=-1)  # (B, N) - mean over embedding dim
    
    # Apply mask - only compute loss on masked patches
    loss = (mse * mask).sum() / mask.sum()
    
    return loss

# Example
original = torch.randn(4, 196, 768)
reconstructed = torch.randn(4, 196, 768)
mask = torch.zeros(4, 196)
mask[:, 49:] = 1  # Mask last 147 patches

loss = compute_mae_loss(original, reconstructed, mask)
print(f"MAE Loss: {loss.item():.4f}")

Why Only Masked Patches?

Complete MAE Training Loop:

# 1. Patch embedding
patches = patch_embed(images)  # (B, 197, 768)

# 2. Random masking
visible_patches, mask, ids_restore = random_masking(patches)  # (B, 50, 768)

# 3. Encoder (only on visible patches)
encoded = encoder(visible_patches)  # (B, 50, 768)

# 4. Add mask tokens and restore order
decoder_input = add_mask_tokens(encoded, ids_restore)  # (B, 197, 768)

# 5. Decoder
reconstructed = decoder(decoder_input)  # (B, 197, 768)

# 6. Compute loss (only on masked patches)
loss = compute_mae_loss(patches[:, 1:], reconstructed[:, 1:], mask)

# 7. Backpropagation
loss.backward()

Key Insights:

  1. High masking ratio (75\%): Forces model to learn global structure
  2. Random masking: Prevents trivial solutions (interpolation)
  3. Asymmetric encoder-decoder: Encoder is large, decoder is small
  4. Pixel-level reconstruction: Simpler than contrastive learning
  5. Efficiency: $15.5\times$ faster than processing all patches
Solution: Exercise 4: Train ViT-Tiny on CIFAR-10
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

class ViTTiny(nn.Module):
    def __init__(self, img_size=32, patch_size=4, in_channels=3, 
                 num_classes=10, embed_dim=192, depth=6, num_heads=3,
                 mlp_ratio=4.0, dropout=0.1):
        super().__init__()
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2  # 64 patches
        
        # Patch embedding
        self.patch_embed = nn.Conv2d(
            in_channels, embed_dim, 
            kernel_size=patch_size, stride=patch_size
        )
        
        # CLS token and position embeddings
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.randn(1, self.n_patches + 1, embed_dim))
        self.dropout = nn.Dropout(dropout)
        
        # Transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            dim_feedforward=int(embed_dim * mlp_ratio),
            dropout=dropout,
            activation='gelu',
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=depth)
        
        # Classification head
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)
    
    def forward(self, x):
        B = x.shape[0]
        
        # Patch embedding
        x = self.patch_embed(x)  # (B, embed_dim, H/P, W/P)
        x = x.flatten(2).transpose(1, 2)  # (B, n_patches, embed_dim)
        
        # Add CLS token
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat([cls_tokens, x], dim=1)
        
        # Add position embeddings
        x = x + self.pos_embed
        x = self.dropout(x)
        
        # Transformer
        x = self.transformer(x)
        
        # Classification
        x = self.norm(x[:, 0])  # CLS token
        x = self.head(x)
        
        return x



# Part (a): Patch size 4 for 32x32 images
print(f"Image size: 32x32")
print(f"Patch size: 4x4")
print(f"Number of patches: {(32 // 4) ** 2} = 64")
print(f"Each patch: 4x4x3 = 48 values")

# Part (b): Model configuration
model = ViTTiny(
    img_size=32,
    patch_size=4,
    in_channels=3,
    num_classes=10,
    embed_dim=192,
    depth=6,
    num_heads=3,
    mlp_ratio=4.0
)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
print(f"\nViT-Tiny parameters: {total_params:,}")

# Part (c): RandAugment data augmentation
from torchvision.transforms import RandAugment

train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    RandAugment(num_ops=2, magnitude=9),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

# Load CIFAR-10
train_dataset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=train_transform
)
test_dataset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=test_transform
)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=4)

# Training setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.05)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)

# Training loop
def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
    
    return total_loss / len(loader), 100. * correct / total

def evaluate(model, loader, device):
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    
    return 100. * correct / total

# Train for 200 epochs
num_epochs = 200
best_acc = 0

for epoch in range(num_epochs):
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
    test_acc = evaluate(model, test_loader, device)
    scheduler.step()
    
    if test_acc > best_acc:
        best_acc = test_acc
    
    if (epoch + 1) 
        print(f"Epoch {epoch+1}/{num_epochs}")
        print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}
        print(f"Test Acc: {test_acc:.2f}

print(f"\nFinal Best Test Accuracy: {best_acc:.2f}

Part (d): Compare to Small ResNet

# Small ResNet for CIFAR-10
class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 1, stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )
    
    def forward(self, x):
        out = torch.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = torch.relu(out)
        return out

class SmallResNet(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 64, 3, 1, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        
        self.layer1 = self._make_layer(64, 64, 2, stride=1)
        self.layer2 = self._make_layer(64, 128, 2, stride=2)
        self.layer3 = self._make_layer(128, 256, 2, stride=2)
        
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(256, num_classes)
    
    def _make_layer(self, in_channels, out_channels, num_blocks, stride):
        layers = [BasicBlock(in_channels, out_channels, stride)]
        for _ in range(1, num_blocks):
            layers.append(BasicBlock(out_channels, out_channels, 1))
        return nn.Sequential(*layers)
    
    def forward(self, x):
        x = torch.relu(self.bn1(self.conv1(x)))
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

# Create and compare models
resnet = SmallResNet(num_classes=10)
vit_tiny = ViTTiny()

resnet_params = sum(p.numel() for p in resnet.parameters())
vit_params = sum(p.numel() for p in vit_tiny.parameters())

print("Model Comparison:")
print(f"ViT-Tiny parameters: {vit_params:,}")
print(f"Small ResNet parameters: {resnet_params:,}")
print(f"Ratio: {vit_params / resnet_params:.2f}x")

Expected Results:

ModelParametersTest AccTraining Time
ViT-Tiny$\sim$5.7M85-87\%Slower
Small ResNet$\sim$2.8M88-90\%Faster

Analysis:

Part (a): Patch Size 4 for 32×32 Images

This is appropriate for CIFAR-10 because:

Part (b): Model Configuration

ViT-Tiny Architecture:

Parameter Count:

Part (c): RandAugment

RandAugment Strategy:

Why RandAugment for ViT?

  1. Data augmentation is crucial: ViT lacks inductive bias
  2. Prevents overfitting: CIFAR-10 is small (50k images)
  3. Improves generalization: +2-3\% accuracy improvement
  4. Simpler than AutoAugment: No search required

Training Recipe:

Part (d): Comparison with Small ResNet

Quantitative Comparison:

MetricViT-TinySmall ResNet
Parameters5.7M2.8M
FLOPs$\sim$0.5 GFLOPs$\sim$0.3 GFLOPs
Test Accuracy85-87\%88-90\%
Training Time$\sim$3 hours$\sim$2 hours
ConvergenceSlowerFaster

Why ResNet Performs Better on CIFAR-10:

  1. Inductive bias: Convolutions encode spatial locality
  2. Translation equivariance: Built into convolutions
  3. Parameter efficiency: Fewer parameters, better generalization
  4. Small dataset: CIFAR-10 (50k) is too small for ViT
  5. Low resolution: $32 \times 32$ images have limited spatial information

When ViT Would Win:

Practical Recommendations:

  1. From scratch on CIFAR-10: Use ResNet (better accuracy, faster)
  2. With pre-training: Use ViT (transfer learning advantage)
  3. Research purposes: Try both, compare carefully
  4. Production: ResNet for efficiency, ViT for maximum accuracy with pre-training

Key Takeaways:

Experiment Variations to Try:

  1. Increase ViT depth to 12 layers
  2. Try different patch sizes (2, 4, 8)
  3. Add more aggressive augmentation
  4. Use mixup or cutmix
  5. Pre-train on CIFAR-100, fine-tune on CIFAR-10
  6. Compare with hybrid models (ResNet + Transformer)
← Chapter 16: Efficient Transformers 📚 Table of Contents Chapter 18: Multimodal Transformers →