jaxvacua.util.flatten_func#
- flatten_func(obj)#
Flatten
objfor the JAX pytree protocol.- Splits
obj.__dict__into: children— values that are JAX arrays / pytrees and should be traced;aux_data— names of those children, plus a flat list of(key, value)pairs for static (non-traced) attributes.
The classification is: a key in
_PYTREE_IGNOREis dropped entirely (cache / scratch state, not part of the pytree); otherwise a value is static iff its Python type isstr/boolor its key appears in_STATIC_KEYS; everything else is a traced child.- Parameters:
obj (
Any) – Instance to flatten. Any class registered withjax.tree_util.register_pytree_node().- Returns:
tuple –
(children, aux_data)— the standard pytree flatten output.- Return type:
Tuple[Tuple[Any,...],Tuple[Any,...]]
- Splits