%%{init: {
"theme": "base"
}}%%
flowchart LR
subgraph TRAIN["Training Frameworks"]
direction TB
PT["PyTorch"]
TF["TensorFlow / Keras"]
SK["scikit-learn"]
JAX["JAX / Flax"]
end
subgraph ONNX_CORE["ONNX Ecosystem"]
direction TB
MODEL["ONNX Model (.onnx)"]
OPT["Optimizer / Quantizer"]
MODEL -- "graph opt + quantize" --> OPT
OPT -- "optimized model" --> MODEL
end
subgraph DEPLOY["Deployment Targets"]
direction TB
ORT["ONNX Runtime Python · C++ · C# · Java"]
WEB["ONNX Runtime Web WASM · WebGL"]
TRT["TensorRT (NVIDIA)"]
OVI["OpenVINO (Intel)"]
CML["CoreML (Apple)"]
WML["Windows ML"]
end
PT -- export --> MODEL
TF -- export --> MODEL
SK -- export --> MODEL
JAX -- export --> MODEL
MODEL --> ORT
MODEL --> WEB
MODEL --> TRT
MODEL --> OVI
MODEL --> CML
MODEL --> WML
Training Computer Vision Models and Running Them with ONNX Runtime

Introduction
Computer vision is one of the most vibrant areas of applied machine learning. Whether you are building an image classifier, a real-time object detector, a segmentation model, or a pose estimator, the challenge after training is always the same: how do you efficiently deploy the model across diverse hardware—cloud GPUs, edge CPUs, mobile SoCs, FPGAs, or browsers—without rewriting your inference code for every target?
ONNX (Open Neural Network Exchange) and ONNX Runtime solve this problem. ONNX provides a standardized intermediate representation for neural network computation graphs, while ONNX Runtime is a high-performance inference engine that executes those graphs across a wide range of hardware backends.
This guide walks you through the entire lifecycle: training a vision model in PyTorch or TensorFlow, exporting it to the ONNX format, validating and optimizing the exported graph, running production-grade inference with ONNX Runtime, and deploying to various targets. By the end, you will have a reliable, reproducible workflow you can apply to nearly any computer vision project.
What is ONNX?
ONNX is an open standard created jointly by Microsoft and Facebook (Meta) in 2017, now maintained by the Linux Foundation under the ONNX community. Its core purpose is to allow models trained in one framework to be run in another.
At its heart, an ONNX model is a protobuf-serialized computation graph. Each node in the graph corresponds to a mathematical operator (Conv, BatchNormalization, Relu, MaxPool, Gemm, etc.), and edges are typed tensors that flow between them.
Key concepts:
- Opset version: ONNX defines its operators in versioned opsets. As of 2025, opset 19–21 are the most current. Always export with the highest opset your runtime supports to access the latest operator set.
- IR version: The overall file format version, independent of opset.
- Initializers: Constant tensors (model weights) stored inside the graph.
- Dynamic shapes: Axes can be marked symbolic (e.g.,
batch_size,height) to allow variable-size inputs at runtime.
Prerequisites and Environment Setup
Python Environment
It is best practice to use a virtual environment or conda environment per project.
# Using conda
conda create -n cv-onnx python=3.11
conda activate cv-onnx
# Or using venv
python -m venv .venv
source .venv/bin/activate # Linux / macOS
.venv\Scripts\activate # WindowsCore Packages
# Deep learning framework (choose one or both)
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
pip install tensorflow
# ONNX core
pip install onnx onnxscript
# ONNX Runtime — CPU only
pip install onnxruntime
# ONNX Runtime — GPU (CUDA 12.x)
pip install onnxruntime-gpu
# Optimization and quantization tools
pip install onnxruntime-tools
pip install onnxoptimizer
# Visualization
pip install netron # or open https://netron.app in a browser
# Utilities
pip install numpy pillow opencv-python-headless matplotlibonnxruntime and onnxruntime-gpu are mutually exclusive packages. Install only one per environment. The GPU package automatically falls back to CPU when CUDA is unavailable.
Verifying the Installation
import onnx
import onnxruntime as ort
import torch
print(f"ONNX version: {onnx.__version__}")
print(f"ONNX Runtime version: {ort.__version__}")
print(f"PyTorch version: {torch.__version__}")
print(f"Available ORT providers:{ort.get_available_providers()}")Understanding the ONNX Ecosystem
Before diving into code, it helps to understand how the different components fit together. Training frameworks export to the ONNX intermediate representation, which is then consumed by ONNX Runtime or converted into other deployment backends.
The ONNX Runtime (ORT) sits at the center of the deployment story. It supports multiple Execution Providers (EPs):
| Execution Provider | Hardware Target |
|---|---|
CPUExecutionProvider |
Any x86/ARM CPU |
CUDAExecutionProvider |
NVIDIA GPUs (CUDA) |
TensorrtExecutionProvider |
NVIDIA GPUs (TensorRT) |
ROCMExecutionProvider |
AMD GPUs |
CoreMLExecutionProvider |
Apple Silicon / iOS |
DirectMLExecutionProvider |
Windows GPU via DirectML |
OpenVINOExecutionProvider |
Intel CPUs, iGPUs, VPUs |
QNNExecutionProvider |
Qualcomm NPU |
Training a Computer Vision Model
PyTorch Workflow
We will train a simple ResNet-18-based image classifier on CIFAR-10 as a concrete example. The principles generalize to any architecture.
# train_cifar_pytorch.py
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
# ──────────────────────────────────────────────────────────────
# 1. Hyperparameters
# ──────────────────────────────────────────────────────────────
BATCH_SIZE = 128
NUM_EPOCHS = 20
LEARNING_RATE = 1e-3
NUM_CLASSES = 10
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ──────────────────────────────────────────────────────────────
# 2. Data pipeline
# ──────────────────────────────────────────────────────────────
train_transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, padding=4),
transforms.ToTensor(),
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
std= [0.2023, 0.1994, 0.2010]),
])
val_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
std= [0.2023, 0.1994, 0.2010]),
])
train_dataset = datasets.CIFAR10(root="./data", train=True,
download=True, transform=train_transform)
val_dataset = datasets.CIFAR10(root="./data", train=False,
download=True, transform=val_transform)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE,
shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE,
shuffle=False, num_workers=4, pin_memory=True)
# ──────────────────────────────────────────────────────────────
# 3. Model definition
# ──────────────────────────────────────────────────────────────
# ResNet-18 adapted for CIFAR-10's 32×32 inputs
model = models.resnet18(weights=None)
# Replace the first conv to handle small spatial dimensions
model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
model.maxpool = nn.Identity() # remove aggressive spatial downsampling
model.fc = nn.Linear(model.fc.in_features, NUM_CLASSES)
model = model.to(DEVICE)
# ──────────────────────────────────────────────────────────────
# 4. Loss, optimizer, scheduler
# ──────────────────────────────────────────────────────────────
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)
# ──────────────────────────────────────────────────────────────
# 5. Training loop
# ──────────────────────────────────────────────────────────────
def train_one_epoch(model, loader, criterion, optimizer, device):
model.train()
running_loss, correct, total = 0.0, 0, 0
for images, labels in loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad(set_to_none=True)
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item() * images.size(0)
correct += (outputs.argmax(1) == labels).sum().item()
total += images.size(0)
return running_loss / total, correct / total
def evaluate(model, loader, criterion, device):
model.eval()
running_loss, correct, total = 0.0, 0, 0
with torch.no_grad():
for images, labels in loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
running_loss += loss.item() * images.size(0)
correct += (outputs.argmax(1) == labels).sum().item()
total += images.size(0)
return running_loss / total, correct / total
best_val_acc = 0.0
for epoch in range(1, NUM_EPOCHS + 1):
train_loss, train_acc = train_one_epoch(model, train_loader,
criterion, optimizer, DEVICE)
val_loss, val_acc = evaluate(model, val_loader, criterion, DEVICE)
scheduler.step()
print(f"Epoch {epoch:02d}/{NUM_EPOCHS} | "
f"Train Loss: {train_loss:.4f} Acc: {train_acc:.4f} | "
f"Val Loss: {val_loss:.4f} Acc: {val_acc:.4f}")
if val_acc > best_val_acc:
best_val_acc = val_acc
torch.save(model.state_dict(), "best_resnet18_cifar10.pth")
print(f" Best validation accuracy: {best_val_acc:.4f}")TensorFlow / Keras Workflow
# train_cifar_tf.py
import tensorflow as tf
from tensorflow.keras import layers, models, optimizers, callbacks
# ──────────────────────────────────────────────────────────────
# 1. Load and preprocess data
# ──────────────────────────────────────────────────────────────
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
# Normalize to [0, 1]
x_train = x_train.astype("float32") / 255.0
x_test = x_test.astype("float32") / 255.0
# Channel-wise normalization (ImageNet-like stats repurposed for CIFAR)
mean = tf.constant([0.4914, 0.4822, 0.4465], dtype=tf.float32)
std = tf.constant([0.2023, 0.1994, 0.2010], dtype=tf.float32)
x_train = (x_train - mean) / std
x_test = (x_test - mean) / std
# ──────────────────────────────────────────────────────────────
# 2. Model: EfficientNetB0 with custom head
# ──────────────────────────────────────────────────────────────
base = tf.keras.applications.EfficientNetB0(
include_top=False,
weights=None,
input_shape=(32, 32, 3),
)
inputs = tf.keras.Input(shape=(32, 32, 3))
x = base(inputs, training=True)
x = layers.GlobalAveragePooling2D()(x)
x = layers.Dropout(0.3)(x)
outputs = layers.Dense(10, activation="softmax")(x)
model = tf.keras.Model(inputs, outputs)
model.compile(
optimizer=optimizers.Adam(learning_rate=1e-3),
loss="sparse_categorical_crossentropy",
metrics=["accuracy"],
)
# ──────────────────────────────────────────────────────────────
# 3. Training
# ──────────────────────────────────────────────────────────────
cb = [
callbacks.ReduceLROnPlateau(patience=5, factor=0.5, verbose=1),
callbacks.EarlyStopping(patience=10, restore_best_weights=True),
callbacks.ModelCheckpoint("best_efficientnet_cifar10.h5",
save_best_only=True),
]
model.fit(
x_train, y_train,
validation_data=(x_test, y_test),
epochs=50,
batch_size=128,
callbacks=cb,
)Exporting a Trained Model to ONNX
Exporting from PyTorch
PyTorch has two export pathways: the classic torch.onnx.export and the newer torch.onnx.dynamo_export (available since PyTorch 2.0). The dynamo path handles more complex dynamic models but is still maturing.
flowchart TD
A["Trained PyTorch Model (.pth weights)"] --> B{"Export Strategy?"}
B --> C["Tracing torch.onnx.export"]
B --> D["Dynamo torch.onnx.dynamo_export (PyTorch ≥ 2.0)"]
C --> E["Standard CNNs ResNet · EfficientNet · YOLO"]
C --> F["Fixed control flow no data-dependent branches"]
D --> G["Transformers ViT · DETR · CLIP"]
D --> H["Dynamic control flow data-dependent branches"]
E --> I["ONNX Model (.onnx)"]
F --> I
G --> I
H --> I
Classic Export (Tracing)
# export_pytorch_to_onnx.py
import torch
from torchvision import models
import torch.nn as nn
import onnx
# ──────────────────────────────────────────────────────────────
# 1. Reconstruct the model and load weights
# ──────────────────────────────────────────────────────────────
DEVICE = torch.device("cpu") # always export from CPU
model = models.resnet18(weights=None)
model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
model.maxpool = nn.Identity()
model.fc = nn.Linear(model.fc.in_features, 10)
model.load_state_dict(torch.load("best_resnet18_cifar10.pth", map_location=DEVICE))
# ──────────────────────────────────────────────────────────────
# 2. Set model to evaluation mode — CRITICAL
# This disables dropout and switches BatchNorm to eval statistics.
# ──────────────────────────────────────────────────────────────
model.eval()
# ──────────────────────────────────────────────────────────────
# 3. Create a representative dummy input
# Shape: (batch_size, channels, height, width)
# ──────────────────────────────────────────────────────────────
dummy_input = torch.randn(1, 3, 32, 32, device=DEVICE)
# ──────────────────────────────────────────────────────────────
# 4. Export
# ──────────────────────────────────────────────────────────────
ONNX_PATH = "resnet18_cifar10.onnx"
torch.onnx.export(
model,
dummy_input,
ONNX_PATH,
export_params=True, # store weights inside the .onnx file
opset_version=18, # target opset; 17–19 recommended
do_constant_folding=True, # fold constant expressions into weights
input_names=["images"], # name the input tensor(s)
output_names=["logits"], # name the output tensor(s)
dynamic_axes={ # mark batch dimension as dynamic
"images": {0: "batch_size"},
"logits": {0: "batch_size"},
},
verbose=False,
)
print(f"Model exported to {ONNX_PATH}")
# ──────────────────────────────────────────────────────────────
# 5. Quick sanity check
# ──────────────────────────────────────────────────────────────
onnx_model = onnx.load(ONNX_PATH)
onnx.checker.check_model(onnx_model)
print("ONNX model is valid ✓")Dynamo-Based Export (PyTorch ≥ 2.0)
import torch
import torch.onnx
# The dynamo exporter captures the full computational graph
# including Python control flow, which tracing cannot.
export_output = torch.onnx.dynamo_export(
model,
dummy_input,
)
export_output.save("resnet18_cifar10_dynamo.onnx")Tracing records a single execution path and may miss data-dependent control flow (e.g., if x.shape[0] > 1:). Dynamo (TorchDynamo + FX graph) captures the full Python graph. For standard CNN architectures, tracing is simpler and more mature. For transformer models with dynamic attention patterns, dynamo is preferred.
Exporting from TensorFlow / Keras
pip install tf2onnx# export_tf_to_onnx.py
import tensorflow as tf
import tf2onnx
import onnx
# Load the saved Keras model
model = tf.keras.models.load_model("best_efficientnet_cifar10.h5")
# Specify the input signature explicitly for reliable export
input_signature = [
tf.TensorSpec(shape=[None, 32, 32, 3], dtype=tf.float32, name="images")
]
# Convert to ONNX
onnx_model, _ = tf2onnx.convert.from_keras(
model,
input_signature=input_signature,
opset=18,
output_path="efficientnet_cifar10.onnx",
)
print("TensorFlow model successfully converted to ONNX ✓")You can also convert from a TensorFlow SavedModel directory:
python -m tf2onnx.convert \
--saved-model ./saved_model_dir \
--output efficientnet_cifar10.onnx \
--opset 18 \
--inputs images:0[batch,32,32,3] \
--outputs softmax:0Exporting from scikit-learn (sklearn-onnx)
While scikit-learn models are rarely used for deep vision, they appear in feature-based vision pipelines (e.g., HOG + SVM).
pip install skl2onnxfrom sklearn.svm import SVC
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
import skl2onnx
from skl2onnx import convert_sklearn
from skl2onnx.common.data_types import FloatTensorType
# Assume `pipeline` is a trained sklearn Pipeline
# with input features of dimension 1764 (HOG features from 32x32 images)
initial_type = [("float_input", FloatTensorType([None, 1764]))]
onnx_model = convert_sklearn(pipeline, initial_types=initial_type,
target_opset=18)
with open("hog_svm.onnx", "wb") as f:
f.write(onnx_model.SerializeToString())Validating and Inspecting the ONNX Model
Before deploying, always validate and inspect the exported model. Subtle bugs in export (wrong opset, un-exported operators, shape errors) can silently produce wrong predictions.
flowchart TD
A["Exported .onnx file"] --> B["Structural Validation onnx.checker.check_model"]
B --> C{"Valid?"}
C -- No --> D["Fix export: check opset, custom ops, eval mode"]
D --> A
C -- Yes --> E["Shape Inference onnx.shape_inference.infer_shapes"]
E --> F["Numerical Validation Compare ORT vs source framework"]
F --> G{"Max diff < 1e-4?"}
G -- No --> H["Investigate: NHWC/NCHW mismatch, Dropout not disabled, opset operator gap"]
H --> A
G -- Yes --> I["Visual Inspection Netron"]
I --> J["Model Ready for Optimization"]
Structural Validation
import onnx
model = onnx.load("resnet18_cifar10.onnx")
# Full graph validity check (type-checking, shape propagation)
onnx.checker.check_model(model, full_check=True)
# Print a human-readable summary
print(onnx.helper.printable_graph(model.graph))Inspecting Model Metadata
import onnx
model = onnx.load("resnet18_cifar10.onnx")
print(f"IR version: {model.ir_version}")
print(f"Opset imports: {[op.version for op in model.opset_import]}")
print(f"Graph name: {model.graph.name}")
print(f"Inputs:")
for inp in model.graph.input:
shape = [d.dim_value or d.dim_param
for d in inp.type.tensor_type.shape.dim]
print(f" {inp.name}: {shape}")
print(f"Outputs:")
for out in model.graph.output:
shape = [d.dim_value or d.dim_param
for d in out.type.tensor_type.shape.dim]
print(f" {out.name}: {shape}")Shape Inference
ONNX can propagate shapes through the graph without running it:
import onnx
from onnx import shape_inference
model = onnx.load("resnet18_cifar10.onnx")
inferred = shape_inference.infer_shapes(model)
onnx.save(inferred, "resnet18_cifar10_inferred.onnx")
print("Shape inference complete. Intermediate shapes are now annotated.")Numerical Validation Against the Source Framework
This is the most important validation step—compare ONNX Runtime outputs against the original framework:
import numpy as np
import torch
import onnxruntime as ort
from torchvision import models
import torch.nn as nn
# ── Original PyTorch model ──
DEVICE = torch.device("cpu")
pt_model = models.resnet18(weights=None)
pt_model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
pt_model.maxpool = nn.Identity()
pt_model.fc = nn.Linear(pt_model.fc.in_features, 10)
pt_model.load_state_dict(torch.load("best_resnet18_cifar10.pth", map_location=DEVICE))
pt_model.eval()
# ── ONNX Runtime session ──
sess = ort.InferenceSession("resnet18_cifar10.onnx",
providers=["CPUExecutionProvider"])
# ── Generate random test batch ──
np.random.seed(42)
dummy_np = np.random.randn(4, 3, 32, 32).astype(np.float32)
dummy_pt = torch.from_numpy(dummy_np)
# ── Run both ──
with torch.no_grad():
pt_out = pt_model(dummy_pt).numpy()
ort_out = sess.run(None, {"images": dummy_np})[0]
# ── Compare ──
max_diff = np.abs(pt_out - ort_out).max()
print(f"Max absolute difference: {max_diff:.2e}")
assert max_diff < 1e-4, f"Outputs diverge! Max diff = {max_diff}"
print("Numerical validation passed ✓")Visual Inspection with Netron
Netron is a browser-based ONNX graph visualizer. Simply drag and drop your .onnx file to see the full operator graph, tensor shapes, and weight statistics. It supports all major model formats (ONNX, TFLite, CoreML, PyTorch, etc.).
Optimizing the ONNX Model
Graph Optimizations
ONNX Runtime applies optimizations automatically during session creation. You can also apply offline optimizations.
flowchart LR
A["FP32 ONNX Model"] --> B["Graph Optimization ORT_ENABLE_ALL"]
B --> C["Constant Folding pre-compute static subgraphs"]
B --> D["Redundant Node Elimination no-op Reshape, Identity"]
B --> E["Operator Fusion Conv + BN + ReLU → single kernel"]
B --> F["Layout Optimization NHWC ↔ NCHW reordering"]
C & D & E & F --> G["Optimized FP32 Model"]
G --> H{"Need further speedup?"}
H -- "Yes, latency-critical" --> I["Static INT8 Quantization + calibration dataset"]
H -- "Yes, no calib data" --> J["Dynamic INT8 Quantization weights only"]
H -- "No" --> K["Deploy"]
I --> K
J --> K
from onnxruntime.transformers import optimizer as ort_optimizer
from onnxruntime import SessionOptions, GraphOptimizationLevel, InferenceSession
# ── Option 1: Let ORT apply optimizations at session creation ──
opts = SessionOptions()
# Levels: DISABLE_ALL, ENABLE_BASIC, ENABLE_EXTENDED, ENABLE_ALL
opts.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
# Save the optimized graph to disk for inspection
opts.optimized_model_filepath = "resnet18_cifar10_optimized.onnx"
sess = InferenceSession(
"resnet18_cifar10.onnx",
sess_options=opts,
providers=["CPUExecutionProvider"],
)
print("Optimized model saved to resnet18_cifar10_optimized.onnx")The optimizations applied include:
- Constant folding: Pre-compute subgraphs with only constant inputs
- Redundant node elimination: Remove no-op Reshape, Identity, etc.
- Operator fusion: Fuse Conv + BatchNorm + Relu into a single kernel
- Layout optimization: Reorder memory layouts for cache efficiency (NHWC → NCHW or vice versa depending on EP)
Quantization
Quantization reduces model size and improves inference speed (often 2–4×) by converting float32 weights and/or activations to int8 or uint8.
Post-Training Static Quantization (PTQ)
Static quantization requires a calibration dataset to compute the activation ranges.
# quantize_static.py
import numpy as np
from onnxruntime.quantization import (
quantize_static,
CalibrationDataReader,
QuantFormat,
QuantType,
)
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# ──────────────────────────────────────────────────────────────
# 1. Calibration data reader
# ──────────────────────────────────────────────────────────────
class CIFAR10CalibReader(CalibrationDataReader):
def __init__(self, num_batches: int = 20, batch_size: int = 32):
val_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.4914, 0.4822, 0.4465],
[0.2023, 0.1994, 0.2010]),
])
dataset = datasets.CIFAR10("./data", train=False,
transform=val_transform)
self.loader = iter(
DataLoader(dataset, batch_size=batch_size, shuffle=False)
)
self.num_batches = num_batches
self.count = 0
def get_next(self):
if self.count >= self.num_batches:
return None
try:
images, _ = next(self.loader)
self.count += 1
return {"images": images.numpy()}
except StopIteration:
return None
# ──────────────────────────────────────────────────────────────
# 2. Quantize
# ──────────────────────────────────────────────────────────────
quantize_static(
model_input="resnet18_cifar10_optimized.onnx",
model_output="resnet18_cifar10_int8.onnx",
calibration_data_reader=CIFAR10CalibReader(num_batches=20),
quant_format=QuantFormat.QDQ, # QDQ or QOperator
activation_type=QuantType.QUInt8,
weight_type=QuantType.QInt8,
per_channel=True,
reduce_range=False,
)
print("Static INT8 quantization complete ✓")Post-Training Dynamic Quantization (faster, no calibration data needed)
from onnxruntime.quantization import quantize_dynamic, QuantType
quantize_dynamic(
model_input="resnet18_cifar10_optimized.onnx",
model_output="resnet18_cifar10_dynamic_int8.onnx",
weight_type=QuantType.QInt8,
per_channel=True,
)
print("Dynamic INT8 quantization complete ✓")Dynamic quantization only quantizes weights ahead of time; activations are quantized at runtime. No calibration data is needed. Works well for transformer layers (Gemm / MatMul) but is less effective for convolutions.
Static quantization quantizes both weights and activations using pre-computed scale/zero-point from a calibration dataset. Faster inference, especially for CNNs, but requires a representative calibration set.
Pruning Before Export
For maximum compression, prune the model before exporting to ONNX. PyTorch’s torch.nn.utils.prune module makes this straightforward.
import torch
import torch.nn.utils.prune as prune
# Apply magnitude-based unstructured pruning to all Conv2d layers
for name, module in model.named_modules():
if isinstance(module, torch.nn.Conv2d):
prune.l1_unstructured(module, name="weight", amount=0.3) # 30% sparsity
prune.remove(module, "weight") # make permanent
# Fine-tune the pruned model for a few epochs, then export to ONNX
# ... (fine-tuning loop) ...
torch.onnx.export(model, dummy_input, "resnet18_pruned.onnx", opset_version=18)Note that unstructured pruning introduces sparsity but does not reduce parameter count in standard dense ONNX kernels. To get actual speedup, you need either structured pruning (whole channels) or a sparse execution provider.
Running Inference with ONNX Runtime
Basic Inference Session
# infer_basic.py
import numpy as np
import onnxruntime as ort
from PIL import Image
import torchvision.transforms as T
# ──────────────────────────────────────────────────────────────
# 1. Create the inference session
# ──────────────────────────────────────────────────────────────
sess = ort.InferenceSession(
"resnet18_cifar10.onnx",
providers=["CPUExecutionProvider"],
)
# ──────────────────────────────────────────────────────────────
# 2. Inspect input/output metadata
# ──────────────────────────────────────────────────────────────
for inp in sess.get_inputs():
print(f"Input name={inp.name!r} shape={inp.shape} dtype={inp.type}")
for out in sess.get_outputs():
print(f"Output name={out.name!r} shape={out.shape} dtype={out.type}")
# ──────────────────────────────────────────────────────────────
# 3. Preprocess a single image
# ──────────────────────────────────────────────────────────────
CLASSES = ["airplane", "automobile", "bird", "cat", "deer",
"dog", "frog", "horse", "ship", "truck"]
transform = T.Compose([
T.Resize((32, 32)),
T.ToTensor(),
T.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
])
image = Image.open("test_image.jpg").convert("RGB")
tensor = transform(image).unsqueeze(0).numpy() # shape: (1, 3, 32, 32)
# ──────────────────────────────────────────────────────────────
# 4. Run inference
# ──────────────────────────────────────────────────────────────
input_name = sess.get_inputs()[0].name # "images"
outputs = sess.run(None, {input_name: tensor})
logits = outputs[0] # shape: (1, 10)
# ──────────────────────────────────────────────────────────────
# 5. Decode prediction
# ──────────────────────────────────────────────────────────────
probabilities = np.exp(logits) / np.exp(logits).sum(axis=-1, keepdims=True)
predicted_class = probabilities.argmax(axis=-1)[0]
confidence = probabilities[0, predicted_class]
print(f"Predicted: {CLASSES[predicted_class]} ({confidence:.1%} confidence)")Configuring Session Options
SessionOptions is how you tune ORT’s behavior:
import onnxruntime as ort
opts = ort.SessionOptions()
# Threading
opts.intra_op_num_threads = 4 # threads within a single operator (e.g., matrix mul)
opts.inter_op_num_threads = 2 # threads across independent operators
# Memory
opts.enable_cpu_mem_arena = True # pre-allocate a memory arena
opts.enable_mem_pattern = True # reuse memory across runs (same input shapes)
opts.enable_mem_reuse = True
# Logging
opts.log_severity_level = 3 # 0=VERBOSE, 1=INFO, 2=WARNING, 3=ERROR, 4=FATAL
# Graph optimization (see above for levels)
opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
# Profiling — dumps a JSON Chrome trace file
opts.enable_profiling = False
# opts.profile_file_prefix = "ort_profile"
sess = ort.InferenceSession(
"resnet18_cifar10.onnx",
sess_options=opts,
providers=["CPUExecutionProvider"],
)Execution Providers
ORT tries each EP in the order you provide them. Operators that an EP cannot handle fall back to the next EP in the list.
flowchart TD
A["Inference Request"] --> B["Try EP #1 e.g. CUDAExecutionProvider"]
B --> C{"Operator supported?"}
C -- Yes --> D["Run on GPU"]
C -- No --> E["Try EP #2 e.g. CPUExecutionProvider"]
E --> F{"Operator supported?"}
F -- Yes --> G["Run on CPU"]
F -- No --> H["RuntimeError: No EP can handle operator"]
D --> I["Output Tensor"]
G --> I
import onnxruntime as ort
# List EPs available on this machine
print(ort.get_available_providers())
# e.g.: ['TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider']
# Use CUDA with CPU fallback
sess = ort.InferenceSession(
"resnet18_cifar10.onnx",
providers=[
("CUDAExecutionProvider", {
"device_id": 0,
"arena_extend_strategy": "kNextPowerOfTwo",
"gpu_mem_limit": 2 * 1024 ** 3, # 2 GB
"cudnn_conv_algo_search": "EXHAUSTIVE",
"do_copy_in_default_stream": True,
}),
"CPUExecutionProvider",
],
)GPU Inference with CUDA
# gpu_inference.py
import numpy as np
import onnxruntime as ort
# ── Create CUDA session ──
providers = [
("CUDAExecutionProvider", {"device_id": 0}),
"CPUExecutionProvider",
]
sess = ort.InferenceSession("resnet18_cifar10.onnx", providers=providers)
# Confirm which EP owns the compute
print("Active providers:", sess.get_providers())
# ── IO Binding — zero-copy for GPU tensors ──
# This avoids an implicit host↔device copy for each run() call.
io_binding = sess.io_binding()
import torch
# Allocate input tensor directly on GPU
gpu_input = torch.randn(8, 3, 32, 32, device="cuda").contiguous()
io_binding.bind_input(
name="images",
device_type="cuda",
device_id=0,
element_type=np.float32,
shape=tuple(gpu_input.shape),
buffer_ptr=gpu_input.data_ptr(),
)
# Allocate output tensor
gpu_output = torch.empty(8, 10, device="cuda").contiguous()
io_binding.bind_output(
name="logits",
device_type="cuda",
device_id=0,
element_type=np.float32,
shape=(8, 10),
buffer_ptr=gpu_output.data_ptr(),
)
# Run without any host↔device copies
sess.run_with_iobinding(io_binding)
logits = gpu_output.cpu().numpy()
print("GPU inference output shape:", logits.shape)Batch Inference
Processing images in batches amortizes kernel launch overhead and maximizes hardware utilization.
# batch_inference.py
import numpy as np
import onnxruntime as ort
from pathlib import Path
from PIL import Image
import torchvision.transforms as T
from typing import List
CLASSES = ["airplane", "automobile", "bird", "cat", "deer",
"dog", "frog", "horse", "ship", "truck"]
transform = T.Compose([
T.Resize((32, 32)),
T.ToTensor(),
T.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
])
def preprocess_batch(image_paths: List[Path]) -> np.ndarray:
tensors = []
for p in image_paths:
img = Image.open(p).convert("RGB")
tensors.append(transform(img).numpy())
return np.stack(tensors, axis=0) # (N, 3, 32, 32)
def infer_batch(sess: ort.InferenceSession,
image_paths: List[Path],
batch_size: int = 32) -> List[str]:
input_name = sess.get_inputs()[0].name
predictions = []
for i in range(0, len(image_paths), batch_size):
batch_paths = image_paths[i : i + batch_size]
batch_np = preprocess_batch(batch_paths)
logits = sess.run(None, {input_name: batch_np})[0]
batch_preds = logits.argmax(axis=-1).tolist()
predictions.extend([CLASSES[p] for p in batch_preds])
return predictions
# Usage
sess = ort.InferenceSession("resnet18_cifar10.onnx",
providers=["CPUExecutionProvider"])
image_files = list(Path("./test_images").glob("*.jpg"))
results = infer_batch(sess, image_files, batch_size=64)
for path, pred in zip(image_files, results):
print(f"{path.name}: {pred}")Preprocessing and Postprocessing Pipelines
Image Classification
The complete classification pipeline, including softmax and top-k decoding:
import numpy as np
import onnxruntime as ort
from PIL import Image
import torchvision.transforms as T
def softmax(x: np.ndarray, axis: int = -1) -> np.ndarray:
e = np.exp(x - x.max(axis=axis, keepdims=True))
return e / e.sum(axis=axis, keepdims=True)
def classify_topk(model_path: str, image_path: str,
class_names: list, k: int = 5):
sess = ort.InferenceSession(model_path,
providers=["CPUExecutionProvider"])
transform = T.Compose([
T.Resize((32, 32)),
T.ToTensor(),
T.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
])
img = Image.open(image_path).convert("RGB")
inp = transform(img).unsqueeze(0).numpy()
logits = sess.run(None, {"images": inp})[0][0] # (10,)
probs = softmax(logits)
top_k = probs.argsort()[::-1][:k]
print(f"Top-{k} predictions:")
for rank, idx in enumerate(top_k, 1):
print(f" {rank}. {class_names[idx]:<15} {probs[idx]:.2%}")Object Detection
For models like YOLOv8 or DETR exported to ONNX, the postprocessing involves non-maximum suppression (NMS) and bounding box decoding.
flowchart TD
A["Input Image BGR uint8"] --> B["BGR → RGB cv2.cvtColor"]
B --> C["Letterbox Resize 640×640 with padding"]
C --> D["Normalize ÷255 → float32"]
D --> E["HWC → CHW np.transpose"]
E --> F["Add Batch Dim np.expand_dims"]
F --> G["ORT Session sess.run()"]
G --> H["Raw Output (1, 84, 8400)"]
H --> I["Transpose → (8400, 84) boxes xywh + class scores"]
I --> J["Confidence Filter score ≥ threshold"]
J --> K["xywh → xyxy bounding box decode"]
K --> L["NMS IoU-based deduplication"]
L --> M["Final Detections boxes · scores · class IDs"]
# yolo_inference.py — demonstrates the postprocessing pattern
import numpy as np
import onnxruntime as ort
import cv2
from typing import List, Tuple
def letterbox(image: np.ndarray, target_size: Tuple[int, int] = (640, 640),
fill_value: int = 114) -> Tuple[np.ndarray, float, Tuple[int, int]]:
"""Resize with preserved aspect ratio and pad to square."""
h, w = image.shape[:2]
th, tw = target_size
ratio = min(th / h, tw / w)
new_h, new_w = int(h * ratio), int(w * ratio)
resized = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
pad_top = (th - new_h) // 2
pad_left = (tw - new_w) // 2
padded = np.full((th, tw, 3), fill_value, dtype=np.uint8)
padded[pad_top:pad_top + new_h, pad_left:pad_left + new_w] = resized
return padded, ratio, (pad_left, pad_top)
def preprocess_detection(image_bgr: np.ndarray) -> Tuple[np.ndarray, float, tuple]:
"""Convert BGR OpenCV image to ONNX-ready float32 tensor."""
image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
padded, ratio, padding = letterbox(image_rgb, (640, 640))
blob = padded.astype(np.float32) / 255.0
blob = np.transpose(blob, (2, 0, 1)) # HWC → CHW
blob = np.expand_dims(blob, axis=0) # add batch dim
return blob, ratio, padding
def nms(boxes: np.ndarray, scores: np.ndarray,
iou_threshold: float = 0.45) -> List[int]:
"""Simple greedy NMS."""
x1, y1, x2, y2 = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3]
areas = (x2 - x1) * (y2 - y1)
order = scores.argsort()[::-1]
keep = []
while order.size > 0:
i = order[0]
keep.append(i)
xx1 = np.maximum(x1[i], x1[order[1:]])
yy1 = np.maximum(y1[i], y1[order[1:]])
xx2 = np.minimum(x2[i], x2[order[1:]])
yy2 = np.minimum(y2[i], y2[order[1:]])
inter = np.maximum(0, xx2 - xx1) * np.maximum(0, yy2 - yy1)
iou = inter / (areas[i] + areas[order[1:]] - inter)
order = order[1:][iou <= iou_threshold]
return keep
def detect(model_path: str, image_path: str,
conf_threshold: float = 0.25):
sess = ort.InferenceSession(model_path,
providers=["CUDAExecutionProvider",
"CPUExecutionProvider"])
image_bgr = cv2.imread(image_path)
blob, ratio, (pad_left, pad_top) = preprocess_detection(image_bgr)
input_name = sess.get_inputs()[0].name
output_name = sess.get_outputs()[0].name
raw_output = sess.run([output_name], {input_name: blob})[0]
# YOLOv8 output shape: (1, 84, 8400) — [batch, 4+num_classes, anchors]
predictions = raw_output[0].T # (8400, 84)
boxes_xywh = predictions[:, :4]
class_scores = predictions[:, 4:]
confidences = class_scores.max(axis=1)
class_ids = class_scores.argmax(axis=1)
mask = confidences >= conf_threshold
boxes_xywh = boxes_xywh[mask]
confidences = confidences[mask]
class_ids = class_ids[mask]
# xywh → xyxy
boxes_xyxy = np.empty_like(boxes_xywh)
boxes_xyxy[:, 0] = boxes_xywh[:, 0] - boxes_xywh[:, 2] / 2
boxes_xyxy[:, 1] = boxes_xywh[:, 1] - boxes_xywh[:, 3] / 2
boxes_xyxy[:, 2] = boxes_xywh[:, 0] + boxes_xywh[:, 2] / 2
boxes_xyxy[:, 3] = boxes_xywh[:, 1] + boxes_xywh[:, 3] / 2
keep = nms(boxes_xyxy, confidences)
print(f"Detected {len(keep)} objects")
return boxes_xyxy[keep], confidences[keep], class_ids[keep]Semantic Segmentation
# segmentation_inference.py
import numpy as np
import onnxruntime as ort
import cv2
def run_segmentation(model_path: str, image_path: str,
input_size: tuple = (512, 512),
num_classes: int = 21): # VOC Pascal classes
sess = ort.InferenceSession(model_path,
providers=["CUDAExecutionProvider",
"CPUExecutionProvider"])
image = cv2.imread(image_path)
original_shape = image.shape[:2]
# Preprocess
resized = cv2.resize(image, input_size)
blob = resized[:, :, ::-1].astype(np.float32) # BGR → RGB
blob = (blob / 255.0 - np.array([0.485, 0.456, 0.406])) \
/ np.array([0.229, 0.224, 0.225])
blob = np.transpose(blob, (2, 0, 1))[np.newaxis] # (1, 3, H, W)
# Inference
input_name = sess.get_inputs()[0].name
output = sess.run(None, {input_name: blob})[0]
# output shape: (1, num_classes, H, W)
seg_map = output[0].argmax(axis=0).astype(np.uint8) # (H, W)
# Resize back to original
seg_map = cv2.resize(seg_map, (original_shape[1], original_shape[0]),
interpolation=cv2.INTER_NEAREST)
# Colorize for visualization
palette = np.random.randint(0, 255, (num_classes, 3), dtype=np.uint8)
colorized = palette[seg_map]
blended = cv2.addWeighted(image, 0.6, colorized, 0.4, 0)
cv2.imwrite("segmentation_result.png", blended)
print(f"Segmentation map shape: {seg_map.shape}")
return seg_mapBenchmarking and Profiling
Latency and Throughput Benchmark
# benchmark.py
import time
import numpy as np
import onnxruntime as ort
def benchmark(model_path: str,
input_shape: tuple = (1, 3, 32, 32),
warmup_runs: int = 20,
benchmark_runs: int = 200,
providers: list = None):
if providers is None:
providers = ["CPUExecutionProvider"]
sess = ort.InferenceSession(model_path, providers=providers)
input_name = sess.get_inputs()[0].name
dummy = np.random.randn(*input_shape).astype(np.float32)
# Warm-up
for _ in range(warmup_runs):
sess.run(None, {input_name: dummy})
# Benchmark
latencies = []
for _ in range(benchmark_runs):
t0 = time.perf_counter()
sess.run(None, {input_name: dummy})
latencies.append((time.perf_counter() - t0) * 1000) # ms
latencies = np.array(latencies)
batch_size = input_shape[0]
fps = batch_size / (latencies.mean() / 1000)
print(f"Model: {model_path}")
print(f"Providers: {providers}")
print(f"Batch size: {batch_size}")
print(f"Latency — mean: {latencies.mean():.2f} ms "
f"p50: {np.percentile(latencies, 50):.2f} ms "
f"p99: {np.percentile(latencies, 99):.2f} ms")
print(f"Throughput: {fps:.1f} images/sec")
return latencies
# Compare FP32 vs INT8
benchmark("resnet18_cifar10.onnx", input_shape=(1, 3, 32, 32))
benchmark("resnet18_cifar10_int8.onnx", input_shape=(1, 3, 32, 32))
benchmark("resnet18_cifar10.onnx", input_shape=(32, 3, 32, 32)) # batch=32Profiling Operator Timings
import onnxruntime as ort
import json
opts = ort.SessionOptions()
opts.enable_profiling = True
opts.profile_file_prefix = "ort_profile"
sess = ort.InferenceSession("resnet18_cifar10.onnx",
sess_options=opts,
providers=["CPUExecutionProvider"])
dummy = np.random.randn(1, 3, 32, 32).astype(np.float32)
for _ in range(50):
sess.run(None, {"images": dummy})
profile_path = sess.end_profiling()
print(f"Profile saved to: {profile_path}")
# Load and inspect top-N slowest operators
with open(profile_path) as f:
events = json.load(f)
op_events = [e for e in events if e.get("cat") == "Node"]
op_events.sort(key=lambda e: e["dur"], reverse=True)
print(" Top-10 slowest operators (microseconds):")
for ev in op_events[:10]:
print(f" {ev['name']:<50} {ev['dur']:>8} µs")Deploying ONNX Models
flowchart TD
A["Optimized ONNX Model"] --> B{"Deployment Target?"}
B --> C["Cloud / Server"]
B --> D["Edge / Embedded"]
B --> E["Browser"]
B --> F["Mobile"]
C --> C1["FastAPI + ONNX Runtime CUDA or CPU EP"]
D --> D1{"Hardware?"}
D1 --> D2["NVIDIA Jetson CUDA EP / TensorRT EP"]
D1 --> D3["ARM CPU Raspberry Pi CPU EP + NEON"]
D1 --> D4["Intel CPU/VPU OpenVINO EP"]
E --> E1["onnxruntime-web WASM or WebGL"]
F --> F1{"Platform?"}
F1 --> F2["Android QNN EP / NNAPI"]
F1 --> F3["iOS / macOS CoreML EP"]
Python Service (FastAPI)
# app.py — production-ready FastAPI inference server
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import JSONResponse
from contextlib import asynccontextmanager
from PIL import Image
import numpy as np
import onnxruntime as ort
import torchvision.transforms as T
import io, logging
logger = logging.getLogger(__name__)
CLASSES = ["airplane", "automobile", "bird", "cat", "deer",
"dog", "frog", "horse", "ship", "truck"]
transform = T.Compose([
T.Resize((32, 32)),
T.ToTensor(),
T.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
])
# ── Global session (initialized once at startup) ──
session: ort.InferenceSession | None = None
@asynccontextmanager
async def lifespan(app: FastAPI):
global session
logger.info("Loading ONNX model…")
opts = ort.SessionOptions()
opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
opts.intra_op_num_threads = 4
session = ort.InferenceSession(
"resnet18_cifar10.onnx",
sess_options=opts,
providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
)
logger.info("Model loaded ✓")
yield
session = None
app = FastAPI(title="CIFAR-10 Classifier", lifespan=lifespan)
def softmax(x: np.ndarray) -> np.ndarray:
e = np.exp(x - x.max())
return e / e.sum()
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
if session is None:
raise HTTPException(status_code=503, detail="Model not ready")
try:
data = await file.read()
image = Image.open(io.BytesIO(data)).convert("RGB")
except Exception as exc:
raise HTTPException(status_code=400, detail=f"Invalid image: {exc}")
tensor = transform(image).unsqueeze(0).numpy()
logits = session.run(None, {"images": tensor})[0][0]
probs = softmax(logits)
top5_idx = probs.argsort()[::-1][:5]
top5 = [{"class": CLASSES[i], "probability": float(probs[i])}
for i in top5_idx]
return JSONResponse({"top5": top5, "predicted": CLASSES[top5_idx[0]]})
# Run with: uvicorn app:app --host 0.0.0.0 --port 8000 --workers 1Edge Devices and Mobile
For edge deployment (Raspberry Pi, NVIDIA Jetson, Android, iOS), the recommended approach is:
ARM CPU: Use
onnxruntimePython package or the C/C++ shared library. ORT’s CPU provider is highly optimized via MLAS (Microsoft Linear Algebra Subprograms) and uses NEON intrinsics on ARM.NVIDIA Jetson: Install
onnxruntime-gpubuilt for JetPack (ARM64 + CUDA). Alternatively, convert to TensorRT via the TensorRT EP.Qualcomm SoC (Android): Use the QNN execution provider with
onnxruntime-androidAAR.Apple Silicon / iOS: Use
CoreMLExecutionProvider(available on macOS 12+ / iOS 15+).
# CoreML on Apple Silicon
sess = ort.InferenceSession(
"resnet18_cifar10.onnx",
providers=[
("CoreMLExecutionProvider", {
"MLComputeUnits": "ALL", # CPU | GPU | NeuralEngine
}),
"CPUExecutionProvider",
],
)ONNX Runtime Web (Browser)
ONNX Runtime Web (onnxruntime-web) runs ONNX models in a browser using WebAssembly (WASM) or WebGL.
npm install onnxruntime-web// classifier.js
import * as ort from 'onnxruntime-web';
async function classifyImage(imageData) {
// Load model once and cache the session
const session = await ort.InferenceSession.create('./resnet18_cifar10.onnx', {
executionProviders: ['webgl'], // or 'wasm' for CPU
graphOptimizationLevel: 'all',
});
// Preprocess: imageData is a Float32Array of shape [1, 3, 32, 32]
const tensor = new ort.Tensor('float32', imageData, [1, 3, 32, 32]);
const results = await session.run({ images: tensor });
const logits = results['logits'].data; // Float32Array of length 10
// Softmax and argmax
const max = Math.max(...logits);
const exps = logits.map(v => Math.exp(v - max));
const sum = exps.reduce((a, b) => a + b, 0);
const probs = exps.map(v => v / sum);
const classIdx = probs.indexOf(Math.max(...probs));
const CLASSES = ['airplane','automobile','bird','cat','deer',
'dog','frog','horse','ship','truck'];
console.log(`Predicted: ${CLASSES[classIdx]} (${(probs[classIdx]*100).toFixed(1)}%)`);
}Common Pitfalls and Troubleshooting
flowchart TD
A["Inference produces bad results or crashes"] --> B{"Error type?"}
B --> C["Inconsistent predictions low accuracy"]
B --> D["InvalidGraph: opset version error"]
B --> E["Shape mismatch at runtime"]
B --> F["Unregistered op or custom op error"]
B --> G["Quantization fails shape undefined"]
B --> H["NHWC/NCHW garbage output"]
B --> I["Slow first call only"]
C --> C1["Fix: call model.eval() before export"]
D --> D1["Fix: lower opset_version to 17 or 18"]
E --> E1["Fix: add dynamic_axes to export call"]
F --> F1["Fix: rewrite with standard ONNX primitives or register custom ORT kernel"]
G --> G1["Fix: run shape_inference before quantize_static"]
H --> H1["Fix: transpose input NCHW ↔ NHWC"]
I --> I1["Fix: add warm-up inference calls"]
1. Model Not in Eval Mode Before Export
Symptom: Predictions are wildly inconsistent; accuracy in ORT is lower than during training.
Cause: torch.nn.Dropout is active in training mode, randomly zeroing out activations. BatchNorm uses running stats in eval mode but batch stats in train mode.
Fix: Always call model.eval() before torch.onnx.export().
2. Opset Mismatch
Symptom: onnxruntime.capi.onnxruntime_pybind11_state.InvalidGraph: Node has unsupported opset version.
Cause: Exporting to a higher opset than ORT supports, or using operators only available in newer opsets.
Fix: Check ort.get_available_providers() and consult the ORT opset compatibility matrix. Use opset_version=17 or 18 for broad compatibility.
3. Fixed Batch Size in the Model
Symptom: Running a batch of 8 images on a model exported with dummy_input = torch.randn(1, 3, 32, 32) fails with a shape error.
Fix: Use dynamic_axes when exporting:
torch.onnx.export(
model, dummy_input, path,
dynamic_axes={"images": {0: "batch"}, "logits": {0: "batch"}},
)4. Custom / Unsupported Operators
Symptom: onnxruntime.capi.onnxruntime_pybind11_state.InvalidGraph: Node ... is not a registered function/op.
Cause: Your model uses a PyTorch custom op or a very new operator not yet in ORT.
Fix: Either rewrite the custom op using standard ONNX primitives, or implement a custom ONNX Runtime operator in C++.
5. Shape Inference Failures in Quantization
Symptom: Quantization fails with ValueError: Shape of input is not fully defined.
Fix: Run shape inference before quantization:
from onnxruntime.quantization import shape_inference
shape_inference.quant_pre_process("model.onnx", "model_inferred.onnx")
# Then quantize model_inferred.onnx6. Memory Layout Mismatches (NHWC vs. NCHW)
Symptom: Garbage outputs when running TensorFlow-exported models.
Cause: TensorFlow uses NHWC (batch, height, width, channels) by default. PyTorch uses NCHW. The export may not reorder axes correctly.
Fix: Explicitly transpose your NumPy array before feeding it to ORT, or add a Transpose node to the ONNX graph.
# TF model exported with NHWC — transpose input accordingly
image_nhwc = image_nchw.transpose(0, 2, 3, 1) # NCHW → NHWC
sess.run(None, {"input": image_nhwc})7. Slow Cold-Start
Symptom: First inference call takes 500ms, subsequent calls are fast.
Cause: ORT performs JIT compilation, kernel selection, and memory arena allocation on the first run.
Fix: Run several warm-up inferences before measuring latency or serving traffic.
Advanced Topics
Dynamic Axes and Variable Batch Sizes
Marking axes as symbolic allows one exported model to handle any input size:
torch.onnx.export(
model, dummy_input, "model.onnx",
dynamic_axes={
"images": {0: "batch_size", 2: "height", 3: "width"},
"logits": {0: "batch_size"},
},
)At inference time:
# Works for any batch size and any spatial resolution
sess.run(None, {"images": np.random.randn(16, 3, 224, 224).astype(np.float32)})
sess.run(None, {"images": np.random.randn(1, 3, 512, 512).astype(np.float32)})Not all models tolerate fully dynamic spatial axes. Models with fixed positional embeddings (e.g., ViT) may require a fixed spatial resolution and will silently produce wrong results if given an unexpected image size.
Custom Operators
If your model uses a custom PyTorch operator, you can register a corresponding ONNX operator and ORT kernel:
# Step 1: Register a custom symbolic function in PyTorch
from torch.onnx import register_custom_op_symbolic
def my_custom_op_symbolic(g, input, weight):
return g.op("custom_domain::MyCustomOp", input, weight)
register_custom_op_symbolic("my_package::my_custom_op", my_custom_op_symbolic, 1)
# Step 2: Implement the ORT kernel in C++ and compile as a shared library
# Step 3: Load the custom op library in ORT
opts = ort.SessionOptions()
opts.register_custom_ops_library("./libmy_custom_ops.so")
sess = ort.InferenceSession("model_with_custom_op.onnx", sess_options=opts)ONNX Training API
ONNX Runtime has an experimental Training API that allows on-device fine-tuning without a separate framework dependency — useful for federated learning and on-device personalization.
# Experimental — requires onnxruntime-training package
from onnxruntime.training import api as orttraining
# Export training artifacts from PyTorch
from onnxruntime.training.ortmodule import ORTModule
import torch.nn as nn
model = YourModel()
ort_model = ORTModule(model) # wraps the model; ORT handles the backward pass
# Training loop is identical to standard PyTorch
optimizer = torch.optim.Adam(ort_model.parameters())
for images, labels in train_loader:
outputs = ort_model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
optimizer.zero_grad()Summary and Best Practices
Workflow Checklist
Optimization Priority
The most impactful optimizations, roughly ordered by return on investment:
- Graph optimization — free, always apply
- INT8 quantization — 2–4× speedup, minimal accuracy loss with careful calibration
- Batching — dramatically increases GPU utilization
- IO binding — eliminates host↔︎device copies for GPU workloads
- TensorRT EP — maximum throughput on NVIDIA hardware
- Structured pruning — reduces model FLOPs before export
With these tools in hand, a model trained on a research GPU cluster can be reliably deployed on everything from a server rack to a Raspberry Pi to a browser tab.
Guide written for ONNX opset 18, ONNX Runtime 1.18+, PyTorch 2.3+, and TensorFlow 2.16+. API details may change in future releases — always consult the official ONNX Runtime documentation for the latest.