Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 93 additions & 59 deletions tutorials/examples/train_hypergrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

import logging
import os
import random
import time
from argparse import ArgumentParser
from math import ceil
Expand Down Expand Up @@ -314,57 +315,48 @@ def get_exact_P_T(env: HyperGrid, gflownet: GFlowNet) -> torch.Tensor:
return (u * probabilities[..., -1]).detach().cpu()


def _sample_new_strategy(
args,
agent_group_id: int,
iteration: int,
prev_eps: float,
prev_temp: float,
prev_noisy: int,
) -> dict:
"""Select a new exploration strategy, including noisy layers.
def _sample_new_strategy(args, rng: random.Random) -> dict:
"""Sample a new exploration strategy by independently sampling each parameter.

The strategy only defines exploration-time parameters and the count of
noisy layers to use when building/rebuilding the networks.
Each parameter (epsilon, temperature, n_noisy_layers) is sampled from a
normal distribution with mean and std specified in args. Values are clamped
to valid ranges.

We pick deterministically from a small candidate pool, excluding the
previous configuration when possible, to ensure diversity across
restarts without requiring synchronization.
Args:
args: Argument namespace containing mean/std for each parameter:
- epsilon, strategy_epsilon_std
- temperature, strategy_temperature_std
- n_noisy_layers, strategy_n_noisy_layers_std
- strategy_noisy_std_init (optional, default 0.5)
rng: Random number generator instance to use for sampling.

Returns:
A dict with keys: name, epsilon, temperature, n_noisy_layers,
and noisy_std_init (if present in args, default 0.5 otherwise).
A dict with keys: name, epsilon, temperature, n_noisy_layers, noisy_std_init.
"""
# TODO: Generate a new exploration strategy instead of selecting from a pre-defined
# list.
candidates = [
{"name": "on_policy", "epsilon": 0.0, "temperature": 1.0, "n_noisy_layers": 0},
{"name": "epsilon_0.1", "epsilon": 0.1, "temperature": 1.0, "n_noisy_layers": 0},
{"name": "temp_1.5", "epsilon": 0.0, "temperature": 1.5, "n_noisy_layers": 0},
{"name": "noisy_1", "epsilon": 0.0, "temperature": 1.0, "n_noisy_layers": 1},
{
"name": "noisy_2_temp_1.5",
"epsilon": 0.0,
"temperature": 1.5,
"n_noisy_layers": 2,
},
]
choices = [
c
for c in candidates
if (
c["epsilon"] != prev_eps
or c["temperature"] != prev_temp
or c["n_noisy_layers"] != prev_noisy
)
]
if not choices:
choices = candidates
idx_seed = int(args.seed) + int(agent_group_id) * 7919 + int(iteration) * 104729
idx = idx_seed % len(choices)
strat = choices[idx]
strat["noisy_std_init"] = float(getattr(args, "agent_noisy_std_init", 0.5))
return strat
# Get mean/std from args with sensible defaults.
eps_mean = float(getattr(args, "epsilon", 0.1))
eps_std = float(getattr(args, "strategy_epsilon_std", 0.05))
temp_mean = float(getattr(args, "temperature", 1.5))
temp_std = float(getattr(args, "strategy_temperature_std", 0.5))
noisy_mean = float(getattr(args, "n_noisy_layers", 1.0))
noisy_std = float(getattr(args, "strategy_n_noisy_layers_std", 1.0))
noisy_std_init = float(getattr(args, "noisy_std_init", 0.5))

# Sample from normal distribution and clamp to valid ranges.
epsilon = max(0.0, rng.gauss(eps_mean, eps_std))
temperature = max(0.01, rng.gauss(temp_mean, temp_std)) # temperature > 0
n_noisy_layers = max(0, round(rng.gauss(noisy_mean, noisy_std)))

# Build a descriptive name for the strategy.
name = f"eps_{epsilon:.3f}_temp_{temperature:.3f}_noisy_{n_noisy_layers}"

return {
"name": name,
"epsilon": epsilon,
"temperature": temperature,
"n_noisy_layers": n_noisy_layers,
"noisy_std_init": noisy_std_init,
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool! Can we fix any of these settings? I.e., only sample temperature, leaving n_noisy_layers=0 and epsilon=0 for all agents?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, exactly, by setting std=0, we only "sample" the mean



def _make_optimizer_for(gflownet, args) -> torch.optim.Optimizer:
Expand Down Expand Up @@ -480,26 +472,21 @@ def set_up_logF_estimator(
return ScalarEstimator(module=module, preprocessor=preprocessor)


def set_up_gflownet(args, env, preprocessor, agent_group_list, my_agent_group_id):
def set_up_gflownet(
args, env, preprocessor, agent_group_list, my_agent_group_id, strategy_rng
):
"""Returns a GFlowNet complete with the required estimators."""
# Initialize per-agent exploration strategy.
# Default (tests stable): on-policy, no noisy layers.
# When --use_random_strategies is provided, sample a random initial strategy.
if getattr(args, "use_random_strategies", False):
cfg = _sample_new_strategy(
args,
agent_group_id=my_agent_group_id,
iteration=0,
prev_eps=9999.0,
prev_temp=9999.0,
prev_noisy=9999,
)
cfg = _sample_new_strategy(args, strategy_rng)
else:
cfg = {
"epsilon": 0.0,
"temperature": 1.0,
"n_noisy_layers": 0,
"noisy_std_init": 0.5,
"epsilon": args.epsilon,
"temperature": args.temperature,
"n_noisy_layers": args.n_noisy_layers,
"noisy_std_init": args.noisy_std_init,
}

args.agent_epsilon = float(cfg.get("epsilon", 0.0))
Expand Down Expand Up @@ -672,6 +659,10 @@ def main(args) -> dict: # noqa: C901

set_seed(args.seed + distributed_context.my_rank)

# Create RNG for strategy sampling (seeded deterministically per agent group).
agent_group_id = distributed_context.agent_group_id or 0
strategy_rng = random.Random(args.seed + agent_group_id)

# Initialize the environment.
env = HyperGrid(
args.ndim,
Expand Down Expand Up @@ -767,6 +758,7 @@ def _model_builder() -> Tuple[GFlowNet, torch.optim.Optimizer]:
preprocessor,
distributed_context.agent_groups,
distributed_context.agent_group_id,
strategy_rng,
)
if use_wandb:
import wandb
Expand Down Expand Up @@ -1424,6 +1416,48 @@ def cleanup():
action="store_true",
help="Use a random strategy for the initial gflownet and restarts.",
)
parser.add_argument(
"--epsilon",
type=float,
default=0.0,
help="Mean epsilon for strategy sampling (default: 0.0).",
)
parser.add_argument(
"--temperature",
type=float,
default=1.0,
help="Mean temperature for strategy sampling (default: 1.0).",
)
parser.add_argument(
"--n_noisy_layers",
type=float,
default=0,
help="Mean number of noisy layers for strategy sampling (default: 0).",
)
parser.add_argument(
"--noisy_std_init",
type=float,
default=0.5,
help="Initial std for noisy layers (default: 0.5).",
)
parser.add_argument(
"--strategy_epsilon_std",
type=float,
default=0.1,
help="Std of epsilon for strategy sampling (default: 0.1).",
)
parser.add_argument(
"--strategy_temperature_std",
type=float,
default=1.0,
help="Std of temperature for strategy sampling (default: 1.0).",
)
parser.add_argument(
"--strategy_n_noisy_layers_std",
type=float,
default=1.0,
help="Std of number of noisy layers for strategy sampling (default: 1.0).",
)
parser.add_argument(
"--use_restarts",
action="store_true",
Expand Down
Loading