From 7cbea563b5a7e3b838edef79cc8b6af608735b38 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Wed, 4 Feb 2026 12:56:16 +0100 Subject: [PATCH 1/2] fix minibatch size --- tutorials/examples/train_hypergrid.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/tutorials/examples/train_hypergrid.py b/tutorials/examples/train_hypergrid.py index aa6ade25..0818c013 100644 --- a/tutorials/examples/train_hypergrid.py +++ b/tutorials/examples/train_hypergrid.py @@ -807,7 +807,7 @@ def _model_builder() -> Tuple[GFlowNet, torch.optim.Optimizer]: gflownet = gflownet.to(device) n_iterations = ceil(args.n_trajectories / args.batch_size) - per_node_batch_size = args.batch_size // distributed_context.world_size + per_node_batch_size = args.batch_size // distributed_context.num_training_ranks modes_found = set() # n_pixels_per_mode = round(env.height / 10) ** env.ndim # Note: on/off-policy depends on the current strategy; recomputed inside the loop. @@ -828,14 +828,6 @@ def _model_builder() -> Tuple[GFlowNet, torch.optim.Optimizer]: ) prof.start() - if args.distributed: - # Create and start error handler. - def cleanup(): - logger.info("Process %d: Cleaning up...", rank) - - rank = torch.distributed.get_rank() - torch.distributed.get_world_size() - # Initialize some variables before the training loop. timing = {} time_start = time.time() From 45e2166493235b6b324fabb48409453c9c4a946c Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Thu, 5 Feb 2026 16:51:39 +0100 Subject: [PATCH 2/2] second fix --- tutorials/examples/train_hypergrid.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tutorials/examples/train_hypergrid.py b/tutorials/examples/train_hypergrid.py index 0818c013..25722dec 100644 --- a/tutorials/examples/train_hypergrid.py +++ b/tutorials/examples/train_hypergrid.py @@ -889,7 +889,7 @@ def _model_builder() -> Tuple[GFlowNet, torch.optim.Optimizer]: ) trajectories = gflownet.sample_trajectories( env, - n=args.batch_size, + n=per_node_batch_size, save_logprobs=is_on_policy_iter, # Reuse on-policy log-probs. save_estimator_outputs=not is_on_policy_iter, # Off-policy caches estimator outputs. epsilon=float(getattr(args, "agent_epsilon", 0.0)),