9.4 Flexible Vectorization in Xarray Using JAX

Wednesday, 31 January 2024: 9:30 AM
324 (The Baltimore Convention Center)
Sam Levang, Salient Predictions, Cambridge, MA; and F. Bartolić

The Python package xarray has become a core tool for analysis and computation with multi-dimensional weather and climate data. Xarray wraps a wide range of methods from lower-level libraries like NumPy and SciPy with labeled dimensions and metadata. For computations not implemented in the core API, xarray provides the generalized "apply_ufunc" wrapper for custom functions. "apply_ufunc" offers powerful flexibility while maintaining all the benefits of working with xarray objects. However, this can be prohibitively slow on large arrays when the underlying functions themselves are not vectorized, and we are forced to use the Python-level "vectorize" option.

For these cases and other more advanced uses, we present a pattern using JAX as an alternative computational backend, which offers the following benefits:

1. flexible and performant vectorization using the "jit" and "vmap" transformations, which compile Python operations to fast XLA kernels
2. seamless dispatching to GPUs and TPUs
3. automatic differentiation, which opens up optimization and other ML applications

With a few simple lines of code, we can use JAX to bridge the gap between the convenience of xarray objects and highly performant code for complex tasks. We show several examples of how we have applied this pattern to weather forecasting data, associated performance benchmarks, and discuss possibilities for further development into a standalone open-source package.

- Indicates paper has been withdrawn from meeting
- Indicates an Award Winner