From 03be5ab0b900a243df24623adde54467c043c721 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Mon, 2 Feb 2026 16:03:05 +0100 Subject: [PATCH 1/2] add validate_environment --- tutorials/examples/test_scripts.py | 1 + tutorials/examples/train_hypergrid.py | 59 +++++++++++++++------------ 2 files changed, 34 insertions(+), 26 deletions(-) diff --git a/tutorials/examples/test_scripts.py b/tutorials/examples/test_scripts.py index 0345ec07..5933d08b 100644 --- a/tutorials/examples/test_scripts.py +++ b/tutorials/examples/test_scripts.py @@ -108,6 +108,7 @@ class HypergridArgs(CommonArgs): timing: bool = True half_precision: bool = False remote_buffer_freq = 1 + validate_environment: bool = True @dataclass diff --git a/tutorials/examples/train_hypergrid.py b/tutorials/examples/train_hypergrid.py index aa6ade25..862c5e43 100644 --- a/tutorials/examples/train_hypergrid.py +++ b/tutorials/examples/train_hypergrid.py @@ -1010,33 +1010,35 @@ def cleanup(): ) # If we are on the master node, calculate the validation metrics. - with Timer(timing, "validation", enabled=args.timing): - assert visited_terminating_states is not None - all_visited_terminating_states.extend(visited_terminating_states) - to_log = { - "loss": loss.item(), - "sample_time": sample_timer.elapsed, - "to_train_samples_time": to_train_samples_timer.elapsed, - "loss_time": loss_timer.elapsed, - "loss_backward_time": loss_backward_timer.elapsed, - "opt_time": opt_timer.elapsed, - "model_averaging_time": model_averaging_timer.elapsed, - "rest_time": rest_time, - "l1_dist": None, # only logged if calculate_partition. - } - to_log.update(averaging_info) - if score_dict is not None: - to_log.update(score_dict) - - if log_this_iter: - validation_info, all_visited_terminating_states = env.validate( - gflownet, - args.validation_samples, - all_visited_terminating_states, - ) - assert all_visited_terminating_states is not None - to_log.update(validation_info) + assert visited_terminating_states is not None + all_visited_terminating_states.extend(visited_terminating_states) + to_log = { + "loss": loss.item(), + "sample_time": sample_timer.elapsed, + "to_train_samples_time": to_train_samples_timer.elapsed, + "loss_time": loss_timer.elapsed, + "loss_backward_time": loss_backward_timer.elapsed, + "opt_time": opt_timer.elapsed, + "model_averaging_time": model_averaging_timer.elapsed, + "rest_time": rest_time, + "l1_dist": None, # only logged if calculate_partition. + } + to_log.update(averaging_info) + if score_dict is not None: + to_log.update(score_dict) + + if log_this_iter: + if args.validate_environment: + with Timer(timing, "validation", enabled=args.timing): + validation_info, all_visited_terminating_states = env.validate( + gflownet, + args.validation_samples, + all_visited_terminating_states, + ) + assert all_visited_terminating_states is not None + to_log.update(validation_info) + with Timer(timing, "log", enabled=args.timing): if distributed_context.my_rank == 0: if args.distributed: manager_rank = distributed_context.assigned_buffer @@ -1340,6 +1342,11 @@ def cleanup(): ) # Validation settings. + parser.add_argument( + "--validate_environment", + action="store_true", + help="Validate the environment at the end of training", + ) parser.add_argument( "--validation_interval", type=int, From da27d8bbefcf8fd754e608b712ad20197b82f8b9 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Mon, 2 Feb 2026 16:07:31 +0100 Subject: [PATCH 2/2] fixes --- tutorials/examples/test_scripts.py | 2 -- tutorials/examples/train_hypergrid.py | 18 ++---------------- 2 files changed, 2 insertions(+), 18 deletions(-) diff --git a/tutorials/examples/test_scripts.py b/tutorials/examples/test_scripts.py index 5933d08b..217d8546 100644 --- a/tutorials/examples/test_scripts.py +++ b/tutorials/examples/test_scripts.py @@ -88,8 +88,6 @@ class DiscreteEBMArgs(CommonArgs): @dataclass class HypergridArgs(CommonArgs): back_ratio: float = 0.5 - store_all_states: bool = True - calculate_partition: bool = True distributed: bool = False diverse_replay_buffer: bool = False epsilon: float = 0.1 diff --git a/tutorials/examples/train_hypergrid.py b/tutorials/examples/train_hypergrid.py index 862c5e43..ad3ef858 100644 --- a/tutorials/examples/train_hypergrid.py +++ b/tutorials/examples/train_hypergrid.py @@ -683,8 +683,8 @@ def main(args) -> dict: # noqa: C901 "R1": args.R1, "R2": args.R2, }, - calculate_partition=args.calculate_partition, - store_all_states=args.store_all_states, + calculate_partition=args.validate_environment, + store_all_states=args.validate_environment, debug=__debug__, ) @@ -1378,20 +1378,6 @@ def cleanup(): action="store_true", help="Stores wandb results locally, to be uploaded later.", ) - - # Settings relevant to the problem size -- toggle off for larger problems. - parser.add_argument( - "--store_all_states", - action="store_true", - default=False, - help="Whether to store all states.", - ) - parser.add_argument( - "--calculate_partition", - action="store_true", - default=False, - help="Whether to calculate the true partition function.", - ) parser.add_argument( "--profile", action="store_true",