XLA Sharding Configuration Guide
This guide covers XLA sharding configuration when working with jaxDecomp, including Shardy partitioner settings, sharding spec/mesh compatibility requirements, and explicit vs auto axis types.
Auto vs Explicit Axis Types
jaxDecomp Compatibility Warning
Warning
Important (February 2026): jaxDecomp’s custom_partitioning primitives are not directly compatible with AxisType.Explicit mesh axes. This is due to a limitation in JAX’s custom_partitioning mechanism where callbacks receive meshes with axis types converted to Auto, but XLA’s SPMD partitioner still makes decisions based on the original Explicit types.
Option 1: Use Auto Axis Types (Recommended)
The simplest approach is to use AxisType.Auto when working with jaxDecomp:
from jax.sharding import AxisType
# Use Auto axis types
mesh_auto = jax.make_mesh(pdims, ('x', 'y'),
axis_types=(AxisType.Auto, AxisType.Auto))
Option 2: Use auto_axes Wrapper for Explicit Meshes
If you must use AxisType.Explicit (e.g., for compatibility with other parts of your codebase), you can use JAX’s auto_axes decorator to wrap jaxDecomp functions:
import jax
from jax.experimental import mesh_utils
from jax.sharding import AxisType, auto_axes, reshard
from jax.sharding import PartitionSpec as P
import jaxdecomp as jd
pdims = (2, 4)
# Create an Explicit mesh
mesh_explicit = jax.make_mesh(pdims, ('x', 'y'),
axis_types=(AxisType.Explicit, AxisType.Explicit))
# Set up array with explicit sharding
arr = jax.random.normal(jax.random.PRNGKey(0), (8, 8, 8))
jax.set_mesh(mesh_explicit)
arr = reshard(arr, P('x', 'y'))
# Get the expected output sharding for the FFT
out_sharding = jd.get_fft_output_sharding(arr.sharding)
# Wrap jaxDecomp function with auto_axes
@auto_axes
def pfft3d_explicit_safe(x, out_sharding=out_sharding):
return jd.fft.pfft3d(x)
# Now it works with Explicit mesh
result = pfft3d_explicit_safe(arr, out_sharding=out_sharding)
Note
The auto_axes decorator temporarily converts the mesh to AxisType.Auto for the duration of the wrapped function, allowing custom_partitioning to work correctly.
cuDecomp Backend: Transposed Mesh Required
When using the cuDecomp backend, you must create a transposed mesh:
from jax.experimental import mesh_utils
from jax.sharding import Mesh, AxisType
pdims = (2, 4)
devices = mesh_utils.create_device_mesh(pdims)
# cuDecomp backend requires transposed mesh with ('y', 'x') axis names
mesh = Mesh(devices.T, ('y', 'x'), axis_types=(AxisType.Auto, AxisType.Auto))
# Note: axis_types defaults to Auto when using the Mesh constructor
mesh = Mesh(devices.T, ('y', 'x')) # Auto by default