jaxvacua.util.flatten_func

Contents

jaxvacua.util.flatten_func#

flatten_func(obj)#

Flatten obj for the JAX pytree protocol.

Splits obj.__dict__ into:
  • children — values that are JAX arrays / pytrees and should be traced;

  • aux_data — names of those children, plus a flat list of (key, value) pairs for static (non-traced) attributes.

The classification is: a key in _PYTREE_IGNORE is dropped entirely (cache / scratch state, not part of the pytree); otherwise a value is static iff its Python type is str / bool or its key appears in _STATIC_KEYS; everything else is a traced child.

Parameters:

obj (Any) – Instance to flatten. Any class registered with jax.tree_util.register_pytree_node().

Returns:

tuple(children, aux_data) — the standard pytree flatten output.

Return type:

Tuple[Tuple[Any, ...], Tuple[Any, ...]]