Matryoshka Transformer: Complete Implementation Guide
Introduction
Matryoshka Transformers are a neural architecture that enables flexible computational budgets during inference by allowing early exits at different layers. Named after Russian nesting dolls, these models contain multiple “nested” representations of decreasing complexity, allowing you to trade off accuracy for speed based on your computational constraints.
Key Concepts
Core Ideas
- Nested Representations: Each layer can potentially serve as a final output
- Early Exits: Inference can stop at any intermediate layer
- Adaptive Computation: Different inputs may require different amounts of computation
- Training Efficiency: Single model training for multiple computational budgets
Architecture Overview
Input → Layer 1 → [Exit 1] → Layer 2 → [Exit 2] → ... → Layer N → [Final Exit]
Implementation
1. Basic Matryoshka Transformer Block
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Optional, Tuple
class MatryoshkaTransformerBlock(nn.Module):
"""
A single transformer block with optional early exit capability
"""
def __init__(
self,
int,
d_model: int,
n_heads: int,
d_ff: float = 0.1,
dropout: bool = False,
has_exit: int] = None
n_classes: Optional[
):super().__init__()
# Standard transformer components
self.attention = nn.MultiheadAttention(
=dropout, batch_first=True
d_model, n_heads, dropout
)self.feed_forward = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(d_ff, d_model)
)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
# Early exit components
self.has_exit = has_exit
if has_exit and n_classes is not None:
self.exit_classifier = nn.Sequential(
nn.LayerNorm(d_model),// 2),
nn.Linear(d_model, d_model
nn.ReLU(),
nn.Dropout(dropout),// 2, n_classes)
nn.Linear(d_model
)
def forward(
self,
x: torch.Tensor, = None
mask: Optional[torch.Tensor] -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
) """
Forward pass with optional early exit
Returns:
x: Transformed input
exit_logits: Early exit predictions (if has_exit=True)
"""
# Self-attention
= self.attention(x, x, x, attn_mask=mask)
attn_out, _ = self.norm1(x + self.dropout(attn_out))
x
# Feed-forward
= self.feed_forward(x)
ff_out = self.norm2(x + self.dropout(ff_out))
x
# Early exit prediction
= None
exit_logits if self.has_exit:
# Use mean pooling for sequence classification
= x.mean(dim=1) # [batch_size, d_model]
pooled = self.exit_classifier(pooled)
exit_logits
return x, exit_logits
2. Complete Matryoshka Transformer Model
class MatryoshkaTransformer(nn.Module):
"""
Complete Matryoshka Transformer with multiple exit points
"""
def __init__(
self,
int,
vocab_size: int = 512,
d_model: int = 8,
n_heads: int = 6,
n_layers: int = 2048,
d_ff: int = 512,
max_seq_len: int = 2,
n_classes: float = 0.1,
dropout: int] = None # Layers with early exits
exit_layers: List[
):super().__init__()
self.d_model = d_model
self.n_layers = n_layers
# Default exit layers (every 2 layers + final)
if exit_layers is None:
= list(range(1, n_layers, 2)) + [n_layers - 1]
exit_layers self.exit_layers = set(exit_layers)
# Embeddings
self.token_embedding = nn.Embedding(vocab_size, d_model)
self.position_embedding = nn.Embedding(max_seq_len, d_model)
self.dropout = nn.Dropout(dropout)
# Transformer blocks
self.blocks = nn.ModuleList([
MatryoshkaTransformerBlock(=d_model,
d_model=n_heads,
n_heads=d_ff,
d_ff=dropout,
dropout=(i in self.exit_layers),
has_exit=n_classes
n_classes
)for i in range(n_layers)
])
# Final classifier (always present)
self.final_classifier = nn.Sequential(
nn.LayerNorm(d_model),
nn.Linear(d_model, n_classes)
)
# Confidence thresholds for early exits
self.confidence_thresholds = nn.Parameter(
len(self.exit_layers),), 0.8)
torch.full((
)
def forward(
self,
input_ids: torch.Tensor,= None,
attention_mask: Optional[torch.Tensor] bool = False,
return_all_exits: float = 0.8,
confidence_threshold: int] = None
max_exit_layer: Optional[-> dict:
) """
Forward pass with adaptive early exiting
Args:
input_ids: Input token IDs [batch_size, seq_len]
attention_mask: Attention mask [batch_size, seq_len]
return_all_exits: Whether to return predictions from all exit points
confidence_threshold: Minimum confidence for early exit
max_exit_layer: Maximum layer to exit at (for budget constraints)
Returns:
Dictionary containing predictions and exit information
"""
= input_ids.shape
batch_size, seq_len
# Embeddings
= torch.arange(seq_len, device=input_ids.device).unsqueeze(0)
positions = self.token_embedding(input_ids) + self.position_embedding(positions)
x = self.dropout(x)
x
# Prepare attention mask
if attention_mask is not None:
# Convert to transformer format
= attention_mask.unsqueeze(1).unsqueeze(2)
attn_mask = (1.0 - attn_mask) * -10000.0
attn_mask = attn_mask.squeeze(1).squeeze(1)
attn_mask else:
= None
attn_mask
# Track exits
= []
exit_predictions = []
exit_confidences = None
exit_layer
# Forward through transformer blocks
for i, block in enumerate(self.blocks):
= block(x, attn_mask)
x, exit_logits
# Check for early exit
if exit_logits is not None:
= F.softmax(exit_logits, dim=-1)
exit_probs = torch.max(exit_probs, dim=-1)[0]
max_confidence
exit_predictions.append(exit_logits)
exit_confidences.append(max_confidence)
# Early exit decision
if not return_all_exits:
if max_exit_layer is None or i <= max_exit_layer:
if torch.mean(max_confidence) >= confidence_threshold:
= i
exit_layer break
# Final prediction
= self.final_classifier(x.mean(dim=1))
final_output
return {
'logits': final_output,
'exit_predictions': exit_predictions,
'exit_confidences': exit_confidences,
'exit_layer': exit_layer,
'total_layers_used': (exit_layer + 1) if exit_layer is not None else self.n_layers
}
3. Training Strategy
class MatryoshkaTrainer:
"""
Training strategy for Matryoshka Transformers
"""
def __init__(
self,
model: MatryoshkaTransformer,float] = None,
exit_loss_weights: List[float = 0.5
distillation_weight:
):self.model = model
self.exit_loss_weights = exit_loss_weights or [0.3, 0.3, 1.0] # Increasing weights
self.distillation_weight = distillation_weight
def compute_loss(
self,
dict,
outputs:
labels: torch.Tensor,float = 3.0
temperature: -> dict:
) """
Compute combined loss from all exit points
"""
= {}
losses = 0
total_loss
# Final layer loss
= F.cross_entropy(outputs['logits'], labels)
final_loss 'final'] = final_loss
losses[+= final_loss
total_loss
# Early exit losses
if outputs['exit_predictions']:
for i, (exit_logits, weight) in enumerate(
zip(outputs['exit_predictions'], self.exit_loss_weights)
):# Classification loss
= F.cross_entropy(exit_logits, labels)
exit_loss f'exit_{i}'] = exit_loss
losses[+= weight * exit_loss
total_loss
# Knowledge distillation from final layer
if self.distillation_weight > 0:
= F.kl_div(
distill_loss / temperature, dim=-1),
F.log_softmax(exit_logits 'logits'] / temperature, dim=-1),
F.softmax(outputs[='batchmean'
reduction* (temperature ** 2)
)
f'distill_{i}'] = distill_loss
losses[+= self.distillation_weight * weight * distill_loss
total_loss
'total'] = total_loss
losses[return losses
def train_step(
self,
dict,
batch:
optimizer: torch.optim.Optimizer-> dict:
) """
Single training step
"""
self.model.train()
optimizer.zero_grad()
# Forward pass
= self.model(
outputs =batch['input_ids'],
input_ids=batch['attention_mask'],
attention_mask=True
return_all_exits
)
# Compute loss
= self.compute_loss(outputs, batch['labels'])
losses
# Backward pass
'total'].backward()
losses[
optimizer.step()
return {k: v.item() for k, v in losses.items()}
4. Inference with Adaptive Computation
class AdaptiveInference:
"""
Adaptive inference with configurable exit strategies
"""
def __init__(self, model: MatryoshkaTransformer):
self.model = model
def predict_with_budget(
self,
input_ids: torch.Tensor,= None,
attention_mask: Optional[torch.Tensor] float = 1.0, # Fraction of full model FLOPs
flop_budget: float = 0.8
confidence_threshold: -> dict:
) """
Predict with computational budget constraint
"""
= int(self.model.n_layers * flop_budget) - 1
max_layer
= self.model(
outputs =input_ids,
input_ids=attention_mask,
attention_mask=confidence_threshold,
confidence_threshold=max_layer
max_exit_layer
)
# Calculate actual computation used
= outputs['total_layers_used']
layers_used = layers_used / self.model.n_layers
actual_budget
return {
**outputs,
'computational_savings': 1.0 - actual_budget,
'flops_used': actual_budget
}
def predict_with_latency_constraint(
self,
input_ids: torch.Tensor,= None,
attention_mask: Optional[torch.Tensor] float = 100.0
max_latency_ms: -> dict:
) """
Predict with latency constraint (simplified)
"""
# This is a simplified version - in practice, you'd profile
# actual inference times for different exit points
= 10.0 # ms
estimated_time_per_layer = int(max_latency_ms / estimated_time_per_layer)
max_layers
return self.predict_with_budget(
=input_ids,
input_ids=attention_mask,
attention_mask=max_layers / self.model.n_layers
flop_budget )
5. Usage Example
# Initialize model
= MatryoshkaTransformer(
model =30000,
vocab_size=512,
d_model=8,
n_heads=12,
n_layers=2,
n_classes=[2, 5, 8, 11] # Exit points
exit_layers
)
# Training setup
= MatryoshkaTrainer(model)
trainer = torch.optim.Adam(model.parameters(), lr=1e-4)
optimizer
# Training loop (simplified)
for batch in dataloader:
= trainer.train_step(batch, optimizer)
losses print(f"Total loss: {losses['total']:.4f}")
# Inference
= AdaptiveInference(model)
inference_engine
# Example: Predict with 50% computational budget
= inference_engine.predict_with_budget(
result =sample_input,
input_ids=0.5,
flop_budget=0.85
confidence_threshold
)
print(f"Prediction: {result['logits'].argmax(-1)}")
print(f"Computational savings: {result['computational_savings']:.2%}")
print(f"Exited at layer: {result['exit_layer']}")
Advanced Features
1. Dynamic Confidence Thresholds
class DynamicThresholdStrategy:
"""
Dynamically adjust confidence thresholds based on input characteristics
"""
def __init__(self, base_threshold: float = 0.8):
self.base_threshold = base_threshold
def get_threshold(self, input_ids: torch.Tensor, layer: int) -> float:
"""
Compute dynamic threshold based on input and layer
"""
# Example: Lower threshold for longer sequences
= input_ids.shape[1]
seq_len = 1.0 - (seq_len - 50) / 500 # Adjust based on length
length_factor
# Example: Higher threshold for earlier layers
= 1.0 + (0.1 * (6 - layer)) # Stricter for early exits
layer_factor
return self.base_threshold * length_factor * layer_factor
2. Ensemble Early Exits
class EnsembleMatryoshka(nn.Module):
"""
Ensemble multiple exit predictions for better accuracy
"""
def __init__(self, base_model: MatryoshkaTransformer):
super().__init__()
self.base_model = base_model
self.ensemble_weights = nn.Parameter(torch.ones(len(base_model.exit_layers) + 1))
def forward(self, input_ids: torch.Tensor, **kwargs) -> dict:
= self.base_model(input_ids, return_all_exits=True, **kwargs)
outputs
# Ensemble all available predictions
= outputs['exit_predictions'] + [outputs['logits']]
all_logits = F.softmax(self.ensemble_weights, dim=0)
weights
= sum(w * logits for w, logits in zip(weights, all_logits))
ensemble_logits
return {
**outputs,
'ensemble_logits': ensemble_logits
}
Performance Optimization Tips
- Layer Selection: Choose exit layers strategically - too many exits can hurt training
- Loss Weighting: Start with lower weights for early exits, increase gradually
- Confidence Calibration: Use temperature scaling to calibrate exit confidences
- Batch Processing: Process samples with similar complexity together
- Caching: Cache intermediate representations for multiple exit strategies
Conclusion
Matryoshka Transformers offer a powerful way to build efficient models that can adapt their computational cost at inference time. The key to success is careful tuning of exit strategies, loss weights, and confidence thresholds for your specific use case.
This implementation provides a solid foundation that you can extend with additional features like cascaded exits, uncertainty estimation, or task-specific adaptations.