Accelerated Python

Python has emerged as the de facto standard language for machine learning, and as a result many of the most cutting edge techniques for GPU programming are exposed via Python libraries or extensions.

NUMBA

Numba is a just-in-time (JIT) compiler for Python. It uses LLVM to compile Python to machine code, which results in much faster execution than any Python interpreter. It also lets you easily parallelize for loops (essentially in the style of OpenMP for C/C++), which is quite handy given that thread-based parallelism in Python itself is heavily handicapped by the global interpreter lock (a mutex that ensures only one thread can run interpreted Python code at a time).

It also supports compilation of GPU kernels. The interface is relatively low level compared to the other acceleration options on this page; it’s not that abstracted from CUDA. In particular, you are responsible for indexing calculations based on the thread and block ids.

Taichi

Taichi is also a JIT compiler for Python, centered on compilation of GPU kernels. It’s geared toward numeric simulation and machine vision tasks, and includes sparse spatial data structures to assist with these applications. Its parallelism model is extremely simple: the outermost for loop is parallelized. This makes it very easy to get started, but hard to adapt for programs that aren’t embarrassingly parallel. It also includes automatic differentiation.

JAX

JAX is the library powering Google’s newest generation of machine learning infrastructure. At heart it’s a JIT compiler with automatic differentiation. The basic interface is essentially that of NumPy. Regular NumPy is much faster than native Python for array computations, because the operations are implemented in C/C++ and compiled to native machine code (ahead of time). JAX’s version of the NumPy interface supports compilation for GPU, but it can also fuse multiple NumPy calls together into singular GPU kernels. This gives the compiler the chance to greatly improve efficiency via loop fusion and rematerialization, among other techniques.

It also features powerful tools for vectorization and parallelism. The vmap primitive lets you automatically vectorize computation over array axes, and pmap lets you parallelize operations across multiple GPUs. Several newer interfaces are being developed for even more advanced parallelism (notably xmap, shmap, and Array parallelism).

The main limitation of JAX is that all array shapes must be known during compile time. This is the case for all common neural network architectures, but isn’t always the case in other scientific code (e.g. neighbor lists in molecular dynamics simulations). JAX also uses a tracing compiler (rather than starting with a Python AST like Numba or Taichi), which means that Python control flow like if, for, or while have to be treated with care. This paradigm (where JAX is only compiling portions of your code) can be used to great effect as a metaprogramming interface, but it takes some getting used to.