Multimodal Transformers
Chapter Overview
Multimodal transformers process multiple modalities (text, images, audio, video) in a unified framework. This chapter covers vision-language models (CLIP, DALL-E), audio-text models (Whisper), and unified architectures that handle arbitrary combinations of modalities.
Learning Objectives
- Understand multimodal fusion strategies
- Implement contrastive learning (CLIP)
- Apply vision-language models to zero-shot classification
- Generate images from text (DALL-E, Stable Diffusion)
- Process audio with transformers (Whisper)
- Build unified multimodal models
Multimodal Learning Fundamentals
Fusion Strategies
The choice of fusion strategy determines how modalities interact, with direct implications for computational cost and expressiveness. Three primary approaches have emerged:
\node[font=\small\bfseries] at (0,4) {Early Fusion}; \node[node, fill=blue!20] (v1) at (-0.5,2.5) {$v_1$}; \node[node, fill=blue!20] (v2) at (0.5,2.5) {$v_2$}; \node[node, fill=red!20] (t1) at (-0.5,1.5) {$t_1$}; \node[node, fill=red!20] (t2) at (0.5,1.5) {$t_2$};
\node[encoder, fill=purple!10] (enc1) at (0,0) {Unified \\ Encoder}; \draw[arrow] (v1) -- (enc1); \draw[arrow] (v2) -- (enc1); \draw[arrow] (t1) -- (enc1); \draw[arrow] (t2) -- (enc1);
\draw[bidir] (v1) to[bend left=10] (t1); \draw[bidir] (v2) to[bend right=10] (t2);
\node[font=\small\bfseries] at (6,4) {Late Fusion}; \node[node, fill=blue!20] (v3) at (5,2.5) {$v_1$}; \node[node, fill=blue!20] (v4) at (6,2.5) {$v_2$}; \node[node, fill=red!20] (t3) at (5,1.5) {$t_1$}; \node[node, fill=red!20] (t4) at (6,1.5) {$t_2$};
\node[encoder, fill=blue!10] (venc) at (5.5,0.8) {Vision \\ Encoder}; \node[encoder, fill=red!10] (tenc) at (5.5,-0.5) {Text \\ Encoder};
\draw[arrow] (v3) -- (venc); \draw[arrow] (v4) -- (venc); \draw[arrow] (t3) -- (tenc); \draw[arrow] (t4) -- (tenc);
\node[encoder, fill=yellow!10] (fuse) at (5.5,-2) {Fusion}; \draw[arrow] (venc) -- (fuse); \draw[arrow] (tenc) -- (fuse);
\node[font=\small\bfseries] at (12,4) {Cross-Modal}; \node[node, fill=blue!20] (v5) at (11,2.5) {$v_1$}; \node[node, fill=blue!20] (v6) at (12,2.5) {$v_2$}; \node[node, fill=red!20] (t5) at (11,1.5) {$t_1$}; \node[node, fill=red!20] (t6) at (12,1.5) {$t_2$};
\node[encoder, fill=blue!10] (venc2) at (11.5,0.8) {Vision \\ Encoder}; \node[encoder, fill=red!10] (tenc2) at (11.5,-0.5) {Text \\ Encoder};
\draw[arrow] (v5) -- (venc2); \draw[arrow] (v6) -- (venc2); \draw[arrow] (t5) -- (tenc2); \draw[arrow] (t6) -- (tenc2);
\draw[cross] (venc2) to[bend left=20] (tenc2); \draw[cross] (tenc2) to[bend left=20] (venc2);
\end{tikzpicture}
| Strategy | Description | Pros | Cons |
|---|---|---|---|
| Early fusion | Concatenate modality tokens into one sequence; process with unified encoder | Rich cross-modal interaction at every layer; simple architecture | $O((N{+}M)^2 d)$ cost; adding patches dramatically increases compute |
| Late fusion | Separate encoders per modality; combine outputs at decision stage (CLIP) | Efficient $O(N^2 d {+} M^2 d)$; encoders parallelizable | No fine-grained cross-modal alignment; interaction only at output |
| Cross-modal attention | Separate encoders with cross-attention layers between modalities (BLIP, Flamingo) | $O(N^2 d {+} M^2 d {+} NMd)$; rich interactions with moderate cost | Additional parameters; more complex architecture |
Cross-modal attention offers the best trade-off for most applications: for 196 image patches and 128 text tokens, cross-attention requires $196 \times 128 = 25{,}088$ computations per head versus $324^2 = 104{,}976$ for early fusion---a 4$\times$ reduction while preserving fine-grained alignment between modalities.
Alignment Objectives
Contrastive Learning:
Matching Loss:
Reconstruction:
CLIP: Contrastive Language-Image Pre-training
CLIP Architecture
CLIP (Contrastive Language-Image Pre-training) represents a breakthrough in vision-language learning by training image and text encoders jointly using a contrastive objective on 400 million image-text pairs collected from the internet. Unlike traditional supervised learning that requires manually labeled categories, CLIP learns to align images with their natural language descriptions, enabling zero-shot transfer to downstream tasks without any task-specific training data.
The training procedure processes batches of $(image, text)$ pairs simultaneously. For each batch of size $N$, all $N$ images are encoded to produce embeddings $\vv_1, \ldots, \vv_N \in \R^{512}$, and all $N$ text descriptions are encoded to produce $\vt_1, \ldots, \vt_N \in \R^{512}$. The model then computes an $N \times N$ similarity matrix where entry $(i,j)$ represents the cosine similarity between image $i$ and text $j$. The contrastive loss maximizes the similarity along the diagonal (correct image-text pairs) while minimizing off-diagonal similarities (incorrect pairings). This symmetric loss is computed in both directionsâpredicting text from image and image from textâand averaged.
The parameter count for CLIP varies significantly across model scales. CLIP ResNet-50 contains approximately 102 million parameters (38M for ResNet-50 image encoder, 63M for text encoder, 1M for projections), while CLIP ViT-L/14 totals around 428 million parameters (304M for ViT-L image encoder, 123M for a larger text encoder with 768 dimensions and 12 layers, 1M for projections). The largest variant, ViT-L/14@336px, processes higher-resolution images (336Ă336 instead of 224Ă224) with the same architecture, increasing computational cost but improving performance on fine-grained tasks.
The similarity matrix is computed as $\mS = \mV \mT\transpose \in \R^{4 \times 4}$, where each entry $S_{ij}$ represents the dot product between image embedding $i$ and text embedding $j$. To make this scale-invariant, CLIP uses cosine similarity: $S_{ij} = \frac{\vv_i \cdot \vt_j}{\|\vv_i\| \|\vt_j\|}$, which normalizes each embedding to unit length before computing the dot product. This ensures that the similarity is determined by the angle between embeddings rather than their magnitudes.
The contrastive loss for the image-to-text direction is computed as:
In practice, CLIP uses very large batch sizes to provide more negative examples for contrastive learning. The original CLIP was trained with batch size 32,768, requiring distributed training across multiple GPUs. With such large batches, each positive pair has 32,767 negative examples, providing a strong learning signal. However, this creates substantial memory requirements: storing the $32{,}768 \times 512$ embedding matrices for images and text requires $32{,}768 \times 512 \times 4 = 67$ MB per modality in FP32, and the $32{,}768 \times 32{,}768$ similarity matrix requires $4.3$ GB. To make this tractable, CLIP uses gradient checkpointing and distributes the batch across many GPUs, computing the similarity matrix in chunks.
Computational Analysis of CLIP Training
Training CLIP at scale requires careful consideration of computational and memory costs across both the image and text encoding paths. For the ViT-L/14 image encoder processing 224Ă224 images, each image is divided into $16 \times 16 = 256$ patches of size $14 \times 14$. These patches are linearly projected to dimension 1024 and processed through 24 transformer layers. The computational cost per image is approximately $2 \times 24 \times 256^2 \times 1024 = 3.2$ GFLOPS for the attention operations (using the $2Ld^2n^2$ formula from Chapter 12) plus $2 \times 24 \times 256 \times 4 \times 1024^2 = 51.5$ GFLOPS for the feed-forward networks, totaling roughly 55 GFLOPS per image.
The text encoder processes sequences of up to 77 tokens through 12 transformer layers with dimension 768. The computational cost per text is approximately $2 \times 12 \times 77^2 \times 768 = 1.1$ GFLOPS for attention plus $2 \times 12 \times 77 \times 4 \times 768^2 = 4.4$ GFLOPS for feed-forward networks, totaling about 5.5 GFLOPS per text. This asymmetryâimages requiring 10Ă more compute than textâmeans that image encoding dominates the computational budget during training.
For a batch of 32,768 examples, the total forward pass requires approximately $32{,}768 \times (55 + 5.5) = 1{,}982{,}464$ GFLOPS or roughly 2 PFLOPS. On an NVIDIA A100 GPU with 312 TFLOPS of FP16 compute, this would take approximately 6.4 seconds per batch for the forward pass alone, not including backward propagation (which typically costs 2Ă the forward pass) or the contrastive loss computation. The full training of CLIP on 400 million image-text pairs with batch size 32,768 requires approximately $400{,}000{,}000 / 32{,}768 = 12{,}207$ batches. At roughly 20 seconds per batch (forward + backward + optimizer step), this amounts to 68 hours of continuous training on a single A100. In practice, OpenAI trained CLIP on 256 V100 GPUs for approximately 12 days, suggesting a total training cost of around 73,728 GPU-hours.
Memory requirements are equally demanding. Each image in the batch requires storing activations for 24 layers with 256 tokens and dimension 1024, totaling approximately $24 \times 256 \times 1024 \times 2 = 12.6$ MB per image in FP16 (the factor of 2 accounts for storing both pre- and post-activation values for backpropagation). For batch size 32,768, this amounts to 413 GB just for image activations. Text activations are smaller at approximately $12 \times 77 \times 768 \times 2 = 1.4$ MB per text, or 46 GB for the full batch. The similarity matrix requires $32{,}768 \times 32{,}768 \times 2 = 2.1$ GB in FP16. Combined with model parameters (428M parameters Ă 2 bytes = 856 MB) and optimizer states (typically 2Ă parameters for Adam), the total memory footprint exceeds 500 GB, necessitating distribution across many GPUs using techniques like ZeRO (Chapter 22) to partition optimizer states and activations.
Zero-Shot Classification with CLIP
One of CLIP's most remarkable capabilities is zero-shot classification: the ability to classify images into categories the model has never been explicitly trained on. This works by leveraging the natural language understanding of the text encoder to create classifiers on the fly from text descriptions. The procedure begins by creating text prompts for each class in the target classification task. For example, for a 10-class animal classification task, we might create prompts like "a photo of a dog", "a photo of a cat", "a photo of a bird", and so on. These prompts are encoded by the text encoder to produce class embeddings $\vt_1, \ldots, \vt_C \in \R^{512}$ where $C$ is the number of classes.
To classify a new image, we encode it with the image encoder to produce $\vv \in \R^{512}$, then compute the cosine similarity between the image embedding and each class embedding: $s_i = \frac{\vv \cdot \vt_i}{\|\vv\| \|\vt_i\|}$. The predicted class is simply $\arg\max_i s_i$, the class whose text description has the highest similarity to the image. This approach requires no training on the target datasetâthe model uses only its pre-trained knowledge of how images and text relate.
The performance of this zero-shot approach is surprisingly strong. CLIP ViT-L/14 achieves 76.2\% top-1 accuracy on ImageNet without ever seeing a single ImageNet training example, matching the performance of a ResNet-50 trained directly on ImageNet's 1.28 million labeled images. This demonstrates that CLIP has learned visual concepts that generalize far beyond its training distribution. Moreover, CLIP shows remarkable robustness to distribution shift: when evaluated on ImageNet variants with different image styles (sketches, cartoons, adversarial examples), CLIP's performance degrades much less than supervised models, suggesting it has learned more robust visual representations.
The prompt engineering aspect of zero-shot classification is crucial for performance. Simple prompts like "dog" perform worse than more descriptive prompts like "a photo of a dog". OpenAI found that using prompt ensemblesâaveraging predictions across multiple prompt templates like "a photo of a \{class\}", "a picture of a \{class\}", "an image of a \{class\}"âimproves accuracy by 1-2\% by reducing sensitivity to prompt phrasing. For fine-grained classification tasks, more specific prompts help: "a photo of a \{species\}, a type of bird" outperforms "a photo of a \{species\}" for bird species classification.
CLIP Variants and Training Requirements
Following CLIP's success, several variants have been developed with different scales and training procedures. OpenCLIP is an open-source reproduction that has trained models ranging from small (ResNet-50 with 102M parameters) to very large (ViT-G/14 with 1.8B parameters) on datasets including LAION-400M and LAION-2B. The largest OpenCLIP models require training on clusters of 128-512 A100 GPUs for several weeks, with estimated costs exceeding \$100,000 for the full training run. The training uses mixed precision (FP16) to reduce memory consumption and enable larger batch sizes, typically 32,768 to 65,536 examples distributed across all GPUs.
ALIGN, developed by Google, scales up the training data to 1.8 billion noisy image-text pairs collected from the web without extensive filtering. This demonstrates that contrastive learning is robust to noise in the training dataâthe model learns to ignore mismatched pairs through the contrastive objective. ALIGN uses an EfficientNet-L2 image encoder (480M parameters) and a BERT-Large text encoder (340M parameters), totaling approximately 820M parameters. Training ALIGN required a cluster of 1024 Cloud TPU v3 cores for approximately 6 days, representing roughly 150,000 TPU-hours.
Florence, Microsoft's unified vision foundation model, extends the CLIP approach to 900 million image-text pairs with a focus on creating a single model that can be adapted to diverse vision tasks. Florence uses a CoSwin transformer as the image encoder (637M parameters) and achieves state-of-the-art results on zero-shot classification, retrieval, and object detection after fine-tuning. The training infrastructure required 512 NVIDIA A100 GPUs for approximately 10 days, with an estimated cost of over \$200,000 in cloud compute.
The hardware requirements for training CLIP-scale models are substantial. A minimum viable setup might use 8-16 A100 GPUs (80GB each) to train a CLIP ResNet-50 model on a smaller dataset like Conceptual Captions (3M pairs) with batch size 2048-4096, requiring approximately 1-2 weeks. Scaling to the full CLIP ViT-L/14 with 400M training pairs and batch size 32,768 necessitates at least 64-128 A100 GPUs with high-bandwidth interconnects (NVLink or InfiniBand) to efficiently synchronize gradients across the distributed batch. The total training cost for reproducing CLIP ViT-L/14 is estimated at \$50,000-\$100,000 in cloud GPU costs, depending on the provider and optimization techniques employed.
DALL-E and Stable Diffusion
DALL-E: Text-to-Image Generation
- Encoder: Compress images to discrete tokens (VQ-VAE)
- Transformer: Autoregressive model over text + image tokens
- Training: Next token prediction
Sequence:
Generate image by: (1) Encode text, (2) Sample image tokens autoregressively
DALL-E 2 (2022):
- Use CLIP embeddings
- Prior: Text embedding $\to$ Image embedding
- Decoder: Image embedding $\to$ Image (diffusion model)
- Much higher quality than DALL-E 1
Stable Diffusion
Latent Diffusion Model:
- Encode image to latent space (VAE)
- Add noise iteratively (forward diffusion)
- Learn to denoise (reverse diffusion)
- Condition on text via cross-attention
Text conditioning:
- Text encoder: CLIP or T5
- Cross-attention: Latent queries attend to text keys/values
- Enables text-guided image generation
1. Text Encoder: CLIP text encoder
2. VAE Encoder: Image $\to$ latent
3. U-Net Denoiser: Diffusion model with cross-attention
- Input: Noisy latent $\vz_t$
- Condition: Text embedding $\vt$
- Output: Predicted noise $\epsilon_\theta(\vz_t, t, \vt)$
4. VAE Decoder: Latent $\to$ image
Parameters: $\approx 860$M total
Vision-Language Understanding
BLIP: Bootstrapped Language-Image Pre-training
Architecture:
- Image encoder (ViT)
- Text encoder (BERT)
- Multimodal encoder (cross-attention between vision and text)
Training objectives:
- ITC: Image-Text Contrastive (like CLIP)
- ITM: Image-Text Matching (binary: match or not)
- LM: Language Modeling on text
Bootstrapping: Generate synthetic captions, filter with model, retrain
Flamingo: Few-Shot Learning
Flamingo represents a significant architectural innovation in multimodal transformers by enabling models to process arbitrarily interleaved sequences of images and text, supporting few-shot learning through in-context examples. Unlike CLIP, which processes single image-text pairs, Flamingo can handle inputs like "Here is an image of a cat: . Here is an image of a dog: . What animal is in this image: ?" This capability enables few-shot visual learning where the model learns new tasks from just a few examples provided in the prompt, without any parameter updates.
The Flamingo architecture consists of three main components, carefully designed to leverage pre-trained models while adding minimal trainable parameters. The vision encoder is a frozen CLIP ViT-L/14 model that processes each image independently to produce a sequence of patch embeddings. For a 224Ă224 image with patch size 14, this yields 256 patch tokens of dimension 1024. The vision encoder's 304M parameters remain frozen throughout training, preserving the strong visual representations learned during CLIP pre-training.
The language model is a frozen Chinchilla 70B model, a large autoregressive transformer trained on text-only data. Chinchilla uses 70 billion parameters across 80 layers with hidden dimension 8192 and 64 attention heads. Keeping this massive language model frozen is crucial for computational tractabilityâtraining 70B parameters would require prohibitive memory and compute. Instead, Flamingo inserts new trainable layers that allow the frozen language model to attend to visual information without modifying its core text processing capabilities.
The key innovation is the Perceiver Resampler, a learned module that compresses the variable-length sequence of image patch embeddings into a fixed number of visual tokens that can be efficiently processed by the language model. The Perceiver Resampler uses cross-attention where a fixed set of learned queries $\mQ \in \R^{64 \times 2048}$ (64 visual tokens, dimension 2048) attends to the image patch embeddings $\mK, \mV \in \R^{256 \times 1024}$ from the vision encoder. This produces a fixed-size representation regardless of input image resolution or the number of images in the sequence. The Perceiver Resampler contains approximately 1.4B parameters (6 layers of cross-attention and feed-forward networks with dimension 2048), making it the primary trainable component of Flamingo.
Between every few layers of the frozen language model, Flamingo inserts new cross-attention layers that allow text tokens to attend to the visual tokens produced by the Perceiver Resampler. Specifically, for Flamingo-80B (built on Chinchilla-70B), cross-attention layers are inserted after every 7th transformer layer, resulting in approximately 11 cross-attention insertions across the 80 layers. Each cross-attention layer adds roughly 134M parameters (for dimension 8192), totaling about 1.5B parameters for all insertions. Combined with the Perceiver Resampler, Flamingo adds approximately 2.9B trainable parameters to the 70B frozen base model, representing just 4\% additional parameters while enabling full multimodal capabilities.
The memory requirements for Flamingo are dominated by the frozen language model. Storing 70B parameters in FP16 requires 140 GB, which exceeds the memory of any single GPU. Flamingo uses model parallelism to partition the language model across multiple GPUsâfor example, distributing across 8 A100 GPUs (80GB each) places roughly 8.75B parameters per GPU, consuming about 17.5 GB for parameters. Activations for a sequence of 2048 tokens (including both text and visual tokens) across 80 layers with dimension 8192 require approximately $2048 \times 80 \times 8192 \times 2 = 2.6$ GB per example in FP16. With batch size 8, activations consume 21 GB per GPU, leaving sufficient memory for gradients of the trainable parameters (2.9B parameters Ă 2 bytes Ă 2 for gradients = 11.6 GB) and optimizer states (23.2 GB for Adam).
Training Flamingo on a mixture of image-text pairs, interleaved image-text documents, and video-text pairs requires substantial computational resources. The training dataset consists of 2.3 billion image-text pairs (similar to CLIP), 43 million interleaved image-text web pages, and 27 million video clips. Training Flamingo-80B for 1 epoch through this data with batch size 256 distributed across 256 A100 GPUs takes approximately 15 days, representing roughly 92,000 GPU-hours. The estimated training cost exceeds \$300,000 in cloud compute. However, the key advantage is that only 2.9B parameters are trained while leveraging the capabilities of a 70B language model, making training far more efficient than training a 70B multimodal model from scratch.
For inference, Flamingo's few-shot learning capability means that users can provide 2-32 example image-text pairs in the prompt to demonstrate a new task, and the model adapts its predictions based on these examples without any fine-tuning. This in-context learning works because the cross-attention mechanism allows the model to attend to the example images when processing the query image. The computational cost of inference scales linearly with the number of examples in the context: each additional image adds 256 patch tokens (after vision encoding) compressed to 64 visual tokens (after Perceiver Resampler), increasing the sequence length and thus the attention cost. For a prompt with 4 example images and 1 query image (5 images total), the visual tokens contribute $5 \times 64 = 320$ tokens to the sequence, which combined with text tokens (typically 500-1000) results in sequences of 800-1300 tokens. On a single A100 GPU, Flamingo-80B can process approximately 2-3 such sequences per second, limited primarily by the memory bandwidth required to load the 70B parameter model.
Computational Analysis of Multimodal Transformers
Multimodal transformers follow the same FLOPs formulas derived in Chapter~12 for their individual encoders: each transformer layer costs $24Bnd_{\text{model}}^2 + 4Bn^2d_{\text{model}}$ FLOPs (attention plus feed-forward). The multimodal-specific addition is cross-modal attention, which costs $4mnd$ FLOPs per layer (where $m$ and $n$ are the sequence lengths of the two modalities). In practice, cross-modal attention is a small fraction of total cost---for BLIP with 128 text tokens and 196 image patches, cross-attention adds only 462~MFLOPs across 6 layers, negligible compared to the self-attention costs in each encoder.
The key computational asymmetry in multimodal models is between modalities: image encoding typically dominates. CLIP's ViT-L/14 requires $\sim$55~GFLOPs per image versus $\sim$5.5~GFLOPs per text, a 10$\times$ ratio. When a large language model serves as the text backbone (as in Flamingo with Chinchilla-70B), text processing dominates instead, requiring $\sim$110~TFLOPs per sequence.
Memory requirements follow the same principles as unimodal transformers (Chapter~12): parameters, gradients, optimizer states, and activations. The multimodal-specific concern is storing activations for both modalities simultaneously. For CLIP ViT-L/14, image activations consume $\sim$75~MB per image in FP16 while text activations require $\sim$1.4~MB per text. For large batch sizes (32,768 in CLIP), this necessitates distributed training with gradient checkpointing and mixed precision (see Chapter~11 for distributed training techniques).
Training Challenges for Multimodal Transformers
Batch Size Requirements for Contrastive Learning
Contrastive learning methods like CLIP require very large batch sizes to provide sufficient negative examples. CLIP's performance scales log-linearly with batch size: increasing from 256 to 32,768 improves ImageNet zero-shot accuracy from $\sim$58\% to 76\%. However, the $32{,}768 \times 32{,}768$ similarity matrix alone requires 4.3~GB in FP32. To make this tractable, CLIP distributes the batch across 256 GPUs using all-gather communication, so the full similarity matrix is never materialized on any single GPU.
Distributed Training and Memory Optimization
Multimodal transformers use the same distributed training techniques as unimodal models (see Chapter~11 for detailed coverage): data parallelism for CLIP-scale models that fit on a single GPU, tensor and pipeline parallelism for larger models like Flamingo-80B where the 70B parameter language model must be partitioned across multiple GPUs. Memory optimization techniques---gradient checkpointing, mixed precision training, and ZeRO optimizer state partitioning---are essential and apply identically to the multimodal setting.
The multimodal-specific challenge is the asymmetric memory profile: image activations ($\sim$75~MB per image for ViT-L) far exceed text activations ($\sim$1.4~MB per text for CLIP's encoder), so image encoding dominates the memory budget during training. For Flamingo-80B, the frozen 70B language model requires 140~GB in FP16, necessitating model parallelism across at least 2 A100 GPUs before accounting for activations or trainable parameters.
Audio Transformers
Whisper: Speech Recognition
Input: Audio waveform $\to$ Log-mel spectrogram
Encoder:
- Input: Spectrogram (80 mel bins)
- Convolution layers (downsample)
- Transformer encoder layers
Decoder:
- Autoregressive text generation
- Special tokens for language, task, timestamps
Training data: 680,000 hours of multilingual audio
Tasks supported:
- Speech recognition (transcription)
- Translation (to English)
- Language identification
- Voice activity detection
- Timestamp prediction
<|startoftranscript|><|en|><|transcribe|><|notimestamps|>
Spectrogram:
- 80 mel bins
- 3000 frames (30 seconds audio at 100 Hz)
- Input: $3000 \times 80$
Encoder:
- Conv layers: $3000 \times 80 \to 1500 \times 768$
- Transformer: Process 1500 tokens
Decoder: Generate text tokens autoregressively
Audio-Text Pre-training
Contrastive learning: Like CLIP but audio-text
AudioCLIP: Tri-modal (image, text, audio)
Applications:
- Zero-shot audio classification
- Audio captioning
- Text-to-audio generation
Unified Multimodal Models
Perceiver and Perceiver IO
Key idea: Map arbitrary modalities to latent space via cross-attention
1. Latent array: Fixed set of learned queries $\mZ \in \R^{M \times d}$
2. Cross-attention: Latents attend to inputs
3. Transformer: Process latents
4. Output: Decode latents to task outputs
Benefits:
- Handles arbitrary input sizes
- Computation independent of input size (fixed latents)
- Unified architecture for images, video, audio, text
GPT-4V and LLaVA
GPT-4V (Vision): GPT-4 with vision capabilities
- Interleaved image and text inputs
- Strong vision-language understanding
- Details not fully disclosed
LLaVA (Open-source):
- CLIP vision encoder
- LLaMA language model
- Linear projection to align embeddings
- Instruction tuning on visual conversations
Exercises
- Generate random image embeddings $(8, 512)$
- Generate random text embeddings $(8, 512)$
- Compute $8 \times 8$ similarity matrix
- Calculate contrastive loss with $\tau = 0.07$
- Load pre-trained CLIP model
- Create text prompts for 10 classes
- Encode images and prompts
- Compute accuracy
- Compare to supervised baseline
- Calculate parameters for encoder (24 layers, $d=1024$)
- Calculate parameters for decoder (24 layers)
- Estimate memory for 30-second audio
- Compare to text-only GPT-2
- Propose architecture
- Define fusion mechanism
- Specify training objective
- Estimate parameter count
Solutions
Full solutions for all exercises are available at \url{https://deeplearning.hofkensvermeulen.be}.
import torch
import torch.nn as nn
import torch.nn.functional as F
def clip_contrastive_loss(image_embeddings, text_embeddings, temperature=0.07):
"""
Compute CLIP contrastive loss
Args:
image_embeddings: (B, D) normalized image embeddings
text_embeddings: (B, D) normalized text embeddings
temperature: temperature parameter tau
Returns:
loss: scalar contrastive loss
"""
# Normalize embeddings
image_embeddings = F.normalize(image_embeddings, dim=-1)
text_embeddings = F.normalize(text_embeddings, dim=-1)
# Compute similarity matrix (B, B)
logits = torch.matmul(image_embeddings, text_embeddings.t()) / temperature
# Labels: diagonal elements are positive pairs
batch_size = image_embeddings.shape[0]
labels = torch.arange(batch_size, device=image_embeddings.device)
# Symmetric loss: image-to-text + text-to-image
loss_i2t = F.cross_entropy(logits, labels)
loss_t2i = F.cross_entropy(logits.t(), labels)
loss = (loss_i2t + loss_t2i) / 2
return loss, logits
# Part (a): Generate random embeddings
batch_size = 8
embed_dim = 512
image_embeddings = torch.randn(batch_size, embed_dim)
text_embeddings = torch.randn(batch_size, embed_dim)
print(f"Image embeddings shape: {image_embeddings.shape}")
print(f"Text embeddings shape: {text_embeddings.shape}")
# Part (b): Normalize embeddings
image_embeddings = F.normalize(image_embeddings, dim=-1)
text_embeddings = F.normalize(text_embeddings, dim=-1)
print(f"\nAfter normalization:")
print(f"Image embedding norms: {torch.norm(image_embeddings, dim=-1)}")
print(f"Text embedding norms: {torch.norm(text_embeddings, dim=-1)}")
# Part (c): Compute similarity matrix
temperature = 0.07
similarity_matrix = torch.matmul(image_embeddings, text_embeddings.t()) / temperature
print(f"\nSimilarity matrix shape: {similarity_matrix.shape}")
print(f"Similarity matrix:\n{similarity_matrix}")
# Part (d): Calculate contrastive loss
loss, logits = clip_contrastive_loss(image_embeddings, text_embeddings, temperature)
print(f"\nContrastive loss: {loss.item():.4f}")
print(f"Logits shape: {logits.shape}")
# Analyze the loss
labels = torch.arange(batch_size)
predictions_i2t = logits.argmax(dim=1)
predictions_t2i = logits.t().argmax(dim=1)
accuracy_i2t = (predictions_i2t == labels).float().mean()
accuracy_t2i = (predictions_t2i == labels).float().mean()
print(f"\nImage-to-Text accuracy: {accuracy_i2t.item():.2
print(f"Text-to-Image accuracy: {accuracy_t2i.item():.2
Mathematical Derivation:
Part (a) \& (b): Embeddings
Image embeddings: $\vI = [\vi_1, \vi_2, \ldots, \vi_8] \in \mathbb{R}^{8 \times 512}$
Text embeddings: $\vT = [\vt_1, \vt_2, \ldots, \vt_8] \in \mathbb{R}^{8 \times 512}$
Normalize to unit sphere: $\hat{\vi}_i = \frac{\vi_i}{\|\vi_i\|_2}, \quad \hat{\vt}_i = \frac{\vt_i}{\|\vt_i\|_2}$
Part (c): Similarity Matrix
Cosine similarity matrix: $\vS_{ij} = \frac{\hat{\vi}_i \cdot \hat{\vt}_j}{\tau}$
where $\tau = 0.07$ is the temperature parameter.
Full matrix: $\vS = \frac{1}{\tau} \hat{\vI} \hat{\vT}^T \in \mathbb{R}^{8 \times 8}$
Example: \[ \vS = \begin{bmatrix} s_{11} & s_{12} & \cdots & s_{18} \\ s_{21} & s_{22} & \cdots & s_{28} \\ \vdots & \vdots & \ddots & \vdots \\ s_{81} & s_{82} & \cdots & s_{88} \end{bmatrix} \]
Diagonal elements $s_{ii}$ are positive pairs (matched image-text).
Off-diagonal elements $s_{ij}$ ($i \neq j$) are negative pairs.
Part (d): Contrastive Loss
Image-to-Text Loss:
For each image $i$, predict its matching text from 8 candidates:
$\mathcal{L}_{i2t} = -\frac{1}{8} \sum_{i=1}^{8} \log \frac{\exp(s_{ii})}{\sum_{j=1}^{8} \exp(s_{ij})}$
This is cross-entropy with labels $y_i = i$ (diagonal).
Text-to-Image Loss:
For each text $j$, predict its matching image from 8 candidates:
$\mathcal{L}_{t2i} = -\frac{1}{8} \sum_{j=1}^{8} \log \frac{\exp(s_{jj})}{\sum_{i=1}^{8} \exp(s_{ij})}$
Total CLIP Loss:
$\mathcal{L}_{\text{CLIP}} = \frac{1}{2}(\mathcal{L}_{i2t} + \mathcal{L}_{t2i})$
Symmetric loss ensures both modalities learn aligned representations.
Why Temperature $\tau = 0.07$?
- Sharpens distribution: Small $\tau$ makes softmax more peaked
- Emphasizes hard negatives: Distinguishes similar but incorrect pairs
- Empirically optimal: Found through hyperparameter search
- Typical range: $\tau \in [0.01, 0.1]$
Effect of temperature:
- $\tau \to 0$: Approaches hard assignment (argmax)
- $\tau \to \infty$: Uniform distribution (no learning)
- $\tau = 0.07$: Good balance for contrastive learning
Numerical Example:
Suppose for image 1:
- $s_{11} = 0.9$ (correct text)
- $s_{12} = 0.3, s_{13} = 0.2, \ldots, s_{18} = 0.1$ (incorrect texts)
Softmax probabilities: $p_1 = \frac{\exp(0.9/0.07)}{\exp(0.9/0.07) + \sum_{j=2}^{8} \exp(s_{1j}/0.07)}$
Loss for image 1: $\ell_1 = -\log p_1$
If $p_1 \approx 1$, then $\ell_1 \approx 0$ (good alignment).
If $p_1 \approx 0.125$ (uniform), then $\ell_1 \approx 2.08$ (poor alignment).
Training Dynamics:
- Initial: Random embeddings, $\mathcal{L} \approx \log(8) = 2.08$
- Training: Embeddings align, diagonal elements increase
- Converged: $s_{ii} \gg s_{ij}$ for $i \neq j$, $\mathcal{L} \to 0$
Key Insights:
- Batch size acts as number of negative samples
- Larger batches improve contrastive learning (more negatives)
- CLIP uses batch sizes up to 32,768 in practice
- Symmetric loss prevents modality collapse
- Temperature is a critical hyperparameter
import torch
import clip
from PIL import Image
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
# Part (a): Load pre-trained CLIP model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)
print(f"CLIP model loaded on {device}")
print(f"Model: ViT-B/32")
# Part (b): Create text prompts for 10 CIFAR-10 classes
cifar10_classes = [
"airplane", "automobile", "bird", "cat", "deer",
"dog", "frog", "horse", "ship", "truck"
]
# Template-based prompts (improves accuracy)
templates = [
"a photo of a {}.",
"a blurry photo of a {}.",
"a photo of many {}.",
"a photo of the small {}.",
"a photo of the large {}.",
]
# Encode text prompts
def encode_text_prompts(model, classes, templates):
"""Encode text prompts with multiple templates"""
text_features = []
for classname in classes:
# Create prompts from templates
texts = [template.format(classname) for template in templates]
texts = clip.tokenize(texts).to(device)
# Encode texts
with torch.no_grad():
class_embeddings = model.encode_text(texts)
class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True)
# Average over templates
class_embedding = class_embeddings.mean(dim=0)
class_embedding = class_embedding / class_embedding.norm()
text_features.append(class_embedding)
text_features = torch.stack(text_features, dim=0)
return text_features
text_features = encode_text_prompts(model, cifar10_classes, templates)
print(f"\nText features shape: {text_features.shape}") # (10, 512)
# Part (c): Load CIFAR-10 test set
test_dataset = torchvision.datasets.CIFAR10(
root='./data',
train=False,
download=True,
transform=preprocess
)
test_loader = DataLoader(test_dataset, batch_size=100, shuffle=False)
# Zero-shot classification
def zero_shot_classify(model, loader, text_features):
"""Perform zero-shot classification"""
correct = 0
total = 0
with torch.no_grad():
for images, labels in tqdm(loader):
images = images.to(device)
labels = labels.to(device)
# Encode images
image_features = model.encode_image(images)
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
# Compute similarity with text features
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
# Predict
predictions = similarity.argmax(dim=-1)
correct += (predictions == labels).sum().item()
total += labels.size(0)
accuracy = 100.0 * correct / total
return accuracy
# Part (d): Compute accuracy
zero_shot_accuracy = zero_shot_classify(model, test_loader, text_features)
print(f"\nZero-shot accuracy: {zero_shot_accuracy:.2f}
# Part (e): Compare to supervised baseline
# Train a simple supervised classifier
class SimpleCNN(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(3, 32, 3, padding=1)
self.conv2 = torch.nn.Conv2d(32, 64, 3, padding=1)
self.pool = torch.nn.MaxPool2d(2, 2)
self.fc1 = torch.nn.Linear(64 * 8 * 8, 128)
self.fc2 = torch.nn.Linear(128, 10)
def forward(self, x):
x = self.pool(torch.relu(self.conv1(x)))
x = self.pool(torch.relu(self.conv2(x)))
x = x.view(-1, 64 * 8 * 8)
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
# Train supervised model (simplified)
supervised_model = SimpleCNN().to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(supervised_model.parameters(), lr=0.001)
# Training loop (10 epochs for quick comparison)
train_dataset = torchvision.datasets.CIFAR10(
root='./data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
for epoch in range(10):
supervised_model.train()
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = supervised_model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# Evaluate supervised model
supervised_model.eval()
correct = 0
total = 0
with torch.no_grad():
for images, labels in test_loader:
images, labels = images.to(device), labels.to(device)
outputs = supervised_model(images)
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
supervised_accuracy = 100.0 * correct / total
print(f"\nComparison:")
print(f"CLIP Zero-shot: {zero_shot_accuracy:.2f}
print(f"Supervised CNN (10 epochs): {supervised_accuracy:.2f}
Expected Results:
| Method | Accuracy | Training Data |
|---|---|---|
| CLIP Zero-shot (ViT-B/32) | 89-91\% | 0 (CIFAR-10) |
| CLIP Zero-shot (ViT-L/14) | 93-95\% | 0 (CIFAR-10) |
| Supervised CNN (10 epochs) | 70-75\% | 50k (CIFAR-10) |
| Supervised ResNet-50 (200 epochs) | 95-96\% | 50k (CIFAR-10) |
Analysis:
Part (a): Pre-trained CLIP Model
CLIP models available:
- RN50: ResNet-50 image encoder
- ViT-B/32: ViT-Base with patch size 32
- ViT-B/16: ViT-Base with patch size 16 (better)
- ViT-L/14: ViT-Large with patch size 14 (best)
Pre-training:
- Dataset: 400M image-text pairs from internet
- Training: Contrastive learning for 32 epochs
- Batch size: 32,768 (large-scale)
- Compute: 256 V100 GPUs for 12 days
Part (b): Text Prompts
Simple prompts:
"airplane", "automobile", "bird", ...
Template-based prompts (better):
"a photo of a airplane."
"a blurry photo of a airplane."
"a photo of many airplanes."
Why templates help:
- Match training distribution (natural sentences)
- Provide context for ambiguous classes
- Ensemble over multiple descriptions
- Improve robustness to variations
Prompt engineering tips:
- Use natural language sentences
- Include domain-specific context
- Try multiple templates and average
- Avoid overly specific descriptions
Part (c): Encoding Process
Image encoding:
- Preprocess: Resize to $224 \times 224$, normalize
- ViT encoder: Extract features
- Projection: Map to shared embedding space (512-dim)
- Normalize: $\hat{\vi} = \vi / \|\vi\|_2$
Text encoding:
- Tokenize: Convert text to token IDs
- Text encoder: Transformer processes tokens
- Projection: Map to shared embedding space (512-dim)
- Normalize: $\hat{\vt} = \vt / \|\vt\|_2$
Part (d): Zero-Shot Classification
Algorithm:
For each test image $\vx$:
- Encode image: $\vi = \text{ImageEncoder}(\vx)$
- Compute similarity with all class embeddings: $s_k = \vi \cdot \vt_k$ for $k = 1, \ldots, 10$
- Apply softmax: $p_k = \frac{\exp(s_k / \tau)}{\sum_{j=1}^{10} \exp(s_j / \tau)}$
- Predict: $\hat{y} = \argmax_k p_k$
Temperature $\tau = 0.01$ (learned during training).
Mathematical Formulation:
$P(y = k | \vx) = \frac{\exp(\text{sim}(\vi, \vt_k) / \tau)}{\sum_{j=1}^{10} \exp(\text{sim}(\vi, \vt_j) / \tau)}$
where $\text{sim}(\vi, \vt) = \vi \cdot \vt$ (cosine similarity after normalization).
Part (e): Comparison with Supervised Baseline
Why CLIP Zero-Shot Outperforms Supervised CNN:
- Pre-training scale: 400M image-text pairs vs 50k CIFAR-10 images
- Transfer learning: Leverages knowledge from diverse data
- Better architecture: ViT-B/32 vs simple CNN
- Semantic understanding: Learns concepts, not just patterns
- Robustness: Generalizes better to distribution shifts
When Supervised Wins:
- Sufficient training data: ResNet-50 with 200 epochs reaches 95-96\%
- Domain-specific: Fine-tuned models beat zero-shot on specialized tasks
- Computational constraints: Smaller models are faster
CLIP Advantages:
- No training required: Instant deployment
- Flexible: Change classes without retraining
- Interpretable: Natural language descriptions
- Robust: Handles distribution shifts better
- Multimodal: Can do image-text retrieval, captioning, etc.
Practical Recommendations:
| Scenario | Recommendation |
|---|---|
| Quick prototype | CLIP zero-shot |
| Fixed classes, lots of data | Supervised training |
| Changing classes frequently | CLIP zero-shot |
| Maximum accuracy | Fine-tune CLIP |
| Limited compute | Supervised small model |
| Interpretability needed | CLIP with prompts |
Improving CLIP Zero-Shot:
- Better prompts: Domain-specific templates
- Larger model: ViT-L/14 instead of ViT-B/32
- Ensemble: Average predictions from multiple prompts
- Few-shot: Add a few examples with linear probe
- Fine-tuning: Adapt to target domain
Key Takeaways:
- CLIP achieves strong zero-shot performance through large-scale pre-training
- Natural language prompts enable flexible classification
- Zero-shot CLIP often matches or exceeds supervised baselines
- Prompt engineering is crucial for optimal performance
- CLIP's multimodal nature enables many downstream tasks
Part (a): Encoder Parameters (24 layers, $d=1024$)
Whisper Encoder Configuration:
- Layers: $L = 24$
- Hidden size: $d = 1024$
- Attention heads: $h = 16$
- MLP ratio: $4.0$ (MLP size = $4096$)
- Audio features: 80-dimensional log-mel spectrogram
- Sequence length: $T = 3000$ (30 seconds at 100 Hz)
Parameter Breakdown:
1. Input Convolution Layers:
- Conv1: $80 \times 3 \times 1024 = 245{,}760$
- Conv2: $1024 \times 3 \times 1024 = 3{,}145{,}728$
- Total: $3{,}391{,}488$ parameters
2. Position Embeddings:
- Sinusoidal (not learned): 0 parameters
3. Per Transformer Layer:
Multi-Head Attention:
- $Q, K, V$ projections: $3 \times 1024^2 = 3{,}145{,}728$
- Output projection: $1024^2 = 1{,}048{,}576$
- Total attention: $4{,}194{,}304$
MLP:
- First linear: $1024 \times 4096 = 4{,}194{,}304$
- Second linear: $4096 \times 1024 = 4{,}194{,}304$
- Total MLP: $8{,}388{,}608$
Layer Normalization:
- 2 LayerNorms: $2 \times 2 \times 1024 = 4{,}096$
Total per layer: $12{,}587{,}008$ parameters
4. All 24 Encoder Layers: $24 \times 12{,}587{,}008 = 302{,}088{,}192$ parameters
Total Encoder: $\approx 305.5$M parameters
Part (b): Decoder Parameters (24 layers)
Whisper Decoder Configuration:
- Layers: $L = 24$
- Hidden size: $d = 1024$
- Attention heads: $h = 16$
- Vocabulary size: $V = 51{,}865$
- Max sequence length: $448$ tokens
Parameter Breakdown:
1. Token Embedding:
- $51{,}865 \times 1024 = 53{,}109{,}760$ parameters
2. Position Embeddings:
- $448 \times 1024 = 458{,}752$ parameters
3. Per Decoder Layer:
Masked Self-Attention:
- Same as encoder: $4{,}194{,}304$ parameters
Cross-Attention:
- $Q$ projection: $1024^2 = 1{,}048{,}576$
- $K, V$ projections (from encoder): $2 \times 1024^2 = 2{,}097{,}152$
- Output projection: $1024^2 = 1{,}048{,}576$
- Total cross-attention: $4{,}194{,}304$
MLP:
- Same as encoder: $8{,}388{,}608$ parameters
Layer Normalization:
- 3 LayerNorms: $3 \times 2 \times 1024 = 6{,}144$
Total per decoder layer: $16{,}783{,}360$ parameters
4. All 24 Decoder Layers: $24 \times 16{,}783{,}360 = 402{,}800{,}640$ parameters
5. Output Projection:
- Shared with token embedding: 0 additional parameters
Total Decoder: $\approx 456.4$M parameters
Total Whisper Model: $305.5 + 456.4 = 761.9$M parameters
(Actual Whisper-large: $\approx 1.55$B parameters due to additional components)
Part (c): Memory for 30-Second Audio
Input Processing:
1. Audio Preprocessing:
- Sample rate: 16 kHz
- 30 seconds: $30 \times 16{,}000 = 480{,}000$ samples
- Raw audio: $480{,}000 \times 4$ bytes = 1.92 MB
2. Log-Mel Spectrogram:
- Window size: 25 ms (400 samples)
- Hop length: 10 ms (160 samples)
- Number of frames: $\frac{480{,}000}{160} = 3{,}000$
- Mel bins: 80
- Features: $3{,}000 \times 80 = 240{,}000$ values
- Memory: $240{,}000 \times 4$ bytes = 0.96 MB
Encoder Memory (Inference):
1. Activations per layer:
- Input: $3{,}000 \times 1024 = 3{,}072{,}000$ values
- Attention scores: $16 \times 3{,}000 \times 3{,}000 = 144{,}000{,}000$ values
- MLP intermediate: $3{,}000 \times 4096 = 12{,}288{,}000$ values
Peak per layer: $\approx 159$M values $\times$ 4 bytes = 636 MB
2. Total encoder activations: $24 \times 636$ MB = 15.3 GB (if storing all layers)
With activation checkpointing: $\approx 1.3$ GB
Decoder Memory (Inference):
For generating 448 tokens:
- Decoder activations: $448 \times 1024 = 458{,}752$ values per layer
- Cross-attention: $448 \times 3{,}000 = 1{,}344{,}000$ values per layer
- KV cache: $2 \times 24 \times 448 \times 1024 = 22{,}020{,}096$ values
Decoder memory: $\approx 500$ MB
Total Memory (Inference):
- Model parameters: $1.55$B $\times$ 4 bytes = 6.2 GB
- Encoder activations: $\approx 1.3$ GB (with checkpointing)
- Decoder activations: $\approx 0.5$ GB
- KV cache: $\approx 0.1$ GB
- Total: $\approx 8.1$ GB
For FP16: $\approx 4.1$ GB
For INT8 quantization: $\approx 2.1$ GB
Part (d): Compare to Text-Only GPT-2
GPT-2 (1.5B parameters):
- Layers: 48
- Hidden size: 1600
- Attention heads: 25
- Vocabulary: 50,257
- Context length: 1024 tokens
Comparison Table:
| Metric | Whisper-large | GPT-2 (1.5B) |
|---|---|---|
| Total Parameters | 1.55B | 1.5B |
| Encoder Layers | 24 | N/A |
| Decoder Layers | 24 | 48 |
| Hidden Size | 1024 | 1600 |
| Attention Heads | 16 | 25 |
| Input Modality | Audio | Text |
| Output Modality | Text | Text |
| Context Length | 3000 (audio) + 448 (text) | 1024 (text) |
| Memory (FP32) | 8.1 GB | 6.5 GB |
| Inference Speed | Slower (audio encoding) | Faster |
Key Differences:
- Architecture:
- Whisper: Encoder-decoder (like T5)
- GPT-2: Decoder-only
- Input Processing:
- Whisper: Audio $\to$ Log-mel $\to$ Encoder
- GPT-2: Text $\to$ Tokens $\to$ Decoder
- Computational Cost:
- Whisper encoder: $O(T^2 d)$ where $T = 3000$
- GPT-2: $O(n^2 d)$ where $n = 1024$
- Whisper is $\approx 9\times$ more expensive for encoder
- Memory Footprint:
- Whisper: Larger due to long audio sequences
- GPT-2: Smaller, text-only
- Use Cases:
- Whisper: Speech recognition, translation, transcription
- GPT-2: Text generation, completion, summarization
Why Whisper Needs Encoder-Decoder:
- Cross-modal: Audio input, text output
- Compression: Encoder compresses 3000 audio frames
- Attention: Decoder attends to compressed audio
- Efficiency: Encoder processes audio once, decoder generates text autoregressively
Performance Comparison:
| Task | Whisper | GPT-2 |
|---|---|---|
| Speech Recognition | Excellent | N/A |
| Text Generation | N/A | Excellent |
| Multilingual | 99 languages | Limited |
| Robustness | High (noisy audio) | N/A |
| Zero-shot | Strong | Strong |
Practical Considerations:
- Deployment:
- Whisper: Requires audio preprocessing
- GPT-2: Simple tokenization
- Latency:
- Whisper: Higher (audio encoding + decoding)
- GPT-2: Lower (text-only)
- Hardware:
- Whisper: Needs GPU for real-time (8+ GB VRAM)
- GPT-2: Can run on CPU for small batches
Key Insights:
- Whisper and GPT-2 have similar parameter counts but different architectures
- Encoder-decoder is essential for cross-modal tasks
- Audio sequences are much longer than text, requiring more memory
- Both models benefit from large-scale pre-training
- Whisper's multimodal nature enables speech-to-text applications
Part (a): Proposed Architecture
import torch
import torch.nn as nn
class MultimodalVideoTransformer(nn.Module):
def __init__(self,
visual_dim=768, # ViT features
audio_dim=512, # Audio features
text_dim=768, # BERT features
hidden_dim=1024, # Fusion dimension
num_layers=12, # Fusion transformer layers
num_heads=16,
num_classes=400): # Action recognition classes
super().__init__()
# Modality-specific encoders
self.visual_encoder = VisualEncoder(visual_dim, hidden_dim)
self.audio_encoder = AudioEncoder(audio_dim, hidden_dim)
self.text_encoder = TextEncoder(text_dim, hidden_dim)
# Modality-specific tokens
self.visual_token = nn.Parameter(torch.randn(1, 1, hidden_dim))
self.audio_token = nn.Parameter(torch.randn(1, 1, hidden_dim))
self.text_token = nn.Parameter(torch.randn(1, 1, hidden_dim))
# Fusion transformer
encoder_layer = nn.TransformerEncoderLayer(
d_model=hidden_dim,
nhead=num_heads,
dim_feedforward=hidden_dim * 4,
dropout=0.1,
batch_first=True
)
self.fusion_transformer = nn.TransformerEncoder(
encoder_layer,
num_layers=num_layers
)
# Classification head
self.classifier = nn.Sequential(
nn.LayerNorm(hidden_dim),
nn.Linear(hidden_dim, hidden_dim),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(hidden_dim, num_classes)
)
def forward(self, visual_features, audio_features, text_features):
"""
Args:
visual_features: (B, T_v, D_v) - video frames
audio_features: (B, T_a, D_a) - audio segments
text_features: (B, T_t, D_t) - caption tokens
Returns:
logits: (B, num_classes)
"""
B = visual_features.shape[0]
# Encode each modality
visual_emb = self.visual_encoder(visual_features) # (B, T_v, H)
audio_emb = self.audio_encoder(audio_features) # (B, T_a, H)
text_emb = self.text_encoder(text_features) # (B, T_t, H)
# Add modality tokens
visual_token = self.visual_token.expand(B, -1, -1)
audio_token = self.audio_token.expand(B, -1, -1)
text_token = self.text_token.expand(B, -1, -1)
visual_emb = torch.cat([visual_token, visual_emb], dim=1)
audio_emb = torch.cat([audio_token, audio_emb], dim=1)
text_emb = torch.cat([text_token, text_emb], dim=1)
# Concatenate all modalities
multimodal_emb = torch.cat([visual_emb, audio_emb, text_emb], dim=1)
# Shape: (B, 1+T_v + 1+T_a + 1+T_t, H)
# Fusion transformer
fused = self.fusion_transformer(multimodal_emb)
# Aggregate: average modality tokens
visual_rep = fused[:, 0, :]
audio_rep = fused[:, 1+visual_features.shape[1], :]
text_rep = fused[:, 1+visual_features.shape[1]+1+audio_features.shape[1], :]
# Combine representations
combined = (visual_rep + audio_rep + text_rep) / 3
# Classification
logits = self.classifier(combined)
return logits
class VisualEncoder(nn.Module):
def __init__(self, input_dim, output_dim):
super().__init__()
self.proj = nn.Linear(input_dim, output_dim)
self.norm = nn.LayerNorm(output_dim)
def forward(self, x):
return self.norm(self.proj(x))
class AudioEncoder(nn.Module):
def __init__(self, input_dim, output_dim):
super().__init__()
self.proj = nn.Linear(input_dim, output_dim)
self.norm = nn.LayerNorm(output_dim)
def forward(self, x):
return self.norm(self.proj(x))
class TextEncoder(nn.Module):
def __init__(self, input_dim, output_dim):
super().__init__()
self.proj = nn.Linear(input_dim, output_dim)
self.norm = nn.LayerNorm(output_dim)
def forward(self, x):
return self.norm(self.proj(x))
# Example usage
model = MultimodalVideoTransformer()
# Simulate inputs
batch_size = 4
visual = torch.randn(batch_size, 16, 768) # 16 frames
audio = torch.randn(batch_size, 32, 512) # 32 audio segments
text = torch.randn(batch_size, 20, 768) # 20 caption tokens
logits = model(visual, audio, text)
print(f"Output shape: {logits.shape}") # (4, 400)
Part (b): Fusion Mechanism
Architecture Overview:
Input:
Visual: (B, 16, 768) - 16 video frames from ViT
Audio: (B, 32, 512) - 32 audio segments from audio encoder
Text: (B, 20, 768) - 20 caption tokens from BERT
Step 1: Modality-Specific Projection
Visual -> (B, 16, 1024)
Audio -> (B, 32, 1024)
Text -> (B, 20, 1024)
Step 2: Add Modality Tokens
Visual: [V_token, v1, v2, ..., v16] -> (B, 17, 1024)
Audio: [A_token, a1, a2, ..., a32] -> (B, 33, 1024)
Text: [T_token, t1, t2, ..., t20] -> (B, 21, 1024)
Step 3: Concatenate
Multimodal: [V_token, v1, ..., v16, A_token, a1, ..., a32, T_token, t1, ..., t20]
Shape: (B, 71, 1024)
Step 4: Fusion Transformer (12 layers)
Cross-modal attention enables interaction
Output: (B, 71, 1024)
Step 5: Aggregate
Extract modality tokens: V_token, A_token, T_token
Average: (V_token + A_token + T_token) / 3
Shape: (B, 1024)
Step 6: Classification
MLP: (B, 1024) -> (B, 400)
Fusion Strategies Comparison:
- Early Fusion (Concatenation):
- Concatenate features before transformer
- Simple but limited cross-modal interaction
- Used in this design
- Late Fusion (Ensemble):
- Process modalities separately
- Combine predictions at the end
- No cross-modal learning
- Cross-Modal Attention:
- Visual attends to audio and text
- Audio attends to visual and text
- More complex but better interaction
- Bottleneck Fusion:
- Compress each modality to bottleneck tokens
- Fuse bottlenecks
- More efficient for long sequences
Why This Design:
- Modality tokens: Aggregate information from each modality
- Shared transformer: Enables cross-modal attention
- Flexible: Can handle missing modalities
- Scalable: Easy to add more modalities
Part (c): Training Objective
Primary Objective: Action Recognition
$\mathcal{L}_{\text{action}} = -\frac{1}{B} \sum_{i=1}^{B} \log P(y_i | \vv_i, \va_i, \vt_i)$
where:
- $\vv_i$: visual features for sample $i$
- $\va_i$: audio features for sample $i$
- $\vt_i$: text features for sample $i$
- $y_i$: ground truth action class
Auxiliary Objectives (Multi-Task Learning):
1. Contrastive Loss (Cross-Modal Alignment):
Align visual-audio, visual-text, audio-text pairs:
$\mathcal{L}_{\text{contrast}} = \mathcal{L}_{\text{VA}} + \mathcal{L}_{\text{VT}} + \mathcal{L}_{\text{AT}}$
where each term is CLIP-style contrastive loss:
$\mathcal{L}_{\text{VA}} = -\frac{1}{B} \sum_{i=1}^{B} \log \frac{\exp(\text{sim}(\vv_i, \va_i) / \tau)}{\sum_{j=1}^{B} \exp(\text{sim}(\vv_i, \va_j) / \tau)}$
2. Masked Modality Modeling:
Randomly mask one modality and predict it from others:
$\mathcal{L}_{\text{mask}} = \mathcal{L}_{\text{mask-V}} + \mathcal{L}_{\text{mask-A}} + \mathcal{L}_{\text{mask-T}}$
Example (mask visual): $\mathcal{L}_{\text{mask-V}} = \|\hat{\vv} - \vv\|_2^2$
where $\hat{\vv} = f(\va, \vt)$ is predicted visual features.
3. Temporal Ordering:
Predict correct temporal order of video segments:
$\mathcal{L}_{\text{temporal}} = -\log P(\text{order} | \vv, \va, \vt)$
Total Training Objective:
$\mathcal{L}_{\text{total}} = \mathcal{L}_{\text{action}} + \lambda_1 \mathcal{L}_{\text{contrast}} + \lambda_2 \mathcal{L}_{\text{mask}} + \lambda_3 \mathcal{L}_{\text{temporal}}$
Typical weights: $\lambda_1 = 0.1$, $\lambda_2 = 0.05$, $\lambda_3 = 0.05$
Training Recipe:
# Optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.05)
# Learning rate schedule
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)
# Training loop
for epoch in range(100):
for batch in dataloader:
visual, audio, text, labels = batch
# Forward pass
logits = model(visual, audio, text)
# Action recognition loss
loss_action = F.cross_entropy(logits, labels)
# Contrastive loss (optional)
visual_rep = model.get_visual_rep(visual)
audio_rep = model.get_audio_rep(audio)
loss_contrast = contrastive_loss(visual_rep, audio_rep)
# Total loss
loss = loss_action + 0.1 * loss_contrast
# Backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()
scheduler.step()
Data Augmentation:
- Visual: Random crop, color jitter, temporal sampling
- Audio: Time stretching, pitch shifting, noise injection
- Text: Synonym replacement, back-translation
- Multimodal: Random modality dropout (robustness)
Part (d): Parameter Count Estimation
Component Breakdown:
1. Modality-Specific Encoders:
Visual Encoder:
- Projection: $768 \times 1024 = 786{,}432$
- LayerNorm: $2 \times 1024 = 2{,}048$
- Total: $788{,}480$
Audio Encoder:
- Projection: $512 \times 1024 = 524{,}288$
- LayerNorm: $2 \times 1024 = 2{,}048$
- Total: $526{,}336$
Text Encoder:
- Projection: $768 \times 1024 = 786{,}432$
- LayerNorm: $2 \times 1024 = 2{,}048$
- Total: $788{,}480$
Encoder total: $2{,}103{,}296$ parameters
2. Modality Tokens:
- 3 tokens $\times$ 1024 = $3{,}072$ parameters
3. Fusion Transformer (12 layers):
Per layer:
- Self-attention: $4 \times 1024^2 = 4{,}194{,}304$
- MLP: $2 \times 1024 \times 4096 = 8{,}388{,}608$
- LayerNorm: $2 \times 2 \times 1024 = 4{,}096$
- Total per layer: $12{,}587{,}008$
12 layers: $12 \times 12{,}587{,}008 = 151{,}044{,}096$ parameters
4. Classification Head:
- LayerNorm: $2 \times 1024 = 2{,}048$
- Linear 1: $1024 \times 1024 = 1{,}048{,}576$
- Linear 2: $1024 \times 400 = 409{,}600$
- Total: $1{,}460{,}224$
Total Model Parameters:
$2{,}103{,}296 + 3{,}072 + 151{,}044{,}096 + 1{,}460{,}224 = 154{,}610{,}688$
Total: $\approx 155$M parameters
Memory Footprint (FP32):
- Parameters: $155$M $\times$ 4 bytes = 620 MB
- Activations (batch size 4):
- Input: $4 \times 71 \times 1024 = 290{,}816$ values
- Per layer: $\approx 2$M values
- Total: $\approx 24$M values $\times$ 4 bytes = 96 MB
- Gradients: 620 MB (same as parameters)
- Optimizer states (AdamW): $2 \times 620$ MB = 1.24 GB
Total training memory: $\approx 2.6$ GB
Comparison with Baselines:
| Model | Parameters | Modalities |
|---|---|---|
| Single-modal (visual only) | 86M | 1 |
| Two-modal (visual + audio) | 120M | 2 |
| Our three-modal | 155M | 3 |
| CLIP (ViT-B/32) | 151M | 2 |
| Whisper-large | 1.55B | 2 |
Design Trade-offs:
- Parameter efficiency:
- Shared fusion transformer reduces parameters
- Modality-specific encoders are lightweight
- Could use pre-trained encoders (ViT, BERT, etc.)
- Computational cost:
- Sequence length: 71 tokens (manageable)
- Attention complexity: $O(71^2 \times 1024) \approx 5$M operations
- Inference time: $\approx 50$ ms on GPU
- Scalability:
- Easy to add more modalities (depth, optical flow, etc.)
- Can increase fusion layers for better interaction
- Bottleneck fusion for longer sequences
Practical Recommendations:
- Use pre-trained encoders: ViT for visual, Wav2Vec for audio, BERT for text
- Freeze encoders initially: Train fusion transformer first
- Fine-tune end-to-end: Unfreeze all parameters later
- Modality dropout: Randomly drop modalities during training for robustness
- Temporal modeling: Add temporal attention for video sequences
Key Insights:
- Multimodal fusion requires careful architecture design
- Modality tokens enable flexible aggregation
- Shared transformer enables cross-modal learning
- Multi-task learning improves representation quality
- Parameter count is reasonable for modern GPUs
- Pre-trained encoders significantly improve performance