jax (python) is a successor to classic python/numpy autograd. It includes various code optimisation, jit-compilations, differentiating and vectorizing.

So, a numerical library with certain high performance machine-learning affordances. Note, it is not a deep learning framework per se, but rather the producer species at lowest trophic level of a deep learning ecosystem. For information frameworks built upon it, read on to later sections.

The official pitch:

JAX can automatically differentiate native Python and NumPy functions. It can differentiate through loops, branches, recursion, and closures, and it can take derivatives of derivatives of derivatives. It supports reverse-mode differentiation (a.k.a. backpropagation) via grad as well as forward-mode differentiation, and the two can be composed arbitrarily to any order.

What’s new is that JAX uses XLA to compile and run your NumPy programs on GPUs and TPUs. Compilation happens under the hood by default, with library calls getting just-in-time compiled and executed. But JAX also lets you just-in-time compile your own Python functions into XLA-optimized kernels using a one-function API, jit. Compilation and automatic differentiation can be composed arbitrarily, so you can express sophisticated algorithms and get maximal performance without leaving Python.

Dig a little deeper, and you’ll see that JAX is really an extensible system for composable function transformations. Both grad and jit are instances of such transformations. Another is vmap for automatic vectorization, with more to come.

This is a research project, not an official Google product. Expect bugs and sharp edges. Please help by trying it out, reporting bugs, and letting us know what you think!

AFAICT the conda installation command is

conda install -c conda-forge jaxlib

You don’t know jax is a popular intro.


It has idioms that are not obvious. For me it was not clear how to use batch vectorizing and functional-style application of structures:

One thing I see often in examples is

from jax.config import config

Do I need to care about it? tl;dr omnistaging is good and necessary and also switched on by default on recent jax, so that line is simply being careful and likely unneeded.

Deep learning frameworks


Over at Deepmind there is Haiku, which looks nifty. Its documentation is much more complete than flax so I will be auditioning it for my next project.

Related, at least organisationally, is rlax the jax reinforcement-learning library from the same company.


Flax is I think the de facto standard deep learning library. Documentation is sparse, but a recent design doc seems canonical. The documentation is not especially coherent (e.g. Why do some modules assume batching and other not? No hints) but it more or less can be cargo-culted and you can ignore the quirks.

See also the following WIP documentation notebooks

Those answered some questions, but I still have questions left over due to various annoying rough edges and non-obvious gotchas. For example, if you miss a parameter needed for a given model, the error is FilteredStackTrace: AssertionError: Need PRNG for "params".

There are some good examples in the repository.

With those caveats about documentation, flax is still not bad because the underlying jax debugging experience is transparent and easy. This is still an OK option.

Probabilistic programming frameworks


Numpyro seems to be the dominant probabilistic programming system. It is a jax port/implementation/something of the pytorch classic, Pyro.

More fringe but possibly interesting, jax-md does molecular dynamics. ladax “LADAX: Layers of distributions using FLAX/JAX” does some kind of latent RV something.


The creators of Stheno eem to be Invenia, some of whose staff I am connected to in various indirect ways. It targets jax as one of several backends via a generic backend library, wesselb/lab: A generic interface for linear algebra backends.

Placeholder; details TBD.

graph networks

Warning! Experimental comments system! If is does not work for you, let me know via the contact form.

No comments yet!

GitHub-flavored Markdown & a sane subset of HTML is supported.