Installation
1. Pure JAX Version (Easy / Recommended)
The easiest way to get started with jaxDecomp is via PyPI using the pure JAX backend—no MPI or GPU-specific setup required.
➤ Step-by-step
Install the appropriate JAX wheel:
GPU:
pip install --upgrade "jax[cuda]"
CPU:
pip install --upgrade "jax[cpu]"
Install
jaxDecomp:pip install jaxdecomp
This setup uses the JAX backend by default and is ideal for experimentation, development, and most common research workflows.
2. cuDecomp Backend (Advanced / HPC)
If you’re working on an HPC cluster and need MPI-based communication for large-scale GPU or CPU FFTs, you can build from source with cuDecomp enabled.
➤ Install with cuDecomp
Make sure your environment provides a CUDA-aware MPI toolchain, such as the NVIDIA HPC SDK.
pip install -U pip
pip install git+https://github.com/DifferentiableUniverseInitiative/jaxDecomp -Ccmake.define.JD_CUDECOMP_BACKEND=ON
If CMake cannot find the NVHPC toolchain, set:
export CMAKE_PREFIX_PATH=$CMAKE_PREFIX_PATH:$NVHPC_ROOT/cmake
Then re-run the installation.
Troubleshooting
If JAX complains about incompatibility with cuSparse or any other library, the easiest solution is to install JAX locally using the cuda-local option:
pip install --upgrade "jax[cuda-local]"
Then proceed with installing jaxDecomp with cuDecomp support.
ℹ️ You can read more about cuDecomp setup and tuning at the official cuDecomp GitHub repo.
Machine-Specific Installation Notes
IDRIS Jean Zay HPE SGI 8600 supercomputer
As of February 2026, loading modules in this exact order works:
module load nvidia-compilers/25.1 cuda/12.6.3 openmpi/4.1.6-cuda nccl/2.26.2-1-cuda cudnn cmake
# Install JAX
pip install --upgrade "jax[cuda-local]"
# Install jaxDecomp with cuDecomp
export CMAKE_PREFIX_PATH=$NVHPC_ROOT/cmake # sometimes needed
pip install git+https://github.com/DifferentiableUniverseInitiative/jaxDecomp -Ccmake.define.JD_CUDECOMP_BACKEND=ON
Note: If using only the pure-JAX backend, you do not need NVHPC.
Important for JeanZay users Make sure to load the correct architecture module before loading the
nvidia-compilersmodule. For example for A100 you need to loadmodule load arch/a100first. You also need to set the CXXFLAGS toexport CXXFLAGS="-tp=zen2 -noswitcherror"if you are using the H100 or A100 partition or if you are using AMD CPUs in general. More info in Jean Zay documentation.
Backend Selection at Runtime
Most functions in jaxDecomp support dynamic backend selection via a backend keyword argument. For example:
from jaxdecomp.fft import pfft3d
# Use the default (pure JAX)
k_array = pfft3d(x)
# Use cuDecomp (if compiled and available)
k_array = pfft3d(x, backend="cudecomp")
This applies to:
jaxdecomp.fft.pfft3djaxdecomp.fft.pifft3djaxdecomp.halo_exchange(and other
jaxdecomp.fft.*and transposition routines)
cuDecomp Transpose Communication Backends
If you’re using the cuDecomp backend, you can also manually choose the transpose communication strategy, which may significantly affect performance depending on your cluster hardware and MPI configuration.
Available options:
from jaxdecomp import (
TRANSPOSE_COMM_NCCL,
TRANSPOSE_COMM_MPI_A2A,
TRANSPOSE_COMM_MPI_P2P,
)
# Set transpose communication backend (default is NCCL)
jaxdecomp.config.update('transpose_comm_backend', TRANSPOSE_COMM_NCCL)
jaxdecomp.config.update('transpose_comm_backend', TRANSPOSE_COMM_MPI_P2P)
jaxdecomp.config.update('transpose_comm_backend', TRANSPOSE_COMM_MPI_A2A)
ℹ️ These options are described in more detail in the cuDecomp GitHub documentation.
Notes on Performance
Backend performance varies widely depending on your cluster setup (e.g., interconnect type, topology, NCCL version, MPI implementation). We recommend benchmarking both backends on your target workload to determine the best configuration.