DINOv2 Student-Teacher Network Training Guide

code
advanced
Author

Krishnatheja Vanka

Published

May 28, 2025

DINOv2 Student-Teacher Network Training Guide

This guide provides a complete implementation for training a DINOv2 (DINO version 2) student-teacher network from scratch using PyTorch. DINOv2 is a self-supervised learning method that trains vision transformers without labels using a teacher-student distillation framework.

Table of Contents

  1. Overview
  2. Architecture Components
  3. Implementation
  4. Training Loop
  5. Usage Example

Overview

DINOv2 uses a student-teacher framework where:

  • Teacher network: Provides stable targets (EMA of student weights)
  • Student network: Learns to match teacher outputs
  • Multi-crop strategy: Uses different image crops for robustness
  • Centering mechanism: Prevents mode collapse

Architecture Components

Vision Transformer (ViT) Backbone

Import Libraries

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
import math
import numpy as np
from typing import Optional, List, Tuple

Patch Embedding

class PatchEmbed(nn.Module):
    """Image to Patch Embedding"""
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
    
    def forward(self, x):
        B, C, H, W = x.shape
        x = self.proj(x).flatten(2).transpose(1, 2)  # [B, N, D]
        return x

Multi-head Self-Attention

class MultiheadAttention(nn.Module):
    """Multi-head Self Attention"""
    def __init__(self, embed_dim, num_heads, dropout=0.0):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        
        self.qkv = nn.Linear(embed_dim, embed_dim * 3)
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        B, N, D = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        attn = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5)
        attn = attn.softmax(dim=-1)
        attn = self.dropout(attn)
        
        x = (attn @ v).transpose(1, 2).reshape(B, N, D)
        x = self.proj(x)
        return x

Transformer Block

class TransformerBlock(nn.Module):
    """Transformer Block"""
    def __init__(self, embed_dim, num_heads, mlp_ratio=4.0, dropout=0.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = MultiheadAttention(embed_dim, num_heads, dropout)
        self.norm2 = nn.LayerNorm(embed_dim)
        
        mlp_hidden_dim = int(embed_dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, mlp_hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_hidden_dim, embed_dim),
            nn.Dropout(dropout)
        )
    
    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

Vision Transformer

class VisionTransformer(nn.Module):
    """Vision Transformer"""
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, 
                 depth=12, num_heads=12, mlp_ratio=4.0, dropout=0.0):
        super().__init__()
        self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
        num_patches = self.patch_embed.num_patches
        
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        self.dropout = nn.Dropout(dropout)
        
        self.blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout)
            for _ in range(depth)
        ])
        
        self.norm = nn.LayerNorm(embed_dim)
        
        # Initialize weights
        self._init_weights()
    
    def _init_weights(self):
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.trunc_normal_(m.weight, std=0.02)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.LayerNorm):
                nn.init.constant_(m.bias, 0)
                nn.init.constant_(m.weight, 1.0)
    
    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)
        
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed
        x = self.dropout(x)
        
        for block in self.blocks:
            x = block(x)
        
        x = self.norm(x)
        return x[:, 0]  # Return CLS token

Dino Head

class DINOHead(nn.Module):
    """DINO Projection Head"""
    def __init__(self, in_dim, out_dim, hidden_dim=2048, bottleneck_dim=256, 
                 num_layers=3, use_bn=False, norm_last_layer=True):
        super().__init__()
        
        if num_layers == 1:
            self.mlp = nn.Linear(in_dim, bottleneck_dim)
        else:
            layers = [nn.Linear(in_dim, hidden_dim)]
            if use_bn:
                layers.append(nn.BatchNorm1d(hidden_dim))
            layers.append(nn.GELU())
            
            for _ in range(num_layers - 2):
                layers.append(nn.Linear(hidden_dim, hidden_dim))
                if use_bn:
                    layers.append(nn.BatchNorm1d(hidden_dim))
                layers.append(nn.GELU())
            
            layers.append(nn.Linear(hidden_dim, bottleneck_dim))
            self.mlp = nn.Sequential(*layers)
        
        self.apply(self._init_weights)
        
        self.last_layer = nn.utils.weight_norm(
            nn.Linear(bottleneck_dim, out_dim, bias=False)
        )
        self.last_layer.weight_g.data.fill_(1)
        if norm_last_layer:
            self.last_layer.weight_g.requires_grad = False
    
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
    
    def forward(self, x):
        x = self.mlp(x)
        x = nn.functional.normalize(x, dim=-1, p=2)
        x = self.last_layer(x)
        return x

DINOv2 Model

class DINOv2(nn.Module):
    """Complete DINOv2 Model"""
    def __init__(self, backbone_args, head_args):
        super().__init__()
        self.backbone = VisionTransformer(**backbone_args)
        self.head = DINOHead(**head_args)
    
    def forward(self, x):
        x = self.backbone(x)
        x = self.head(x)
        return x

Data Augmentation and Multi-Crop Strategy

Multi-Crop Data Augmentation

class MultiCropDataAugmentation:
    """Multi-crop data augmentation for DINOv2"""
    def __init__(self, global_crops_scale=(0.4, 1.0), local_crops_scale=(0.05, 0.4),
                 global_crops_number=2, local_crops_number=6, size_crops=(224, 96)):
        self.global_crops_number = global_crops_number
        self.local_crops_number = local_crops_number
        
        # Global crops (teacher and student)
        self.global_transform = transforms.Compose([
            transforms.RandomResizedCrop(size_crops[0], scale=global_crops_scale, 
                                       interpolation=transforms.InterpolationMode.BICUBIC),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1),
            transforms.RandomGrayscale(p=0.2),
            GaussianBlur(p=1.0),
            Solarization(p=0.0),
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        ])
        
        # Local crops (student only)
        self.local_transform = transforms.Compose([
            transforms.RandomResizedCrop(size_crops[1], scale=local_crops_scale,
                                       interpolation=transforms.InterpolationMode.BICUBIC),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1),
            transforms.RandomGrayscale(p=0.2),
            GaussianBlur(p=0.5),
            Solarization(p=0.2),
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        ])
    
    def __call__(self, image):
        crops = []
        
        # Global crops
        for _ in range(self.global_crops_number):
            crops.append(self.global_transform(image))
        
        # Local crops
        for _ in range(self.local_crops_number):
            crops.append(self.local_transform(image))
        
        return crops

Augmentation Utilities

class GaussianBlur:
    """Gaussian blur augmentation"""
    def __init__(self, p=0.5, radius_min=0.1, radius_max=2.0):
        self.prob = p
        self.radius_min = radius_min
        self.radius_max = radius_max
    
    def __call__(self, img):
        if torch.rand(1) < self.prob:
            radius = self.radius_min + torch.rand(1) * (self.radius_max - self.radius_min)
            return transforms.functional.gaussian_blur(img, kernel_size=9, sigma=radius.item())
        return img

class Solarization:
    """Solarization augmentation"""
    def __init__(self, p=0.2):
        self.p = p
    
    def __call__(self, img):
        if torch.rand(1) < self.p:
            return transforms.functional.solarize(img, threshold=128)
        return img

Loss Functions and Training Components

class DINOLoss(nn.Module):
    """DINO Loss with centering and sharpening"""
    def __init__(self, out_dim, ncrops, warmup_teacher_temp=0.04, 
                 teacher_temp=0.04, warmup_teacher_temp_epochs=0, 
                 student_temp=0.1, center_momentum=0.9):
        super().__init__()
        self.student_temp = student_temp
        self.center_momentum = center_momentum
        self.ncrops = ncrops
        self.register_buffer("center", torch.zeros(1, out_dim))
        
        # Temperature schedule
        self.teacher_temp_schedule = np.concatenate((
            np.linspace(warmup_teacher_temp, teacher_temp, warmup_teacher_temp_epochs),
            np.ones(1000) * teacher_temp  # Assume max 1000 epochs
        ))
    
    def forward(self, student_output, teacher_output, epoch):
        """
        Cross-entropy between softmax outputs of the teacher and student networks.
        """
        student_out = student_output / self.student_temp
        student_out = student_out.chunk(self.ncrops)
        
        # Teacher centering and sharpening
        temp = self.teacher_temp_schedule[epoch]
        teacher_out = F.softmax((teacher_output - self.center) / temp, dim=-1)
        teacher_out = teacher_out.detach().chunk(2)  # Only 2 global crops for teacher
        
        total_loss = 0
        n_loss_terms = 0
        
        for iq, q in enumerate(teacher_out):
            for v in range(len(student_out)):
                if v == iq:
                    continue  # Skip same crop
                loss = torch.sum(-q * F.log_softmax(student_out[v], dim=-1), dim=-1)
                total_loss += loss.mean()
                n_loss_terms += 1
        
        total_loss /= n_loss_terms
        self.update_center(teacher_output)
        return total_loss
    
    @torch.no_grad()
    def update_center(self, teacher_output):
        """Update center used for teacher output."""
        batch_center = torch.sum(teacher_output, dim=0, keepdim=True)
        batch_center = batch_center / len(teacher_output)
        
        # EMA update
        self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum)

Training Utilities

@torch.no_grad()
def update_teacher(student, teacher, momentum):
    """EMA update of the teacher network."""
    for param_student, param_teacher in zip(student.parameters(), teacher.parameters()):
        param_teacher.data.mul_(momentum).add_(param_student.data, alpha=1 - momentum)

def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0):
    """Cosine learning rate schedule with linear warmup."""
    warmup_schedule = np.array([])
    warmup_iters = warmup_epochs * niter_per_ep
    
    if warmup_epochs > 0:
        warmup_schedule = np.linspace(0, base_value, warmup_iters)
    
    iters = np.arange(epochs * niter_per_ep - warmup_iters)
    schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters)))
    
    schedule = np.concatenate((warmup_schedule, schedule))
    assert len(schedule) == epochs * niter_per_ep
    return schedule

Training Loop Implementation

class DINOv2Trainer:
    """DINOv2 Training Pipeline"""
    def __init__(self, config):
        self.config = config
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        # Model architecture configs
        backbone_args = {
            'img_size': 224,
            'patch_size': 16,
            'embed_dim': 768,
            'depth': 12,
            'num_heads': 12,
            'mlp_ratio': 4.0,
            'dropout': 0.0
        }
        
        head_args = {
            'in_dim': 768,
            'out_dim': 65536,  # Large output dimension
            'hidden_dim': 2048,
            'bottleneck_dim': 256
        }
        
        # Initialize student and teacher networks
        self.student = DINOv2(backbone_args, head_args).to(self.device)
        self.teacher = DINOv2(backbone_args, head_args).to(self.device)
        
        # Teacher starts as copy of student
        self.teacher.load_state_dict(self.student.state_dict())
        
        # Teacher parameters are not updated by gradients
        for p in self.teacher.parameters():
            p.requires_grad = False
        
        # Loss function
        self.dino_loss = DINOLoss(
            out_dim=head_args['out_dim'],
            ncrops=8,  # 2 global + 6 local crops
            student_temp=0.1,
            teacher_temp=0.04,
            center_momentum=0.9
        ).to(self.device)
        
        # Optimizer
        self.optimizer = torch.optim.AdamW(
            self.student.parameters(),
            lr=config['base_lr'],
            weight_decay=config['weight_decay']
        )
        
        # Learning rate scheduler
        self.lr_schedule = cosine_scheduler(
            config['base_lr'],
            config['final_lr'],
            config['epochs'],
            config['niter_per_ep'],
            config['warmup_epochs']
        )
        
        # Momentum schedule for teacher updates
        self.momentum_schedule = cosine_scheduler(
            config['momentum_teacher'],
            1.0,
            config['epochs'],
            config['niter_per_ep']
        )
    
    def train_epoch(self, dataloader, epoch):
        """Train for one epoch"""
        self.student.train()
        self.teacher.eval()
        
        total_loss = 0
        num_batches = len(dataloader)
        
        for it, (images, _) in enumerate(dataloader):
            # Update learning rate
            lr = self.lr_schedule[epoch * num_batches + it]
            for param_group in self.optimizer.param_groups:
                param_group['lr'] = lr
            
            # Move to device and prepare crops
            images = [im.to(self.device, non_blocking=True) for im in images]
            
            # Teacher forward pass (only on global crops)
            teacher_output = self.teacher(torch.cat(images[:2]))
            
            # Student forward pass (on all crops)
            student_output = self.student(torch.cat(images))
            
            # Compute loss
            loss = self.dino_loss(student_output, teacher_output, epoch)
            
            # Backward pass
            self.optimizer.zero_grad()
            loss.backward()
            
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(self.student.parameters(), max_norm=3.0)
            
            self.optimizer.step()
            
            # Update teacher with EMA
            momentum = self.momentum_schedule[epoch * num_batches + it]
            update_teacher(self.student, self.teacher, momentum)
            
            total_loss += loss.item()
            
            if it % 100 == 0:
                print(f'Epoch {epoch}, Iter {it}/{num_batches}, Loss: {loss.item():.4f}, LR: {lr:.6f}')
        
        return total_loss / num_batches
    
    def train(self, dataloader):
        """Full training loop"""
        for epoch in range(self.config['epochs']):
            avg_loss = self.train_epoch(dataloader, epoch)
            print(f'Epoch {epoch}/{self.config["epochs"]}, Average Loss: {avg_loss:.4f}')
            
            # Save checkpoint
            if epoch % self.config['save_every'] == 0:
                self.save_checkpoint(epoch)
    
    def save_checkpoint(self, epoch):
        """Save model checkpoint"""
        checkpoint = {
            'epoch': epoch,
            'student_state_dict': self.student.state_dict(),
            'teacher_state_dict': self.teacher.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'config': self.config
        }
        torch.save(checkpoint, f'dinov2_checkpoint_epoch_{epoch}.pth')

Usage Example

def main():
    # Training configuration
    config = {
        'base_lr': 5e-4,
        'final_lr': 1e-6,
        'weight_decay': 0.04,
        'momentum_teacher': 0.996,
        'epochs': 100,
        'warmup_epochs': 10,
        'batch_size': 64,
        'save_every': 10,
        'niter_per_ep': None  # Will be set after dataloader creation
    }
    
    # Data setup
    transform = MultiCropDataAugmentation()
    dataset = ImageFolder(root='path/to/your/dataset', transform=transform)
    dataloader = DataLoader(
        dataset, 
        batch_size=config['batch_size'], 
        shuffle=True, 
        num_workers=4,
        pin_memory=True,
        drop_last=True
    )
    
    config['niter_per_ep'] = len(dataloader)
    
    # Initialize trainer and start training
    trainer = DINOv2Trainer(config)
    trainer.train(dataloader)

main()

Key Features Implemented

  1. Vision Transformer Backbone: Complete ViT implementation with patch embedding, multi-head attention, and transformer blocks
  2. Multi-crop Strategy: Global and local crops with different augmentations
  3. Teacher-Student Framework: EMA updates for teacher network
  4. DINO Loss: Cross-entropy loss with centering mechanism to prevent collapse
  5. Learning Rate Scheduling: Cosine annealing with warmup
  6. Gradient Clipping: Stability during training
  7. Checkpointing: Save/load model states

Training Tips

  1. Batch Size: Use large batch sizes (256-1024) for better performance
  2. Data Augmentation: Strong augmentations are crucial for self-supervised learning
  3. Temperature Scheduling: Gradually increase teacher temperature
  4. Momentum Scheduling: Start with high momentum and decrease over time
  5. Multi-GPU Training: Use DistributedDataParallel for faster training

This implementation provides a solid foundation for training DINOv2 models. Adjust hyperparameters based on your dataset size and computational resources.