jaxvacua.util.jit_with_static_args

jaxvacua.util.jit_with_static_args#

jit_with_static_args(func, static_argnums=())#

Wrap func with jax.jit, treating positional arguments at the indices in static_argnums as compile-time constants.

Parameters:
  • func (Callable) – The function to be JIT-compiled.

  • static_argnums (Tuple[int, ...]) – Positions of arguments to mark as static. Default ().

Returns:

Callable – JIT-compiled wrapper around func.

Return type:

Callable

Example

def f(x, y, n): return x + y + n
jit_f = jit_with_static_args(f, static_argnums=(2,))