ONNX Runtime Training Technical Deep Dive

Published May 19 2020 08:00 AM 7,991 Views

Author: Sherlock Huang, AI Frameworks, Microsoft

This post is co-authored by Cheng Tang, Jesse Benson, Kaarthik Sivashanmugam and Alexey Svyatkovskiy


Today we announced the preview for new training feature in ONNX Runtime (ORT). This blog explains how we have been using it to accelerate training for large transformer models. ONNX Runtime Training is integrated with PyTorch so that existing training code can be directly accelerated for training.

In this paper, we will describe some of the key aspects of ORT design and implementation that enable us to achieve the distributed training performance improvements. We will also use BERT-L pre-training as the benchmark to illustrate the performance of ORT training. Finally, we will present a case study of training GPT-2 model for code autocompletion feature in Visual Studio IntelliCode


Design and Implementation

ONNX Runtime Training is built on the same open sourced code as the popular inference engine for ONNX models. Figure 1 shows the high-level architecture for ONNX Runtime’s ecosystem. ORT is a common runtime backend that supports multiple framework frontends, such as PyTorch and Tensorflow/Keras. It makes use of the Execution Provider interface to perform computation on different hardwareThis enables us to build hardware-agnostic, graph-level optimizations that are extensible across different platforms, as well as hardware specific optimization targeting platforms like NVIDIA GPUWe have also implemented additional optimizations, outlined below, to expedite training for large transformer models. 


Figure 1. ONNX Runtime High Level ArchitectureFigure 1. ONNX Runtime High Level Architecture

Static Graph Optimizations

Machine learning models are commonly abstracted as computation graphs. The computation graph used by deep learning frameworks could be either static or dynamic. In the current implementation, ORT has a view of the entire static computation graph. This makes it possible to enable many common graph optimization techniques, such as constant folding, redundant operation elimination, and operator fusion. They are first applied on the forward computation graph before auto differentiation engine builds the backward graph. As ORT has the global knowledge of data dependencies, it only builds the minimal gradient graph that is needed for targeted weights. Consequently, activation tensors that are not needed for backward computation are automatically dropped after use. With a minimal training graph, it ensures that only essential computation is performed and memory consumption is minimized.


Memory Usage Optimizations

Over the last few years, the size of deep learning models has been growing rapidly. GPU memory consumption has become a limiting factor for large model training. ORT has made conscious efforts to preserve and reuse memory whenever possible. For example, ORT reuses the same buffer segments throughout a series of operations, including gradient accumulation, gradient scaling adjustment, allreduce communication and weight update computation (if the optimizer allows). ORT also tries to perform in-place operations if the source tensor is no longer consumed elsewhere in the computation graph. ORT’s kernel implementation also tries to minimize the use of scratch buffers, such as avoid using some memory intensive cuDNN functions, and reusing output buffer as scratch buffer if possible. As a result, ORT can train BERT with 2x the batch size as PyTorch. This enables us to utilize the GPU resources more efficiently, resulting in better performance on the same model and the ability to train larger models.


ZeRO Stage 1 Integration

Zero Redundancy Optimizer (ZeRO) is a memory optimization technique from Microsoft Research. ZeRO is used to save GPU memory consumption by eliminating duplicated states across workers during distributed training. ZeRO has three main optimization stages.  Currently, ONNX Runtime implemented Stage 1 of ZeRO. ZeRO Stage 1, known as the optimizer state partitioning, allows ORT to shard the optimizer states, including 1st and 2nd order moments (and fp32 copy of weights in mixed precision mode), across multiple workers with no extra communication overhead. With ZeRO, ORT can further boost batch size or train a larger model. In BERT-L pre-training, ZeRO allows batch size to further grow from 148 to 168 for phase 1 and from 23 to 27 for phase 2 in a 32GB V100. Distributed checkpointing is also introduced, as model persistent state is distributed across multiple workers. ZeRO can be enabled with a config flag.


Native Mixed Precision Training Support     

Unlike PyTorch’s dependency on NVIDIA Apex extension, ORT has implemented its own support for mixed precision training. Mixed precision training can be enabled with a config flag – no other code change needed. Under the hood, ORT converts the static computation graph into mixed precision mode through a series of graph transformations, i.e. running most of the computations in fp16 while keeping some numerically sensitive computation in fp32. ORT supports dynamic loss scaling by automatically inserting the computation nodes for loss scaling into the graph.


Highly Scaleable Distributed Training

ORT seeks to build a unified highly scaleable distributed training framework for hybrid parallelism, including a mixed of data and model parallelisms. ORT supports data parallelism, which is the most popular distributed training mode adopted by many internal teams. We are enhancing ORT to fully support training extremely large models (>100 billion parameters). It has an experimental implementation of Megatron-style horizontal parallelism and we are actively developing to support pipeline parallelism, such as PipeDream.


CUDA Kernel Optimizations

ORT has introduced highly optimized CUDA kernels for some key operations including Reductions, Dropout and Softmax. In addition, we have also introduced a few key operator fusions with fused kernels for LayerNormalization, Gelu and their gradients, as well as Lamb Optimizer.


Using ORT with PyTorch Training Code

ONNX Runtime has the capability to train existing PyTorch models through its optimized backend. For this, we have introduced a python API for PyTorch, called ORTTrainer, which can be used to switch the training backend for PyTorch models (instance of torch.nn.Module) to ORT. This requires some changes from the user, such as replacing the PyTorch optimizer, and optionally, setting flags to enable additional features such as mixed-precision training. Under the hood, as shown in Figure 2, ORTTrainer first converts the PyTorch model to ONNX format through the PyTorch-ONNX exporter. Next, ORT backend takes over and applies graph optimizations, builds a training graph, performs transformations on it as needed (e.g. mixed-precision transformation), and sets up the graph elements needed for distributed training. In this design, while all the computation-intensive workload is offloaded onto the ORT backend, users can still enjoy the rich PyTorch frontend utilities, such as data loading, checkpointing , and easy specification of loss functions. 


Figure 2. Workflow for converting an PyTorch model into an ORT training graphFigure 2. Workflow for converting an PyTorch model into an ORT training graph

It is important to note that the current API is experimental and expected to see significant changes in the near future. A new version of the API is under active development. Our goal is to improve the interface to provide more seamless integration with PyTorch training that requires minimal changes in users’ training code, introduce new features, and present a more flexible API to cover advanced scenarios. Please refer to the training examples for more details.


Benchmarking Training Acceleration with ONNX Runtime

We now present the performance evaluation of BERT-L pre-training with ONNX Runtime in a 4-node DGX-2 cluster. In AzureML, we also reproduced the pre-training convergence for BERT-Large using sample from NVIDIA’s DeepLearningExamplesle’s repo. We also validated fine tuning accuracy with SQuAD benchmarks.


Benchmarking on DGX-2

We compared PyTorch and ORT’s BERT-L training performance on 4 NVIDIA DGX-2 machines (each with 16x 32GB V100) interconnected with InfiniBand. PyTorch’s result was obtained with NGC 20.03-py3 docker image following Nvidia’s recipe. ORT’s result was obtained following the same recipe, except that ORT used bigger local batch sizes. As described above, ORT is able to run at a 2x batch size of PyTorch’s. ORT ran at a local batch size of 128 and 16 for phase 1 and 2 respectively, whereas PyTorch ran at batch size of 64 and 8. The effective global batch size remained unchanged in both cases. Overall, ORT achieved throughput improvement of 11.32% and 14.61% for phase 1 and 2. The total time to train was reduces by 11.16%, from 17.74 hours to 15.76 hours.

Table 1. Time to train on 4 NVIDIA DGX-2 machines


PyTorch 1.5 with

NGC 20.03-py3

PyTorch 1.5 with

ONNX Runtime

% Gain with

ONNX Runtime

Phase 1 Throughput (ex/sec)




Phase 2 Throughput (ex/sec)




Phase 1 time (hours)




Phase 2 time (hours)




Total time (hours)





BERT-L Pre-training on AzureML

We performed BERT-L pre-training on 8x ND40rs_v2 cluster (each with 8x 32GB V100) interconnected with InfiniBand in AzureML. We used the same Nvidia’s recipe, expect that we doubled the local batch size in the same way we mentioned above. Mixed precision mode and LAMB optimizer was used throughout the training. As the end of phase 2, we achieved the training loss of 1.31. The end-to-end training time was 18.32 hours.

Table 2. Time to train on Azure ML with 8x ND40rs_v2


PyTorch 1.5 with ONNX Runtime

Phase 1 Throughput (ex/sec)


Phase 2 Throughput (ex/sec)


Phase 1 Time (hours)


Phase 2 Time (hours)


Total Time (hours)



Figure 3 shows a loss curve produced in a typical pre-training run. Phase 1 ends with a loss value around 1.4 after 7038 steps. Phase 2 continues with a jump of loss due to switch of sequence length, and it finally decrease to a loss value around 1.3.


Figure 3. ORT BERT-L pre-training loss curvesFigure 3. ORT BERT-L pre-training loss curves

The pretrained model is then further finetuned on SQuAD dataset. Both full precision or mixed precision finetuning result in satisfactory Exact Match and F1 scores.

Table 3. BERT-L fine-tuning result on SQuAD Dataset

Accuracy Metrics

Finetuning - FP32

Finetuning -

mixed precision

Exact Match %



F1 score %




A Case Study with Visual Studio using GPT-2 Medium

Microsoft Visual Studio uses ONNX Runtime to accelerate pre-training a 24-layer GPT-2 Medium model to power code autocompletion in the IntelliCode of Visual Studio. Intellicode serves as a universal programming language compiler, effectively generating syntactically correct code in multiple programming languages, capable of completing an entire line of code in a couple of keystrokes. The training dataset for this task comprises over 1.2 billion lines of source code in Python, C#, JavaScript and TypeScript programming language from 52000 top-starred projects in GitHub. We treat the source code data as a sequence of tokens corresponding to the output of a lexical analyzer.

The training was performed in a DGX-2 cluster. As we use a large sequence length of 1024, the memory usage is very intensive and PyTorch is only able to fit a batch size of 2 on the 32GB V100. ORT achieved 15.8% higher throughput under the identical local batch. As ORT is more memory efficient and able to run at a bigger batch size of 3, it delivered an overall 20.5% of the throughput improvement. As a result, the overall training time is reduced from 202 hours to 168 hours (with 1.2 x higher throughput). The final evaluation metric also achieved the same production shipping bar.  

Table 4. GPT-2 medium pre-training performance.


Batch size / GPU

Throughput (ex/sec)

Time to train (hours)





PyTorch + ORT




Pytorch + ORT






Today, we announced the preview of training support in ONNX Runtime with a focus on large scale computation intensive transformer models. We have demonstrated that, on a 4 DGX-2 cluster, ONNX Runtime can achieve a throughput gain of 11.32% and 14.61% for BERT-L phase 1 and 2 pre-training over PyTorch. The total training time was reduced by 11.16%, from 17.74 hours to 15.76 hours. ONNX Runtime is able to train BERT-L at a 2x batch size as PyTorch. We have shown a similar 20.5% speedup on a GPT-2 model, saving 34 hours in total training time. ONNX Runtime Training is integrated with PyTorch so that existing PyTorch training code can be directly accelerated for transformer models training. 


Get Started

As a part of the announcement on using ONNX Runtime for training, we have released a Docker image with ORT and made available a repo at https://github.com/microsoft/onnxruntime-training-examples that will host examples for ORT training. The first recipe available in this repo will help you get started with ORT for BERT pretraining in Azure Machine Learning service or NVIDIA DGX-2 and see the speedup in action. This recipe shows how to use ONNX Runtime training with BERT pretraining implementation in PyTorch. You can use this example either with the two datasets used in the original implementation or with your custom dataset to pretrain a BERT model and get the performance improvements with ORT reported in this blog. We are planning to add more examples for transformer models and other models. We also welcome your contribution to this repo and feedback to improve ORT training capabilities and experience.

Version history
Last update:
‎Oct 29 2020 10:00 AM
Updated by: