Understanding Domain Decomposition in jaxDecomp
jaxDecomp supports both slab and pencil domain decompositions through the pdims argument. This determines how your 3D array is partitioned across devices.
What is pdims?
The pdims parameter defines the decomposition of your domain:
pdims=(1, N)or(N, 1)→ Slab decompositionpdims=(M, N)where both M, N > 1 → Pencil decomposition
Slab vs Pencil: Tradeoffs
Feature |
Slab Decomposition |
Pencil Decomposition |
|---|---|---|
Faster per FFT call |
✅ Often faster |
❌ Slightly slower |
Accuracy |
⚠️ Can be slightly lower |
✅ Slightly better |
Slab Decomposition
Slab decomposition is typically faster per FFT and easier to configure. It works well for:
Prototyping
Small-to-medium simulations
GPU-limited environments
pdims = (1, 8) # 8-GPU slab decomposition along the second axis
Pencil Decomposition
Pencil decomposition enables better load balancing and scalability. It is ideal for:
Large-scale simulations (e.g., 2048³ grids)
Production workloads
Higher accuracy in tightly coupled FFT pipelines
pdims = (2, 4) # 8-GPU pencil decomposition
Dynamically Generating pdims
You can programmatically compute pdims based on the number of devices available:
import jax
device_count = jax.device_count() # Total number of GPUs
assert device_count % 2 == 0, "Need an even number of devices for 2D mesh"
# Example: 2 rows, N/2 columns
pdims = (2, device_count // 2)
You can experiment with other factorizations depending on your topology. The goal is to create a pdims = (Px, Py) such that Px * Py == jax.device_count().
Recommendation
There is no universal “best” choice — we recommend trying both. For most scientific simulations, the accuracy difference is small, but pencil decompositions are often more scalable in the long run.
Creating the JAX Mesh and Sharding
Once pdims is defined, use it to create a JAX mesh and sharding spec:
from jax.sharding import NamedSharding, PartitionSpec as P
mesh = jax.make_mesh(pdims, axis_names=('x', 'y'))
sharding = NamedSharding(mesh, P('x', 'y'))
This sharding object can then be used with distributed arrays, FFTs, halo exchanges, and more.
TL;DR
Use
pdims = (1, N)for simpler and faster setups.Use
pdims = (M, N)(M, N > 1) for large simulations that need scalability.Pencil decompositions require more transposes but enable more parallelism.
Choose
pdimsbased on your hardware and workload—there’s no one-size-fits-all.