jaxvacua.util.progress_bar_jax#
- progress_bar_jax(arg, transforms)#
JAX-host-callback progress printer.
Designed to be used as the
result_shapecallback ofjax.experimental.host_callback.id_tap(or its modernjax.debug.callbackreplacement). Prints a single-line residual / iteration tracker via carriage return so successive calls overwrite each other on the terminal.- Parameters:
arg (
Tuple[int,int,float]) –(i, n_iter, residual)— current iteration index, total iteration count, and a scalar residual to display.transforms (
Any) – JAX-supplied trace metadata (unused here; required by the host-callback signature).
- Returns:
int –
i— the current iteration index, unchanged.- Return type:
int