Click here to Skip to main content
15,844,754 members
Articles / Artificial Intelligence / Keras

Making PyTorch AI Models Portable Using ONNX

Rate me:
Please Sign up or sign in to vote.
5.00/5 (2 votes)
8 Sep 2020CPOL5 min read 6K   41   4  
In this article I provide a brief overview of PyTorch for those looking for a deep learning framework for building and training neural networks.
Here I then show how to convert PyTorch models to the ONNX format using the conversion tool which is already a part of PyTorch itself. I also show the best practice of adding metadata to the exported model.

In this article in our series about using portable neural networks in 2020, you’ll learn how to convert a PyTorch model to the portable ONNX format.

Since ONNX is not a framework for building and training models, I will start with a brief introduction to PyTorch. This will be useful for engineers that are starting from scratch and are considering PyTorch as a framework to build and train their models.

A Brief Introduction to PyTorch

PyTorch was released in 2016 and was developed by Facebook’s AI Research lab (FAIR). It has become the preferred framework for researchers experimenting with natural language processing and computer vision. This is interesting because TensorFlow is more widely used in production environments. This dichotomy between what is preferred in the research laboratory versus production really emphasizes the value of a standard like ONNX, which provides a common format for models and a runtime that can be used from all the popular programming languages. As an example, let’s suppose an organization does not want to have every possible framework in its production environment and instead wants to standardize on one. Without ONNX, the model would need to be reimplemented in the framework chosen for production and deployed. This is a non-trivial engineering task. Using ONNX, the PyTorch model can be exported with just a few lines of code and consumed from any language. Only the ONNX Runtime is needed in production.

Importing the Converter

The maintainers of PyTorch have integrated the ONNX converter into PyTorch itself. You do not need to install any additional packages. Once PyTorch is installed, you can access the PyTorch to ONNX converter by including the following import in your modules:

import torch

Once the torch module is imported, you can access the conversion function as follows:


Hopefully, this is a practice that other frameworks will adopt. Packaging and versioning the converter with the framework itself makes for one less package to install and also prevents version mismatches between the framework and converter.

A Quick Look at a Model

Before converting a PyTorch model, we need to look at the code that creates the model in order to determine the shape of the input. The code below creates a PyTorch model that predicts the numbers found in the MNIST dataset. A detailed description of the model layers is beyond the scope of this article, but we do need to note the shape of the input. Here it is 784. More specifically, this code is creating a model where the input will be a flattened tensor that is an array of 784 floats. What is the significance of 784? Well, each of the images in the MNIST dataset is a 28 × 28 pixel image. 28 × 28 = 784. So, once flattened, our input is 784 floats where each float represents a shade of gray. The bottom line: This model is expecting 784 floats from a single image. It is not expecting a multidimensional array and it is not expecting a batch of images. Only one prediction at a time. This is an important fact when converting the model to ONNX.

def build_model():
   # Layer details for the neural network
   input_size = 784
   hidden_sizes = [128, 64]
   output_size = 10

   # Build a feed-forward network
   model = nn.Sequential(nn.Linear(input_size, hidden_sizes[0]),
                       nn.Linear(hidden_sizes[0], hidden_sizes[1]),
                       nn.Linear(hidden_sizes[1], output_size),
   return model

Converting PyTorch Models to ONNX

The function below shows how to use the torch.onnx.export function. There are a few tricks to using this function correctly. The first and most important trick is to set up your sample input correctly. The sample_input parameter is used to determine the input to the ONNX model. The export_to_onnx function will accept whatever you give it — as long as it is a tensor — and the conversion will work without error. However, if the sample input is of the wrong shape then you will get an error when you try to run the ONNX model from ONNX Runtime.

def export_to_onnx(model):
   sample_input = torch.randn(1, 784)
   torch.onnx.export(model,            # model being run
                     sample_input,     # model input (or a tuple for multiple inputs)
                     ONNX_MODEL_FILE,  # where to save the model
                     input_names = ['input'],   # the model's input names
                     output_names = ['output'] # the model's output names

   # Set metadata on the model.
   onnx_model = onnx.load(ONNX_MODEL_FILE)
   meta = onnx_model.metadata_props.add()
   meta.key = "creation_date"
   meta.value ="%m/%d/%Y, %H:%M:%S")
   meta = onnx_model.metadata_props.add()
   meta.key = "author"
   meta.value = 'keithpij'
   onnx_model.doc_string = 'MNIST model converted from Pytorch'
   onnx_model.model_version = 3  # This must be an integer or long., ONNX_MODEL_FILE)

If the original PyTorch model were designed to accept a batch of 100 images then this sample input would be fine. However, as previously stated, our model was designed to accept only one image at a time when making predictions. If you export the model with this sample input, then you’ll get an error when you run the model.

The code that adds metadata to the model is a best practice. As the data you use to train your model evolves, so will your model. Therefore it is a good idea to add metadata to your model so that you can distinguish it from previous models. The example above adds a brief description of the model to the doc_string property and sets the version. creation_date and author are custom properties added to the metadata_props property bag. You are free to create as many custom properties using this property bag. Unfortunately, the model_version property requires an integer or long so you will not be able to version it like your services using major.minor.revision syntax. Additionally, the export function saves the model to a file automatically, so to add this metadata you need to reopen the file and resave it.

Summary and Next Steps

In this article, I provided a brief overview of PyTorch for those looking for a deep learning framework for building and training neural networks. I then showed how to convert PyTorch models to the ONNX format using the conversion tool which is already a part of PyTorch itself. I also showed the best practice of adding metadata to the exported model.

Since the purpose of this article was to demonstrate converting Keras models to the ONNX format, I did not go into detail building and training Keras models. The code sample for this post contains code that explores Keras itself. The module is a full end-to-end demo that shows how to load the data, explore the images, and train the model.

Next, we’ll look at converting a TensorFlow model to ONNX.


This article is part of the series 'CodeProject’s Next Top Model View All


This article, along with any associated source code and files, is licensed under The Code Project Open License (CPOL)

Written By
Technical Lead Bank of New York Mellon
United States United States
Keith is a sojourner in the software industry. He has over 30 years of experience building and bringing applications to market. He has worked for startups and large enterprises in roles ranging from tech lead to business development manager. He is currently a senior engineer on BNY Mellon's Distribution Analytics team where he is building data pipelines from on-premise data sources to the cloud.

Comments and Discussions

-- There are no messages in this forum --