jaxvacua.util.jit_with_static_args#
- jit_with_static_args(func, static_argnums=())#
Wrap
funcwithjax.jit, treating positional arguments at the indices instatic_argnumsas compile-time constants.- Parameters:
func (
Callable) – The function to be JIT-compiled.static_argnums (
Tuple[int,...]) – Positions of arguments to mark as static. Default().
- Returns:
Callable – JIT-compiled wrapper around
func.- Return type:
Callable
Example
def f(x, y, n): return x + y + n jit_f = jit_with_static_args(f, static_argnums=(2,))