Upon running the notebook as it is, I observe an AttributeError: 'tuple' object has no attribute 'dtype', while 'Training Resnets with VeLO'. On the 114th line of the 2nd cell under "Training Resnets with VeLO" i.e.
state = solver.init_state(params, L2REG, next(test_ds), batch_stats)
The return value of the "init_state" method of optaxSolver is an OptaxState and it returns
OptaxState(iter_num=jnp.asarray(0),
value=jnp.asarray(jnp.inf, value.dtype),
error=jnp.asarray(jnp.inf, dtype=params_dtype),
aux=aux,
internal_state=opt_state)
I cannot understand if value is a tuple or just a scalar and why is this error keep occurring. Please help in solving it.