Introduction

Matryoshka Transformers are a neural architecture that enables flexible computational budgets during inference by allowing early exits at different layers. Named after Russian nesting dolls, these models contain multiple “nested” representations of decreasing complexity, allowing you to trade off accuracy for speed based on your computational constraints.

Key Concepts

Core Ideas

  • Nested Representations: Each layer can potentially serve as a final output
  • Early Exits: Inference can stop at any intermediate layer
  • Adaptive Computation: Different inputs may require different amounts of computation
  • Training Efficiency: Single model training for multiple computational budgets

Architecture Overview

Input → Layer 1 → [Exit 1] → Layer 2 → [Exit 2] → ... → Layer N → [Final Exit]

Implementation

1. Basic Matryoshka Transformer Block

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Optional, Tuple

class MatryoshkaTransformerBlock(nn.Module):
    """
    A single transformer block with optional early exit capability
    """
    def __init__(
        self,
        d_model: int,
        n_heads: int,
        d_ff: int,
        dropout: float = 0.1,
        has_exit: bool = False,
        n_classes: Optional[int] = None
    ):
        super().__init__()
        
        # Standard transformer components
        self.attention = nn.MultiheadAttention(
            d_model, n_heads, dropout=dropout, batch_first=True
        )
        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model)
        )
        
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
        # Early exit components
        self.has_exit = has_exit
        if has_exit and n_classes is not None:
            self.exit_classifier = nn.Sequential(
                nn.LayerNorm(d_model),
                nn.Linear(d_model, d_model // 2),
                nn.ReLU(),
                nn.Dropout(dropout),
                nn.Linear(d_model // 2, n_classes)
            )
    
    def forward(
        self, 
        x: torch.Tensor, 
        mask: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        """
        Forward pass with optional early exit
        
        Returns:
            x: Transformed input
            exit_logits: Early exit predictions (if has_exit=True)
        """
        # Self-attention
        attn_out, _ = self.attention(x, x, x, attn_mask=mask)
        x = self.norm1(x + self.dropout(attn_out))
        
        # Feed-forward
        ff_out = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_out))
        
        # Early exit prediction
        exit_logits = None
        if self.has_exit:
            # Use mean pooling for sequence classification
            pooled = x.mean(dim=1)  # [batch_size, d_model]
            exit_logits = self.exit_classifier(pooled)
        
        return x, exit_logits

2. Complete Matryoshka Transformer Model

class MatryoshkaTransformer(nn.Module):
    """
    Complete Matryoshka Transformer with multiple exit points
    """
    def __init__(
        self,
        vocab_size: int,
        d_model: int = 512,
        n_heads: int = 8,
        n_layers: int = 6,
        d_ff: int = 2048,
        max_seq_len: int = 512,
        n_classes: int = 2,
        dropout: float = 0.1,
        exit_layers: List[int] = None  # Layers with early exits
    ):
        super().__init__()
        
        self.d_model = d_model
        self.n_layers = n_layers
        
        # Default exit layers (every 2 layers + final)
        if exit_layers is None:
            exit_layers = list(range(1, n_layers, 2)) + [n_layers - 1]
        self.exit_layers = set(exit_layers)
        
        # Embeddings
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.position_embedding = nn.Embedding(max_seq_len, d_model)
        self.dropout = nn.Dropout(dropout)
        
        # Transformer blocks
        self.blocks = nn.ModuleList([
            MatryoshkaTransformerBlock(
                d_model=d_model,
                n_heads=n_heads,
                d_ff=d_ff,
                dropout=dropout,
                has_exit=(i in self.exit_layers),
                n_classes=n_classes
            )
            for i in range(n_layers)
        ])
        
        # Final classifier (always present)
        self.final_classifier = nn.Sequential(
            nn.LayerNorm(d_model),
            nn.Linear(d_model, n_classes)
        )
        
        # Confidence thresholds for early exits
        self.confidence_thresholds = nn.Parameter(
            torch.full((len(self.exit_layers),), 0.8)
        )
    
    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        return_all_exits: bool = False,
        confidence_threshold: float = 0.8,
        max_exit_layer: Optional[int] = None
    ) -> dict:
        """
        Forward pass with adaptive early exiting
        
        Args:
            input_ids: Input token IDs [batch_size, seq_len]
            attention_mask: Attention mask [batch_size, seq_len]
            return_all_exits: Whether to return predictions from all exit points
            confidence_threshold: Minimum confidence for early exit
            max_exit_layer: Maximum layer to exit at (for budget constraints)
        
        Returns:
            Dictionary containing predictions and exit information
        """
        batch_size, seq_len = input_ids.shape
        
        # Embeddings
        positions = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)
        x = self.token_embedding(input_ids) + self.position_embedding(positions)
        x = self.dropout(x)
        
        # Prepare attention mask
        if attention_mask is not None:
            # Convert to transformer format
            attn_mask = attention_mask.unsqueeze(1).unsqueeze(2)
            attn_mask = (1.0 - attn_mask) * -10000.0
            attn_mask = attn_mask.squeeze(1).squeeze(1)
        else:
            attn_mask = None
        
        # Track exits
        exit_predictions = []
        exit_confidences = []
        exit_layer = None
        
        # Forward through transformer blocks
        for i, block in enumerate(self.blocks):
            x, exit_logits = block(x, attn_mask)
            
            # Check for early exit
            if exit_logits is not None:
                exit_probs = F.softmax(exit_logits, dim=-1)
                max_confidence = torch.max(exit_probs, dim=-1)[0]
                
                exit_predictions.append(exit_logits)
                exit_confidences.append(max_confidence)
                
                # Early exit decision
                if not return_all_exits:
                    if max_exit_layer is None or i <= max_exit_layer:
                        if torch.mean(max_confidence) >= confidence_threshold:
                            exit_layer = i
                            break
        
        # Final prediction
        final_output = self.final_classifier(x.mean(dim=1))
        
        return {
            'logits': final_output,
            'exit_predictions': exit_predictions,
            'exit_confidences': exit_confidences,
            'exit_layer': exit_layer,
            'total_layers_used': (exit_layer + 1) if exit_layer is not None else self.n_layers
        }

3. Training Strategy

class MatryoshkaTrainer:
    """
    Training strategy for Matryoshka Transformers
    """
    def __init__(
        self,
        model: MatryoshkaTransformer,
        exit_loss_weights: List[float] = None,
        distillation_weight: float = 0.5
    ):
        self.model = model
        self.exit_loss_weights = exit_loss_weights or [0.3, 0.3, 1.0]  # Increasing weights
        self.distillation_weight = distillation_weight
        
    def compute_loss(
        self,
        outputs: dict,
        labels: torch.Tensor,
        temperature: float = 3.0
    ) -> dict:
        """
        Compute combined loss from all exit points
        """
        losses = {}
        total_loss = 0
        
        # Final layer loss
        final_loss = F.cross_entropy(outputs['logits'], labels)
        losses['final'] = final_loss
        total_loss += final_loss
        
        # Early exit losses
        if outputs['exit_predictions']:
            for i, (exit_logits, weight) in enumerate(
                zip(outputs['exit_predictions'], self.exit_loss_weights)
            ):
                # Classification loss
                exit_loss = F.cross_entropy(exit_logits, labels)
                losses[f'exit_{i}'] = exit_loss
                total_loss += weight * exit_loss
                
                # Knowledge distillation from final layer
                if self.distillation_weight > 0:
                    distill_loss = F.kl_div(
                        F.log_softmax(exit_logits / temperature, dim=-1),
                        F.softmax(outputs['logits'] / temperature, dim=-1),
                        reduction='batchmean'
                    ) * (temperature ** 2)
                    
                    losses[f'distill_{i}'] = distill_loss
                    total_loss += self.distillation_weight * weight * distill_loss
        
        losses['total'] = total_loss
        return losses
    
    def train_step(
        self,
        batch: dict,
        optimizer: torch.optim.Optimizer
    ) -> dict:
        """
        Single training step
        """
        self.model.train()
        optimizer.zero_grad()
        
        # Forward pass
        outputs = self.model(
            input_ids=batch['input_ids'],
            attention_mask=batch['attention_mask'],
            return_all_exits=True
        )
        
        # Compute loss
        losses = self.compute_loss(outputs, batch['labels'])
        
        # Backward pass
        losses['total'].backward()
        optimizer.step()
        
        return {k: v.item() for k, v in losses.items()}

4. Inference with Adaptive Computation

class AdaptiveInference:
    """
    Adaptive inference with configurable exit strategies
    """
    def __init__(self, model: MatryoshkaTransformer):
        self.model = model
    
    def predict_with_budget(
        self,
        input_ids: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        flop_budget: float = 1.0,  # Fraction of full model FLOPs
        confidence_threshold: float = 0.8
    ) -> dict:
        """
        Predict with computational budget constraint
        """
        max_layer = int(self.model.n_layers * flop_budget) - 1
        
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            confidence_threshold=confidence_threshold,
            max_exit_layer=max_layer
        )
        
        # Calculate actual computation used
        layers_used = outputs['total_layers_used']
        actual_budget = layers_used / self.model.n_layers
        
        return {
            **outputs,
            'computational_savings': 1.0 - actual_budget,
            'flops_used': actual_budget
        }
    
    def predict_with_latency_constraint(
        self,
        input_ids: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        max_latency_ms: float = 100.0
    ) -> dict:
        """
        Predict with latency constraint (simplified)
        """
        # This is a simplified version - in practice, you'd profile
        # actual inference times for different exit points
        
        estimated_time_per_layer = 10.0  # ms
        max_layers = int(max_latency_ms / estimated_time_per_layer)
        
        return self.predict_with_budget(
            input_ids=input_ids,
            attention_mask=attention_mask,
            flop_budget=max_layers / self.model.n_layers
        )

5. Usage Example

# Initialize model
model = MatryoshkaTransformer(
    vocab_size=30000,
    d_model=512,
    n_heads=8,
    n_layers=12,
    n_classes=2,
    exit_layers=[2, 5, 8, 11]  # Exit points
)

# Training setup
trainer = MatryoshkaTrainer(model)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# Training loop (simplified)
for batch in dataloader:
    losses = trainer.train_step(batch, optimizer)
    print(f"Total loss: {losses['total']:.4f}")

# Inference
inference_engine = AdaptiveInference(model)

# Example: Predict with 50% computational budget
result = inference_engine.predict_with_budget(
    input_ids=sample_input,
    flop_budget=0.5,
    confidence_threshold=0.85
)

print(f"Prediction: {result['logits'].argmax(-1)}")
print(f"Computational savings: {result['computational_savings']:.2%}")
print(f"Exited at layer: {result['exit_layer']}")

Advanced Features

1. Dynamic Confidence Thresholds

class DynamicThresholdStrategy:
    """
    Dynamically adjust confidence thresholds based on input characteristics
    """
    def __init__(self, base_threshold: float = 0.8):
        self.base_threshold = base_threshold
        
    def get_threshold(self, input_ids: torch.Tensor, layer: int) -> float:
        """
        Compute dynamic threshold based on input and layer
        """
        # Example: Lower threshold for longer sequences
        seq_len = input_ids.shape[1]
        length_factor = 1.0 - (seq_len - 50) / 500  # Adjust based on length
        
        # Example: Higher threshold for earlier layers
        layer_factor = 1.0 + (0.1 * (6 - layer))  # Stricter for early exits
        
        return self.base_threshold * length_factor * layer_factor

2. Ensemble Early Exits

class EnsembleMatryoshka(nn.Module):
    """
    Ensemble multiple exit predictions for better accuracy
    """
    def __init__(self, base_model: MatryoshkaTransformer):
        super().__init__()
        self.base_model = base_model
        self.ensemble_weights = nn.Parameter(torch.ones(len(base_model.exit_layers) + 1))
        
    def forward(self, input_ids: torch.Tensor, **kwargs) -> dict:
        outputs = self.base_model(input_ids, return_all_exits=True, **kwargs)
        
        # Ensemble all available predictions
        all_logits = outputs['exit_predictions'] + [outputs['logits']]
        weights = F.softmax(self.ensemble_weights, dim=0)
        
        ensemble_logits = sum(w * logits for w, logits in zip(weights, all_logits))
        
        return {
            **outputs,
            'ensemble_logits': ensemble_logits
        }

Performance Optimization Tips

  1. Layer Selection: Choose exit layers strategically - too many exits can hurt training
  2. Loss Weighting: Start with lower weights for early exits, increase gradually
  3. Confidence Calibration: Use temperature scaling to calibrate exit confidences
  4. Batch Processing: Process samples with similar complexity together
  5. Caching: Cache intermediate representations for multiple exit strategies

Conclusion

Matryoshka Transformers offer a powerful way to build efficient models that can adapt their computational cost at inference time. The key to success is careful tuning of exit strategies, loss weights, and confidence thresholds for your specific use case.

This implementation provides a solid foundation that you can extend with additional features like cascaded exits, uncertainty estimation, or task-specific adaptations.