jaxvacua.util.vmapping_func

Contents

jaxvacua.util.vmapping_func#

vmapping_func(func, in_axes=None, **kwargs)#

Build a JIT-compiled, vmapped wrapper around func with optional keyword arguments frozen inside the closure.

Note

Each call constructs a fresh jax.jit(jax.vmap(...)) object; this defeats JAX’s compilation cache and forces XLA recompilation. Use vmapping_func_cached() if you call the same combination repeatedly.

Parameters:
  • func (Callable) – Function to be vmapped.

  • in_axes (Union[int, Tuple, None]) – Forwarded to jax.vmap.

  • **kwargs (Any) – Keyword arguments to be bound inside the closure.

Returns:

Callable – JIT-compiled vmapped function.

Return type:

Callable