Accelerate PyTorch transformer model training with ONNX Runtime – a deep dive
Published Jul 13 2021 09:00 AM 14.3K Views
Microsoft

Authors: Ravi shankar Kolli (@Ravi_Kolli) , Aishwarya Bhandare (@ashbhandare), M. Zeeshan Siddiqui (@mzs-msft) , Kshama Pawar (@kshama-msft) , Sherlock Huang (@SherlockNoMad) and the ONNX Runtime Training Team

 

Why ONNX Runtime for PyTorch?

ONNX Runtime (ORT) for PyTorch accelerates training large scale models across multiple GPUs with up to 37% increase in training throughput over PyTorch and up to 86% speed up when combined with DeepSpeed. Today, transformer models are fundamental to Natural Language Processing (NLP) applications. These models with billions of parameters utilize multiple GPUs for distributed training. This large scale is costly and time consuming for pre-training and fine-tuning such complex models. Training with ONNX Runtime for PyTorch, through its torch_ort.ORTModule API, speeds up training through efficient memory utilization, highly-optimized computational graph, mixed precision execution, all through a quick and easy, couple-line change to existing PyTorch training scripts. It also provides hardware support for both Nvidia and AMD GPUs and extensibility with custom operators, optimizers and hardware accelerator support. ONNX Runtime for PyTorch empowers AI developers to take full advantage of the PyTorch ecosystem – with the flexibility of PyTorch and the performance using ONNX Runtime.

 

Flexibility in Integration

To use ONNX Runtime as the backend for training your PyTorch model, you begin by installing the torch-ort package and making the following 2-line change to your training script. ORTModule class is a simple wrapper for torch.nn.Module that optimizes the memory and computations required for training.

 

from torch_ort import ORTModule
model = ORTModule(model)
  1. import torch_ort – allows you to access all the APIs and features of ONNX Runtime
  2. model = torch_ort.ORTModule(model) – wraps the torch.nn.Module in the PyTorch training script with ORTModule to allow acceleration using ONNX Runtime

The rest of the training loop is unmodified. ORTModule can be flexibly composed with torch.nn.Module, allowing the user to wrap part or whole of the model to run with ORT. For instance, users can choose to wrap the encoder-decoder portion of the model while leaving the loss function in PyTorch. ORT will speed up the wrapped portion of the model.

 

ONNX Runtime for PyTorch has been integrated to run popular Hugging Face models with a centralized code change. Additional installation instructions can be found at pytorch/ort and samples can be found at ONNX Runtime Training Examples.

 

Performance Results

We have validated Hugging Face Transformer models with performance gains in samples/second ranging from 37% (baseline PyTorch) to 86% (combined with DeepSpeed) for different models for pre-training and fine-tuning scenarios. The performance for all runs was measured with models running on the Azure NDv2 SKU on a single node (except for the A100 results), with torch autocast as the mixed precision solution. Please refer to the configuration mentioned in the onnx-runtime-training-examples repo to reproduce the results. 

 

kshamamsft_0-1626160648567.png

 

Figure 1: ORT throughput for PyTorch Hugging Face models

 

ONNX Runtime for PyTorch also supports A100 Tensor Core GPU showing upto 1.31x throughput improvements on some Hugging Face models as seen in Figure 2.

 

kshamamsft_1-1626135034004.png

 

Figure 2: ORT throughput on A100 Tensor Core GPU

 

The speedup is a result of various graph optimizations, use of fused and optimized GPU kernels, and efficient memory handling that ORT performs in the backend. We will discuss more details in the following sections.

 

10,000 foot view of ORTModule internals

 

Overview

ORTModule is a python wrapper around torch.nn.Module (Figure 3, ORTModule wraps around torch.nn.Module) that intercepts forward and backward calls and delegates them to ONNX Runtime backend to achieve better training performance. ORTModule serves as a drop-in replacement for any torch.nn.Module for ease of use. Upon the initial forward call, the PyTorch module is exported to ONNX graph using torch-onnx exporter, which is then used to create a session. ORT’s native auto-differentiation is invoked during session creation by augmenting the forward graph to insert gradient nodes (backward graph). Static graph transformations such as constant folding, redundant operation elimination, and operator fusions are further applied to optimize the computation graph. ORTModule backend uses highly optimized kernels to execute the graph with optimal use of the GPU resources.

 

kshamamsft_2-1626127015497.png

Figure 3: ORTModule execution flow

 

Partial Graph Execution

PyTorch executes the model with separate forward and backward calls, whereas ORT represents the model as a single static computation graph. ORTModule implements a partial graph executor to mimic PyTorch’s forward and backward calls. Graph state, such as the stashed intermediate tensor, between a pair of forward and backward calls is captured and shared through RunContext (Figure 3).

Tensor Exchange

The tensors such as module input, outputs, gradients, etc. are exchanged between PyTorch and ORT using DLPack to avoid any memory copy.

 

Unified Memory Allocator

ORTModule uses PyTorch’s allocator for GPU tensor memory management. This is done to avoid having two allocators that can hide free memory from each other leading to inefficient memory utilization and reducing the maximum batch size that can be reached.

 

kshamamsft_3-1626127015501.png

 

Figure 4: Unified memory allocator

 

Composability with Acceleration Libraries

 

Integration with DeepSpeed

ONNX Runtime for PyTorch supports a seamless integration with DeepSpeed to further accelerate distributed training for increasingly large models. The gains compose well as we see significant gains over a variety of Hugging Face models with DeepSpeed and ORT. We see gains ranging from 58% to 86% for Hugging Face models over PyTorch by using DeepSpeed ZeRO Stage 1 with ORT (Figure 5). Currently, ORTModule supports composing with DeepSpeed FP16, ZeRO Stage 1 and 2. Further improvements for ZeRO Stage 2 are in progress.

 

kshamamsft_2-1626135650107.png

 

Figure 5: ORT throughput improvements with DeepSpeed ZeRO Stage 1

 

Mixed precision support

ONNX Runtime supports mixed precision training with a variety of solutions like PyTorch’s native AMP, Nvidia’s Apex O1, as well as with DeepSpeed FP16. This allows the user with flexibility to avoid changing their current set up to bring ORT’s acceleration capabilities to their training workloads.

 

Figure 1 in the "Performance results" section above shows speedup with ORT for Pytorch autocast. We see gains ranging from 11% to 37% by using ONNX Runtime for Pytorch.

 

Figure 6 shows speedup with ORT combined with DeepSpeed FP16 against baseline PyTorch. We see gains ranging from 22% to 86% on popular Hugging Face models.

 

kshamamsft_3-1626135944666.png

 

Figure 6: ORT throughput improvements with DeepSpeed FP16

 

Figure 7 shows speedup for using ORT with NVIDIA’s Apex O1, giving 8% to 23% gains over PyTorch.

.

kshamamsft_4-1626136002444.png

 

Figure 7: ORT throughput improvements with Apex O1 mixed precision

 

Looking Forward

The ONNX Runtime team is working on more exciting optimizations to make training large workloads even faster. ONNX Runtime for PyTorch plans to add support for custom torch.autograd functions which would allow the graph execution to switch back to PyTorch for user-defined autograd functions. This would allow us to support advanced scenarios like horizontal parallelism and Mixture of Expert models. Further improvements and support for DeepSpeed ZeRO Stage 2 and 3 have also been planned for future releases. The team will also be adding samples to the ONNX Runtime Training Examples in the future as we support more types of models and scenarios. Be sure to check out the links below to learn more and get started with ONNX Runtime for PyTorch!

 

Getting Started

Version history
Last update:
‎Jul 13 2021 09:03 AM
Updated by: