Home

Translations: English | 中文

Learning JAX: JAX and Flax NNX Tutorials

This is a complete collection of JAX and Flax NNX tutorials. Work through them to build high-performance ML models with the JAX ecosystem.

The material comes from Google’s Learning-JAX project and covers topics from fundamentals to advanced workflows.

✨ JAX Tutorial Series

About JAX

JAX is a Python library for high-performance numerical computing, ideal for machine learning research. It provides:

  • Automatic differentiation via jax.grad
  • JIT compilation via jax.jit
  • Automatic vectorization via jax.vmap
  • Parallelism with straightforward distributed training

Key Components

  • JAX: core numerical library
  • Flax NNX: modern neural network library
  • Optax: optimizer library
  • Orbax: checkpointing library
  • Grain: data loading toolkit
  • Chex: testing and reliability tools

Resources