Convert custom model to ONNX format¶
Pytorch¶
Steps based on EXPORTING A MODEL FROM PYTORCH TO ONNX AND RUNNING IT USING ONNX RUNTIME.
Step 0. Requirements:
Pytorch
ONNX
- Step 1. Load PyTorch model
from torch import nn import torch.utils.model_zoo as model_zoo import torch.onnx model = ... # your model instation model.load_state_dict(torch.load(YOUR_MODEL_CHECKPOINT_PATH, map_location='cpu')['state_dict']) model.eval()
- Step 2. Create data sample with
batch_size=1
and call forward step of your model: x = torch.rand(1, INP_CHANNEL, INP_HEIGHT, INP_WIDTH) # eg. torch.rand([1, 3, 256, 256]) _ = model(x)
Step 3a. Call export function with static batch_size=1:
torch.onnx.export(model, x, # model input 'model.onnx', # where to save the model export_params=True, opset_version=15, input_names=['input'], output_names=['output'], do_constant_folding=False)
Step 3b. Call export function with dynamic batch_size:
torch.onnx.export(model, x, # model input 'model.onnx', # where to save the model export_params=True, opset_version=15, input_names=['input'], output_names=['output'], dynamic_axes={'input': {0: 'batch_size'}, # variable lenght axes 'output': {0: 'batch_size'}})
Tensorflow/Keras¶
Steps based on the tensorflow-onnx repository. The instruction is valid for saved model
format. For other types follow tensorflow-onnx
instructions.
Requirements:
tensorflow
ONNX
tf2onnx
And simply call converter script:
python -m tf2onnx.convert --saved-model YOUR_MODEL_CHECKPOINT_PATH --output model.onnx --opset 15
Update ONNX model to support dynamic batch size¶
To convert model to support dynamic batch size, you need to update model.onnx
file. You can do it manually using this script. Please note that the script is not perfect and may not work for all models.