JAX can automatically differentiate native Python and NumPy functions. It can differentiate through loops, branches, recursion, and closures, and it can take derivatives of derivatives of derivatives. It supports reverse-mode differentiation (a.k.a. backpropagation) via grad as well as forward-mode differentiation, and the two can be composed arbitrarily to any order.
What’s new is that JAX uses XLA to compile and run your NumPy programs on GPUs and TPUs. Compilation happens under the hood by default, with library calls getting just-in-time compiled and executed. But JAX also lets you just-in-time compile your own Python functions into XLA-optimized kernels using a one-function API,
jit. Compilation and automatic differentiation can be composed arbitrarily, so you can express sophisticated algorithms and get maximal performance without leaving Python.
Dig a little deeper, and you’ll see that JAX is really an extensible system for composable function transformations. Both
jitare instances of such transformations. Another is
vmapfor automatic vectorization, with more to come.
This is a research project, not an official Google product. Expect bugs and sharp edges. Please help by trying it out, reporting bugs, and letting us know what you think!
AFAICT the conda installation command is
conda install -c conda-forge jaxlib
You don’t know jax is a popular intro.
It has various things built in whose use is not obvious, including batch vectorizing and functional-style applications.
One thing I see often is
from jax.config import config config.enable_omnistaging()
Do I need to care about it? omnistaging is not obvious. tl;dr omnistaging is good and necessary and also switched on by default on recent jax, so that line is simply being careful.
Over at Deepmind there is Haiku, which looks nifty. Its documentation is much more complete than flax. Related, at least organisationally, is rlax the reinforcement-learning library from the same company.
There is a fully-featured deep learning library called flax available. Documentation is sparse, but the current design doc seems canonical. The documentation is not especially coherent (e.g. Why do some modules assume batching and other not? No hints.) but it more or less can be cargo-culted.
See also the following WIP documentation notebooks
- flax/flax_basics.ipynb at 1e972542a92fa69f78d78dc6a07d6acaa7c5eb01 · google/flax
- Flax 2 ("Linen") - Colaboratory
They answered some questions, but I still had questions left over due to various annoying rough edges and non-obvious gotchas.
For example, if you miss a parameter needed for a given model, the error is
FilteredStackTrace: AssertionError: Need PRNG for "params".
There are some good examples in the repository.