Skip to content

ONNX Export

Export trained AutoTimm models to ONNX format for cross-platform deployment with ONNX Runtime, TensorRT, OpenVINO, CoreML, and more.

Overview

ONNX (Open Neural Network Exchange) allows you to:

  • Deploy cross-platform - Run models on any platform with an ONNX runtime
  • Use multiple runtimes - ONNX Runtime, TensorRT, OpenVINO, CoreML
  • Dynamic batch sizes - Batch dimension is dynamic by default
  • Hardware acceleration - GPU, CPU, and specialized accelerators
  • Self-contained files - Single .onnx file for deployment

Installation

pip install onnx onnxruntime onnxscript

# For GPU inference
pip install onnxruntime-gpu

# Or install as part of AutoTimm
pip install autotimm[onnx]

Quick Start

Basic Export

import autotimm as at  # recommended alias
from autotimm import ImageClassifier, export_to_onnx
import torch

# Load trained model
model = ImageClassifier.load_from_checkpoint("model.ckpt", compile_model=False)

# Export to ONNX
example_input = torch.randn(1, 3, 224, 224)
export_to_onnx(
    model,
    "model.onnx",
    example_input=example_input
)

Convenience Method

# Even simpler - one line export
model = ImageClassifier(backbone="resnet50", num_classes=10)
model.to_onnx("model.onnx")

Load and Use

import numpy as np
import onnxruntime as ort

# No AutoTimm dependency needed!
session = ort.InferenceSession("model.onnx")
input_name = session.get_inputs()[0].name

image = np.random.randn(1, 3, 224, 224).astype(np.float32)
outputs = session.run(None, {input_name: image})

CLI Export (Quickest)

Export directly from the command line without writing Python:

python -m autotimm.export.export_onnx \
    --checkpoint logs/run_1/checkpoints/best.ckpt \
    --output model.onnx \
    --task-class ImageClassifier \
    --input-size 224

Additional options:

# With graph simplification and custom opset
python -m autotimm.export.export_onnx \
    --checkpoint checkpoint.ckpt \
    --output model.onnx \
    --task-class ObjectDetector \
    --opset-version 17 \
    --simplify

# With explicit hparams file
python -m autotimm.export.export_onnx \
    --checkpoint checkpoint.ckpt \
    --output model.onnx \
    --hparams-yaml logs/run_1/hparams.yaml

The CLI auto-detects image_size from model hparams when available, so --input-size is optional if the checkpoint was saved with hparams.

Export Methods

AutoTimm provides multiple ways to export models to ONNX from Python:

Method 1: export_to_onnx()

Full control over the export process:

from autotimm.export import export_to_onnx

path = export_to_onnx(
    model=model,
    save_path="model.onnx",
    example_input=torch.randn(1, 3, 224, 224),
    opset_version=17,     # ONNX opset version
    simplify=False,       # Optional graph simplification
)

Parameters:

  • model: The PyTorch model to export
  • save_path: Output file path (.onnx extension recommended)
  • example_input: Example input tensor (required)
  • opset_version: ONNX opset version (default: 17)
  • dynamic_axes: Dynamic axes config (default: batch dimension dynamic)
  • simplify: Whether to simplify the graph with onnx-simplifier

Method 2: model.to_onnx()

Convenience method on model instances:

# With file save
path = model.to_onnx("model.onnx")

# Without specifying path (uses temp file)
path = model.to_onnx()

# With custom options
path = model.to_onnx(
    "model.onnx",
    example_input=torch.randn(1, 3, 299, 299),
    opset_version=17
)

Method 3: export_checkpoint_to_onnx()

Direct checkpoint export:

from autotimm import export_checkpoint_to_onnx, ImageClassifier

path = export_checkpoint_to_onnx(
    checkpoint_path="model.ckpt",
    save_path="model.onnx",
    model_class=ImageClassifier,
    example_input=torch.randn(1, 3, 224, 224),
)

Dynamic Batch Size

By default, the batch dimension is dynamic, allowing inference with any batch size:

# Export with batch size 1
export_to_onnx(model, "model.onnx", torch.randn(1, 3, 224, 224))

# Inference with any batch size
session = ort.InferenceSession("model.onnx")
input_name = session.get_inputs()[0].name

# Works with batch size 1, 4, 8, etc.
for batch_size in [1, 4, 8]:
    input_data = np.random.randn(batch_size, 3, 224, 224).astype(np.float32)
    outputs = session.run(None, {input_name: input_data})
    print(f"Batch {batch_size}: output shape {outputs[0].shape}")

Custom Dynamic Axes

# Dynamic batch and spatial dimensions
dynamic_axes = {
    "input": {0: "batch_size", 2: "height", 3: "width"},
    "output": {0: "batch_size"},
}

export_to_onnx(
    model, "model.onnx", example_input,
    dynamic_axes=dynamic_axes
)

Validation

Verify the exported model matches the original:

from autotimm.export import validate_onnx_export

# Export model
export_to_onnx(model, "model.onnx", example_input)

# Validate outputs match
is_valid = validate_onnx_export(
    original_model=model,
    onnx_path="model.onnx",
    example_input=example_input,
    rtol=1e-5,
    atol=1e-5
)

if is_valid:
    print(":material-check: Export verified successfully")
else:
    print(":material-close: Export validation failed")

Supported Models

All AutoTimm task models support ONNX export:

Classification

from autotimm import ImageClassifier

model = ImageClassifier(backbone="resnet50", num_classes=1000)
model.to_onnx("classifier.onnx")

Semantic Segmentation

from autotimm import SemanticSegmentor

model = SemanticSegmentor(backbone="resnet50", num_classes=19)
model.to_onnx("segmentor.onnx")

Object Detection

Detection models automatically flatten their list outputs into named tensors for ONNX compatibility:

from autotimm import ObjectDetector

model = ObjectDetector(backbone="resnet50", num_classes=80)
model.to_onnx("detector.onnx", example_input=torch.randn(1, 3, 640, 640))

# Outputs: cls_l0..cls_l4, reg_l0..reg_l4, ctr_l0..ctr_l4 (15 tensors)

Instance Segmentation

from autotimm import InstanceSegmentor

model = InstanceSegmentor(backbone="resnet50", num_classes=80)
model.to_onnx("instance.onnx", example_input=torch.randn(1, 3, 800, 800))

# Outputs: cls_l0..cls_l4, reg_l0..reg_l4, ctr_l0..ctr_l4 (15 tensors)

YOLOX

from autotimm import YOLOXDetector

model = YOLOXDetector(model_name="yolox-s", num_classes=80)
model.to_onnx("yolox.onnx")  # Default input: 640x640

# Outputs: cls_l0..cls_l2, reg_l0..reg_l2 (6 tensors)

Production Deployment

ONNX Runtime (Python)

import numpy as np
import onnxruntime as ort

# Load model
session = ort.InferenceSession("model.onnx")
input_name = session.get_inputs()[0].name

# Run inference
image = np.random.randn(1, 3, 224, 224).astype(np.float32)
outputs = session.run(None, {input_name: image})

# Get predictions
logits = outputs[0]
predicted_class = np.argmax(logits, axis=1)[0]

GPU Inference

# Use CUDA provider for GPU acceleration
session = ort.InferenceSession(
    "model.onnx",
    providers=["CUDAExecutionProvider", "CPUExecutionProvider"]
)

# Use TensorRT for maximum NVIDIA GPU performance
session = ort.InferenceSession(
    "model.onnx",
    providers=["TensorrtExecutionProvider", "CUDAExecutionProvider", "CPUExecutionProvider"]
)

Session Optimization

# Enable graph optimizations
sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL

session = ort.InferenceSession(
    "model.onnx",
    sess_options,
    providers=["CPUExecutionProvider"]
)

Using load_onnx()

AutoTimm provides a convenience function that validates and loads in one step:

from autotimm import load_onnx

# Validates model integrity, then creates inference session
session = load_onnx("model.onnx")
session = load_onnx("model.onnx", providers=["CUDAExecutionProvider", "CPUExecutionProvider"])

ONNX vs TorchScript

Feature ONNX TorchScript
Cross-platform Wide compatibility PyTorch ecosystem only
Runtime options ONNX Runtime, TensorRT, OpenVINO, CoreML LibTorch only
Dynamic batch Built-in support Works but less flexible
C++ deployment Via ONNX Runtime C++ API Via LibTorch
Mobile ONNX Runtime Mobile PyTorch Mobile
Python-free No Python needed No Python needed
GPU optimization TensorRT, OpenVINO Limited

Recommendation: Use ONNX for cross-platform deployment and hardware-optimized inference. Use TorchScript when staying within the PyTorch ecosystem.

Troubleshooting

Common Issues

1. ImportError: No module named 'onnx'

pip install onnx onnxruntime onnxscript

2. RuntimeError: ONNX export failed

  • Try a lower opset version: opset_version=14
  • Ensure model is in eval mode: model.eval()
  • Disable torch.compile: compile_model=False

3. Outputs don't match original model

# Validate with looser tolerance
is_valid = validate_onnx_export(model, "model.onnx", example_input, rtol=1e-3, atol=1e-3)

4. Dynamic axes warning

The warning about dynamic_axes being deprecated in favor of dynamic_shapes is harmless and can be ignored. AutoTimm handles this internally.

Limitations

What Works

Standard feedforward models CNN backbones (ResNet, EfficientNet, etc.) Vision Transformers (ViT, Swin, DeiT) Detection models (FCOS, YOLOX) Segmentation models (DeepLabV3+, FCN) Dynamic batch sizes Different input sizes

What Doesn't Work

Nested list/dict outputs (detection outputs are auto-flattened) Training-specific features (optimizers, schedulers) Some custom Python operations without ONNX equivalents Mask head of InstanceSegmentor (detection head only)

Examples

See complete working examples in the repository:

  • examples/deployment/export_to_onnx.py - Comprehensive ONNX export examples

TensorRT Deployment

For maximum inference performance on NVIDIA GPUs, convert your ONNX model to a TensorRT engine. TensorRT applies layer fusion, kernel auto-tuning, and precision calibration to produce highly optimised inference plans.

Prerequisites

pip install tensorrt pycuda numpy

TensorRT requires an NVIDIA GPU and CUDA toolkit installed on the system.

Convert ONNX to TensorRT

import tensorrt as trt

logger = trt.Logger(trt.Logger.WARNING)
builder = trt.Builder(logger)
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
parser = trt.OnnxParser(network, logger)

# Parse the ONNX model
with open("model.onnx", "rb") as f:
    if not parser.parse(f.read()):
        for i in range(parser.num_errors):
            print(parser.get_error(i))
        raise RuntimeError("ONNX parsing failed")

# Build the engine
config = builder.create_builder_config()
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 30)  # 1 GB

# Optional: enable FP16 for ~2x speedup
# config.set_flag(trt.BuilderFlag.FP16)

engine_bytes = builder.build_serialized_network(network, config)

# Save the engine
with open("model.engine", "wb") as f:
    f.write(engine_bytes)

Run Inference with TensorRT

import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit
import numpy as np

# Load engine
logger = trt.Logger(trt.Logger.WARNING)
runtime = trt.Runtime(logger)
with open("model.engine", "rb") as f:
    engine = runtime.deserialize_cuda_engine(f.read())

context = engine.create_execution_context()

# Allocate buffers
input_shape = engine.get_tensor_shape(engine.get_tensor_name(0))
output_shape = engine.get_tensor_shape(engine.get_tensor_name(1))

h_input = np.random.randn(*input_shape).astype(np.float32)
h_output = np.empty(output_shape, dtype=np.float32)

d_input = cuda.mem_alloc(h_input.nbytes)
d_output = cuda.mem_alloc(h_output.nbytes)

# Transfer input to GPU, run, transfer output back
cuda.memcpy_htod(d_input, h_input)
context.set_tensor_address(engine.get_tensor_name(0), int(d_input))
context.set_tensor_address(engine.get_tensor_name(1), int(d_output))
context.execute_async_v3(stream_handle=cuda.Stream().handle)
cuda.memcpy_dtoh(h_output, d_output)

predicted_class = np.argmax(h_output, axis=1)[0]
print(f"Predicted class: {predicted_class}")

End-to-End: Checkpoint → ONNX → TensorRT

# Step 1: Export checkpoint to ONNX
python -m autotimm.export.export_onnx \
    --checkpoint best.ckpt \
    --output model.onnx \
    --task-class ImageClassifier

# Step 2: Convert ONNX to TensorRT (via trtexec CLI)
trtexec --onnx=model.onnx --saveEngine=model.engine --fp16

trtexec ships with the TensorRT installation and is the simplest way to convert without writing Python.

Performance Comparison

Runtime Latency (ResNet-50, batch=1) Notes
PyTorch (CPU) ~30 ms Baseline
ONNX Runtime (CPU) ~15 ms 2x faster
ONNX Runtime (CUDA) ~5 ms GPU acceleration
TensorRT (FP32) ~2 ms Optimised GPU kernels
TensorRT (FP16) ~1 ms Half precision

Approximate values on an NVIDIA A100. Actual results depend on hardware and model architecture.

See Also