Contrastive Predictive Coding Representation Learning: A Comprehensive Guide
Traditionally, machine learning (ML) has been broadly classified into two types: supervised and unsupervised learning. Supervised learning involves training a model using labeled data, while unsupervised learning explores patterns in data without any prior labeling. To leverage the vast amounts of unlabeled data available, self-supervised learning (SSL) has emerged as an advanced learning strategy.
Self-Supervised Learning (SSL)
In self-supervised learning, unlabeled data is partitioned into positive and negative samples, mimicking supervised binary classification. The object under consideration is treated as a positive example, and all other samples as negative. Self-supervised learning methods learn joint embeddings and can be broadly classified into two types: contrastive and non-contrastive methods.
Contrastive methods maximize the similarity scores between different samples of the same data (e.g., different views of the same image) while minimizing them for other samples/images. Non-contrastive methods, on the other hand, don’t account for the negative samples. SimCLR and MoCo are examples of contrastive methods, while BYOL or DINO are good examples of non-contrastive methods.
Contrastive Learning (CL)
Contrastive learning centers around maximizing similarities between positive data pairs while minimizing similarities between negative pairs. For example, an image of a mango should have maximized similarity with other mango images and minimized similarity with images of other objects. The simplest setting is to consider each data point as a positive data point and all others as negative. If 'x' represents a data point, then the pair (x, T(x)) is considered positive, and (x, T(y)) is considered negative, where 'y' denotes any other data point and T(x) and T(y) are their encoder representations.
Training
For each pair of data points (x, T(y)), a similarity score is calculated. This score is used to train the model by minimizing it for negative data points and maximizing it for y=x. This optimization is achieved by choosing a suitable loss function, such as cross-entropy loss. In practice, training is performed in batches. For example, in a batch of 32 images, each image needs to have a maximized similarity score with itself and minimized similarity scores with the other 31 images. To learn discriminative features effectively, data augmentation techniques are used to create multiple copies of the same data point.
Read also: Overview of Masked Contrastive Learning
Contrastive Learning Examples
Contrastive learning has evolved significantly over the past decade, from CNN-based image discrimination to more advanced techniques like CLIP. Some modern CL algorithms include:
- Contrastive Predictive Coding (CPC)
- A Simple Framework for Contrastive Learning of (visual) Representations (SimCLR)
- Momentum Contrast (MoCo)
- Contrastive Language-Image Pretraining (CLIP)
Contrastive Predictive Coding (CPC)
Inspired by predictive coding and its adaptation in neuroscience, CPC focuses on high-level information in the data while ignoring low-level noise. CPC works by compressing high-dimensional data into a latent embedding space, making predictions in this space, and training the model using the Noise-Contrastive Estimation (NCE) loss function.
SimCLR
SimCLR is a contrastive learning technique for computer vision that doesn't need any pre-augmentation or specialized architecture.
SimCLR works as follows:
- An image is chosen randomly, and its views (two in the original implementation) are generated using augmentation techniques like random cropping, random color distortion, or Gaussian blurring.
- Image representation/embedding is computed using a ResNet-based CNN.
- This representation is further transformed into a (non-linear) projection using an MLP.
- Both CNN and MLP are trained to minimize the contrastive loss function.
Fine-tuning the CNN on labeled images can further increase its performance and generalization on downstream tasks.
Read also: Inductive Representation Learning
Insights from SimCLR
SimCLR's authors provided insights useful for almost any contrastive learning method:
- A combination of augmentation techniques is critical.
- Nonlinear projection is important.
- Scaling up improves performance.
Momentum Contrast (MoCo)
Momentum Contrast (MoCo) views contrastive learning as a dictionary lookup. Data augmentation is applied to produce two copies, xq and xk. A query encoder takes xq and generates embeddings. The momentum encoder takes xk and dynamically generates a dictionary of keys. The encoded query matches the dictionary of keys, and contrastive loss (infoNCE) is calculated. Both encoders are trained together to minimize this contrastive loss.
CLIP
CLIP combines images and their captions, using two encoders: one for text and the other for images. The image encoder, based on a Vision Transformer (ViT), gets the image embedding, while the text encoder tokenizes the caption to get the text features. These features are collated as a pair in the embedding space. Both text and image encoders are trained to maximize the distance of any given pair of (Ii, Ti) with the others. During testing, a dictionary of captions and the desired image are provided, and CLIP returns the caption with the highest probability based on the image.
Code Example
The official implementation of MoCo by Meta Research centers around the Moco class. The constructor initializes attributes like K, m, and T, using default values for feature dimension (dim) as 128, queue size (K) as 65,536, momentum co-efficient, μ as 0.999, and softmax temperature τ as 0.07. The implementation includes both versions of MLP, which was first seen in SimCLR and later in MoCov2.
def __init__(self, base_encoder, dim=128, K=65536, m=0.999, T=0.07, mlp=False): super(MoCo, self).__init__() self.K = K self.m = m self.T = T # create the encoders # num_classes is the output fc dimension self.encoder_q = base_encoder(num_classes=dim) self.encoder_k = base_encoder(num_classes=dim) if mlp: # hack: brute-force replacement dim_mlp = self.encoder_q.fc.weight.shape[1] self.encoder_q.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_q.fc) self.encoder_k.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_k.fc) for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): param_k.data.copy_(param_q.data) # initialize param_k.requires_grad = False # not update by gradient # create the queue self.register_buffer("queue", torch.randn(dim, K)) self.queue = nn.functional.normalize(self.queue, dim=0) self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))The momentum encoder gathers the keys, and the queue’s pointer points to the keys transpose. The pointer moves to the next batch, reflected in the attribute (self.queue_ptr).
Read also: Learn about Matryoshka Representation Learning
def _dequeue_and_enqueue(self, keys): keys = concat_all_gather(keys) batch_size = keys.shape[0] ptr = int(self.queue_ptr) assert self.K % batch_size == 0 # for simplicity self.queue[:, ptr : ptr + batch_size] = keys.T ptr = (ptr + batch_size) % self.K # move pointer self.queue_ptr[0] = ptrIn the training pass, query features are computed, key features are computed using the momentum update, logits are computed, and the queue is dequeued and enqueued.
def forward(self, im_q, im_k): """ Input: im_q: a batch of query images im_k: a batch of key images Output: logits, targets """ # compute query features q = self.encoder_q(im_q) # queries: NxC q = nn.functional.normalize(q, dim=1) # compute key features with torch.no_grad(): # no gradient to keys self._momentum_update_key_encoder() # update the key encoder # shuffle for making use of BN im_k, idx_unshuffle = self._batch_shuffle_ddp(im_k) k = self.encoder_k(im_k) # keys: NxC k = nn.functional.normalize(k, dim=1) # undo shuffle k = self._batch_unshuffle_ddp(k, idx_unshuffle) # compute logits # Einstein sum is more intuitive # positive logits: Nx1 l_pos = torch.einsum("nc,nc->n", [q, k]).unsqueeze(-1) # negative logits: NxK l_neg = torch.einsum("nc,ck->nk", [q, self.queue.clone().detach()]) # logits: Nx(1+K) logits = torch.cat([l_pos, l_neg], dim=1) # apply temperature logits /= self.T # labels: positive key indicators labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda() # dequeue and enqueue self._dequeue_and_enqueue(k) return logits, labelsContrastive Learning and Vector Databases
Contrastive learning centers around input data embeddings and their distance, making it highly relevant to vector databases. The goal is to learn an embedding space where similar sample pairs are close and dissimilar ones are far apart. Contrastive learning can be applied to both supervised and unsupervised settings. Early versions of loss functions involved only one positive and one negative sample.
Contrastive Loss
Contrastive loss (Chopra et al.) aims to learn a function fθ(.) that encodes input samples xi into embedding vectors such that examples from the same class have similar embeddings and samples from different classes have very different ones.
Triplet Loss
Triplet loss (Schroff et al.) involves selecting one positive sample and one negative sample for each anchor input x. The positive sample x+ belongs to the same class as x, and the negative sample x- is sampled from a different class.
Lifted Structured Loss
Lifted Structured Loss (Song et al.) aims to learn embeddings such that all positive pairs are closer than any negative pair by a fixed margin.
Noise Contrastive Estimation (NCE)
Noise Contrastive Estimation (NCE) is a method for estimating parameters of a statistical model. It involves running logistic regression to distinguish the target data from noise.
InfoNCE Loss
The InfoNCE loss in CPC (Contrastive Predictive Coding; van den Oord, et al.) estimates the density ratio between the conditional distribution and the proposal distribution.
Soft-Nearest Neighbors Loss
Soft-Nearest Neighbors Loss (Salakhutdinov & Hinton 2007, Frosst et al.) uses a temperature parameter to tune the concentration of features in the representation space.
Data Augmentation
Data augmentation techniques are needed to create noise versions of training samples to feed into the loss as positive samples. Proper data augmentation is critical for learning good and generalizable embedding features.
Batch Size
Using a large batch size during training is another key ingredient in the success of many contrastive learning methods.
Hard Negative Mining
Hard negative samples have different labels from the anchor sample but have embedding features very close to the anchor embedding.
Debiased Contrastive Learning
Chuang et al. (2020) studied the sampling bias in contrastive learning and proposed debiased loss.
Multi-Crop Augmentation
Multi-crop augmentation uses two standard resolution crops and samples a set of additional low-resolution crops that cover only small parts of the image.
AutoAugment
AutoAugment (Cubuk, et al. 2018) frames the problem of learning best data augmentation operations.
CutMix
CutMix (Yun et al., 2019) does region-level mixture by generating a new example by combining a local region of one image with the rest of the other image.
MoCHi
MoCHi (“Mixing of Contrastive Hard Negatives”; Kalantidis et al. 2020) maintains a queue of K negative features and sorts these negative features by similarity to the query.
SimCLR Framework
SimCLR (Chen et al, 2020) proposed a simple framework for contrastive learning of visual representations.
Barlow Twins
Barlow Twins (Zbontar et al. 2021) feeds two distorted versions of samples into the same network to extract features and learns to make the cross-correlation matrix between these two groups of output features close to the identity.
BYOL
BYOL (Bootstrap Your Own Latent; Grill, et al 2020) achieves state-of-the-art results without using negative samples. It relies on two neural networks, referred to as online and target networks that interact and learn from each other.
Instance Contrastive Learning
Instance contrastive learning (Wu et al, 2018) pushes the class-wise supervision to the extreme by considering each instance as a distinct class of its own.
Momentum Contrast (MoCo) Framework
Momentum Contrast (MoCo; He et al, 2019) provides a framework of unsupervised learning visual representation as a dynamic dictionary look-up.
CURL
CURL (Srinivas, et al. 2020) applies the ideas from SimCLR in Reinforcement Learning. It learns a visual representation for RL tasks by matching embeddings of two data-augmented versions of the raw observation.
DeepCluster
DeepCluster (Caron et al.) iteratively clusters deep features and uses the cluster assignments as pseudo-labels.
SwAV
SwAV (Swapping Assignments between multiple Views; Caron et al. 2020) is an online contrastive learning algorithm.
CLIP Framework
CLIP (Contrastive Language-Image Pre-training; Radford et al.) performs contrastive pre-training over text-image pairs.
Introduction to Contrastive Predictive Coding (CPC)
CPC extracts meaningful representations from unlabeled data by maximizing mutual information between input sequences, creating embeddings that are compact and powerful for downstream tasks. CPC improves performance in domains like speech recognition and image classification. Leveraging CPC for feature extraction can reduce the need for labeled data while delivering state-of-the-art accuracy in transfer learning tasks.
This guide provides a clear understanding of:
- How CPC works and why it’s effective.
- How to implement CPC step-by-step with custom encoders and autoregressive models.
- How to evaluate and fine-tune your representations for specific tasks.
High-Level Overview of Contrastive Predictive Coding
CPC captures meaningful patterns or representations from the input by predicting the future. These representations are learned by maximizing the mutual information between parts of the input. CPC trains a model to focus on the relationships and dependencies within the data, making CPC representations effective for tasks like classification and transfer learning. This predictive approach shines when working with high-dimensional data.
Real-World Use Cases
- Speech Recognition: CPC improves automatic speech recognition (ASR) models.
- Computer Vision: CPC is the backbone for pretraining a model on unannotated image data.
- Reinforcement Learning: CPC helps agents understand their surroundings more effectively, improving overall performance.
Key Components of CPC
Each component of CPC, including the encoder, autoregressive model, contrastive loss, and negative sampling, plays a critical role.
1. Data Encoder
The encoder transforms input data into compact, meaningful representations. Choosing the right encoder architecture depends on the domain:
- Images: Convolutional Neural Networks (CNNs) are the standard.
- Sequential Data: 1D CNNs, RNNs, or Transformer encoders work well.
class CPCEncoder(nn.Module): def __init__(self, input_dim, hidden_dim): super(CPCEncoder, self).__init__() self.conv1 = nn.Conv1d(input_dim, hidden_dim, kernel_size=5, stride=2) self.conv2 = nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, stride=1) self.relu = nn.ReLU() def forward(self, x): x = self.relu(self.conv1(x)) x = self.relu(self.conv2(x)) return x2. Autoregressive Model
The autoregressive model predicts future representations while retaining high-level dependencies in the data. GRUs are fast and work well for moderately sized datasets, while Transformers are suitable for data with long-range dependencies.
class AutoregressiveModel(nn.Module): def __init__(self, input_dim, hidden_dim): super(AutoregressiveModel, self).__init__() self.gru = nn.GRU(input_dim, hidden_dim, batch_first=True) def forward(self, x): _, hidden = self.gru(x) return hidden.squeeze(0) # Remove batch dimension3. Contrastive Loss (InfoNCE)
The InfoNCE loss distinguishes between positive and negative pairs of data representations.
def info_nce_loss(positive, negative, temperature=0.1): # Combine positive and negative pairs into logits logits = torch.cat([positive, negative], dim=1) labels = torch.zeros(positive.size(0), dtype=torch.long).to(positive.device) loss = nn.CrossEntropyLoss()(logits / temperature, labels) return loss4. Choice of Negative Sampling
Poorly chosen negatives can lead to representations that don’t generalize. Advanced techniques include:
- In-Batch Negatives: Straightforward but often biased.
- Hard Negatives: Selecting negatives closer to positives in the feature space boosts representation quality.
def get_negatives(features, memory_bank, k=5): # Random in-batch negatives in_batch_negatives = features[torch.randperm(features.size(0))] # Hard negatives from memory bank distances = torch.cdist(features, memory_bank) hard_negatives = memory_bank[torch.topk(distances, k, largest=False).indices] # Combine and return return torch.cat([in_batch_negatives, hard_negatives], dim=0)Setting Up the Environment
The essential libraries for CPC include PyTorch, NumPy, and Matplotlib.
pip install torch torchvision numpy matplotlibStep-by-Step Implementation
5.1 Data Preparation
Choosing the right dataset is crucial. LibriSpeech is suitable for audio-based tasks, while CIFAR-10 or ImageNet is suitable for vision tasks.
import torchaudioimport torchdef preprocess_audio(filepath, target_sample_rate=16000): # Load audio file waveform, sample_rate = torchaudio.load(filepath) # Resample to target rate if sample_rate != target_sample_rate: waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)(waveform) # Normalize waveform = waveform / torch.max(torch.abs(waveform)) return waveform# Example usagewaveform = preprocess_audio("path_to_audio.wav")print(waveform.shape)5.2 Building the Encoder
import torch.nn as nnclass CPC_Encoder(nn.Module): def __init__(self, input_dim, hidden_dim): super(CPC_Encoder, self).__init__() self.conv1 = nn.Conv1d(input_dim, hidden_dim, kernel_size=5, stride=2) self.conv2 = nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, stride=1) self.relu = nn.ReLU() def forward(self, x): x = self.relu(self.conv1(x)) x = self.relu(self.conv2(x)) return x5.3 Designing the Autoregressive Model
class AutoregressiveModel(nn.Module): def __init__(self, input_dim, hidden_dim): super(AutoregressiveModel, self).__init__() self.gru = nn.GRU(input_dim, hidden_dim, batch_first=True) def forward(self, x): _, hidden = self.gru(x) return hidden.squeeze(0)5.4 Contrastive Loss Implementation
def info_nce_loss(positive, negative, temperature=0.1): # Combine positive and negative pairs into logits logits = torch.cat([positive, negative], dim=1) labels = torch.zeros(positive.size(0), dtype=torch.long).to(positive.device) loss = nn.CrossEntropyLoss()(logits / temperature, labels) return lossTraining Pipeline
The training pipeline handles data efficiently, stabilizes gradients, and streamlines the overall process. Gradient accumulation simulates a larger batch size without needing massive GPU memory.
tags: #contrastive #predictive #coding #representation #learning #tutorial

