jaxvacua.util.flatten_top

Contents

jaxvacua.util.flatten_top#

flatten_top(arr, as_list=True, N=1)#

Flatten the top N levels (axis 0) of a nested iterable.

Parameters:
  • arr (Iterable) – The input (can be ragged).

  • as_list (bool) – If True (default), return a list; otherwise return a np.ndarray.

  • N (int) – Number of top levels to flatten. Default 1.

Returns:

list | np.ndarrayarr with its top N levels flattened.

Return type:

Union[list, ndarray]

Example

A = np.asarray(range(2**3)).reshape(2, 2, 2)
flatten_top(A.tolist())       # [[0, 1], [2, 3], [4, 5], [6, 7]]
flatten_top(A.tolist(), N=2)  # [0, 1, 2, 3, 4, 5, 6, 7]