Home AI/ML Self-Supervised Learning (SSL) for Pretraining: A Complete Guide

Self-Supervised Learning (SSL) for Pretraining: A Complete Guide

Last updated: May 27, 2026
k
Published April 17, 2026 · Updated May 27, 2026 · 47 min read

Summary

What this post covers: A complete examination of self-supervised learning, including its taxonomy, the mathematics of contrastive learning and masked modelling, PyTorch implementations of SimCLR and MAE, and the pretraining-to-fine-tuning workflow that defines modern AI.

Key insights:

  • SSL breaks the labelling bottleneck that constrained supervised learning for decades by turning the structure of unlabelled data into its own supervisory signal. This is the same mechanism that underlies GPT, BERT, DINO, MAE, CLIP and essentially every frontier model.
  • The field has converged on four major families: contrastive methods (SimCLR, MoCo, BYOL), masked modelling (BERT, MAE, BEiT), generative methods (GPT-style autoregression) and self-distillation (DINO). Each suits specific modalities and compute budgets.
  • Contrastive learning requires large batches and careful augmentation design; masked modelling tolerates smaller batches and is currently the appropriate default for transformer-based vision and language pretraining.
  • SSL representations now match or exceed supervised ImageNet pretraining on most downstream benchmarks, and the same recipe transfers to speech (wav2vec 2.0, HuBERT), time series, graphs and multimodal data (CLIP).
  • For practitioners, the practical approach is to select the SSL family that matches the modality, pretrain on as much unlabelled in-domain data as the budget permits, and then fine-tune on a small labelled set. This two-stage pipeline almost always exceeds training from scratch.

Main topics: Why Self-Supervised Learning Matters, The SSL Taxonomy: A Complete Map, Contrastive Learning in Depth, Masked Modeling in Depth, PyTorch Implementation from Scratch, The Pretraining to Fine-Tuning Pipeline, SSL Beyond Vision and NLP, Practical Guide: Choosing and Using SSL, Method Comparison Table, Frequently Asked Questions, Closing Thoughts, References and Further Reading.

GPT-4 was trained on trillions of tokens without a single human label. DINO can segment objects without ever observing a segmentation mask. The underlying mechanism is Self-Supervised Learning, the technique behind almost every frontier AI model today.

The observation merits emphasis. The most powerful AI systems ever built, including those that write code, generate images, translate languages and assist in diagnosing diseases, did not learn their core representations from carefully curated, hand-labelled datasets. They learned by solving puzzles that the data itself provided: predict the next word; reconstruct a masked patch; determine whether two augmented views originated from the same image. No human annotator labelled trillions of training examples. The data itself served as the teacher.

This is not a minor technical detail. It represents a fundamental shift in how AI systems are built, and understanding it is essential for anyone working in machine learning today. Whether the task involves training vision models, language models, time series forecasters or graph neural networks, the paradigm is the same: pretrain with self-supervision on substantial unlabelled data, then fine-tune on the specific task with a small labelled dataset.

Key Takeaway: Self-supervised learning generates its own supervisory signal from the structure of unlabelled data. It has become the default pretraining strategy for nearly every modality, including text, images, audio, time series, graphs and multimodal systems.

The following sections present a comprehensive treatment. They cover the full taxonomy of SSL methods, examine the mathematics of contrastive and masked modelling objectives, implement SimCLR and MAE from scratch in PyTorch, walk through the pretraining-to-fine-tuning pipeline, and survey SSL’s expanding reach into domains beyond vision and NLP. By the end, the reader will have both the conceptual understanding and the working code required to apply SSL to their own problems.

Why Self-Supervised Learning Matters

The Labeling Bottleneck

Supervised learning carries a substantial cost: it is exceptionally expensive. ImageNet took years and millions of dollars to annotate 14 million images. Medical imaging datasets require board-certified radiologists at hundreds of dollars per hour. Autonomous driving datasets need teams of annotators drawing pixel-perfect segmentation masks for every frame. Even after all such effort, these labelled datasets remain small compared with the volume of unlabelled data that exists.

Consider the figures. YouTube receives 500 hours of video every minute. The Common Crawl contains petabytes of web text. Hospitals generate millions of medical images annually, the vast majority unlabelled. Industrial sensors stream terabytes of time series data daily. There is a substantial asymmetry between the labelled data that can be afforded and the unlabelled data that already exists.

This is the labelling bottleneck, and it has been the central constraint of applied machine learning for decades. Self-supervised learning removes that constraint by converting unlabelled data into a source of supervision.

SSL Bridges Unsupervised and Supervised Learning

Traditional unsupervised learning, including clustering, dimensionality reduction and density estimation, learns structure within data but does not produce representations optimised for downstream tasks. Supervised learning produces task-specific representations but requires labels. SSL occupies the productive middle ground: it creates its own labels from the data’s inherent structure, producing representations that transfer effectively to downstream tasks.

The key insight is simple but consequential: a pretext task can be designed that forces the model to learn useful representations without any human annotation. Predicting the next word requires the model to understand grammar, semantics and world knowledge. Reconstructing a masked image patch requires the model to understand object shapes, textures and spatial relationships. Determining whether two views originated from the same image requires the model to learn viewpoint-invariant, semantically meaningful features.

The pretext task is not the end goal. It is the mechanism by which the model acquires general-purpose representations that can later be fine-tuned for any downstream task. This is the pretraining revolution.

The Pretraining Revolution

The modern ML paradigm is a two-stage pipeline: SSL pretraining on large unlabelled data, followed by supervised fine-tuning on small labelled data. This approach now dominates virtually every domain.

  • Natural Language Processing. GPT (autoregressive pretraining), BERT (masked language modelling) and T5 (span corruption) all use SSL pretraining. The success of modern LLMs such as GPT-4 and Claude is built entirely on this foundation.
  • Computer Vision. SimCLR, MoCo and BYOL (contrastive learning), MAE and BEiT (masked image modelling) and DINO (self-distillation) now match or exceed supervised ImageNet pretraining.
  • Speech and Audio. wav2vec 2.0 and HuBERT learn speech representations from raw audio without transcriptions.
  • Multimodal. CLIP learns joint text-image representations from 400 million image-text pairs scraped from the internet, without manual labelling.

Any reader who has worked with transfer learning and fine-tuning has already benefited from SSL. Most pretrained models that are downloaded were pretrained using self-supervised objectives.

The SSL Taxonomy: A Complete Map

Self-supervised learning is not a single technique. It is a family of methods that share the principle of deriving supervision from data structure. The full landscape is examined below.

Self-Supervised Learning—Taxonomy Self-Supervised Learning Contrastive Methods SimCLR (Chen 2020) MoCo (He 2020) BYOL (Grill 2020) Barlow Twins (Zbontar 2021) SwAV (Caron 2020) Masked Modeling BERT (Devlin 2019) MAE (He 2022) BEiT (Bao 2022) data2vec (Baevski 2022) Generative Methods GPT Autoregressive (2018+) VAE-Based Methods Diffusion Pretraining Self-Distillation DINO (Caron 2021) DINOv2 (Oquab 2024) EsViT (Li 2022) Core Principles Contrastive: Pull positive pairs together, push negatives apart Masked Modeling: Mask portions of input, predict the masked content Generative: Predict next token or reconstruct full input Self-Distillation: Student learns from teacher (itself, with EMA) All methods share one goal: learn powerful representations from unlabeled data

Contrastive Methods

Contrastive learning is built on a simple but powerful idea: learn representations in which similar items are close together and dissimilar items are far apart in embedding space. The challenge is defining “similar” without labels. The solution is data augmentation. Two augmented views of the same image, or the same sentence with different dropout masks, form a positive pair. Views from different images form negative pairs.

SimCLR (Chen et al., 2020) is the conceptually simplest contrastive method. An image is taken, two random augmentations are created, both pass through an encoder and a projection head, and the model is trained to recognise that the two resulting representations originated from the same image, while pushing apart representations from different images. The loss function is NT-Xent (Normalised Temperature-scaled Cross-Entropy), a variant of InfoNCE. SimCLR’s principal weakness is its requirement for substantial batch sizes (4,096 or more) in order to provide sufficient negatives.

MoCo (He et al., 2020) addresses the batch-size problem with a momentum encoder and a queue of negatives. Rather than requiring all negatives to be present in the current batch, MoCo maintains a queue of recent representations. The key encoder is updated via exponential moving average (EMA) of the query encoder, providing consistent targets without backpropagation through the key encoder.

BYOL (Grill et al., 2020) demonstrated a surprising result: negative pairs are not required. BYOL employs a teacher-student architecture in which the student predicts the teacher’s representation, and the teacher is an EMA of the student. A stop-gradient on the teacher prevents collapse. The approach was initially controversial owing to questions about how it avoids the trivial solution of constant outputs, but it performs strongly in practice.

Barlow Twins (Zbontar et al., 2021) takes a different approach. Rather than contrasting individual samples, it computes the cross-correlation matrix between the embeddings of two augmented views and pushes it toward the identity matrix. This achieves redundancy reduction, in which each dimension of the embedding captures distinct information.

SwAV (Caron et al., 2020) combines contrastive learning with online clustering. Rather than directly comparing representations, it assigns augmented views to prototype clusters and trains the model so that different views of the same image are assigned to the same cluster. Multi-crop augmentation, in which multiple small crops accompany two global crops, improves performance substantially.

Masked Modeling Methods

Masked modelling is the other major SSL paradigm. Its principle is to hide part of the input and train the model to predict the hidden portion. This forces the model to learn the statistical structure of the data.

BERT (Devlin et al., 2019) pioneered masked language modeling (MLM) for NLP. It masks 15% of input tokens and trains a Transformer to predict the masked tokens from context. This seemingly simple objective produces representations that capture deep linguistic knowledge, syntax, semantics, coreference, and even some world knowledge. BERT’s representations power everything from search engines to retrieval-augmented generation systems.

MAE (He et al., 2022) applied masked modeling to images with spectacular results. It masks a whopping 75% of image patches and trains a Vision Transformer to reconstruct the masked patches. The key innovation is asymmetric design: only the visible 25% of patches pass through the heavy encoder, while a lightweight decoder handles reconstruction. This makes MAE highly compute-efficient.

BEiT (Bao et al., 2022) takes a different approach to masked image modeling. Instead of reconstructing raw pixels, it predicts discrete visual tokens generated by a pre-trained dVAE (discrete variational autoencoder). This makes the prediction task more semantic and less focused on low-level pixel details.

data2vec (Baevski et al., 2022) unifies masked modeling across modalities. It uses the same framework for speech, vision, and text: a student model predicts the representations of a teacher model (EMA) for masked portions of the input. The target is the teacher’s latent representation, not the raw input.

Generative Methods

Generative SSL methods learn by generating or reconstructing data.

GPT-style autoregressive pretraining is technically a form of self-supervised learning: predict the next token given all previous tokens. No labels are needed—the next token in the sequence is the label. This deceptively simple objective, scaled to trillions of tokens, produces the large language models that have transformed AI.

VAE-based methods learn by encoding data to a latent space and reconstructing it. The encoder must capture meaningful structure to enable accurate reconstruction. While less dominant than contrastive or masked methods for representation learning, VAEs remain important for generative tasks.

Diffusion-based pretraining is an emerging area. Models like Stable Diffusion learn to denoise images, which requires understanding image structure at multiple scales. Recent work shows that diffusion model encoders can produce competitive representations for downstream tasks.

Self-Distillation Methods

DINO (Caron et al., 2021) demonstrated that self-distillation with Vision Transformers produces remarkable emergent properties. A student network learns to match the output distribution of a teacher network (EMA of the student) across different augmented views. The stunning result: DINO features contain explicit information about object boundaries—the attention maps perform unsupervised object segmentation. No segmentation labels were ever used.

DINOv2 (Oquab et al., 2024) scaled up DINO with larger datasets, more compute, and a combination of self-distillation and masked image modeling. The resulting features are so powerful that they serve as general-purpose visual features competitive with or superior to OpenAI’s CLIP across a wide range of benchmarks, without any text supervision.

Contrastive Learning in Depth

The InfoNCE Loss

At the heart of contrastive learning is the InfoNCE loss (and its variants). Let us build up the mathematics carefully.

Given a batch of N images, we create two augmented views of each, yielding 2N total views. For a positive pair (i, j)—two views of the same image—the NT-Xent loss is:

L(i,j) = -log( exp(sim(z_i, z_j) / τ) / Σ_k exp(sim(z_i, z_k) / τ) )

where:
  sim(z_i, z_j) = (z_i · z_j) / (||z_i|| · ||z_j||)    # cosine similarity
  τ = temperature parameter (typically 0.07 to 0.5)
  k ranges over all 2N views except i (including all negatives and the positive j)

This is essentially a (2N-1)-way classification problem: given anchor z_i, identify which of the other 2N-1 representations is its positive pair z_j. The temperature τ controls the “hardness” of this classification. Lower temperature makes the model focus more on hard negatives (representations that are similar but from different images), while higher temperature makes the distribution more uniform.

The connection to mutual information is deep: the InfoNCE loss provides a lower bound on the mutual information between the two views. Maximizing this bound encourages the encoder to capture information that is shared across views (semantic content) while discarding information that differs (augmentation-specific noise like color jitter or crop position).

Augmentation Strategies

Augmentation is not just a detail in contrastive learning, it is the entire source of the learning signal. The choice of augmentations defines what information the model must preserve (shared across augmentations) and what it can discard (varies across augmentations).

For images, the standard SimCLR augmentation pipeline includes:

  • Random resized crop: The most important augmentation. Forces the model to recognize objects regardless of scale and position.
  • Random horizontal flip: Teaches left-right invariance.
  • Color jitter: Random changes to brightness, contrast, saturation, and hue. Prevents the model from relying on color histograms.
  • Random grayscale: Applied with 20% probability. Further reduces color dependence.
  • Gaussian blur: Forces the model to learn from shape rather than texture details.

Chen et al. showed that random resized crop combined with color jitter is by far the most important augmentation combination. Without color jitter, the model can “cheat” by simply learning to match color histograms rather than semantic content.

For text, augmentations are different: dropout masks (as used in SimCSE), token deletion, synonym replacement, or back-translation. For time series, augmentations include temporal jitter, amplitude scaling, time warping, and window cropping.

The Projection Head

A surprising finding from SimCLR: representations are much better when you apply the contrastive loss to the output of a small projection head (an MLP) on top of the encoder, rather than directly to the encoder’s output. After training, you throw away the projection head and use the encoder’s output for downstream tasks.

Why does this work? The projection head acts as an information bottleneck that absorbs augmentation-specific information. The contrastive loss encourages representations that are invariant to augmentations—but some augmentation-specific information (like precise spatial layout) might be useful for downstream tasks. The projection head lets the contrastive loss “consume” augmentation-invariance at the projection layer while preserving richer information in the encoder.

Batch Size, Momentum Encoders, and Collapse Prevention

SimCLR needs large batch sizes (4096 or more) because the quality of contrastive learning depends on having enough negative pairs. With a batch of N images, you get 2(N-1) negatives per positive pair. More negatives means a harder discrimination task, which produces better representations.

MoCo elegantly avoids this requirement. It maintains a queue of 65,536 encoded representations from recent batches. The key encoder that produces queue entries is updated via exponential moving average (EMA) of the query encoder with momentum coefficient m = 0.999:

θ_key = m * θ_key + (1 - m) * θ_query

This slow update ensures that the queue entries are consistent—they all come from “similar” versions of the encoder, even though the query encoder is updating rapidly via gradient descent.

Caution: Representation collapse is the existential threat to contrastive learning. If the model learns to output a constant vector for all inputs, the loss is trivially minimized (all similarities are identical). SimCLR prevents collapse through negative pairs. BYOL prevents it through stop-gradient and EMA. Barlow Twins prevents it through redundancy reduction. If your SSL training loss drops suspiciously fast and representations look uniform, you likely have collapse.

Each method has its own collapse prevention mechanism, and understanding this is crucial for debugging SSL training:

  • SimCLR/MoCo: Negative pairs explicitly push representations apart. No negatives → collapse.
  • BYOL: Stop-gradient on the teacher prevents the degenerate solution. The asymmetry between student (has predictor MLP) and teacher (no predictor) is essential.
  • Barlow Twins: The off-diagonal terms of the cross-correlation matrix are penalized, preventing all dimensions from encoding the same information.
  • SwAV: The Sinkhorn-Knopp algorithm ensures balanced cluster assignments, preventing all samples from collapsing to one cluster.

Masked Modeling in Depth

BERT’s Masked Language Modeling

BERT masks 15% of input tokens and trains a Transformer encoder to predict them. But the masking strategy has subtleties:

  • 80% of the time, the selected token is replaced with [MASK]
  • 10% of the time, it is replaced with a random token
  • 10% of the time, it is kept unchanged

Why this complexity? If the model only ever sees [MASK] tokens during training, it will never see them during fine-tuning, creating a train-test mismatch. The random replacement forces the model to maintain a good representation of every token position (it cannot tell which tokens are corrupted), and keeping some tokens unchanged teaches the model that the original token might be correct.

The 15% masking rate is deliberately low for text. Language is highly structured—natural language has enough redundancy that even 15% masking forces the model to develop deep contextual understanding. Masking much more would make the task too ambiguous (many valid completions become possible).

MAE: Masked Autoencoders for Vision

MAE takes masked modeling to images, but with a dramatically different masking ratio: 75%. Why can you mask three-quarters of an image when BERT only masks 15% of text? Because images have much higher spatial redundancy than language. A missing patch can often be interpolated from its neighbors. You need to mask a lot to force the model to learn real semantic understanding rather than simple local interpolation.

MAE’s architecture is brilliantly efficient through asymmetry:

  1. Divide the image into non-overlapping patches (e.g., 16×16 pixels each for a 224×224 image = 196 patches)
  2. Randomly mask 75% of patches (keep 49 patches, mask 147)
  3. Encode only the visible 25% with a large ViT encoder
  4. Add learnable mask tokens for the masked positions
  5. Decode all patches (visible + mask tokens) with a small decoder
  6. Compute loss only on the masked patches (MSE between predicted and original pixel values)

The key efficiency insight: the heavy encoder only processes 25% of patches. Since self-attention is O(n^2), processing 49 patches instead of 196 reduces encoder computation by roughly 16x. This makes MAE much faster to train than contrastive methods that must process full images twice.

Masked Autoencoder (MAE)—Architecture Original Image 16 patches (4×4) mask 75% After Masking 4 visible, 12 masked ViT Encoder (Large) Only processes visible 25% 4 patches only! add mask tokens Decoder (Small, lightweight) Processes all 16 tokens (4 encoded + 12 mask tokens) Reconstructed Predicted masked patches MSE Loss (masked only) Why MAE is Compute-Efficient Standard ViT Encodes all 196 patches Self-attention: O(196^2) = O(38,416) Expensive MAE Encoder Encodes only 49 visible patches (25%) Self-attention: O(49^2) = O(2,401) ~16x faster! After pretraining, discard decoder. Use encoder for downstream tasks. Visible patches (kept) Masked patches (hidden) Reconstructed patches (predicted) He et al. 2022,Masked Autoencoders Are Scalable Vision Learners

Why Masking Ratio Matters

The masking ratio is one of the most important hyperparameters in masked modeling, and the optimal value depends entirely on the modality:

  • Text (BERT): 15%—Language has high information density. Each token carries significant semantic content. Masking too much makes prediction too ambiguous.
  • Images (MAE): 75%—Images have high spatial redundancy. Neighboring pixels are highly correlated. You need to mask a lot to prevent trivial interpolation.
  • Audio (wav2vec 2.0): ~50%,Audio falls between text and images in information density.

He et al. showed that MAE performance peaks at 75% masking and degrades significantly below 50% or above 90%. Below 50%, the task is too easy—the model can reconstruct from local context. Above 90%, too little information remains for meaningful reconstruction.

Positional embeddings play a crucial role in masked modeling. When 75% of patches are masked, the decoder must know where each mask token belongs to reconstruct the correct content. Without strong positional embeddings, reconstruction would be impossible—the decoder would not know whether a mask token should contain sky, grass, or a car bumper.

PyTorch Implementation from Scratch

This section implements the two flagship SSL methods, SimCLR and a simplified MAE, in complete, runnable PyTorch code. Downstream evaluation via linear probing and fine-tuning is also implemented.

SimCLR: Contrastive Learning Implementation

First, the complete SimCLR pipeline: augmentation, encoder, projection head, NT-Xent loss, and training loop.

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, datasets, models
from torch.utils.data import DataLoader
import numpy as np


# ============================================================
# Step 1: SimCLR Augmentation Pipeline
# ============================================================
class SimCLRAugmentation:
    """Creates two correlated views of the same image."""

    def __init__(self, size=32):
        # For CIFAR-10 (32x32). Scale sizes for larger images.
        self.transform = transforms.Compose([
            transforms.RandomResizedCrop(size=size, scale=(0.2, 1.0)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomApply([
                transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
            ], p=0.8),
            transforms.RandomGrayscale(p=0.2),
            transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.4914, 0.4822, 0.4465],
                std=[0.2470, 0.2435, 0.2616]
            ),
        ])

    def __call__(self, x):
        """Return two augmented views of the same image."""
        return self.transform(x), self.transform(x)


class SimCLRDataset:
    """Wrapper that applies SimCLR augmentation to any dataset."""

    def __init__(self, dataset, augmentation):
        self.dataset = dataset
        self.augmentation = augmentation

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        img, label = self.dataset[idx]
        view1, view2 = self.augmentation(img)
        return view1, view2, label


# ============================================================
# Step 2: SimCLR Model (Encoder + Projection Head)
# ============================================================
class SimCLR(nn.Module):
    """SimCLR model with ResNet encoder and MLP projection head."""

    def __init__(self, base_encoder='resnet18', projection_dim=128,
                 hidden_dim=256):
        super().__init__()

        # Encoder: ResNet without the final classification layer
        if base_encoder == 'resnet18':
            self.encoder = models.resnet18(weights=None)
            encoder_dim = 512
        elif base_encoder == 'resnet50':
            self.encoder = models.resnet50(weights=None)
            encoder_dim = 2048
        else:
            raise ValueError(f"Unknown encoder: {base_encoder}")

        # Remove the final fully connected layer
        self.encoder.fc = nn.Identity()

        # Projection head: 2-layer MLP
        # This is where the contrastive loss is applied.
        # After training, we DISCARD this and use encoder output.
        self.projection_head = nn.Sequential(
            nn.Linear(encoder_dim, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, projection_dim),
        )

        self.encoder_dim = encoder_dim

    def forward(self, x):
        """Returns both encoder features and projected features."""
        h = self.encoder(x)           # shape: (batch, encoder_dim)
        z = self.projection_head(h)   # shape: (batch, projection_dim)
        return h, z


# ============================================================
# Step 3: NT-Xent Loss (Normalized Temperature-scaled Cross-Entropy)
# ============================================================
class NTXentLoss(nn.Module):
    """NT-Xent loss for contrastive learning (SimCLR).

    For a batch of N images producing 2N augmented views,
    each image has exactly 1 positive pair and 2(N-1) negatives.
    """

    def __init__(self, temperature=0.5):
        super().__init__()
        self.temperature = temperature

    def forward(self, z_i, z_j):
        """
        Args:
            z_i: projections from first augmented view  (N, dim)
            z_j: projections from second augmented view (N, dim)
        Returns:
            Scalar loss value
        """
        batch_size = z_i.shape[0]

        # Normalize projections to unit sphere
        z_i = F.normalize(z_i, dim=1)
        z_j = F.normalize(z_j, dim=1)

        # Concatenate: [z_i_0, z_i_1, ..., z_j_0, z_j_1, ...]
        z = torch.cat([z_i, z_j], dim=0)  # (2N, dim)

        # Compute pairwise cosine similarity matrix
        sim_matrix = torch.mm(z, z.T) / self.temperature  # (2N, 2N)

        # Mask out self-similarity (diagonal)
        mask = torch.eye(2 * batch_size, dtype=torch.bool,
                         device=z.device)
        sim_matrix.masked_fill_(mask, -float('inf'))

        # For each z_i[k], positive is z_j[k] (at index k + N)
        # For each z_j[k], positive is z_i[k] (at index k)
        positive_indices = torch.cat([
            torch.arange(batch_size, 2 * batch_size),
            torch.arange(0, batch_size)
        ]).to(z.device)

        # NT-Xent is cross-entropy with positives as targets
        loss = F.cross_entropy(sim_matrix, positive_indices)
        return loss


# ============================================================
# Step 4: Training Loop
# ============================================================
def train_simclr(model, dataloader, optimizer, criterion,
                 epochs=100, device='cuda'):
    """Full SimCLR pretraining loop."""
    model.train()

    for epoch in range(epochs):
        total_loss = 0
        num_batches = 0

        for view1, view2, _ in dataloader:
            view1 = view1.to(device)
            view2 = view2.to(device)

            # Forward pass through encoder + projection head
            _, z_i = model(view1)
            _, z_j = model(view2)

            # Compute NT-Xent loss
            loss = criterion(z_i, z_j)

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            num_batches += 1

        avg_loss = total_loss / num_batches
        if (epoch + 1) % 10 == 0:
            print(f"Epoch [{epoch+1}/{epochs}] | Loss: {avg_loss:.4f}")

    return model


# ============================================================
# Step 5: Full Pipeline — Pretrain on CIFAR-10
# ============================================================
def run_simclr_pretraining():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Load CIFAR-10 (no labels needed for pretraining!)
    raw_dataset = datasets.CIFAR10(
        root='./data', train=True, download=True
    )

    augmentation = SimCLRAugmentation(size=32)
    ssl_dataset = SimCLRDataset(raw_dataset, augmentation)
    dataloader = DataLoader(
        ssl_dataset, batch_size=256, shuffle=True,
        num_workers=4, pin_memory=True, drop_last=True
    )

    # Initialize model, optimizer, loss
    model = SimCLR(
        base_encoder='resnet18',
        projection_dim=128,
        hidden_dim=256
    ).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=3e-4,
                                 weight_decay=1e-4)

    criterion = NTXentLoss(temperature=0.5)

    # Train!
    print("Starting SimCLR pretraining...")
    model = train_simclr(
        model, dataloader, optimizer, criterion,
        epochs=100, device=device
    )

    # Save pretrained encoder (without projection head)
    torch.save(model.encoder.state_dict(), 'simclr_encoder.pth')
    print("Pretrained encoder saved to simclr_encoder.pth")
    return model


if __name__ == '__main__':
    run_simclr_pretraining()

SimCLR Pipeline—Contrastive Learning Input Image x Original t~T t’~T View 1 x_i crop + jitter View 2 x_j crop + blur Encoder f(x) = h Encoder f(x) = h shared weights Projection g(h) = z Projection g(h) = z Embedding Space z_i z_j attract z_k z_m z_n z_p z_q repel negatives NT-Xent Loss: L = -log( exp(sim(z_i, z_j)/τ) / Σ_k exp(sim(z_i, z_k)/τ) ) Positive pair (same image, different augmentations) Negative pairs (different images)

Tip: When running SimCLR on CIFAR-10 with a ResNet-18 encoder, a batch size of 256 works reasonably well. For ImageNet-scale experiments, the original paper used batch sizes of 4,096 to 8,192 with the LARS optimiser. For compute-constrained settings, MoCo or BYOL are alternatives that work well at the standard batch size of 256.

MAE: Masked Autoencoder Implementation

Now let us implement a simplified Masked Autoencoder. We will build a ViT-based encoder-decoder that masks 75% of image patches and learns to reconstruct them.

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
import math


# ============================================================
# Patch Embedding Layer
# ============================================================
class PatchEmbedding(nn.Module):
    """Convert image into sequence of patch embeddings."""

    def __init__(self, img_size=32, patch_size=4, in_channels=3,
                 embed_dim=192):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(
            in_channels, embed_dim,
            kernel_size=patch_size, stride=patch_size
        )

    def forward(self, x):
        # x: (B, C, H, W) -> (B, num_patches, embed_dim)
        x = self.proj(x)                     # (B, embed_dim, H/P, W/P)
        x = x.flatten(2).transpose(1, 2)     # (B, num_patches, embed_dim)
        return x


# ============================================================
# Transformer Block
# ============================================================
class TransformerBlock(nn.Module):
    """Standard Transformer block with multi-head self-attention."""

    def __init__(self, embed_dim, num_heads, mlp_ratio=4.0,
                 dropout=0.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(
            embed_dim, num_heads, dropout=dropout, batch_first=True
        )
        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, int(embed_dim * mlp_ratio)),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(int(embed_dim * mlp_ratio), embed_dim),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        # Self-attention with residual
        x_norm = self.norm1(x)
        attn_out, _ = self.attn(x_norm, x_norm, x_norm)
        x = x + attn_out
        # MLP with residual
        x = x + self.mlp(self.norm2(x))
        return x


# ============================================================
# MAE Encoder
# ============================================================
class MAEEncoder(nn.Module):
    """Vision Transformer encoder that only processes visible patches."""

    def __init__(self, img_size=32, patch_size=4, in_channels=3,
                 embed_dim=192, depth=6, num_heads=6):
        super().__init__()
        self.patch_embed = PatchEmbedding(
            img_size, patch_size, in_channels, embed_dim
        )
        num_patches = self.patch_embed.num_patches

        # Learnable positional embeddings
        self.pos_embed = nn.Parameter(
            torch.zeros(1, num_patches, embed_dim)
        )
        nn.init.trunc_normal_(self.pos_embed, std=0.02)

        # Transformer blocks
        self.blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads)
            for _ in range(depth)
        ])
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x, mask):
        """
        Args:
            x: images (B, C, H, W)
            mask: boolean mask (B, num_patches), True = KEEP
        Returns:
            Encoded visible patches (B, num_visible, embed_dim)
            ids_restore for unshuffling
        """
        # Patch embedding
        x = self.patch_embed(x)             # (B, N, D)
        x = x + self.pos_embed              # Add positional embeddings

        B, N, D = x.shape

        # Keep only visible (unmasked) patches
        # mask: True = visible, False = masked
        ids_keep = mask.nonzero(as_tuple=False)
        # Gather visible patches per sample
        visible_patches = []
        for b in range(B):
            keep_idx = mask[b].nonzero(as_tuple=True)[0]
            visible_patches.append(x[b, keep_idx])

        # Stack into batch (all samples have same number of visible)
        x = torch.stack(visible_patches)    # (B, num_visible, D)

        # Apply Transformer blocks (ONLY to visible patches!)
        for block in self.blocks:
            x = block(x)
        x = self.norm(x)

        return x, mask


# ============================================================
# MAE Decoder
# ============================================================
class MAEDecoder(nn.Module):
    """Lightweight decoder that reconstructs masked patches."""

    def __init__(self, num_patches, embed_dim=192, decoder_dim=96,
                 decoder_depth=2, decoder_heads=3, patch_size=4,
                 in_channels=3):
        super().__init__()
        self.num_patches = num_patches
        self.patch_size = patch_size

        # Project encoder dim to decoder dim
        self.decoder_embed = nn.Linear(embed_dim, decoder_dim)

        # Learnable mask token
        self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_dim))
        nn.init.normal_(self.mask_token, std=0.02)

        # Decoder positional embeddings
        self.decoder_pos_embed = nn.Parameter(
            torch.zeros(1, num_patches, decoder_dim)
        )
        nn.init.trunc_normal_(self.decoder_pos_embed, std=0.02)

        # Decoder Transformer blocks
        self.blocks = nn.ModuleList([
            TransformerBlock(decoder_dim, decoder_heads)
            for _ in range(decoder_depth)
        ])
        self.norm = nn.LayerNorm(decoder_dim)

        # Predict pixel values for each patch
        self.pred = nn.Linear(
            decoder_dim, patch_size * patch_size * in_channels
        )

    def forward(self, x, mask):
        """
        Args:
            x: encoded visible patches (B, num_visible, encoder_dim)
            mask: boolean (B, num_patches), True = visible
        Returns:
            Predicted patches (B, num_patches, patch_pixels)
        """
        B = x.shape[0]
        x = self.decoder_embed(x)  # (B, num_visible, decoder_dim)

        # Build full sequence: visible tokens + mask tokens
        full_seq = self.mask_token.expand(
            B, self.num_patches, -1
        ).clone()

        # Place visible tokens at their original positions
        for b in range(B):
            visible_idx = mask[b].nonzero(as_tuple=True)[0]
            full_seq[b, visible_idx] = x[b]

        # Add positional embeddings
        full_seq = full_seq + self.decoder_pos_embed

        # Apply decoder Transformer blocks
        for block in self.blocks:
            full_seq = block(full_seq)
        full_seq = self.norm(full_seq)

        # Predict pixel values
        pred = self.pred(full_seq)  # (B, num_patches, P*P*C)
        return pred


# ============================================================
# Full MAE Model
# ============================================================
class MAE(nn.Module):
    """Complete Masked Autoencoder."""

    def __init__(self, img_size=32, patch_size=4, in_channels=3,
                 embed_dim=192, encoder_depth=6, encoder_heads=6,
                 decoder_dim=96, decoder_depth=2, decoder_heads=3,
                 mask_ratio=0.75):
        super().__init__()
        self.mask_ratio = mask_ratio
        self.patch_size = patch_size
        num_patches = (img_size // patch_size) ** 2

        self.encoder = MAEEncoder(
            img_size, patch_size, in_channels,
            embed_dim, encoder_depth, encoder_heads
        )
        self.decoder = MAEDecoder(
            num_patches, embed_dim, decoder_dim,
            decoder_depth, decoder_heads, patch_size, in_channels
        )
        self.num_patches = num_patches

    def generate_mask(self, batch_size, device):
        """Generate random mask: True = keep, False = mask out."""
        num_keep = int(self.num_patches * (1 - self.mask_ratio))
        mask = torch.zeros(batch_size, self.num_patches,
                          dtype=torch.bool, device=device)

        for b in range(batch_size):
            keep_idx = torch.randperm(
                self.num_patches, device=device
            )[:num_keep]
            mask[b, keep_idx] = True

        return mask

    def patchify(self, imgs):
        """Convert images to patch sequences for loss computation.
        imgs: (B, C, H, W) -> (B, num_patches, patch_size^2 * C)
        """
        p = self.patch_size
        B, C, H, W = imgs.shape
        h, w = H // p, W // p
        patches = imgs.reshape(B, C, h, p, w, p)
        patches = patches.permute(0, 2, 4, 1, 3, 5)  # (B, h, w, C, p, p)
        patches = patches.reshape(B, h * w, C * p * p)
        return patches

    def forward(self, imgs):
        """
        Args:
            imgs: (B, C, H, W)
        Returns:
            loss: MSE reconstruction loss (on masked patches only)
            pred: predicted patches (B, num_patches, patch_pixels)
            mask: the mask used (B, num_patches)
        """
        B = imgs.shape[0]
        device = imgs.device

        # Generate random mask
        mask = self.generate_mask(B, device)

        # Encode visible patches only
        encoded, mask = self.encoder(imgs, mask)

        # Decode all patches (visible + mask tokens)
        pred = self.decoder(encoded, mask)

        # Compute loss only on masked patches
        target = self.patchify(imgs)
        # mask is True for visible, we want loss on ~mask (masked)
        masked = ~mask  # True where patches were masked

        # Per-patch MSE, then average over masked patches
        loss = (pred - target) ** 2
        loss = loss.mean(dim=-1)          # per-patch MSE
        loss = (loss * masked.float()).sum() / masked.float().sum()

        return loss, pred, mask


# ============================================================
# MAE Training Loop
# ============================================================
def train_mae(model, dataloader, optimizer, epochs=100,
              device='cuda'):
    """Full MAE pretraining loop."""
    model.train()

    for epoch in range(epochs):
        total_loss = 0
        num_batches = 0

        for imgs, _ in dataloader:
            imgs = imgs.to(device)

            # Forward pass
            loss, pred, mask = model(imgs)

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            num_batches += 1

        avg_loss = total_loss / num_batches
        if (epoch + 1) % 10 == 0:
            print(f"Epoch [{epoch+1}/{epochs}] "
                  f"| Recon Loss: {avg_loss:.4f}")

    return model


def run_mae_pretraining():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.4914, 0.4822, 0.4465],
            std=[0.2470, 0.2435, 0.2616]
        ),
    ])

    dataset = datasets.CIFAR10(
        root='./data', train=True, download=True,
        transform=transform
    )
    dataloader = DataLoader(
        dataset, batch_size=256, shuffle=True,
        num_workers=4, pin_memory=True
    )

    # Initialize MAE
    model = MAE(
        img_size=32, patch_size=4,          # 8x8 = 64 patches
        embed_dim=192, encoder_depth=6, encoder_heads=6,
        decoder_dim=96, decoder_depth=2, decoder_heads=3,
        mask_ratio=0.75
    ).to(device)

    optimizer = torch.optim.AdamW(
        model.parameters(), lr=1.5e-4,
        betas=(0.9, 0.95), weight_decay=0.05
    )

    print("Starting MAE pretraining...")
    model = train_mae(model, dataloader, optimizer,
                      epochs=100, device=device)

    # Save encoder only (discard decoder)
    torch.save(model.encoder.state_dict(), 'mae_encoder.pth')
    print("Pretrained MAE encoder saved to mae_encoder.pth")
    return model


if __name__ == '__main__':
    run_mae_pretraining()

Downstream Evaluation: Linear Probing and Fine-Tuning

After SSL pretraining, we need to evaluate how good the learned representations are. There are two standard protocols: linear probing (freeze the encoder, train only a linear classifier on top) and full fine-tuning (update all weights). If you have used transfer learning in other contexts, these concepts should feel familiar.

import torch
import torch.nn as nn
from torchvision import transforms, datasets, models
from torch.utils.data import DataLoader


# ============================================================
# Linear Probing: Freeze encoder, train linear head only
# ============================================================
class LinearProbe(nn.Module):
    """Linear probe for evaluating SSL representations."""

    def __init__(self, encoder, encoder_dim, num_classes=10):
        super().__init__()
        self.encoder = encoder
        # Freeze all encoder parameters
        for param in self.encoder.parameters():
            param.requires_grad = False
        self.classifier = nn.Linear(encoder_dim, num_classes)

    def forward(self, x):
        with torch.no_grad():
            features = self.encoder(x)
        return self.classifier(features)


def train_linear_probe(encoder, encoder_dim, train_loader,
                       test_loader, epochs=50, device='cuda'):
    """Train and evaluate a linear probe on frozen SSL features."""
    model = LinearProbe(encoder, encoder_dim).to(device)
    optimizer = torch.optim.Adam(
        model.classifier.parameters(), lr=1e-3
    )
    criterion = nn.CrossEntropyLoss()

    for epoch in range(epochs):
        model.train()
        for imgs, labels in train_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            logits = model(imgs)
            loss = criterion(logits, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    # Evaluate
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for imgs, labels in test_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            preds = model(imgs).argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    accuracy = 100 * correct / total
    print(f"Linear Probe Accuracy: {accuracy:.2f}%")
    return accuracy


# ============================================================
# Full Fine-Tuning: Update all weights with small LR
# ============================================================
class FineTuner(nn.Module):
    """Full fine-tuning of SSL-pretrained encoder."""

    def __init__(self, encoder, encoder_dim, num_classes=10):
        super().__init__()
        self.encoder = encoder
        self.classifier = nn.Linear(encoder_dim, num_classes)

    def forward(self, x):
        features = self.encoder(x)
        return self.classifier(features)


def finetune_model(encoder, encoder_dim, train_loader,
                   test_loader, epochs=30, device='cuda'):
    """Fine-tune the full model (encoder + classifier)."""
    model = FineTuner(encoder, encoder_dim).to(device)

    # Use smaller LR for encoder, larger for classifier
    optimizer = torch.optim.Adam([
        {'params': model.encoder.parameters(), 'lr': 1e-4},
        {'params': model.classifier.parameters(), 'lr': 1e-3},
    ])
    criterion = nn.CrossEntropyLoss()

    for epoch in range(epochs):
        model.train()
        for imgs, labels in train_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            logits = model(imgs)
            loss = criterion(logits, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    # Evaluate
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for imgs, labels in test_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            preds = model(imgs).argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    accuracy = 100 * correct / total
    print(f"Fine-Tune Accuracy: {accuracy:.2f}%")
    return accuracy


# ============================================================
# Run Evaluation Pipeline
# ============================================================
def evaluate_ssl_model():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Standard transforms for evaluation (no SSL augmentation)
    eval_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.4914, 0.4822, 0.4465],
            std=[0.2470, 0.2435, 0.2616]
        ),
    ])

    train_set = datasets.CIFAR10(
        root='./data', train=True, download=True,
        transform=eval_transform
    )
    test_set = datasets.CIFAR10(
        root='./data', train=False, download=True,
        transform=eval_transform
    )
    train_loader = DataLoader(train_set, batch_size=256, shuffle=True)
    test_loader = DataLoader(test_set, batch_size=256)

    # Load pretrained SimCLR encoder
    encoder = models.resnet18(weights=None)
    encoder.fc = nn.Identity()
    encoder.load_state_dict(torch.load('simclr_encoder.pth'))
    encoder.to(device)

    print("=== SimCLR Evaluation ===")
    print("Linear Probe:")
    train_linear_probe(encoder, 512, train_loader, test_loader,
                       device=device)
    print("Fine-Tuning:")
    # Reload encoder for fresh fine-tuning
    encoder2 = models.resnet18(weights=None)
    encoder2.fc = nn.Identity()
    encoder2.load_state_dict(torch.load('simclr_encoder.pth'))
    finetune_model(encoder2, 512, train_loader, test_loader,
                   device=device)


if __name__ == '__main__':
    evaluate_ssl_model()
Key Takeaway: Linear probing measures the quality of frozen representations—it answers “how much useful information did SSL capture?” Fine-tuning measures practical downstream performance—it answers “how well does this pretrained model perform after adaptation?” A strong linear probe result with further improvement from fine-tuning is the hallmark of a good SSL method.

The Pretraining to Fine-Tuning Pipeline

The SSL pretrain, then supervised fine-tune paradigm is now the default approach in modern machine learning. But the fine-tuning stage itself has several variations, each suited to different scenarios.

Linear Probing

Freeze the entire encoder and train only a linear classifier (single fully connected layer) on top. This is the purest test of representation quality, if a linear classifier can achieve high accuracy on the frozen features, the representations must contain rich, linearly separable information about the task.

When to use: When you have very little labeled data (hundreds or low thousands of samples), overfitting is a serious risk. Freezing the encoder limits the model’s capacity and acts as strong regularization. Linear probing is also the standard benchmark for comparing SSL methods.

Full Fine-Tuning

Update all parameters—encoder and classifier—using the labeled data. The key practice is using a much smaller learning rate for the pretrained encoder than for the new classifier head. Typical ratios are 10x to 100x. This preserves the useful representations while allowing them to adapt to the specific downstream task.

When to use: When you have moderate amounts of labeled data (thousands to tens of thousands of samples) and the downstream task is related but not identical to the pretraining data distribution. This is the most common fine-tuning approach in practice.

Partial Fine-Tuning (Layer Freezing)

Freeze the early layers of the encoder and only fine-tune the later layers plus the classifier. The intuition: early layers learn generic features (edges, textures, basic patterns) that transfer universally, while later layers learn more task-specific features that may need adaptation.

When to use: When your downstream domain is somewhat different from the pretraining domain but you have limited data. Partial fine-tuning is a middle ground between linear probing (maximum regularization) and full fine-tuning (maximum flexibility). This approach is widely used in domain adaptation scenarios where the source and target distributions differ.

When Each Approach Works Best

Strategy Labeled Data Domain Similarity Best For
Linear Probing Very small (100-1K) High SSL benchmarks, few-shot
Partial Fine-Tuning Small (1K-10K) Medium Cross-domain transfer
Full Fine-Tuning Moderate (10K+) Low to High Production models
Train from Scratch Very large (100K+) N/A Unique domains, considerable data

 

The key insight: SSL pretraining almost never hurts. Even when you have a large labeled dataset, initializing from SSL-pretrained weights typically matches or beats training from scratch, while converging faster. The only scenario where from-scratch training might win is when your data is highly domain-specific (e.g., satellite imagery or microscopy) and you have abundant labeled data.

SSL Beyond Vision and NLP

SSL is not limited to images and text. The principles, create a pretext task from data structure, learn representations, fine-tune downstream—apply to virtually any data modality.

Time Series

Time series data is abundant in industry, healthcare, and finance, but labeled anomalies or events are rare. SSL methods for time series anomaly detection have become increasingly important:

  • TS2Vec learns hierarchical representations by contrasting subseries at different temporal scales. It uses timestamp masking and random cropping as augmentations.
  • TNC (Temporal Neighborhood Coding) treats temporally adjacent windows as positive pairs and distant windows as negatives, based on the assumption that nearby time points share similar underlying state.
  • TS-TCC (Time-Series Temporal Contrastive Coding) combines time-domain and frequency-domain augmentations with a temporal contrasting module that predicts future timesteps.

The key challenge in time series SSL is choosing augmentations that preserve semantics. Unlike images, where random cropping is nearly always safe, time series augmentations must be chosen carefully—time warping might destroy periodicity, and amplitude scaling might change the meaning of threshold crossings. This connects directly to domain adaptation challenges in time series where distribution shift is common.

Audio and Speech

wav2vec 2.0 (Baevski et al., 2020) applies masked prediction to raw audio waveforms. It quantizes speech into discrete tokens using a codebook, masks spans of the quantized representation, and trains a Transformer to predict the masked tokens. Fine-tuned on just 10 minutes of labeled speech, wav2vec 2.0 achieves word error rates competitive with systems trained on 960 hours of labeled data.

HuBERT (Hsu et al., 2021) takes a similar approach but uses offline clustering (k-means) to create pseudo-labels for masked prediction, iteratively refining the clusters as the model improves.

Tabular Data

SSL for tabular data is harder than for images or text because tabular features lack the spatial or sequential structure that makes augmentation natural:

  • SCARF (Self-supervised Contrastive Learning using Random Feature Corruption) creates positive pairs by randomly corrupting a subset of features with values drawn from the empirical marginal distribution.
  • VIME (Value Imputation and Mask Estimation) uses a pretext task similar to BERT: mask feature values and predict both the masked values and which features were masked.

Graph Data

Graphs present unique opportunities for SSL because their structure provides rich self-supervision signals. If you are familiar with Graph Attention Networks, SSL can learn even better node and graph representations:

  • GraphCL applies contrastive learning to graphs using augmentations like node dropping, edge perturbation, attribute masking, and subgraph sampling.
  • GCC (Graph Contrastive Coding) learns structural representations by contrasting subgraph instances sampled via random walks.

Multimodal Learning

CLIP (Contrastive Language-Image Pre-training) is perhaps the most impactful multimodal SSL method. It learns to align text and image representations by contrasting matching image-text pairs (positives) against non-matching pairs (negatives) from a batch of 32,768 pairs. The result: zero-shot image classification by simply comparing image embeddings with text embeddings of class descriptions.

ImageBind (Gong et al., 2023) extends this to six modalities, images, text, audio, depth, thermal, and IMU data—using images as the binding modality. All other modalities are aligned to the image embedding space, enabling zero-shot cross-modal retrieval without ever training on pairs of non-image modalities.

Practical Guide: Choosing and Using SSL

Choosing the Right SSL Method

The choice of SSL method depends on your modality, compute budget, and downstream task:

  • If you work with text: Masked language modeling (BERT-style) or autoregressive pretraining (GPT-style). This is mature and well-understood. In most cases, you should not train from scratch—use a pretrained model from HuggingFace.
  • If you work with images and have limited compute: MAE. It only processes 25% of patches through the encoder, making it 3-4x more efficient than contrastive methods.
  • If you work with images and want the best representations: DINOv2. It combines self-distillation with masked image modeling and produces the best general-purpose visual features available.
  • If you work with small image datasets: BYOL or Barlow Twins. They do not require large batch sizes and work well with standard hardware.
  • If you need multimodal capabilities: CLIP or its variants.
  • If you work with time series: TS2Vec or TS-TCC.

Compute Requirements

Method Min. Batch Size GPU Memory Training Time (ImageNet)
SimCLR 4096+ (ideal) High (multi-GPU) ~3 days (32 TPUs)
MoCo v3 256-1024 Moderate ~2 days (8 GPUs)
BYOL 256 Moderate ~2 days (8 GPUs)
Barlow Twins 256-2048 Moderate ~2 days (8 GPUs)
MAE 256-4096 Low (efficient!) ~1 day (8 GPUs)
DINO 256-1024 High (two networks) ~3 days (8 GPUs)

 

When SSL Outperforms Supervised Learning

SSL pretraining is especially valuable in these scenarios:

  • Small labeled datasets: When you have fewer than 10,000 labeled examples, SSL pretrained models consistently outperform training from scratch. The gap widens as the labeled set shrinks.
  • Distribution shift: SSL representations are often more robust to distribution shift because they capture general structural properties rather than task-specific shortcuts.
  • Out-of-distribution detection: SSL features often enable better anomaly and OOD detection. Methods like Deep SVDD can benefit from SSL-pretrained feature extractors.
  • Semi-supervised settings: When you have a large unlabeled dataset and a small labeled subset, SSL pretraining on the unlabeled data followed by fine-tuning on the labeled data is the standard approach.

Pretrained Models vs. Training Your Own

For most practitioners, the answer is simple: download a pretrained model. Training SSL from scratch requires significant compute resources and careful hyperparameter tuning. Pretrained models are available from:

  • HuggingFace: The largest repository of pretrained models. BERT, GPT-2, ViT, CLIP, DINOv2, and hundreds more. pip install transformers and you are running in minutes.
  • timm (PyTorch Image Models): Extensive collection of vision models including MAE, DINOv2, and CLIP-pretrained ViTs. pip install timm.
  • torchvision: ResNet, ViT, and other models pretrained on ImageNet (supervised) and SWAG (SSL). Built into PyTorch.
  • DINO model zoo: Official DINOv2 checkpoints from Meta AI. current best general-purpose visual features.

Train your own SSL model only when: (1) your domain is very different from standard datasets (medical imaging, satellite imagery, industrial sensors), (2) you have abundant unlabeled domain data, and (3) pretrained models perform poorly on your downstream task.

Common Pitfalls

Caution: These are the most common mistakes when implementing SSL from scratch:

  • Augmentation leaking labels: If your augmentation pipeline preserves class-discriminative features too strongly (e.g., not using color jitter for color-based classes), the model can solve the contrastive task without learning semantic representations.
  • Undetected collapse: Monitor the standard deviation of your embeddings across a batch. If it drops toward zero, your model has collapsed. Also check the rank of the embedding matrix.
  • Bad temperature: Too low temperature (below 0.05) makes training unstable. Too high (above 1.0) makes the loss too easy. Start with τ = 0.1 to 0.5.
  • Not using a projection head: Applying contrastive loss directly to encoder features produces measurably worse representations than using a projection head.
  • Insufficient training: SSL pretraining typically requires more epochs than supervised training. SimCLR uses 800 epochs on ImageNet; MAE uses 1600. Do not stop at 100.

Method Comparison Table

A comprehensive comparison of the major SSL methods is provided below to aid selection.

Method Type Negatives? Architecture Batch Size ImageNet Top-1
SimCLR Contrastive Yes (in-batch) ResNet + MLP 4096+ 76.5% (R50)
MoCo v3 Contrastive Yes (queue) ViT + momentum 256-4096 76.7% (ViT-B)
BYOL Contrastive No ResNet + EMA 256-4096 78.6% (R200x2)
Barlow Twins Redundancy Red. No ResNet + MLP 256-2048 73.2% (R50)
MAE Masked Modeling No ViT encoder-decoder 256-4096 83.6% (ViT-H)
DINO Self-Distillation No ViT + EMA teacher 256-1024 83.6% (ViT-g)

 

Key Takeaway: For a fresh start, MAE and DINOv2 represent the current best options for vision. For NLP, both BERT-style masked modelling and GPT-style autoregressive pretraining remain dominant. The trend is clear: negative-free methods (BYOL, Barlow Twins, MAE, DINO) have largely surpassed methods that require explicit negative pairs.

Frequently Asked Questions

SSL vs. unsupervised learning, what is the difference?

Unsupervised learning (clustering, PCA, autoencoders) learns data structure without any labels. Self-supervised learning also uses no human labels, but it creates pseudo-labels from the data itself—predicting masked tokens, matching augmented views, or reconstructing hidden patches. The key difference is that SSL defines a specific prediction task (pretext task) with a clear loss function, producing representations optimized for transfer to downstream tasks. Traditional unsupervised methods like k-means do not have this task-oriented structure. SSL sits between supervised and unsupervised learning, borrowing the task structure of supervised learning while using the label-free data of unsupervised learning.

Which SSL method should I use for my problem?

Start by considering your modality. For text, use pretrained BERT or GPT models—do not train from scratch unless you have domain-specific text (biomedical, legal, code). For images, DINOv2 provides the best general-purpose features; download the pretrained model and fine-tune. For time series, TS2Vec is a strong baseline. For graphs, GraphCL. For multimodal tasks, CLIP. If you must train from scratch due to a unique domain, MAE is the most compute-efficient option for vision, and BYOL is the most forgiving of small batch sizes. Write your data pipeline in Python using PyTorch, it has the best SSL ecosystem.

Do I need a GPU cluster for SSL pretraining?

For ImageNet-scale pretraining from scratch, yes—you need multiple GPUs. SimCLR used 128 TPU v3 cores, MAE used 8 A100 GPUs, and DINOv2 used even more. However, there are practical alternatives: (1) use a pretrained model and only fine-tune—this requires just 1 GPU, (2) train on smaller datasets like CIFAR-10 or your domain-specific data, SSL on 50K images is feasible on a single GPU in hours, (3) use efficient methods like MAE that process only 25% of patches, reducing compute by 3-4x. Most practitioners should never train SSL from scratch on ImageNet—just download the pretrained weights.

Can SSL work on small datasets?

Yes, but with caveats. SSL on very small datasets (under 10K samples) may not produce great representations from scratch, because there is not enough data diversity for the model to learn generalizable features. However, SSL still helps in two ways: (1) use a pretrained SSL model trained on a large external dataset and fine-tune on your small dataset—this is highly effective, (2) if you have a large unlabeled dataset in the same domain and a small labeled dataset, pretrain on the unlabeled data and fine-tune on the labeled data. The gap between SSL and supervised learning grows wider as the labeled dataset shrinks, with 1% of ImageNet labels, SSL pretrained models can be 15-20% more accurate than training from scratch.

SSL vs. supervised pretraining (ImageNet)—which is better?

SSL pretraining has now matched or exceeded supervised ImageNet pretraining across most benchmarks. MAE with a ViT-Huge achieves 86.9 percent on ImageNet when fine-tuned, compared with 85.1 percent for supervised ViT-Huge. DINOv2 produces features that outperform supervised models on detection, segmentation and depth estimation without fine-tuning. The advantages of SSL pretraining go beyond accuracy: it does not require labels, making it scalable to larger datasets; SSL representations are generally more robust to distribution shift; and SSL models transfer more effectively across diverse downstream tasks. The only scenario in which supervised pretraining may still be preferable is one in which the downstream task closely matches ImageNet classification and the simplest possible pipeline is required.

Closing Thoughts

Self-supervised learning has fundamentally changed how AI systems are built. The two-stage paradigm, in which a model is pretrained on substantial unlabelled data with self-supervision and then fine-tuned on a small labelled dataset for the specific task, is now the default approach across virtually every modality, including text, images, audio, time series, graphs and multimodal systems.

The methods examined in this article, including SimCLR, MoCo, BYOL and Barlow Twins (contrastive), BERT and MAE (masked modelling), GPT (autoregressive), and DINO (self-distillation), represent the major families of SSL techniques. Each has its strengths. Contrastive methods produce excellent representations but some require large batches. Masked modelling is compute-efficient and scalable. Self-distillation methods such as DINO produce representations with notable emergent properties.

The practical guidance for practitioners is as follows.

  1. Begin with pretrained models. Download from HuggingFace, timm or torchvision. Avoid training from scratch unless there is a compelling reason.
  2. Fine-tune appropriately. Use linear probing for very small datasets, partial fine-tuning for moderate datasets, and full fine-tuning with differential learning rates for larger datasets.
  3. Know when to train independently. Domain-specific data (medical, industrial, scientific) that differs substantially from standard training sets may benefit from SSL pretraining on the user’s own unlabelled data.
  4. Monitor for collapse. Track embedding statistics during training. If the standard deviation falls toward zero, the model has collapsed.

The trajectory of SSL is toward universal foundation models, that is, single models pretrained on multiple modalities that can be fine-tuned for any task with minimal data. DINOv2, ImageBind and data2vec are early examples of this trend. Understanding SSL is not merely academically interesting. It is the practical foundation for modern AI engineering.

References and Further Reading

Related Posts on AI Code Invest:

Key Papers:

Additional References:

  • He et al., 2020,”Momentum Contrast for Unsupervised Visual Representation Learning” (MoCo)
  • Grill et al., 2020—”Bootstrap Your Own Latent” (BYOL)
  • Zbontar et al., 2021—”Barlow Twins: Self-Supervised Learning via Redundancy Reduction”
  • Devlin et al., 2019,”BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding”
  • Baevski et al., 2022—”data2vec: A General Framework for Self-supervised Learning in Speech, Vision and Language”
  • Oquab et al., 2024—”DINOv2: Learning Robust Visual Features without Supervision”
  • Radford et al., 2021,”Learning Transferable Visual Models From Natural Language Supervision” (CLIP)

You Might Also Like

Comments

Leave a Reply

Your email address will not be published. Required fields are marked *