diff --git a/tutorials/examples/train_hypergrid.py b/tutorials/examples/train_hypergrid.py index aa6ade25..25722dec 100644 --- a/tutorials/examples/train_hypergrid.py +++ b/tutorials/examples/train_hypergrid.py @@ -807,7 +807,7 @@ def _model_builder() -> Tuple[GFlowNet, torch.optim.Optimizer]: gflownet = gflownet.to(device) n_iterations = ceil(args.n_trajectories / args.batch_size) - per_node_batch_size = args.batch_size // distributed_context.world_size + per_node_batch_size = args.batch_size // distributed_context.num_training_ranks modes_found = set() # n_pixels_per_mode = round(env.height / 10) ** env.ndim # Note: on/off-policy depends on the current strategy; recomputed inside the loop. @@ -828,14 +828,6 @@ def _model_builder() -> Tuple[GFlowNet, torch.optim.Optimizer]: ) prof.start() - if args.distributed: - # Create and start error handler. - def cleanup(): - logger.info("Process %d: Cleaning up...", rank) - - rank = torch.distributed.get_rank() - torch.distributed.get_world_size() - # Initialize some variables before the training loop. timing = {} time_start = time.time() @@ -897,7 +889,7 @@ def cleanup(): ) trajectories = gflownet.sample_trajectories( env, - n=args.batch_size, + n=per_node_batch_size, save_logprobs=is_on_policy_iter, # Reuse on-policy log-probs. save_estimator_outputs=not is_on_policy_iter, # Off-policy caches estimator outputs. epsilon=float(getattr(args, "agent_epsilon", 0.0)),