Installation


2. cuDecomp Backend (Advanced / HPC)

If you’re working on an HPC cluster and need MPI-based communication for large-scale GPU or CPU FFTs, you can build from source with cuDecomp enabled.

➤ Install with cuDecomp

Make sure your environment provides a CUDA-aware MPI toolchain, such as the NVIDIA HPC SDK.

pip install -U pip
pip install git+https://github.com/DifferentiableUniverseInitiative/jaxDecomp -Ccmake.define.JD_CUDECOMP_BACKEND=ON

If CMake cannot find the NVHPC toolchain, set:

export CMAKE_PREFIX_PATH=$CMAKE_PREFIX_PATH:$NVHPC_ROOT/cmake

Then re-run the installation.

Troubleshooting

If JAX complains about incompatibility with cuSparse or any other library, the easiest solution is to install JAX locally using the cuda-local option:

pip install --upgrade "jax[cuda-local]"

Then proceed with installing jaxDecomp with cuDecomp support.

ℹ️ You can read more about cuDecomp setup and tuning at the official cuDecomp GitHub repo.


Machine-Specific Installation Notes

IDRIS Jean Zay HPE SGI 8600 supercomputer

As of February 2026, loading modules in this exact order works:

module load nvidia-compilers/25.1 cuda/12.6.3 openmpi/4.1.6-cuda nccl/2.26.2-1-cuda cudnn  cmake
# Install JAX
pip install --upgrade "jax[cuda-local]"

# Install jaxDecomp with cuDecomp
export CMAKE_PREFIX_PATH=$NVHPC_ROOT/cmake # sometimes needed
pip install git+https://github.com/DifferentiableUniverseInitiative/jaxDecomp -Ccmake.define.JD_CUDECOMP_BACKEND=ON

Note: If using only the pure-JAX backend, you do not need NVHPC.

Important for JeanZay users Make sure to load the correct architecture module before loading the nvidia-compilers module. For example for A100 you need to load module load arch/a100 first. You also need to set the CXXFLAGS to export CXXFLAGS="-tp=zen2 -noswitcherror" if you are using the H100 or A100 partition or if you are using AMD CPUs in general. More info in Jean Zay documentation.


Backend Selection at Runtime

Most functions in jaxDecomp support dynamic backend selection via a backend keyword argument. For example:

from jaxdecomp.fft import pfft3d

# Use the default (pure JAX)
k_array = pfft3d(x)

# Use cuDecomp (if compiled and available)
k_array = pfft3d(x, backend="cudecomp")

This applies to:

  • jaxdecomp.fft.pfft3d

  • jaxdecomp.fft.pifft3d

  • jaxdecomp.halo_exchange

  • (and other jaxdecomp.fft.* and transposition routines)


cuDecomp Transpose Communication Backends

If you’re using the cuDecomp backend, you can also manually choose the transpose communication strategy, which may significantly affect performance depending on your cluster hardware and MPI configuration.

Available options:

from jaxdecomp import (
    TRANSPOSE_COMM_NCCL,
    TRANSPOSE_COMM_MPI_A2A,
    TRANSPOSE_COMM_MPI_P2P,
)

# Set transpose communication backend (default is NCCL)
jaxdecomp.config.update('transpose_comm_backend', TRANSPOSE_COMM_NCCL)
jaxdecomp.config.update('transpose_comm_backend', TRANSPOSE_COMM_MPI_P2P)
jaxdecomp.config.update('transpose_comm_backend', TRANSPOSE_COMM_MPI_A2A)

ℹ️ These options are described in more detail in the cuDecomp GitHub documentation.


Notes on Performance

Backend performance varies widely depending on your cluster setup (e.g., interconnect type, topology, NCCL version, MPI implementation). We recommend benchmarking both backends on your target workload to determine the best configuration.