jaxvacua.util.vmapping_func_cached#
- vmapping_func_cached(func, in_axes=None, **kwargs)#
Cached variant of
vmapping_func(). Returns a JIT-compiled vmapped function, reusing the previously compiled XLA kernel whenever(func, in_axes, kwargs)match a prior call.- Parameters:
func (
Callable) – Function to be vmapped. Must be hashable; Python bound methods satisfy this requirement.in_axes (
Union[int,Tuple,None]) – Forwarded tojax.vmap. Must be hashable (tuples of ints orNone). DefaultNone.**kwargs (
Any) – Keyword arguments to be bound inside the closure. All values must be hashable.
- Returns:
Callable – JIT-compiled vmapped function, reused from cache when
possible.
- Return type:
Callable
Note
The backing LRU cache is module-level (capacity 256) and persists for the lifetime of the Python process.