Transpositions in jaxDecomp

Transpositions are a core operation in jaxDecomp, enabling distributed 3D FFTs by realigning data across devices so that each axis can be processed locally. These are global transposes: they reshuffle slices of data between GPUs according to the domain decomposition layout.


What is a Global Transpose?

In a distributed 3D FFT, the algorithm applies a series of 1D FFTs along different axes. Between each FFT, the array must be transposed so that the next axis becomes undistributed and locally contiguous.

For example:

Start → FFT along Z
Transpose Z → Y
FFT along Y
Transpose Y → X
FFT along X

These transpositions change the mapping of the distributed axes while preserving the global data shape.

Visual Illustration

The animation below shows how distributed pencils are rotated during a round-trip FFT. Each step reorients the domain decomposition for the next FFT axis.

Animation of distributed transpositions in jaxDecomp

Contiguous vs Non-Contiguous Transpositions

jaxDecomp supports two modes of transposition:

  • Contiguous: The layout is physically reshuffled (e.g., changing from ZXY to YZX).

  • Non-contiguous: The global axis order is preserved, but the device mapping changes.

In most cases, both perform similarly. Non-contiguous transposes are useful when the logical layout (e.g., for halo exchange or diagnostics) should remain unchanged.

it can be set to False by doing :

jaxdecomp.config.update('transpose_axis_contiguous', False)

API Example

# Manually transpose a distributed array
y_pencil = jaxdecomp.transposeXtoY(x_pencil)
z_pencil = jaxdecomp.transposeYtoZ(y_pencil)

Note: These functions are already called internally by pfft3d and pifft3d. You only need to use them directly for custom workflows—such as I/O reordering, diagnostics, or algorithms requiring specific axis alignments.


Summary

  • Transpositions are required to align each axis for local 1D FFTs in a distributed array.

  • jaxDecomp provides high-level primitives for axis-aligned transpositions.

  • Both contiguous and non-contiguous modes are supported and efficient.

  • The transpose API is fully differentiable and JAX-compatible.

🔄 See Distributed FFT for how these transposes are used in pfft3d. 🧱 See Domain Decomposition to understand how arrays are partitioned across GPUs.