JAX and XLA

less than 1 minute read

Some basic about JAX and XLA

1 JAX and XLA

Both are from Google. JAX: Autograd XLA: Accelerated Linear Algebra, primary backend of JAX

JAX is an OSS python library for large-scale machine learning and high-performance scientific computing It Auto-differentiates Numpy and python functions; supports composability of computation transformations ,including automatic differentiation (grad), just-in-time compilation (jit), automatic vectorization (vmap), and parallelization (pmap).

XLA compiler helps scale workloads with minimal code changes across a cluster of rack-scale GPU platforms.

XLA supports key accelerators – GPU, TPU, Trainium

2 NVIDIA Goals for JAX & XLA

Enable/accelerate ML workloads at large-scale for pre-training & post-training workflows on rack-scale/datacenter-grade hardware.

Enable workstation/auto/robotic hardware for scientific/simulation/inference workloads

Address high-profile customers’ requests by contributing into OSS or shared NDA repos.

Support/expand JAX/XLA OSS community, through tight collaboration with Google.

Software Assets/Deliverables OSS repos: MaxText, MaxDiffusion, JAX, XLA

NGC Containers: MaxText 25.10, JAX 25.10

Nvidia-owned OSS repos: JAX Toolbox, JaxPP, MLIR-TRT

JAX Toolbox: nightly containers, recipes, examples, docs

JaxPP: MPMD-based Pipeline Parallelism

MLIR-TRT: compiler/converter to export StableHLO to TRT

3 MaxText and MaxDiffusion

Alt text

Tags:

Categories:

Updated: