DINOv2 Student-Teacher Network Training Guide
This guide provides a complete implementation for training a DINOv2 (DINO version 2) student-teacher network from scratch using PyTorch. DINOv2 is a self-supervised learning method that trains vision transformers without labels using a teacher-student distillation framework.
Table of Contents
Overview
DINOv2 uses a student-teacher framework where:
- Teacher network: Provides stable targets (EMA of student weights)
- Student network: Learns to match teacher outputs
- Multi-crop strategy: Uses different image crops for robustness
- Centering mechanism: Prevents mode collapse
Architecture Components
Vision Transformer (ViT) Backbone
Import Libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
import math
import numpy as np
from typing import Optional, List, Tuple
Patch Embedding
class PatchEmbed(nn.Module):
"""Image to Patch Embedding"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = (img_size // patch_size) ** 2
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x):
= x.shape
B, C, H, W = self.proj(x).flatten(2).transpose(1, 2) # [B, N, D]
x return x
Multi-head Self-Attention
class MultiheadAttention(nn.Module):
"""Multi-head Self Attention"""
def __init__(self, embed_dim, num_heads, dropout=0.0):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.qkv = nn.Linear(embed_dim, embed_dim * 3)
self.proj = nn.Linear(embed_dim, embed_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
= x.shape
B, N, D = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
qkv = qkv[0], qkv[1], qkv[2]
q, k, v
= (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5)
attn = attn.softmax(dim=-1)
attn = self.dropout(attn)
attn
= (attn @ v).transpose(1, 2).reshape(B, N, D)
x = self.proj(x)
x return x
Transformer Block
class TransformerBlock(nn.Module):
"""Transformer Block"""
def __init__(self, embed_dim, num_heads, mlp_ratio=4.0, dropout=0.0):
super().__init__()
self.norm1 = nn.LayerNorm(embed_dim)
self.attn = MultiheadAttention(embed_dim, num_heads, dropout)
self.norm2 = nn.LayerNorm(embed_dim)
= int(embed_dim * mlp_ratio)
mlp_hidden_dim self.mlp = nn.Sequential(
nn.Linear(embed_dim, mlp_hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(mlp_hidden_dim, embed_dim),
nn.Dropout(dropout)
)
def forward(self, x):
= x + self.attn(self.norm1(x))
x = x + self.mlp(self.norm2(x))
x return x
Vision Transformer
class VisionTransformer(nn.Module):
"""Vision Transformer"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768,
=12, num_heads=12, mlp_ratio=4.0, dropout=0.0):
depthsuper().__init__()
self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
= self.patch_embed.num_patches
num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
self.dropout = nn.Dropout(dropout)
self.blocks = nn.ModuleList([
TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout)for _ in range(depth)
])
self.norm = nn.LayerNorm(embed_dim)
# Initialize weights
self._init_weights()
def _init_weights(self):
self.pos_embed, std=0.02)
nn.init.trunc_normal_(self.cls_token, std=0.02)
nn.init.trunc_normal_(
for m in self.modules():
if isinstance(m, nn.Linear):
=0.02)
nn.init.trunc_normal_(m.weight, stdif m.bias is not None:
0)
nn.init.constant_(m.bias, elif isinstance(m, nn.LayerNorm):
0)
nn.init.constant_(m.bias, 1.0)
nn.init.constant_(m.weight,
def forward(self, x):
= x.shape[0]
B = self.patch_embed(x)
x
= self.cls_token.expand(B, -1, -1)
cls_tokens = torch.cat((cls_tokens, x), dim=1)
x = x + self.pos_embed
x = self.dropout(x)
x
for block in self.blocks:
= block(x)
x
= self.norm(x)
x return x[:, 0] # Return CLS token
Dino Head
class DINOHead(nn.Module):
"""DINO Projection Head"""
def __init__(self, in_dim, out_dim, hidden_dim=2048, bottleneck_dim=256,
=3, use_bn=False, norm_last_layer=True):
num_layerssuper().__init__()
if num_layers == 1:
self.mlp = nn.Linear(in_dim, bottleneck_dim)
else:
= [nn.Linear(in_dim, hidden_dim)]
layers if use_bn:
layers.append(nn.BatchNorm1d(hidden_dim))
layers.append(nn.GELU())
for _ in range(num_layers - 2):
layers.append(nn.Linear(hidden_dim, hidden_dim))if use_bn:
layers.append(nn.BatchNorm1d(hidden_dim))
layers.append(nn.GELU())
layers.append(nn.Linear(hidden_dim, bottleneck_dim))self.mlp = nn.Sequential(*layers)
self.apply(self._init_weights)
self.last_layer = nn.utils.weight_norm(
=False)
nn.Linear(bottleneck_dim, out_dim, bias
)self.last_layer.weight_g.data.fill_(1)
if norm_last_layer:
self.last_layer.weight_g.requires_grad = False
def _init_weights(self, m):
if isinstance(m, nn.Linear):
=0.02)
nn.init.trunc_normal_(m.weight, stdif m.bias is not None:
0)
nn.init.constant_(m.bias,
def forward(self, x):
= self.mlp(x)
x = nn.functional.normalize(x, dim=-1, p=2)
x = self.last_layer(x)
x return x
DINOv2 Model
class DINOv2(nn.Module):
"""Complete DINOv2 Model"""
def __init__(self, backbone_args, head_args):
super().__init__()
self.backbone = VisionTransformer(**backbone_args)
self.head = DINOHead(**head_args)
def forward(self, x):
= self.backbone(x)
x = self.head(x)
x return x
Data Augmentation and Multi-Crop Strategy
Multi-Crop Data Augmentation
class MultiCropDataAugmentation:
"""Multi-crop data augmentation for DINOv2"""
def __init__(self, global_crops_scale=(0.4, 1.0), local_crops_scale=(0.05, 0.4),
=2, local_crops_number=6, size_crops=(224, 96)):
global_crops_numberself.global_crops_number = global_crops_number
self.local_crops_number = local_crops_number
# Global crops (teacher and student)
self.global_transform = transforms.Compose([
0], scale=global_crops_scale,
transforms.RandomResizedCrop(size_crops[=transforms.InterpolationMode.BICUBIC),
interpolation=0.5),
transforms.RandomHorizontalFlip(p=0.4, contrast=0.4, saturation=0.2, hue=0.1),
transforms.ColorJitter(brightness=0.2),
transforms.RandomGrayscale(p=1.0),
GaussianBlur(p=0.0),
Solarization(p
transforms.ToTensor(),0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
transforms.Normalize((
])
# Local crops (student only)
self.local_transform = transforms.Compose([
1], scale=local_crops_scale,
transforms.RandomResizedCrop(size_crops[=transforms.InterpolationMode.BICUBIC),
interpolation=0.5),
transforms.RandomHorizontalFlip(p=0.4, contrast=0.4, saturation=0.2, hue=0.1),
transforms.ColorJitter(brightness=0.2),
transforms.RandomGrayscale(p=0.5),
GaussianBlur(p=0.2),
Solarization(p
transforms.ToTensor(),0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
transforms.Normalize((
])
def __call__(self, image):
= []
crops
# Global crops
for _ in range(self.global_crops_number):
self.global_transform(image))
crops.append(
# Local crops
for _ in range(self.local_crops_number):
self.local_transform(image))
crops.append(
return crops
Augmentation Utilities
class GaussianBlur:
"""Gaussian blur augmentation"""
def __init__(self, p=0.5, radius_min=0.1, radius_max=2.0):
self.prob = p
self.radius_min = radius_min
self.radius_max = radius_max
def __call__(self, img):
if torch.rand(1) < self.prob:
= self.radius_min + torch.rand(1) * (self.radius_max - self.radius_min)
radius return transforms.functional.gaussian_blur(img, kernel_size=9, sigma=radius.item())
return img
class Solarization:
"""Solarization augmentation"""
def __init__(self, p=0.2):
self.p = p
def __call__(self, img):
if torch.rand(1) < self.p:
return transforms.functional.solarize(img, threshold=128)
return img
Loss Functions and Training Components
class DINOLoss(nn.Module):
"""DINO Loss with centering and sharpening"""
def __init__(self, out_dim, ncrops, warmup_teacher_temp=0.04,
=0.04, warmup_teacher_temp_epochs=0,
teacher_temp=0.1, center_momentum=0.9):
student_tempsuper().__init__()
self.student_temp = student_temp
self.center_momentum = center_momentum
self.ncrops = ncrops
self.register_buffer("center", torch.zeros(1, out_dim))
# Temperature schedule
self.teacher_temp_schedule = np.concatenate((
np.linspace(warmup_teacher_temp, teacher_temp, warmup_teacher_temp_epochs),1000) * teacher_temp # Assume max 1000 epochs
np.ones(
))
def forward(self, student_output, teacher_output, epoch):
"""
Cross-entropy between softmax outputs of the teacher and student networks.
"""
= student_output / self.student_temp
student_out = student_out.chunk(self.ncrops)
student_out
# Teacher centering and sharpening
= self.teacher_temp_schedule[epoch]
temp = F.softmax((teacher_output - self.center) / temp, dim=-1)
teacher_out = teacher_out.detach().chunk(2) # Only 2 global crops for teacher
teacher_out
= 0
total_loss = 0
n_loss_terms
for iq, q in enumerate(teacher_out):
for v in range(len(student_out)):
if v == iq:
continue # Skip same crop
= torch.sum(-q * F.log_softmax(student_out[v], dim=-1), dim=-1)
loss += loss.mean()
total_loss += 1
n_loss_terms
/= n_loss_terms
total_loss self.update_center(teacher_output)
return total_loss
@torch.no_grad()
def update_center(self, teacher_output):
"""Update center used for teacher output."""
= torch.sum(teacher_output, dim=0, keepdim=True)
batch_center = batch_center / len(teacher_output)
batch_center
# EMA update
self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum)
Training Utilities
@torch.no_grad()
def update_teacher(student, teacher, momentum):
"""EMA update of the teacher network."""
for param_student, param_teacher in zip(student.parameters(), teacher.parameters()):
=1 - momentum)
param_teacher.data.mul_(momentum).add_(param_student.data, alpha
def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0):
"""Cosine learning rate schedule with linear warmup."""
= np.array([])
warmup_schedule = warmup_epochs * niter_per_ep
warmup_iters
if warmup_epochs > 0:
= np.linspace(0, base_value, warmup_iters)
warmup_schedule
= np.arange(epochs * niter_per_ep - warmup_iters)
iters = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters)))
schedule
= np.concatenate((warmup_schedule, schedule))
schedule assert len(schedule) == epochs * niter_per_ep
return schedule
Training Loop Implementation
class DINOv2Trainer:
"""DINOv2 Training Pipeline"""
def __init__(self, config):
self.config = config
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Model architecture configs
= {
backbone_args 'img_size': 224,
'patch_size': 16,
'embed_dim': 768,
'depth': 12,
'num_heads': 12,
'mlp_ratio': 4.0,
'dropout': 0.0
}
= {
head_args 'in_dim': 768,
'out_dim': 65536, # Large output dimension
'hidden_dim': 2048,
'bottleneck_dim': 256
}
# Initialize student and teacher networks
self.student = DINOv2(backbone_args, head_args).to(self.device)
self.teacher = DINOv2(backbone_args, head_args).to(self.device)
# Teacher starts as copy of student
self.teacher.load_state_dict(self.student.state_dict())
# Teacher parameters are not updated by gradients
for p in self.teacher.parameters():
= False
p.requires_grad
# Loss function
self.dino_loss = DINOLoss(
=head_args['out_dim'],
out_dim=8, # 2 global + 6 local crops
ncrops=0.1,
student_temp=0.04,
teacher_temp=0.9
center_momentumself.device)
).to(
# Optimizer
self.optimizer = torch.optim.AdamW(
self.student.parameters(),
=config['base_lr'],
lr=config['weight_decay']
weight_decay
)
# Learning rate scheduler
self.lr_schedule = cosine_scheduler(
'base_lr'],
config['final_lr'],
config['epochs'],
config['niter_per_ep'],
config['warmup_epochs']
config[
)
# Momentum schedule for teacher updates
self.momentum_schedule = cosine_scheduler(
'momentum_teacher'],
config[1.0,
'epochs'],
config['niter_per_ep']
config[
)
def train_epoch(self, dataloader, epoch):
"""Train for one epoch"""
self.student.train()
self.teacher.eval()
= 0
total_loss = len(dataloader)
num_batches
for it, (images, _) in enumerate(dataloader):
# Update learning rate
= self.lr_schedule[epoch * num_batches + it]
lr for param_group in self.optimizer.param_groups:
'lr'] = lr
param_group[
# Move to device and prepare crops
= [im.to(self.device, non_blocking=True) for im in images]
images
# Teacher forward pass (only on global crops)
= self.teacher(torch.cat(images[:2]))
teacher_output
# Student forward pass (on all crops)
= self.student(torch.cat(images))
student_output
# Compute loss
= self.dino_loss(student_output, teacher_output, epoch)
loss
# Backward pass
self.optimizer.zero_grad()
loss.backward()
# Gradient clipping
self.student.parameters(), max_norm=3.0)
torch.nn.utils.clip_grad_norm_(
self.optimizer.step()
# Update teacher with EMA
= self.momentum_schedule[epoch * num_batches + it]
momentum self.student, self.teacher, momentum)
update_teacher(
+= loss.item()
total_loss
if it % 100 == 0:
print(f'Epoch {epoch}, Iter {it}/{num_batches}, Loss: {loss.item():.4f}, LR: {lr:.6f}')
return total_loss / num_batches
def train(self, dataloader):
"""Full training loop"""
for epoch in range(self.config['epochs']):
= self.train_epoch(dataloader, epoch)
avg_loss print(f'Epoch {epoch}/{self.config["epochs"]}, Average Loss: {avg_loss:.4f}')
# Save checkpoint
if epoch % self.config['save_every'] == 0:
self.save_checkpoint(epoch)
def save_checkpoint(self, epoch):
"""Save model checkpoint"""
= {
checkpoint 'epoch': epoch,
'student_state_dict': self.student.state_dict(),
'teacher_state_dict': self.teacher.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'config': self.config
}f'dinov2_checkpoint_epoch_{epoch}.pth') torch.save(checkpoint,
Usage Example
def main():
# Training configuration
= {
config 'base_lr': 5e-4,
'final_lr': 1e-6,
'weight_decay': 0.04,
'momentum_teacher': 0.996,
'epochs': 100,
'warmup_epochs': 10,
'batch_size': 64,
'save_every': 10,
'niter_per_ep': None # Will be set after dataloader creation
}
# Data setup
= MultiCropDataAugmentation()
transform = ImageFolder(root='path/to/your/dataset', transform=transform)
dataset = DataLoader(
dataloader
dataset, =config['batch_size'],
batch_size=True,
shuffle=4,
num_workers=True,
pin_memory=True
drop_last
)
'niter_per_ep'] = len(dataloader)
config[
# Initialize trainer and start training
= DINOv2Trainer(config)
trainer
trainer.train(dataloader)
main()
Key Features Implemented
- Vision Transformer Backbone: Complete ViT implementation with patch embedding, multi-head attention, and transformer blocks
- Multi-crop Strategy: Global and local crops with different augmentations
- Teacher-Student Framework: EMA updates for teacher network
- DINO Loss: Cross-entropy loss with centering mechanism to prevent collapse
- Learning Rate Scheduling: Cosine annealing with warmup
- Gradient Clipping: Stability during training
- Checkpointing: Save/load model states
Training Tips
- Batch Size: Use large batch sizes (256-1024) for better performance
- Data Augmentation: Strong augmentations are crucial for self-supervised learning
- Temperature Scheduling: Gradually increase teacher temperature
- Momentum Scheduling: Start with high momentum and decrease over time
- Multi-GPU Training: Use DistributedDataParallel for faster training
This implementation provides a solid foundation for training DINOv2 models. Adjust hyperparameters based on your dataset size and computational resources.