jaxvacua.util.random_uniform_jit

jaxvacua.util.random_uniform_jit#

random_uniform_jit(rns_key, lower_bound, upper_bound, shape=(1,))#

JIT-compiled version of random_uniform(). rns_key is a JAX PRNG-key array; shape is treated as a static argument so the output shape is fixed at trace time.

Parameters:
  • rns_key (Array) – JAX PRNG-key array.

  • lower_bound (float) – Lower edge of the sampling interval.

  • upper_bound (float) – Upper edge of the sampling interval (exclusive).

  • shape (Tuple[int, ...]) – Output shape. Default (1,).

Returns:

Array – Uniform samples of shape shape and dtype DTYPE.

Return type:

Array