CLIP Code Guide: Complete Implementation and Usage
Table of Contents
Introduction to CLIP
CLIP (Contrastive Language-Image Pre-training) is a neural network architecture developed by OpenAI that learns visual concepts from natural language supervision. It can understand images in the context of natural language descriptions, enabling zero-shot classification and multimodal understanding.
Key Features:
- Zero-shot image classification
- Text-image similarity computation
- Multimodal embeddings
- Transfer learning capabilities
Architecture Overview
CLIP consists of two main components: 1. Text Encoder: Processes text descriptions (typically a Transformer) 2. Image Encoder: Processes images (typically a Vision Transformer or ResNet)
The model learns to maximize the cosine similarity between corresponding text-image pairs while minimizing it for non-corresponding pairs.
Setting Up the Environment
Installing Dependencies
# Basic installation
pip install torch torchvision transformers
pip install clip-by-openai # Official OpenAI CLIP
pip install open-clip-torch # OpenCLIP (more models)
# For development and training
pip install wandb datasets accelerate
pip install matplotlib pillow requests
Alternative Installation
# Install from source
git clone https://github.com/openai/CLIP.git
cd CLIP
pip install -e .
Basic CLIP Usage
1. Loading Pre-trained CLIP Model
import clip
import torch
from PIL import Image
import requests
from io import BytesIO
# Load model and preprocessing
= "cuda" if torch.cuda.is_available() else "cpu"
device = clip.load("ViT-B/32", device=device)
model, preprocess
# Available models: ViT-B/32, ViT-B/16, ViT-L/14, RN50, RN101, RN50x4, etc.
print(f"Available models: {clip.available_models()}")
2. Image Classification (Zero-shot)
def zero_shot_classification(image_path, text_options):
# Load and preprocess image
= Image.open(image_path)
image = preprocess(image).unsqueeze(0).to(device)
image_input
# Tokenize text options
= clip.tokenize(text_options).to(device)
text_inputs
# Get predictions
with torch.no_grad():
= model.encode_image(image_input)
image_features = model.encode_text(text_inputs)
text_features
# Calculate similarities
= (100.0 * image_features @ text_features.T).softmax(dim=-1)
similarities
# Get results
= similarities[0].topk(len(text_options))
values, indices
= []
results for value, index in zip(values, indices):
results.append({'label': text_options[index],
'confidence': value.item()
})
return results
# Example usage
= ["a dog", "a cat", "a car", "a bird", "a house"]
text_options = zero_shot_classification("path/to/image.jpg", text_options)
results
for result in results:
print(f"{result['label']}: {result['confidence']:.2%}")
3. Text-Image Similarity
def compute_similarity(image_path, text_description):
# Load image
= Image.open(image_path)
image = preprocess(image).unsqueeze(0).to(device)
image_input
# Tokenize text
= clip.tokenize([text_description]).to(device)
text_input
# Get features
with torch.no_grad():
= model.encode_image(image_input)
image_features = model.encode_text(text_input)
text_features
# Normalize features
= image_features / image_features.norm(dim=-1, keepdim=True)
image_features = text_features / text_features.norm(dim=-1, keepdim=True)
text_features
# Compute similarity
= (image_features @ text_features.T).item()
similarity
return similarity
# Example usage
= compute_similarity("dog.jpg", "a golden retriever sitting in grass")
similarity print(f"Similarity: {similarity:.4f}")
Custom CLIP Implementation
Basic CLIP Architecture
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import GPT2Model, GPT2Tokenizer
import timm
class CLIPModel(nn.Module):
def __init__(self,
='resnet50',
image_encoder_name='gpt2',
text_encoder_name=512,
embed_dim=224,
image_resolution=49408):
vocab_sizesuper().__init__()
self.embed_dim = embed_dim
self.image_resolution = image_resolution
# Image encoder
self.visual = timm.create_model(image_encoder_name, pretrained=True, num_classes=0)
= self.visual.num_features
visual_dim
# Text encoder
self.text_encoder = GPT2Model.from_pretrained(text_encoder_name)
= self.text_encoder.config.n_embd
text_dim
# Projection layers
self.visual_projection = nn.Linear(visual_dim, embed_dim, bias=False)
self.text_projection = nn.Linear(text_dim, embed_dim, bias=False)
# Learnable temperature parameter
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
self.initialize_parameters()
def initialize_parameters(self):
# Initialize projection layers
self.visual_projection.weight, std=0.02)
nn.init.normal_(self.text_projection.weight, std=0.02)
nn.init.normal_(
def encode_image(self, image):
# Extract visual features
= self.visual(image)
visual_features # Project to common embedding space
= self.visual_projection(visual_features)
image_features # Normalize
= F.normalize(image_features, dim=-1)
image_features return image_features
def encode_text(self, text):
# Get text features from last token
= self.text_encoder(text)
text_outputs # Use last token's representation
= text_outputs.last_hidden_state[:, -1, :]
text_features # Project to common embedding space
= self.text_projection(text_features)
text_features # Normalize
= F.normalize(text_features, dim=-1)
text_features return text_features
def forward(self, image, text):
= self.encode_image(image)
image_features = self.encode_text(text)
text_features
# Compute logits
= self.logit_scale.exp()
logit_scale = logit_scale * image_features @ text_features.t()
logits_per_image = logits_per_image.t()
logits_per_text
return logits_per_image, logits_per_text
CLIP Loss Function
def clip_loss(logits_per_image, logits_per_text):
"""
Contrastive loss for CLIP training
"""
= logits_per_image.shape[0]
batch_size = torch.arange(batch_size, device=logits_per_image.device)
labels
# Cross-entropy loss for both directions
= F.cross_entropy(logits_per_image, labels)
loss_i = F.cross_entropy(logits_per_text, labels)
loss_t
# Average the losses
= (loss_i + loss_t) / 2
loss return loss
Training CLIP from Scratch
Dataset Preparation
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import json
class ImageTextDataset(Dataset):
def __init__(self, data_path, image_dir, transform=None, tokenizer=None, max_length=77):
with open(data_path, 'r') as f:
self.data = json.load(f)
self.image_dir = image_dir
self.transform = transform
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
= self.data[idx]
item
# Load image
= os.path.join(self.image_dir, item['image'])
image_path = Image.open(image_path).convert('RGB')
image
if self.transform:
= self.transform(image)
image
# Tokenize text
= item['caption']
text if self.tokenizer:
= self.tokenizer(
text_tokens
text,=self.max_length,
max_length='max_length',
padding=True,
truncation='pt'
return_tensors'input_ids'].squeeze(0)
)[else:
# Simple tokenization for demonstration
= torch.zeros(self.max_length, dtype=torch.long)
text_tokens
return image, text_tokens, text
Training Loop
def train_clip(model, dataloader, optimizer, scheduler, device, num_epochs):
model.train()
for epoch in range(num_epochs):
= 0
total_loss = 0
num_batches
for batch_idx, (images, text_tokens, _) in enumerate(dataloader):
= images.to(device)
images = text_tokens.to(device)
text_tokens
# Forward pass
= model(images, text_tokens)
logits_per_image, logits_per_text
# Compute loss
= clip_loss(logits_per_image, logits_per_text)
loss
# Backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()
+= loss.item()
total_loss += 1
num_batches
# Logging
if batch_idx % 100 == 0:
print(f'Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}')
# Update learning rate
scheduler.step()
= total_loss / num_batches
avg_loss print(f'Epoch {epoch} completed. Average Loss: {avg_loss:.4f}')
# Training setup
= CLIPModel().to(device)
model = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)
optimizer = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
scheduler
# Start training
=100) train_clip(model, dataloader, optimizer, scheduler, device, num_epochs
Fine-tuning CLIP
Domain-Specific Fine-tuning
def fine_tune_clip(pretrained_model, dataloader, num_epochs=10, lr=1e-5):
# Freeze most layers, only fine-tune projection layers
for param in pretrained_model.visual.parameters():
= False
param.requires_grad
for param in pretrained_model.text_encoder.parameters():
= False
param.requires_grad
# Only train projection layers
= torch.optim.Adam([
optimizer 'params': pretrained_model.visual_projection.parameters()},
{'params': pretrained_model.text_projection.parameters()},
{'params': [pretrained_model.logit_scale]}
{=lr)
], lr
pretrained_model.train()
for epoch in range(num_epochs):
for batch_idx, (images, text_tokens, _) in enumerate(dataloader):
= images.to(device)
images = text_tokens.to(device)
text_tokens
= pretrained_model(images, text_tokens)
logits_per_image, logits_per_text = clip_loss(logits_per_image, logits_per_text)
loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch_idx % 50 == 0:
print(f'Fine-tune Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}')
Advanced Applications
1. Image Search with CLIP
class CLIPImageSearch:
def __init__(self, model, preprocess):
self.model = model
self.preprocess = preprocess
self.image_features = None
self.image_paths = None
def index_images(self, image_paths):
"""Pre-compute features for all images"""
self.image_paths = image_paths
= []
features
for img_path in image_paths:
= Image.open(img_path)
image = self.preprocess(image).unsqueeze(0).to(device)
image_input
with torch.no_grad():
= self.model.encode_image(image_input)
image_feature
features.append(image_feature)
self.image_features = torch.cat(features, dim=0)
self.image_features = self.image_features / self.image_features.norm(dim=-1, keepdim=True)
def search(self, query_text, top_k=5):
"""Search for images matching the text query"""
= clip.tokenize([query_text]).to(device)
text_input
with torch.no_grad():
= self.model.encode_text(text_input)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
text_features
# Compute similarities
= (text_features @ self.image_features.T).squeeze(0)
similarities
# Get top-k results
= similarities.topk(top_k)
top_similarities, top_indices
= []
results for sim, idx in zip(top_similarities, top_indices):
results.append({'path': self.image_paths[idx],
'similarity': sim.item()
})
return results
# Usage example
= CLIPImageSearch(model, preprocess)
search_engine
search_engine.index_images(list_of_image_paths)= search_engine.search("a red sports car", top_k=10) results
2. Content-Based Image Clustering
from sklearn.cluster import KMeans
import numpy as np
def cluster_images_by_content(image_paths, n_clusters=5):
# Extract features for all images
= []
features
for img_path in image_paths:
= Image.open(img_path)
image = preprocess(image).unsqueeze(0).to(device)
image_input
with torch.no_grad():
= model.encode_image(image_input)
feature
features.append(feature.cpu().numpy())
# Convert to numpy array
= np.vstack(features)
features
# Perform clustering
= KMeans(n_clusters=n_clusters, random_state=42)
kmeans = kmeans.fit_predict(features)
cluster_labels
# Organize results
= {}
clusters for i, label in enumerate(cluster_labels):
if label not in clusters:
= []
clusters[label]
clusters[label].append(image_paths[i])
return clusters
3. Visual Question Answering
def visual_qa(image_path, question, answer_choices):
"""Simple VQA using CLIP"""
= Image.open(image_path)
image = preprocess(image).unsqueeze(0).to(device)
image_input
# Create prompts combining question with each answer
= [f"Question: {question} Answer: {choice}" for choice in answer_choices]
prompts = clip.tokenize(prompts).to(device)
text_inputs
with torch.no_grad():
= model.encode_image(image_input)
image_features = model.encode_text(text_inputs)
text_features
# Compute similarities
= (100.0 * image_features @ text_features.T).softmax(dim=-1)
similarities
# Return the most likely answer
= similarities.argmax().item()
best_idx return answer_choices[best_idx], similarities[0][best_idx].item()
# Example usage
= visual_qa(
answer, confidence "image.jpg",
"What color is the car?",
"red", "blue", "green", "yellow", "black"]
[
)print(f"Answer: {answer}, Confidence: {confidence:.2%}")
Performance Optimization
1. Batch Processing
def batch_encode_images(image_paths, batch_size=32):
"""Process images in batches for better efficiency"""
= []
all_features
for i in range(0, len(image_paths), batch_size):
= image_paths[i:i+batch_size]
batch_paths = []
batch_images
for path in batch_paths:
= Image.open(path)
image = preprocess(image)
image_input
batch_images.append(image_input)
= torch.stack(batch_images).to(device)
batch_tensor
with torch.no_grad():
= model.encode_image(batch_tensor)
batch_features
all_features.append(batch_features.cpu())
return torch.cat(all_features, dim=0)
2. Mixed Precision Training
from torch.cuda.amp import autocast, GradScaler
def train_with_mixed_precision(model, dataloader, optimizer, num_epochs):
= GradScaler()
scaler
model.train()
for epoch in range(num_epochs):
for images, text_tokens, _ in dataloader:
= images.to(device)
images = text_tokens.to(device)
text_tokens
optimizer.zero_grad()
# Forward pass with autocast
with autocast():
= model(images, text_tokens)
logits_per_image, logits_per_text = clip_loss(logits_per_image, logits_per_text)
loss
# Backward pass with scaling
scaler.scale(loss).backward()
scaler.step(optimizer) scaler.update()
3. Model Quantization
import torch.quantization as quantization
def quantize_clip_model(model):
"""Quantize CLIP model for inference"""
eval()
model.
# Specify quantization configuration
= quantization.get_default_qconfig('fbgemm')
model.qconfig
# Prepare model for quantization
= quantization.prepare(model, inplace=False)
model_prepared
# Calibrate with sample data (you need to provide calibration data)
# ... calibration code here ...
# Convert to quantized model
= quantization.convert(model_prepared, inplace=False)
model_quantized
return model_quantized
Common Issues and Solutions
1. Memory Management
# Clear GPU cache
torch.cuda.empty_cache()
# Use gradient checkpointing for large models
def enable_gradient_checkpointing(model):
if hasattr(model.visual, 'set_grad_checkpointing'):
True)
model.visual.set_grad_checkpointing(if hasattr(model.text_encoder, 'gradient_checkpointing_enable'):
model.text_encoder.gradient_checkpointing_enable()
2. Handling Different Image Sizes
from torchvision import transforms
def create_adaptive_transform(target_size=224):
return transforms.Compose([
=transforms.InterpolationMode.BICUBIC),
transforms.Resize(target_size, interpolation
transforms.CenterCrop(target_size),
transforms.ToTensor(),
transforms.Normalize(=[0.485, 0.456, 0.406],
mean=[0.229, 0.224, 0.225]
std
) ])
3. Text Preprocessing
import re
def preprocess_text(text, max_length=77):
"""Clean and preprocess text for CLIP"""
# Remove special characters and extra whitespace
= re.sub(r'[^\w\s]', '', text)
text = ' '.join(text.split())
text
# Truncate if too long
= text.split()
words if len(words) > max_length - 2: # Account for special tokens
= ' '.join(words[:max_length-2])
text
return text
4. Model Evaluation Utilities
def evaluate_zero_shot_accuracy(model, preprocess, test_loader, class_names):
"""Evaluate zero-shot classification accuracy"""
eval()
model.= 0
correct = 0
total
# Encode class names
= clip.tokenize([f"a photo of a {name}" for name in class_names]).to(device)
text_inputs
with torch.no_grad():
= model.encode_text(text_inputs)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
text_features
for images, labels in test_loader:
= images.to(device)
images
# Encode images
= model.encode_image(images)
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
image_features
# Compute similarities
= (100.0 * image_features @ text_features.T).softmax(dim=-1)
similarities = similarities.argmax(dim=-1)
predictions
+= (predictions == labels.to(device)).sum().item()
correct += labels.size(0)
total
= correct / total
accuracy return accuracy
# Usage
= evaluate_zero_shot_accuracy(model, preprocess, test_loader, class_names)
accuracy print(f"Zero-shot accuracy: {accuracy:.2%}")
Conclusion
This guide covers the essential aspects of working with CLIP, from basic usage to advanced implementations. Key takeaways:
- Start Simple: Use pre-trained models for most applications
- Understand the Architecture: CLIP’s power comes from joint text-image training
- Optimize for Your Use Case: Fine-tune or customize based on your specific needs
- Monitor Performance: Use proper evaluation metrics and optimization techniques
- Handle Edge Cases: Implement robust preprocessing and error handling
For production deployments, consider:
- Model quantization for faster inference
- Batch processing for efficiency
- Proper error handling and fallbacks
- Monitoring and logging for performance tracking
The field of multimodal AI is rapidly evolving, so stay updated with the latest research and implementations to leverage CLIP’s full potential in your applications.