jaxvacua.util.progress_bar_jax

jaxvacua.util.progress_bar_jax#

progress_bar_jax(arg, transforms)#

JAX-host-callback progress printer.

Designed to be used as the result_shape callback of jax.experimental.host_callback.id_tap (or its modern jax.debug.callback replacement). Prints a single-line residual / iteration tracker via carriage return so successive calls overwrite each other on the terminal.

Parameters:
  • arg (Tuple[int, int, float]) – (i, n_iter, residual) — current iteration index, total iteration count, and a scalar residual to display.

  • transforms (Any) – JAX-supplied trace metadata (unused here; required by the host-callback signature).

Returns:

inti — the current iteration index, unchanged.

Return type:

int