Multimodal Transformers

Chapter Overview

Multimodal transformers process multiple modalities (text, images, audio, video) in a unified framework. This chapter covers vision-language models (CLIP, DALL-E), audio-text models (Whisper), and unified architectures that handle arbitrary combinations of modalities.

Learning Objectives

  1. Understand multimodal fusion strategies
  2. Implement contrastive learning (CLIP)
  3. Apply vision-language models to zero-shot classification
  4. Generate images from text (DALL-E, Stable Diffusion)
  5. Process audio with transformers (Whisper)
  6. Build unified multimodal models

Multimodal Learning Fundamentals

Fusion Strategies

The choice of fusion strategy determines how modalities interact, with direct implications for computational cost and expressiveness. Three primary approaches have emerged:

\begin{tikzpicture}[ node/.style={circle, draw, minimum size=0.6cm, font=\footnotesize}, encoder/.style={rectangle, draw, minimum width=2cm, minimum height=1.2cm, font=\small}, arrow/.style={->, thick}, bidir/.style={<->, thick, blue}, cross/.style={->, thick, green!60!black, dashed} ]

\node[font=\small\bfseries] at (0,4) {Early Fusion}; \node[node, fill=blue!20] (v1) at (-0.5,2.5) {$v_1$}; \node[node, fill=blue!20] (v2) at (0.5,2.5) {$v_2$}; \node[node, fill=red!20] (t1) at (-0.5,1.5) {$t_1$}; \node[node, fill=red!20] (t2) at (0.5,1.5) {$t_2$};

\node[encoder, fill=purple!10] (enc1) at (0,0) {Unified \\ Encoder}; \draw[arrow] (v1) -- (enc1); \draw[arrow] (v2) -- (enc1); \draw[arrow] (t1) -- (enc1); \draw[arrow] (t2) -- (enc1);

\draw[bidir] (v1) to[bend left=10] (t1); \draw[bidir] (v2) to[bend right=10] (t2);

\node[font=\small\bfseries] at (6,4) {Late Fusion}; \node[node, fill=blue!20] (v3) at (5,2.5) {$v_1$}; \node[node, fill=blue!20] (v4) at (6,2.5) {$v_2$}; \node[node, fill=red!20] (t3) at (5,1.5) {$t_1$}; \node[node, fill=red!20] (t4) at (6,1.5) {$t_2$};

\node[encoder, fill=blue!10] (venc) at (5.5,0.8) {Vision \\ Encoder}; \node[encoder, fill=red!10] (tenc) at (5.5,-0.5) {Text \\ Encoder};

\draw[arrow] (v3) -- (venc); \draw[arrow] (v4) -- (venc); \draw[arrow] (t3) -- (tenc); \draw[arrow] (t4) -- (tenc);

\node[encoder, fill=yellow!10] (fuse) at (5.5,-2) {Fusion}; \draw[arrow] (venc) -- (fuse); \draw[arrow] (tenc) -- (fuse);

\node[font=\small\bfseries] at (12,4) {Cross-Modal}; \node[node, fill=blue!20] (v5) at (11,2.5) {$v_1$}; \node[node, fill=blue!20] (v6) at (12,2.5) {$v_2$}; \node[node, fill=red!20] (t5) at (11,1.5) {$t_1$}; \node[node, fill=red!20] (t6) at (12,1.5) {$t_2$};

\node[encoder, fill=blue!10] (venc2) at (11.5,0.8) {Vision \\ Encoder}; \node[encoder, fill=red!10] (tenc2) at (11.5,-0.5) {Text \\ Encoder};

\draw[arrow] (v5) -- (venc2); \draw[arrow] (v6) -- (venc2); \draw[arrow] (t5) -- (tenc2); \draw[arrow] (t6) -- (tenc2);

\draw[cross] (venc2) to[bend left=20] (tenc2); \draw[cross] (tenc2) to[bend left=20] (venc2);

\end{tikzpicture}

Three multimodal fusion strategies. Early fusion (left): concatenates modalities and processes with unified encoder, enabling rich interactions but with quadratic cost $O((N+M)^2)$. Late fusion (center): separate encoders with fusion only at output, efficient $O(N^2 + M^2)$ but limited cross-modal interaction. Cross-modal attention (right): separate encoders with explicit cross-attention, balancing efficiency $O(N^2 + M^2 + NM)$ with rich interactions.
StrategyDescriptionProsCons
Early fusionConcatenate modality tokens into one sequence; process with unified encoderRich cross-modal interaction at every layer; simple architecture$O((N{+}M)^2 d)$ cost; adding patches dramatically increases compute
Late fusionSeparate encoders per modality; combine outputs at decision stage (CLIP)Efficient $O(N^2 d {+} M^2 d)$; encoders parallelizableNo fine-grained cross-modal alignment; interaction only at output
Cross-modal attentionSeparate encoders with cross-attention layers between modalities (BLIP, Flamingo)$O(N^2 d {+} M^2 d {+} NMd)$; rich interactions with moderate costAdditional parameters; more complex architecture

Cross-modal attention offers the best trade-off for most applications: for 196 image patches and 128 text tokens, cross-attention requires $196 \times 128 = 25{,}088$ computations per head versus $324^2 = 104{,}976$ for early fusion---a 4$\times$ reduction while preserving fine-grained alignment between modalities.

Alignment Objectives

Contrastive Learning:

$$ \mathcal{L}_{\text{contrastive}} = -\log \frac{\exp(\text{sim}(v_i, t_i)/\tau)}{\sum_j \exp(\text{sim}(v_i, t_j)/\tau)} $$
where $v_i$ = image embedding, $t_i$ = text embedding, $\tau$ = temperature

Matching Loss:

$$ \mathcal{L}_{\text{match}} = -\mathbb{E}[\log P(\text{match}|v, t)] $$

Reconstruction:

$$ \mathcal{L}_{\text{recon}} = \|f_{\text{dec}}(v) - t\|^2 $$

CLIP: Contrastive Language-Image Pre-training

CLIP Architecture

CLIP (Contrastive Language-Image Pre-training) represents a breakthrough in vision-language learning by training image and text encoders jointly using a contrastive objective on 400 million image-text pairs collected from the internet. Unlike traditional supervised learning that requires manually labeled categories, CLIP learns to align images with their natural language descriptions, enabling zero-shot transfer to downstream tasks without any task-specific training data.

Definition: The CLIP architecture consists of three main components that work together to create a shared embedding space for images and text. The image encoder can be either a Vision Transformer (ViT) or a ResNet, which processes an input image and produces a fixed-dimensional embedding $\vv \in \R^{d}$. For the largest CLIP model (ViT-L/14), the image encoder is a ViT with patch size 14, hidden dimension 1024, 24 layers, and 16 attention heads, totaling approximately 304 million parameters. The text encoder is a transformer decoder (similar to GPT) with a context length of 77 tokens, hidden dimension 768, 12 layers, and 12 attention heads, containing roughly 63 million parameters. Both encoders are followed by learned linear projection layers that map their outputs to a shared embedding space of dimension $d = 512$, where cosine similarity can be computed directly.

The training procedure processes batches of $(image, text)$ pairs simultaneously. For each batch of size $N$, all $N$ images are encoded to produce embeddings $\vv_1, \ldots, \vv_N \in \R^{512}$, and all $N$ text descriptions are encoded to produce $\vt_1, \ldots, \vt_N \in \R^{512}$. The model then computes an $N \times N$ similarity matrix where entry $(i,j)$ represents the cosine similarity between image $i$ and text $j$. The contrastive loss maximizes the similarity along the diagonal (correct image-text pairs) while minimizing off-diagonal similarities (incorrect pairings). This symmetric loss is computed in both directions—predicting text from image and image from text—and averaged.

The parameter count for CLIP varies significantly across model scales. CLIP ResNet-50 contains approximately 102 million parameters (38M for ResNet-50 image encoder, 63M for text encoder, 1M for projections), while CLIP ViT-L/14 totals around 428 million parameters (304M for ViT-L image encoder, 123M for a larger text encoder with 768 dimensions and 12 layers, 1M for projections). The largest variant, ViT-L/14@336px, processes higher-resolution images (336×336 instead of 224×224) with the same architecture, increasing computational cost but improving performance on fine-grained tasks.

Example: Consider a training batch with size $N = 4$ and embedding dimension $d = 512$. The image encoder processes four images to produce embeddings arranged as rows in matrix $\mV = [\vv_1, \vv_2, \vv_3, \vv_4]\transpose \in \R^{4 \times 512}$, while the text encoder processes their corresponding captions to produce $\mT = [\vt_1, \vt_2, \vt_3, \vt_4]\transpose \in \R^{4 \times 512}$.

The similarity matrix is computed as $\mS = \mV \mT\transpose \in \R^{4 \times 4}$, where each entry $S_{ij}$ represents the dot product between image embedding $i$ and text embedding $j$. To make this scale-invariant, CLIP uses cosine similarity: $S_{ij} = \frac{\vv_i \cdot \vt_j}{\|\vv_i\| \|\vt_j\|}$, which normalizes each embedding to unit length before computing the dot product. This ensures that the similarity is determined by the angle between embeddings rather than their magnitudes.

The contrastive loss for the image-to-text direction is computed as:

$$ \mathcal{L}_i^{\text{img}\to\text{txt}} = -\log \frac{\exp(S_{ii}/\tau)}{\sum_{j=1}^{N} \exp(S_{ij}/\tau)} $$
where $\tau$ is a learned temperature parameter, initialized to $0.07$ and trained jointly with the model. The temperature controls the sharpness of the distribution: smaller values make the model more confident (sharper peaks), while larger values produce softer distributions. The symmetric text-to-image loss $\mathcal{L}_i^{\text{txt}\to\text{img}}$ is computed analogously by treating text as queries and images as candidates. The total loss averages both directions:
$$ \mathcal{L} = \frac{1}{2N} \sum_{i=1}^{N} (\mathcal{L}_i^{\text{img}\to\text{txt}} + \mathcal{L}_i^{\text{txt}\to\text{img}}) $$

In practice, CLIP uses very large batch sizes to provide more negative examples for contrastive learning. The original CLIP was trained with batch size 32,768, requiring distributed training across multiple GPUs. With such large batches, each positive pair has 32,767 negative examples, providing a strong learning signal. However, this creates substantial memory requirements: storing the $32{,}768 \times 512$ embedding matrices for images and text requires $32{,}768 \times 512 \times 4 = 67$ MB per modality in FP32, and the $32{,}768 \times 32{,}768$ similarity matrix requires $4.3$ GB. To make this tractable, CLIP uses gradient checkpointing and distributes the batch across many GPUs, computing the similarity matrix in chunks.

Computational Analysis of CLIP Training

Training CLIP at scale requires careful consideration of computational and memory costs across both the image and text encoding paths. For the ViT-L/14 image encoder processing 224×224 images, each image is divided into $16 \times 16 = 256$ patches of size $14 \times 14$. These patches are linearly projected to dimension 1024 and processed through 24 transformer layers. The computational cost per image is approximately $2 \times 24 \times 256^2 \times 1024 = 3.2$ GFLOPS for the attention operations (using the $2Ld^2n^2$ formula from Chapter 12) plus $2 \times 24 \times 256 \times 4 \times 1024^2 = 51.5$ GFLOPS for the feed-forward networks, totaling roughly 55 GFLOPS per image.

The text encoder processes sequences of up to 77 tokens through 12 transformer layers with dimension 768. The computational cost per text is approximately $2 \times 12 \times 77^2 \times 768 = 1.1$ GFLOPS for attention plus $2 \times 12 \times 77 \times 4 \times 768^2 = 4.4$ GFLOPS for feed-forward networks, totaling about 5.5 GFLOPS per text. This asymmetry—images requiring 10× more compute than text—means that image encoding dominates the computational budget during training.

For a batch of 32,768 examples, the total forward pass requires approximately $32{,}768 \times (55 + 5.5) = 1{,}982{,}464$ GFLOPS or roughly 2 PFLOPS. On an NVIDIA A100 GPU with 312 TFLOPS of FP16 compute, this would take approximately 6.4 seconds per batch for the forward pass alone, not including backward propagation (which typically costs 2× the forward pass) or the contrastive loss computation. The full training of CLIP on 400 million image-text pairs with batch size 32,768 requires approximately $400{,}000{,}000 / 32{,}768 = 12{,}207$ batches. At roughly 20 seconds per batch (forward + backward + optimizer step), this amounts to 68 hours of continuous training on a single A100. In practice, OpenAI trained CLIP on 256 V100 GPUs for approximately 12 days, suggesting a total training cost of around 73,728 GPU-hours.

Memory requirements are equally demanding. Each image in the batch requires storing activations for 24 layers with 256 tokens and dimension 1024, totaling approximately $24 \times 256 \times 1024 \times 2 = 12.6$ MB per image in FP16 (the factor of 2 accounts for storing both pre- and post-activation values for backpropagation). For batch size 32,768, this amounts to 413 GB just for image activations. Text activations are smaller at approximately $12 \times 77 \times 768 \times 2 = 1.4$ MB per text, or 46 GB for the full batch. The similarity matrix requires $32{,}768 \times 32{,}768 \times 2 = 2.1$ GB in FP16. Combined with model parameters (428M parameters × 2 bytes = 856 MB) and optimizer states (typically 2× parameters for Adam), the total memory footprint exceeds 500 GB, necessitating distribution across many GPUs using techniques like ZeRO (Chapter 22) to partition optimizer states and activations.

Zero-Shot Classification with CLIP

One of CLIP's most remarkable capabilities is zero-shot classification: the ability to classify images into categories the model has never been explicitly trained on. This works by leveraging the natural language understanding of the text encoder to create classifiers on the fly from text descriptions. The procedure begins by creating text prompts for each class in the target classification task. For example, for a 10-class animal classification task, we might create prompts like "a photo of a dog", "a photo of a cat", "a photo of a bird", and so on. These prompts are encoded by the text encoder to produce class embeddings $\vt_1, \ldots, \vt_C \in \R^{512}$ where $C$ is the number of classes.

To classify a new image, we encode it with the image encoder to produce $\vv \in \R^{512}$, then compute the cosine similarity between the image embedding and each class embedding: $s_i = \frac{\vv \cdot \vt_i}{\|\vv\| \|\vt_i\|}$. The predicted class is simply $\arg\max_i s_i$, the class whose text description has the highest similarity to the image. This approach requires no training on the target dataset—the model uses only its pre-trained knowledge of how images and text relate.

The performance of this zero-shot approach is surprisingly strong. CLIP ViT-L/14 achieves 76.2\% top-1 accuracy on ImageNet without ever seeing a single ImageNet training example, matching the performance of a ResNet-50 trained directly on ImageNet's 1.28 million labeled images. This demonstrates that CLIP has learned visual concepts that generalize far beyond its training distribution. Moreover, CLIP shows remarkable robustness to distribution shift: when evaluated on ImageNet variants with different image styles (sketches, cartoons, adversarial examples), CLIP's performance degrades much less than supervised models, suggesting it has learned more robust visual representations.

The prompt engineering aspect of zero-shot classification is crucial for performance. Simple prompts like "dog" perform worse than more descriptive prompts like "a photo of a dog". OpenAI found that using prompt ensembles—averaging predictions across multiple prompt templates like "a photo of a \{class\}", "a picture of a \{class\}", "an image of a \{class\}"—improves accuracy by 1-2\% by reducing sensitivity to prompt phrasing. For fine-grained classification tasks, more specific prompts help: "a photo of a \{species\}, a type of bird" outperforms "a photo of a \{species\}" for bird species classification.

CLIP Variants and Training Requirements

Following CLIP's success, several variants have been developed with different scales and training procedures. OpenCLIP is an open-source reproduction that has trained models ranging from small (ResNet-50 with 102M parameters) to very large (ViT-G/14 with 1.8B parameters) on datasets including LAION-400M and LAION-2B. The largest OpenCLIP models require training on clusters of 128-512 A100 GPUs for several weeks, with estimated costs exceeding \$100,000 for the full training run. The training uses mixed precision (FP16) to reduce memory consumption and enable larger batch sizes, typically 32,768 to 65,536 examples distributed across all GPUs.

ALIGN, developed by Google, scales up the training data to 1.8 billion noisy image-text pairs collected from the web without extensive filtering. This demonstrates that contrastive learning is robust to noise in the training data—the model learns to ignore mismatched pairs through the contrastive objective. ALIGN uses an EfficientNet-L2 image encoder (480M parameters) and a BERT-Large text encoder (340M parameters), totaling approximately 820M parameters. Training ALIGN required a cluster of 1024 Cloud TPU v3 cores for approximately 6 days, representing roughly 150,000 TPU-hours.

Florence, Microsoft's unified vision foundation model, extends the CLIP approach to 900 million image-text pairs with a focus on creating a single model that can be adapted to diverse vision tasks. Florence uses a CoSwin transformer as the image encoder (637M parameters) and achieves state-of-the-art results on zero-shot classification, retrieval, and object detection after fine-tuning. The training infrastructure required 512 NVIDIA A100 GPUs for approximately 10 days, with an estimated cost of over \$200,000 in cloud compute.

The hardware requirements for training CLIP-scale models are substantial. A minimum viable setup might use 8-16 A100 GPUs (80GB each) to train a CLIP ResNet-50 model on a smaller dataset like Conceptual Captions (3M pairs) with batch size 2048-4096, requiring approximately 1-2 weeks. Scaling to the full CLIP ViT-L/14 with 400M training pairs and batch size 32,768 necessitates at least 64-128 A100 GPUs with high-bandwidth interconnects (NVLink or InfiniBand) to efficiently synchronize gradients across the distributed batch. The total training cost for reproducing CLIP ViT-L/14 is estimated at \$50,000-\$100,000 in cloud GPU costs, depending on the provider and optimization techniques employed.

DALL-E and Stable Diffusion

DALL-E: Text-to-Image Generation

Definition: DALL-E 1 (2021):

Sequence:

$$ [\text{BOS}, \text{text tokens}, \text{image tokens}, \text{EOS}] $$

Generate image by: (1) Encode text, (2) Sample image tokens autoregressively

DALL-E 2 (2022):

Stable Diffusion

Latent Diffusion Model:

  1. Encode image to latent space (VAE)
  2. Add noise iteratively (forward diffusion)
  3. Learn to denoise (reverse diffusion)
  4. Condition on text via cross-attention

Text conditioning:

Example: Components:

1. Text Encoder: CLIP text encoder

$$ \text{prompt} \to \vt \in \R^{77 \times 768} $$

2. VAE Encoder: Image $\to$ latent

$$ \mI \in \R^{512 \times 512 \times 3} \to \vz \in \R^{64 \times 64 \times 4} $$

3. U-Net Denoiser: Diffusion model with cross-attention

4. VAE Decoder: Latent $\to$ image

$$ \vz \in \R^{64 \times 64 \times 4} \to \mI \in \R^{512 \times 512 \times 3} $$

Parameters: $\approx 860$M total

Vision-Language Understanding

BLIP: Bootstrapped Language-Image Pre-training

Architecture:

Training objectives:

  1. ITC: Image-Text Contrastive (like CLIP)
  2. ITM: Image-Text Matching (binary: match or not)
  3. LM: Language Modeling on text

Bootstrapping: Generate synthetic captions, filter with model, retrain

Flamingo: Few-Shot Learning

Flamingo represents a significant architectural innovation in multimodal transformers by enabling models to process arbitrarily interleaved sequences of images and text, supporting few-shot learning through in-context examples. Unlike CLIP, which processes single image-text pairs, Flamingo can handle inputs like "Here is an image of a cat: . Here is an image of a dog: . What animal is in this image: ?" This capability enables few-shot visual learning where the model learns new tasks from just a few examples provided in the prompt, without any parameter updates.

The Flamingo architecture consists of three main components, carefully designed to leverage pre-trained models while adding minimal trainable parameters. The vision encoder is a frozen CLIP ViT-L/14 model that processes each image independently to produce a sequence of patch embeddings. For a 224×224 image with patch size 14, this yields 256 patch tokens of dimension 1024. The vision encoder's 304M parameters remain frozen throughout training, preserving the strong visual representations learned during CLIP pre-training.

The language model is a frozen Chinchilla 70B model, a large autoregressive transformer trained on text-only data. Chinchilla uses 70 billion parameters across 80 layers with hidden dimension 8192 and 64 attention heads. Keeping this massive language model frozen is crucial for computational tractability—training 70B parameters would require prohibitive memory and compute. Instead, Flamingo inserts new trainable layers that allow the frozen language model to attend to visual information without modifying its core text processing capabilities.

The key innovation is the Perceiver Resampler, a learned module that compresses the variable-length sequence of image patch embeddings into a fixed number of visual tokens that can be efficiently processed by the language model. The Perceiver Resampler uses cross-attention where a fixed set of learned queries $\mQ \in \R^{64 \times 2048}$ (64 visual tokens, dimension 2048) attends to the image patch embeddings $\mK, \mV \in \R^{256 \times 1024}$ from the vision encoder. This produces a fixed-size representation regardless of input image resolution or the number of images in the sequence. The Perceiver Resampler contains approximately 1.4B parameters (6 layers of cross-attention and feed-forward networks with dimension 2048), making it the primary trainable component of Flamingo.

Between every few layers of the frozen language model, Flamingo inserts new cross-attention layers that allow text tokens to attend to the visual tokens produced by the Perceiver Resampler. Specifically, for Flamingo-80B (built on Chinchilla-70B), cross-attention layers are inserted after every 7th transformer layer, resulting in approximately 11 cross-attention insertions across the 80 layers. Each cross-attention layer adds roughly 134M parameters (for dimension 8192), totaling about 1.5B parameters for all insertions. Combined with the Perceiver Resampler, Flamingo adds approximately 2.9B trainable parameters to the 70B frozen base model, representing just 4\% additional parameters while enabling full multimodal capabilities.

The memory requirements for Flamingo are dominated by the frozen language model. Storing 70B parameters in FP16 requires 140 GB, which exceeds the memory of any single GPU. Flamingo uses model parallelism to partition the language model across multiple GPUs—for example, distributing across 8 A100 GPUs (80GB each) places roughly 8.75B parameters per GPU, consuming about 17.5 GB for parameters. Activations for a sequence of 2048 tokens (including both text and visual tokens) across 80 layers with dimension 8192 require approximately $2048 \times 80 \times 8192 \times 2 = 2.6$ GB per example in FP16. With batch size 8, activations consume 21 GB per GPU, leaving sufficient memory for gradients of the trainable parameters (2.9B parameters × 2 bytes × 2 for gradients = 11.6 GB) and optimizer states (23.2 GB for Adam).

Training Flamingo on a mixture of image-text pairs, interleaved image-text documents, and video-text pairs requires substantial computational resources. The training dataset consists of 2.3 billion image-text pairs (similar to CLIP), 43 million interleaved image-text web pages, and 27 million video clips. Training Flamingo-80B for 1 epoch through this data with batch size 256 distributed across 256 A100 GPUs takes approximately 15 days, representing roughly 92,000 GPU-hours. The estimated training cost exceeds \$300,000 in cloud compute. However, the key advantage is that only 2.9B parameters are trained while leveraging the capabilities of a 70B language model, making training far more efficient than training a 70B multimodal model from scratch.

For inference, Flamingo's few-shot learning capability means that users can provide 2-32 example image-text pairs in the prompt to demonstrate a new task, and the model adapts its predictions based on these examples without any fine-tuning. This in-context learning works because the cross-attention mechanism allows the model to attend to the example images when processing the query image. The computational cost of inference scales linearly with the number of examples in the context: each additional image adds 256 patch tokens (after vision encoding) compressed to 64 visual tokens (after Perceiver Resampler), increasing the sequence length and thus the attention cost. For a prompt with 4 example images and 1 query image (5 images total), the visual tokens contribute $5 \times 64 = 320$ tokens to the sequence, which combined with text tokens (typically 500-1000) results in sequences of 800-1300 tokens. On a single A100 GPU, Flamingo-80B can process approximately 2-3 such sequences per second, limited primarily by the memory bandwidth required to load the 70B parameter model.

Computational Analysis of Multimodal Transformers

Multimodal transformers follow the same FLOPs formulas derived in Chapter~12 for their individual encoders: each transformer layer costs $24Bnd_{\text{model}}^2 + 4Bn^2d_{\text{model}}$ FLOPs (attention plus feed-forward). The multimodal-specific addition is cross-modal attention, which costs $4mnd$ FLOPs per layer (where $m$ and $n$ are the sequence lengths of the two modalities). In practice, cross-modal attention is a small fraction of total cost---for BLIP with 128 text tokens and 196 image patches, cross-attention adds only 462~MFLOPs across 6 layers, negligible compared to the self-attention costs in each encoder.

The key computational asymmetry in multimodal models is between modalities: image encoding typically dominates. CLIP's ViT-L/14 requires $\sim$55~GFLOPs per image versus $\sim$5.5~GFLOPs per text, a 10$\times$ ratio. When a large language model serves as the text backbone (as in Flamingo with Chinchilla-70B), text processing dominates instead, requiring $\sim$110~TFLOPs per sequence.

Memory requirements follow the same principles as unimodal transformers (Chapter~12): parameters, gradients, optimizer states, and activations. The multimodal-specific concern is storing activations for both modalities simultaneously. For CLIP ViT-L/14, image activations consume $\sim$75~MB per image in FP16 while text activations require $\sim$1.4~MB per text. For large batch sizes (32,768 in CLIP), this necessitates distributed training with gradient checkpointing and mixed precision (see Chapter~11 for distributed training techniques).

Training Challenges for Multimodal Transformers

Batch Size Requirements for Contrastive Learning

Contrastive learning methods like CLIP require very large batch sizes to provide sufficient negative examples. CLIP's performance scales log-linearly with batch size: increasing from 256 to 32,768 improves ImageNet zero-shot accuracy from $\sim$58\% to 76\%. However, the $32{,}768 \times 32{,}768$ similarity matrix alone requires 4.3~GB in FP32. To make this tractable, CLIP distributes the batch across 256 GPUs using all-gather communication, so the full similarity matrix is never materialized on any single GPU.

Distributed Training and Memory Optimization

Multimodal transformers use the same distributed training techniques as unimodal models (see Chapter~11 for detailed coverage): data parallelism for CLIP-scale models that fit on a single GPU, tensor and pipeline parallelism for larger models like Flamingo-80B where the 70B parameter language model must be partitioned across multiple GPUs. Memory optimization techniques---gradient checkpointing, mixed precision training, and ZeRO optimizer state partitioning---are essential and apply identically to the multimodal setting.

The multimodal-specific challenge is the asymmetric memory profile: image activations ($\sim$75~MB per image for ViT-L) far exceed text activations ($\sim$1.4~MB per text for CLIP's encoder), so image encoding dominates the memory budget during training. For Flamingo-80B, the frozen 70B language model requires 140~GB in FP16, necessitating model parallelism across at least 2 A100 GPUs before accounting for activations or trainable parameters.

Audio Transformers

Whisper: Speech Recognition

Definition: Encoder-decoder transformer for speech:

Input: Audio waveform $\to$ Log-mel spectrogram

Encoder:

Decoder:

Training data: 680,000 hours of multilingual audio

Tasks supported:

Example: Special tokens:

<|startoftranscript|><|en|><|transcribe|><|notimestamps|>

Spectrogram:

Encoder:

Decoder: Generate text tokens autoregressively

Audio-Text Pre-training

Contrastive learning: Like CLIP but audio-text

AudioCLIP: Tri-modal (image, text, audio)

Applications:

Unified Multimodal Models

Perceiver and Perceiver IO

Key idea: Map arbitrary modalities to latent space via cross-attention

Definition: Components:

1. Latent array: Fixed set of learned queries $\mZ \in \R^{M \times d}$

2. Cross-attention: Latents attend to inputs

$$ \mZ_1 = \text{CrossAttn}(\mQ=\mZ, \mK=\mX, \mV=\mX) $$

3. Transformer: Process latents

$$ \mZ_L = \text{Transformer}(\mZ_1) $$

4. Output: Decode latents to task outputs

Benefits:

GPT-4V and LLaVA

GPT-4V (Vision): GPT-4 with vision capabilities

LLaVA (Open-source):

Exercises

Exercise 1: Implement CLIP contrastive loss for batch size 8:
  1. Generate random image embeddings $(8, 512)$
  2. Generate random text embeddings $(8, 512)$
  3. Compute $8 \times 8$ similarity matrix
  4. Calculate contrastive loss with $\tau = 0.07$
Exercise 2: Use CLIP for zero-shot classification on CIFAR-10:
  1. Load pre-trained CLIP model
  2. Create text prompts for 10 classes
  3. Encode images and prompts
  4. Compute accuracy
  5. Compare to supervised baseline
Exercise 3: Analyze Whisper architecture:
  1. Calculate parameters for encoder (24 layers, $d=1024$)
  2. Calculate parameters for decoder (24 layers)
  3. Estimate memory for 30-second audio
  4. Compare to text-only GPT-2
Exercise 4: Design multimodal fusion strategy for video understanding (visual + audio + captions):
  1. Propose architecture
  2. Define fusion mechanism
  3. Specify training objective
  4. Estimate parameter count

Solutions

Full solutions for all exercises are available at \url{https://deeplearning.hofkensvermeulen.be}.

Solution: Exercise 1: CLIP Contrastive Loss Implementation
import torch
import torch.nn as nn
import torch.nn.functional as F

def clip_contrastive_loss(image_embeddings, text_embeddings, temperature=0.07):
    """
    Compute CLIP contrastive loss
    Args:
        image_embeddings: (B, D) normalized image embeddings
        text_embeddings: (B, D) normalized text embeddings
        temperature: temperature parameter tau
    Returns:
        loss: scalar contrastive loss
    """
    # Normalize embeddings
    image_embeddings = F.normalize(image_embeddings, dim=-1)
    text_embeddings = F.normalize(text_embeddings, dim=-1)
    
    # Compute similarity matrix (B, B)
    logits = torch.matmul(image_embeddings, text_embeddings.t()) / temperature
    
    # Labels: diagonal elements are positive pairs
    batch_size = image_embeddings.shape[0]
    labels = torch.arange(batch_size, device=image_embeddings.device)
    
    # Symmetric loss: image-to-text + text-to-image
    loss_i2t = F.cross_entropy(logits, labels)
    loss_t2i = F.cross_entropy(logits.t(), labels)
    
    loss = (loss_i2t + loss_t2i) / 2
    
    return loss, logits

# Part (a): Generate random embeddings
batch_size = 8
embed_dim = 512

image_embeddings = torch.randn(batch_size, embed_dim)
text_embeddings = torch.randn(batch_size, embed_dim)

print(f"Image embeddings shape: {image_embeddings.shape}")
print(f"Text embeddings shape: {text_embeddings.shape}")

# Part (b): Normalize embeddings
image_embeddings = F.normalize(image_embeddings, dim=-1)
text_embeddings = F.normalize(text_embeddings, dim=-1)

print(f"\nAfter normalization:")
print(f"Image embedding norms: {torch.norm(image_embeddings, dim=-1)}")
print(f"Text embedding norms: {torch.norm(text_embeddings, dim=-1)}")

# Part (c): Compute similarity matrix
temperature = 0.07
similarity_matrix = torch.matmul(image_embeddings, text_embeddings.t()) / temperature

print(f"\nSimilarity matrix shape: {similarity_matrix.shape}")
print(f"Similarity matrix:\n{similarity_matrix}")

# Part (d): Calculate contrastive loss
loss, logits = clip_contrastive_loss(image_embeddings, text_embeddings, temperature)

print(f"\nContrastive loss: {loss.item():.4f}")
print(f"Logits shape: {logits.shape}")

# Analyze the loss
labels = torch.arange(batch_size)
predictions_i2t = logits.argmax(dim=1)
predictions_t2i = logits.t().argmax(dim=1)

accuracy_i2t = (predictions_i2t == labels).float().mean()
accuracy_t2i = (predictions_t2i == labels).float().mean()

print(f"\nImage-to-Text accuracy: {accuracy_i2t.item():.2
print(f"Text-to-Image accuracy: {accuracy_t2i.item():.2

Mathematical Derivation:

Part (a) \& (b): Embeddings

Image embeddings: $\vI = [\vi_1, \vi_2, \ldots, \vi_8] \in \mathbb{R}^{8 \times 512}$

Text embeddings: $\vT = [\vt_1, \vt_2, \ldots, \vt_8] \in \mathbb{R}^{8 \times 512}$

Normalize to unit sphere: $\hat{\vi}_i = \frac{\vi_i}{\|\vi_i\|_2}, \quad \hat{\vt}_i = \frac{\vt_i}{\|\vt_i\|_2}$

Part (c): Similarity Matrix

Cosine similarity matrix: $\vS_{ij} = \frac{\hat{\vi}_i \cdot \hat{\vt}_j}{\tau}$

where $\tau = 0.07$ is the temperature parameter.

Full matrix: $\vS = \frac{1}{\tau} \hat{\vI} \hat{\vT}^T \in \mathbb{R}^{8 \times 8}$

Example: \[ \vS = \begin{bmatrix} s_{11} & s_{12} & \cdots & s_{18} \\ s_{21} & s_{22} & \cdots & s_{28} \\ \vdots & \vdots & \ddots & \vdots \\ s_{81} & s_{82} & \cdots & s_{88} \end{bmatrix} \]

Diagonal elements $s_{ii}$ are positive pairs (matched image-text).

Off-diagonal elements $s_{ij}$ ($i \neq j$) are negative pairs.

Part (d): Contrastive Loss

Image-to-Text Loss:

For each image $i$, predict its matching text from 8 candidates:

$\mathcal{L}_{i2t} = -\frac{1}{8} \sum_{i=1}^{8} \log \frac{\exp(s_{ii})}{\sum_{j=1}^{8} \exp(s_{ij})}$

This is cross-entropy with labels $y_i = i$ (diagonal).

Text-to-Image Loss:

For each text $j$, predict its matching image from 8 candidates:

$\mathcal{L}_{t2i} = -\frac{1}{8} \sum_{j=1}^{8} \log \frac{\exp(s_{jj})}{\sum_{i=1}^{8} \exp(s_{ij})}$

Total CLIP Loss:

$\mathcal{L}_{\text{CLIP}} = \frac{1}{2}(\mathcal{L}_{i2t} + \mathcal{L}_{t2i})$

Symmetric loss ensures both modalities learn aligned representations.

Why Temperature $\tau = 0.07$?

Effect of temperature:

Numerical Example:

Suppose for image 1:

Softmax probabilities: $p_1 = \frac{\exp(0.9/0.07)}{\exp(0.9/0.07) + \sum_{j=2}^{8} \exp(s_{1j}/0.07)}$

Loss for image 1: $\ell_1 = -\log p_1$

If $p_1 \approx 1$, then $\ell_1 \approx 0$ (good alignment).

If $p_1 \approx 0.125$ (uniform), then $\ell_1 \approx 2.08$ (poor alignment).

Training Dynamics:

  1. Initial: Random embeddings, $\mathcal{L} \approx \log(8) = 2.08$
  2. Training: Embeddings align, diagonal elements increase
  3. Converged: $s_{ii} \gg s_{ij}$ for $i \neq j$, $\mathcal{L} \to 0$

Key Insights:

Solution: Exercise 2: CLIP Zero-Shot Classification on CIFAR-10
import torch
import clip
from PIL import Image
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

# Part (a): Load pre-trained CLIP model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

print(f"CLIP model loaded on {device}")
print(f"Model: ViT-B/32")

# Part (b): Create text prompts for 10 CIFAR-10 classes
cifar10_classes = [
    "airplane", "automobile", "bird", "cat", "deer",
    "dog", "frog", "horse", "ship", "truck"
]

# Template-based prompts (improves accuracy)
templates = [
    "a photo of a {}.",
    "a blurry photo of a {}.",
    "a photo of many {}.",
    "a photo of the small {}.",
    "a photo of the large {}.",
]

# Encode text prompts
def encode_text_prompts(model, classes, templates):
    """Encode text prompts with multiple templates"""
    text_features = []
    
    for classname in classes:
        # Create prompts from templates
        texts = [template.format(classname) for template in templates]
        texts = clip.tokenize(texts).to(device)
        
        # Encode texts
        with torch.no_grad():
            class_embeddings = model.encode_text(texts)
            class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True)
            
            # Average over templates
            class_embedding = class_embeddings.mean(dim=0)
            class_embedding = class_embedding / class_embedding.norm()
            
            text_features.append(class_embedding)
    
    text_features = torch.stack(text_features, dim=0)
    return text_features

text_features = encode_text_prompts(model, cifar10_classes, templates)
print(f"\nText features shape: {text_features.shape}")  # (10, 512)

# Part (c): Load CIFAR-10 test set
test_dataset = torchvision.datasets.CIFAR10(
    root='./data', 
    train=False, 
    download=True,
    transform=preprocess
)

test_loader = DataLoader(test_dataset, batch_size=100, shuffle=False)

# Zero-shot classification
def zero_shot_classify(model, loader, text_features):
    """Perform zero-shot classification"""
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in tqdm(loader):
            images = images.to(device)
            labels = labels.to(device)
            
            # Encode images
            image_features = model.encode_image(images)
            image_features = image_features / image_features.norm(dim=-1, keepdim=True)
            
            # Compute similarity with text features
            similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
            
            # Predict
            predictions = similarity.argmax(dim=-1)
            
            correct += (predictions == labels).sum().item()
            total += labels.size(0)
    
    accuracy = 100.0 * correct / total
    return accuracy

# Part (d): Compute accuracy
zero_shot_accuracy = zero_shot_classify(model, test_loader, text_features)
print(f"\nZero-shot accuracy: {zero_shot_accuracy:.2f}



# Part (e): Compare to supervised baseline
# Train a simple supervised classifier
class SimpleCNN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(3, 32, 3, padding=1)
        self.conv2 = torch.nn.Conv2d(32, 64, 3, padding=1)
        self.pool = torch.nn.MaxPool2d(2, 2)
        self.fc1 = torch.nn.Linear(64 * 8 * 8, 128)
        self.fc2 = torch.nn.Linear(128, 10)
    
    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.view(-1, 64 * 8 * 8)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Train supervised model (simplified)
supervised_model = SimpleCNN().to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(supervised_model.parameters(), lr=0.001)

# Training loop (10 epochs for quick comparison)
train_dataset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True,
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)

for epoch in range(10):
    supervised_model.train()
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = supervised_model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

# Evaluate supervised model
supervised_model.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = supervised_model(images)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

supervised_accuracy = 100.0 * correct / total

print(f"\nComparison:")
print(f"CLIP Zero-shot: {zero_shot_accuracy:.2f}
print(f"Supervised CNN (10 epochs): {supervised_accuracy:.2f}

Expected Results:

MethodAccuracyTraining Data
CLIP Zero-shot (ViT-B/32)89-91\%0 (CIFAR-10)
CLIP Zero-shot (ViT-L/14)93-95\%0 (CIFAR-10)
Supervised CNN (10 epochs)70-75\%50k (CIFAR-10)
Supervised ResNet-50 (200 epochs)95-96\%50k (CIFAR-10)

Analysis:

Part (a): Pre-trained CLIP Model

CLIP models available:

Pre-training:

Part (b): Text Prompts

Simple prompts:

"airplane", "automobile", "bird", ...
Template-based prompts (better):

"a photo of a airplane."
"a blurry photo of a airplane."
"a photo of many airplanes."

Why templates help:

Prompt engineering tips:

Part (c): Encoding Process

Image encoding:

  1. Preprocess: Resize to $224 \times 224$, normalize
  2. ViT encoder: Extract features
  3. Projection: Map to shared embedding space (512-dim)
  4. Normalize: $\hat{\vi} = \vi / \|\vi\|_2$

Text encoding:

  1. Tokenize: Convert text to token IDs
  2. Text encoder: Transformer processes tokens
  3. Projection: Map to shared embedding space (512-dim)
  4. Normalize: $\hat{\vt} = \vt / \|\vt\|_2$

Part (d): Zero-Shot Classification

Algorithm:

For each test image $\vx$:

  1. Encode image: $\vi = \text{ImageEncoder}(\vx)$
  2. Compute similarity with all class embeddings: $s_k = \vi \cdot \vt_k$ for $k = 1, \ldots, 10$
  3. Apply softmax: $p_k = \frac{\exp(s_k / \tau)}{\sum_{j=1}^{10} \exp(s_j / \tau)}$
  4. Predict: $\hat{y} = \argmax_k p_k$

Temperature $\tau = 0.01$ (learned during training).

Mathematical Formulation:

$P(y = k | \vx) = \frac{\exp(\text{sim}(\vi, \vt_k) / \tau)}{\sum_{j=1}^{10} \exp(\text{sim}(\vi, \vt_j) / \tau)}$

where $\text{sim}(\vi, \vt) = \vi \cdot \vt$ (cosine similarity after normalization).

Part (e): Comparison with Supervised Baseline

Why CLIP Zero-Shot Outperforms Supervised CNN:

  1. Pre-training scale: 400M image-text pairs vs 50k CIFAR-10 images
  2. Transfer learning: Leverages knowledge from diverse data
  3. Better architecture: ViT-B/32 vs simple CNN
  4. Semantic understanding: Learns concepts, not just patterns
  5. Robustness: Generalizes better to distribution shifts

When Supervised Wins:

CLIP Advantages:

  1. No training required: Instant deployment
  2. Flexible: Change classes without retraining
  3. Interpretable: Natural language descriptions
  4. Robust: Handles distribution shifts better
  5. Multimodal: Can do image-text retrieval, captioning, etc.

Practical Recommendations:

ScenarioRecommendation
Quick prototypeCLIP zero-shot
Fixed classes, lots of dataSupervised training
Changing classes frequentlyCLIP zero-shot
Maximum accuracyFine-tune CLIP
Limited computeSupervised small model
Interpretability neededCLIP with prompts

Improving CLIP Zero-Shot:

  1. Better prompts: Domain-specific templates
  2. Larger model: ViT-L/14 instead of ViT-B/32
  3. Ensemble: Average predictions from multiple prompts
  4. Few-shot: Add a few examples with linear probe
  5. Fine-tuning: Adapt to target domain

Key Takeaways:

Solution: Exercise 3: Whisper Architecture Analysis

Part (a): Encoder Parameters (24 layers, $d=1024$)

Whisper Encoder Configuration:

Parameter Breakdown:

1. Input Convolution Layers:

2. Position Embeddings:

3. Per Transformer Layer:

Multi-Head Attention:

MLP:

Layer Normalization:

Total per layer: $12{,}587{,}008$ parameters

4. All 24 Encoder Layers: $24 \times 12{,}587{,}008 = 302{,}088{,}192$ parameters

Total Encoder: $\approx 305.5$M parameters

Part (b): Decoder Parameters (24 layers)

Whisper Decoder Configuration:

Parameter Breakdown:

1. Token Embedding:

2. Position Embeddings:

3. Per Decoder Layer:

Masked Self-Attention:

Cross-Attention:

MLP:

Layer Normalization:

Total per decoder layer: $16{,}783{,}360$ parameters

4. All 24 Decoder Layers: $24 \times 16{,}783{,}360 = 402{,}800{,}640$ parameters

5. Output Projection:

Total Decoder: $\approx 456.4$M parameters

Total Whisper Model: $305.5 + 456.4 = 761.9$M parameters

(Actual Whisper-large: $\approx 1.55$B parameters due to additional components)

Part (c): Memory for 30-Second Audio

Input Processing:

1. Audio Preprocessing:

2. Log-Mel Spectrogram:

Encoder Memory (Inference):

1. Activations per layer:

Peak per layer: $\approx 159$M values $\times$ 4 bytes = 636 MB

2. Total encoder activations: $24 \times 636$ MB = 15.3 GB (if storing all layers)

With activation checkpointing: $\approx 1.3$ GB

Decoder Memory (Inference):

For generating 448 tokens:

Decoder memory: $\approx 500$ MB

Total Memory (Inference):

For FP16: $\approx 4.1$ GB

For INT8 quantization: $\approx 2.1$ GB

Part (d): Compare to Text-Only GPT-2

GPT-2 (1.5B parameters):

Comparison Table:

MetricWhisper-largeGPT-2 (1.5B)
Total Parameters1.55B1.5B
Encoder Layers24N/A
Decoder Layers2448
Hidden Size10241600
Attention Heads1625
Input ModalityAudioText
Output ModalityTextText
Context Length3000 (audio) + 448 (text)1024 (text)
Memory (FP32)8.1 GB6.5 GB
Inference SpeedSlower (audio encoding)Faster

Key Differences:

  1. Architecture:
    • Whisper: Encoder-decoder (like T5)
    • GPT-2: Decoder-only
  2. Input Processing:
    • Whisper: Audio $\to$ Log-mel $\to$ Encoder
    • GPT-2: Text $\to$ Tokens $\to$ Decoder
  3. Computational Cost:
    • Whisper encoder: $O(T^2 d)$ where $T = 3000$
    • GPT-2: $O(n^2 d)$ where $n = 1024$
    • Whisper is $\approx 9\times$ more expensive for encoder
  4. Memory Footprint:
    • Whisper: Larger due to long audio sequences
    • GPT-2: Smaller, text-only
  5. Use Cases:
    • Whisper: Speech recognition, translation, transcription
    • GPT-2: Text generation, completion, summarization

Why Whisper Needs Encoder-Decoder:

Performance Comparison:

TaskWhisperGPT-2
Speech RecognitionExcellentN/A
Text GenerationN/AExcellent
Multilingual99 languagesLimited
RobustnessHigh (noisy audio)N/A
Zero-shotStrongStrong

Practical Considerations:

  1. Deployment:
    • Whisper: Requires audio preprocessing
    • GPT-2: Simple tokenization
  2. Latency:
    • Whisper: Higher (audio encoding + decoding)
    • GPT-2: Lower (text-only)
  3. Hardware:
    • Whisper: Needs GPU for real-time (8+ GB VRAM)
    • GPT-2: Can run on CPU for small batches

Key Insights:

Solution: Exercise 4: Multimodal Fusion for Video Understanding

Part (a): Proposed Architecture

import torch
import torch.nn as nn

class MultimodalVideoTransformer(nn.Module):
    def __init__(self, 
                 visual_dim=768,      # ViT features
                 audio_dim=512,       # Audio features
                 text_dim=768,        # BERT features
                 hidden_dim=1024,     # Fusion dimension
                 num_layers=12,       # Fusion transformer layers
                 num_heads=16,
                 num_classes=400):    # Action recognition classes
        super().__init__()
        
        # Modality-specific encoders
        self.visual_encoder = VisualEncoder(visual_dim, hidden_dim)
        self.audio_encoder = AudioEncoder(audio_dim, hidden_dim)
        self.text_encoder = TextEncoder(text_dim, hidden_dim)
        
        # Modality-specific tokens
        self.visual_token = nn.Parameter(torch.randn(1, 1, hidden_dim))
        self.audio_token = nn.Parameter(torch.randn(1, 1, hidden_dim))
        self.text_token = nn.Parameter(torch.randn(1, 1, hidden_dim))
        
        # Fusion transformer
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=num_heads,
            dim_feedforward=hidden_dim * 4,
            dropout=0.1,
            batch_first=True
        )
        self.fusion_transformer = nn.TransformerEncoder(
            encoder_layer, 
            num_layers=num_layers
        )
        
        # Classification head
        self.classifier = nn.Sequential(
            nn.LayerNorm(hidden_dim),
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, num_classes)
        )
    
    def forward(self, visual_features, audio_features, text_features):
        """
        Args:
            visual_features: (B, T_v, D_v) - video frames
            audio_features: (B, T_a, D_a) - audio segments
            text_features: (B, T_t, D_t) - caption tokens
        Returns:
            logits: (B, num_classes)
        """
        B = visual_features.shape[0]
        
        # Encode each modality
        visual_emb = self.visual_encoder(visual_features)  # (B, T_v, H)
        audio_emb = self.audio_encoder(audio_features)     # (B, T_a, H)
        text_emb = self.text_encoder(text_features)        # (B, T_t, H)
        
        # Add modality tokens
        visual_token = self.visual_token.expand(B, -1, -1)
        audio_token = self.audio_token.expand(B, -1, -1)
        text_token = self.text_token.expand(B, -1, -1)
        
        visual_emb = torch.cat([visual_token, visual_emb], dim=1)
        audio_emb = torch.cat([audio_token, audio_emb], dim=1)
        text_emb = torch.cat([text_token, text_emb], dim=1)
        
        # Concatenate all modalities
        multimodal_emb = torch.cat([visual_emb, audio_emb, text_emb], dim=1)
        # Shape: (B, 1+T_v + 1+T_a + 1+T_t, H)
        
        # Fusion transformer
        fused = self.fusion_transformer(multimodal_emb)
        
        # Aggregate: average modality tokens
        visual_rep = fused[:, 0, :]
        audio_rep = fused[:, 1+visual_features.shape[1], :]
        text_rep = fused[:, 1+visual_features.shape[1]+1+audio_features.shape[1], :]
        
        # Combine representations
        combined = (visual_rep + audio_rep + text_rep) / 3
        
        # Classification
        logits = self.classifier(combined)
        
        return logits

class VisualEncoder(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.proj = nn.Linear(input_dim, output_dim)
        self.norm = nn.LayerNorm(output_dim)
    
    def forward(self, x):
        return self.norm(self.proj(x))

class AudioEncoder(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.proj = nn.Linear(input_dim, output_dim)
        self.norm = nn.LayerNorm(output_dim)
    
    def forward(self, x):
        return self.norm(self.proj(x))

class TextEncoder(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.proj = nn.Linear(input_dim, output_dim)
        self.norm = nn.LayerNorm(output_dim)
    
    def forward(self, x):
        return self.norm(self.proj(x))

# Example usage
model = MultimodalVideoTransformer()

# Simulate inputs
batch_size = 4
visual = torch.randn(batch_size, 16, 768)   # 16 frames
audio = torch.randn(batch_size, 32, 512)    # 32 audio segments
text = torch.randn(batch_size, 20, 768)     # 20 caption tokens

logits = model(visual, audio, text)
print(f"Output shape: {logits.shape}")  # (4, 400)

Part (b): Fusion Mechanism

Architecture Overview:


Input:
  Visual: (B, 16, 768)  - 16 video frames from ViT
  Audio:  (B, 32, 512)  - 32 audio segments from audio encoder
  Text:   (B, 20, 768)  - 20 caption tokens from BERT

Step 1: Modality-Specific Projection
  Visual -> (B, 16, 1024)
  Audio  -> (B, 32, 1024)
  Text   -> (B, 20, 1024)

Step 2: Add Modality Tokens
  Visual: [V_token, v1, v2, ..., v16]  -> (B, 17, 1024)
  Audio:  [A_token, a1, a2, ..., a32]  -> (B, 33, 1024)
  Text:   [T_token, t1, t2, ..., t20]  -> (B, 21, 1024)

Step 3: Concatenate
  Multimodal: [V_token, v1, ..., v16, A_token, a1, ..., a32, T_token, t1, ..., t20]
  Shape: (B, 71, 1024)

Step 4: Fusion Transformer (12 layers)
  Cross-modal attention enables interaction
  Output: (B, 71, 1024)

Step 5: Aggregate
  Extract modality tokens: V_token, A_token, T_token
  Average: (V_token + A_token + T_token) / 3
  Shape: (B, 1024)

Step 6: Classification
  MLP: (B, 1024) -> (B, 400)

Fusion Strategies Comparison:

  1. Early Fusion (Concatenation):
    • Concatenate features before transformer
    • Simple but limited cross-modal interaction
    • Used in this design
  2. Late Fusion (Ensemble):
    • Process modalities separately
    • Combine predictions at the end
    • No cross-modal learning
  3. Cross-Modal Attention:
    • Visual attends to audio and text
    • Audio attends to visual and text
    • More complex but better interaction
  4. Bottleneck Fusion:
    • Compress each modality to bottleneck tokens
    • Fuse bottlenecks
    • More efficient for long sequences

Why This Design:

Part (c): Training Objective

Primary Objective: Action Recognition

$\mathcal{L}_{\text{action}} = -\frac{1}{B} \sum_{i=1}^{B} \log P(y_i | \vv_i, \va_i, \vt_i)$

where:

Auxiliary Objectives (Multi-Task Learning):

1. Contrastive Loss (Cross-Modal Alignment):

Align visual-audio, visual-text, audio-text pairs:

$\mathcal{L}_{\text{contrast}} = \mathcal{L}_{\text{VA}} + \mathcal{L}_{\text{VT}} + \mathcal{L}_{\text{AT}}$

where each term is CLIP-style contrastive loss:

$\mathcal{L}_{\text{VA}} = -\frac{1}{B} \sum_{i=1}^{B} \log \frac{\exp(\text{sim}(\vv_i, \va_i) / \tau)}{\sum_{j=1}^{B} \exp(\text{sim}(\vv_i, \va_j) / \tau)}$

2. Masked Modality Modeling:

Randomly mask one modality and predict it from others:

$\mathcal{L}_{\text{mask}} = \mathcal{L}_{\text{mask-V}} + \mathcal{L}_{\text{mask-A}} + \mathcal{L}_{\text{mask-T}}$

Example (mask visual): $\mathcal{L}_{\text{mask-V}} = \|\hat{\vv} - \vv\|_2^2$

where $\hat{\vv} = f(\va, \vt)$ is predicted visual features.

3. Temporal Ordering:

Predict correct temporal order of video segments:

$\mathcal{L}_{\text{temporal}} = -\log P(\text{order} | \vv, \va, \vt)$

Total Training Objective:

$\mathcal{L}_{\text{total}} = \mathcal{L}_{\text{action}} + \lambda_1 \mathcal{L}_{\text{contrast}} + \lambda_2 \mathcal{L}_{\text{mask}} + \lambda_3 \mathcal{L}_{\text{temporal}}$

Typical weights: $\lambda_1 = 0.1$, $\lambda_2 = 0.05$, $\lambda_3 = 0.05$

Training Recipe:

# Optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.05)

# Learning rate schedule
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)

# Training loop
for epoch in range(100):
    for batch in dataloader:
        visual, audio, text, labels = batch
        
        # Forward pass
        logits = model(visual, audio, text)
        
        # Action recognition loss
        loss_action = F.cross_entropy(logits, labels)
        
        # Contrastive loss (optional)
        visual_rep = model.get_visual_rep(visual)
        audio_rep = model.get_audio_rep(audio)
        loss_contrast = contrastive_loss(visual_rep, audio_rep)
        
        # Total loss
        loss = loss_action + 0.1 * loss_contrast
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    scheduler.step()

Data Augmentation:

Part (d): Parameter Count Estimation

Component Breakdown:

1. Modality-Specific Encoders:

Visual Encoder:

Audio Encoder:

Text Encoder:

Encoder total: $2{,}103{,}296$ parameters

2. Modality Tokens:

3. Fusion Transformer (12 layers):

Per layer:

12 layers: $12 \times 12{,}587{,}008 = 151{,}044{,}096$ parameters

4. Classification Head:

Total Model Parameters:

$2{,}103{,}296 + 3{,}072 + 151{,}044{,}096 + 1{,}460{,}224 = 154{,}610{,}688$

Total: $\approx 155$M parameters

Memory Footprint (FP32):

Total training memory: $\approx 2.6$ GB

Comparison with Baselines:

ModelParametersModalities
Single-modal (visual only)86M1
Two-modal (visual + audio)120M2
Our three-modal155M3
CLIP (ViT-B/32)151M2
Whisper-large1.55B2

Design Trade-offs:

  1. Parameter efficiency:
    • Shared fusion transformer reduces parameters
    • Modality-specific encoders are lightweight
    • Could use pre-trained encoders (ViT, BERT, etc.)
  2. Computational cost:
    • Sequence length: 71 tokens (manageable)
    • Attention complexity: $O(71^2 \times 1024) \approx 5$M operations
    • Inference time: $\approx 50$ ms on GPU
  3. Scalability:
    • Easy to add more modalities (depth, optical flow, etc.)
    • Can increase fusion layers for better interaction
    • Bottleneck fusion for longer sequences

Practical Recommendations:

  1. Use pre-trained encoders: ViT for visual, Wav2Vec for audio, BERT for text
  2. Freeze encoders initially: Train fusion transformer first
  3. Fine-tune end-to-end: Unfreeze all parameters later
  4. Modality dropout: Randomly drop modalities during training for robustness
  5. Temporal modeling: Add temporal attention for video sequences

Key Insights:

← Chapter 17: Vision Transformers 📚 Table of Contents Chapter 19: Long Context Handling →