jaxvacua.util.vmapping_func#
- vmapping_func(func, in_axes=None, **kwargs)#
Build a JIT-compiled, vmapped wrapper around
funcwith optional keyword arguments frozen inside the closure.Note
Each call constructs a fresh
jax.jit(jax.vmap(...))object; this defeats JAX’s compilation cache and forces XLA recompilation. Usevmapping_func_cached()if you call the same combination repeatedly.- Parameters:
func (
Callable) – Function to be vmapped.in_axes (
Union[int,Tuple,None]) – Forwarded tojax.vmap.**kwargs (
Any) – Keyword arguments to be bound inside the closure.
- Returns:
Callable – JIT-compiled vmapped function.
- Return type:
Callable