API Reference

FFT

jaxdecomp.fft.fftfreq3d(array: Union[Array, ndarray, bool, number, bool, int, float, complex], d: float = 1.0) Array[source]

Compute the 3D FFT frequency vectors.

Note

The input array must be in the frequency domain, meaning it must be complex. The order of the frequency vectors is always X, Y, Z.

Parameters:
  • array (ArrayLike) – Input array in the frequency domain.

  • d (float, optional) – Sample spacing (default is 1.0).

Returns:

3D FFT frequency vectors.

Return type:

Array

Raises:

ValueError – If the input array is not complex.

Example

>>> k_array = pfft3d(global_array)
>>> kvec = fftfreq3d(k_array)
jaxdecomp.fft.pfft3d(a: Union[Array, ndarray, bool, number, bool, int, float, complex], norm: Optional[str] = 'backward', backend: str = 'JAX') Array[source]

Perform 3D FFT on the input array.

Note

The returned array is transposed compared to the input array. If the input is of shape (X, Y, Z), the output will be in the shape (Y, Z, X).

Parameters:
  • a (ArrayLike) – Input array to transform.

  • norm (Optional[str], optional) – Type of normalization (“backward”, “ortho”, or “forward”), by default “backward”.

  • backend (str, optional) – Backend to use (“JAX” or “cudecomp”), by default “JAX”.

Returns:

Transformed array after 3D FFT.

Return type:

Array

Example

>>> import jax
>>> jax.distributed.initialize()
>>> rank = jax.process_index()
>>> from jax.experimental import mesh_utils
>>> from jax.sharding import Mesh, NamedSharding
>>> from jax.sharding import PartitionSpec as P
>>> global_shape = (16, 16, 16)
>>> pdims = (4, 4)
>>> local_shape = (global_shape[0] // pdims[1], global_shape[1] // pdims[0], global_shape[2])
>>> devices = mesh_utils.create_device_mesh(pdims)
>>> mesh = Mesh(devices.T, axis_names=('z', 'y'))
>>> sharding = NamedSharding(mesh, P('z', 'y'))
>>> global_array = jax.make_array_from_callback(global_shape, sharding, lambda _: jax.random.normal(jax.random.PRNGKey(rank), local_shape))
>>> k_array = pfft3d(global_array)
jaxdecomp.fft.pifft3d(a: Union[Array, ndarray, bool, number, bool, int, float, complex], norm: Optional[str] = 'backward', backend: str = 'JAX') Array[source]

Perform inverse 3D FFT on the input array.

Note

The returned array will have its shape restored back to (X, Y, Z) after the inverse FFT.

Parameters:
  • a (ArrayLike) – Input array to transform.

  • norm (Optional[str], optional) – Type of normalization (“backward”, “ortho”, or “forward”), by default “backward”.

  • backend (str, optional) – Backend to use (“JAX” or “cudecomp”), by default “JAX”.

Returns:

Transformed array after inverse 3D FFT.

Return type:

Array

Example

>>> k_array = pfft3d(global_array)
>>> original_array = pifft3d(k_array)
jaxdecomp.fft.rfftfreq3d(array: Union[Array, ndarray, bool, number, bool, int, float, complex], d: float = 1.0) Array[source]

Compute the 3D real FFT frequency vectors.

Note

The input array must be in the frequency domain, meaning it must be complex. The order of the frequency vectors is always X, Y, Z.

Parameters:
  • array (ArrayLike) – Input array in the frequency domain.

  • d (float, optional) – Sample spacing (default is 1.0).

Returns:

3D real FFT frequency vectors.

Return type:

Array

Raises:

ValueError – If the input array is not complex.

Example

>>> k_array = pfft3d(global_array)
>>> kvec = rfftfreq3d(k_array)

Halo Exchange

jaxdecomp.halo.halo_exchange(x: Array, halo_extents: tuple[int, int], halo_periods: tuple[bool, bool], backend: str = 'jax') Array[source]

Perform a halo exchange operation using the specified backend.

Parameters:
  • x (Array) – Input array for the halo exchange.

  • halo_extents (HaloExtentType) – Tuple specifying the extents of the halo in each dimension.

  • halo_periods (Periodicity) – Tuple specifying the periodicity (True or False) in each dimension.

  • backend (str, optional) – Backend to use for the halo exchange (“jax” or “cudecomp”), by default “jax”.

Returns:

Array after performing the halo exchange.

Return type:

Array

Raises:

ValueError – If an invalid backend is specified.

Example

>>> import jax
>>> from jax import random
>>> from jax.sharding import PartitionSpec as P
>>> from jax.experimental import mesh_utils
>>> from jax.sharding import Mesh, NamedSharding
>>> from jax import shard_map

# Initialize distributed mesh and array >>> jax.distributed.initialize() >>> rank = jax.process_index() >>> pdims = (2, 2) >>> global_shape = (16, 16, 16) >>> local_shape = (global_shape[0] // pdims[0], global_shape[1] // pdims[1], global_shape[2]) >>> devices = mesh_utils.create_device_mesh(pdims) >>> mesh = Mesh(devices.T, axis_names=(‘z’, ‘y’)) >>> sharding = NamedSharding(mesh, P(‘z’, ‘y’))

# Create global array with random values >>> global_array = jax.make_array_from_callback(global_shape, sharding, lambda idx: random.normal(random.PRNGKey(rank), local_shape))

# Define padding for halo exchange >>> padding = [(1, 1), (1, 1), (0, 0)]

# Pad the array with custom padding >>> @partial(shard_map, mesh=mesh, in_specs=P(‘z’, ‘y’), out_specs=P(‘z’, ‘y’)) >>> def pad(arr): … return jnp.pad(arr, padding, mode=’linear_ramp’, end_values=20)

>>> padded_array = pad(global_array)

# Perform halo exchange using JAX backend >>> halo_extents = (1, 1) >>> halo_periods = (True, True) >>> updated_array = halo_exchange(padded_array, halo_extents, halo_periods, backend=”jax”)

Transpositions

jaxdecomp.transpose.transposeXtoY(x: Array, backend: str = 'jax') Array[source]

Transpose the input array from X-pencil to Y-pencil.

Note: Expects input in Z Y X format and returns output in X Z Y format.

Parameters:
  • x (Array) – Input array in Z Y X format to transpose.

  • backend (str, optional) – Backend to use for the transpose operation (“jax” or “cudecomp”), by default “jax”.

Returns:

Transposed array in X Z Y format.

Return type:

Array

Raises:

ValueError – If the backend is invalid.

Example

>>> import jax
>>> from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
>>> from jax.experimental import mesh_utils
>>> import jaxdecomp
>>> global_shape = (16, 32, 64)
>>> pdims = (2, 4)
>>> devices = mesh_utils.create_device_mesh(pdims)
>>> mesh = Mesh(devices.T, axis_names=('z', 'y'))
>>> sharding = NamedSharding(mesh, P('z', 'y'))
>>> local_shape = (global_shape[0] // pdims[1], global_shape[1] // pdims[0], global_shape[2])
>>> global_array = jax.make_array_from_callback(global_shape,
                                                sharding,
                                                data_callback=lambda _: jax.random.normal(jax.random.PRNGKey(0), local_shape))
>>> transposed_array = jaxdecomp.transposeXtoY(global_array)
jaxdecomp.transpose.transposeXtoZ(x: Array, backend: str = 'jax') Array[source]

Transpose the input array from X-pencil to Z-pencil.

Note: Expects input in Z Y X format and returns output in Z X Y format.

Parameters:
  • x (Array) – Input array in Z Y X format to transpose.

  • backend (str, optional) – Backend to use for the transpose operation, by default “jax”.

Returns:

Transposed array in Z X Y format.

Return type:

Array

Raises:
  • ValueError – If the backend is invalid.

  • NotImplementedError – If the backend does not support the operation (e.g., ‘cudecomp’ for x_z).

Example

>>> import jaxdecomp
>>> transposed_array = jaxdecomp.transposeXtoZ(global_array)
jaxdecomp.transpose.transposeYtoX(x: Array, backend: str = 'jax') Array[source]

Transpose the input array from Y-pencil to X-pencil.

Note: Expects input in X Z Y format and returns output in Z Y X format.

Parameters:
  • x (Array) – Input array in X Z Y format to transpose.

  • backend (str, optional) – Backend to use for the transpose operation (“jax” or “cudecomp”), by default “jax”.

Returns:

Transposed array in Z Y X format.

Return type:

Array

Raises:

ValueError – If the backend is invalid.

Example

>>> import jaxdecomp
>>> transposed_array = jaxdecomp.transposeYtoX(global_array)
jaxdecomp.transpose.transposeYtoZ(x: Array, backend: str = 'jax') Array[source]

Transpose the input array from Y-pencil to Z-pencil.

Note: Expects input in X Z Y format and returns output in Y X Z format.

Parameters:
  • x (Array) – Input array in X Z Y format to transpose.

  • backend (str, optional) – Backend to use for the transpose operation (“jax” or “cudecomp”), by default “jax”.

Returns:

Transposed array in Y X Z format.

Return type:

Array

Raises:

ValueError – If the backend is invalid.

Example

>>> import jaxdecomp
>>> transposed_array = jaxdecomp.transposeYtoZ(global_array)
jaxdecomp.transpose.transposeZtoX(x: Array, backend: str = 'jax') Array[source]

Transpose the input array from Z-pencil to X-pencil.

Note: Expects input in Y Z X format and returns output in X Z Y format.

Parameters:
  • x (Array) – Input array in Y Z X format to transpose.

  • backend (str, optional) – Backend to use for the transpose operation, by default “jax”.

Returns:

Transposed array in X Z Y format.

Return type:

Array

Raises:
  • ValueError – If the backend is invalid.

  • NotImplementedError – If the backend does not support the operation (e.g., ‘cudecomp’ for z_x).

Example

>>> import jaxdecomp
>>> transposed_array = jaxdecomp.transposeZtoX(global_array)
jaxdecomp.transpose.transposeZtoY(x: Array, backend: str = 'jax') Array[source]

Transpose the input array from Z-pencil to Y-pencil.

Note: Expects input in Y X Z format and returns output in X Z Y format.

Parameters:
  • x (Array) – Input array in Y X Z format to transpose.

  • backend (str, optional) – Backend to use for the transpose operation (“jax” or “cudecomp”), by default “jax”.

Returns:

Transposed array in X Z Y format.

Return type:

Array

Raises:

ValueError – If the backend is invalid.

Example

>>> import jaxdecomp
>>> transposed_array = jaxdecomp.transposeZtoY(global_array)