jaxvacua.util.PRNGSequence

jaxvacua.util.PRNGSequence#

class PRNGSequence(seed=42)#

Bases: object

Splittable JAX PRNG key generator. Adopted from CYJax.

Each call to next(rns) splits the internal key and returns one of the halves, leaving the other as the new internal state. Use this whenever a function needs a fresh subkey — guarantees deterministic, non-overlapping streams of random numbers across the whole process lifetime.

__init__(seed=42)#
Parameters:

seed (Union[int, Array]) – Either a Python int (used as a seed for jax.random.PRNGKey) or an existing JAX PRNG key that becomes the initial internal state. Defaults to 42.

Example

rns = PRNGSequence(42)
key = next(rns)

Methods

__init__([seed])