Transpositions in jaxDecomp
Transpositions are a core operation in jaxDecomp, enabling distributed 3D FFTs by realigning data across devices so that each axis can be processed locally. These are global transposes: they reshuffle slices of data between GPUs according to the domain decomposition layout.
What is a Global Transpose?
In a distributed 3D FFT, the algorithm applies a series of 1D FFTs along different axes. Between each FFT, the array must be transposed so that the next axis becomes undistributed and locally contiguous.
For example:
Start → FFT along Z
Transpose Z → Y
FFT along Y
Transpose Y → X
FFT along X
These transpositions change the mapping of the distributed axes while preserving the global data shape.
Visual Illustration
The animation below shows how distributed pencils are rotated during a round-trip FFT. Each step reorients the domain decomposition for the next FFT axis.

Contiguous vs Non-Contiguous Transpositions
jaxDecomp supports two modes of transposition:
Contiguous: The layout is physically reshuffled (e.g., changing from
ZXYtoYZX).Non-contiguous: The global axis order is preserved, but the device mapping changes.
In most cases, both perform similarly. Non-contiguous transposes are useful when the logical layout (e.g., for halo exchange or diagnostics) should remain unchanged.
it can be set to False by doing :
jaxdecomp.config.update('transpose_axis_contiguous', False)
API Example
# Manually transpose a distributed array
y_pencil = jaxdecomp.transposeXtoY(x_pencil)
z_pencil = jaxdecomp.transposeYtoZ(y_pencil)
Note: These functions are already called internally by
pfft3dandpifft3d. You only need to use them directly for custom workflows—such as I/O reordering, diagnostics, or algorithms requiring specific axis alignments.
Summary
Transpositions are required to align each axis for local 1D FFTs in a distributed array.
jaxDecompprovides high-level primitives for axis-aligned transpositions.Both contiguous and non-contiguous modes are supported and efficient.
The transpose API is fully differentiable and JAX-compatible.
🔄 See Distributed FFT for how these transposes are used in
pfft3d. 🧱 See Domain Decomposition to understand how arrays are partitioned across GPUs.