For these cases and other more advanced uses, we present a pattern using JAX as an alternative computational backend, which offers the following benefits:
1. flexible and performant vectorization using the "jit" and "vmap" transformations, which compile Python operations to fast XLA kernels
2. seamless dispatching to GPUs and TPUs
3. automatic differentiation, which opens up optimization and other ML applications
With a few simple lines of code, we can use JAX to bridge the gap between the convenience of xarray objects and highly performant code for complex tasks. We show several examples of how we have applied this pattern to weather forecasting data, associated performance benchmarks, and discuss possibilities for further development into a standalone open-source package.
 - Indicates paper has been withdrawn from meeting
 - Indicates paper has been withdrawn from meeting - Indicates an Award Winner
 - Indicates an Award Winner