ONNX Export¶
Export trained AutoTimm models to ONNX format for cross-platform deployment with ONNX Runtime, TensorRT, OpenVINO, CoreML, and more.
Overview¶
ONNX (Open Neural Network Exchange) allows you to:
- Deploy cross-platform - Run models on any platform with an ONNX runtime
- Use multiple runtimes - ONNX Runtime, TensorRT, OpenVINO, CoreML
- Dynamic batch sizes - Batch dimension is dynamic by default
- Hardware acceleration - GPU, CPU, and specialized accelerators
- Self-contained files - Single
.onnxfile for deployment
Installation¶
pip install onnx onnxruntime onnxscript
# For GPU inference
pip install onnxruntime-gpu
# Or install as part of AutoTimm
pip install autotimm[onnx]
Quick Start¶
Basic Export¶
import autotimm as at # recommended alias
from autotimm import ImageClassifier, export_to_onnx
import torch
# Load trained model
model = ImageClassifier.load_from_checkpoint("model.ckpt", compile_model=False)
# Export to ONNX
example_input = torch.randn(1, 3, 224, 224)
export_to_onnx(
model,
"model.onnx",
example_input=example_input
)
Convenience Method¶
# Even simpler - one line export
model = ImageClassifier(backbone="resnet50", num_classes=10)
model.to_onnx("model.onnx")
Load and Use¶
import numpy as np
import onnxruntime as ort
# No AutoTimm dependency needed!
session = ort.InferenceSession("model.onnx")
input_name = session.get_inputs()[0].name
image = np.random.randn(1, 3, 224, 224).astype(np.float32)
outputs = session.run(None, {input_name: image})
CLI Export (Quickest)¶
Export directly from the command line without writing Python:
python -m autotimm.export.export_onnx \
--checkpoint logs/run_1/checkpoints/best.ckpt \
--output model.onnx \
--task-class ImageClassifier \
--input-size 224
Additional options:
# With graph simplification and custom opset
python -m autotimm.export.export_onnx \
--checkpoint checkpoint.ckpt \
--output model.onnx \
--task-class ObjectDetector \
--opset-version 17 \
--simplify
# With explicit hparams file
python -m autotimm.export.export_onnx \
--checkpoint checkpoint.ckpt \
--output model.onnx \
--hparams-yaml logs/run_1/hparams.yaml
The CLI auto-detects image_size from model hparams when available, so --input-size is optional if the checkpoint was saved with hparams.
Export Methods¶
AutoTimm provides multiple ways to export models to ONNX from Python:
Method 1: export_to_onnx()¶
Full control over the export process:
from autotimm.export import export_to_onnx
path = export_to_onnx(
model=model,
save_path="model.onnx",
example_input=torch.randn(1, 3, 224, 224),
opset_version=17, # ONNX opset version
simplify=False, # Optional graph simplification
)
Parameters:
model: The PyTorch model to exportsave_path: Output file path (.onnx extension recommended)example_input: Example input tensor (required)opset_version: ONNX opset version (default: 17)dynamic_axes: Dynamic axes config (default: batch dimension dynamic)simplify: Whether to simplify the graph with onnx-simplifier
Method 2: model.to_onnx()¶
Convenience method on model instances:
# With file save
path = model.to_onnx("model.onnx")
# Without specifying path (uses temp file)
path = model.to_onnx()
# With custom options
path = model.to_onnx(
"model.onnx",
example_input=torch.randn(1, 3, 299, 299),
opset_version=17
)
Method 3: export_checkpoint_to_onnx()¶
Direct checkpoint export:
from autotimm import export_checkpoint_to_onnx, ImageClassifier
path = export_checkpoint_to_onnx(
checkpoint_path="model.ckpt",
save_path="model.onnx",
model_class=ImageClassifier,
example_input=torch.randn(1, 3, 224, 224),
)
Dynamic Batch Size¶
By default, the batch dimension is dynamic, allowing inference with any batch size:
# Export with batch size 1
export_to_onnx(model, "model.onnx", torch.randn(1, 3, 224, 224))
# Inference with any batch size
session = ort.InferenceSession("model.onnx")
input_name = session.get_inputs()[0].name
# Works with batch size 1, 4, 8, etc.
for batch_size in [1, 4, 8]:
input_data = np.random.randn(batch_size, 3, 224, 224).astype(np.float32)
outputs = session.run(None, {input_name: input_data})
print(f"Batch {batch_size}: output shape {outputs[0].shape}")
Custom Dynamic Axes¶
# Dynamic batch and spatial dimensions
dynamic_axes = {
"input": {0: "batch_size", 2: "height", 3: "width"},
"output": {0: "batch_size"},
}
export_to_onnx(
model, "model.onnx", example_input,
dynamic_axes=dynamic_axes
)
Validation¶
Verify the exported model matches the original:
from autotimm.export import validate_onnx_export
# Export model
export_to_onnx(model, "model.onnx", example_input)
# Validate outputs match
is_valid = validate_onnx_export(
original_model=model,
onnx_path="model.onnx",
example_input=example_input,
rtol=1e-5,
atol=1e-5
)
if is_valid:
print(":material-check: Export verified successfully")
else:
print(":material-close: Export validation failed")
Supported Models¶
All AutoTimm task models support ONNX export:
Classification¶
from autotimm import ImageClassifier
model = ImageClassifier(backbone="resnet50", num_classes=1000)
model.to_onnx("classifier.onnx")
Semantic Segmentation¶
from autotimm import SemanticSegmentor
model = SemanticSegmentor(backbone="resnet50", num_classes=19)
model.to_onnx("segmentor.onnx")
Object Detection¶
Detection models automatically flatten their list outputs into named tensors for ONNX compatibility:
from autotimm import ObjectDetector
model = ObjectDetector(backbone="resnet50", num_classes=80)
model.to_onnx("detector.onnx", example_input=torch.randn(1, 3, 640, 640))
# Outputs: cls_l0..cls_l4, reg_l0..reg_l4, ctr_l0..ctr_l4 (15 tensors)
Instance Segmentation¶
from autotimm import InstanceSegmentor
model = InstanceSegmentor(backbone="resnet50", num_classes=80)
model.to_onnx("instance.onnx", example_input=torch.randn(1, 3, 800, 800))
# Outputs: cls_l0..cls_l4, reg_l0..reg_l4, ctr_l0..ctr_l4 (15 tensors)
YOLOX¶
from autotimm import YOLOXDetector
model = YOLOXDetector(model_name="yolox-s", num_classes=80)
model.to_onnx("yolox.onnx") # Default input: 640x640
# Outputs: cls_l0..cls_l2, reg_l0..reg_l2 (6 tensors)
Production Deployment¶
ONNX Runtime (Python)¶
import numpy as np
import onnxruntime as ort
# Load model
session = ort.InferenceSession("model.onnx")
input_name = session.get_inputs()[0].name
# Run inference
image = np.random.randn(1, 3, 224, 224).astype(np.float32)
outputs = session.run(None, {input_name: image})
# Get predictions
logits = outputs[0]
predicted_class = np.argmax(logits, axis=1)[0]
GPU Inference¶
# Use CUDA provider for GPU acceleration
session = ort.InferenceSession(
"model.onnx",
providers=["CUDAExecutionProvider", "CPUExecutionProvider"]
)
# Use TensorRT for maximum NVIDIA GPU performance
session = ort.InferenceSession(
"model.onnx",
providers=["TensorrtExecutionProvider", "CUDAExecutionProvider", "CPUExecutionProvider"]
)
Session Optimization¶
# Enable graph optimizations
sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
session = ort.InferenceSession(
"model.onnx",
sess_options,
providers=["CPUExecutionProvider"]
)
Using load_onnx()¶
AutoTimm provides a convenience function that validates and loads in one step:
from autotimm import load_onnx
# Validates model integrity, then creates inference session
session = load_onnx("model.onnx")
session = load_onnx("model.onnx", providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
ONNX vs TorchScript¶
| Feature | ONNX | TorchScript |
|---|---|---|
| Cross-platform | Wide compatibility | PyTorch ecosystem only |
| Runtime options | ONNX Runtime, TensorRT, OpenVINO, CoreML | LibTorch only |
| Dynamic batch | Built-in support | Works but less flexible |
| C++ deployment | Via ONNX Runtime C++ API | Via LibTorch |
| Mobile | ONNX Runtime Mobile | PyTorch Mobile |
| Python-free | No Python needed | No Python needed |
| GPU optimization | TensorRT, OpenVINO | Limited |
Recommendation: Use ONNX for cross-platform deployment and hardware-optimized inference. Use TorchScript when staying within the PyTorch ecosystem.
Troubleshooting¶
Common Issues¶
1. ImportError: No module named 'onnx'
2. RuntimeError: ONNX export failed
- Try a lower opset version:
opset_version=14 - Ensure model is in eval mode:
model.eval() - Disable
torch.compile:compile_model=False
3. Outputs don't match original model
# Validate with looser tolerance
is_valid = validate_onnx_export(model, "model.onnx", example_input, rtol=1e-3, atol=1e-3)
4. Dynamic axes warning
The warning about dynamic_axes being deprecated in favor of dynamic_shapes is harmless and can be ignored. AutoTimm handles this internally.
Limitations¶
What Works¶
Standard feedforward models CNN backbones (ResNet, EfficientNet, etc.) Vision Transformers (ViT, Swin, DeiT) Detection models (FCOS, YOLOX) Segmentation models (DeepLabV3+, FCN) Dynamic batch sizes Different input sizes
What Doesn't Work¶
Nested list/dict outputs (detection outputs are auto-flattened) Training-specific features (optimizers, schedulers) Some custom Python operations without ONNX equivalents Mask head of InstanceSegmentor (detection head only)
Examples¶
See complete working examples in the repository:
examples/deployment/export_to_onnx.py- Comprehensive ONNX export examples
TensorRT Deployment¶
For maximum inference performance on NVIDIA GPUs, convert your ONNX model to a TensorRT engine. TensorRT applies layer fusion, kernel auto-tuning, and precision calibration to produce highly optimised inference plans.
Prerequisites¶
TensorRT requires an NVIDIA GPU and CUDA toolkit installed on the system.
Convert ONNX to TensorRT¶
import tensorrt as trt
logger = trt.Logger(trt.Logger.WARNING)
builder = trt.Builder(logger)
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
parser = trt.OnnxParser(network, logger)
# Parse the ONNX model
with open("model.onnx", "rb") as f:
if not parser.parse(f.read()):
for i in range(parser.num_errors):
print(parser.get_error(i))
raise RuntimeError("ONNX parsing failed")
# Build the engine
config = builder.create_builder_config()
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 30) # 1 GB
# Optional: enable FP16 for ~2x speedup
# config.set_flag(trt.BuilderFlag.FP16)
engine_bytes = builder.build_serialized_network(network, config)
# Save the engine
with open("model.engine", "wb") as f:
f.write(engine_bytes)
Run Inference with TensorRT¶
import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit
import numpy as np
# Load engine
logger = trt.Logger(trt.Logger.WARNING)
runtime = trt.Runtime(logger)
with open("model.engine", "rb") as f:
engine = runtime.deserialize_cuda_engine(f.read())
context = engine.create_execution_context()
# Allocate buffers
input_shape = engine.get_tensor_shape(engine.get_tensor_name(0))
output_shape = engine.get_tensor_shape(engine.get_tensor_name(1))
h_input = np.random.randn(*input_shape).astype(np.float32)
h_output = np.empty(output_shape, dtype=np.float32)
d_input = cuda.mem_alloc(h_input.nbytes)
d_output = cuda.mem_alloc(h_output.nbytes)
# Transfer input to GPU, run, transfer output back
cuda.memcpy_htod(d_input, h_input)
context.set_tensor_address(engine.get_tensor_name(0), int(d_input))
context.set_tensor_address(engine.get_tensor_name(1), int(d_output))
context.execute_async_v3(stream_handle=cuda.Stream().handle)
cuda.memcpy_dtoh(h_output, d_output)
predicted_class = np.argmax(h_output, axis=1)[0]
print(f"Predicted class: {predicted_class}")
End-to-End: Checkpoint → ONNX → TensorRT¶
# Step 1: Export checkpoint to ONNX
python -m autotimm.export.export_onnx \
--checkpoint best.ckpt \
--output model.onnx \
--task-class ImageClassifier
# Step 2: Convert ONNX to TensorRT (via trtexec CLI)
trtexec --onnx=model.onnx --saveEngine=model.engine --fp16
trtexec ships with the TensorRT installation and is the simplest way to convert without writing Python.
Performance Comparison¶
| Runtime | Latency (ResNet-50, batch=1) | Notes |
|---|---|---|
| PyTorch (CPU) | ~30 ms | Baseline |
| ONNX Runtime (CPU) | ~15 ms | 2x faster |
| ONNX Runtime (CUDA) | ~5 ms | GPU acceleration |
| TensorRT (FP32) | ~2 ms | Optimised GPU kernels |
| TensorRT (FP16) | ~1 ms | Half precision |
Approximate values on an NVIDIA A100. Actual results depend on hardware and model architecture.
See Also¶
- TorchScript Export - TorchScript export guide
- Production Deployment - Complete deployment guide
- Model Export Guide - Overview of all export options
- API Reference - Complete API documentation