JAX and XLA
Some basic about JAX and XLA
1 JAX and XLA
Both are from Google. JAX: Autograd XLA: Accelerated Linear Algebra, primary backend of JAX
JAX is an OSS python library for large-scale machine learning and high-performance scientific computing It Auto-differentiates Numpy and python functions; supports composability of computation transformations ,including automatic differentiation (grad), just-in-time compilation (jit), automatic vectorization (vmap), and parallelization (pmap).
XLA compiler helps scale workloads with minimal code changes across a cluster of rack-scale GPU platforms.
XLA supports key accelerators – GPU, TPU, Trainium
2 NVIDIA Goals for JAX & XLA
Enable/accelerate ML workloads at large-scale for pre-training & post-training workflows on rack-scale/datacenter-grade hardware.
Enable workstation/auto/robotic hardware for scientific/simulation/inference workloads
Address high-profile customers’ requests by contributing into OSS or shared NDA repos.
Support/expand JAX/XLA OSS community, through tight collaboration with Google.
Software Assets/Deliverables OSS repos: MaxText, MaxDiffusion, JAX, XLA
NGC Containers: MaxText 25.10, JAX 25.10
Nvidia-owned OSS repos: JAX Toolbox, JaxPP, MLIR-TRT
JAX Toolbox: nightly containers, recipes, examples, docs
JaxPP: MPMD-based Pipeline Parallelism
MLIR-TRT: compiler/converter to export StableHLO to TRT
3 MaxText and MaxDiffusion
