jaxvacua.util.jit_with_dynamic_static_args

jaxvacua.util.jit_with_dynamic_static_args#

jit_with_dynamic_static_args(func)#

Build a wrapper that re-JITs func on each call, using is_static() to dynamically decide which positional arguments are static. Convenient for prototyping; in production code prefer jit_with_static_args() so the trace cache hits.

Parameters:

func (Callable) – The function to be JIT-compiled.

Returns:

Callable – Wrapper that rebuilds the JIT plan per call.

Return type:

Callable