jaxvacua.util.flatten

Contents

jaxvacua.util.flatten#

flatten(arr, as_gen=False, as_np_arr=False)#

Recursively flatten an arbitrarily nested iterable of any depth.

(Modified from Stack Overflow #2158395.)

Parameters:
  • arr (Iterable) – The (possibly ragged, possibly nested) input.

  • as_gen (bool) – Return a generator instead of materialising.

  • as_np_arr (bool) – Return a 1-D np.ndarray instead of a list.

Returns:

list | np.ndarray | generator – The flattened elements.

Raises:

ValueError – If both as_gen and as_np_arr are True.

Return type:

Union[list, ndarray, Iterable]

Example

A = np.asarray(range(2**3)).reshape(2, 2, 2)
flatten(A)                    # [0, 1, 2, 3, 4, 5, 6, 7]
flatten(A, as_np_arr=True)    # array([0, 1, 2, 3, 4, 5, 6, 7])