jaxvacua.util.vmapping_func_cached

jaxvacua.util.vmapping_func_cached#

vmapping_func_cached(func, in_axes=None, **kwargs)#

Cached variant of vmapping_func(). Returns a JIT-compiled vmapped function, reusing the previously compiled XLA kernel whenever (func, in_axes, kwargs) match a prior call.

Parameters:
  • func (Callable) – Function to be vmapped. Must be hashable; Python bound methods satisfy this requirement.

  • in_axes (Union[int, Tuple, None]) – Forwarded to jax.vmap. Must be hashable (tuples of ints or None). Default None.

  • **kwargs (Any) – Keyword arguments to be bound inside the closure. All values must be hashable.

Returns:
  • Callable – JIT-compiled vmapped function, reused from cache when

  • possible.

Return type:

Callable

Note

The backing LRU cache is module-level (capacity 256) and persists for the lifetime of the Python process.