Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions tutorials/examples/test_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,6 @@ class DiscreteEBMArgs(CommonArgs):
@dataclass
class HypergridArgs(CommonArgs):
back_ratio: float = 0.5
store_all_states: bool = True
calculate_partition: bool = True
distributed: bool = False
diverse_replay_buffer: bool = False
epsilon: float = 0.1
Expand All @@ -108,6 +106,7 @@ class HypergridArgs(CommonArgs):
timing: bool = True
half_precision: bool = False
remote_buffer_freq = 1
validate_environment: bool = True


@dataclass
Expand Down
77 changes: 35 additions & 42 deletions tutorials/examples/train_hypergrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,8 +683,8 @@ def main(args) -> dict: # noqa: C901
"R1": args.R1,
"R2": args.R2,
},
calculate_partition=args.calculate_partition,
store_all_states=args.store_all_states,
calculate_partition=args.validate_environment,
store_all_states=args.validate_environment,
debug=__debug__,
)

Expand Down Expand Up @@ -1010,33 +1010,35 @@ def cleanup():
)

# If we are on the master node, calculate the validation metrics.
with Timer(timing, "validation", enabled=args.timing):
assert visited_terminating_states is not None
all_visited_terminating_states.extend(visited_terminating_states)
to_log = {
"loss": loss.item(),
"sample_time": sample_timer.elapsed,
"to_train_samples_time": to_train_samples_timer.elapsed,
"loss_time": loss_timer.elapsed,
"loss_backward_time": loss_backward_timer.elapsed,
"opt_time": opt_timer.elapsed,
"model_averaging_time": model_averaging_timer.elapsed,
"rest_time": rest_time,
"l1_dist": None, # only logged if calculate_partition.
}
to_log.update(averaging_info)
if score_dict is not None:
to_log.update(score_dict)

if log_this_iter:
validation_info, all_visited_terminating_states = env.validate(
gflownet,
args.validation_samples,
all_visited_terminating_states,
)
assert all_visited_terminating_states is not None
to_log.update(validation_info)
assert visited_terminating_states is not None
all_visited_terminating_states.extend(visited_terminating_states)
to_log = {
"loss": loss.item(),
"sample_time": sample_timer.elapsed,
"to_train_samples_time": to_train_samples_timer.elapsed,
"loss_time": loss_timer.elapsed,
"loss_backward_time": loss_backward_timer.elapsed,
"opt_time": opt_timer.elapsed,
"model_averaging_time": model_averaging_timer.elapsed,
"rest_time": rest_time,
"l1_dist": None, # only logged if calculate_partition.
}
to_log.update(averaging_info)
if score_dict is not None:
to_log.update(score_dict)

if log_this_iter:
if args.validate_environment:
with Timer(timing, "validation", enabled=args.timing):
validation_info, all_visited_terminating_states = env.validate(
gflownet,
args.validation_samples,
all_visited_terminating_states,
)
assert all_visited_terminating_states is not None
to_log.update(validation_info)

with Timer(timing, "log", enabled=args.timing):
if distributed_context.my_rank == 0:
if args.distributed:
manager_rank = distributed_context.assigned_buffer
Expand Down Expand Up @@ -1340,6 +1342,11 @@ def cleanup():
)

# Validation settings.
parser.add_argument(
"--validate_environment",
action="store_true",
help="Validate the environment at the end of training",
)
parser.add_argument(
"--validation_interval",
type=int,
Expand Down Expand Up @@ -1371,20 +1378,6 @@ def cleanup():
action="store_true",
help="Stores wandb results locally, to be uploaded later.",
)

# Settings relevant to the problem size -- toggle off for larger problems.
parser.add_argument(
"--store_all_states",
action="store_true",
default=False,
help="Whether to store all states.",
)
parser.add_argument(
"--calculate_partition",
action="store_true",
default=False,
help="Whether to calculate the true partition function.",
)
parser.add_argument(
"--profile",
action="store_true",
Expand Down
Loading