DINOv2 Student-Teacher Network Training Guide

This guide provides a complete implementation for training a DINOv2 (DINO version 2) student-teacher network from scratch using PyTorch. DINOv2 is a self-supervised learning method that trains vision transformers without labels using a teacher-student distillation framework.
Overview
DINOv2 uses a student-teacher framework where:
- Teacher network: Provides stable targets (EMA of student weights)
- Student network: Learns to match teacher outputs
- Multi-crop strategy: Uses different image crops for robustness
- Centering mechanism: Prevents mode collapse
Architecture Components
Vision Transformer (ViT) Backbone
Import Libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
import math
import numpy as np
from typing import Optional, List, TuplePatch Embedding
class PatchEmbed(nn.Module):
"""Image to Patch Embedding"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = (img_size // patch_size) ** 2
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x):
B, C, H, W = x.shape
x = self.proj(x).flatten(2).transpose(1, 2) # [B, N, D]
return xMulti-head Self-Attention
class MultiheadAttention(nn.Module):
"""Multi-head Self Attention"""
def __init__(self, embed_dim, num_heads, dropout=0.0):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.qkv = nn.Linear(embed_dim, embed_dim * 3)
self.proj = nn.Linear(embed_dim, embed_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
B, N, D = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5)
attn = attn.softmax(dim=-1)
attn = self.dropout(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, D)
x = self.proj(x)
return xTransformer Block
class TransformerBlock(nn.Module):
"""Transformer Block"""
def __init__(self, embed_dim, num_heads, mlp_ratio=4.0, dropout=0.0):
super().__init__()
self.norm1 = nn.LayerNorm(embed_dim)
self.attn = MultiheadAttention(embed_dim, num_heads, dropout)
self.norm2 = nn.LayerNorm(embed_dim)
mlp_hidden_dim = int(embed_dim * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(embed_dim, mlp_hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(mlp_hidden_dim, embed_dim),
nn.Dropout(dropout)
)
def forward(self, x):
x = x + self.attn(self.norm1(x))
x = x + self.mlp(self.norm2(x))
return xVision Transformer
class VisionTransformer(nn.Module):
"""Vision Transformer"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768,
depth=12, num_heads=12, mlp_ratio=4.0, dropout=0.0):
super().__init__()
self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
self.dropout = nn.Dropout(dropout)
self.blocks = nn.ModuleList([
TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout)
for _ in range(depth)
])
self.norm = nn.LayerNorm(embed_dim)
# Initialize weights
self._init_weights()
def _init_weights(self):
nn.init.trunc_normal_(self.pos_embed, std=0.02)
nn.init.trunc_normal_(self.cls_token, std=0.02)
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x):
B = x.shape[0]
x = self.patch_embed(x)
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x = x + self.pos_embed
x = self.dropout(x)
for block in self.blocks:
x = block(x)
x = self.norm(x)
return x[:, 0] # Return CLS tokenDino Head
class DINOHead(nn.Module):
"""DINO Projection Head"""
def __init__(self, in_dim, out_dim, hidden_dim=2048, bottleneck_dim=256,
num_layers=3, use_bn=False, norm_last_layer=True):
super().__init__()
if num_layers == 1:
self.mlp = nn.Linear(in_dim, bottleneck_dim)
else:
layers = [nn.Linear(in_dim, hidden_dim)]
if use_bn:
layers.append(nn.BatchNorm1d(hidden_dim))
layers.append(nn.GELU())
for _ in range(num_layers - 2):
layers.append(nn.Linear(hidden_dim, hidden_dim))
if use_bn:
layers.append(nn.BatchNorm1d(hidden_dim))
layers.append(nn.GELU())
layers.append(nn.Linear(hidden_dim, bottleneck_dim))
self.mlp = nn.Sequential(*layers)
self.apply(self._init_weights)
self.last_layer = nn.utils.weight_norm(
nn.Linear(bottleneck_dim, out_dim, bias=False)
)
self.last_layer.weight_g.data.fill_(1)
if norm_last_layer:
self.last_layer.weight_g.requires_grad = False
def _init_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x):
x = self.mlp(x)
x = nn.functional.normalize(x, dim=-1, p=2)
x = self.last_layer(x)
return xDINOv2 Model
class DINOv2(nn.Module):
"""Complete DINOv2 Model"""
def __init__(self, backbone_args, head_args):
super().__init__()
self.backbone = VisionTransformer(**backbone_args)
self.head = DINOHead(**head_args)
def forward(self, x):
x = self.backbone(x)
x = self.head(x)
return xData Augmentation and Multi-Crop Strategy
Multi-Crop Data Augmentation
class MultiCropDataAugmentation:
"""Multi-crop data augmentation for DINOv2"""
def __init__(self, global_crops_scale=(0.4, 1.0), local_crops_scale=(0.05, 0.4),
global_crops_number=2, local_crops_number=6, size_crops=(224, 96)):
self.global_crops_number = global_crops_number
self.local_crops_number = local_crops_number
# Global crops (teacher and student)
self.global_transform = transforms.Compose([
transforms.RandomResizedCrop(size_crops[0], scale=global_crops_scale,
interpolation=transforms.InterpolationMode.BICUBIC),
transforms.RandomHorizontalFlip(p=0.5),
transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1),
transforms.RandomGrayscale(p=0.2),
GaussianBlur(p=1.0),
Solarization(p=0.0),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
# Local crops (student only)
self.local_transform = transforms.Compose([
transforms.RandomResizedCrop(size_crops[1], scale=local_crops_scale,
interpolation=transforms.InterpolationMode.BICUBIC),
transforms.RandomHorizontalFlip(p=0.5),
transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1),
transforms.RandomGrayscale(p=0.2),
GaussianBlur(p=0.5),
Solarization(p=0.2),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
def __call__(self, image):
crops = []
# Global crops
for _ in range(self.global_crops_number):
crops.append(self.global_transform(image))
# Local crops
for _ in range(self.local_crops_number):
crops.append(self.local_transform(image))
return cropsAugmentation Utilities
class GaussianBlur:
"""Gaussian blur augmentation"""
def __init__(self, p=0.5, radius_min=0.1, radius_max=2.0):
self.prob = p
self.radius_min = radius_min
self.radius_max = radius_max
def __call__(self, img):
if torch.rand(1) < self.prob:
radius = self.radius_min + torch.rand(1) * (self.radius_max - self.radius_min)
return transforms.functional.gaussian_blur(img, kernel_size=9, sigma=radius.item())
return img
class Solarization:
"""Solarization augmentation"""
def __init__(self, p=0.2):
self.p = p
def __call__(self, img):
if torch.rand(1) < self.p:
return transforms.functional.solarize(img, threshold=128)
return imgLoss Functions and Training Components
class DINOLoss(nn.Module):
"""DINO Loss with centering and sharpening"""
def __init__(self, out_dim, ncrops, warmup_teacher_temp=0.04,
teacher_temp=0.04, warmup_teacher_temp_epochs=0,
student_temp=0.1, center_momentum=0.9):
super().__init__()
self.student_temp = student_temp
self.center_momentum = center_momentum
self.ncrops = ncrops
self.register_buffer("center", torch.zeros(1, out_dim))
# Temperature schedule
self.teacher_temp_schedule = np.concatenate((
np.linspace(warmup_teacher_temp, teacher_temp, warmup_teacher_temp_epochs),
np.ones(1000) * teacher_temp # Assume max 1000 epochs
))
def forward(self, student_output, teacher_output, epoch):
"""
Cross-entropy between softmax outputs of the teacher and student networks.
"""
student_out = student_output / self.student_temp
student_out = student_out.chunk(self.ncrops)
# Teacher centering and sharpening
temp = self.teacher_temp_schedule[epoch]
teacher_out = F.softmax((teacher_output - self.center) / temp, dim=-1)
teacher_out = teacher_out.detach().chunk(2) # Only 2 global crops for teacher
total_loss = 0
n_loss_terms = 0
for iq, q in enumerate(teacher_out):
for v in range(len(student_out)):
if v == iq:
continue # Skip same crop
loss = torch.sum(-q * F.log_softmax(student_out[v], dim=-1), dim=-1)
total_loss += loss.mean()
n_loss_terms += 1
total_loss /= n_loss_terms
self.update_center(teacher_output)
return total_loss
@torch.no_grad()
def update_center(self, teacher_output):
"""Update center used for teacher output."""
batch_center = torch.sum(teacher_output, dim=0, keepdim=True)
batch_center = batch_center / len(teacher_output)
# EMA update
self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum)Training Utilities
@torch.no_grad()
def update_teacher(student, teacher, momentum):
"""EMA update of the teacher network."""
for param_student, param_teacher in zip(student.parameters(), teacher.parameters()):
param_teacher.data.mul_(momentum).add_(param_student.data, alpha=1 - momentum)
def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0):
"""Cosine learning rate schedule with linear warmup."""
warmup_schedule = np.array([])
warmup_iters = warmup_epochs * niter_per_ep
if warmup_epochs > 0:
warmup_schedule = np.linspace(0, base_value, warmup_iters)
iters = np.arange(epochs * niter_per_ep - warmup_iters)
schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters)))
schedule = np.concatenate((warmup_schedule, schedule))
assert len(schedule) == epochs * niter_per_ep
return scheduleTraining Loop Implementation
class DINOv2Trainer:
"""DINOv2 Training Pipeline"""
def __init__(self, config):
self.config = config
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Model architecture configs
backbone_args = {
'img_size': 224,
'patch_size': 16,
'embed_dim': 768,
'depth': 12,
'num_heads': 12,
'mlp_ratio': 4.0,
'dropout': 0.0
}
head_args = {
'in_dim': 768,
'out_dim': 65536, # Large output dimension
'hidden_dim': 2048,
'bottleneck_dim': 256
}
# Initialize student and teacher networks
self.student = DINOv2(backbone_args, head_args).to(self.device)
self.teacher = DINOv2(backbone_args, head_args).to(self.device)
# Teacher starts as copy of student
self.teacher.load_state_dict(self.student.state_dict())
# Teacher parameters are not updated by gradients
for p in self.teacher.parameters():
p.requires_grad = False
# Loss function
self.dino_loss = DINOLoss(
out_dim=head_args['out_dim'],
ncrops=8, # 2 global + 6 local crops
student_temp=0.1,
teacher_temp=0.04,
center_momentum=0.9
).to(self.device)
# Optimizer
self.optimizer = torch.optim.AdamW(
self.student.parameters(),
lr=config['base_lr'],
weight_decay=config['weight_decay']
)
# Learning rate scheduler
self.lr_schedule = cosine_scheduler(
config['base_lr'],
config['final_lr'],
config['epochs'],
config['niter_per_ep'],
config['warmup_epochs']
)
# Momentum schedule for teacher updates
self.momentum_schedule = cosine_scheduler(
config['momentum_teacher'],
1.0,
config['epochs'],
config['niter_per_ep']
)
def train_epoch(self, dataloader, epoch):
"""Train for one epoch"""
self.student.train()
self.teacher.eval()
total_loss = 0
num_batches = len(dataloader)
for it, (images, _) in enumerate(dataloader):
# Update learning rate
lr = self.lr_schedule[epoch * num_batches + it]
for param_group in self.optimizer.param_groups:
param_group['lr'] = lr
# Move to device and prepare crops
images = [im.to(self.device, non_blocking=True) for im in images]
# Teacher forward pass (only on global crops)
teacher_output = self.teacher(torch.cat(images[:2]))
# Student forward pass (on all crops)
student_output = self.student(torch.cat(images))
# Compute loss
loss = self.dino_loss(student_output, teacher_output, epoch)
# Backward pass
self.optimizer.zero_grad()
loss.backward()
# Gradient clipping
torch.nn.utils.clip_grad_norm_(self.student.parameters(), max_norm=3.0)
self.optimizer.step()
# Update teacher with EMA
momentum = self.momentum_schedule[epoch * num_batches + it]
update_teacher(self.student, self.teacher, momentum)
total_loss += loss.item()
if it % 100 == 0:
print(f'Epoch {epoch}, Iter {it}/{num_batches}, Loss: {loss.item():.4f}, LR: {lr:.6f}')
return total_loss / num_batches
def train(self, dataloader):
"""Full training loop"""
for epoch in range(self.config['epochs']):
avg_loss = self.train_epoch(dataloader, epoch)
print(f'Epoch {epoch}/{self.config["epochs"]}, Average Loss: {avg_loss:.4f}')
# Save checkpoint
if epoch % self.config['save_every'] == 0:
self.save_checkpoint(epoch)
def save_checkpoint(self, epoch):
"""Save model checkpoint"""
checkpoint = {
'epoch': epoch,
'student_state_dict': self.student.state_dict(),
'teacher_state_dict': self.teacher.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'config': self.config
}
torch.save(checkpoint, f'dinov2_checkpoint_epoch_{epoch}.pth')Usage Example
def main():
# Training configuration
config = {
'base_lr': 5e-4,
'final_lr': 1e-6,
'weight_decay': 0.04,
'momentum_teacher': 0.996,
'epochs': 100,
'warmup_epochs': 10,
'batch_size': 64,
'save_every': 10,
'niter_per_ep': None # Will be set after dataloader creation
}
# Data setup
transform = MultiCropDataAugmentation()
dataset = ImageFolder(root='path/to/your/dataset', transform=transform)
dataloader = DataLoader(
dataset,
batch_size=config['batch_size'],
shuffle=True,
num_workers=4,
pin_memory=True,
drop_last=True
)
config['niter_per_ep'] = len(dataloader)
# Initialize trainer and start training
trainer = DINOv2Trainer(config)
trainer.train(dataloader)
main()Key Features Implemented
- Vision Transformer Backbone: Complete ViT implementation with patch embedding, multi-head attention, and transformer blocks
- Multi-crop Strategy: Global and local crops with different augmentations
- Teacher-Student Framework: EMA updates for teacher network
- DINO Loss: Cross-entropy loss with centering mechanism to prevent collapse
- Learning Rate Scheduling: Cosine annealing with warmup
- Gradient Clipping: Stability during training
- Checkpointing: Save/load model states
Training Tips
- Batch Size: Use large batch sizes (256-1024) for better performance
- Data Augmentation: Strong augmentations are crucial for self-supervised learning
- Temperature Scheduling: Gradually increase teacher temperature
- Momentum Scheduling: Start with high momentum and decrease over time
- Multi-GPU Training: Use DistributedDataParallel for faster training
This implementation provides a solid foundation for training DINOv2 models. Adjust hyperparameters based on your dataset size and computational resources.