XLA Sharding Configuration Guide

This guide covers XLA sharding configuration when working with jaxDecomp, including Shardy partitioner settings, sharding spec/mesh compatibility requirements, and explicit vs auto axis types.

Shardy Partitioner Configuration

Activating/Deactivating Shardy

The Shardy partitioner is the default in JAX 0.7.0+. You can control it via JAX configuration:

import jax

# Activate Shardy partitioner (default in JAX 0.7.0+)
jax.config.update('jax_use_shardy_partitioner', True)

# Deactivate Shardy partitioner (use legacy GSPMD)
jax.config.update('jax_use_shardy_partitioner', False)

Sharding Spec Must Match Mesh

Warning: PartitionSpec Must Directly Use Mesh Axis Names

When creating a sharding, the PartitionSpec must directly correspond to the mesh axes. Using None for an axis that has size > 1 in the mesh is NOT valid.

Example - Correct vs Incorrect:

import os
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
os.environ["JAX_PLATFORM_NAME"] = "cpu"
os.environ["JAX_PLATFORMS"] = "cpu"

import jax
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding

pdims = (2, 4)
pdim_x, pdim_y = pdims

mesh_2d = jax.make_mesh((pdims), ('x', 'y'))

# OK - spec directly uses mesh axis names
sharding = NamedSharding(mesh_2d, P('x', 'y'))

# NOT OK - 'x' has size 2, cannot use None
sharding = NamedSharding(mesh_2d, P(None, 'y'))

Workaround: Create a Specific Mesh for Partial Sharding

If you need to use None in your spec, create a mesh where that axis has size 1:

# Create a mesh with size 1 on the first axis
mesh_1 = jax.make_mesh((1, pdim_y), ('UNUSED', 'y'), devices=jax.devices()[::pdim_x])

# OK now - UNUSED axis has size 1
sharding = NamedSharding(mesh_1, P(None, 'y'))

Validating Spec/Mesh Compatibility

Use jaxdecomp.validate_spec_matches_mesh to check compatibility:

from jaxdecomp import validate_spec_matches_mesh
from jax.sharding import PartitionSpec as P

def check(spec, mesh, name):
    try:
        validate_spec_matches_mesh(spec, mesh)
        print(f"Spec {spec} is VALID for mesh {name}")
    except Exception as e:
        print(f"Spec {spec} is INVALID for mesh {name}: {e}")

check(P('x', 'y'), mesh_2d, "mesh_2d")   # VALID
check(P(None, 'y'), mesh_2d, "mesh_2d")  # INVALID
check(P(None, 'y'), mesh_1, "mesh_1")    # VALID

Auto vs Explicit Axis Types

jaxDecomp Compatibility Warning

Warning

Important (February 2026): jaxDecomp’s custom_partitioning primitives are not directly compatible with AxisType.Explicit mesh axes. This is due to a limitation in JAX’s custom_partitioning mechanism where callbacks receive meshes with axis types converted to Auto, but XLA’s SPMD partitioner still makes decisions based on the original Explicit types.

Option 2: Use auto_axes Wrapper for Explicit Meshes

If you must use AxisType.Explicit (e.g., for compatibility with other parts of your codebase), you can use JAX’s auto_axes decorator to wrap jaxDecomp functions:

import jax
from jax.experimental import mesh_utils
from jax.sharding import AxisType, auto_axes, reshard
from jax.sharding import PartitionSpec as P
import jaxdecomp as jd

pdims = (2, 4)

# Create an Explicit mesh
mesh_explicit = jax.make_mesh(pdims, ('x', 'y'),
                               axis_types=(AxisType.Explicit, AxisType.Explicit))

# Set up array with explicit sharding
arr = jax.random.normal(jax.random.PRNGKey(0), (8, 8, 8))
jax.set_mesh(mesh_explicit)
arr = reshard(arr, P('x', 'y'))

# Get the expected output sharding for the FFT
out_sharding = jd.get_fft_output_sharding(arr.sharding)

# Wrap jaxDecomp function with auto_axes
@auto_axes
def pfft3d_explicit_safe(x, out_sharding=out_sharding):
    return jd.fft.pfft3d(x)

# Now it works with Explicit mesh
result = pfft3d_explicit_safe(arr, out_sharding=out_sharding)

Note

The auto_axes decorator temporarily converts the mesh to AxisType.Auto for the duration of the wrapped function, allowing custom_partitioning to work correctly.

cuDecomp Backend: Transposed Mesh Required

When using the cuDecomp backend, you must create a transposed mesh:

from jax.experimental import mesh_utils
from jax.sharding import Mesh, AxisType

pdims = (2, 4)
devices = mesh_utils.create_device_mesh(pdims)

# cuDecomp backend requires transposed mesh with ('y', 'x') axis names
mesh = Mesh(devices.T, ('y', 'x'), axis_types=(AxisType.Auto, AxisType.Auto))

# Note: axis_types defaults to Auto when using the Mesh constructor
mesh = Mesh(devices.T, ('y', 'x'))  # Auto by default