jaxvacua.util#

General-purpose utility functions used across JAXVacua.

Purpose#

Collect package-wide helpers that are intentionally independent of the physics class hierarchy, including random-number handling, batching, serialisation, diagnostics and integer-lattice routines.

Main public API#

  • Random helpers: PRNGSequence, random_uniform, random_integer and related sampling utilities.

  • JAX helpers: vectorisation, caching, flatten/unflatten adapters and small wrappers for compiled workflows.

  • IO and diagnostics: pickle helpers, timing/timeout utilities and progress helpers used by search scripts.

  • Integer and lattice helpers such as extended_euclidean and orthogonal_lattice.

Design notes#

The file is organised into thematic sections separated by banners. Functions kept here should remain generic enough to be reused without importing the large geometry or flux-EFT modules.

PRNG / random sampling#

PRNGSequence([seed])

Splittable JAX PRNG key generator. Adopted from CYJax.

random_uniform(lower_bound, upper_bound[, ...])

Sample uniformly distributed real numbers on [lower_bound, upper_bound).

random_integer(lower_bound, upper_bound[, ...])

Sample uniformly distributed integers on [lower_bound, upper_bound] (inclusive on both ends, matching the convention of jax.random.randint(maxval=upper_bound + 1)).

random_uniform_jit(rns_key, lower_bound, ...)

JIT-compiled version of random_uniform(). rns_key is a JAX PRNG-key array; shape is treated as a static argument so the output shape is fixed at trace time.

random_integer_jit(rns_key, lower_bound, ...)

JIT-compiled version of random_integer().

JIT / vmap helpers#

vmapping_func(func[, in_axes])

Build a JIT-compiled, vmapped wrapper around func with optional keyword arguments frozen inside the closure.

vmapping_func_cached(func[, in_axes])

Cached variant of vmapping_func(). Returns a JIT-compiled vmapped function, reusing the previously compiled XLA kernel whenever (func, in_axes, kwargs) match a prior call.

jit_with_static_args(func[, static_argnums])

Wrap func with jax.jit, treating positional arguments at the indices in static_argnums as compile-time constants.

jit_with_dynamic_static_args(func)

Build a wrapper that re-JITs func on each call, using is_static() to dynamically decide which positional arguments are static. Convenient for prototyping; in production code prefer jit_with_static_args() so the trace cache hits.

is_static(arg)

Heuristic test for whether arg should be treated as a JAX static argument: True iff arg is not a jnp.ndarray.

Array / numerical helpers#

subsets(iterable, n[, as_list])

All size-n subsets of an iterable. Thin wrapper around itertools.combinations with an optional list-eager flag.

flatten(arr[, as_gen, as_np_arr])

Recursively flatten an arbitrarily nested iterable of any depth.

flatten_top(arr[, as_list, N])

Flatten the top N levels (axis 0) of a nested iterable.

check_nan(x)

True iff any element of x is NaN.

compute_evs_hermitian(x)

Eigenvalues of a Hermitian / symmetric matrix.

rank_matrix(x[, tolerance])

Matrix rank under a numerical tolerance.

Pickle I/O#

load_pickle(filen)

Load and return the contents of a (plain, uncompressed) pickle file.

load_zipped_pickle(filen)

Load and return the contents of a gzip-compressed pickle file.

save_zipped_pickle(obj, filen[, protocol])

Pickle obj and write it gzip-compressed to filen.

Dict / DataFrame helpers#

mergeDictionary(dict_1, dict_2)

Merge two dictionaries. For keys present in both, values are concatenated along axis 0 via np.append.

is_outlier(data[, column, percentile_cut])

Boolean outlier mask for data based on a symmetric percentile cut.

Timeout / progress#

progress_bar_jax(arg, transforms)

JAX-host-callback progress printer.

quit_function(fn_name)

Hard-interrupt the main thread. Used by exit_after() as the timer callback when the wrapped function exceeds its budget.

exit_after(s)

Decorator factory: wrap a function to abort if its execution exceeds s seconds. Internally arms a threading.Timer that calls quit_function() to interrupt the main thread on expiry.

Model-data I/O#

save_model_data(data, fname, model_ID, h12)

Write data to files_dir/h12_<h12>/<fname> as a gzip-compressed pickle. Creates intermediate directories as needed. Prompts the user on overwrite.

Pytree flatten / unflatten#

Generic flatten / unflatten functions used by register_pytree_node for the project’s pytree-registered classes (periods, css, FluxEFT, Conifold).

flatten_func(obj)

Flatten obj for the JAX pytree protocol.

unflatten_func_class(aux_data, children, myclass)

Inverse of flatten_func() for a specific class myclass.

Number-theoretic / lattice helpers#

extended_euclidean(w)

Computes Bézout's identity and a unimodular integer basis transformation for an integer array \(w\).

orthogonal_lattice(gens_in)

Returns generators of the integer lattice orthogonal to the lattice spanned by gens_in.