Accelerating AI applications using the JAX framework on Azure’s NDm A100 v4 Virtual Machines
Published Feb 07 2023 10:00 AM 7,339 Views

Authors: Daramfon Akpan (Technical Program Manager - Microsoft), Leopold Cambier (Software Engineer - NVIDIA), and Jon Shelley (Principal TPM Manager - Microsoft)



  • Azure’s NDm A100 v4-series virtual machines (VMs) results are in line with the NVIDIA DGX A100 system
  • Jax scales efficiently from 1 to 16 VMs.
  • Azure NDm A100 v4-series is the right VM series to handle your training needs for large deep learning models 


JAX [1] is a new Python machine learning framework initially developed by Google. JAX brings together NumPy, automatic differentiation, distributed computations, just-in-time compilation and fusion in a unified framework for high performance machine learning and other scientific workloads. JAX brings together all those concepts in a high-level language, Python.


Many libraries have been built on top of JAX. For instance, FLAX [2] implements many neural network primitives on top of JAX and is used by many machine learning models. T5X [3] is one of such libraries. T5X[3] is a library for training, evaluating, and inferring with JAX models across many scales, with a focus on Transformer-based language models. T5X has been successfully used to train language models with hundreds of billions of parameters on very large datasets such as the Pile dataset [5].


Reaching new horizons on Azure

We selected the Azure NDm A100 v4-series to run the training benchmarks for the T5 model with the new NVIDIA JAX framework. The NDm A100 v4 series is Azure’s flagship GPU offerings for AI training and deep learning workloads. These virtual machines are powered by NVIDIA A100 80GB Tensor Core GPUs. These instances have the most GPU memory capacity and network bandwidth available on Azure and are backed by NVIDIA InfiniBand HDR connections to support scaling up and out.  


T5 Large and XLarge model description

The Large T5 model has 770 million parameters, 24 encoding and decoding layers and 16 heads. The XL T5 model is much larger with 3 billion parameters with 32 heads. The results below show data from training both models on the wmt_t2t_ende_v003 dataset, a reasonably small machine translation dataset from English to German. This allows for faster download and pre-processing and is enough to extract relevant performance metrics. In practice, larger datasets such as Pile [5] are used for end-to-end training.



Figure 1: Showing performance results for training T5 ”large” model running with NVIDIA JAX container on Azure



Figure 2: Showing performance results for training T5 ”XLarge” model running with NVIDIA JAX container on Azure


The results highlight good scaling from 1 to 16 nodes on both the Large and XLarge T5 models running with JAX on Azure. The Large T5 model has a scaling efficiency of 84% at 16 nodes (128 GPUs) while the XL T5 model has a scaling efficiency of 82% at 16 nodes (128 GPUs). The throughput is within 5% as compared to the NVIDIA DGX A100 data reported here. Customers can now use the JAX framework on Azure when training Large Language Models (LLMs) with solid scaling performance.


We invite you to learn more about how Azure can help you accelerate your JAX workloads using the links below.



[1] Bradbury, James, et al. “{JAX}: composable transformations of {P}ython+{N}um{P}y programs “. Available at (2018)

[2] Heek, Jonathan, et al. “Flax: A Neural Network Library and Ecosystem for JAX”. Available at (2020).

[3] Roberts, Adam, et al. "Scaling Up Models and Data with T5X and seqio." arXiv preprint arXiv:2203.17189 (2022).

[4] Raffel, Colin, et al. "Exploring the limits of transfer learning with a unified text-to-text transformer." J. Mach. Learn. Res. 21.140 (2020): 1-67

[5] Gao, Leo, et al. "The pile: An 800gb dataset of diverse text for language modeling." arXiv preprint arXiv:2101.00027 (2020).

Version history
Last update:
‎Feb 07 2023 12:43 PM
Updated by: