API Reference
FFT
- jaxdecomp.fft.fftfreq3d(array: Union[Array, ndarray, bool, number, bool, int, float, complex], d: float = 1.0) Array[source]
Compute the 3D FFT frequency vectors.
Note
The input array must be in the frequency domain, meaning it must be complex. The order of the frequency vectors is always X, Y, Z.
- Parameters:
array (ArrayLike) – Input array in the frequency domain.
d (float, optional) – Sample spacing (default is 1.0).
- Returns:
3D FFT frequency vectors.
- Return type:
Array
- Raises:
ValueError – If the input array is not complex.
Example
>>> k_array = pfft3d(global_array) >>> kvec = fftfreq3d(k_array)
- jaxdecomp.fft.pfft3d(a: Union[Array, ndarray, bool, number, bool, int, float, complex], norm: Optional[str] = 'backward', backend: str = 'JAX') Array[source]
Perform 3D FFT on the input array.
Note
The returned array is transposed compared to the input array. If the input is of shape (X, Y, Z), the output will be in the shape (Y, Z, X).
- Parameters:
a (ArrayLike) – Input array to transform.
norm (Optional[str], optional) – Type of normalization (“backward”, “ortho”, or “forward”), by default “backward”.
backend (str, optional) – Backend to use (“JAX” or “cudecomp”), by default “JAX”.
- Returns:
Transformed array after 3D FFT.
- Return type:
Array
Example
>>> import jax >>> jax.distributed.initialize() >>> rank = jax.process_index() >>> from jax.experimental import mesh_utils >>> from jax.sharding import Mesh, NamedSharding >>> from jax.sharding import PartitionSpec as P >>> global_shape = (16, 16, 16) >>> pdims = (4, 4) >>> local_shape = (global_shape[0] // pdims[1], global_shape[1] // pdims[0], global_shape[2]) >>> devices = mesh_utils.create_device_mesh(pdims) >>> mesh = Mesh(devices.T, axis_names=('z', 'y')) >>> sharding = NamedSharding(mesh, P('z', 'y')) >>> global_array = jax.make_array_from_callback(global_shape, sharding, lambda _: jax.random.normal(jax.random.PRNGKey(rank), local_shape)) >>> k_array = pfft3d(global_array)
- jaxdecomp.fft.pifft3d(a: Union[Array, ndarray, bool, number, bool, int, float, complex], norm: Optional[str] = 'backward', backend: str = 'JAX') Array[source]
Perform inverse 3D FFT on the input array.
Note
The returned array will have its shape restored back to (X, Y, Z) after the inverse FFT.
- Parameters:
a (ArrayLike) – Input array to transform.
norm (Optional[str], optional) – Type of normalization (“backward”, “ortho”, or “forward”), by default “backward”.
backend (str, optional) – Backend to use (“JAX” or “cudecomp”), by default “JAX”.
- Returns:
Transformed array after inverse 3D FFT.
- Return type:
Array
Example
>>> k_array = pfft3d(global_array) >>> original_array = pifft3d(k_array)
- jaxdecomp.fft.rfftfreq3d(array: Union[Array, ndarray, bool, number, bool, int, float, complex], d: float = 1.0) Array[source]
Compute the 3D real FFT frequency vectors.
Note
The input array must be in the frequency domain, meaning it must be complex. The order of the frequency vectors is always X, Y, Z.
- Parameters:
array (ArrayLike) – Input array in the frequency domain.
d (float, optional) – Sample spacing (default is 1.0).
- Returns:
3D real FFT frequency vectors.
- Return type:
Array
- Raises:
ValueError – If the input array is not complex.
Example
>>> k_array = pfft3d(global_array) >>> kvec = rfftfreq3d(k_array)
Halo Exchange
- jaxdecomp.halo.halo_exchange(x: Array, halo_extents: tuple[int, int], halo_periods: tuple[bool, bool], backend: str = 'jax') Array[source]
Perform a halo exchange operation using the specified backend.
- Parameters:
x (Array) – Input array for the halo exchange.
halo_extents (HaloExtentType) – Tuple specifying the extents of the halo in each dimension.
halo_periods (Periodicity) – Tuple specifying the periodicity (True or False) in each dimension.
backend (str, optional) – Backend to use for the halo exchange (“jax” or “cudecomp”), by default “jax”.
- Returns:
Array after performing the halo exchange.
- Return type:
Array
- Raises:
ValueError – If an invalid backend is specified.
Example
>>> import jax >>> from jax import random >>> from jax.sharding import PartitionSpec as P >>> from jax.experimental import mesh_utils >>> from jax.sharding import Mesh, NamedSharding >>> from jax import shard_map
# Initialize distributed mesh and array >>> jax.distributed.initialize() >>> rank = jax.process_index() >>> pdims = (2, 2) >>> global_shape = (16, 16, 16) >>> local_shape = (global_shape[0] // pdims[0], global_shape[1] // pdims[1], global_shape[2]) >>> devices = mesh_utils.create_device_mesh(pdims) >>> mesh = Mesh(devices.T, axis_names=(‘z’, ‘y’)) >>> sharding = NamedSharding(mesh, P(‘z’, ‘y’))
# Create global array with random values >>> global_array = jax.make_array_from_callback(global_shape, sharding, lambda idx: random.normal(random.PRNGKey(rank), local_shape))
# Define padding for halo exchange >>> padding = [(1, 1), (1, 1), (0, 0)]
# Pad the array with custom padding >>> @partial(shard_map, mesh=mesh, in_specs=P(‘z’, ‘y’), out_specs=P(‘z’, ‘y’)) >>> def pad(arr): … return jnp.pad(arr, padding, mode=’linear_ramp’, end_values=20)
>>> padded_array = pad(global_array)
# Perform halo exchange using JAX backend >>> halo_extents = (1, 1) >>> halo_periods = (True, True) >>> updated_array = halo_exchange(padded_array, halo_extents, halo_periods, backend=”jax”)
Transpositions
- jaxdecomp.transpose.transposeXtoY(x: Array, backend: str = 'jax') Array[source]
Transpose the input array from X-pencil to Y-pencil.
Note: Expects input in Z Y X format and returns output in X Z Y format.
- Parameters:
x (Array) – Input array in Z Y X format to transpose.
backend (str, optional) – Backend to use for the transpose operation (“jax” or “cudecomp”), by default “jax”.
- Returns:
Transposed array in X Z Y format.
- Return type:
Array
- Raises:
ValueError – If the backend is invalid.
Example
>>> import jax >>> from jax.sharding import Mesh, NamedSharding, PartitionSpec as P >>> from jax.experimental import mesh_utils >>> import jaxdecomp >>> global_shape = (16, 32, 64) >>> pdims = (2, 4) >>> devices = mesh_utils.create_device_mesh(pdims) >>> mesh = Mesh(devices.T, axis_names=('z', 'y')) >>> sharding = NamedSharding(mesh, P('z', 'y')) >>> local_shape = (global_shape[0] // pdims[1], global_shape[1] // pdims[0], global_shape[2]) >>> global_array = jax.make_array_from_callback(global_shape, sharding, data_callback=lambda _: jax.random.normal(jax.random.PRNGKey(0), local_shape)) >>> transposed_array = jaxdecomp.transposeXtoY(global_array)
- jaxdecomp.transpose.transposeXtoZ(x: Array, backend: str = 'jax') Array[source]
Transpose the input array from X-pencil to Z-pencil.
Note: Expects input in Z Y X format and returns output in Z X Y format.
- Parameters:
x (Array) – Input array in Z Y X format to transpose.
backend (str, optional) – Backend to use for the transpose operation, by default “jax”.
- Returns:
Transposed array in Z X Y format.
- Return type:
Array
- Raises:
ValueError – If the backend is invalid.
NotImplementedError – If the backend does not support the operation (e.g., ‘cudecomp’ for x_z).
Example
>>> import jaxdecomp >>> transposed_array = jaxdecomp.transposeXtoZ(global_array)
- jaxdecomp.transpose.transposeYtoX(x: Array, backend: str = 'jax') Array[source]
Transpose the input array from Y-pencil to X-pencil.
Note: Expects input in X Z Y format and returns output in Z Y X format.
- Parameters:
x (Array) – Input array in X Z Y format to transpose.
backend (str, optional) – Backend to use for the transpose operation (“jax” or “cudecomp”), by default “jax”.
- Returns:
Transposed array in Z Y X format.
- Return type:
Array
- Raises:
ValueError – If the backend is invalid.
Example
>>> import jaxdecomp >>> transposed_array = jaxdecomp.transposeYtoX(global_array)
- jaxdecomp.transpose.transposeYtoZ(x: Array, backend: str = 'jax') Array[source]
Transpose the input array from Y-pencil to Z-pencil.
Note: Expects input in X Z Y format and returns output in Y X Z format.
- Parameters:
x (Array) – Input array in X Z Y format to transpose.
backend (str, optional) – Backend to use for the transpose operation (“jax” or “cudecomp”), by default “jax”.
- Returns:
Transposed array in Y X Z format.
- Return type:
Array
- Raises:
ValueError – If the backend is invalid.
Example
>>> import jaxdecomp >>> transposed_array = jaxdecomp.transposeYtoZ(global_array)
- jaxdecomp.transpose.transposeZtoX(x: Array, backend: str = 'jax') Array[source]
Transpose the input array from Z-pencil to X-pencil.
Note: Expects input in Y Z X format and returns output in X Z Y format.
- Parameters:
x (Array) – Input array in Y Z X format to transpose.
backend (str, optional) – Backend to use for the transpose operation, by default “jax”.
- Returns:
Transposed array in X Z Y format.
- Return type:
Array
- Raises:
ValueError – If the backend is invalid.
NotImplementedError – If the backend does not support the operation (e.g., ‘cudecomp’ for z_x).
Example
>>> import jaxdecomp >>> transposed_array = jaxdecomp.transposeZtoX(global_array)
- jaxdecomp.transpose.transposeZtoY(x: Array, backend: str = 'jax') Array[source]
Transpose the input array from Z-pencil to Y-pencil.
Note: Expects input in Y X Z format and returns output in X Z Y format.
- Parameters:
x (Array) – Input array in Y X Z format to transpose.
backend (str, optional) – Backend to use for the transpose operation (“jax” or “cudecomp”), by default “jax”.
- Returns:
Transposed array in X Z Y format.
- Return type:
Array
- Raises:
ValueError – If the backend is invalid.
Example
>>> import jaxdecomp >>> transposed_array = jaxdecomp.transposeZtoY(global_array)