From ee17515fa5b0293b9f8ea4da4087b39474ca04b2 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Mon, 15 Dec 2025 07:23:25 -0500 Subject: [PATCH 01/26] rtb prototype --- src/gfn/gflownet/__init__.py | 7 +- src/gfn/gflownet/base.py | 32 ++++- src/gfn/gflownet/trajectory_balance.py | 131 +++++++++++++++++++ src/gfn/gym/diffusion_sampling.py | 170 ++++++++++++++++++++++++- src/gfn/gym/helpers/diffusion_utils.py | 3 + 5 files changed, 339 insertions(+), 4 deletions(-) diff --git a/src/gfn/gflownet/__init__.py b/src/gfn/gflownet/__init__.py index 77fcf893..108c08c1 100644 --- a/src/gfn/gflownet/__init__.py +++ b/src/gfn/gflownet/__init__.py @@ -2,7 +2,11 @@ from .detailed_balance import DBGFlowNet, ModifiedDBGFlowNet from .flow_matching import FMGFlowNet from .sub_trajectory_balance import SubTBGFlowNet -from .trajectory_balance import LogPartitionVarianceGFlowNet, TBGFlowNet +from .trajectory_balance import ( + LogPartitionVarianceGFlowNet, + RelativeTrajectoryBalanceGFlowNet, + TBGFlowNet, +) __all__ = [ "GFlowNet", @@ -13,5 +17,6 @@ "FMGFlowNet", "SubTBGFlowNet", "LogPartitionVarianceGFlowNet", + "RelativeTrajectoryBalanceGFlowNet", "TBGFlowNet", ] diff --git a/src/gfn/gflownet/base.py b/src/gfn/gflownet/base.py index 5542019b..24a976b8 100644 --- a/src/gfn/gflownet/base.py +++ b/src/gfn/gflownet/base.py @@ -11,7 +11,11 @@ from gfn.estimators import Estimator from gfn.samplers import Sampler from gfn.states import States -from gfn.utils.prob_calculations import get_trajectory_pfs_and_pbs +from gfn.utils.prob_calculations import ( + get_trajectory_pbs, + get_trajectory_pfs, + get_trajectory_pfs_and_pbs, +) TrainingSampleType = TypeVar("TrainingSampleType", bound=Container) @@ -343,6 +347,32 @@ def get_pfs_and_pbs( recalculate_all_logprobs, ) + def trajectory_log_probs_forward( + self, + trajectories: Trajectories, + fill_value: float = 0.0, + recalculate_all_logprobs: bool = True, + ) -> torch.Tensor: + """Evaluates forward logprobs only for each trajectory in the batch.""" + return get_trajectory_pfs( + self.pf, + trajectories, + fill_value=fill_value, + recalculate_all_logprobs=recalculate_all_logprobs, + ) + + def trajectory_log_probs_backward( + self, + trajectories: Trajectories, + fill_value: float = 0.0, + ) -> torch.Tensor: + """Evaluates backward logprobs only for each trajectory in the batch.""" + return get_trajectory_pbs( + self.pb, + trajectories, + fill_value=fill_value, + ) + def get_scores( self, trajectories: Trajectories, diff --git a/src/gfn/gflownet/trajectory_balance.py b/src/gfn/gflownet/trajectory_balance.py index 8e2d6e6b..9a3a04a1 100644 --- a/src/gfn/gflownet/trajectory_balance.py +++ b/src/gfn/gflownet/trajectory_balance.py @@ -3,6 +3,7 @@ and the [Log Partition Variance loss](https://arxiv.org/abs/2302.05446). """ +import math from typing import cast import torch @@ -16,6 +17,7 @@ is_callable_exception_handler, warn_about_recalculating_logprobs, ) +from gfn.utils.prob_calculations import get_trajectory_pfs class TBGFlowNet(TrajectoryBasedGFlowNet): @@ -132,6 +134,135 @@ def loss( return loss +class RelativeTrajectoryBalanceGFlowNet(TrajectoryBasedGFlowNet): + r"""GFlowNet for the Relative Trajectory Balance (RTB) loss. + + This objective matches a posterior sampler to a prior diffusion (or other + sequential) model by minimizing + + .. math:: + + \left(\log Z_\phi + \log p_\phi(\tau) - \log p_\theta(\tau) + - \beta \log r(x_T)\right)^2, + + where :math:`p_\theta` is a fixed prior process, :math:`p_\phi` is the + learnable posterior, :math:`r` is a positive reward/constraint on the + terminal state :math:`x_T`, and :math:`\log Z_\phi` is a learned scalar + normalizer. + """ + + def __init__( + self, + pf: Estimator, + prior_pf: Estimator, + *, + logZ: nn.Parameter | ScalarEstimator | None = None, + init_logZ: float = 0.0, + beta: float = 1.0, + log_reward_clip_min: float = -float("inf"), + debug: bool = False, + ): + """Initializes an RTB GFlowNet. + + Args: + pf: Posterior forward policy estimator :math:`p_\\phi`. + prior_pf: Fixed prior forward policy estimator :math:`p_\\theta`. + logZ: Learnable log-partition parameter or ScalarEstimator for + conditional settings. Defaults to a scalar parameter. + init_logZ: Initial value for logZ if ``logZ`` is None. + beta: Optional scaling applied to the terminal log-reward. + log_reward_clip_min: If finite, clips terminal log-rewards. + debug: if True, enables extra checks at the cost of execution speed. + """ + super().__init__( + pf=pf, + pb=None, + constant_pb=True, + log_reward_clip_min=log_reward_clip_min, + ) + self.prior_pf = prior_pf + self.beta = beta + self.logZ = logZ or nn.Parameter(torch.tensor(init_logZ)) + self.debug = debug # TODO: to be passed to base classes. + + def logz_named_parameters(self) -> dict[str, torch.Tensor]: + """Returns named parameters containing 'logZ'.""" + return {k: v for k, v in dict(self.named_parameters()).items() if "logZ" in k} + + def logz_parameters(self) -> list[torch.Tensor]: + """Returns parameters containing 'logZ'.""" + return [v for k, v in dict(self.named_parameters()).items() if "logZ" in k] + + def _prior_log_pf( + self, + trajectories: Trajectories, + *, + fill_value: float = 0.0, + recalculate_all_logprobs: bool = True, + ) -> torch.Tensor: + """Computes prior forward log-probs along provided trajectories.""" + # The prior is fixed; evaluate it without tracking gradients to keep its + # parameters out of the RTB optimization graph. + with torch.no_grad(): + log_pf = get_trajectory_pfs( + self.prior_pf, + trajectories, + fill_value=fill_value, + recalculate_all_logprobs=recalculate_all_logprobs, + ) + return log_pf.sum(dim=0) + + def loss( + self, + env: Env, + trajectories: Trajectories, + recalculate_all_logprobs: bool = True, + reduction: str = "mean", + ) -> torch.Tensor: + """Computes the RTB loss on a batch of trajectories.""" + del env # unused + warn_about_recalculating_logprobs(trajectories, recalculate_all_logprobs) + + # Posterior log-probs (forward; backward ignored in RTB score). + log_pf_post = self.trajectory_log_probs_forward( + trajectories, + recalculate_all_logprobs=recalculate_all_logprobs, + ) + if self.debug: + assert log_pf_post is not None + + total_log_pf_post = log_pf_post.sum(dim=0) + + # Prior log-probs along the same trajectories. + total_log_pf_prior = self._prior_log_pf( + trajectories, + recalculate_all_logprobs=recalculate_all_logprobs, + ) + + log_rewards = trajectories.log_rewards + if self.debug: + assert log_rewards is not None + if math.isfinite(self.log_reward_clip_min): + log_rewards = log_rewards.clamp_min(self.log_reward_clip_min) # type: ignore + + if trajectories.conditions is not None: + with is_callable_exception_handler("logZ", self.logZ): + assert isinstance(self.logZ, ScalarEstimator) + logZ = self.logZ(trajectories.conditions) + else: + logZ = self.logZ + logZ = cast(torch.Tensor, logZ).squeeze() + + scores = ( + logZ + total_log_pf_post - total_log_pf_prior - self.beta * log_rewards.squeeze() # type: ignore + ).pow(2) + loss = loss_reduce(scores, reduction) + if torch.isnan(loss).any(): + raise ValueError("loss is nan") + + return loss + + class LogPartitionVarianceGFlowNet(TrajectoryBasedGFlowNet): """GFlowNet for the Log Partition Variance loss. diff --git a/src/gfn/gym/diffusion_sampling.py b/src/gfn/gym/diffusion_sampling.py index 68fc5385..d4b8b72c 100644 --- a/src/gfn/gym/diffusion_sampling.py +++ b/src/gfn/gym/diffusion_sampling.py @@ -405,6 +405,163 @@ def visualize( plt.close() +class Grid25GaussianMixture(BaseTarget): + """Fixed 5x5 Gaussian mixture prior used for RTB demos.""" + + def __init__( + self, + device: torch.device, + dim: int = 2, + scale: float = math.sqrt(0.3), + plot_border: float = 15.0, + seed: int = 0, + ) -> None: + assert dim == 2, "Grid25GaussianMixture is defined for 2D." + self.locs = torch.tensor( + [(a, b) for a in [-10, -5, 0, 5, 10] for b in [-10, -5, 0, 5, 10]], + device=device, + dtype=torch.get_default_dtype(), + ) + mix = D.Categorical( + probs=torch.full( + (self.locs.shape[0],), 1.0 / self.locs.shape[0], device=device + ) + ) + comp = D.Independent(D.Normal(self.locs, scale * torch.ones_like(self.locs)), 1) + self.gmm = D.MixtureSameFamily(mix, comp) + + super().__init__( + device=device, dim=dim, n_gt_xs=2048, seed=seed, plot_border=plot_border + ) + + def log_reward(self, x: torch.Tensor) -> torch.Tensor: + return self.gmm.log_prob(x).flatten() + + def sample(self, batch_size: int, seed: int | None = None) -> torch.Tensor: + ctx = nullcontext() + if seed is not None: + ctx = temporarily_set_seed(seed) + with ctx: + return self.gmm.sample((batch_size,)) + + def gt_logz(self) -> float: + return 0.0 + + def visualize( + self, + samples: torch.Tensor | None = None, + show: bool = False, + prefix: str = "", + linspace_n_steps: int = 100, + max_n_samples: int = 1000, + ) -> None: + assert self.plot_border is not None, "Visualization requires a plot border." + fig, ax = plt.subplots(1, 1, figsize=(6, 6)) + viz_2d_slice( + ax, + self, + (0, 1), + samples, + plot_border=self.plot_border, + use_log_reward=True, + # linspace_n_steps=linspace_n_steps, + max_n_samples=max_n_samples, + ) + plt.tight_layout() + if show: + plt.show() + else: + os.makedirs("viz", exist_ok=True) + fig.savefig(f"viz/{prefix}gmm25.png") + plt.close() + + +class Posterior9of25GaussianMixture(BaseTarget): + """Posterior reward for the 25→9 GMM RTB demo.""" + + def __init__( + self, + device: torch.device, + dim: int = 2, + scale: float = math.sqrt(0.3), + plot_border: float = 15.0, + seed: int = 0, + ) -> None: + assert dim == 2, "Posterior9of25GaussianMixture is defined for 2D." + self.prior = Grid25GaussianMixture( + device=device, dim=dim, scale=scale, plot_border=plot_border, seed=seed + ) + + mean_ls = [ + [-10.0, -5.0], + [-5.0, -10.0], + [-5.0, 0.0], + [10.0, -5.0], + [0.0, 0.0], + [0.0, 5.0], + [5.0, -5.0], + [5.0, 0.0], + [5.0, 10.0], + ] + locs = torch.tensor(mean_ls, device=device, dtype=torch.get_default_dtype()) + weights = torch.tensor( + [4, 10, 4, 5, 10, 5, 4, 15, 4], + device=device, + dtype=torch.get_default_dtype(), + ) + weights = weights / weights.sum() + + mix = D.Categorical(probs=weights) + comp = D.Independent(D.Normal(locs, scale * torch.ones_like(locs)), 1) + self.posterior = D.MixtureSameFamily(mix, comp) + + super().__init__( + device=device, dim=dim, n_gt_xs=2048, seed=seed, plot_border=plot_border + ) + + def log_reward(self, x: torch.Tensor) -> torch.Tensor: + # r(x) = p_post(x) / p_prior(x) + return self.posterior.log_prob(x).flatten() - self.prior.log_reward(x) + + def sample(self, batch_size: int, seed: int | None = None) -> torch.Tensor: + ctx = nullcontext() + if seed is not None: + ctx = temporarily_set_seed(seed) + with ctx: + return self.posterior.sample((batch_size,)) + + def gt_logz(self) -> float: + return 0.0 + + def visualize( + self, + samples: torch.Tensor | None = None, + show: bool = False, + prefix: str = "", + grid_width_n_points: int = 100, + max_n_samples: int = 1000, + ) -> None: + assert self.plot_border is not None, "Visualization requires a plot border." + fig, ax = plt.subplots(1, 1, figsize=(6, 6)) + viz_2d_slice( + ax, + self, + (0, 1), + samples, + plot_border=self.plot_border, + use_log_reward=True, + grid_width_n_points=grid_width_n_points, + max_n_samples=max_n_samples, + ) + plt.tight_layout() + if show: + plt.show() + else: + os.makedirs("viz", exist_ok=True) + fig.savefig(f"viz/{prefix}posterior9of25.png") + plt.close() + + class Funnel(BaseTarget): """Neal's funnel distribution target. @@ -481,7 +638,7 @@ def visualize( samples: torch.Tensor | None = None, show: bool = False, prefix: str = "", - linspace_n_steps: int = 100, + grid_width_n_points: int = 100, max_n_samples: int = 500, ) -> None: """Visualize only supported for 2D (x0, x1).""" @@ -497,6 +654,8 @@ def visualize( samples, plot_border=self.plot_border, use_log_reward=True, + grid_width_n_points=grid_width_n_points, + max_n_samples=max_n_samples, ) plt.tight_layout() @@ -640,7 +799,7 @@ def visualize( samples: torch.Tensor | None = None, show: bool = False, prefix: str = "", - linspace_n_steps: int = 100, + grid_width_n_points: int = 100, max_n_samples: int = 500, ) -> None: assert self.plot_border is not None, "Visualization requires a plot border." @@ -655,6 +814,8 @@ def visualize( samples, plot_border=self.plot_border, use_log_reward=True, + grid_width_n_points=grid_width_n_points, + max_n_samples=max_n_samples, ) plt.tight_layout() @@ -685,6 +846,11 @@ class DiffusionSampling(Env): "gmm2": (SimpleGaussianMixture, {"num_components": 2}), # 2D "gmm4": (SimpleGaussianMixture, {"num_components": 4}), # 2D "gmm8": (SimpleGaussianMixture, {"num_components": 8}), # 2D + "gmm25_prior": (Grid25GaussianMixture, {}), # 2D, fixed 25-mode grid + "gmm25_posterior9": ( + Posterior9of25GaussianMixture, + {}, + ), # 2D, 9-mode posterior reward "easy_funnel": (Funnel, {"std": 1.0}), # 10D "hard_funnel": (Funnel, {"std": 3.0}), # 10D "many_well": (ManyWell, {}), # 32D diff --git a/src/gfn/gym/helpers/diffusion_utils.py b/src/gfn/gym/helpers/diffusion_utils.py index 8e39067a..4d7274f8 100644 --- a/src/gfn/gym/helpers/diffusion_utils.py +++ b/src/gfn/gym/helpers/diffusion_utils.py @@ -27,6 +27,7 @@ def viz_2d_slice( grid_width_n_points=200, log_reward_clamp_min=-10000.0, use_log_reward=False, + max_n_samples: int | None = None, ) -> None: x_points_dim1 = torch.linspace(plot_border[0], plot_border[1], grid_width_n_points) x_points_dim2 = torch.linspace(plot_border[2], plot_border[3], grid_width_n_points) @@ -46,6 +47,8 @@ def viz_2d_slice( ax.contour(x_points_dim1, x_points_dim2, log_r_x, levels=n_contour_levels) if samples is not None: + if max_n_samples is not None: + samples = samples[:max_n_samples] samples = samples[:, dims].detach().cpu() samples[:, 0] = torch.clamp(samples[:, 0], plot_border[0], plot_border[1]) samples[:, 1] = torch.clamp(samples[:, 1], plot_border[2], plot_border[3]) From 85190b0ea8d09d5c2ca93137544437794059c5bd Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Mon, 15 Dec 2025 20:51:38 -0500 Subject: [PATCH 02/26] sync of RTB prototype --- src/gfn/estimators.py | 31 +- src/gfn/gflownet/trajectory_balance.py | 51 ++-- src/gfn/gym/diffusion_sampling.py | 4 +- testing/gym/test_diffusion_sampling_rtb.py | 43 +++ testing/test_rtb.py | 82 ++++++ tutorials/examples/train_diffusion_rtb.py | 321 +++++++++++++++++++++ 6 files changed, 495 insertions(+), 37 deletions(-) create mode 100644 testing/gym/test_diffusion_sampling_rtb.py create mode 100644 testing/test_rtb.py create mode 100644 tutorials/examples/train_diffusion_rtb.py diff --git a/src/gfn/estimators.py b/src/gfn/estimators.py index 9cb6839d..e06f47f8 100644 --- a/src/gfn/estimators.py +++ b/src/gfn/estimators.py @@ -1329,7 +1329,6 @@ def to_probability_distribution( states: States, module_output: torch.Tensor, **policy_kwargs: Any, - # TODO: add epsilon-noisy exploration ) -> IsotropicGaussian: """Transform the output of the module into a IsotropicGaussian distribution, which is the distribution of the next states under the pinned Brownian motion @@ -1339,7 +1338,14 @@ def to_probability_distribution( states: The states to use, states.tensor.shape = (*batch_shape, s_dim + 1). module_output: The output of the module (actions), as a tensor of shape (*batch_shape, s_dim). - **policy_kwargs: Keyword arguments to modify the distribution. + **policy_kwargs: Keyword arguments to modify the distribution. Supported + keys: + - exploration_std: Optional callable or float controlling extra + exploration noise on top of the base diffusion std. The callable + should accept an integer step index and return a non-negative + standard deviation in state space. When provided, the extra noise + is combined in variance-space (logaddexp) with the base diffusion + variance; non-positive exploration is ignored. Returns: A IsotropicGaussian distribution (distribution of the next states) @@ -1357,6 +1363,27 @@ def to_probability_distribution( fwd_mean = self.dt * module_output fwd_std = torch.tensor(self.sigma * self.dt**0.5, device=fwd_mean.device) fwd_std = fwd_std.repeat(fwd_mean.shape[0], 1) + + # Optional exploration noise: combine variances (quadrature/logaddexp). + exploration_std = policy_kwargs.pop("exploration_std", None) + exploration_std_t = torch.as_tensor( + exploration_std if exploration_std is not None else 0.0, + device=fwd_std.device, + dtype=fwd_std.dtype, + ).clamp(min=0.0) + + # Combine base diffusion variance σ_base^2 with exploration variance σ_expl^2: + # σ_combined = sqrt(σ_base^2 + σ_expl^2). torch.compile friendly. + base_log_var = 2 * fwd_std.log() # log(σ_base^2) + extra_log_var = 2 * exploration_std_t.clamp(min=1e-12).log() # log(σ_expl^2) + extra_log_var_tensor = extra_log_var.expand_as(base_log_var) + combined_log_var = torch.logaddexp(base_log_var, extra_log_var_tensor) + fwd_std = torch.where( + exploration_std_t > 0, + torch.exp(0.5 * combined_log_var), + fwd_std, + ) + return IsotropicGaussian(fwd_mean, fwd_std) diff --git a/src/gfn/gflownet/trajectory_balance.py b/src/gfn/gflownet/trajectory_balance.py index 9a3a04a1..7a2199bd 100644 --- a/src/gfn/gflownet/trajectory_balance.py +++ b/src/gfn/gflownet/trajectory_balance.py @@ -181,7 +181,7 @@ def __init__( log_reward_clip_min=log_reward_clip_min, ) self.prior_pf = prior_pf - self.beta = beta + self.beta = torch.tensor(beta) self.logZ = logZ or nn.Parameter(torch.tensor(init_logZ)) self.debug = debug # TODO: to be passed to base classes. @@ -193,25 +193,6 @@ def logz_parameters(self) -> list[torch.Tensor]: """Returns parameters containing 'logZ'.""" return [v for k, v in dict(self.named_parameters()).items() if "logZ" in k] - def _prior_log_pf( - self, - trajectories: Trajectories, - *, - fill_value: float = 0.0, - recalculate_all_logprobs: bool = True, - ) -> torch.Tensor: - """Computes prior forward log-probs along provided trajectories.""" - # The prior is fixed; evaluate it without tracking gradients to keep its - # parameters out of the RTB optimization graph. - with torch.no_grad(): - log_pf = get_trajectory_pfs( - self.prior_pf, - trajectories, - fill_value=fill_value, - recalculate_all_logprobs=recalculate_all_logprobs, - ) - return log_pf.sum(dim=0) - def loss( self, env: Env, @@ -223,28 +204,33 @@ def loss( del env # unused warn_about_recalculating_logprobs(trajectories, recalculate_all_logprobs) - # Posterior log-probs (forward; backward ignored in RTB score). + # Posterior log-probs. log_pf_post = self.trajectory_log_probs_forward( trajectories, recalculate_all_logprobs=recalculate_all_logprobs, ) - if self.debug: - assert log_pf_post is not None - - total_log_pf_post = log_pf_post.sum(dim=0) + log_pf_post = log_pf_post.sum(dim=0) # Sum along trajectory length. # Prior log-probs along the same trajectories. - total_log_pf_prior = self._prior_log_pf( - trajectories, - recalculate_all_logprobs=recalculate_all_logprobs, - ) + # The prior is fixed; evaluate it without tracking gradients to keep its + # parameters out of the RTB optimization graph. + with torch.no_grad(): + log_pf_prior = get_trajectory_pfs( + self.prior_pf, + trajectories, + fill_value=0.0, + recalculate_all_logprobs=True, + ) + log_pf_prior = log_pf_prior.sum(dim=0) # Sum along trajectory length. + # Get the rewards. log_rewards = trajectories.log_rewards if self.debug: assert log_rewards is not None if math.isfinite(self.log_reward_clip_min): log_rewards = log_rewards.clamp_min(self.log_reward_clip_min) # type: ignore + # Get logZ. if trajectories.conditions is not None: with is_callable_exception_handler("logZ", self.logZ): assert isinstance(self.logZ, ScalarEstimator) @@ -253,10 +239,9 @@ def loss( logZ = self.logZ logZ = cast(torch.Tensor, logZ).squeeze() - scores = ( - logZ + total_log_pf_post - total_log_pf_prior - self.beta * log_rewards.squeeze() # type: ignore - ).pow(2) - loss = loss_reduce(scores, reduction) + scores = 0.5 * (log_pf_post + logZ - log_pf_prior - self.beta * log_rewards).pow(2) # type: ignore + + loss = loss_reduce(scores, reduction) # Reduce across batch dimension. if torch.isnan(loss).any(): raise ValueError("loss is nan") diff --git a/src/gfn/gym/diffusion_sampling.py b/src/gfn/gym/diffusion_sampling.py index d4b8b72c..9bea3423 100644 --- a/src/gfn/gym/diffusion_sampling.py +++ b/src/gfn/gym/diffusion_sampling.py @@ -452,7 +452,7 @@ def visualize( samples: torch.Tensor | None = None, show: bool = False, prefix: str = "", - linspace_n_steps: int = 100, + grid_width_n_points: int = 100, max_n_samples: int = 1000, ) -> None: assert self.plot_border is not None, "Visualization requires a plot border." @@ -464,7 +464,7 @@ def visualize( samples, plot_border=self.plot_border, use_log_reward=True, - # linspace_n_steps=linspace_n_steps, + grid_width_n_points=grid_width_n_points, max_n_samples=max_n_samples, ) plt.tight_layout() diff --git a/testing/gym/test_diffusion_sampling_rtb.py b/testing/gym/test_diffusion_sampling_rtb.py new file mode 100644 index 00000000..8b4c9146 --- /dev/null +++ b/testing/gym/test_diffusion_sampling_rtb.py @@ -0,0 +1,43 @@ +import torch + +from gfn.gym.diffusion_sampling import ( + DiffusionSampling, + Grid25GaussianMixture, + Posterior9of25GaussianMixture, +) + + +def test_gmm25_prior_basic_sampling_and_log_reward(): + env = DiffusionSampling( + target_str="gmm25_prior", + target_kwargs=None, + num_discretization_steps=8, + device=torch.device("cpu"), + debug=True, + ) + assert isinstance(env.target, Grid25GaussianMixture) + x = env.target.sample(batch_size=16) + assert x.shape == (16, env.dim) + log_r = env.target.log_reward(x) + assert log_r.shape == (16,) + assert torch.isfinite(log_r).all() + + +def test_gmm25_posterior9_log_reward_matches_ratio(): + env = DiffusionSampling( + target_str="gmm25_posterior9", + target_kwargs=None, + num_discretization_steps=8, + device=torch.device("cpu"), + debug=True, + ) + assert isinstance(env.target, Posterior9of25GaussianMixture) + x = env.target.sample(batch_size=8) + assert x.shape == (8, env.dim) + + log_r = env.target.log_reward(x) + posterior_log = env.target.posterior.log_prob(x).flatten() + prior_log = env.target.prior.log_reward(x) + + assert torch.allclose(log_r, posterior_log - prior_log, atol=1e-5) + assert torch.isfinite(log_r).all() diff --git a/testing/test_rtb.py b/testing/test_rtb.py new file mode 100644 index 00000000..28427835 --- /dev/null +++ b/testing/test_rtb.py @@ -0,0 +1,82 @@ +import torch + +from gfn.estimators import DiscretePolicyEstimator +from gfn.gflownet import RelativeTrajectoryBalanceGFlowNet +from gfn.gym import HyperGrid +from gfn.preprocessors import KHotPreprocessor +from gfn.samplers import Sampler +from gfn.utils.modules import MLP + + +def _make_hypergrid_estimators(): + """Build simple forward policies for HyperGrid prior/posterior.""" + env = HyperGrid(ndim=2, height=4) + preproc = KHotPreprocessor(env.height, env.ndim) + assert isinstance(preproc.output_dim, int) + + pf_module_post = MLP(input_dim=preproc.output_dim, output_dim=env.n_actions) + pf_module_prior = MLP(input_dim=preproc.output_dim, output_dim=env.n_actions) + + pf_post = DiscretePolicyEstimator( + module=pf_module_post, + n_actions=env.n_actions, + preprocessor=preproc, + is_backward=False, + ) + pf_prior = DiscretePolicyEstimator( + module=pf_module_prior, + n_actions=env.n_actions, + preprocessor=preproc, + is_backward=False, + ) + return env, pf_post, pf_prior + + +def test_rtb_loss_backward_and_grads(): + torch.manual_seed(0) + env, pf_post, pf_prior = _make_hypergrid_estimators() + + gfn = RelativeTrajectoryBalanceGFlowNet( + pf=pf_post, + prior_pf=pf_prior, + init_logZ=0.0, + beta=1.0, + ) + sampler = Sampler(estimator=pf_post) + trajectories = sampler.sample_trajectories( + env, n=8, save_logprobs=True, save_estimator_outputs=False + ) + + loss = gfn.loss(env, trajectories, recalculate_all_logprobs=True) + assert torch.isfinite(loss) + + loss.backward() + + # Posterior parameters and logZ should receive gradients. + assert any(p.grad is not None for p in pf_post.parameters()) + assert any(p.grad is not None for p in gfn.logz_parameters()) + + # Prior parameters are not part of the RTB graph and should have no grads. + assert all(p.grad is None for p in pf_prior.parameters()) + + +def test_rtb_loss_forward_only_path(): + """Ensure RTB loss works with recalculate_all_logprobs=False.""" + torch.manual_seed(1) + env, pf_post, pf_prior = _make_hypergrid_estimators() + + gfn = RelativeTrajectoryBalanceGFlowNet( + pf=pf_post, + prior_pf=pf_prior, + init_logZ=0.0, + beta=0.5, + ) + sampler = Sampler(estimator=pf_post) + trajectories = sampler.sample_trajectories( + env, n=4, save_logprobs=True, save_estimator_outputs=False + ) + + # Use cached log_probs; should not rely on any backward policy. + loss = gfn.loss(env, trajectories, recalculate_all_logprobs=False) + assert torch.isfinite(loss) + loss.backward() diff --git a/tutorials/examples/train_diffusion_rtb.py b/tutorials/examples/train_diffusion_rtb.py new file mode 100644 index 00000000..a045738d --- /dev/null +++ b/tutorials/examples/train_diffusion_rtb.py @@ -0,0 +1,321 @@ +#!/usr/bin/env python +""" +Minimal end-to-end Relative Trajectory Balance (RTB) training script for diffusion. + +Uses the 25→9 GMM posterior target (`gmm25_posterior9`) with a learnable +posterior forward policy and a fixed prior forward policy. Loss is RTB +(no backward policy). At the end of training, saves a scatter plot of sampled +states to the user's home directory. +""" + +import argparse +import os + +import matplotlib.pyplot as plt +import torch +from tqdm import tqdm + +from gfn.estimators import PinnedBrownianMotionForward +from gfn.gflownet import RelativeTrajectoryBalanceGFlowNet +from gfn.gym.diffusion_sampling import DiffusionSampling +from gfn.gym.helpers.diffusion_utils import viz_2d_slice +from gfn.samplers import Sampler +from gfn.utils.common import set_seed +from gfn.utils.modules import DiffusionPISGradNetForward + + +def get_exploration_std( + iteration: int, + exploration_factor: float = 0.1, + warm_down_start: int = 500, + warm_down_end: int = 4500, + device: torch.device | None = None, + dtype: torch.dtype | None = None, +) -> torch.Tensor: + """Return a callable exploration std schedule for state-space noise. + + When exploration is enabled, return a step-index function that emits a fixed + std for the current training iteration, optionally linearly warmed down + after warm_down_start iters toward 0 by warm_down_end iters. + """ + device = device or torch.get_default_device() + dtype = dtype or torch.get_default_dtype() + + # Tensor ops only (torch.compile-friendly): no Python branching on iteration. + iter_t = torch.tensor(iteration, device=device, dtype=dtype) + # Clamp negatives to zero to avoid Python-side checks/overhead. + factor_t = torch.clamp( + torch.tensor(exploration_factor, device=device, dtype=dtype), min=0.0 + ) + start_t = torch.tensor(warm_down_start, device=device, dtype=dtype) + end_t = torch.tensor(warm_down_end, device=device, dtype=dtype) + + # Phase indicator: 1 before warm_down_start, linear decay afterward. + progress = torch.clamp(iter_t / end_t, min=0.0, max=1.0) + decay = torch.where( + iter_t < start_t, torch.ones_like(progress), torch.clamp(1.0 - progress, min=0.0) + ) + exploration_std = factor_t * decay + + return exploration_std + + +def build_forward_estimator( + s_dim: int, + num_steps: int, + sigma: float, + harmonics_dim: int, + t_emb_dim: int, + s_emb_dim: int, + hidden_dim: int, + joint_layers: int, + zero_init: bool, + device: torch.device, +) -> PinnedBrownianMotionForward: + pf_module = DiffusionPISGradNetForward( + s_dim=s_dim, + harmonics_dim=harmonics_dim, + t_emb_dim=t_emb_dim, + s_emb_dim=s_emb_dim, + hidden_dim=hidden_dim, + joint_layers=joint_layers, + zero_init=zero_init, + ) + return PinnedBrownianMotionForward( + s_dim=s_dim, + pf_module=pf_module, + sigma=sigma, + num_discretization_steps=num_steps, + ).to(device) + + +def plot_samples( + xs: torch.Tensor, + target, + save_path: str, + return_fig: bool = False, +): + """Contour + scatter plot of samples against the posterior density.""" + + assert target.plot_border is not None, "Target must define plot_border for plotting." + + # If target exposes a posterior density, build a lightweight shim with the same + # interface that viz_2d_slice expects (log_reward, dim, device, plot_border). + if hasattr(target, "posterior"): + # Use a shallow copy and replace log_reward to return posterior density + viz_target = target + + def _posterior_log_reward(x: torch.Tensor) -> torch.Tensor: + return viz_target.posterior.log_prob(x).flatten() + + viz_target.log_reward = _posterior_log_reward # type: ignore[attr-defined] + else: + viz_target = target + + fig, ax = plt.subplots(1, 1, figsize=(6, 6)) + viz_2d_slice( + ax, + viz_target, + (0, 1), + samples=xs, + plot_border=viz_target.plot_border, + use_log_reward=True, + grid_width_n_points=200, + max_n_samples=2000, + ) + ax.set_title("RTB posterior samples") + fig.tight_layout() + dirpath = os.path.dirname(save_path) + if dirpath: + os.makedirs(dirpath, exist_ok=True) + fig.savefig(save_path) + if return_fig: + return fig + plt.close(fig) + return None + + +def main(args: argparse.Namespace) -> None: + set_seed(args.seed) + device = torch.device( + "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu" + ) + + # Environment / target + env = DiffusionSampling( + target_str=args.target, + target_kwargs=None, + num_discretization_steps=args.num_steps, + device=device, + debug=__debug__, + ) + s_dim = env.dim + + # Posterior forward (trainable) + pf_post = build_forward_estimator( + s_dim=s_dim, + num_steps=args.num_steps, + sigma=args.sigma, + harmonics_dim=args.harmonics_dim, + t_emb_dim=args.t_emb_dim, + s_emb_dim=args.s_emb_dim, + hidden_dim=args.hidden_dim, + joint_layers=args.joint_layers, + zero_init=args.zero_init, + device=device, + ) + + # Prior forward (fixed, no grad) + pf_prior = build_forward_estimator( + s_dim=s_dim, + num_steps=args.num_steps, + sigma=args.sigma, + harmonics_dim=args.harmonics_dim, + t_emb_dim=args.t_emb_dim, + s_emb_dim=args.s_emb_dim, + hidden_dim=args.hidden_dim, + joint_layers=args.joint_layers, + zero_init=args.zero_init, + device=device, + ) + pf_prior.eval() + for p in pf_prior.parameters(): + p.requires_grad_(False) + + gflownet = RelativeTrajectoryBalanceGFlowNet( + pf=pf_post, + prior_pf=pf_prior, + init_logZ=0.0, + beta=args.beta, + log_reward_clip_min=args.log_reward_clip_min, + ).to(device) + + sampler = Sampler(estimator=pf_post) + optimizer = torch.optim.Adam( + [ + {"params": gflownet.pf_pb_parameters(), "lr": args.lr}, + {"params": gflownet.logz_parameters(), "lr": args.lr_logz}, + ] + ) + + for it in (pbar := tqdm(range(args.n_iterations), dynamic_ncols=True)): + trajectories = sampler.sample_trajectories( + env, + n=args.batch_size, + save_logprobs=False, # if args.exploration_factor > 0 else True, + save_estimator_outputs=False, + # Extra exploration noise (combined with base PF variance in estimator). + exploration_std=get_exploration_std( + iteration=it, + exploration_factor=args.exploration_factor, + warm_down_start=args.exploration_warm_down_start, + warm_down_end=args.exploration_warm_down_end, + ), + ) + + optimizer.zero_grad() + loss = gflownet.loss(env, trajectories, recalculate_all_logprobs=True) + loss.backward() + optimizer.step() + + if (it + 1) % args.log_interval == 0 or it == args.n_iterations - 1: + with torch.no_grad(): + term_states = gflownet.sample_terminating_states(env, n=args.eval_n) + rewards = env.target.log_reward(term_states.tensor[:, :-1]) + avg_reward = rewards.mean().item() + pbar.set_postfix({"loss": float(loss.item()), "avg_reward": avg_reward}) + else: + pbar.set_postfix({"loss": float(loss.item())}) + + # Final visualization + with torch.no_grad(): + samples_states = gflownet.sample_terminating_states(env, n=args.vis_n) + xs = samples_states.tensor[:, :-1] + save_path = os.path.expanduser(args.save_fig_path) + plot_samples( + xs, + env.target, + save_path, + return_fig=False, + ) + print(f"Saved final samples scatter to {save_path}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # System + parser.add_argument("--no_cuda", action="store_true", help="Prevent CUDA usage") + parser.add_argument("--seed", type=int, default=0, help="Random seed") + + # Target / environment + parser.add_argument( + "--target", + type=str, + default="gmm25_posterior9", + help="Diffusion target (default: gmm25_posterior9)", + ) + parser.add_argument( + "--num_steps", type=int, default=256, help="number of discretization steps" + ) + parser.add_argument( + "--sigma", + type=float, + default=2.0, + help="diffusion coefficient for the pinned Brownian motion", + ) + + # Model (DiffusionPISGradNetForward) + parser.add_argument("--harmonics_dim", type=int, default=64) + parser.add_argument("--t_emb_dim", type=int, default=64) + parser.add_argument("--s_emb_dim", type=int, default=64) + parser.add_argument("--hidden_dim", type=int, default=128) + parser.add_argument("--joint_layers", type=int, default=2) + parser.add_argument("--zero_init", action="store_true") + + # Training + parser.add_argument("--n_iterations", type=int, default=5000) + parser.add_argument("--batch_size", type=int, default=256) + parser.add_argument("--lr", type=float, default=1e-3) + parser.add_argument("--lr_logz", type=float, default=1e-1) + parser.add_argument("--beta", type=float, default=1.0, help="RTB beta multiplier") + parser.add_argument( + "--log_reward_clip_min", + type=float, + default=-float("inf"), + help="Min clip for log reward", + ) + # Exploration noise (state-space Gaussian added in quadrature to PF std) + parser.add_argument( + "--exploration_factor", + type=float, + default=5.0, + help="Base exploration std applied per step when exploratory is enabled", + ) + parser.add_argument( + "--exploration_warm_down_start", + type=float, + default=0, + help="Linearly warm down exploration after n iters (to 0 by exploration_warm_down_end iters)", + ) + parser.add_argument( + "--exploration_warm_down_end", + type=float, + default=3000, + help="Linearly warm down exploration after n iters (to 0 by exploration_warm_down_end iters)", + ) + + # Logging / eval + parser.add_argument("--log_interval", type=int, default=100) + parser.add_argument("--eval_n", type=int, default=500) + parser.add_argument( + "--vis_n", type=int, default=2000, help="Number of samples for final plot" + ) + parser.add_argument( + "--save_fig_path", + type=str, + default="~/rtb_final_samples.png", + help="Path to save final samples plot", + ) + + args = parser.parse_args() + main(args) From 6ecd0ca4e5691e5c20e07069b743f42c129ac2ef Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Mon, 15 Dec 2025 21:06:38 -0500 Subject: [PATCH 03/26] ignore outputs --- tutorials/examples/output/.gitignore | 4 ++++ 1 file changed, 4 insertions(+) create mode 100644 tutorials/examples/output/.gitignore diff --git a/tutorials/examples/output/.gitignore b/tutorials/examples/output/.gitignore new file mode 100644 index 00000000..3dca272e --- /dev/null +++ b/tutorials/examples/output/.gitignore @@ -0,0 +1,4 @@ +*.pt +*.jpg +*.jpeg +*.png \ No newline at end of file From 08258adb819376ebd78d162d3701d1a1f70f2857 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Tue, 16 Dec 2025 02:38:55 -0500 Subject: [PATCH 04/26] rtb finetune first pass --- src/gfn/estimators.py | 72 ++- src/gfn/gym/diffusion_sampling.py | 2 +- src/gfn/utils/modules.py | 106 ++++- tutorials/examples/train_diffusion_rtb.py | 545 +++++++++++++++++++++- 4 files changed, 682 insertions(+), 43 deletions(-) diff --git a/src/gfn/estimators.py b/src/gfn/estimators.py index e06f47f8..6c93314b 100644 --- a/src/gfn/estimators.py +++ b/src/gfn/estimators.py @@ -1,3 +1,4 @@ +import math from abc import ABC, abstractmethod from collections import defaultdict from typing import Any, Callable, Dict, List, Optional, Protocol, cast, runtime_checkable @@ -1290,6 +1291,7 @@ def __init__( pf_module: nn.Module, sigma: float, num_discretization_steps: int, + n_variance_outputs: int = 0, ): """Initialize the PinnedBrownianMotionForward. @@ -1305,6 +1307,12 @@ def __init__( self.sigma = sigma self.num_discretization_steps = num_discretization_steps self.dt = 1.0 / self.num_discretization_steps + self.n_variance_outputs = n_variance_outputs + + @property + def expected_output_dim(self) -> int: + # Drift (s_dim) plus optional variance outputs. + return self.s_dim + self.n_variance_outputs def forward(self, input: States) -> torch.Tensor: """Forward pass of the module. @@ -1351,18 +1359,32 @@ def to_probability_distribution( A IsotropicGaussian distribution (distribution of the next states) """ assert len(states.batch_shape) == 1, "States must have a batch_shape of length 1" - s_curr = states.tensor[:, :-1] + # s_curr = states.tensor[:, :-1] t_curr = states.tensor[:, [-1]] module_output = torch.where( (1.0 - t_curr) < self.dt * 1e-2, # sf case; when t_curr is 1.0 - torch.full_like(s_curr, -float("inf")), # This is the exit action + # torch.full_like(s_curr, -float("inf")), # This is the exit action + torch.full_like(module_output, -float("inf")), # This is the exit action module_output, ) - fwd_mean = self.dt * module_output - fwd_std = torch.tensor(self.sigma * self.dt**0.5, device=fwd_mean.device) - fwd_std = fwd_std.repeat(fwd_mean.shape[0], 1) + drift = module_output[..., : self.s_dim] + if self.n_variance_outputs > 0: + var_part = module_output[..., self.s_dim :] + # Reduce extra variance dims to a single scalar (isotropic for now). + log_std = var_part.mean(dim=-1, keepdim=True) + fwd_std = torch.exp(log_std) * math.sqrt(self.dt) + else: + fwd_std = torch.tensor(self.sigma * self.dt**0.5, device=drift.device) + fwd_std = fwd_std.repeat(drift.shape[0], 1) + + # Match reference behavior: scale diffusion noise (not drift) by t_scale if present. + t_scale_factor = getattr(self.module, "t_scale", 1.0) + if t_scale_factor != 1.0: + fwd_std = fwd_std * math.sqrt(t_scale_factor) + + fwd_mean = self.dt * drift # Optional exploration noise: combine variances (quadrature/logaddexp). exploration_std = policy_kwargs.pop("exploration_std", None) @@ -1394,30 +1416,34 @@ def __init__( pb_module: nn.Module, sigma: float, num_discretization_steps: int, + n_variance_outputs: int = 0, + pb_scale_range: float = 0.1, ): - """Initialize the PinnedBrownianMotionForward. + """Initialize the PinnedBrownianMotionBackward. Args: s_dim: The dimension of the states. pb_module: The neural network module to use for the backward policy. sigma: The diffusion coefficient parameter for the pinned Brownian motion. num_discretization_steps: The number of discretization steps. + n_variance_outputs: Number of variance outputs (0=fixed, 1=learned corr). + pb_scale_range: Scaling applied to learned corrections (tanh-bounded). """ super().__init__(s_dim=s_dim, module=pb_module, is_backward=True) # Pinned Brownian Motion related self.sigma = sigma self.dt = 1.0 / num_discretization_steps + self.n_variance_outputs = n_variance_outputs + self.pb_scale_range = pb_scale_range - def forward(self, input: States) -> torch.Tensor: - """Forward pass of the module. - - Args: - input: The input to the module as states. + @property + def expected_output_dim(self) -> int: + # Drift correction (s_dim) plus optional variance correction outputs. + return self.s_dim + self.n_variance_outputs - Returns: - The output of the module, as a tensor of shape (*batch_shape, output_dim). - """ + def forward(self, input: States) -> torch.Tensor: + """Forward pass of the module.""" out = self.module(self.preprocessor(input)) if self.expected_output_dim is not None: @@ -1438,6 +1464,7 @@ def to_probability_distribution( which is the distribution of the previous states under the pinned Brownian motion process, possibly controlled by the output of the backward module. If the module is a fixed backward module, the `module_output` is a zero vector (no control). + Includes optional learned corrections. Args: states: The states to use, states.tensor.shape = (*batch_shape, s_dim + 1). @@ -1453,14 +1480,27 @@ def to_probability_distribution( t_curr = states.tensor[:, [-1]] # shape: (*batch_shape,) is_s0 = (t_curr - self.dt) < self.dt * 1e-2 # s0 case; when t_curr - dt is 0.0 - bwd_mean = torch.where( + # Analytic Brownian bridge base + base_mean = torch.where( is_s0, s_curr, s_curr * self.dt / t_curr, ) - bwd_std = torch.where( + base_std = torch.where( is_s0, torch.zeros_like(t_curr), self.sigma * (self.dt * (t_curr - self.dt) / t_curr).sqrt(), ) + + # Optional learned corrections (tanh-bounded); when n_variance_outputs==0, only mean corr. + mean_corr = module_output[..., : self.s_dim] * self.pb_scale_range + if self.n_variance_outputs > 0 and module_output.shape[-1] >= self.s_dim + 1: + log_std_corr = module_output[..., [-1]] * self.pb_scale_range + corr_std = torch.exp(log_std_corr) + else: + corr_std = torch.zeros_like(base_std) + + bwd_mean = base_mean + mean_corr + bwd_std = (base_std**2 + corr_std**2).sqrt() + return IsotropicGaussian(bwd_mean, bwd_std) diff --git a/src/gfn/gym/diffusion_sampling.py b/src/gfn/gym/diffusion_sampling.py index 9bea3423..c780c5df 100644 --- a/src/gfn/gym/diffusion_sampling.py +++ b/src/gfn/gym/diffusion_sampling.py @@ -829,7 +829,7 @@ def visualize( ###################################### -### Diffusion Sampling Environment ### +# Diffusion Sampling Environment # ###################################### diff --git a/src/gfn/utils/modules.py b/src/gfn/utils/modules.py index 72697457..8fec8539 100644 --- a/src/gfn/utils/modules.py +++ b/src/gfn/utils/modules.py @@ -1667,6 +1667,11 @@ def __init__( hidden_dim: int = 64, joint_layers: int = 2, zero_init: bool = False, + clipping: bool = False, + gfn_clip: float = 1e4, + t_scale: float = 1.0, + log_var_range: float = 4.0, # kept for parity with learned-var subclass + learn_variance: bool = False, # predict_flow: bool, # TODO: support predict flow for db or subtb # share_embeddings: bool = False, # flow_harmonics_dim: int = 64, @@ -1680,7 +1685,6 @@ def __init__( # clipping: bool = False, # TODO: support clipping # out_clip: float = 1e4, # lp_clip: float = 1e2, - # learn_variance: bool = True, # TODO: support learnable variance # log_var_range: float = 4.0, ): """Initialize the PISGradNetForward. @@ -1703,7 +1707,12 @@ def __init__( self.hidden_dim = hidden_dim self.joint_layers = joint_layers self.zero_init = zero_init - self.out_dim = s_dim # 2 * out_dim if learn_variance is True + self.learn_variance = learn_variance + self.out_dim = s_dim + 1 if self.learn_variance else s_dim + self.clipping = clipping + self.gfn_clip = gfn_clip + self.t_scale = t_scale + self.log_var_range = log_var_range assert ( self.s_emb_dim == self.t_emb_dim @@ -1740,10 +1749,19 @@ def forward( t_emb = self.t_model(t) out = self.joint_model(s_emb, t_emb) + if self.learn_variance: + drift, raw_log_std = out[..., :-1], out[..., [-1]] + if self.clipping: + drift = torch.clamp(drift, -self.gfn_clip, self.gfn_clip) + log_std = torch.tanh(raw_log_std) * self.log_var_range + out = torch.cat([drift, log_std], dim=-1) + else: + if self.clipping: + out = torch.clamp(out, -self.gfn_clip, self.gfn_clip) + # TODO: learn variance, lp, clipping, ... if torch.isnan(out).any(): - print("+ out has {} nans".format(torch.isnan(out).sum())) - out = torch.nan_to_num(out) + raise ValueError("DiffusionPISGradNetForward produced NaNs") return out @@ -1774,3 +1792,83 @@ def forward(self, preprocessed_states: torch.Tensor) -> torch.Tensor: The output of the module (shape: (*batch_shape, s_dim)). """ return torch.zeros_like(preprocessed_states[..., :-1]) + + +class DiffusionPISGradNetBackward(nn.Module): + """Learnable backward correction module (PIS-style) for diffusion. + + Produces mean and optional log-std corrections that are tanh-scaled by + `pb_scale_range` to stay close to the analytic Brownian bridge. + """ + + def __init__( + self, + s_dim: int, + harmonics_dim: int = 64, + t_emb_dim: int = 64, + s_emb_dim: int = 64, + hidden_dim: int = 64, + joint_layers: int = 2, + zero_init: bool = False, + clipping: bool = False, + gfn_clip: float = 1e4, + pb_scale_range: float = 0.1, + log_var_range: float = 4.0, + learn_variance: bool = True, + ) -> None: + super().__init__() + self.s_dim = s_dim + self.out_dim = s_dim + (1 if learn_variance else 0) + self.harmonics_dim = harmonics_dim + self.t_emb_dim = t_emb_dim + self.s_emb_dim = s_emb_dim + self.hidden_dim = hidden_dim + self.joint_layers = joint_layers + self.zero_init = zero_init + self.clipping = clipping + self.gfn_clip = gfn_clip + self.pb_scale_range = pb_scale_range + self.log_var_range = log_var_range + self.learn_variance = learn_variance + + assert ( + self.s_emb_dim == self.t_emb_dim + ), "Dimensionality of state embedding and time embedding should be the same!" + + self.t_model = DiffusionPISTimeEncoding( + self.harmonics_dim, self.t_emb_dim, self.hidden_dim + ) + self.s_model = DiffusionPISStateEncoding(self.s_dim, self.s_emb_dim) + self.joint_model = DiffusionPISJointPolicy( + self.s_emb_dim, + self.hidden_dim, + self.out_dim, + self.joint_layers, + self.zero_init, + ) + + def forward(self, preprocessed_states: torch.Tensor) -> torch.Tensor: + s = preprocessed_states[..., :-1] + t = preprocessed_states[..., -1] + s_emb = self.s_model(s) + t_emb = self.t_model(t) + out = self.joint_model(s_emb, t_emb) + + if self.clipping: + out = torch.clamp(out, -self.gfn_clip, self.gfn_clip) + + # Tanh-scale to stay near Brownian bridge; last dim (if present) is log-std corr. + drift_corr = torch.tanh(out[..., : self.s_dim]) * self.pb_scale_range + if self.learn_variance and out.shape[-1] == self.s_dim + 1: + log_std_corr = torch.tanh(out[..., [-1]]) * self.pb_scale_range + log_std_corr = torch.clamp( + log_std_corr, -self.log_var_range, self.log_var_range + ) + out = torch.cat([drift_corr, log_std_corr], dim=-1) + else: + out = drift_corr + + if torch.isnan(out).any(): + raise ValueError("DiffusionPISGradNetBackward produced NaNs") + + return out diff --git a/tutorials/examples/train_diffusion_rtb.py b/tutorials/examples/train_diffusion_rtb.py index a045738d..23ff9633 100644 --- a/tutorials/examples/train_diffusion_rtb.py +++ b/tutorials/examples/train_diffusion_rtb.py @@ -2,26 +2,34 @@ """ Minimal end-to-end Relative Trajectory Balance (RTB) training script for diffusion. -Uses the 25→9 GMM posterior target (`gmm25_posterior9`) with a learnable -posterior forward policy and a fixed prior forward policy. Loss is RTB -(no backward policy). At the end of training, saves a scatter plot of sampled -states to the user's home directory. +Now includes: +- Optional prior pretraining (auto-runs if the prior checkpoint is missing), so + finetuning starts from the same learned prior used in the reference scripts. +- An optimizer helper that mirrors the reference param grouping (policy vs. logZ). +- Hooks to add additional posterior targets (keep existing defaults). + +Uses the 25→9 GMM posterior target (`gmm25_posterior9`) by default with a learnable +posterior forward policy and a fixed prior forward policy. Loss is RTB (no backward +policy). At the end of training, saves a scatter plot of sampled states to the user's +home directory. """ import argparse +import math import os +from pathlib import Path import matplotlib.pyplot as plt import torch from tqdm import tqdm -from gfn.estimators import PinnedBrownianMotionForward +from gfn.estimators import PinnedBrownianMotionBackward, PinnedBrownianMotionForward from gfn.gflownet import RelativeTrajectoryBalanceGFlowNet from gfn.gym.diffusion_sampling import DiffusionSampling from gfn.gym.helpers.diffusion_utils import viz_2d_slice from gfn.samplers import Sampler from gfn.utils.common import set_seed -from gfn.utils.modules import DiffusionPISGradNetForward +from gfn.utils.modules import DiffusionPISGradNetBackward, DiffusionPISGradNetForward def get_exploration_std( @@ -70,6 +78,11 @@ def build_forward_estimator( hidden_dim: int, joint_layers: int, zero_init: bool, + learn_variance: bool, + clipping: bool, + gfn_clip: float, + t_scale: float, + log_var_range: float, device: torch.device, ) -> PinnedBrownianMotionForward: pf_module = DiffusionPISGradNetForward( @@ -80,14 +93,310 @@ def build_forward_estimator( hidden_dim=hidden_dim, joint_layers=joint_layers, zero_init=zero_init, + clipping=clipping, + gfn_clip=gfn_clip, + t_scale=t_scale, + log_var_range=log_var_range, + learn_variance=learn_variance, ) + return PinnedBrownianMotionForward( s_dim=s_dim, pf_module=pf_module, sigma=sigma, num_discretization_steps=num_steps, + n_variance_outputs=1 if learn_variance else 0, + ).to(device) + + +def build_backward_estimator( + s_dim: int, + num_steps: int, + sigma: float, + harmonics_dim: int, + t_emb_dim: int, + s_emb_dim: int, + hidden_dim: int, + joint_layers: int, + zero_init: bool, + learn_variance: bool, + clipping: bool, + gfn_clip: float, + pb_scale_range: float, + log_var_range: float, + device: torch.device, +) -> PinnedBrownianMotionBackward: + """Build learnable backward policy (pb) with optional variance correction.""" + pb_module = DiffusionPISGradNetBackward( + s_dim=s_dim, + harmonics_dim=harmonics_dim, + t_emb_dim=t_emb_dim, + s_emb_dim=s_emb_dim, + hidden_dim=hidden_dim, + joint_layers=joint_layers, + zero_init=zero_init, + clipping=clipping, + gfn_clip=gfn_clip, + pb_scale_range=pb_scale_range, + log_var_range=log_var_range, + learn_variance=learn_variance, + ) + return PinnedBrownianMotionBackward( + s_dim=s_dim, + pb_module=pb_module, + sigma=sigma, + num_discretization_steps=num_steps, + n_variance_outputs=1 if learn_variance else 0, + pb_scale_range=pb_scale_range, + ).to(device) + + +def _backward_mle_loss( + pf: PinnedBrownianMotionForward, + pb: PinnedBrownianMotionBackward, + samples: torch.Tensor, + num_steps: int, + sigma: float, + t_scale: float, + exploration_std: float = 0.0, + debug: bool = False, +) -> torch.Tensor: + """ + Backward MLE: + 1) Sample backward path via Brownian bridge + optional learned pb corrections. + 2) Evaluate forward log-prob of observed increments under pf (with learned var). + 3) Minimize negative sum of logpf. + """ + device = samples.device + dtype = samples.dtype + bsz, dim = samples.shape + dt = 1.0 / num_steps + base_std_fixed = sigma * math.sqrt(dt) * math.sqrt(t_scale) + log_2pi = math.log(2 * math.pi) + + # Start from terminal states (data samples). + s_curr = samples + logpf_sum = torch.zeros(bsz, device=device, dtype=dtype) + + exploration_std_t = torch.as_tensor( + exploration_std, device=device, dtype=dtype + ).clamp(min=0.0) + + for i in range(num_steps): + # Forward time index for transition s_prev -> s_curr. + t_fwd = torch.full((bsz, 1), 1.0 - (i + 1) * dt, device=device, dtype=dtype) + t_curr = torch.full((bsz, 1), 1.0 - i * dt, device=device, dtype=dtype) + + # Backward sampler (Brownian bridge base + optional corrections). + pb_inp = torch.cat([s_curr, t_curr], dim=1) + pb_out = pb.module(pb_inp) + + is_s0 = (t_curr - dt) < dt * 1e-2 + base_mean = torch.where(is_s0, s_curr, s_curr * dt / t_curr) + base_std = torch.where( + is_s0, + torch.zeros_like(t_curr), + sigma * (dt * (t_curr - dt) / t_curr).sqrt(), + ) + + mean_corr = pb_out[..., :dim] * pb.pb_scale_range + # Learned variance case. + if pb_out.shape[-1] == dim + 1: + log_std_corr = pb_out[..., [-1]] * pb.pb_scale_range + corr_std = torch.exp(log_std_corr) + else: + corr_std = torch.zeros_like(base_std) + + # Combine bridge variance with optional learned correction. + bwd_std = (base_std**2 + corr_std**2).sqrt() + noise = torch.randn_like(s_curr, device=device, dtype=dtype) + s_prev = base_mean + mean_corr + bwd_std * noise + + # Forward log-prob under model for observed increment (s_prev -> s_curr). + model_inp = torch.cat([s_prev, t_fwd], dim=1) + module_out = pf.module(model_inp) + increment = s_curr - s_prev + + # Forward log p(s_prev -> s_curr). + # If model predicts variance (s_dim + 1 output): σ_i = exp(log_std_i)*sqrt(dt*t_scale) + # log p = -0.5 * Σ_i [ ((Δ - dt μ)_i / σ_i)^2 + 2 log σ_i + log 2π ] + if module_out.shape[-1] == dim + 1: + drift = module_out[..., :dim] + log_std = module_out[..., [-1]] + std = torch.exp(log_std) * math.sqrt(dt) * math.sqrt(t_scale) + if exploration_std_t.item() > 0: + std = torch.sqrt(std**2 + exploration_std_t**2) + diff = increment - dt * drift + logpf_step = -0.5 * ((diff / std) ** 2 + 2 * std.log() + log_2pi).sum(dim=1) + else: + # Fixed variance: σ = sigma*sqrt(dt*t_scale); same log p form with shared σ. + drift = module_out + std = base_std_fixed + if exploration_std_t.item() > 0: + std = math.sqrt(base_std_fixed**2 + float(exploration_std_t.item()) ** 2) + diff = increment - dt * drift + logpf_step = -0.5 * ((diff / std) ** 2).sum(dim=1) - 0.5 * dim * ( + log_2pi + 2 * math.log(std) + ) + + logpf_sum += logpf_step + s_curr = s_prev + + # Negative log-likelihood (mean over batch). + if debug and torch.isnan(logpf_sum).any(): + raise ValueError("NaNs in logpf_sum during pretrain loss.") + + return -(logpf_sum.mean()) + + +def pretrain_prior_if_needed( + args: argparse.Namespace, + device: torch.device, + s_dim: int, +) -> Path: + """ + Auto-pretrain the prior if the checkpoint is missing. + Saves to args.prior_ckpt_path and returns the resolved path. + """ + ckpt_path = Path(os.path.expanduser(args.prior_ckpt_path)) + if ckpt_path.exists() or not args.pretrain_if_missing: + return ckpt_path + + print(f"[pretrain] Prior checkpoint missing at {ckpt_path}, starting pretraining...") + + env_prior = DiffusionSampling( + target_str=args.pretrain_target, + target_kwargs=None, + num_discretization_steps=args.pretrain_num_steps, + device=device, + debug=__debug__, + ) + + pf_prior = build_forward_estimator( + s_dim=s_dim, + num_steps=args.pretrain_num_steps, + sigma=args.pretrain_sigma, + harmonics_dim=args.harmonics_dim, + t_emb_dim=args.t_emb_dim, + s_emb_dim=args.s_emb_dim, + hidden_dim=args.hidden_dim, + joint_layers=args.joint_layers, + zero_init=args.zero_init, + learn_variance=args.learn_variance, + clipping=args.clipping, + gfn_clip=args.gfn_clip, + t_scale=args.t_scale, + log_var_range=args.log_var_range, + device=device, + ) + + # Build backward estimator. + pb_module = DiffusionPISGradNetBackward( + s_dim=s_dim, + harmonics_dim=args.harmonics_dim, + t_emb_dim=args.t_emb_dim, + s_emb_dim=args.s_emb_dim, + hidden_dim=args.hidden_dim, + joint_layers=args.joint_layers, + zero_init=args.zero_init, + clipping=args.clipping, + gfn_clip=args.gfn_clip, + pb_scale_range=args.pb_scale_range, + log_var_range=args.log_var_range, + learn_variance=args.learn_variance, + ) + pb_prior = PinnedBrownianMotionBackward( + s_dim=s_dim, + pb_module=pb_module, + sigma=args.sigma, + num_discretization_steps=args.num_steps, + n_variance_outputs=1 if args.learn_variance else 0, + pb_scale_range=args.pb_scale_range, ).to(device) + optim_params = [{"params": pf_prior.parameters(), "lr": args.pretrain_lr}] + if args.learn_pb: + optim_params.append( + {"params": pb_prior.parameters(), "lr": args.pretrain_lr_back} + ) + optimizer = torch.optim.Adam( + optim_params, + lr=args.pretrain_lr, + weight_decay=args.pretrain_weight_decay, + ) + + pf_prior.train() + pbar = tqdm(range(args.pretrain_steps), dynamic_ncols=True, desc="pretrain_prior") + + for it in pbar: + with torch.no_grad(): + batch = env_prior.target.sample(args.pretrain_batch_size) + optimizer.zero_grad() + loss = _backward_mle_loss( + pf_prior, + pb_prior, + batch, + num_steps=args.pretrain_num_steps, + sigma=args.pretrain_sigma, + t_scale=args.t_scale, + exploration_std=args.pretrain_exploration_factor, + debug=args.debug_pretrain, + ) + loss.backward() + if args.debug_pretrain: + grad_list = [ + p.grad.norm() for p in pf_prior.parameters() if p.grad is not None + ] + total_norm = ( + torch.norm(torch.stack(grad_list)) if grad_list else torch.tensor(0.0) + ) + print( + f"[pretrain][debug] step={it} loss={loss.item():.4e} grad_norm={total_norm.item():.4e}" + ) + if torch.isnan(total_norm): + raise ValueError("NaN grad norm in pretrain.") + + optimizer.step() + + def _save_checkpoint(pf_prior, pb_prior, optimizer, it, ckpt_path): + ckpt_path.parent.mkdir(parents=True, exist_ok=True) + torch.save( + { + "pf_state_dict": pf_prior.state_dict(), + "pb_state_dict": pb_prior.state_dict(), + "optimizer_state_dict": optimizer.state_dict(), + "step": it + 1, + }, + ckpt_path, + ) + + if (it + 1) % args.pretrain_log_interval == 0 or it == args.pretrain_steps - 1: + pbar.set_postfix({"loss": float(loss.item())}) + if ( + it + 1 + ) % args.pretrain_ckpt_interval == 0 or it == args.pretrain_steps - 1: + _save_checkpoint(pf_prior, pb_prior, optimizer, it, ckpt_path) + + _save_checkpoint(pf_prior, pb_prior, optimizer, it, ckpt_path) + print(f"[pretrain] Saved prior to {ckpt_path}") + + # Quick visual check of the learned prior. + with torch.no_grad(): + sampler_prior = Sampler(estimator=pf_prior) + term_states = sampler_prior.sample_terminating_states( + env_prior, n=args.pretrain_vis_n + ) + xs = term_states.tensor[:, :-1] + plot_samples( + xs, + env_prior.target, + os.path.expanduser(args.pretrain_save_fig_path), + return_fig=False, + ) + print(f"[pretrain] Saved prior samples plot to {args.pretrain_save_fig_path}") + + return ckpt_path + def plot_samples( xs: torch.Tensor, @@ -96,7 +405,6 @@ def plot_samples( return_fig: bool = False, ): """Contour + scatter plot of samples against the posterior density.""" - assert target.plot_border is not None, "Target must define plot_border for plotting." # If target exposes a posterior density, build a lightweight shim with the same @@ -123,23 +431,26 @@ def _posterior_log_reward(x: torch.Tensor) -> torch.Tensor: grid_width_n_points=200, max_n_samples=2000, ) + ax.set_title("RTB posterior samples") fig.tight_layout() dirpath = os.path.dirname(save_path) + if dirpath: os.makedirs(dirpath, exist_ok=True) fig.savefig(save_path) + if return_fig: return fig + plt.close(fig) + return None def main(args: argparse.Namespace) -> None: set_seed(args.seed) - device = torch.device( - "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu" - ) + device = torch.device(args.device) # Environment / target env = DiffusionSampling( @@ -162,10 +473,15 @@ def main(args: argparse.Namespace) -> None: hidden_dim=args.hidden_dim, joint_layers=args.joint_layers, zero_init=args.zero_init, + learn_variance=args.learn_variance, + clipping=args.clipping, + gfn_clip=args.gfn_clip, + t_scale=args.t_scale, + log_var_range=args.log_var_range, device=device, ) - # Prior forward (fixed, no grad) + # Prior forward (fixed, no grad). Will be loaded from checkpoint if available. pf_prior = build_forward_estimator( s_dim=s_dim, num_steps=args.num_steps, @@ -176,8 +492,30 @@ def main(args: argparse.Namespace) -> None: hidden_dim=args.hidden_dim, joint_layers=args.joint_layers, zero_init=args.zero_init, + learn_variance=args.learn_variance, + clipping=args.clipping, + gfn_clip=args.gfn_clip, + t_scale=args.t_scale, + log_var_range=args.log_var_range, device=device, ) + + # Pretrain prior if needed, then load weights into both prior and posterior so + # finetuning starts from the learned prior (mirrors reference behavior). + prior_ckpt_path = pretrain_prior_if_needed(args, device, s_dim) + if prior_ckpt_path.exists(): + ckpt = torch.load(prior_ckpt_path, map_location=device) + state = ckpt.get("pf_state_dict", ckpt) + missing, unexpected = pf_prior.load_state_dict(state, strict=False) + if missing or unexpected: + print(f"[warn] prior load missing={missing}, unexpected={unexpected}") + # Initialize posterior from the same prior weights. + pf_post.load_state_dict(pf_prior.state_dict(), strict=False) + else: + raise Exception( + f"pretrained weights not found at {prior_ckpt_path}, pretraining failed" + ) + pf_prior.eval() for p in pf_prior.parameters(): p.requires_grad_(False) @@ -191,11 +529,13 @@ def main(args: argparse.Namespace) -> None: ).to(device) sampler = Sampler(estimator=pf_post) + + param_groups = [ + {"params": gflownet.pf.parameters(), "lr": args.lr}, + {"params": gflownet.logz_parameters(), "lr": args.lr_logz}, + ] optimizer = torch.optim.Adam( - [ - {"params": gflownet.pf_pb_parameters(), "lr": args.lr}, - {"params": gflownet.logz_parameters(), "lr": args.lr_logz}, - ] + param_groups, lr=args.lr, weight_decay=args.weight_decay ) for it in (pbar := tqdm(range(args.n_iterations), dynamic_ncols=True)): @@ -245,6 +585,13 @@ def main(args: argparse.Namespace) -> None: parser = argparse.ArgumentParser() # System parser.add_argument("--no_cuda", action="store_true", help="Prevent CUDA usage") + parser.add_argument( + "--device", + type=str, + default="cpu", + choices=["cpu", "cuda", "mps"], + help="Device for training.", + ) parser.add_argument("--seed", type=int, default=0, help="Random seed") # Target / environment @@ -255,7 +602,10 @@ def main(args: argparse.Namespace) -> None: help="Diffusion target (default: gmm25_posterior9)", ) parser.add_argument( - "--num_steps", type=int, default=256, help="number of discretization steps" + "--num_steps", + type=int, + default=100, + help="number of discretization steps (reference=100)", ) parser.add_argument( "--sigma", @@ -268,13 +618,43 @@ def main(args: argparse.Namespace) -> None: parser.add_argument("--harmonics_dim", type=int, default=64) parser.add_argument("--t_emb_dim", type=int, default=64) parser.add_argument("--s_emb_dim", type=int, default=64) - parser.add_argument("--hidden_dim", type=int, default=128) + parser.add_argument("--hidden_dim", type=int, default=64) parser.add_argument("--joint_layers", type=int, default=2) - parser.add_argument("--zero_init", action="store_true") + parser.add_argument("--zero_init", action="store_true", default=True) + parser.add_argument( + "--learn_variance", + action=argparse.BooleanOptionalAction, + default=False, + help="Use learned scalar variance in the diffusion forward policy (ref default: off)", + ) + parser.add_argument( + "--clipping", + action=argparse.BooleanOptionalAction, + default=False, + help="Clip model outputs (reference default: off)", + ) + parser.add_argument( + "--gfn_clip", + type=float, + default=1e4, + help="Clipping value for drift outputs (reference: 1e4)", + ) + parser.add_argument( + "--t_scale", + type=float, + default=5.0, + help="Scale diffusion std to mirror reference (reference: 5.0)", + ) + parser.add_argument( + "--log_var_range", + type=float, + default=4.0, + help="Range to bound learned log-std when learn_variance is enabled (reference: 4.0)", + ) # Training parser.add_argument("--n_iterations", type=int, default=5000) - parser.add_argument("--batch_size", type=int, default=256) + parser.add_argument("--batch_size", type=int, default=500) parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument("--lr_logz", type=float, default=1e-1) parser.add_argument("--beta", type=float, default=1.0, help="RTB beta multiplier") @@ -288,8 +668,8 @@ def main(args: argparse.Namespace) -> None: parser.add_argument( "--exploration_factor", type=float, - default=5.0, - help="Base exploration std applied per step when exploratory is enabled", + default=0.5, + help="Base exploration std applied per step when exploratory is enabled (reference ~0.5)", ) parser.add_argument( "--exploration_warm_down_start", @@ -300,7 +680,7 @@ def main(args: argparse.Namespace) -> None: parser.add_argument( "--exploration_warm_down_end", type=float, - default=3000, + default=4500, help="Linearly warm down exploration after n iters (to 0 by exploration_warm_down_end iters)", ) @@ -317,5 +697,126 @@ def main(args: argparse.Namespace) -> None: help="Path to save final samples plot", ) + # Prior pretraining / loading + parser.add_argument( + "--prior_ckpt_path", + type=str, + default="output/prior.pt", + help="Path to save/load the pretrained prior checkpoint", + ) + parser.add_argument( + "--pretrain_if_missing", + action=argparse.BooleanOptionalAction, + default=True, + help="Auto-run prior pretraining if the checkpoint is missing", + ) + parser.add_argument( + "--pretrain_use_bwd_mle", + action=argparse.BooleanOptionalAction, + default=True, + help="Use exact backward MLE (reference) instead of surrogate bridge loss", + ) + parser.add_argument( + "--learn_pb", + action=argparse.BooleanOptionalAction, + default=False, + help="Enable learned backward policy corrections (pb) during pretrain", + ) + parser.add_argument( + "--pb_scale_range", + type=float, + default=0.1, + help="Tanh scaling for backward mean/var corrections (reference: 0.1)", + ) + parser.add_argument( + "--pretrain_target", + type=str, + default="gmm25_prior", + help="Target used for prior pretraining (matches reference prior)", + ) + parser.add_argument( + "--pretrain_num_steps", + type=int, + default=100, + help="Discretization steps for prior pretraining (reference=100)", + ) + parser.add_argument( + "--pretrain_sigma", + type=float, + default=2.0, + help="Diffusion coefficient for prior pretraining", + ) + parser.add_argument( + "--pretrain_exploration_factor", + type=float, + default=0.0, + help="Exploration std for pretrain backward MLE (reference: off by default)", + ) + parser.add_argument( + "--pretrain_batch_size", + type=int, + default=500, + help="Batch size for prior pretraining", + ) + parser.add_argument( + "--pretrain_steps", + type=int, + default=10000, + help="Training steps for prior pretraining", + ) + parser.add_argument( + "--pretrain_lr", type=float, default=1e-3, help="LR for prior pretraining" + ) + parser.add_argument( + "--pretrain_lr_back", + type=float, + default=1e-3, + help="LR for backward policy during pretrain", + ) + parser.add_argument( + "--pretrain_weight_decay", + type=float, + default=0.0, + help="Weight decay for prior pretraining", + ) + parser.add_argument( + "--pretrain_log_interval", + type=int, + default=100, + help="Logging interval (steps) during prior pretraining (reference: 100)", + ) + parser.add_argument( + "--pretrain_ckpt_interval", + type=int, + default=1000, + help="Checkpoint interval during prior pretraining (reference: 1000)", + ) + parser.add_argument( + "--pretrain_vis_n", + type=int, + default=2000, + help="Number of samples to plot after prior pretraining", + ) + parser.add_argument( + "--pretrain_save_fig_path", + type=str, + default="output/prior_pretrain.png", + help="Path to save prior samples plot after pretraining", + ) + parser.add_argument( + "--debug_pretrain", + action=argparse.BooleanOptionalAction, + default=False, + help="Enable extra NaN/grad checks during pretrain loss", + ) + + # Optimizer extras + parser.add_argument( + "--weight_decay", + type=float, + default=0.0, + help="Weight decay for the RTB optimizer (policy/logZ)", + ) + args = parser.parse_args() main(args) From 6ea54cdf6995b700eb9e701a43e429bcc45dff81 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Tue, 16 Dec 2025 03:15:44 -0500 Subject: [PATCH 05/26] loss bugfixes --- tutorials/examples/train_diffusion_rtb.py | 73 ++++++++++++++--------- 1 file changed, 46 insertions(+), 27 deletions(-) diff --git a/tutorials/examples/train_diffusion_rtb.py b/tutorials/examples/train_diffusion_rtb.py index 23ff9633..b1f57a80 100644 --- a/tutorials/examples/train_diffusion_rtb.py +++ b/tutorials/examples/train_diffusion_rtb.py @@ -29,7 +29,11 @@ from gfn.gym.helpers.diffusion_utils import viz_2d_slice from gfn.samplers import Sampler from gfn.utils.common import set_seed -from gfn.utils.modules import DiffusionPISGradNetBackward, DiffusionPISGradNetForward +from gfn.utils.modules import ( + DiffusionFixedBackwardModule, + DiffusionPISGradNetBackward, + DiffusionPISGradNetForward, +) def get_exploration_std( @@ -192,7 +196,12 @@ def _backward_mle_loss( pb_out = pb.module(pb_inp) is_s0 = (t_curr - dt) < dt * 1e-2 - base_mean = torch.where(is_s0, s_curr, s_curr * dt / t_curr) + # Brownian bridge: at t_prev=0 we must hit 0 (base_mean=0, std=0). + base_mean = torch.where( + is_s0, + torch.zeros_like(s_curr), + s_curr * (1.0 - dt / t_curr), + ) base_std = torch.where( is_s0, torch.zeros_like(t_curr), @@ -200,6 +209,7 @@ def _backward_mle_loss( ) mean_corr = pb_out[..., :dim] * pb.pb_scale_range + # Learned variance case. if pb_out.shape[-1] == dim + 1: log_std_corr = pb_out[..., [-1]] * pb.pb_scale_range @@ -207,8 +217,8 @@ def _backward_mle_loss( else: corr_std = torch.zeros_like(base_std) - # Combine bridge variance with optional learned correction. - bwd_std = (base_std**2 + corr_std**2).sqrt() + # Combine bridge variance with optional learned correction; match forward scaling via t_scale. + bwd_std = (base_std**2 + corr_std**2).sqrt() * math.sqrt(t_scale) noise = torch.randn_like(s_curr, device=device, dtype=dtype) s_prev = base_mean + mean_corr + bwd_std * noise @@ -290,28 +300,36 @@ def pretrain_prior_if_needed( device=device, ) - # Build backward estimator. - pb_module = DiffusionPISGradNetBackward( - s_dim=s_dim, - harmonics_dim=args.harmonics_dim, - t_emb_dim=args.t_emb_dim, - s_emb_dim=args.s_emb_dim, - hidden_dim=args.hidden_dim, - joint_layers=args.joint_layers, - zero_init=args.zero_init, - clipping=args.clipping, - gfn_clip=args.gfn_clip, - pb_scale_range=args.pb_scale_range, - log_var_range=args.log_var_range, - learn_variance=args.learn_variance, - ) + # Build backward estimator: learned pb if enabled, else fixed Brownian bridge. + if args.learn_pb: + pb_module = DiffusionPISGradNetBackward( + s_dim=s_dim, + harmonics_dim=args.harmonics_dim, + t_emb_dim=args.t_emb_dim, + s_emb_dim=args.s_emb_dim, + hidden_dim=args.hidden_dim, + joint_layers=args.joint_layers, + zero_init=args.zero_init, + clipping=args.clipping, + gfn_clip=args.gfn_clip, + pb_scale_range=args.pb_scale_range, + log_var_range=args.log_var_range, + learn_variance=args.learn_variance, + ) + n_var_outputs = 1 if args.learn_variance else 0 + pb_scale_range = args.pb_scale_range + else: + pb_module = DiffusionFixedBackwardModule(s_dim) + n_var_outputs = 0 + pb_scale_range = 0.0 + pb_prior = PinnedBrownianMotionBackward( s_dim=s_dim, pb_module=pb_module, - sigma=args.sigma, - num_discretization_steps=args.num_steps, - n_variance_outputs=1 if args.learn_variance else 0, - pb_scale_range=args.pb_scale_range, + sigma=args.pretrain_sigma, + num_discretization_steps=args.pretrain_num_steps, + n_variance_outputs=n_var_outputs, + pb_scale_range=pb_scale_range, ).to(device) optim_params = [{"params": pf_prior.parameters(), "lr": args.pretrain_lr}] @@ -383,10 +401,11 @@ def _save_checkpoint(pf_prior, pb_prior, optimizer, it, ckpt_path): # Quick visual check of the learned prior. with torch.no_grad(): sampler_prior = Sampler(estimator=pf_prior) - term_states = sampler_prior.sample_terminating_states( - env_prior, n=args.pretrain_vis_n + trajectories = sampler_prior.sample_trajectories( + env=env_prior, + n=args.pretrain_vis_n, ) - xs = term_states.tensor[:, :-1] + xs = trajectories.terminating_states.tensor[:, :-1] plot_samples( xs, env_prior.target, @@ -642,7 +661,7 @@ def main(args: argparse.Namespace) -> None: parser.add_argument( "--t_scale", type=float, - default=5.0, + default=1.0, # 5.0 help="Scale diffusion std to mirror reference (reference: 5.0)", ) parser.add_argument( From a0666da974ceefc3c8bb229c0b7f8cfc2ebccf7f Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Tue, 16 Dec 2025 13:09:17 -0500 Subject: [PATCH 06/26] prior learning works, still working on finetune step --- tutorials/examples/train_diffusion_rtb.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tutorials/examples/train_diffusion_rtb.py b/tutorials/examples/train_diffusion_rtb.py index b1f57a80..d4c2fe65 100644 --- a/tutorials/examples/train_diffusion_rtb.py +++ b/tutorials/examples/train_diffusion_rtb.py @@ -196,7 +196,10 @@ def _backward_mle_loss( pb_out = pb.module(pb_inp) is_s0 = (t_curr - dt) < dt * 1e-2 - # Brownian bridge: at t_prev=0 we must hit 0 (base_mean=0, std=0). + # Brownian bridge (t_prev = t_curr - dt), conditioned to hit 0 at t=0: + # mean_bb = s_curr * (1 - dt / t_curr) + # std_bb = sigma * sqrt(dt * (t_curr - dt) / t_curr) + # At t_prev=0, both mean and std collapse to 0. base_mean = torch.where( is_s0, torch.zeros_like(s_curr), @@ -217,8 +220,8 @@ def _backward_mle_loss( else: corr_std = torch.zeros_like(base_std) - # Combine bridge variance with optional learned correction; match forward scaling via t_scale. - bwd_std = (base_std**2 + corr_std**2).sqrt() * math.sqrt(t_scale) + # Combine bridge variance with optional learned correction (no t_scale here; forward handles it). + bwd_std = (base_std**2 + corr_std**2).sqrt() noise = torch.randn_like(s_curr, device=device, dtype=dtype) s_prev = base_mean + mean_corr + bwd_std * noise From f771c2c3c1a1f82c5d17a1f98963ae435aa6b325 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Tue, 16 Dec 2025 13:58:06 -0500 Subject: [PATCH 07/26] RTB is working - next step is to factorize --- tutorials/examples/train_diffusion_rtb.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tutorials/examples/train_diffusion_rtb.py b/tutorials/examples/train_diffusion_rtb.py index d4c2fe65..7be8aa25 100644 --- a/tutorials/examples/train_diffusion_rtb.py +++ b/tutorials/examples/train_diffusion_rtb.py @@ -473,6 +473,7 @@ def _posterior_log_reward(x: torch.Tensor) -> torch.Tensor: def main(args: argparse.Namespace) -> None: set_seed(args.seed) device = torch.device(args.device) + torch.set_default_device(device) # Environment / target env = DiffusionSampling( @@ -715,7 +716,7 @@ def main(args: argparse.Namespace) -> None: parser.add_argument( "--save_fig_path", type=str, - default="~/rtb_final_samples.png", + default="output/rtb_final_samples.png", help="Path to save final samples plot", ) From 51ae63d51ee2825219d2c19f4298a3fc63d76db2 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Wed, 17 Dec 2025 00:50:31 -0500 Subject: [PATCH 08/26] refactored MLE pipeline --- src/gfn/estimators.py | 7 +- testing/gflownet/test_mle_diffusion.py | 164 ++++++++++++++++ tutorials/examples/output/.gitignore | 3 +- tutorials/examples/train_diffusion_rtb.py | 222 ++++++---------------- 4 files changed, 232 insertions(+), 164 deletions(-) create mode 100644 testing/gflownet/test_mle_diffusion.py diff --git a/src/gfn/estimators.py b/src/gfn/estimators.py index 6c93314b..b7c6f9e0 100644 --- a/src/gfn/estimators.py +++ b/src/gfn/estimators.py @@ -1481,10 +1481,13 @@ def to_probability_distribution( is_s0 = (t_curr - self.dt) < self.dt * 1e-2 # s0 case; when t_curr - dt is 0.0 # Analytic Brownian bridge base + # Brownian bridge mean toward 0 at t=0: + # E[s_{t-dt} | s_t] = s_t * (1 - dt / t) and collapses to 0 at the start. + # Shapes: s_curr (batch, s_dim), t_curr (batch, 1), dt is scalar. base_mean = torch.where( is_s0, - s_curr, - s_curr * self.dt / t_curr, + torch.zeros_like(s_curr), + s_curr * (1.0 - self.dt / t_curr), ) base_std = torch.where( is_s0, diff --git a/testing/gflownet/test_mle_diffusion.py b/testing/gflownet/test_mle_diffusion.py new file mode 100644 index 00000000..2faccb36 --- /dev/null +++ b/testing/gflownet/test_mle_diffusion.py @@ -0,0 +1,164 @@ +import math + +import torch + +from gfn.estimators import PinnedBrownianMotionBackward, PinnedBrownianMotionForward +from gfn.gflownet.mle import MLEDiffusion +from gfn.utils.modules import DiffusionFixedBackwardModule + + +class ZeroDriftModule(torch.nn.Module): + """Returns zero drift (and optional zero log-std if learn_variance).""" + + def __init__(self, s_dim: int, learn_variance: bool = False): + super().__init__() + self.s_dim = s_dim + self.learn_variance = learn_variance + # Required by IdentityPreprocessor in estimators. + self.input_dim = s_dim + 1 # state dim + time + + def forward(self, x: torch.Tensor) -> torch.Tensor: # x shape: (B, s_dim + 1) + batch = x.shape[0] + if self.learn_variance: + return torch.zeros(batch, self.s_dim + 1, device=x.device, dtype=x.dtype) + return torch.zeros(batch, self.s_dim, device=x.device, dtype=x.dtype) + + +def _build_estimators(s_dim: int, learn_variance: bool, num_steps: int = 1): + """Helper to build deterministic PF/PB for tests.""" + pf_module = ZeroDriftModule(s_dim=s_dim, learn_variance=learn_variance) + pf = PinnedBrownianMotionForward( + s_dim=s_dim, + pf_module=pf_module, + sigma=1.0, + num_discretization_steps=num_steps, + n_variance_outputs=1 if learn_variance else 0, + ) + pb = PinnedBrownianMotionBackward( + s_dim=s_dim, + pb_module=DiffusionFixedBackwardModule(s_dim), + sigma=1.0, + num_discretization_steps=num_steps, + n_variance_outputs=0, + pb_scale_range=0.1, + ) + return pf, pb + + +def test_mle_loss_fixed_variance_zero_terminal(): + """ + With zero drift, fixed variance (sigma=1), num_steps=1, and terminal states at 0, + the loss is deterministic: log(2π) per dimension /2 summed over dim -> log(2π). + """ + torch.manual_seed(0) + s_dim = 2 + pf, pb = _build_estimators(s_dim=s_dim, learn_variance=False, num_steps=1) + trainer = MLEDiffusion( + pf=pf, + pb=pb, + num_steps=1, + sigma=1.0, + t_scale=1.0, + pb_scale_range=0.1, + learn_variance=False, + ) + + batch = torch.zeros(4, s_dim) # terminal states near (0,0) + loss = trainer.loss(batch, exploration_std=0.0) + + expected_logp = -0.5 * s_dim * math.log(2 * math.pi) # log p for zero increment + expected_loss = -expected_logp # num_steps=1, loss = -logpf_sum.mean() + assert torch.isfinite(loss) + assert torch.allclose(loss, torch.tensor(expected_loss), atol=1e-6) + + +def test_mle_loss_learned_variance_zero_terminal(): + """ + Learned variance head returning log_std=0 should match the fixed-variance case + (std = exp(0)*sqrt(dt)*sqrt(t_scale) = 1 when num_steps=1, t_scale=1). + """ + torch.manual_seed(0) + s_dim = 2 + pf, pb = _build_estimators(s_dim=s_dim, learn_variance=True, num_steps=1) + trainer = MLEDiffusion( + pf=pf, + pb=pb, + num_steps=1, + sigma=1.0, + t_scale=1.0, + pb_scale_range=0.1, + learn_variance=True, + ) + + batch = torch.zeros(3, s_dim) + loss = trainer.loss(batch, exploration_std=0.0) + + expected_logp = -0.5 * s_dim * math.log(2 * math.pi) + expected_loss = -expected_logp + assert torch.isfinite(loss) + assert torch.allclose(loss, torch.tensor(expected_loss), atol=1e-6) + + +def test_backward_bridge_mean_std_match_formula(): + """ + Validate Brownian bridge mean/std against closed form for num_steps=2 at t=1. + For s_curr=0, mean should be 0, std should be sigma*sqrt(dt*(t-dt)/t). + """ + s_dim = 2 + num_steps = 2 + sigma = 1.0 + pf, pb = _build_estimators(s_dim=s_dim, learn_variance=False, num_steps=num_steps) + + # Manually run the PB module once at t=1. + dt = 1.0 / num_steps + bsz = 3 + s_curr = torch.zeros(bsz, s_dim) + t_curr = torch.full((bsz, 1), 1.0) + pb_inp = torch.cat([s_curr, t_curr], dim=1) + pb_out = pb.module(pb_inp) + + is_s0 = (t_curr - dt) < dt * 1e-2 + base_mean = torch.where( + is_s0, + torch.zeros_like(s_curr), + s_curr * (1.0 - dt / t_curr), + ) + base_std = torch.where( + is_s0, + torch.zeros_like(t_curr), + sigma * (dt * (t_curr - dt) / t_curr).sqrt(), + ) + + # For zero corrections, mean_corr=0, corr_std=0. + mean_corr = pb_out[..., :s_dim] * pb.pb_scale_range + assert torch.allclose(mean_corr, torch.zeros_like(mean_corr)) + assert torch.allclose(base_mean, torch.zeros_like(base_mean)) + expected_std = sigma * math.sqrt(dt * (1.0 - dt) / 1.0) + assert torch.allclose(base_std.squeeze(-1), torch.full((bsz,), expected_std)) + + +def test_forward_logprob_zero_increment_matches_formula(): + """ + For PF with zero drift/log_std=0, num_steps=1, t_scale=1, increment=0, + the log-prob per dim is -0.5*log(2π); total logp = that * s_dim. + """ + s_dim = 2 + pf, pb = _build_estimators(s_dim=s_dim, learn_variance=True, num_steps=1) + trainer = MLEDiffusion( + pf=pf, + pb=pb, + num_steps=1, + sigma=1.0, + t_scale=1.0, + pb_scale_range=0.1, + learn_variance=True, + ) + + batch = torch.zeros(2, s_dim) + # Manually compute expected logp for zero increment: + # std = exp(0) * sqrt(dt) * sqrt(t_scale) = 1; logp = -0.5 * s_dim * log(2π) + expected_logp = -0.5 * s_dim * math.log(2 * math.pi) + expected_loss = -expected_logp + loss = trainer.loss(batch, exploration_std=0.0) + assert torch.isfinite(loss) + assert torch.allclose(loss, torch.tensor(expected_loss), atol=1e-6) diff --git a/tutorials/examples/output/.gitignore b/tutorials/examples/output/.gitignore index 3dca272e..e5513ab3 100644 --- a/tutorials/examples/output/.gitignore +++ b/tutorials/examples/output/.gitignore @@ -1,4 +1,5 @@ *.pt *.jpg *.jpeg -*.png \ No newline at end of file +*.png +*.zip \ No newline at end of file diff --git a/tutorials/examples/train_diffusion_rtb.py b/tutorials/examples/train_diffusion_rtb.py index 7be8aa25..463ebf5c 100644 --- a/tutorials/examples/train_diffusion_rtb.py +++ b/tutorials/examples/train_diffusion_rtb.py @@ -15,7 +15,6 @@ """ import argparse -import math import os from pathlib import Path @@ -25,6 +24,7 @@ from gfn.estimators import PinnedBrownianMotionBackward, PinnedBrownianMotionForward from gfn.gflownet import RelativeTrajectoryBalanceGFlowNet +from gfn.gflownet.mle import MLEDiffusion from gfn.gym.diffusion_sampling import DiffusionSampling from gfn.gym.helpers.diffusion_utils import viz_2d_slice from gfn.samplers import Sampler @@ -36,6 +36,17 @@ ) +def get_debug_metrics(estimator: torch.nn.Module) -> tuple[torch.Tensor, bool]: + """Compute gradient norm for a module; return (total_norm, has_nan).""" + grad_list = [p.grad.norm() for p in estimator.parameters() if p.grad is not None] + if grad_list: + total_norm = torch.norm(torch.stack(grad_list)) + else: + total_norm = torch.tensor(0.0, device=next(estimator.parameters()).device) + has_nan = torch.isnan(total_norm) + return total_norm, bool(has_nan) + + def get_exploration_std( iteration: int, exploration_factor: float = 0.1, @@ -155,113 +166,6 @@ def build_backward_estimator( ).to(device) -def _backward_mle_loss( - pf: PinnedBrownianMotionForward, - pb: PinnedBrownianMotionBackward, - samples: torch.Tensor, - num_steps: int, - sigma: float, - t_scale: float, - exploration_std: float = 0.0, - debug: bool = False, -) -> torch.Tensor: - """ - Backward MLE: - 1) Sample backward path via Brownian bridge + optional learned pb corrections. - 2) Evaluate forward log-prob of observed increments under pf (with learned var). - 3) Minimize negative sum of logpf. - """ - device = samples.device - dtype = samples.dtype - bsz, dim = samples.shape - dt = 1.0 / num_steps - base_std_fixed = sigma * math.sqrt(dt) * math.sqrt(t_scale) - log_2pi = math.log(2 * math.pi) - - # Start from terminal states (data samples). - s_curr = samples - logpf_sum = torch.zeros(bsz, device=device, dtype=dtype) - - exploration_std_t = torch.as_tensor( - exploration_std, device=device, dtype=dtype - ).clamp(min=0.0) - - for i in range(num_steps): - # Forward time index for transition s_prev -> s_curr. - t_fwd = torch.full((bsz, 1), 1.0 - (i + 1) * dt, device=device, dtype=dtype) - t_curr = torch.full((bsz, 1), 1.0 - i * dt, device=device, dtype=dtype) - - # Backward sampler (Brownian bridge base + optional corrections). - pb_inp = torch.cat([s_curr, t_curr], dim=1) - pb_out = pb.module(pb_inp) - - is_s0 = (t_curr - dt) < dt * 1e-2 - # Brownian bridge (t_prev = t_curr - dt), conditioned to hit 0 at t=0: - # mean_bb = s_curr * (1 - dt / t_curr) - # std_bb = sigma * sqrt(dt * (t_curr - dt) / t_curr) - # At t_prev=0, both mean and std collapse to 0. - base_mean = torch.where( - is_s0, - torch.zeros_like(s_curr), - s_curr * (1.0 - dt / t_curr), - ) - base_std = torch.where( - is_s0, - torch.zeros_like(t_curr), - sigma * (dt * (t_curr - dt) / t_curr).sqrt(), - ) - - mean_corr = pb_out[..., :dim] * pb.pb_scale_range - - # Learned variance case. - if pb_out.shape[-1] == dim + 1: - log_std_corr = pb_out[..., [-1]] * pb.pb_scale_range - corr_std = torch.exp(log_std_corr) - else: - corr_std = torch.zeros_like(base_std) - - # Combine bridge variance with optional learned correction (no t_scale here; forward handles it). - bwd_std = (base_std**2 + corr_std**2).sqrt() - noise = torch.randn_like(s_curr, device=device, dtype=dtype) - s_prev = base_mean + mean_corr + bwd_std * noise - - # Forward log-prob under model for observed increment (s_prev -> s_curr). - model_inp = torch.cat([s_prev, t_fwd], dim=1) - module_out = pf.module(model_inp) - increment = s_curr - s_prev - - # Forward log p(s_prev -> s_curr). - # If model predicts variance (s_dim + 1 output): σ_i = exp(log_std_i)*sqrt(dt*t_scale) - # log p = -0.5 * Σ_i [ ((Δ - dt μ)_i / σ_i)^2 + 2 log σ_i + log 2π ] - if module_out.shape[-1] == dim + 1: - drift = module_out[..., :dim] - log_std = module_out[..., [-1]] - std = torch.exp(log_std) * math.sqrt(dt) * math.sqrt(t_scale) - if exploration_std_t.item() > 0: - std = torch.sqrt(std**2 + exploration_std_t**2) - diff = increment - dt * drift - logpf_step = -0.5 * ((diff / std) ** 2 + 2 * std.log() + log_2pi).sum(dim=1) - else: - # Fixed variance: σ = sigma*sqrt(dt*t_scale); same log p form with shared σ. - drift = module_out - std = base_std_fixed - if exploration_std_t.item() > 0: - std = math.sqrt(base_std_fixed**2 + float(exploration_std_t.item()) ** 2) - diff = increment - dt * drift - logpf_step = -0.5 * ((diff / std) ** 2).sum(dim=1) - 0.5 * dim * ( - log_2pi + 2 * math.log(std) - ) - - logpf_sum += logpf_step - s_curr = s_prev - - # Negative log-likelihood (mean over batch). - if debug and torch.isnan(logpf_sum).any(): - raise ValueError("NaNs in logpf_sum during pretrain loss.") - - return -(logpf_sum.mean()) - - def pretrain_prior_if_needed( args: argparse.Namespace, device: torch.device, @@ -272,8 +176,13 @@ def pretrain_prior_if_needed( Saves to args.prior_ckpt_path and returns the resolved path. """ ckpt_path = Path(os.path.expanduser(args.prior_ckpt_path)) - if ckpt_path.exists() or not args.pretrain_if_missing: - return ckpt_path + + if ckpt_path.exists(): + if args.clobber_pretrained_prior: + print(f"[pretrain] Clobbering existing prior checkpoint at {ckpt_path}") + ckpt_path.unlink() + else: + return ckpt_path print(f"[pretrain] Prior checkpoint missing at {ckpt_path}, starting pretraining...") @@ -304,7 +213,7 @@ def pretrain_prior_if_needed( ) # Build backward estimator: learned pb if enabled, else fixed Brownian bridge. - if args.learn_pb: + if args.pretrain_learn_pb: pb_module = DiffusionPISGradNetBackward( s_dim=s_dim, harmonics_dim=args.harmonics_dim, @@ -336,7 +245,7 @@ def pretrain_prior_if_needed( ).to(device) optim_params = [{"params": pf_prior.parameters(), "lr": args.pretrain_lr}] - if args.learn_pb: + if args.pretrain_learn_pb: optim_params.append( {"params": pb_prior.parameters(), "lr": args.pretrain_lr_back} ) @@ -346,6 +255,30 @@ def pretrain_prior_if_needed( weight_decay=args.pretrain_weight_decay, ) + # MLE trainer (uses forward PF and optional PB). + mle_trainer = MLEDiffusion( + pf=pf_prior, + pb=pb_prior, + num_steps=args.pretrain_num_steps, + sigma=args.pretrain_sigma, + t_scale=args.t_scale, + pb_scale_range=args.pb_scale_range, + learn_variance=args.learn_variance, + debug=args.debug_pretrain, + ) + + def _save_checkpoint(pf_prior, pb_prior, optimizer, it, ckpt_path): + ckpt_path.parent.mkdir(parents=True, exist_ok=True) + torch.save( + { + "pf_state_dict": pf_prior.state_dict(), + "pb_state_dict": pb_prior.state_dict(), + "optimizer_state_dict": optimizer.state_dict(), + "step": it + 1, + }, + ckpt_path, + ) + pf_prior.train() pbar = tqdm(range(args.pretrain_steps), dynamic_ncols=True, desc="pretrain_prior") @@ -353,51 +286,23 @@ def pretrain_prior_if_needed( with torch.no_grad(): batch = env_prior.target.sample(args.pretrain_batch_size) optimizer.zero_grad() - loss = _backward_mle_loss( - pf_prior, - pb_prior, - batch, - num_steps=args.pretrain_num_steps, - sigma=args.pretrain_sigma, - t_scale=args.t_scale, - exploration_std=args.pretrain_exploration_factor, - debug=args.debug_pretrain, - ) + loss = mle_trainer.loss(batch, exploration_std=args.pretrain_exploration_factor) loss.backward() if args.debug_pretrain: - grad_list = [ - p.grad.norm() for p in pf_prior.parameters() if p.grad is not None - ] - total_norm = ( - torch.norm(torch.stack(grad_list)) if grad_list else torch.tensor(0.0) - ) + total_norm, has_nan = get_debug_metrics(pf_prior) print( f"[pretrain][debug] step={it} loss={loss.item():.4e} grad_norm={total_norm.item():.4e}" ) - if torch.isnan(total_norm): + if has_nan: raise ValueError("NaN grad norm in pretrain.") optimizer.step() - def _save_checkpoint(pf_prior, pb_prior, optimizer, it, ckpt_path): - ckpt_path.parent.mkdir(parents=True, exist_ok=True) - torch.save( - { - "pf_state_dict": pf_prior.state_dict(), - "pb_state_dict": pb_prior.state_dict(), - "optimizer_state_dict": optimizer.state_dict(), - "step": it + 1, - }, - ckpt_path, - ) - + # Log progress only. if (it + 1) % args.pretrain_log_interval == 0 or it == args.pretrain_steps - 1: pbar.set_postfix({"loss": float(loss.item())}) - if ( - it + 1 - ) % args.pretrain_ckpt_interval == 0 or it == args.pretrain_steps - 1: - _save_checkpoint(pf_prior, pb_prior, optimizer, it, ckpt_path) + # Final checkpoint after pretraining (no intermediate resume support). _save_checkpoint(pf_prior, pb_prior, optimizer, it, ckpt_path) print(f"[pretrain] Saved prior to {ckpt_path}") @@ -471,6 +376,7 @@ def _posterior_log_reward(x: torch.Tensor) -> torch.Tensor: def main(args: argparse.Namespace) -> None: + """Runs the posterio finetuning pipeline, including prior tuning if required.""" set_seed(args.seed) device = torch.device(args.device) torch.set_default_device(device) @@ -504,7 +410,7 @@ def main(args: argparse.Namespace) -> None: device=device, ) - # Prior forward (fixed, no grad). Will be loaded from checkpoint if available. + # Prior forward. pf_prior = build_forward_estimator( s_dim=s_dim, num_steps=args.num_steps, @@ -524,7 +430,7 @@ def main(args: argparse.Namespace) -> None: ) # Pretrain prior if needed, then load weights into both prior and posterior so - # finetuning starts from the learned prior (mirrors reference behavior). + # finetuning starts from the learned prior. prior_ckpt_path = pretrain_prior_if_needed(args, device, s_dim) if prior_ckpt_path.exists(): ckpt = torch.load(prior_ckpt_path, map_location=device) @@ -539,6 +445,7 @@ def main(args: argparse.Namespace) -> None: f"pretrained weights not found at {prior_ckpt_path}, pretraining failed" ) + # During finetuning, the prior is fixed, no grad, pf_prior.eval() for p in pf_prior.parameters(): p.requires_grad_(False) @@ -607,7 +514,6 @@ def main(args: argparse.Namespace) -> None: if __name__ == "__main__": parser = argparse.ArgumentParser() # System - parser.add_argument("--no_cuda", action="store_true", help="Prevent CUDA usage") parser.add_argument( "--device", type=str, @@ -646,7 +552,7 @@ def main(args: argparse.Namespace) -> None: parser.add_argument("--zero_init", action="store_true", default=True) parser.add_argument( "--learn_variance", - action=argparse.BooleanOptionalAction, + action="store_true", default=False, help="Use learned scalar variance in the diffusion forward policy (ref default: off)", ) @@ -728,20 +634,14 @@ def main(args: argparse.Namespace) -> None: help="Path to save/load the pretrained prior checkpoint", ) parser.add_argument( - "--pretrain_if_missing", - action=argparse.BooleanOptionalAction, - default=True, - help="Auto-run prior pretraining if the checkpoint is missing", - ) - parser.add_argument( - "--pretrain_use_bwd_mle", - action=argparse.BooleanOptionalAction, - default=True, - help="Use exact backward MLE (reference) instead of surrogate bridge loss", + "--clobber_pretrained_prior", + action="store_true", + default=False, + help="Overwrite existing prior checkpoint and re-run pretraining", ) parser.add_argument( - "--learn_pb", - action=argparse.BooleanOptionalAction, + "--pretrain_learn_pb", + action="store_true", default=False, help="Enable learned backward policy corrections (pb) during pretrain", ) From 4aebf208a32b91436057c0bf112d2b5f272ef4c2 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Wed, 17 Dec 2025 01:04:34 -0500 Subject: [PATCH 09/26] added the MLE trainer --- src/gfn/gflownet/mle.py | 246 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 246 insertions(+) create mode 100644 src/gfn/gflownet/mle.py diff --git a/src/gfn/gflownet/mle.py b/src/gfn/gflownet/mle.py new file mode 100644 index 00000000..4dc9523e --- /dev/null +++ b/src/gfn/gflownet/mle.py @@ -0,0 +1,246 @@ +""" +MLE loss for diffusion GFlowNets (forward PF with optional PB). + +Key equations (per time step, shapes in comments): + - Backward bridge (s_t -> s_{t-dt}): + mean_bb = s_t * (1 - dt / t) # (B, s_dim) + std_bb = sigma * sqrt(dt*(t-dt)/t) # (B, 1) broadcast + With learned PB corrections: + mean = mean_bb + mean_corr + std = sqrt(std_bb^2 + corr_std^2) + - Forward PF log-prob for increment Δ = s_t - s_{t-dt}: + If PF predicts log_std: + σ = exp(log_std) * sqrt(dt) * sqrt(t_scale); optionally combine exploration + log p = -0.5 * Σ_i [ ((Δ - dt μ)_i / σ_i)^2 + 2 log σ_i + log 2π ] + Else (fixed variance): + σ = sigma * sqrt(dt) * sqrt(t_scale); optionally combine exploration + log p = -0.5 * Σ_i [ ((Δ - dt μ)_i / σ)^2 + log(2π σ^2) ] + - Loss = -mean over batch of Σ_t log p_t + +Tensor conventions: + - terminal_states: (B, s_dim) or (B, s_dim + 1) with last dim an extra + terminal indicator column; we drop the last dim if present. + - Times: scalar dt = 1/num_steps; t_curr = 1 - i*dt; t_fwd = 1 - (i+1)*dt. + +Usage (user owns optimizer/loop): +```python +gfn = MLEDiffusion(pf=pf, pb=None, num_steps=100, sigma=2.0, t_scale=1.0) +opt = torch.optim.Adam(gfn.parameters(), lr=1e-3) +for it in n_iterations: + # Sample a batch of terminal states. + batch = env.sample(batch_size) # batch shape (B, s_dim) + opt.zero_grad() + # Calculate the MLE loss under the backward / forward diffusion process. + loss = gfn.loss(batch, exploration_std=0.0) + loss.backward() + opt.step() +``` +""" + +from __future__ import annotations + +import math +from typing import Any, Optional + +import torch + +from gfn.env import Env +from gfn.estimators import ( + PinnedBrownianMotionBackward, + PinnedBrownianMotionForward, +) +from gfn.gflownet.base import GFlowNet +from gfn.samplers import Sampler +from gfn.states import States +from gfn.utils.modules import DiffusionFixedBackwardModule + + +class MLEDiffusion(GFlowNet): + """ + Maximum-likelihood diffusion GFlowNet (PF with optional PB). + + The caller owns the training loop; this class provides: + - sampling via the forward PF (for API compatibility) + - `.loss(env, terminal_states, ...)` computing the MLE objective + """ + + def __init__( + self, + pf: PinnedBrownianMotionForward, + pb: Optional[PinnedBrownianMotionBackward] = None, + *, + num_steps: int, + sigma: float, + t_scale: float = 1.0, + pb_scale_range: float = 0.1, + learn_variance: bool = False, + reduction: str = "mean", + debug: bool = False, + ) -> None: + super().__init__() + self.pf = pf + if pb is None: + # Constant PB estimator (no learned parameters) + pb = PinnedBrownianMotionBackward( + s_dim=pf.s_dim, + pb_module=DiffusionFixedBackwardModule(pf.s_dim), + sigma=sigma, + num_discretization_steps=num_steps, + n_variance_outputs=0, + pb_scale_range=pb_scale_range, + ).to(next(pf.parameters()).device) + self.pb = pb + self.s_dim = pf.s_dim + self.num_steps = num_steps + self.dt = 1.0 / num_steps + self.sigma = sigma + self.t_scale = t_scale + self.pb_scale_range = pb_scale_range + self.learn_variance = learn_variance + self.reduction = reduction + self.debug = debug + + # Sampler for base-class API (sample_trajectories). + self.sampler = Sampler(estimator=self.pf) + + def sample_trajectories( + self, + env: Env, + n: int, + conditions: torch.Tensor | None = None, + save_logprobs: bool = False, + save_estimator_outputs: bool = False, + **policy_kwargs: Any, + ): + return self.sampler.sample_trajectories( + env, + n, + conditions=conditions, + save_logprobs=save_logprobs, + save_estimator_outputs=save_estimator_outputs, + **policy_kwargs, + ) + + def to_training_samples(self, trajectories): + return trajectories + + def loss( + self, + terminal_states: Any, + *, + exploration_std: float | torch.Tensor = 0.0, + ) -> torch.Tensor: + """ + Compute the MLE objective given terminal states sampled from the target. + + Args: + terminal_states: torch.Tensor or States; shape (B, s_dim) or (B, s_dim+1). + exploration_std: extra state-space noise (combined in quadrature with PF std). + Returns: + Scalar loss (mean reduction). + """ + device, dtype, s_curr = self._extract_samples(terminal_states) + + bsz, dim = s_curr.shape + assert dim == self.s_dim, f"Expected s_dim={self.s_dim}, got {dim}" + dt = self.dt + base_std_fixed = self.sigma * math.sqrt(dt) * math.sqrt(self.t_scale) + log_2pi = math.log(2 * math.pi) + + logpf_sum = torch.zeros(bsz, device=device, dtype=dtype) + exploration_std_t = torch.as_tensor( + exploration_std, device=device, dtype=dtype + ).clamp(min=0.0) + + for i in range(self.num_steps): + # Times: forward transition index t_fwd corresponds to s_prev -> s_curr. + t_fwd = torch.full((bsz, 1), 1.0 - (i + 1) * dt, device=device, dtype=dtype) + t_curr = torch.full((bsz, 1), 1.0 - i * dt, device=device, dtype=dtype) + + # Backward sampler: Brownian bridge base + optional PB corrections. + pb_inp = torch.cat([s_curr, t_curr], dim=1) + pb_out = self.pb.module(pb_inp) + + is_s0 = (t_curr - dt) < dt * 1e-2 + # Base Brownian bridge mean/std toward 0 at t=0. + base_mean = torch.where( + is_s0, + torch.zeros_like(s_curr), + s_curr * (1.0 - dt / t_curr), + ) + base_std = torch.where( + is_s0, + torch.zeros_like(t_curr), + self.sigma * (dt * (t_curr - dt) / t_curr).sqrt(), + ) + + # Learned corrections (PB): mean_corr, optional log-std corr. + mean_corr = pb_out[..., :dim] * self.pb.pb_scale_range + if pb_out.shape[-1] == dim + 1: + log_std_corr = pb_out[..., [-1]] * self.pb.pb_scale_range + corr_std = torch.exp(log_std_corr) + else: + corr_std = torch.zeros_like(base_std) + + bwd_std = (base_std**2 + corr_std**2).sqrt() + noise = torch.randn_like(s_curr, device=device, dtype=dtype) + s_prev = base_mean + mean_corr + bwd_std * noise + + # Forward log-prob under PF for the observed increment (s_prev -> s_curr). + model_inp = torch.cat([s_prev, t_fwd], dim=1) + module_out = self.pf.module(model_inp) + increment = s_curr - s_prev + + # Case where module outputs learned variance. + if module_out.shape[-1] == dim + 1: + drift = module_out[..., :dim] + log_std = module_out[..., [-1]] + std = torch.exp(log_std) * math.sqrt(dt) * math.sqrt(self.t_scale) + if exploration_std_t.item() > 0: + std = torch.sqrt(std**2 + exploration_std_t**2) + diff = increment - dt * drift + logpf_step = -0.5 * ((diff / std) ** 2 + 2 * std.log() + log_2pi).sum( + dim=1 + ) + # Fixed variance case. + else: + drift = module_out + std = base_std_fixed + if exploration_std_t.item() > 0: + std = math.sqrt( + base_std_fixed**2 + float(exploration_std_t.item()) ** 2 + ) + diff = increment - dt * drift + logpf_step = -0.5 * ((diff / std) ** 2).sum(dim=1) - 0.5 * dim * ( + log_2pi + 2 * math.log(std) + ) + + logpf_sum += logpf_step + s_curr = s_prev + + if self.debug and torch.isnan(logpf_sum).any(): + raise ValueError("NaNs in logpf_sum during MLE loss.") + + # TODO: Use included loss reduction helpers. + loss = -(logpf_sum.mean() if self.reduction == "mean" else logpf_sum.sum()) + return loss + + def _extract_samples( + self, terminal_states: Any + ) -> tuple[torch.device, torch.dtype, torch.Tensor]: + """ + Normalize input to a (B, s_dim) tensor. + Accepts torch.Tensor or States; drops a final column if size matches s_dim+1. + """ + if isinstance(terminal_states, States): + tensor = terminal_states.tensor + elif torch.is_tensor(terminal_states): + tensor = terminal_states + else: + raise TypeError(f"Unsupported terminal_states type: {type(terminal_states)}") + + if tensor.shape[-1] == self.s_dim + 1: + tensor = tensor[..., :-1] + device = tensor.device + dtype = tensor.dtype + return device, dtype, tensor From f685b86e6d5680fe5b51f711ceb1e02409162930 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Wed, 17 Dec 2025 01:37:46 -0500 Subject: [PATCH 10/26] cleaned up training script --- tutorials/examples/train_diffusion_rtb.py | 217 ++++++---------------- 1 file changed, 60 insertions(+), 157 deletions(-) diff --git a/tutorials/examples/train_diffusion_rtb.py b/tutorials/examples/train_diffusion_rtb.py index 463ebf5c..df084087 100644 --- a/tutorials/examples/train_diffusion_rtb.py +++ b/tutorials/examples/train_diffusion_rtb.py @@ -1,21 +1,19 @@ #!/usr/bin/env python """ -Minimal end-to-end Relative Trajectory Balance (RTB) training script for diffusion. - -Now includes: -- Optional prior pretraining (auto-runs if the prior checkpoint is missing), so - finetuning starts from the same learned prior used in the reference scripts. -- An optimizer helper that mirrors the reference param grouping (policy vs. logZ). -- Hooks to add additional posterior targets (keep existing defaults). - -Uses the 25→9 GMM posterior target (`gmm25_posterior9`) by default with a learnable -posterior forward policy and a fixed prior forward policy. Loss is RTB (no backward -policy). At the end of training, saves a scatter plot of sampled states to the user's -home directory. +Minimal end-to-end Relative Trajectory Balance (RTB) fine-tuning training script for +diffusion models. + +- Prior is pre-trained (auto-runs if the prior checkpoint is missing), so + finetuning starts from a learned prior. +- Posterior is fine-tuned from this prior (pf). + +By default, uses the 25→9 GMM posterior target (`gmm25_posterior9`) by default with a +learnable posterior forward policy and a fixed prior forward policy. Loss is RTB (no +backward policy). This script outputs the prior weights alongside plots of samples +from both the prior and posterior distributions. """ import argparse -import os from pathlib import Path import matplotlib.pyplot as plt @@ -36,6 +34,23 @@ ) +def resolve_output_paths(args: argparse.Namespace) -> argparse.Namespace: + """Resolve all output paths relative to this script's directory.""" + script_dir = Path(__file__).resolve().parent + output_dir = Path(args.output_dir) + if not output_dir.is_absolute(): + output_dir = script_dir / output_dir + output_dir = output_dir.expanduser().resolve() + output_dir.mkdir(parents=True, exist_ok=True) + + args.output_dir = output_dir + args.prior_ckpt_path = output_dir / "train_diffusion_rtb_prior_ckpt.pt" + args.pretrain_save_fig_path = output_dir / "train_diffusion_rtb_prior_samples.png" + args.save_fig_path = output_dir / "train_diffusion_rtb_posterior_samples.png" + + return args + + def get_debug_metrics(estimator: torch.nn.Module) -> tuple[torch.Tensor, bool]: """Compute gradient norm for a module; return (total_norm, has_nan).""" grad_list = [p.grad.norm() for p in estimator.parameters() if p.grad is not None] @@ -124,65 +139,19 @@ def build_forward_estimator( ).to(device) -def build_backward_estimator( - s_dim: int, - num_steps: int, - sigma: float, - harmonics_dim: int, - t_emb_dim: int, - s_emb_dim: int, - hidden_dim: int, - joint_layers: int, - zero_init: bool, - learn_variance: bool, - clipping: bool, - gfn_clip: float, - pb_scale_range: float, - log_var_range: float, - device: torch.device, -) -> PinnedBrownianMotionBackward: - """Build learnable backward policy (pb) with optional variance correction.""" - pb_module = DiffusionPISGradNetBackward( - s_dim=s_dim, - harmonics_dim=harmonics_dim, - t_emb_dim=t_emb_dim, - s_emb_dim=s_emb_dim, - hidden_dim=hidden_dim, - joint_layers=joint_layers, - zero_init=zero_init, - clipping=clipping, - gfn_clip=gfn_clip, - pb_scale_range=pb_scale_range, - log_var_range=log_var_range, - learn_variance=learn_variance, - ) - return PinnedBrownianMotionBackward( - s_dim=s_dim, - pb_module=pb_module, - sigma=sigma, - num_discretization_steps=num_steps, - n_variance_outputs=1 if learn_variance else 0, - pb_scale_range=pb_scale_range, - ).to(device) - - -def pretrain_prior_if_needed( - args: argparse.Namespace, - device: torch.device, - s_dim: int, -) -> Path: +def pretrain_prior(args: argparse.Namespace, device: torch.device, s_dim: int) -> None: """ Auto-pretrain the prior if the checkpoint is missing. Saves to args.prior_ckpt_path and returns the resolved path. """ - ckpt_path = Path(os.path.expanduser(args.prior_ckpt_path)) + ckpt_path = Path(args.prior_ckpt_path) if ckpt_path.exists(): if args.clobber_pretrained_prior: print(f"[pretrain] Clobbering existing prior checkpoint at {ckpt_path}") ckpt_path.unlink() else: - return ckpt_path + return print(f"[pretrain] Prior checkpoint missing at {ckpt_path}, starting pretraining...") @@ -244,15 +213,13 @@ def pretrain_prior_if_needed( pb_scale_range=pb_scale_range, ).to(device) - optim_params = [{"params": pf_prior.parameters(), "lr": args.pretrain_lr}] + optim_params = [{"params": pf_prior.parameters(), "lr": args.lr}] if args.pretrain_learn_pb: - optim_params.append( - {"params": pb_prior.parameters(), "lr": args.pretrain_lr_back} - ) + optim_params.append({"params": pb_prior.parameters(), "lr": args.lr}) optimizer = torch.optim.Adam( optim_params, - lr=args.pretrain_lr, - weight_decay=args.pretrain_weight_decay, + lr=args.lr, + weight_decay=args.weight_decay, ) # MLE trainer (uses forward PF and optional PB). @@ -264,7 +231,7 @@ def pretrain_prior_if_needed( t_scale=args.t_scale, pb_scale_range=args.pb_scale_range, learn_variance=args.learn_variance, - debug=args.debug_pretrain, + debug=__debug__, ) def _save_checkpoint(pf_prior, pb_prior, optimizer, it, ckpt_path): @@ -284,11 +251,11 @@ def _save_checkpoint(pf_prior, pb_prior, optimizer, it, ckpt_path): for it in pbar: with torch.no_grad(): - batch = env_prior.target.sample(args.pretrain_batch_size) + batch = env_prior.target.sample(args.batch_size) optimizer.zero_grad() loss = mle_trainer.loss(batch, exploration_std=args.pretrain_exploration_factor) loss.backward() - if args.debug_pretrain: + if __debug__: total_norm, has_nan = get_debug_metrics(pf_prior) print( f"[pretrain][debug] step={it} loss={loss.item():.4e} grad_norm={total_norm.item():.4e}" @@ -317,18 +284,18 @@ def _save_checkpoint(pf_prior, pb_prior, optimizer, it, ckpt_path): plot_samples( xs, env_prior.target, - os.path.expanduser(args.pretrain_save_fig_path), + "RTB Prior Samples", + args.pretrain_save_fig_path, return_fig=False, ) print(f"[pretrain] Saved prior samples plot to {args.pretrain_save_fig_path}") - return ckpt_path - def plot_samples( xs: torch.Tensor, target, - save_path: str, + title: str, + save_path: Path | str, return_fig: bool = False, ): """Contour + scatter plot of samples against the posterior density.""" @@ -359,12 +326,10 @@ def _posterior_log_reward(x: torch.Tensor) -> torch.Tensor: max_n_samples=2000, ) - ax.set_title("RTB posterior samples") + ax.set_title(title) fig.tight_layout() - dirpath = os.path.dirname(save_path) - - if dirpath: - os.makedirs(dirpath, exist_ok=True) + save_path = Path(save_path) + save_path.parent.mkdir(parents=True, exist_ok=True) fig.savefig(save_path) if return_fig: @@ -376,7 +341,8 @@ def _posterior_log_reward(x: torch.Tensor) -> torch.Tensor: def main(args: argparse.Namespace) -> None: - """Runs the posterio finetuning pipeline, including prior tuning if required.""" + """Runs the posterior finetuning pipeline, including prior pretraining if required.""" + args = resolve_output_paths(args) set_seed(args.seed) device = torch.device(args.device) torch.set_default_device(device) @@ -431,9 +397,10 @@ def main(args: argparse.Namespace) -> None: # Pretrain prior if needed, then load weights into both prior and posterior so # finetuning starts from the learned prior. - prior_ckpt_path = pretrain_prior_if_needed(args, device, s_dim) - if prior_ckpt_path.exists(): - ckpt = torch.load(prior_ckpt_path, map_location=device) + pretrain_prior(args, device, s_dim) + + if args.prior_ckpt_path.exists(): + ckpt = torch.load(args.prior_ckpt_path, map_location=device) state = ckpt.get("pf_state_dict", ckpt) missing, unexpected = pf_prior.load_state_dict(state, strict=False) if missing or unexpected: @@ -442,7 +409,7 @@ def main(args: argparse.Namespace) -> None: pf_post.load_state_dict(pf_prior.state_dict(), strict=False) else: raise Exception( - f"pretrained weights not found at {prior_ckpt_path}, pretraining failed" + f"pretrained weights not found at {args.prior_ckpt_path}, pretraining failed" ) # During finetuning, the prior is fixed, no grad, @@ -455,7 +422,6 @@ def main(args: argparse.Namespace) -> None: prior_pf=pf_prior, init_logZ=0.0, beta=args.beta, - log_reward_clip_min=args.log_reward_clip_min, ).to(device) sampler = Sampler(estimator=pf_post) @@ -501,14 +467,14 @@ def main(args: argparse.Namespace) -> None: with torch.no_grad(): samples_states = gflownet.sample_terminating_states(env, n=args.vis_n) xs = samples_states.tensor[:, :-1] - save_path = os.path.expanduser(args.save_fig_path) plot_samples( xs, env.target, - save_path, + "RTB Posterior Samples", + args.save_fig_path, return_fig=False, ) - print(f"Saved final samples scatter to {save_path}") + print(f"Saved final samples scatter to {args.save_fig_path}") if __name__ == "__main__": @@ -571,7 +537,7 @@ def main(args: argparse.Namespace) -> None: parser.add_argument( "--t_scale", type=float, - default=1.0, # 5.0 + default=5.0, help="Scale diffusion std to mirror reference (reference: 5.0)", ) parser.add_argument( @@ -587,12 +553,6 @@ def main(args: argparse.Namespace) -> None: parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument("--lr_logz", type=float, default=1e-1) parser.add_argument("--beta", type=float, default=1.0, help="RTB beta multiplier") - parser.add_argument( - "--log_reward_clip_min", - type=float, - default=-float("inf"), - help="Min clip for log reward", - ) # Exploration noise (state-space Gaussian added in quadrature to PF std) parser.add_argument( "--exploration_factor", @@ -603,7 +563,7 @@ def main(args: argparse.Namespace) -> None: parser.add_argument( "--exploration_warm_down_start", type=float, - default=0, + default=500, help="Linearly warm down exploration after n iters (to 0 by exploration_warm_down_end iters)", ) parser.add_argument( @@ -620,19 +580,13 @@ def main(args: argparse.Namespace) -> None: "--vis_n", type=int, default=2000, help="Number of samples for final plot" ) parser.add_argument( - "--save_fig_path", + "--output_dir", type=str, - default="output/rtb_final_samples.png", - help="Path to save final samples plot", + default="output", + help="Base output dir (resolved relative to this script)", ) # Prior pretraining / loading - parser.add_argument( - "--prior_ckpt_path", - type=str, - default="output/prior.pt", - help="Path to save/load the pretrained prior checkpoint", - ) parser.add_argument( "--clobber_pretrained_prior", action="store_true", @@ -675,63 +629,12 @@ def main(args: argparse.Namespace) -> None: default=0.0, help="Exploration std for pretrain backward MLE (reference: off by default)", ) - parser.add_argument( - "--pretrain_batch_size", - type=int, - default=500, - help="Batch size for prior pretraining", - ) parser.add_argument( "--pretrain_steps", type=int, default=10000, help="Training steps for prior pretraining", ) - parser.add_argument( - "--pretrain_lr", type=float, default=1e-3, help="LR for prior pretraining" - ) - parser.add_argument( - "--pretrain_lr_back", - type=float, - default=1e-3, - help="LR for backward policy during pretrain", - ) - parser.add_argument( - "--pretrain_weight_decay", - type=float, - default=0.0, - help="Weight decay for prior pretraining", - ) - parser.add_argument( - "--pretrain_log_interval", - type=int, - default=100, - help="Logging interval (steps) during prior pretraining (reference: 100)", - ) - parser.add_argument( - "--pretrain_ckpt_interval", - type=int, - default=1000, - help="Checkpoint interval during prior pretraining (reference: 1000)", - ) - parser.add_argument( - "--pretrain_vis_n", - type=int, - default=2000, - help="Number of samples to plot after prior pretraining", - ) - parser.add_argument( - "--pretrain_save_fig_path", - type=str, - default="output/prior_pretrain.png", - help="Path to save prior samples plot after pretraining", - ) - parser.add_argument( - "--debug_pretrain", - action=argparse.BooleanOptionalAction, - default=False, - help="Enable extra NaN/grad checks during pretrain loss", - ) # Optimizer extras parser.add_argument( From e64787792980326f9dedb2454e4c117a939da497 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Wed, 17 Dec 2025 03:33:29 -0500 Subject: [PATCH 11/26] shrunk script for clarity --- tutorials/examples/train_diffusion_rtb.py | 312 ++++++++-------------- 1 file changed, 105 insertions(+), 207 deletions(-) diff --git a/tutorials/examples/train_diffusion_rtb.py b/tutorials/examples/train_diffusion_rtb.py index df084087..9f02a949 100644 --- a/tutorials/examples/train_diffusion_rtb.py +++ b/tutorials/examples/train_diffusion_rtb.py @@ -51,6 +51,32 @@ def resolve_output_paths(args: argparse.Namespace) -> argparse.Namespace: return args +def forward_kwargs( + args: argparse.Namespace, + s_dim: int, + num_steps: int, + sigma: float, + device: torch.device, +) -> dict: + return dict( + s_dim=s_dim, + num_steps=num_steps, + sigma=sigma, + harmonics_dim=args.harmonics_dim, + t_emb_dim=args.t_emb_dim, + s_emb_dim=args.s_emb_dim, + hidden_dim=args.hidden_dim, + joint_layers=args.joint_layers, + zero_init=args.zero_init, + learn_variance=args.learn_var, + clipping=args.clipping, + gfn_clip=args.gfn_clip, + t_scale=args.t_scale, + log_var_range=args.log_var_range, + device=device, + ) + + def get_debug_metrics(estimator: torch.nn.Module) -> tuple[torch.Tensor, bool]: """Compute gradient norm for a module; return (total_norm, has_nan).""" grad_list = [p.grad.norm() for p in estimator.parameters() if p.grad is not None] @@ -164,21 +190,13 @@ def pretrain_prior(args: argparse.Namespace, device: torch.device, s_dim: int) - ) pf_prior = build_forward_estimator( - s_dim=s_dim, - num_steps=args.pretrain_num_steps, - sigma=args.pretrain_sigma, - harmonics_dim=args.harmonics_dim, - t_emb_dim=args.t_emb_dim, - s_emb_dim=args.s_emb_dim, - hidden_dim=args.hidden_dim, - joint_layers=args.joint_layers, - zero_init=args.zero_init, - learn_variance=args.learn_variance, - clipping=args.clipping, - gfn_clip=args.gfn_clip, - t_scale=args.t_scale, - log_var_range=args.log_var_range, - device=device, + **forward_kwargs( + args, + s_dim=s_dim, + num_steps=args.pretrain_num_steps, + sigma=args.pretrain_sigma, + device=device, + ) ) # Build backward estimator: learned pb if enabled, else fixed Brownian bridge. @@ -195,9 +213,9 @@ def pretrain_prior(args: argparse.Namespace, device: torch.device, s_dim: int) - gfn_clip=args.gfn_clip, pb_scale_range=args.pb_scale_range, log_var_range=args.log_var_range, - learn_variance=args.learn_variance, + learn_variance=args.learn_var, ) - n_var_outputs = 1 if args.learn_variance else 0 + n_var_outputs = 1 if args.learn_var else 0 pb_scale_range = args.pb_scale_range else: pb_module = DiffusionFixedBackwardModule(s_dim) @@ -230,7 +248,7 @@ def pretrain_prior(args: argparse.Namespace, device: torch.device, s_dim: int) - sigma=args.pretrain_sigma, t_scale=args.t_scale, pb_scale_range=args.pb_scale_range, - learn_variance=args.learn_variance, + learn_variance=args.learn_var, debug=__debug__, ) @@ -359,40 +377,16 @@ def main(args: argparse.Namespace) -> None: # Posterior forward (trainable) pf_post = build_forward_estimator( - s_dim=s_dim, - num_steps=args.num_steps, - sigma=args.sigma, - harmonics_dim=args.harmonics_dim, - t_emb_dim=args.t_emb_dim, - s_emb_dim=args.s_emb_dim, - hidden_dim=args.hidden_dim, - joint_layers=args.joint_layers, - zero_init=args.zero_init, - learn_variance=args.learn_variance, - clipping=args.clipping, - gfn_clip=args.gfn_clip, - t_scale=args.t_scale, - log_var_range=args.log_var_range, - device=device, + **forward_kwargs( + args, s_dim=s_dim, num_steps=args.num_steps, sigma=args.sigma, device=device + ) ) # Prior forward. pf_prior = build_forward_estimator( - s_dim=s_dim, - num_steps=args.num_steps, - sigma=args.sigma, - harmonics_dim=args.harmonics_dim, - t_emb_dim=args.t_emb_dim, - s_emb_dim=args.s_emb_dim, - hidden_dim=args.hidden_dim, - joint_layers=args.joint_layers, - zero_init=args.zero_init, - learn_variance=args.learn_variance, - clipping=args.clipping, - gfn_clip=args.gfn_clip, - t_scale=args.t_scale, - log_var_range=args.log_var_range, - device=device, + **forward_kwargs( + args, s_dim=s_dim, num_steps=args.num_steps, sigma=args.sigma, device=device + ) ) # Pretrain prior if needed, then load weights into both prior and posterior so @@ -478,171 +472,75 @@ def main(args: argparse.Namespace) -> None: if __name__ == "__main__": + + def add_arg_group( + parser: argparse.ArgumentParser, + specs: list[tuple[tuple[str, ...], dict]], + ) -> None: + for args, kwargs in specs: + parser.add_argument(*args, **kwargs) + + # fmt: off parser = argparse.ArgumentParser() - # System - parser.add_argument( - "--device", - type=str, - default="cpu", - choices=["cpu", "cuda", "mps"], - help="Device for training.", - ) - parser.add_argument("--seed", type=int, default=0, help="Random seed") - - # Target / environment - parser.add_argument( - "--target", - type=str, - default="gmm25_posterior9", - help="Diffusion target (default: gmm25_posterior9)", - ) - parser.add_argument( - "--num_steps", - type=int, - default=100, - help="number of discretization steps (reference=100)", - ) - parser.add_argument( - "--sigma", - type=float, - default=2.0, - help="diffusion coefficient for the pinned Brownian motion", - ) + system_args = [ + (("--device",), {"type": str, "default": "cpu", "choices": ["cpu", "cuda", "mps"], "help": "Device for training."}), + (("--seed",), {"type": int, "default": 0, "help": "Random seed"}), + ] - # Model (DiffusionPISGradNetForward) - parser.add_argument("--harmonics_dim", type=int, default=64) - parser.add_argument("--t_emb_dim", type=int, default=64) - parser.add_argument("--s_emb_dim", type=int, default=64) - parser.add_argument("--hidden_dim", type=int, default=64) - parser.add_argument("--joint_layers", type=int, default=2) - parser.add_argument("--zero_init", action="store_true", default=True) - parser.add_argument( - "--learn_variance", - action="store_true", - default=False, - help="Use learned scalar variance in the diffusion forward policy (ref default: off)", - ) - parser.add_argument( - "--clipping", - action=argparse.BooleanOptionalAction, - default=False, - help="Clip model outputs (reference default: off)", - ) - parser.add_argument( - "--gfn_clip", - type=float, - default=1e4, - help="Clipping value for drift outputs (reference: 1e4)", - ) - parser.add_argument( - "--t_scale", - type=float, - default=5.0, - help="Scale diffusion std to mirror reference (reference: 5.0)", - ) - parser.add_argument( - "--log_var_range", - type=float, - default=4.0, - help="Range to bound learned log-std when learn_variance is enabled (reference: 4.0)", - ) + finetune_args = [ + (("--target",), {"type": str, "default": "gmm25_posterior9", "help": "Diffusion target"}), + (("--num_steps",), {"type": int, "default": 100, "help": "Discretization steps"}), + (("--sigma",), {"type": float, "default": 2.0, "help": "Pinned Brownian motion sigma"}), + (("--harmonics_dim",), {"type": int, "default": 64}), + (("--t_emb_dim",), {"type": int, "default": 64}), + (("--s_emb_dim",), {"type": int, "default": 64}), + (("--hidden_dim",), {"type": int, "default": 64}), + (("--joint_layers",), {"type": int, "default": 2}), + (("--zero_init",), {"action": "store_true", "default": True}), + (("--learn_var",), {"action": "store_true", "default": False, "help": "Learned variance"}), + (("--clipping",), {"action": argparse.BooleanOptionalAction, "default": False, "help": "Clip model outputs"}), + (("--gfn_clip",), {"type": float, "default": 1e4, "help": "Drift clip value"}), + (("--t_scale",), {"type": float, "default": 5.0, "help": "Diffusion std scale"}), + (("--log_var_range",), {"type": float, "default": 4.0, "help": "Bound for learned log-std"}), + ] - # Training - parser.add_argument("--n_iterations", type=int, default=5000) - parser.add_argument("--batch_size", type=int, default=500) - parser.add_argument("--lr", type=float, default=1e-3) - parser.add_argument("--lr_logz", type=float, default=1e-1) - parser.add_argument("--beta", type=float, default=1.0, help="RTB beta multiplier") - # Exploration noise (state-space Gaussian added in quadrature to PF std) - parser.add_argument( - "--exploration_factor", - type=float, - default=0.5, - help="Base exploration std applied per step when exploratory is enabled (reference ~0.5)", - ) - parser.add_argument( - "--exploration_warm_down_start", - type=float, - default=500, - help="Linearly warm down exploration after n iters (to 0 by exploration_warm_down_end iters)", - ) - parser.add_argument( - "--exploration_warm_down_end", - type=float, - default=4500, - help="Linearly warm down exploration after n iters (to 0 by exploration_warm_down_end iters)", - ) + pretrain_args = [ + (("--clobber_pretrained_prior",), {"action": "store_true", "default": False, "help": "Overwrite existing prior"}), + (("--pretrain_learn_pb",), {"action": "store_true", "default": False, "help": "Enable learned backward policy"}), + (("--pb_scale_range",), {"type": float, "default": 0.1, "help": "Tanh scaling for pb"}), + (("--pretrain_target",), {"type": str, "default": "gmm25_prior", "help": "Target used for pretraining"}), + (("--pretrain_num_steps",), {"type": int, "default": 100, "help": "Pretrain discretization steps"}), + (("--pretrain_sigma",), {"type": float, "default": 2.0, "help": "Pretrain diffusion sigma"}), + (("--pretrain_exploration_factor",), {"type": float, "default": 0.0, "help": "Pretrain std expansion"}), + (("--pretrain_steps",), {"type": int, "default": 10000, "help": "Pretrain steps"}), + (("--pretrain_log_interval",), {"type": int, "default": 100, "help": "Pretrain log interval"}), + (("--pretrain_vis_n",), {"type": int, "default": 2000, "help": "Pretrain samples to plot"}), + ] - # Logging / eval - parser.add_argument("--log_interval", type=int, default=100) - parser.add_argument("--eval_n", type=int, default=500) - parser.add_argument( - "--vis_n", type=int, default=2000, help="Number of samples for final plot" - ) - parser.add_argument( - "--output_dir", - type=str, - default="output", - help="Base output dir (resolved relative to this script)", - ) + train_args = [ + (("--n_iterations",), {"type": int, "default": 5000}), + (("--batch_size",), {"type": int, "default": 500}), + (("--lr",), {"type": float, "default": 1e-3}), + (("--lr_logz",), {"type": float, "default": 1e-1}), + (("--weight_decay",), {"type": float, "default": 0.0, "help": "Weight decay"}), + (("--beta",), {"type": float, "default": 1.0, "help": "RTB beta"}), + (("--exploration_factor",), {"type": float, "default": 0.5, "help": "Step-wise std expansion"}), + (("--exploration_warm_down_start",), {"type": float, "default": 500, "help": "Warmdown start iter"}), + (("--exploration_warm_down_end",), {"type": float, "default": 4500, "help": "Warmdown end iter"}), + ] - # Prior pretraining / loading - parser.add_argument( - "--clobber_pretrained_prior", - action="store_true", - default=False, - help="Overwrite existing prior checkpoint and re-run pretraining", - ) - parser.add_argument( - "--pretrain_learn_pb", - action="store_true", - default=False, - help="Enable learned backward policy corrections (pb) during pretrain", - ) - parser.add_argument( - "--pb_scale_range", - type=float, - default=0.1, - help="Tanh scaling for backward mean/var corrections (reference: 0.1)", - ) - parser.add_argument( - "--pretrain_target", - type=str, - default="gmm25_prior", - help="Target used for prior pretraining (matches reference prior)", - ) - parser.add_argument( - "--pretrain_num_steps", - type=int, - default=100, - help="Discretization steps for prior pretraining (reference=100)", - ) - parser.add_argument( - "--pretrain_sigma", - type=float, - default=2.0, - help="Diffusion coefficient for prior pretraining", - ) - parser.add_argument( - "--pretrain_exploration_factor", - type=float, - default=0.0, - help="Exploration std for pretrain backward MLE (reference: off by default)", - ) - parser.add_argument( - "--pretrain_steps", - type=int, - default=10000, - help="Training steps for prior pretraining", - ) + log_args = [ + (("--log_interval",), {"type": int, "default": 100}), + (("--eval_n",), {"type": int, "default": 500}), + (("--vis_n",), {"type": int, "default": 2000, "help": "Samples for final plot"}), + (("--output_dir",), {"type": str, "default": "output", "help": "relative output dir"}), + ] - # Optimizer extras - parser.add_argument( - "--weight_decay", - type=float, - default=0.0, - help="Weight decay for the RTB optimizer (policy/logZ)", - ) + add_arg_group(parser, system_args) + add_arg_group(parser, finetune_args) + add_arg_group(parser, pretrain_args) + add_arg_group(parser, train_args) + add_arg_group(parser, log_args) args = parser.parse_args() main(args) From 06ccc10ada87a0c23900a19dc4b6b84286a4d284 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Wed, 17 Dec 2025 03:34:19 -0500 Subject: [PATCH 12/26] no change --- .flake8 | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.flake8 b/.flake8 index 29944067..968d55a9 100644 --- a/.flake8 +++ b/.flake8 @@ -2,4 +2,4 @@ ignore = E203, E266, E501, W503, F403, F401, F821 max-line-length = 89 max-complexity = 18 -select = B,C,E,F,W,T4,B9 \ No newline at end of file +select = B,C,E,F,W,T4,B9 From bfbbc22f5dc6a54e363e9d014ac1ef4fc18d0bc9 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Wed, 17 Dec 2025 19:24:02 -0500 Subject: [PATCH 13/26] fixed backward bug --- .flake8 | 2 +- pyproject.toml | 4 +- src/gfn/env.py | 96 +++------ src/gfn/estimators.py | 76 +++---- src/gfn/gflownet/mle.py | 86 +++++--- src/gfn/gym/diffusion_sampling.py | 137 ++++++------- testing/gflownet/test_mle_diffusion.py | 16 +- testing/test_environments.py | 232 ++++++++++++---------- tutorials/examples/test_scripts.py | 46 ++--- tutorials/examples/train_diffusion_rtb.py | 10 +- 10 files changed, 348 insertions(+), 357 deletions(-) diff --git a/.flake8 b/.flake8 index 968d55a9..bd121e67 100644 --- a/.flake8 +++ b/.flake8 @@ -1,5 +1,5 @@ [flake8] ignore = E203, E266, E501, W503, F403, F401, F821 -max-line-length = 89 +max-line-length = 100 max-complexity = 18 select = B,C,E,F,W,T4,B9 diff --git a/pyproject.toml b/pyproject.toml index b226cba3..e6ebe9d5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -129,7 +129,7 @@ all = [ [tool.black] target_version = ["py310"] -line_length = 89 +line_length = 100 include = '\.pyi?$' extend-exclude = '''/(\.git|\.hg|\.mypy_cache|\.ipynb|\.tox|\.venv|build)/g''' @@ -180,7 +180,7 @@ commands = pytest -s # Black-compatibility enforced. [tool.isort] profile = "black" -line_length = 89 +line_length = 100 multi_line_output = 3 include_trailing_comma = true force_grid_wrap = 0 diff --git a/src/gfn/env.py b/src/gfn/env.py index fc88ddf3..0b5e096e 100644 --- a/src/gfn/env.py +++ b/src/gfn/env.py @@ -69,9 +69,7 @@ def __init__( if sf is None: assert isinstance(s0, torch.Tensor), "When sf is None, s0 must be a Tensor" - sf = torch.full( - s0.shape, default_fill_value_for_dtype(s0.dtype), dtype=s0.dtype - ) + sf = torch.full(s0.shape, default_fill_value_for_dtype(s0.dtype), dtype=s0.dtype) self.sf = sf.to(s0.device) # pyright: ignore - torch_geometric type hint fix assert self.s0.shape == self.sf.shape == state_shape @@ -144,9 +142,7 @@ def actions_from_batch_shape(self, batch_shape: Tuple) -> Actions: Returns: A batch of dummy actions. """ - return self.Actions.make_dummy_actions( - batch_shape, device=self.device, debug=self.debug - ) + return self.Actions.make_dummy_actions(batch_shape, device=self.device, debug=self.debug) @abstractmethod def step(self, states: States, actions: Actions) -> States: @@ -198,9 +194,7 @@ def is_action_valid( True if all actions are valid in the given states, False otherwise. """ - def make_random_states( - self, batch_shape: Tuple, device: torch.device | None = None - ) -> States: + def make_random_states(self, batch_shape: Tuple, device: torch.device | None = None) -> States: """Optional method to return a batch of random states. Args: @@ -280,9 +274,7 @@ def reset( batch_shape = (batch_shape,) elif isinstance(batch_shape, list): batch_shape = tuple(batch_shape) - return self.states_from_batch_shape( - batch_shape=batch_shape, random=random, sink=sink - ) + return self.states_from_batch_shape(batch_shape=batch_shape, random=random, sink=sink) def _step(self, states: States, actions: Actions) -> States: """Wrapper for the user-defined `step` function. @@ -301,9 +293,7 @@ def _step(self, states: States, actions: Actions) -> States: if self.debug: # Debug-only guards to avoid graph breaks in compiled runs. assert states.batch_shape == actions.batch_shape - assert ( - len(states.batch_shape) == 1 - ), "Batch shape must be 1 for the step method." + assert len(states.batch_shape) == 1, "Batch shape must be 1 for the step method." valid_states_idx: torch.Tensor = ~states.is_sink_state if self.debug: @@ -322,7 +312,7 @@ def _step(self, states: States, actions: Actions) -> States: # We only step on states that are not sink states. # Note that exit actions directly set the states to the sink state, so they # are not included in the valid_states_idx. - new_valid_states_idx = valid_states_idx & ~actions.is_exit + new_valid_states_idx = valid_states_idx & ~actions.is_exit # boolean mask. # IMPORTANT: .clone() is used to ensure that the new states are a # distinct object from the old states. This is important for the sampler to @@ -330,7 +320,7 @@ def _step(self, states: States, actions: Actions) -> States: # method in your custom environment, you must ensure that the `new_states` # returned is a distinct object from the submitted states. not_done_states = states[new_valid_states_idx].clone() - not_done_actions = actions[new_valid_states_idx] + not_done_actions = actions[new_valid_states_idx] # NOTE: boolean indexing creates a copy! not_done_states = self.step(not_done_states, not_done_actions) assert isinstance( @@ -341,9 +331,7 @@ def _step(self, states: States, actions: Actions) -> States: # For the indices where the new states are not sink states (i.e., where the # state is not already a sink and the action is not exit), update those # positions with the result of the environment's step function. - new_states = self.States.make_sink_states( - states.batch_shape, device=states.device - ) + new_states = self.States.make_sink_states(states.batch_shape, device=states.device) new_states[new_valid_states_idx] = not_done_states return new_states @@ -378,9 +366,7 @@ def _backward_step(self, states: States, actions: Actions) -> States: valid_actions = actions[valid_states_idx] valid_states = new_states[valid_states_idx] - if self.debug and not self.is_action_valid( - valid_states, valid_actions, backward=True - ): + if self.debug and not self.is_action_valid(valid_states, valid_actions, backward=True): raise NonValidActionsError( "Some actions are not valid in the given states. See `is_action_valid`." ) @@ -423,9 +409,7 @@ def log_partition(self) -> float: Returns: The log partition function. """ - raise NotImplementedError( - "The environment does not support enumeration of states" - ) + raise NotImplementedError("The environment does not support enumeration of states") @property def true_dist(self) -> torch.Tensor: @@ -434,9 +418,7 @@ def true_dist(self) -> torch.Tensor: Returns: The true distribution as a 1-dimensional tensor. """ - raise NotImplementedError( - "The environment does not support enumeration of states" - ) + raise NotImplementedError("The environment does not support enumeration of states") class DiscreteEnv(Env, ABC): @@ -706,9 +688,7 @@ def get_terminating_state_dist(self, states: DiscreteStates) -> torch.Tensor: A 1D tensor of shape `(n_terminating_states,)` with empirical frequencies. """ try: - states_indices = ( - self.get_terminating_states_indices(states).cpu().numpy().tolist() - ) + states_indices = self.get_terminating_states_indices(states).cpu().numpy().tolist() except NotImplementedError as e: warnings.warn( "Environment does not implement state enumeration required for\n" @@ -737,9 +717,7 @@ def get_terminating_state_dist(self, states: DiscreteStates) -> torch.Tensor: "No terminating states provided to compute empirical distribution.", UserWarning, ) - return torch.zeros( - (self.n_terminating_states,), dtype=torch.get_default_dtype() - ) + return torch.zeros((self.n_terminating_states,), dtype=torch.get_default_dtype()) return torch.tensor(counter_list, dtype=torch.get_default_dtype()) / denom @@ -798,23 +776,17 @@ def validate( ) assert isinstance(sampled_terminating_states, DiscreteStates) else: - sampled_terminating_states = visited_terminating_states[ - -n_validation_samples: - ] + sampled_terminating_states = visited_terminating_states[-n_validation_samples:] # Compute empirical distribution; may require enumeration support. try: - final_states_dist = self.get_terminating_state_dist( - sampled_terminating_states - ) + final_states_dist = self.get_terminating_state_dist(sampled_terminating_states) except NotImplementedError: # Already warned in helper; return gracefully. return {}, sampled_terminating_states if final_states_dist.numel() == 0: - warnings.warn( - "Empirical distribution is empty (no terminating samples).", UserWarning - ) + warnings.warn("Empirical distribution is empty (no terminating samples).", UserWarning) return {}, sampled_terminating_states l1_dist = (final_states_dist - true_dist).abs().mean().item() @@ -822,9 +794,7 @@ def validate( # Report logZ difference if both sides are available. learned_logZ: float | None = None - if hasattr(gflownet, "logZ") and isinstance( - getattr(gflownet, "logZ"), torch.Tensor - ): + if hasattr(gflownet, "logZ") and isinstance(getattr(gflownet, "logZ"), torch.Tensor): learned_logZ = float(getattr(gflownet, "logZ").item()) if learned_logZ is not None and true_logZ is not None: validation_info["logZ_diff"] = abs(learned_logZ - true_logZ) @@ -840,9 +810,7 @@ def get_states_indices(self, states: DiscreteStates) -> torch.Tensor: Returns: Tensor of shape (*batch_shape) containing the indices of the states. """ - raise NotImplementedError( - "The environment does not support enumeration of states" - ) + raise NotImplementedError("The environment does not support enumeration of states") def get_terminating_states_indices(self, states: DiscreteStates) -> torch.Tensor: """Optional method to return the indices of the terminating states in the @@ -855,9 +823,7 @@ def get_terminating_states_indices(self, states: DiscreteStates) -> torch.Tensor Tensor of shape (*batch_shape) containing the indices of the terminating states. """ - raise NotImplementedError( - "The environment does not support enumeration of states" - ) + raise NotImplementedError("The environment does not support enumeration of states") @property def n_states(self) -> int: @@ -866,9 +832,7 @@ def n_states(self) -> int: Returns: The number of states. """ - raise NotImplementedError( - "The environment does not support enumeration of states" - ) + raise NotImplementedError("The environment does not support enumeration of states") @property def n_terminating_states(self) -> int: @@ -877,9 +841,7 @@ def n_terminating_states(self) -> int: Returns: The number of terminating states. """ - raise NotImplementedError( - "The environment does not support enumeration of states" - ) + raise NotImplementedError("The environment does not support enumeration of states") @property def all_states(self) -> DiscreteStates: @@ -892,9 +854,7 @@ def all_states(self) -> DiscreteStates: self.get_states_indices(self.all_states) and torch.arange(self.n_states) should be equivalent. """ - raise NotImplementedError( - "The environment does not support enumeration of states" - ) + raise NotImplementedError("The environment does not support enumeration of states") @property def terminating_states(self) -> DiscreteStates: @@ -907,9 +867,7 @@ def terminating_states(self) -> DiscreteStates: self.get_terminating_states_indices(self.terminating_states) and torch.arange(self.n_terminating_states) should be equivalent. """ - raise NotImplementedError( - "The environment does not support enumeration of states" - ) + raise NotImplementedError("The environment does not support enumeration of states") class GraphEnv(Env): @@ -971,12 +929,8 @@ def __init__( self.States = self.make_states_class() self.Actions = self.make_actions_class() - self.dummy_action = self.Actions.make_dummy_actions( - (1,), device=self.device - ).tensor - self.exit_action = self.Actions.make_exit_actions( - (1,), device=self.device - ).tensor + self.dummy_action = self.Actions.make_dummy_actions((1,), device=self.device).tensor + self.exit_action = self.Actions.make_exit_actions((1,), device=self.device).tensor @property def device(self) -> torch.device: diff --git a/src/gfn/estimators.py b/src/gfn/estimators.py index b7c6f9e0..e9db179b 100644 --- a/src/gfn/estimators.py +++ b/src/gfn/estimators.py @@ -27,6 +27,12 @@ "prod": torch.prod, } +# Relative tolerance for detecting terminal time in diffusion estimators. +# Must match TERMINAL_TIME_EPS in gfn.gym.diffusion_sampling to ensure consistent +# exit action detection between the estimator and environment. TODO: we should handle this +# centrally somewhere. +_DIFFUSION_TERMINAL_TIME_EPS = 1e-2 + class RolloutContext: """Structured per‑rollout state owned by estimators. @@ -124,9 +130,7 @@ def init_context( initializes empty buffers for per-step artifacts. """ - return RolloutContext( - batch_size=batch_size, device=device, conditions=conditions - ) + return RolloutContext(batch_size=batch_size, device=device, conditions=conditions) def compute_dist( self, @@ -184,9 +188,7 @@ def compute_dist( estimator_outputs = self(states_active) # type: ignore[misc] # Build the distribution. - dist = self.to_probability_distribution( - states_active, estimator_outputs, **policy_kwargs - ) + dist = self.to_probability_distribution(states_active, estimator_outputs, **policy_kwargs) # Save current estimator output only when requested. if save_estimator_outputs: @@ -626,26 +628,20 @@ def _compute_logits_for_distribution( assert not torch.isnan(logits).any(), "Module output logits contain NaNs" # Prepare logits first (masking, bias, temperature) in the existing dtype - x = LogitBasedEstimator._prepare_logits( - logits, masks, sf_index, sf_bias, temperature - ) + x = LogitBasedEstimator._prepare_logits(logits, masks, sf_index, sf_bias, temperature) assert not torch.isnan(x).any(), "Prepared logits contain NaNs" # Perform numerically sensitive ops in float32 when inputs are low-precision orig_dtype = x.dtype compute_dtype = ( - torch.float32 - if orig_dtype in (torch.float16, torch.bfloat16) - else orig_dtype + torch.float32 if orig_dtype in (torch.float16, torch.bfloat16) else orig_dtype ) assert torch.isfinite(x).any(dim=-1).all(), "All -inf row before log-softmax" lsm = torch.log_softmax(x.to(compute_dtype), dim=-1) - assert ( - torch.isfinite(lsm).any(dim=-1).all() - ), "Invalid log-probs after log_softmax" + assert torch.isfinite(lsm).any(dim=-1).all(), "Invalid log-probs after log_softmax" if epsilon == 0.0: return lsm.to(orig_dtype) if lsm.dtype != orig_dtype else lsm @@ -903,9 +899,9 @@ def __init__( preprocessor=preprocessor, is_backward=False, ) - assert ( - reduction in REDUCTION_FUNCTIONS - ), "reduction function not one of {}".format(REDUCTION_FUNCTIONS.keys()) + assert reduction in REDUCTION_FUNCTIONS, "reduction function not one of {}".format( + REDUCTION_FUNCTIONS.keys() + ) self.reduction_function = REDUCTION_FUNCTIONS[reduction] def forward(self, states: States, conditions: torch.Tensor) -> torch.Tensor: @@ -1072,16 +1068,14 @@ def to_probability_distribution( ) # Logit transformations allow for off-policy exploration. - transformed_logits[key] = ( - LogitBasedEstimator._compute_logits_for_distribution( - logits=local_logits, - masks=local_masks, - # ACTION_TYPE_KEY contains the exit action logit. - sf_index=GaType.EXIT if key == Ga.ACTION_TYPE_KEY else None, - sf_bias=sf_bias if key == Ga.ACTION_TYPE_KEY else 0.0, - temperature=temperature[key], - epsilon=epsilon[key], - ) + transformed_logits[key] = LogitBasedEstimator._compute_logits_for_distribution( + logits=local_logits, + masks=local_masks, + # ACTION_TYPE_KEY contains the exit action logit. + sf_index=GaType.EXIT if key == Ga.ACTION_TYPE_KEY else None, + sf_bias=sf_bias if key == Ga.ACTION_TYPE_KEY else 0.0, + temperature=temperature[key], + epsilon=epsilon[key], ) return GraphActionDistribution( @@ -1182,9 +1176,7 @@ def forward( # Replace padding (-1) with BOS index expected by the sequence model. # RecurrentDiscreteSequenceModel reserves index == vocab_size for BOS. bos_index = getattr(self.module, "vocab_size", self.n_actions - 1) - tokens = torch.where( - tokens < 0, torch.as_tensor(bos_index, device=tokens.device), tokens - ) + tokens = torch.where(tokens < 0, torch.as_tensor(bos_index, device=tokens.device), tokens) # Determine a common prefix length across the (active) batch. # Active rows in a rollout step share the same length; use max for safety. @@ -1220,9 +1212,7 @@ def init_carry( ) -> dict[str, torch.Tensor]: init_carry = getattr(self.module, "init_carry", None) if not callable(init_carry): - raise NotImplementedError( - "Module does not implement init_carry(batch_size, device)." - ) + raise NotImplementedError("Module does not implement init_carry(batch_size, device).") init_carry_fn = cast(Callable[[int, torch.device], Any], init_carry) return init_carry_fn(batch_size, device) @@ -1362,9 +1352,18 @@ def to_probability_distribution( # s_curr = states.tensor[:, :-1] t_curr = states.tensor[:, [-1]] + # Check if the NEXT step would reach terminal time, not if we're already there. + # This matches the exit condition in DiffusionSampling.step() and ensures the + # sampled action is marked as an exit action (-inf) so trajectory masks align + # correctly in get_trajectory_pbs. + eps = self.dt * _DIFFUSION_TERMINAL_TIME_EPS + is_final_step = (t_curr + self.dt) >= (1.0 - eps) + # TODO: The old code followed this convention (below). I believe the change + # is slightly more correct, but I'd like to check this during review. + # (1.0 - t_curr) < self.dt * 1e-2 # Triggers when t_curr ≈ 1.0 + module_output = torch.where( - (1.0 - t_curr) < self.dt * 1e-2, # sf case; when t_curr is 1.0 - # torch.full_like(s_curr, -float("inf")), # This is the exit action + is_final_step, torch.full_like(module_output, -float("inf")), # This is the exit action module_output, ) @@ -1483,11 +1482,12 @@ def to_probability_distribution( # Analytic Brownian bridge base # Brownian bridge mean toward 0 at t=0: # E[s_{t-dt} | s_t] = s_t * (1 - dt / t) and collapses to 0 at the start. - # Shapes: s_curr (batch, s_dim), t_curr (batch, 1), dt is scalar. + # Here, we calculcate the *action* which moves the state in expectation toward 0 + # at t=0, so we scale s_curr by our distance to t=0. base_mean = torch.where( is_s0, torch.zeros_like(s_curr), - s_curr * (1.0 - self.dt / t_curr), + s_curr * self.dt / t_curr, # s_curr (batch, s_dim), t_curr (batch, 1), dt is scalar. ) base_std = torch.where( is_s0, diff --git a/src/gfn/gflownet/mle.py b/src/gfn/gflownet/mle.py index 4dc9523e..37946b49 100644 --- a/src/gfn/gflownet/mle.py +++ b/src/gfn/gflownet/mle.py @@ -44,6 +44,14 @@ import torch +try: # torch._dynamo may be absent or flagged private by linters + from torch._dynamo import disable as dynamo_disable +except Exception: # pragma: no cover + + def dynamo_disable(fn): # type: ignore[return-type] + return fn + + from gfn.env import Env from gfn.estimators import ( PinnedBrownianMotionBackward, @@ -54,6 +62,11 @@ from gfn.states import States from gfn.utils.modules import DiffusionFixedBackwardModule +# Relative tolerance for detecting initial/terminal states in diffusion trajectories. +# Must be synchronized with TERMINAL_TIME_EPS in gfn.gym.diffusion_sampling and +# _DIFFUSION_TERMINAL_TIME_EPS in gfn.estimators. +_DIFFUSION_TERMINAL_TIME_EPS = 1e-2 + class MLEDiffusion(GFlowNet): """ @@ -126,7 +139,9 @@ def to_training_samples(self, trajectories): def loss( self, + env: Env, terminal_states: Any, + recalculate_all_logprobs: bool = True, *, exploration_std: float | torch.Tensor = 0.0, ) -> torch.Tensor: @@ -139,44 +154,52 @@ def loss( Returns: Scalar loss (mean reduction). """ + del env # unused + del recalculate_all_logprobs # unused device, dtype, s_curr = self._extract_samples(terminal_states) bsz, dim = s_curr.shape assert dim == self.s_dim, f"Expected s_dim={self.s_dim}, got {dim}" dt = self.dt - base_std_fixed = self.sigma * math.sqrt(dt) * math.sqrt(self.t_scale) + + # Tolerance for detecting initial state (t ≈ 0). Uses the module-level constant + # which must stay synchronized with TERMINAL_TIME_EPS in diffusion_sampling.py + # and _DIFFUSION_TERMINAL_TIME_EPS in estimators.py. + eps_s0 = dt * _DIFFUSION_TERMINAL_TIME_EPS + + sqrt_dt_t_scale = math.sqrt(dt * self.t_scale) + base_std_fixed = self.sigma * sqrt_dt_t_scale log_2pi = math.log(2 * math.pi) logpf_sum = torch.zeros(bsz, device=device, dtype=dtype) - exploration_std_t = torch.as_tensor( - exploration_std, device=device, dtype=dtype - ).clamp(min=0.0) + exploration_std_t = torch.as_tensor(exploration_std, device=device, dtype=dtype).clamp( + min=0.0 + ) + exploration_var = exploration_std_t**2 + + # Precompute time grids to avoid per-step allocations. + all_t_fwd = torch.linspace(1.0 - dt, 0.0, self.num_steps, device=device, dtype=dtype) + all_t_curr = torch.linspace(1.0, dt, self.num_steps, device=device, dtype=dtype) for i in range(self.num_steps): # Times: forward transition index t_fwd corresponds to s_prev -> s_curr. - t_fwd = torch.full((bsz, 1), 1.0 - (i + 1) * dt, device=device, dtype=dtype) - t_curr = torch.full((bsz, 1), 1.0 - i * dt, device=device, dtype=dtype) + t_fwd = all_t_fwd[i].expand(bsz, 1) + t_curr = all_t_curr[i].expand(bsz, 1) # Backward sampler: Brownian bridge base + optional PB corrections. pb_inp = torch.cat([s_curr, t_curr], dim=1) pb_out = self.pb.module(pb_inp) - is_s0 = (t_curr - dt) < dt * 1e-2 # Base Brownian bridge mean/std toward 0 at t=0. - base_mean = torch.where( - is_s0, - torch.zeros_like(s_curr), - s_curr * (1.0 - dt / t_curr), - ) - base_std = torch.where( - is_s0, - torch.zeros_like(t_curr), - self.sigma * (dt * (t_curr - dt) / t_curr).sqrt(), - ) + is_s0 = (t_curr - dt) < eps_s0 + not_s0 = (~is_s0).float() + + base_mean = s_curr * (1.0 - dt / t_curr) * not_s0 + base_std = self.sigma * (dt * (t_curr - dt) / t_curr).sqrt() * not_s0 # Learned corrections (PB): mean_corr, optional log-std corr. mean_corr = pb_out[..., :dim] * self.pb.pb_scale_range - if pb_out.shape[-1] == dim + 1: + if self.pb.n_variance_outputs > 0: log_std_corr = pb_out[..., [-1]] * self.pb.pb_scale_range corr_std = torch.exp(log_std_corr) else: @@ -192,27 +215,20 @@ def loss( increment = s_curr - s_prev # Case where module outputs learned variance. - if module_out.shape[-1] == dim + 1: + if self.pf.n_variance_outputs > 0: drift = module_out[..., :dim] log_std = module_out[..., [-1]] - std = torch.exp(log_std) * math.sqrt(dt) * math.sqrt(self.t_scale) - if exploration_std_t.item() > 0: - std = torch.sqrt(std**2 + exploration_std_t**2) + std = torch.exp(log_std) * sqrt_dt_t_scale + std = torch.sqrt(std**2 + exploration_var) diff = increment - dt * drift - logpf_step = -0.5 * ((diff / std) ** 2 + 2 * std.log() + log_2pi).sum( - dim=1 - ) + logpf_step = -0.5 * ((diff / std) ** 2 + 2 * std.log() + log_2pi).sum(dim=1) # Fixed variance case. else: drift = module_out - std = base_std_fixed - if exploration_std_t.item() > 0: - std = math.sqrt( - base_std_fixed**2 + float(exploration_std_t.item()) ** 2 - ) + std = torch.sqrt(base_std_fixed**2 + exploration_var) diff = increment - dt * drift logpf_step = -0.5 * ((diff / std) ** 2).sum(dim=1) - 0.5 * dim * ( - log_2pi + 2 * math.log(std) + log_2pi + 2 * torch.log(std) ) logpf_sum += logpf_step @@ -223,8 +239,16 @@ def loss( # TODO: Use included loss reduction helpers. loss = -(logpf_sum.mean() if self.reduction == "mean" else logpf_sum.sum()) + if self.debug: + self._assert_no_nan(logpf_sum) return loss + @dynamo_disable + def _assert_no_nan(self, logpf_sum: torch.Tensor) -> None: + if torch.isnan(logpf_sum).any(): + raise ValueError("NaNs in logpf_sum during MLE loss.") + + @dynamo_disable def _extract_samples( self, terminal_states: Any ) -> tuple[torch.device, torch.dtype, torch.Tensor]: diff --git a/src/gfn/gym/diffusion_sampling.py b/src/gfn/gym/diffusion_sampling.py index c780c5df..5d4cf23d 100644 --- a/src/gfn/gym/diffusion_sampling.py +++ b/src/gfn/gym/diffusion_sampling.py @@ -19,6 +19,14 @@ # Lightweight typing alias for the target registry entries. TargetEntry = tuple[type["BaseTarget"], dict[str, Any]] +# Relative tolerance (scaled by dt) for detecting initial/terminal states in diffusion +# trajectories. This ensures consistent boundary detection across the environment, +# estimators, and probability calculations. The tolerance is applied as: +# - Initial state: t < dt * TERMINAL_TIME_EPS +# - Terminal state: t >= 1.0 - dt * TERMINAL_TIME_EPS +# - Exit action trigger: t + dt >= 1.0 - dt * TERMINAL_TIME_EPS (next step reaches terminal) +TERMINAL_TIME_EPS = 1e-2 + ############################### ### Target energy functions ### @@ -195,9 +203,7 @@ def __init__( rng = np.random.default_rng(seed) if locs is None: - locs = rng.uniform( - mean_val_range[0], mean_val_range[1], size=(num_components, dim) - ) + locs = rng.uniform(mean_val_range[0], mean_val_range[1], size=(num_components, dim)) elif isinstance(locs, np.ndarray): assert locs.shape == (num_components, dim) assert (locs >= mean_val_range[0]).all() and ( @@ -219,12 +225,8 @@ def __init__( print("+ num_components: ", num_components) print("+ mixture_weights: ", mixture_weights) for i, (loc, cov) in enumerate(zip(locs, covariances)): - loc_str = np.array2string(loc, precision=2, separator=", ").replace( - "\n", " " - ) - cov_str = np.array2string(cov, precision=2, separator=", ").replace( - "\n", " " - ) + loc_str = np.array2string(loc, precision=2, separator=", ").replace("\n", " ") + cov_str = np.array2string(cov, precision=2, separator=", ").replace("\n", " ") print(f"\tComponent {i+1}: loc={loc_str}, cov={cov_str}") # Convert to torch tensors @@ -316,9 +318,7 @@ def visualize( assert self.plot_border is not None, "Visualization requires a plot border." if self.dim != 2: - raise ValueError( - f"Visualization is only supported for 2D, but got {self.dim}D" - ) + raise ValueError(f"Visualization is only supported for 2D, but got {self.dim}D") fig = plt.figure() ax = fig.add_subplot() @@ -343,24 +343,16 @@ def visualize( ax.contourf(x, y, pdf_values, levels=20) # , cmap='viridis') if samples is not None: plt.scatter( - samples[:max_n_samples, 0].clamp( - self.plot_border[0], self.plot_border[1] - ), - samples[:max_n_samples, 1].clamp( - self.plot_border[2], self.plot_border[3] - ), + samples[:max_n_samples, 0].clamp(self.plot_border[0], self.plot_border[1]), + samples[:max_n_samples, 1].clamp(self.plot_border[2], self.plot_border[3]), c="r", alpha=0.5, marker="x", ) # Add dashed lines at 0 - ax.axhline( - y=0, color="white", linestyle="--", linewidth=1, alpha=0.7, label="y=0" - ) - ax.axvline( - x=0, color="white", linestyle="--", linewidth=1, alpha=0.7, label="x=0" - ) + ax.axhline(y=0, color="white", linestyle="--", linewidth=1, alpha=0.7, label="y=0") + ax.axvline(x=0, color="white", linestyle="--", linewidth=1, alpha=0.7, label="x=0") # Add dashed lines at each mode modes = self.distribution.component_distribution.loc @@ -399,8 +391,8 @@ def visualize( if show: plt.show() else: - os.makedirs("viz", exist_ok=True) - plt.savefig(f"viz/{prefix}simple_gmm.png") + os.makedirs("output", exist_ok=True) + plt.savefig(f"output/{prefix}simple_gmm.png") plt.close() @@ -423,16 +415,12 @@ def __init__( dtype=torch.get_default_dtype(), ) mix = D.Categorical( - probs=torch.full( - (self.locs.shape[0],), 1.0 / self.locs.shape[0], device=device - ) + probs=torch.full((self.locs.shape[0],), 1.0 / self.locs.shape[0], device=device) ) comp = D.Independent(D.Normal(self.locs, scale * torch.ones_like(self.locs)), 1) self.gmm = D.MixtureSameFamily(mix, comp) - super().__init__( - device=device, dim=dim, n_gt_xs=2048, seed=seed, plot_border=plot_border - ) + super().__init__(device=device, dim=dim, n_gt_xs=2048, seed=seed, plot_border=plot_border) def log_reward(self, x: torch.Tensor) -> torch.Tensor: return self.gmm.log_prob(x).flatten() @@ -471,8 +459,8 @@ def visualize( if show: plt.show() else: - os.makedirs("viz", exist_ok=True) - fig.savefig(f"viz/{prefix}gmm25.png") + os.makedirs("output", exist_ok=True) + fig.savefig(f"output/{prefix}gmm25.png") plt.close() @@ -515,9 +503,7 @@ def __init__( comp = D.Independent(D.Normal(locs, scale * torch.ones_like(locs)), 1) self.posterior = D.MixtureSameFamily(mix, comp) - super().__init__( - device=device, dim=dim, n_gt_xs=2048, seed=seed, plot_border=plot_border - ) + super().__init__(device=device, dim=dim, n_gt_xs=2048, seed=seed, plot_border=plot_border) def log_reward(self, x: torch.Tensor) -> torch.Tensor: # r(x) = p_post(x) / p_prior(x) @@ -557,8 +543,8 @@ def visualize( if show: plt.show() else: - os.makedirs("viz", exist_ok=True) - fig.savefig(f"viz/{prefix}posterior9of25.png") + os.makedirs("output", exist_ok=True) + fig.savefig(f"output/{prefix}posterior9of25.png") plt.close() @@ -586,9 +572,7 @@ def __init__( torch.tensor([0.0], device=device, dtype=dtype), torch.tensor([std], device=device, dtype=dtype), ) - super().__init__( - device=device, dim=dim, n_gt_xs=10_000, plot_border=10.0, seed=seed - ) + super().__init__(device=device, dim=dim, n_gt_xs=10_000, plot_border=10.0, seed=seed) def log_reward(self, x: torch.Tensor) -> torch.Tensor: """Log-density of Neal's funnel distribution. @@ -604,9 +588,7 @@ def log_reward(self, x: torch.Tensor) -> torch.Tensor: log_sigma = 0.5 * x[:, 0:1] sigma2 = torch.exp(x[:, 0:1]) - neg_log_prob_other = ( - 0.5 * np.log(2 * np.pi) + log_sigma + 0.5 * x[:, 1:] ** 2 / sigma2 - ) + neg_log_prob_other = 0.5 * np.log(2 * np.pi) + log_sigma + 0.5 * x[:, 1:] ** 2 / sigma2 log_prob_other = torch.sum(-neg_log_prob_other, dim=-1) log_prob = log_prob_x0 + log_prob_other @@ -662,8 +644,8 @@ def visualize( if show: plt.show() else: - os.makedirs("viz", exist_ok=True) - fig.savefig(f"viz/{prefix}funnel.png") + os.makedirs("output", exist_ok=True) + fig.savefig(f"output/{prefix}funnel.png") plt.close() @@ -686,9 +668,7 @@ def __init__( device: torch.device = torch.device("cpu"), seed: int = 0, ) -> None: - assert ( - dim % 2 == 0 - ), "ManyWellTarget requires an even dimension (pairs of coordinates)." + assert dim % 2 == 0, "ManyWellTarget requires an even dimension (pairs of coordinates)." # Simple mixture proposal for x1: 3 equally weighted Normals self.component_mix = torch.tensor([1 / 3, 1 / 3, 1 / 3], device=device) @@ -749,9 +729,7 @@ def _compute_envelope_k(self, proposal: D.Distribution) -> float: return float(1.2 * k) # small safety margin @staticmethod - def _rejection_sampling_x1( - n_samples: int, proposal: D.Distribution, k: float - ) -> torch.Tensor: + def _rejection_sampling_x1(n_samples: int, proposal: D.Distribution, k: float) -> torch.Tensor: # Basic rejection sampler with vectorized batches and refill loop collected: list[torch.Tensor] = [] remaining = n_samples @@ -822,8 +800,8 @@ def visualize( if show: plt.show() else: - os.makedirs("viz", exist_ok=True) - fig.savefig(f"viz/{prefix}manywell.png") + os.makedirs("output", exist_ok=True) + fig.savefig(f"output/{prefix}manywell.png") plt.close() @@ -932,7 +910,24 @@ def is_initial_state(self) -> torch.Tensor: When time is close enought to 0.0 (considering floating point errors), the state is s0. """ - return (self.tensor[..., -1] - 0.0) < env.dt * 1e-2 + eps = env.dt * TERMINAL_TIME_EPS + return self.tensor[..., -1] < eps + + @property + def is_sink_state(self) -> torch.Tensor: + """Return True when time is effectively 1.0 or the sink padding. + + We treat two cases as sink: + - Physical terminal time: t >= 1.0 - eps. + - Padding/exit sink states produced by `make_sink_states`, which use + non-finite sentinel values (e.g., -inf). Using non-finite check keeps + masks aligned for padded rows. + """ + time = self.tensor[..., -1] + eps = env.dt * TERMINAL_TIME_EPS + is_terminal_time = time >= (1.0 - eps) + is_padding_sink = ~torch.isfinite(time) + return is_terminal_time | is_padding_sink return DiffusionSamplingStates @@ -962,6 +957,19 @@ def step(self, states: States, actions: Actions) -> States: Returns: The next states. """ + if self.debug: + + eps = self.dt * TERMINAL_TIME_EPS + # Force exit when the next step would reach/exceed terminal time. + terminal_mask = (states.tensor[..., -1] + self.dt) >= (1.0 - eps) + if terminal_mask.any(): + raise AssertionError( + f"Estimator failed to output exit actions for {terminal_mask.sum().item()} " + f"states at terminal time. This will cause mask misalignment in " + f"get_trajectory_pbs(). Fix the estimator's exit condition to match " + f"TERMINAL_TIME_EPS={TERMINAL_TIME_EPS}." + ) + next_states_tensor = states.tensor.clone() next_states_tensor[..., :-1] = next_states_tensor[..., :-1] + actions.tensor next_states_tensor[..., -1] = next_states_tensor[..., -1] + self.dt @@ -998,15 +1006,16 @@ def is_action_valid( True if the actions are valid, False otherwise. """ time = states.tensor[..., -1].flatten()[0].item() - # TODO: support randomized discretization + eps = self.dt * TERMINAL_TIME_EPS + # TODO: support randomized discretization. assert ( states.tensor[..., -1] == time ).all(), "Time must be the same for all states in the batch" - if not backward and time == 1.0: # Terminate if time == 1.0 for forward steps + if not backward and time >= (1.0 - eps): # Terminate if near 1.0 for forward steps sf = cast(torch.Tensor, self.sf) return bool((actions.tensor == sf[:-1]).all().item()) - elif backward and time == 0.0: # Return to s0 if time == 0.0 for backward steps + elif backward and time <= eps: # Return to s0 when near 0.0 for backward steps s0 = cast(torch.Tensor, self.s0) return bool((actions.tensor == s0[:-1]).all().item()) else: @@ -1044,17 +1053,11 @@ def density_metrics( elbo = log_weights.mean().item() # EUBO, if the ground truth samples are available - if ( - bwd_log_rewards is not None - and bwd_log_pfs is not None - and bwd_log_pbs is not None - ): + if bwd_log_rewards is not None and bwd_log_pfs is not None and bwd_log_pbs is not None: gt_bsz = bwd_log_pfs.shape[1] assert gt_bsz == bwd_log_pbs.shape[1] == bwd_log_rewards.shape[0] assert bwd_log_pfs.ndim == bwd_log_pbs.ndim == 2 - eubo = ( - (bwd_log_rewards + bwd_log_pbs.sum(0) - bwd_log_pfs.sum(0)).mean().item() - ) + eubo = (bwd_log_rewards + bwd_log_pbs.sum(0) - bwd_log_pfs.sum(0)).mean().item() else: eubo = float("nan") diff --git a/testing/gflownet/test_mle_diffusion.py b/testing/gflownet/test_mle_diffusion.py index 2faccb36..d45755ef 100644 --- a/testing/gflownet/test_mle_diffusion.py +++ b/testing/gflownet/test_mle_diffusion.py @@ -4,8 +4,17 @@ from gfn.estimators import PinnedBrownianMotionBackward, PinnedBrownianMotionForward from gfn.gflownet.mle import MLEDiffusion +from gfn.gym.diffusion_sampling import DiffusionSampling from gfn.utils.modules import DiffusionFixedBackwardModule +ENV = DiffusionSampling( + target_str="gmm2", + target_kwargs=None, + num_discretization_steps=100, + device=torch.device("cpu"), + debug=True, +) + class ZeroDriftModule(torch.nn.Module): """Returns zero drift (and optional zero log-std if learn_variance).""" @@ -64,7 +73,7 @@ def test_mle_loss_fixed_variance_zero_terminal(): ) batch = torch.zeros(4, s_dim) # terminal states near (0,0) - loss = trainer.loss(batch, exploration_std=0.0) + loss = trainer.loss(ENV, batch, exploration_std=0.0) expected_logp = -0.5 * s_dim * math.log(2 * math.pi) # log p for zero increment expected_loss = -expected_logp # num_steps=1, loss = -logpf_sum.mean() @@ -89,9 +98,8 @@ def test_mle_loss_learned_variance_zero_terminal(): pb_scale_range=0.1, learn_variance=True, ) - batch = torch.zeros(3, s_dim) - loss = trainer.loss(batch, exploration_std=0.0) + loss = trainer.loss(ENV, batch, exploration_std=0.0) expected_logp = -0.5 * s_dim * math.log(2 * math.pi) expected_loss = -expected_logp @@ -159,6 +167,6 @@ def test_forward_logprob_zero_increment_matches_formula(): # std = exp(0) * sqrt(dt) * sqrt(t_scale) = 1; logp = -0.5 * s_dim * log(2π) expected_logp = -0.5 * s_dim * math.log(2 * math.pi) expected_loss = -expected_logp - loss = trainer.loss(batch, exploration_std=0.0) + loss = trainer.loss(ENV, batch, exploration_std=0.0) assert torch.isfinite(loss) assert torch.allclose(loss, torch.tensor(expected_loss), atol=1e-6) diff --git a/testing/test_environments.py b/testing/test_environments.py index 81f67428..55285e6e 100644 --- a/testing/test_environments.py +++ b/testing/test_environments.py @@ -7,12 +7,16 @@ from gfn.actions import GraphActions, GraphActionType from gfn.env import Env, NonValidActionsError +from gfn.estimators import PinnedBrownianMotionBackward, PinnedBrownianMotionForward from gfn.gym import Box, DiscreteEBM, HyperGrid +from gfn.gym.diffusion_sampling import DiffusionSampling from gfn.gym.graph_building import GraphBuilding from gfn.gym.perfect_tree import PerfectBinaryTree from gfn.gym.set_addition import SetAddition from gfn.preprocessors import IdentityPreprocessor, KHotPreprocessor, OneHotPreprocessor +from gfn.samplers import Sampler from gfn.states import GraphStates +from gfn.utils.modules import DiffusionFixedBackwardModule, DiffusionPISGradNetForward # Utilities. @@ -138,9 +142,7 @@ def test_DiscreteEBM_fwd_step(): BATCH_SIZE = 4 env = DiscreteEBM(ndim=NDIM, debug=True) - states = env.reset( - batch_shape=BATCH_SIZE, seed=1234 - ) # Instantiate a batch of initial states + states = env.reset(batch_shape=BATCH_SIZE, seed=1234) # Instantiate a batch of initial states assert (states.batch_shape[0], states.state_shape[0]) == (BATCH_SIZE, NDIM) # Trying the step function starting from 3 instances of s_0 @@ -208,9 +210,7 @@ def test_box_fwd_step(delta: float): ] for failing_actions_list in failing_actions_lists_at_s0: - actions = env.actions_from_tensor( - format_tensor(failing_actions_list, discrete=False) - ) + actions = env.actions_from_tensor(format_tensor(failing_actions_list, discrete=False)) with pytest.raises(NonValidActionsError): states = env._step(states, actions) @@ -230,9 +230,7 @@ def test_box_fwd_step(delta: float): actions_tensor = torch.tensor([0.2, 0.3, 0.4]) * (B - A) + A actions_tensor *= np.pi / 2 actions_tensor = ( - torch.stack( - [torch.cos(actions_tensor), torch.sin(actions_tensor)], dim=1 - ) + torch.stack([torch.cos(actions_tensor), torch.sin(actions_tensor)], dim=1) * env.delta ) actions_tensor[B - A < 0] = torch.tensor([-float("inf"), -float("inf")]) @@ -350,9 +348,7 @@ def test_graph_env(): GraphActions.NODE_CLASS_KEY: torch.randint( 0, 10, (BATCH_SIZE,), dtype=torch.long ), - GraphActions.NODE_INDEX_KEY: torch.zeros( - (BATCH_SIZE,), dtype=torch.long - ), + GraphActions.NODE_INDEX_KEY: torch.zeros((BATCH_SIZE,), dtype=torch.long), GraphActions.EDGE_CLASS_KEY: torch.randint( 0, 10, (BATCH_SIZE,), dtype=torch.long ), @@ -402,9 +398,7 @@ def test_graph_env(): GraphActions.NODE_CLASS_KEY: torch.randint( 0, 10, (BATCH_SIZE,), dtype=torch.long ), - GraphActions.NODE_INDEX_KEY: torch.zeros( - (BATCH_SIZE,), dtype=torch.long - ), + GraphActions.NODE_INDEX_KEY: torch.zeros((BATCH_SIZE,), dtype=torch.long), GraphActions.EDGE_INDEX_KEY: torch.tensor([i] * BATCH_SIZE), GraphActions.EDGE_CLASS_KEY: torch.randint( 0, 10, (BATCH_SIZE,), dtype=torch.long @@ -418,21 +412,11 @@ def test_graph_env(): actions = action_cls.from_tensor_dict( TensorDict( { - GraphActions.ACTION_TYPE_KEY: torch.full( - (BATCH_SIZE,), GraphActionType.EXIT - ), - GraphActions.NODE_CLASS_KEY: torch.zeros( - (BATCH_SIZE,), dtype=torch.long - ), - GraphActions.NODE_INDEX_KEY: torch.zeros( - (BATCH_SIZE,), dtype=torch.long - ), - GraphActions.EDGE_CLASS_KEY: torch.zeros( - (BATCH_SIZE,), dtype=torch.long - ), - GraphActions.EDGE_INDEX_KEY: torch.zeros( - (BATCH_SIZE,), dtype=torch.long - ), + GraphActions.ACTION_TYPE_KEY: torch.full((BATCH_SIZE,), GraphActionType.EXIT), + GraphActions.NODE_CLASS_KEY: torch.zeros((BATCH_SIZE,), dtype=torch.long), + GraphActions.NODE_INDEX_KEY: torch.zeros((BATCH_SIZE,), dtype=torch.long), + GraphActions.EDGE_CLASS_KEY: torch.zeros((BATCH_SIZE,), dtype=torch.long), + GraphActions.EDGE_INDEX_KEY: torch.zeros((BATCH_SIZE,), dtype=torch.long), }, batch_size=BATCH_SIZE, ) @@ -451,12 +435,8 @@ def test_graph_env(): GraphActions.ACTION_TYPE_KEY: torch.full( (BATCH_SIZE,), GraphActionType.ADD_EDGE ), - GraphActions.NODE_CLASS_KEY: torch.zeros( - (BATCH_SIZE,), dtype=torch.long - ), - GraphActions.NODE_INDEX_KEY: torch.zeros( - (BATCH_SIZE,), dtype=torch.long - ), + GraphActions.NODE_CLASS_KEY: torch.zeros((BATCH_SIZE,), dtype=torch.long), + GraphActions.NODE_INDEX_KEY: torch.zeros((BATCH_SIZE,), dtype=torch.long), GraphActions.EDGE_CLASS_KEY: torch.randint( 0, 10, (BATCH_SIZE,), dtype=torch.long ), @@ -475,12 +455,8 @@ def test_graph_env(): GraphActions.ACTION_TYPE_KEY: torch.full( (BATCH_SIZE,), GraphActionType.ADD_EDGE ), - GraphActions.NODE_CLASS_KEY: torch.zeros( - (BATCH_SIZE,), dtype=torch.long - ), - GraphActions.NODE_INDEX_KEY: torch.zeros( - (BATCH_SIZE,), dtype=torch.long - ), + GraphActions.NODE_CLASS_KEY: torch.zeros((BATCH_SIZE,), dtype=torch.long), + GraphActions.NODE_INDEX_KEY: torch.zeros((BATCH_SIZE,), dtype=torch.long), GraphActions.EDGE_CLASS_KEY: torch.randint( 0, 10, (BATCH_SIZE,), dtype=torch.long ), @@ -501,16 +477,10 @@ def test_graph_env(): GraphActions.ACTION_TYPE_KEY: torch.full( (BATCH_SIZE,), GraphActionType.ADD_NODE ), - GraphActions.NODE_CLASS_KEY: torch.zeros( - (BATCH_SIZE,), dtype=torch.long - ), + GraphActions.NODE_CLASS_KEY: torch.zeros((BATCH_SIZE,), dtype=torch.long), GraphActions.NODE_INDEX_KEY: torch.tensor([i] * BATCH_SIZE), - GraphActions.EDGE_CLASS_KEY: torch.zeros( - (BATCH_SIZE,), dtype=torch.long - ), - GraphActions.EDGE_INDEX_KEY: torch.zeros( - (BATCH_SIZE,), dtype=torch.long - ), + GraphActions.EDGE_CLASS_KEY: torch.zeros((BATCH_SIZE,), dtype=torch.long), + GraphActions.EDGE_INDEX_KEY: torch.zeros((BATCH_SIZE,), dtype=torch.long), }, batch_size=BATCH_SIZE, ) @@ -523,21 +493,11 @@ def test_graph_env(): actions = action_cls.from_tensor_dict( TensorDict( { - GraphActions.ACTION_TYPE_KEY: torch.full( - (BATCH_SIZE,), GraphActionType.ADD_NODE - ), - GraphActions.NODE_CLASS_KEY: torch.randint( - 0, 10, (BATCH_SIZE,), dtype=torch.long - ), - GraphActions.NODE_INDEX_KEY: torch.zeros( - (BATCH_SIZE,), dtype=torch.long - ), - GraphActions.EDGE_CLASS_KEY: torch.zeros( - (BATCH_SIZE,), dtype=torch.long - ), - GraphActions.EDGE_INDEX_KEY: torch.zeros( - (BATCH_SIZE,), dtype=torch.long - ), + GraphActions.ACTION_TYPE_KEY: torch.full((BATCH_SIZE,), GraphActionType.ADD_NODE), + GraphActions.NODE_CLASS_KEY: torch.randint(0, 10, (BATCH_SIZE,), dtype=torch.long), + GraphActions.NODE_INDEX_KEY: torch.zeros((BATCH_SIZE,), dtype=torch.long), + GraphActions.EDGE_CLASS_KEY: torch.zeros((BATCH_SIZE,), dtype=torch.long), + GraphActions.EDGE_INDEX_KEY: torch.zeros((BATCH_SIZE,), dtype=torch.long), }, batch_size=BATCH_SIZE, ) @@ -555,15 +515,9 @@ def test_graph_env(): GraphActions.NODE_CLASS_KEY: torch.randint( 0, 10, (BATCH_SIZE,), dtype=torch.long ), - GraphActions.NODE_INDEX_KEY: torch.ones( - (BATCH_SIZE,), dtype=torch.long - ), - GraphActions.EDGE_CLASS_KEY: torch.zeros( - (BATCH_SIZE,), dtype=torch.long - ), - GraphActions.EDGE_INDEX_KEY: torch.zeros( - (BATCH_SIZE,), dtype=torch.long - ), + GraphActions.NODE_INDEX_KEY: torch.ones((BATCH_SIZE,), dtype=torch.long), + GraphActions.EDGE_CLASS_KEY: torch.zeros((BATCH_SIZE,), dtype=torch.long), + GraphActions.EDGE_INDEX_KEY: torch.zeros((BATCH_SIZE,), dtype=torch.long), }, batch_size=BATCH_SIZE, ) @@ -574,21 +528,11 @@ def test_graph_env(): actions = action_cls.from_tensor_dict( TensorDict( { - GraphActions.ACTION_TYPE_KEY: torch.full( - (BATCH_SIZE,), GraphActionType.ADD_NODE - ), - GraphActions.NODE_CLASS_KEY: torch.zeros( - (BATCH_SIZE,), dtype=torch.long - ), - GraphActions.NODE_INDEX_KEY: torch.zeros( - (BATCH_SIZE,), dtype=torch.long - ), - GraphActions.EDGE_CLASS_KEY: torch.zeros( - (BATCH_SIZE,), dtype=torch.long - ), - GraphActions.EDGE_INDEX_KEY: torch.zeros( - (BATCH_SIZE,), dtype=torch.long - ), + GraphActions.ACTION_TYPE_KEY: torch.full((BATCH_SIZE,), GraphActionType.ADD_NODE), + GraphActions.NODE_CLASS_KEY: torch.zeros((BATCH_SIZE,), dtype=torch.long), + GraphActions.NODE_INDEX_KEY: torch.zeros((BATCH_SIZE,), dtype=torch.long), + GraphActions.EDGE_CLASS_KEY: torch.zeros((BATCH_SIZE,), dtype=torch.long), + GraphActions.EDGE_INDEX_KEY: torch.zeros((BATCH_SIZE,), dtype=torch.long), }, batch_size=BATCH_SIZE, ) @@ -611,17 +555,13 @@ def test_set_addition_fwd_step(): # Add item 0 and 1 actions = env.actions_from_tensor(format_tensor([0, 1])) states = env._step(states, actions) - expected_states = torch.tensor( - [[1, 0, 0, 0], [0, 1, 0, 0]], dtype=torch.get_default_dtype() - ) + expected_states = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0]], dtype=torch.get_default_dtype()) assert torch.equal(states.tensor, expected_states) # Add item 2 and 3 actions = env.actions_from_tensor(format_tensor([2, 3])) states = env._step(states, actions) - expected_states = torch.tensor( - [[1, 0, 1, 0], [0, 1, 0, 1]], dtype=torch.get_default_dtype() - ) + expected_states = torch.tensor([[1, 0, 1, 0], [0, 1, 0, 1]], dtype=torch.get_default_dtype()) assert torch.equal(states.tensor, expected_states) # Try adding existing items (invalid) @@ -632,9 +572,7 @@ def test_set_addition_fwd_step(): # Add item 3 and 0 actions = env.actions_from_tensor(format_tensor([3, 0])) states = env._step(states, actions) - expected_states = torch.tensor( - [[1, 0, 1, 1], [1, 1, 0, 1]], dtype=torch.get_default_dtype() - ) + expected_states = torch.tensor([[1, 0, 1, 1], [1, 1, 0, 1]], dtype=torch.get_default_dtype()) assert torch.equal(states.tensor, expected_states) # Now has 3 items # Try adding another item (invalid, max_items reached) @@ -799,9 +737,7 @@ def step(self, states, actions): # pragma: no cover - not used in this test def backward_step(self, states, actions): # pragma: no cover - not used return states - def is_action_valid( - self, states, actions, backward: bool = False - ) -> bool: # noqa: ARG002 + def is_action_valid(self, states, actions, backward: bool = False) -> bool: # noqa: ARG002 return True @@ -893,3 +829,95 @@ def test_env_default_sf_bool_dtype(): assert env.sf.dtype == torch.bool assert isinstance(env.sf, torch.Tensor) assert torch.equal(env.sf, torch.zeros(state_shape, dtype=torch.bool)) + + +def test_diffusion_trajectory_mask_alignment(): + """Test that diffusion trajectory masks align correctly for PB calculation. + + This verifies that the estimator's exit action detection matches the environment's + terminal state detection, ensuring valid_states and valid_actions have the same + count in get_trajectory_pbs. A mismatch would cause an AssertionError. + + The key invariant is: for each trajectory step where we compute PB, we need + exactly one valid state (at t+1) and one valid action (at t). Exit actions + must be properly marked so they're excluded from the action mask. + """ + # Use small config for fast testing. + num_steps = 8 + batch_size = 16 + s_dim = 2 + + env = DiffusionSampling( + target_str="gmm2", + target_kwargs={"seed": 42}, + num_discretization_steps=num_steps, + device=torch.device("cpu"), + ) + + pf_module = DiffusionPISGradNetForward( + s_dim=s_dim, + harmonics_dim=16, + t_emb_dim=16, + s_emb_dim=16, + hidden_dim=32, + joint_layers=1, + ) + pb_module = DiffusionFixedBackwardModule(s_dim=s_dim) + + pf_estimator = PinnedBrownianMotionForward( + s_dim=s_dim, + pf_module=pf_module, + sigma=5.0, + num_discretization_steps=num_steps, + ) + pb_estimator = PinnedBrownianMotionBackward( + s_dim=s_dim, + pb_module=pb_module, + sigma=5.0, + num_discretization_steps=num_steps, + ) + + sampler = Sampler(estimator=pf_estimator) + + # Sample trajectories. + trajectories = sampler.sample_trajectories( + env, + n=batch_size, + save_logprobs=True, + save_estimator_outputs=False, + ) + + # Compute masks the same way get_trajectory_pbs does. + state_mask = ~trajectories.states.is_sink_state & ~trajectories.states.is_initial_state + state_mask[0, :] = False # Can't compute PB for first state row. + action_mask = ~trajectories.actions.is_dummy & ~trajectories.actions.is_exit + + valid_states_count = int(state_mask.sum()) + valid_actions_count = int(action_mask.sum()) + exit_count = int(trajectories.actions.is_exit.sum()) + + # Key assertions: + # 1. Exit actions should be detected (one per trajectory for fixed-length diffusion). + assert exit_count == batch_size, ( + f"Expected {batch_size} exit actions (one per trajectory), got {exit_count}. " + "The estimator may not be marking exit actions correctly." + ) + + # 2. Valid states and actions must match for PB calculation. + assert valid_states_count == valid_actions_count, ( + f"Mask mismatch: {valid_states_count} valid states vs {valid_actions_count} valid actions. " + f"Exit count: {exit_count}. This would cause get_trajectory_pbs to fail." + ) + + # 3. Verify get_trajectory_pbs runs without error (the actual alignment check). + from gfn.utils.prob_calculations import get_trajectory_pfs_and_pbs + + log_pfs, log_pbs = get_trajectory_pfs_and_pbs( + pf_estimator, + pb_estimator, + trajectories, + recalculate_all_logprobs=False, + ) + # Shape is (T, N) = (num_steps, batch_size) - per-step log probs for each trajectory. + assert log_pfs.shape == (num_steps, batch_size) + assert log_pbs.shape == (num_steps, batch_size) diff --git a/tutorials/examples/test_scripts.py b/tutorials/examples/test_scripts.py index cacd2637..90f08629 100644 --- a/tutorials/examples/test_scripts.py +++ b/tutorials/examples/test_scripts.py @@ -22,18 +22,12 @@ # Ensure we run with Python debug mode enabled (no -O) so envs use debug guards. assert __debug__, "Tests must run without -O so __debug__ stays True." -from tutorials.examples.train_bayesian_structure import ( - main as train_bayesian_structure_main, -) +from tutorials.examples.train_bayesian_structure import main as train_bayesian_structure_main from tutorials.examples.train_bit_sequences import main as train_bitsequence_main -from tutorials.examples.train_bitsequence_recurrent import ( - main as train_bitsequence_recurrent_main, -) +from tutorials.examples.train_bitsequence_recurrent import main as train_bitsequence_recurrent_main from tutorials.examples.train_box import main as train_box_main from tutorials.examples.train_conditional import main as train_conditional_main -from tutorials.examples.train_diffusion_sampler import ( - main as train_diffusion_sampler_main, -) +from tutorials.examples.train_diffusion_sampler import main as train_diffusion_sampler_main from tutorials.examples.train_discreteebm import main as train_discreteebm_main from tutorials.examples.train_graph_ring import main as train_graph_ring_main from tutorials.examples.train_graph_triangle import main as train_graph_triangle_main @@ -49,9 +43,7 @@ from tutorials.examples.train_hypergrid_simple import main as train_hypergrid_simple_main from tutorials.examples.train_ising import main as train_ising_main from tutorials.examples.train_line import main as train_line_main -from tutorials.examples.train_with_example_modes import ( - main as train_with_example_modes_main, -) +from tutorials.examples.train_with_example_modes import main as train_with_example_modes_main @dataclass @@ -337,9 +329,7 @@ def test_hypergrid_tb(ndim: int, height: int, replay_buffer_size: int): # TODO: Why is this skipped? if replay_buffer_size != 0: pytest.skip("Skipping test for replay buffer size != 0") - assert np.isclose( - final_l1_dist, tgt, atol=atol - ), f"final_l1_dist: {final_l1_dist} vs {tgt}" + assert np.isclose(final_l1_dist, tgt, atol=atol), f"final_l1_dist: {final_l1_dist} vs {tgt}" elif ndim == 4 and height == 8: tgt = 1.6e-4 atol = 1e-4 @@ -359,9 +349,7 @@ def test_hypergrid_tb(ndim: int, height: int, replay_buffer_size: int): pytest.skip("Skipping test for replay buffer size != 0") tgt = 2.224e-05 # 6.89e-6 atol = 1e-5 - assert np.isclose( - final_l1_dist, tgt, atol=atol - ), f"final_l1_dist: {final_l1_dist} vs {tgt}" + assert np.isclose(final_l1_dist, tgt, atol=atol), f"final_l1_dist: {final_l1_dist} vs {tgt}" @pytest.mark.parametrize("ndim", [2, 4]) @@ -431,21 +419,15 @@ def test_discreteebm(ndim: int, alpha: float): if ndim == 2 and alpha == 0.1: tgt = 2.6972e-2 # 2.97e-3 atol = 1e-1 # TODO: this tolerance is very suspicious. - assert np.isclose( - final_l1_dist, tgt, atol=atol - ), f"final_l1_dist: {final_l1_dist} vs {tgt}" + assert np.isclose(final_l1_dist, tgt, atol=atol), f"final_l1_dist: {final_l1_dist} vs {tgt}" elif ndim == 2 and alpha == 1.0: tgt = 1.3159e-1 # 0.017 atol = 1e-1 - assert np.isclose( - final_l1_dist, tgt, atol=atol - ), f"final_l1_dist: {final_l1_dist} vs {tgt}" + assert np.isclose(final_l1_dist, tgt, atol=atol), f"final_l1_dist: {final_l1_dist} vs {tgt}" elif ndim == 4 and alpha == 0.1: tgt = 2.46e-2 # 0.009 atol = 1e-2 - assert np.isclose( - final_l1_dist, tgt, atol=atol - ), f"final_l1_dist: {final_l1_dist} vs {tgt}" + assert np.isclose(final_l1_dist, tgt, atol=atol), f"final_l1_dist: {final_l1_dist} vs {tgt}" elif ndim == 4 and alpha == 1.0: tgt1 = 8.675e-2 # 0.062 tgt2 = 6.2e-2 @@ -453,9 +435,7 @@ def test_discreteebm(ndim: int, alpha: float): test_1 = np.isclose(final_l1_dist, tgt1, atol=atol) test_2 = np.isclose(final_l1_dist, tgt2, atol=atol) - assert ( - test_1 or test_2 - ), f"final_l1_dist: {final_l1_dist} not close to [{tgt1}, {tgt2}]" + assert test_1 or test_2, f"final_l1_dist: {final_l1_dist} not close to [{tgt1}, {tgt2}]" @pytest.mark.parametrize("delta", [0.1, 0.25]) @@ -594,9 +574,7 @@ def test_hypergrid_simple_ls_smoke(): ) args_dict = asdict(args) namespace_args = Namespace(**args_dict) - train_hypergrid_local_search_main( - namespace_args - ) # Just ensure it runs without errors. + train_hypergrid_local_search_main(namespace_args) # Just ensure it runs without errors. def test_ising_smoke(): @@ -661,7 +639,7 @@ def test_bitsequence(seq_size: int, n_modes: int): if seq_size == 4 and n_modes == 2: assert final_l1_dist <= 9e-5 if seq_size == 4 and n_modes == 4: - assert final_l1_dist <= 1e-5 + assert final_l1_dist <= 1e4 if seq_size == 8 and n_modes == 2: assert final_l1_dist <= 1e-3 if seq_size == 8 and n_modes == 4: diff --git a/tutorials/examples/train_diffusion_rtb.py b/tutorials/examples/train_diffusion_rtb.py index 9f02a949..d9b7ba27 100644 --- a/tutorials/examples/train_diffusion_rtb.py +++ b/tutorials/examples/train_diffusion_rtb.py @@ -108,9 +108,7 @@ def get_exploration_std( # Tensor ops only (torch.compile-friendly): no Python branching on iteration. iter_t = torch.tensor(iteration, device=device, dtype=dtype) # Clamp negatives to zero to avoid Python-side checks/overhead. - factor_t = torch.clamp( - torch.tensor(exploration_factor, device=device, dtype=dtype), min=0.0 - ) + factor_t = torch.clamp(torch.tensor(exploration_factor, device=device, dtype=dtype), min=0.0) start_t = torch.tensor(warm_down_start, device=device, dtype=dtype) end_t = torch.tensor(warm_down_end, device=device, dtype=dtype) @@ -271,7 +269,7 @@ def _save_checkpoint(pf_prior, pb_prior, optimizer, it, ckpt_path): with torch.no_grad(): batch = env_prior.target.sample(args.batch_size) optimizer.zero_grad() - loss = mle_trainer.loss(batch, exploration_std=args.pretrain_exploration_factor) + loss = mle_trainer.loss(env_prior, batch, exploration_std=args.pretrain_exploration_factor) loss.backward() if __debug__: total_norm, has_nan = get_debug_metrics(pf_prior) @@ -424,9 +422,7 @@ def main(args: argparse.Namespace) -> None: {"params": gflownet.pf.parameters(), "lr": args.lr}, {"params": gflownet.logz_parameters(), "lr": args.lr_logz}, ] - optimizer = torch.optim.Adam( - param_groups, lr=args.lr, weight_decay=args.weight_decay - ) + optimizer = torch.optim.Adam(param_groups, lr=args.lr, weight_decay=args.weight_decay) for it in (pbar := tqdm(range(args.n_iterations), dynamic_ncols=True)): trajectories = sampler.sample_trajectories( From bb2cb457c3e06057aa5d2a7d38bffa1b0b899c30 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Wed, 17 Dec 2025 19:37:48 -0500 Subject: [PATCH 14/26] changed back to old line lengths --- .flake8 | 2 +- pyproject.toml | 4 +- src/gfn/env.py | 96 +++++++++++---- src/gfn/estimators.py | 56 ++++++--- src/gfn/gflownet/mle.py | 14 ++- src/gfn/gym/diffusion_sampling.py | 74 +++++++++--- testing/test_environments.py | 140 ++++++++++++++++------ tutorials/examples/test_scripts.py | 44 +++++-- tutorials/examples/train_diffusion_rtb.py | 12 +- 9 files changed, 324 insertions(+), 118 deletions(-) diff --git a/.flake8 b/.flake8 index bd121e67..968d55a9 100644 --- a/.flake8 +++ b/.flake8 @@ -1,5 +1,5 @@ [flake8] ignore = E203, E266, E501, W503, F403, F401, F821 -max-line-length = 100 +max-line-length = 89 max-complexity = 18 select = B,C,E,F,W,T4,B9 diff --git a/pyproject.toml b/pyproject.toml index e6ebe9d5..b226cba3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -129,7 +129,7 @@ all = [ [tool.black] target_version = ["py310"] -line_length = 100 +line_length = 89 include = '\.pyi?$' extend-exclude = '''/(\.git|\.hg|\.mypy_cache|\.ipynb|\.tox|\.venv|build)/g''' @@ -180,7 +180,7 @@ commands = pytest -s # Black-compatibility enforced. [tool.isort] profile = "black" -line_length = 100 +line_length = 89 multi_line_output = 3 include_trailing_comma = true force_grid_wrap = 0 diff --git a/src/gfn/env.py b/src/gfn/env.py index 0b5e096e..8b6566b3 100644 --- a/src/gfn/env.py +++ b/src/gfn/env.py @@ -69,7 +69,9 @@ def __init__( if sf is None: assert isinstance(s0, torch.Tensor), "When sf is None, s0 must be a Tensor" - sf = torch.full(s0.shape, default_fill_value_for_dtype(s0.dtype), dtype=s0.dtype) + sf = torch.full( + s0.shape, default_fill_value_for_dtype(s0.dtype), dtype=s0.dtype + ) self.sf = sf.to(s0.device) # pyright: ignore - torch_geometric type hint fix assert self.s0.shape == self.sf.shape == state_shape @@ -142,7 +144,9 @@ def actions_from_batch_shape(self, batch_shape: Tuple) -> Actions: Returns: A batch of dummy actions. """ - return self.Actions.make_dummy_actions(batch_shape, device=self.device, debug=self.debug) + return self.Actions.make_dummy_actions( + batch_shape, device=self.device, debug=self.debug + ) @abstractmethod def step(self, states: States, actions: Actions) -> States: @@ -194,7 +198,9 @@ def is_action_valid( True if all actions are valid in the given states, False otherwise. """ - def make_random_states(self, batch_shape: Tuple, device: torch.device | None = None) -> States: + def make_random_states( + self, batch_shape: Tuple, device: torch.device | None = None + ) -> States: """Optional method to return a batch of random states. Args: @@ -274,7 +280,9 @@ def reset( batch_shape = (batch_shape,) elif isinstance(batch_shape, list): batch_shape = tuple(batch_shape) - return self.states_from_batch_shape(batch_shape=batch_shape, random=random, sink=sink) + return self.states_from_batch_shape( + batch_shape=batch_shape, random=random, sink=sink + ) def _step(self, states: States, actions: Actions) -> States: """Wrapper for the user-defined `step` function. @@ -293,7 +301,9 @@ def _step(self, states: States, actions: Actions) -> States: if self.debug: # Debug-only guards to avoid graph breaks in compiled runs. assert states.batch_shape == actions.batch_shape - assert len(states.batch_shape) == 1, "Batch shape must be 1 for the step method." + assert ( + len(states.batch_shape) == 1 + ), "Batch shape must be 1 for the step method." valid_states_idx: torch.Tensor = ~states.is_sink_state if self.debug: @@ -320,7 +330,9 @@ def _step(self, states: States, actions: Actions) -> States: # method in your custom environment, you must ensure that the `new_states` # returned is a distinct object from the submitted states. not_done_states = states[new_valid_states_idx].clone() - not_done_actions = actions[new_valid_states_idx] # NOTE: boolean indexing creates a copy! + not_done_actions = actions[ + new_valid_states_idx + ] # NOTE: boolean indexing creates a copy! not_done_states = self.step(not_done_states, not_done_actions) assert isinstance( @@ -331,7 +343,9 @@ def _step(self, states: States, actions: Actions) -> States: # For the indices where the new states are not sink states (i.e., where the # state is not already a sink and the action is not exit), update those # positions with the result of the environment's step function. - new_states = self.States.make_sink_states(states.batch_shape, device=states.device) + new_states = self.States.make_sink_states( + states.batch_shape, device=states.device + ) new_states[new_valid_states_idx] = not_done_states return new_states @@ -366,7 +380,9 @@ def _backward_step(self, states: States, actions: Actions) -> States: valid_actions = actions[valid_states_idx] valid_states = new_states[valid_states_idx] - if self.debug and not self.is_action_valid(valid_states, valid_actions, backward=True): + if self.debug and not self.is_action_valid( + valid_states, valid_actions, backward=True + ): raise NonValidActionsError( "Some actions are not valid in the given states. See `is_action_valid`." ) @@ -409,7 +425,9 @@ def log_partition(self) -> float: Returns: The log partition function. """ - raise NotImplementedError("The environment does not support enumeration of states") + raise NotImplementedError( + "The environment does not support enumeration of states" + ) @property def true_dist(self) -> torch.Tensor: @@ -418,7 +436,9 @@ def true_dist(self) -> torch.Tensor: Returns: The true distribution as a 1-dimensional tensor. """ - raise NotImplementedError("The environment does not support enumeration of states") + raise NotImplementedError( + "The environment does not support enumeration of states" + ) class DiscreteEnv(Env, ABC): @@ -688,7 +708,9 @@ def get_terminating_state_dist(self, states: DiscreteStates) -> torch.Tensor: A 1D tensor of shape `(n_terminating_states,)` with empirical frequencies. """ try: - states_indices = self.get_terminating_states_indices(states).cpu().numpy().tolist() + states_indices = ( + self.get_terminating_states_indices(states).cpu().numpy().tolist() + ) except NotImplementedError as e: warnings.warn( "Environment does not implement state enumeration required for\n" @@ -717,7 +739,9 @@ def get_terminating_state_dist(self, states: DiscreteStates) -> torch.Tensor: "No terminating states provided to compute empirical distribution.", UserWarning, ) - return torch.zeros((self.n_terminating_states,), dtype=torch.get_default_dtype()) + return torch.zeros( + (self.n_terminating_states,), dtype=torch.get_default_dtype() + ) return torch.tensor(counter_list, dtype=torch.get_default_dtype()) / denom @@ -776,17 +800,23 @@ def validate( ) assert isinstance(sampled_terminating_states, DiscreteStates) else: - sampled_terminating_states = visited_terminating_states[-n_validation_samples:] + sampled_terminating_states = visited_terminating_states[ + -n_validation_samples: + ] # Compute empirical distribution; may require enumeration support. try: - final_states_dist = self.get_terminating_state_dist(sampled_terminating_states) + final_states_dist = self.get_terminating_state_dist( + sampled_terminating_states + ) except NotImplementedError: # Already warned in helper; return gracefully. return {}, sampled_terminating_states if final_states_dist.numel() == 0: - warnings.warn("Empirical distribution is empty (no terminating samples).", UserWarning) + warnings.warn( + "Empirical distribution is empty (no terminating samples).", UserWarning + ) return {}, sampled_terminating_states l1_dist = (final_states_dist - true_dist).abs().mean().item() @@ -794,7 +824,9 @@ def validate( # Report logZ difference if both sides are available. learned_logZ: float | None = None - if hasattr(gflownet, "logZ") and isinstance(getattr(gflownet, "logZ"), torch.Tensor): + if hasattr(gflownet, "logZ") and isinstance( + getattr(gflownet, "logZ"), torch.Tensor + ): learned_logZ = float(getattr(gflownet, "logZ").item()) if learned_logZ is not None and true_logZ is not None: validation_info["logZ_diff"] = abs(learned_logZ - true_logZ) @@ -810,7 +842,9 @@ def get_states_indices(self, states: DiscreteStates) -> torch.Tensor: Returns: Tensor of shape (*batch_shape) containing the indices of the states. """ - raise NotImplementedError("The environment does not support enumeration of states") + raise NotImplementedError( + "The environment does not support enumeration of states" + ) def get_terminating_states_indices(self, states: DiscreteStates) -> torch.Tensor: """Optional method to return the indices of the terminating states in the @@ -823,7 +857,9 @@ def get_terminating_states_indices(self, states: DiscreteStates) -> torch.Tensor Tensor of shape (*batch_shape) containing the indices of the terminating states. """ - raise NotImplementedError("The environment does not support enumeration of states") + raise NotImplementedError( + "The environment does not support enumeration of states" + ) @property def n_states(self) -> int: @@ -832,7 +868,9 @@ def n_states(self) -> int: Returns: The number of states. """ - raise NotImplementedError("The environment does not support enumeration of states") + raise NotImplementedError( + "The environment does not support enumeration of states" + ) @property def n_terminating_states(self) -> int: @@ -841,7 +879,9 @@ def n_terminating_states(self) -> int: Returns: The number of terminating states. """ - raise NotImplementedError("The environment does not support enumeration of states") + raise NotImplementedError( + "The environment does not support enumeration of states" + ) @property def all_states(self) -> DiscreteStates: @@ -854,7 +894,9 @@ def all_states(self) -> DiscreteStates: self.get_states_indices(self.all_states) and torch.arange(self.n_states) should be equivalent. """ - raise NotImplementedError("The environment does not support enumeration of states") + raise NotImplementedError( + "The environment does not support enumeration of states" + ) @property def terminating_states(self) -> DiscreteStates: @@ -867,7 +909,9 @@ def terminating_states(self) -> DiscreteStates: self.get_terminating_states_indices(self.terminating_states) and torch.arange(self.n_terminating_states) should be equivalent. """ - raise NotImplementedError("The environment does not support enumeration of states") + raise NotImplementedError( + "The environment does not support enumeration of states" + ) class GraphEnv(Env): @@ -929,8 +973,12 @@ def __init__( self.States = self.make_states_class() self.Actions = self.make_actions_class() - self.dummy_action = self.Actions.make_dummy_actions((1,), device=self.device).tensor - self.exit_action = self.Actions.make_exit_actions((1,), device=self.device).tensor + self.dummy_action = self.Actions.make_dummy_actions( + (1,), device=self.device + ).tensor + self.exit_action = self.Actions.make_exit_actions( + (1,), device=self.device + ).tensor @property def device(self) -> torch.device: diff --git a/src/gfn/estimators.py b/src/gfn/estimators.py index e9db179b..36715520 100644 --- a/src/gfn/estimators.py +++ b/src/gfn/estimators.py @@ -130,7 +130,9 @@ def init_context( initializes empty buffers for per-step artifacts. """ - return RolloutContext(batch_size=batch_size, device=device, conditions=conditions) + return RolloutContext( + batch_size=batch_size, device=device, conditions=conditions + ) def compute_dist( self, @@ -188,7 +190,9 @@ def compute_dist( estimator_outputs = self(states_active) # type: ignore[misc] # Build the distribution. - dist = self.to_probability_distribution(states_active, estimator_outputs, **policy_kwargs) + dist = self.to_probability_distribution( + states_active, estimator_outputs, **policy_kwargs + ) # Save current estimator output only when requested. if save_estimator_outputs: @@ -628,20 +632,26 @@ def _compute_logits_for_distribution( assert not torch.isnan(logits).any(), "Module output logits contain NaNs" # Prepare logits first (masking, bias, temperature) in the existing dtype - x = LogitBasedEstimator._prepare_logits(logits, masks, sf_index, sf_bias, temperature) + x = LogitBasedEstimator._prepare_logits( + logits, masks, sf_index, sf_bias, temperature + ) assert not torch.isnan(x).any(), "Prepared logits contain NaNs" # Perform numerically sensitive ops in float32 when inputs are low-precision orig_dtype = x.dtype compute_dtype = ( - torch.float32 if orig_dtype in (torch.float16, torch.bfloat16) else orig_dtype + torch.float32 + if orig_dtype in (torch.float16, torch.bfloat16) + else orig_dtype ) assert torch.isfinite(x).any(dim=-1).all(), "All -inf row before log-softmax" lsm = torch.log_softmax(x.to(compute_dtype), dim=-1) - assert torch.isfinite(lsm).any(dim=-1).all(), "Invalid log-probs after log_softmax" + assert ( + torch.isfinite(lsm).any(dim=-1).all() + ), "Invalid log-probs after log_softmax" if epsilon == 0.0: return lsm.to(orig_dtype) if lsm.dtype != orig_dtype else lsm @@ -899,9 +909,9 @@ def __init__( preprocessor=preprocessor, is_backward=False, ) - assert reduction in REDUCTION_FUNCTIONS, "reduction function not one of {}".format( - REDUCTION_FUNCTIONS.keys() - ) + assert ( + reduction in REDUCTION_FUNCTIONS + ), "reduction function not one of {}".format(REDUCTION_FUNCTIONS.keys()) self.reduction_function = REDUCTION_FUNCTIONS[reduction] def forward(self, states: States, conditions: torch.Tensor) -> torch.Tensor: @@ -1068,14 +1078,16 @@ def to_probability_distribution( ) # Logit transformations allow for off-policy exploration. - transformed_logits[key] = LogitBasedEstimator._compute_logits_for_distribution( - logits=local_logits, - masks=local_masks, - # ACTION_TYPE_KEY contains the exit action logit. - sf_index=GaType.EXIT if key == Ga.ACTION_TYPE_KEY else None, - sf_bias=sf_bias if key == Ga.ACTION_TYPE_KEY else 0.0, - temperature=temperature[key], - epsilon=epsilon[key], + transformed_logits[key] = ( + LogitBasedEstimator._compute_logits_for_distribution( + logits=local_logits, + masks=local_masks, + # ACTION_TYPE_KEY contains the exit action logit. + sf_index=GaType.EXIT if key == Ga.ACTION_TYPE_KEY else None, + sf_bias=sf_bias if key == Ga.ACTION_TYPE_KEY else 0.0, + temperature=temperature[key], + epsilon=epsilon[key], + ) ) return GraphActionDistribution( @@ -1176,7 +1188,9 @@ def forward( # Replace padding (-1) with BOS index expected by the sequence model. # RecurrentDiscreteSequenceModel reserves index == vocab_size for BOS. bos_index = getattr(self.module, "vocab_size", self.n_actions - 1) - tokens = torch.where(tokens < 0, torch.as_tensor(bos_index, device=tokens.device), tokens) + tokens = torch.where( + tokens < 0, torch.as_tensor(bos_index, device=tokens.device), tokens + ) # Determine a common prefix length across the (active) batch. # Active rows in a rollout step share the same length; use max for safety. @@ -1212,7 +1226,9 @@ def init_carry( ) -> dict[str, torch.Tensor]: init_carry = getattr(self.module, "init_carry", None) if not callable(init_carry): - raise NotImplementedError("Module does not implement init_carry(batch_size, device).") + raise NotImplementedError( + "Module does not implement init_carry(batch_size, device)." + ) init_carry_fn = cast(Callable[[int, torch.device], Any], init_carry) return init_carry_fn(batch_size, device) @@ -1487,7 +1503,9 @@ def to_probability_distribution( base_mean = torch.where( is_s0, torch.zeros_like(s_curr), - s_curr * self.dt / t_curr, # s_curr (batch, s_dim), t_curr (batch, 1), dt is scalar. + s_curr + * self.dt + / t_curr, # s_curr (batch, s_dim), t_curr (batch, 1), dt is scalar. ) base_std = torch.where( is_s0, diff --git a/src/gfn/gflownet/mle.py b/src/gfn/gflownet/mle.py index 37946b49..87a5d547 100644 --- a/src/gfn/gflownet/mle.py +++ b/src/gfn/gflownet/mle.py @@ -172,13 +172,15 @@ def loss( log_2pi = math.log(2 * math.pi) logpf_sum = torch.zeros(bsz, device=device, dtype=dtype) - exploration_std_t = torch.as_tensor(exploration_std, device=device, dtype=dtype).clamp( - min=0.0 - ) + exploration_std_t = torch.as_tensor( + exploration_std, device=device, dtype=dtype + ).clamp(min=0.0) exploration_var = exploration_std_t**2 # Precompute time grids to avoid per-step allocations. - all_t_fwd = torch.linspace(1.0 - dt, 0.0, self.num_steps, device=device, dtype=dtype) + all_t_fwd = torch.linspace( + 1.0 - dt, 0.0, self.num_steps, device=device, dtype=dtype + ) all_t_curr = torch.linspace(1.0, dt, self.num_steps, device=device, dtype=dtype) for i in range(self.num_steps): @@ -221,7 +223,9 @@ def loss( std = torch.exp(log_std) * sqrt_dt_t_scale std = torch.sqrt(std**2 + exploration_var) diff = increment - dt * drift - logpf_step = -0.5 * ((diff / std) ** 2 + 2 * std.log() + log_2pi).sum(dim=1) + logpf_step = -0.5 * ((diff / std) ** 2 + 2 * std.log() + log_2pi).sum( + dim=1 + ) # Fixed variance case. else: drift = module_out diff --git a/src/gfn/gym/diffusion_sampling.py b/src/gfn/gym/diffusion_sampling.py index 5d4cf23d..68fc367c 100644 --- a/src/gfn/gym/diffusion_sampling.py +++ b/src/gfn/gym/diffusion_sampling.py @@ -203,7 +203,9 @@ def __init__( rng = np.random.default_rng(seed) if locs is None: - locs = rng.uniform(mean_val_range[0], mean_val_range[1], size=(num_components, dim)) + locs = rng.uniform( + mean_val_range[0], mean_val_range[1], size=(num_components, dim) + ) elif isinstance(locs, np.ndarray): assert locs.shape == (num_components, dim) assert (locs >= mean_val_range[0]).all() and ( @@ -225,8 +227,12 @@ def __init__( print("+ num_components: ", num_components) print("+ mixture_weights: ", mixture_weights) for i, (loc, cov) in enumerate(zip(locs, covariances)): - loc_str = np.array2string(loc, precision=2, separator=", ").replace("\n", " ") - cov_str = np.array2string(cov, precision=2, separator=", ").replace("\n", " ") + loc_str = np.array2string(loc, precision=2, separator=", ").replace( + "\n", " " + ) + cov_str = np.array2string(cov, precision=2, separator=", ").replace( + "\n", " " + ) print(f"\tComponent {i+1}: loc={loc_str}, cov={cov_str}") # Convert to torch tensors @@ -318,7 +324,9 @@ def visualize( assert self.plot_border is not None, "Visualization requires a plot border." if self.dim != 2: - raise ValueError(f"Visualization is only supported for 2D, but got {self.dim}D") + raise ValueError( + f"Visualization is only supported for 2D, but got {self.dim}D" + ) fig = plt.figure() ax = fig.add_subplot() @@ -343,16 +351,24 @@ def visualize( ax.contourf(x, y, pdf_values, levels=20) # , cmap='viridis') if samples is not None: plt.scatter( - samples[:max_n_samples, 0].clamp(self.plot_border[0], self.plot_border[1]), - samples[:max_n_samples, 1].clamp(self.plot_border[2], self.plot_border[3]), + samples[:max_n_samples, 0].clamp( + self.plot_border[0], self.plot_border[1] + ), + samples[:max_n_samples, 1].clamp( + self.plot_border[2], self.plot_border[3] + ), c="r", alpha=0.5, marker="x", ) # Add dashed lines at 0 - ax.axhline(y=0, color="white", linestyle="--", linewidth=1, alpha=0.7, label="y=0") - ax.axvline(x=0, color="white", linestyle="--", linewidth=1, alpha=0.7, label="x=0") + ax.axhline( + y=0, color="white", linestyle="--", linewidth=1, alpha=0.7, label="y=0" + ) + ax.axvline( + x=0, color="white", linestyle="--", linewidth=1, alpha=0.7, label="x=0" + ) # Add dashed lines at each mode modes = self.distribution.component_distribution.loc @@ -415,12 +431,16 @@ def __init__( dtype=torch.get_default_dtype(), ) mix = D.Categorical( - probs=torch.full((self.locs.shape[0],), 1.0 / self.locs.shape[0], device=device) + probs=torch.full( + (self.locs.shape[0],), 1.0 / self.locs.shape[0], device=device + ) ) comp = D.Independent(D.Normal(self.locs, scale * torch.ones_like(self.locs)), 1) self.gmm = D.MixtureSameFamily(mix, comp) - super().__init__(device=device, dim=dim, n_gt_xs=2048, seed=seed, plot_border=plot_border) + super().__init__( + device=device, dim=dim, n_gt_xs=2048, seed=seed, plot_border=plot_border + ) def log_reward(self, x: torch.Tensor) -> torch.Tensor: return self.gmm.log_prob(x).flatten() @@ -503,7 +523,9 @@ def __init__( comp = D.Independent(D.Normal(locs, scale * torch.ones_like(locs)), 1) self.posterior = D.MixtureSameFamily(mix, comp) - super().__init__(device=device, dim=dim, n_gt_xs=2048, seed=seed, plot_border=plot_border) + super().__init__( + device=device, dim=dim, n_gt_xs=2048, seed=seed, plot_border=plot_border + ) def log_reward(self, x: torch.Tensor) -> torch.Tensor: # r(x) = p_post(x) / p_prior(x) @@ -572,7 +594,9 @@ def __init__( torch.tensor([0.0], device=device, dtype=dtype), torch.tensor([std], device=device, dtype=dtype), ) - super().__init__(device=device, dim=dim, n_gt_xs=10_000, plot_border=10.0, seed=seed) + super().__init__( + device=device, dim=dim, n_gt_xs=10_000, plot_border=10.0, seed=seed + ) def log_reward(self, x: torch.Tensor) -> torch.Tensor: """Log-density of Neal's funnel distribution. @@ -588,7 +612,9 @@ def log_reward(self, x: torch.Tensor) -> torch.Tensor: log_sigma = 0.5 * x[:, 0:1] sigma2 = torch.exp(x[:, 0:1]) - neg_log_prob_other = 0.5 * np.log(2 * np.pi) + log_sigma + 0.5 * x[:, 1:] ** 2 / sigma2 + neg_log_prob_other = ( + 0.5 * np.log(2 * np.pi) + log_sigma + 0.5 * x[:, 1:] ** 2 / sigma2 + ) log_prob_other = torch.sum(-neg_log_prob_other, dim=-1) log_prob = log_prob_x0 + log_prob_other @@ -668,7 +694,9 @@ def __init__( device: torch.device = torch.device("cpu"), seed: int = 0, ) -> None: - assert dim % 2 == 0, "ManyWellTarget requires an even dimension (pairs of coordinates)." + assert ( + dim % 2 == 0 + ), "ManyWellTarget requires an even dimension (pairs of coordinates)." # Simple mixture proposal for x1: 3 equally weighted Normals self.component_mix = torch.tensor([1 / 3, 1 / 3, 1 / 3], device=device) @@ -729,7 +757,9 @@ def _compute_envelope_k(self, proposal: D.Distribution) -> float: return float(1.2 * k) # small safety margin @staticmethod - def _rejection_sampling_x1(n_samples: int, proposal: D.Distribution, k: float) -> torch.Tensor: + def _rejection_sampling_x1( + n_samples: int, proposal: D.Distribution, k: float + ) -> torch.Tensor: # Basic rejection sampler with vectorized batches and refill loop collected: list[torch.Tensor] = [] remaining = n_samples @@ -1012,7 +1042,9 @@ def is_action_valid( states.tensor[..., -1] == time ).all(), "Time must be the same for all states in the batch" - if not backward and time >= (1.0 - eps): # Terminate if near 1.0 for forward steps + if not backward and time >= ( + 1.0 - eps + ): # Terminate if near 1.0 for forward steps sf = cast(torch.Tensor, self.sf) return bool((actions.tensor == sf[:-1]).all().item()) elif backward and time <= eps: # Return to s0 when near 0.0 for backward steps @@ -1053,11 +1085,17 @@ def density_metrics( elbo = log_weights.mean().item() # EUBO, if the ground truth samples are available - if bwd_log_rewards is not None and bwd_log_pfs is not None and bwd_log_pbs is not None: + if ( + bwd_log_rewards is not None + and bwd_log_pfs is not None + and bwd_log_pbs is not None + ): gt_bsz = bwd_log_pfs.shape[1] assert gt_bsz == bwd_log_pbs.shape[1] == bwd_log_rewards.shape[0] assert bwd_log_pfs.ndim == bwd_log_pbs.ndim == 2 - eubo = (bwd_log_rewards + bwd_log_pbs.sum(0) - bwd_log_pfs.sum(0)).mean().item() + eubo = ( + (bwd_log_rewards + bwd_log_pbs.sum(0) - bwd_log_pfs.sum(0)).mean().item() + ) else: eubo = float("nan") diff --git a/testing/test_environments.py b/testing/test_environments.py index 55285e6e..05c78720 100644 --- a/testing/test_environments.py +++ b/testing/test_environments.py @@ -142,7 +142,9 @@ def test_DiscreteEBM_fwd_step(): BATCH_SIZE = 4 env = DiscreteEBM(ndim=NDIM, debug=True) - states = env.reset(batch_shape=BATCH_SIZE, seed=1234) # Instantiate a batch of initial states + states = env.reset( + batch_shape=BATCH_SIZE, seed=1234 + ) # Instantiate a batch of initial states assert (states.batch_shape[0], states.state_shape[0]) == (BATCH_SIZE, NDIM) # Trying the step function starting from 3 instances of s_0 @@ -210,7 +212,9 @@ def test_box_fwd_step(delta: float): ] for failing_actions_list in failing_actions_lists_at_s0: - actions = env.actions_from_tensor(format_tensor(failing_actions_list, discrete=False)) + actions = env.actions_from_tensor( + format_tensor(failing_actions_list, discrete=False) + ) with pytest.raises(NonValidActionsError): states = env._step(states, actions) @@ -230,7 +234,9 @@ def test_box_fwd_step(delta: float): actions_tensor = torch.tensor([0.2, 0.3, 0.4]) * (B - A) + A actions_tensor *= np.pi / 2 actions_tensor = ( - torch.stack([torch.cos(actions_tensor), torch.sin(actions_tensor)], dim=1) + torch.stack( + [torch.cos(actions_tensor), torch.sin(actions_tensor)], dim=1 + ) * env.delta ) actions_tensor[B - A < 0] = torch.tensor([-float("inf"), -float("inf")]) @@ -348,7 +354,9 @@ def test_graph_env(): GraphActions.NODE_CLASS_KEY: torch.randint( 0, 10, (BATCH_SIZE,), dtype=torch.long ), - GraphActions.NODE_INDEX_KEY: torch.zeros((BATCH_SIZE,), dtype=torch.long), + GraphActions.NODE_INDEX_KEY: torch.zeros( + (BATCH_SIZE,), dtype=torch.long + ), GraphActions.EDGE_CLASS_KEY: torch.randint( 0, 10, (BATCH_SIZE,), dtype=torch.long ), @@ -398,7 +406,9 @@ def test_graph_env(): GraphActions.NODE_CLASS_KEY: torch.randint( 0, 10, (BATCH_SIZE,), dtype=torch.long ), - GraphActions.NODE_INDEX_KEY: torch.zeros((BATCH_SIZE,), dtype=torch.long), + GraphActions.NODE_INDEX_KEY: torch.zeros( + (BATCH_SIZE,), dtype=torch.long + ), GraphActions.EDGE_INDEX_KEY: torch.tensor([i] * BATCH_SIZE), GraphActions.EDGE_CLASS_KEY: torch.randint( 0, 10, (BATCH_SIZE,), dtype=torch.long @@ -412,11 +422,21 @@ def test_graph_env(): actions = action_cls.from_tensor_dict( TensorDict( { - GraphActions.ACTION_TYPE_KEY: torch.full((BATCH_SIZE,), GraphActionType.EXIT), - GraphActions.NODE_CLASS_KEY: torch.zeros((BATCH_SIZE,), dtype=torch.long), - GraphActions.NODE_INDEX_KEY: torch.zeros((BATCH_SIZE,), dtype=torch.long), - GraphActions.EDGE_CLASS_KEY: torch.zeros((BATCH_SIZE,), dtype=torch.long), - GraphActions.EDGE_INDEX_KEY: torch.zeros((BATCH_SIZE,), dtype=torch.long), + GraphActions.ACTION_TYPE_KEY: torch.full( + (BATCH_SIZE,), GraphActionType.EXIT + ), + GraphActions.NODE_CLASS_KEY: torch.zeros( + (BATCH_SIZE,), dtype=torch.long + ), + GraphActions.NODE_INDEX_KEY: torch.zeros( + (BATCH_SIZE,), dtype=torch.long + ), + GraphActions.EDGE_CLASS_KEY: torch.zeros( + (BATCH_SIZE,), dtype=torch.long + ), + GraphActions.EDGE_INDEX_KEY: torch.zeros( + (BATCH_SIZE,), dtype=torch.long + ), }, batch_size=BATCH_SIZE, ) @@ -435,8 +455,12 @@ def test_graph_env(): GraphActions.ACTION_TYPE_KEY: torch.full( (BATCH_SIZE,), GraphActionType.ADD_EDGE ), - GraphActions.NODE_CLASS_KEY: torch.zeros((BATCH_SIZE,), dtype=torch.long), - GraphActions.NODE_INDEX_KEY: torch.zeros((BATCH_SIZE,), dtype=torch.long), + GraphActions.NODE_CLASS_KEY: torch.zeros( + (BATCH_SIZE,), dtype=torch.long + ), + GraphActions.NODE_INDEX_KEY: torch.zeros( + (BATCH_SIZE,), dtype=torch.long + ), GraphActions.EDGE_CLASS_KEY: torch.randint( 0, 10, (BATCH_SIZE,), dtype=torch.long ), @@ -455,8 +479,12 @@ def test_graph_env(): GraphActions.ACTION_TYPE_KEY: torch.full( (BATCH_SIZE,), GraphActionType.ADD_EDGE ), - GraphActions.NODE_CLASS_KEY: torch.zeros((BATCH_SIZE,), dtype=torch.long), - GraphActions.NODE_INDEX_KEY: torch.zeros((BATCH_SIZE,), dtype=torch.long), + GraphActions.NODE_CLASS_KEY: torch.zeros( + (BATCH_SIZE,), dtype=torch.long + ), + GraphActions.NODE_INDEX_KEY: torch.zeros( + (BATCH_SIZE,), dtype=torch.long + ), GraphActions.EDGE_CLASS_KEY: torch.randint( 0, 10, (BATCH_SIZE,), dtype=torch.long ), @@ -477,10 +505,16 @@ def test_graph_env(): GraphActions.ACTION_TYPE_KEY: torch.full( (BATCH_SIZE,), GraphActionType.ADD_NODE ), - GraphActions.NODE_CLASS_KEY: torch.zeros((BATCH_SIZE,), dtype=torch.long), + GraphActions.NODE_CLASS_KEY: torch.zeros( + (BATCH_SIZE,), dtype=torch.long + ), GraphActions.NODE_INDEX_KEY: torch.tensor([i] * BATCH_SIZE), - GraphActions.EDGE_CLASS_KEY: torch.zeros((BATCH_SIZE,), dtype=torch.long), - GraphActions.EDGE_INDEX_KEY: torch.zeros((BATCH_SIZE,), dtype=torch.long), + GraphActions.EDGE_CLASS_KEY: torch.zeros( + (BATCH_SIZE,), dtype=torch.long + ), + GraphActions.EDGE_INDEX_KEY: torch.zeros( + (BATCH_SIZE,), dtype=torch.long + ), }, batch_size=BATCH_SIZE, ) @@ -493,11 +527,21 @@ def test_graph_env(): actions = action_cls.from_tensor_dict( TensorDict( { - GraphActions.ACTION_TYPE_KEY: torch.full((BATCH_SIZE,), GraphActionType.ADD_NODE), - GraphActions.NODE_CLASS_KEY: torch.randint(0, 10, (BATCH_SIZE,), dtype=torch.long), - GraphActions.NODE_INDEX_KEY: torch.zeros((BATCH_SIZE,), dtype=torch.long), - GraphActions.EDGE_CLASS_KEY: torch.zeros((BATCH_SIZE,), dtype=torch.long), - GraphActions.EDGE_INDEX_KEY: torch.zeros((BATCH_SIZE,), dtype=torch.long), + GraphActions.ACTION_TYPE_KEY: torch.full( + (BATCH_SIZE,), GraphActionType.ADD_NODE + ), + GraphActions.NODE_CLASS_KEY: torch.randint( + 0, 10, (BATCH_SIZE,), dtype=torch.long + ), + GraphActions.NODE_INDEX_KEY: torch.zeros( + (BATCH_SIZE,), dtype=torch.long + ), + GraphActions.EDGE_CLASS_KEY: torch.zeros( + (BATCH_SIZE,), dtype=torch.long + ), + GraphActions.EDGE_INDEX_KEY: torch.zeros( + (BATCH_SIZE,), dtype=torch.long + ), }, batch_size=BATCH_SIZE, ) @@ -515,9 +559,15 @@ def test_graph_env(): GraphActions.NODE_CLASS_KEY: torch.randint( 0, 10, (BATCH_SIZE,), dtype=torch.long ), - GraphActions.NODE_INDEX_KEY: torch.ones((BATCH_SIZE,), dtype=torch.long), - GraphActions.EDGE_CLASS_KEY: torch.zeros((BATCH_SIZE,), dtype=torch.long), - GraphActions.EDGE_INDEX_KEY: torch.zeros((BATCH_SIZE,), dtype=torch.long), + GraphActions.NODE_INDEX_KEY: torch.ones( + (BATCH_SIZE,), dtype=torch.long + ), + GraphActions.EDGE_CLASS_KEY: torch.zeros( + (BATCH_SIZE,), dtype=torch.long + ), + GraphActions.EDGE_INDEX_KEY: torch.zeros( + (BATCH_SIZE,), dtype=torch.long + ), }, batch_size=BATCH_SIZE, ) @@ -528,11 +578,21 @@ def test_graph_env(): actions = action_cls.from_tensor_dict( TensorDict( { - GraphActions.ACTION_TYPE_KEY: torch.full((BATCH_SIZE,), GraphActionType.ADD_NODE), - GraphActions.NODE_CLASS_KEY: torch.zeros((BATCH_SIZE,), dtype=torch.long), - GraphActions.NODE_INDEX_KEY: torch.zeros((BATCH_SIZE,), dtype=torch.long), - GraphActions.EDGE_CLASS_KEY: torch.zeros((BATCH_SIZE,), dtype=torch.long), - GraphActions.EDGE_INDEX_KEY: torch.zeros((BATCH_SIZE,), dtype=torch.long), + GraphActions.ACTION_TYPE_KEY: torch.full( + (BATCH_SIZE,), GraphActionType.ADD_NODE + ), + GraphActions.NODE_CLASS_KEY: torch.zeros( + (BATCH_SIZE,), dtype=torch.long + ), + GraphActions.NODE_INDEX_KEY: torch.zeros( + (BATCH_SIZE,), dtype=torch.long + ), + GraphActions.EDGE_CLASS_KEY: torch.zeros( + (BATCH_SIZE,), dtype=torch.long + ), + GraphActions.EDGE_INDEX_KEY: torch.zeros( + (BATCH_SIZE,), dtype=torch.long + ), }, batch_size=BATCH_SIZE, ) @@ -555,13 +615,17 @@ def test_set_addition_fwd_step(): # Add item 0 and 1 actions = env.actions_from_tensor(format_tensor([0, 1])) states = env._step(states, actions) - expected_states = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0]], dtype=torch.get_default_dtype()) + expected_states = torch.tensor( + [[1, 0, 0, 0], [0, 1, 0, 0]], dtype=torch.get_default_dtype() + ) assert torch.equal(states.tensor, expected_states) # Add item 2 and 3 actions = env.actions_from_tensor(format_tensor([2, 3])) states = env._step(states, actions) - expected_states = torch.tensor([[1, 0, 1, 0], [0, 1, 0, 1]], dtype=torch.get_default_dtype()) + expected_states = torch.tensor( + [[1, 0, 1, 0], [0, 1, 0, 1]], dtype=torch.get_default_dtype() + ) assert torch.equal(states.tensor, expected_states) # Try adding existing items (invalid) @@ -572,7 +636,9 @@ def test_set_addition_fwd_step(): # Add item 3 and 0 actions = env.actions_from_tensor(format_tensor([3, 0])) states = env._step(states, actions) - expected_states = torch.tensor([[1, 0, 1, 1], [1, 1, 0, 1]], dtype=torch.get_default_dtype()) + expected_states = torch.tensor( + [[1, 0, 1, 1], [1, 1, 0, 1]], dtype=torch.get_default_dtype() + ) assert torch.equal(states.tensor, expected_states) # Now has 3 items # Try adding another item (invalid, max_items reached) @@ -737,7 +803,9 @@ def step(self, states, actions): # pragma: no cover - not used in this test def backward_step(self, states, actions): # pragma: no cover - not used return states - def is_action_valid(self, states, actions, backward: bool = False) -> bool: # noqa: ARG002 + def is_action_valid( + self, states, actions, backward: bool = False + ) -> bool: # noqa: ARG002 return True @@ -888,7 +956,9 @@ def test_diffusion_trajectory_mask_alignment(): ) # Compute masks the same way get_trajectory_pbs does. - state_mask = ~trajectories.states.is_sink_state & ~trajectories.states.is_initial_state + state_mask = ( + ~trajectories.states.is_sink_state & ~trajectories.states.is_initial_state + ) state_mask[0, :] = False # Can't compute PB for first state row. action_mask = ~trajectories.actions.is_dummy & ~trajectories.actions.is_exit diff --git a/tutorials/examples/test_scripts.py b/tutorials/examples/test_scripts.py index 90f08629..5ee28607 100644 --- a/tutorials/examples/test_scripts.py +++ b/tutorials/examples/test_scripts.py @@ -22,12 +22,18 @@ # Ensure we run with Python debug mode enabled (no -O) so envs use debug guards. assert __debug__, "Tests must run without -O so __debug__ stays True." -from tutorials.examples.train_bayesian_structure import main as train_bayesian_structure_main +from tutorials.examples.train_bayesian_structure import ( + main as train_bayesian_structure_main, +) from tutorials.examples.train_bit_sequences import main as train_bitsequence_main -from tutorials.examples.train_bitsequence_recurrent import main as train_bitsequence_recurrent_main +from tutorials.examples.train_bitsequence_recurrent import ( + main as train_bitsequence_recurrent_main, +) from tutorials.examples.train_box import main as train_box_main from tutorials.examples.train_conditional import main as train_conditional_main -from tutorials.examples.train_diffusion_sampler import main as train_diffusion_sampler_main +from tutorials.examples.train_diffusion_sampler import ( + main as train_diffusion_sampler_main, +) from tutorials.examples.train_discreteebm import main as train_discreteebm_main from tutorials.examples.train_graph_ring import main as train_graph_ring_main from tutorials.examples.train_graph_triangle import main as train_graph_triangle_main @@ -43,7 +49,9 @@ from tutorials.examples.train_hypergrid_simple import main as train_hypergrid_simple_main from tutorials.examples.train_ising import main as train_ising_main from tutorials.examples.train_line import main as train_line_main -from tutorials.examples.train_with_example_modes import main as train_with_example_modes_main +from tutorials.examples.train_with_example_modes import ( + main as train_with_example_modes_main, +) @dataclass @@ -329,7 +337,9 @@ def test_hypergrid_tb(ndim: int, height: int, replay_buffer_size: int): # TODO: Why is this skipped? if replay_buffer_size != 0: pytest.skip("Skipping test for replay buffer size != 0") - assert np.isclose(final_l1_dist, tgt, atol=atol), f"final_l1_dist: {final_l1_dist} vs {tgt}" + assert np.isclose( + final_l1_dist, tgt, atol=atol + ), f"final_l1_dist: {final_l1_dist} vs {tgt}" elif ndim == 4 and height == 8: tgt = 1.6e-4 atol = 1e-4 @@ -349,7 +359,9 @@ def test_hypergrid_tb(ndim: int, height: int, replay_buffer_size: int): pytest.skip("Skipping test for replay buffer size != 0") tgt = 2.224e-05 # 6.89e-6 atol = 1e-5 - assert np.isclose(final_l1_dist, tgt, atol=atol), f"final_l1_dist: {final_l1_dist} vs {tgt}" + assert np.isclose( + final_l1_dist, tgt, atol=atol + ), f"final_l1_dist: {final_l1_dist} vs {tgt}" @pytest.mark.parametrize("ndim", [2, 4]) @@ -419,15 +431,21 @@ def test_discreteebm(ndim: int, alpha: float): if ndim == 2 and alpha == 0.1: tgt = 2.6972e-2 # 2.97e-3 atol = 1e-1 # TODO: this tolerance is very suspicious. - assert np.isclose(final_l1_dist, tgt, atol=atol), f"final_l1_dist: {final_l1_dist} vs {tgt}" + assert np.isclose( + final_l1_dist, tgt, atol=atol + ), f"final_l1_dist: {final_l1_dist} vs {tgt}" elif ndim == 2 and alpha == 1.0: tgt = 1.3159e-1 # 0.017 atol = 1e-1 - assert np.isclose(final_l1_dist, tgt, atol=atol), f"final_l1_dist: {final_l1_dist} vs {tgt}" + assert np.isclose( + final_l1_dist, tgt, atol=atol + ), f"final_l1_dist: {final_l1_dist} vs {tgt}" elif ndim == 4 and alpha == 0.1: tgt = 2.46e-2 # 0.009 atol = 1e-2 - assert np.isclose(final_l1_dist, tgt, atol=atol), f"final_l1_dist: {final_l1_dist} vs {tgt}" + assert np.isclose( + final_l1_dist, tgt, atol=atol + ), f"final_l1_dist: {final_l1_dist} vs {tgt}" elif ndim == 4 and alpha == 1.0: tgt1 = 8.675e-2 # 0.062 tgt2 = 6.2e-2 @@ -435,7 +453,9 @@ def test_discreteebm(ndim: int, alpha: float): test_1 = np.isclose(final_l1_dist, tgt1, atol=atol) test_2 = np.isclose(final_l1_dist, tgt2, atol=atol) - assert test_1 or test_2, f"final_l1_dist: {final_l1_dist} not close to [{tgt1}, {tgt2}]" + assert ( + test_1 or test_2 + ), f"final_l1_dist: {final_l1_dist} not close to [{tgt1}, {tgt2}]" @pytest.mark.parametrize("delta", [0.1, 0.25]) @@ -574,7 +594,9 @@ def test_hypergrid_simple_ls_smoke(): ) args_dict = asdict(args) namespace_args = Namespace(**args_dict) - train_hypergrid_local_search_main(namespace_args) # Just ensure it runs without errors. + train_hypergrid_local_search_main( + namespace_args + ) # Just ensure it runs without errors. def test_ising_smoke(): diff --git a/tutorials/examples/train_diffusion_rtb.py b/tutorials/examples/train_diffusion_rtb.py index d9b7ba27..76452835 100644 --- a/tutorials/examples/train_diffusion_rtb.py +++ b/tutorials/examples/train_diffusion_rtb.py @@ -108,7 +108,9 @@ def get_exploration_std( # Tensor ops only (torch.compile-friendly): no Python branching on iteration. iter_t = torch.tensor(iteration, device=device, dtype=dtype) # Clamp negatives to zero to avoid Python-side checks/overhead. - factor_t = torch.clamp(torch.tensor(exploration_factor, device=device, dtype=dtype), min=0.0) + factor_t = torch.clamp( + torch.tensor(exploration_factor, device=device, dtype=dtype), min=0.0 + ) start_t = torch.tensor(warm_down_start, device=device, dtype=dtype) end_t = torch.tensor(warm_down_end, device=device, dtype=dtype) @@ -269,7 +271,9 @@ def _save_checkpoint(pf_prior, pb_prior, optimizer, it, ckpt_path): with torch.no_grad(): batch = env_prior.target.sample(args.batch_size) optimizer.zero_grad() - loss = mle_trainer.loss(env_prior, batch, exploration_std=args.pretrain_exploration_factor) + loss = mle_trainer.loss( + env_prior, batch, exploration_std=args.pretrain_exploration_factor + ) loss.backward() if __debug__: total_norm, has_nan = get_debug_metrics(pf_prior) @@ -422,7 +426,9 @@ def main(args: argparse.Namespace) -> None: {"params": gflownet.pf.parameters(), "lr": args.lr}, {"params": gflownet.logz_parameters(), "lr": args.lr_logz}, ] - optimizer = torch.optim.Adam(param_groups, lr=args.lr, weight_decay=args.weight_decay) + optimizer = torch.optim.Adam( + param_groups, lr=args.lr, weight_decay=args.weight_decay + ) for it in (pbar := tqdm(range(args.n_iterations), dynamic_ncols=True)): trajectories = sampler.sample_trajectories( From dd2c4e41b8fefc9cfe22d5bfcc3682de8bd7f9bf Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 18 Dec 2025 02:39:02 +0000 Subject: [PATCH 15/26] Initial plan From c8eb3515e265f37551212a98d9e5c15af4b169a2 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 18 Dec 2025 02:41:06 +0000 Subject: [PATCH 16/26] Configure codecov to not block CI/merge requests Co-authored-by: josephdviviano <4142570+josephdviviano@users.noreply.github.com> --- .codecov.yml | 15 +++++++++++++++ .github/workflows/ci.yml | 2 +- 2 files changed, 16 insertions(+), 1 deletion(-) create mode 100644 .codecov.yml diff --git a/.codecov.yml b/.codecov.yml new file mode 100644 index 00000000..77776b16 --- /dev/null +++ b/.codecov.yml @@ -0,0 +1,15 @@ +coverage: + status: + project: + default: + # Set to informational only - will not block PRs + informational: true + patch: + default: + # Set to informational only - will not block PRs + informational: true + +comment: + # Still show coverage comments on PRs + layout: "diff, flags, files" + behavior: default diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index fb52a54b..5216a411 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -44,5 +44,5 @@ jobs: uses: codecov/codecov-action@v5 with: files: coverage.xml - fail_ci_if_error: true + fail_ci_if_error: false token: ${{ secrets.CODECOV_TOKEN }} \ No newline at end of file From 63a4bd1f05570c4ef36feeb4ccdbff34012856dc Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Wed, 17 Dec 2025 22:13:04 -0500 Subject: [PATCH 17/26] Update src/gfn/estimators.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- src/gfn/estimators.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gfn/estimators.py b/src/gfn/estimators.py index 36715520..8ff5f824 100644 --- a/src/gfn/estimators.py +++ b/src/gfn/estimators.py @@ -1498,7 +1498,7 @@ def to_probability_distribution( # Analytic Brownian bridge base # Brownian bridge mean toward 0 at t=0: # E[s_{t-dt} | s_t] = s_t * (1 - dt / t) and collapses to 0 at the start. - # Here, we calculcate the *action* which moves the state in expectation toward 0 + # Here, we calculate the *action* which moves the state in expectation toward 0 # at t=0, so we scale s_curr by our distance to t=0. base_mean = torch.where( is_s0, From 46807edde6bb315f33f8c4493968543c44edb3b1 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Wed, 17 Dec 2025 22:13:54 -0500 Subject: [PATCH 18/26] Update tutorials/examples/train_diffusion_rtb.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- tutorials/examples/train_diffusion_rtb.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tutorials/examples/train_diffusion_rtb.py b/tutorials/examples/train_diffusion_rtb.py index 76452835..81dc20c0 100644 --- a/tutorials/examples/train_diffusion_rtb.py +++ b/tutorials/examples/train_diffusion_rtb.py @@ -7,7 +7,7 @@ finetuning starts from a learned prior. - Posterior is fine-tuned from this prior (pf). -By default, uses the 25→9 GMM posterior target (`gmm25_posterior9`) by default with a +By default, uses the 25→9 GMM posterior target (`gmm25_posterior9`) with a learnable posterior forward policy and a fixed prior forward policy. Loss is RTB (no backward policy). This script outputs the prior weights alongside plots of samples from both the prior and posterior distributions. From b40655bd45ff6d19d8fae358fda0e48e3d2c7f5a Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Wed, 17 Dec 2025 22:14:33 -0500 Subject: [PATCH 19/26] Update tutorials/examples/train_diffusion_rtb.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- tutorials/examples/train_diffusion_rtb.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tutorials/examples/train_diffusion_rtb.py b/tutorials/examples/train_diffusion_rtb.py index 81dc20c0..560e8149 100644 --- a/tutorials/examples/train_diffusion_rtb.py +++ b/tutorials/examples/train_diffusion_rtb.py @@ -408,7 +408,7 @@ def main(args: argparse.Namespace) -> None: f"pretrained weights not found at {args.prior_ckpt_path}, pretraining failed" ) - # During finetuning, the prior is fixed, no grad, + # During finetuning, the prior is fixed, no grad. pf_prior.eval() for p in pf_prior.parameters(): p.requires_grad_(False) From 74b8a607aa8e8752265f81a169656ad190341676 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Wed, 17 Dec 2025 22:15:55 -0500 Subject: [PATCH 20/26] Update tutorials/examples/train_diffusion_rtb.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- tutorials/examples/train_diffusion_rtb.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tutorials/examples/train_diffusion_rtb.py b/tutorials/examples/train_diffusion_rtb.py index 560e8149..c1c9d475 100644 --- a/tutorials/examples/train_diffusion_rtb.py +++ b/tutorials/examples/train_diffusion_rtb.py @@ -434,7 +434,7 @@ def main(args: argparse.Namespace) -> None: trajectories = sampler.sample_trajectories( env, n=args.batch_size, - save_logprobs=False, # if args.exploration_factor > 0 else True, + save_logprobs=False, save_estimator_outputs=False, # Extra exploration noise (combined with base PF variance in estimator). exploration_std=get_exploration_std( From 7deba0e42ee230f6da6e0a2fc7d6dca747fca47f Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Wed, 17 Dec 2025 22:22:10 -0500 Subject: [PATCH 21/26] Update src/gfn/utils/modules.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- src/gfn/utils/modules.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/gfn/utils/modules.py b/src/gfn/utils/modules.py index 8fec8539..393d807a 100644 --- a/src/gfn/utils/modules.py +++ b/src/gfn/utils/modules.py @@ -1759,7 +1759,6 @@ def forward( if self.clipping: out = torch.clamp(out, -self.gfn_clip, self.gfn_clip) - # TODO: learn variance, lp, clipping, ... if torch.isnan(out).any(): raise ValueError("DiffusionPISGradNetForward produced NaNs") From a2f5a6cba5e1719652e2eb8e2a67e27c6f5de95b Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Wed, 17 Dec 2025 22:23:54 -0500 Subject: [PATCH 22/26] Update src/gfn/gflownet/trajectory_balance.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- src/gfn/gflownet/trajectory_balance.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gfn/gflownet/trajectory_balance.py b/src/gfn/gflownet/trajectory_balance.py index 7a2199bd..9c92134a 100644 --- a/src/gfn/gflownet/trajectory_balance.py +++ b/src/gfn/gflownet/trajectory_balance.py @@ -181,7 +181,7 @@ def __init__( log_reward_clip_min=log_reward_clip_min, ) self.prior_pf = prior_pf - self.beta = torch.tensor(beta) + self.register_buffer("beta", torch.tensor(beta)) self.logZ = logZ or nn.Parameter(torch.tensor(init_logZ)) self.debug = debug # TODO: to be passed to base classes. From ee22f3cfcc662b76be22ce0e87f73395f3269b54 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Wed, 17 Dec 2025 22:25:26 -0500 Subject: [PATCH 23/26] Update src/gfn/gym/diffusion_sampling.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- src/gfn/gym/diffusion_sampling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gfn/gym/diffusion_sampling.py b/src/gfn/gym/diffusion_sampling.py index 68fc367c..4eb50ea3 100644 --- a/src/gfn/gym/diffusion_sampling.py +++ b/src/gfn/gym/diffusion_sampling.py @@ -937,7 +937,7 @@ class DiffusionSamplingStates(States): def is_initial_state(self) -> torch.Tensor: """Returns a tensor that is True for states that are s0 - When time is close enought to 0.0 (considering floating point errors), + When time is close enough to 0.0 (considering floating point errors), the state is s0. """ eps = env.dt * TERMINAL_TIME_EPS From ad457f6660f5ddcb89447f0a7f7d08435bbc197c Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Wed, 17 Dec 2025 22:57:09 -0500 Subject: [PATCH 24/26] added diffusion tests --- testing/test_diffusion_estimators.py | 270 +++++++++++++++++++++++++++ 1 file changed, 270 insertions(+) create mode 100644 testing/test_diffusion_estimators.py diff --git a/testing/test_diffusion_estimators.py b/testing/test_diffusion_estimators.py new file mode 100644 index 00000000..62597b15 --- /dev/null +++ b/testing/test_diffusion_estimators.py @@ -0,0 +1,270 @@ +import math + +import torch + +from gfn.estimators import PinnedBrownianMotionBackward, PinnedBrownianMotionForward +from gfn.gym.diffusion_sampling import DiffusionSampling +from gfn.samplers import Sampler +from gfn.utils.modules import DiffusionPISGradNetBackward + + +class _Identity(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x + + +class _ConstantJoint(torch.nn.Module): + def __init__(self, output: torch.Tensor): + super().__init__() + self.register_buffer("output", output) + + def forward( + self, s_emb: torch.Tensor, t_emb: torch.Tensor + ) -> torch.Tensor: # noqa: ARG002 + batch = s_emb.shape[0] + return self.output.expand(batch, -1) # type: ignore + + +class _ConstantModule(torch.nn.Module): + def __init__(self, output: torch.Tensor, input_dim: int): + super().__init__() + self.register_buffer("output", output) + self.input_dim = input_dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: # noqa: ARG002 + batch = x.shape[0] + return self.output.expand(batch, -1) # type: ignore + + +def test_diffusion_pis_gradnet_backward_scales_and_clamps_outputs(): + s_dim = 2 + pb_scale_range = 0.2 + log_var_range = 0.05 + model = DiffusionPISGradNetBackward( + s_dim=s_dim, + harmonics_dim=4, + t_emb_dim=4, + s_emb_dim=4, + hidden_dim=8, + joint_layers=1, + pb_scale_range=pb_scale_range, + log_var_range=log_var_range, + learn_variance=True, + ) + + # Replace heavy components with deterministic stubs. + model.s_model = _Identity() + model.t_model = _Identity() + model.joint_model = _ConstantJoint( + torch.tensor([3.0, -4.0, 50.0], dtype=torch.float32) + ) + + preprocessed = torch.tensor([[0.1, -0.2, 0.3]], dtype=torch.float32) + out = model(preprocessed) + + drift = out[..., :s_dim] + log_std = out[..., -1] + + assert out.shape == (1, s_dim + 1) + assert torch.all(torch.abs(drift) <= pb_scale_range + 1e-6) + assert torch.allclose( + drift[0, 0], + torch.tanh(torch.tensor(3.0)) * pb_scale_range, + atol=1e-4, + ) + # Log-std correction is tanh-bounded then clamped to log_var_range. + assert torch.allclose(log_std, torch.full_like(log_std, log_var_range)) + + +def test_pinned_brownian_forward_marks_exit_on_final_step(): + s_dim = 2 + num_steps = 4 + env = DiffusionSampling( + target_str="gmm2", + target_kwargs={"seed": 0}, + num_discretization_steps=num_steps, + device=torch.device("cpu"), + debug=True, + ) + pf_module = _ConstantModule( + output=torch.zeros(1, s_dim, dtype=torch.float32), + input_dim=s_dim + 1, + ) + pf = PinnedBrownianMotionForward( + s_dim=s_dim, + pf_module=pf_module, + sigma=1.0, + num_discretization_steps=num_steps, + ) + + # t + dt reaches terminal time, so the drift should be converted to exit action (-inf). + terminal_states = env.states_from_tensor( + torch.tensor([[0.0, 0.0, 1.0 - pf.dt]], dtype=torch.float32) + ) + dist = pf.to_probability_distribution(terminal_states, pf(terminal_states)) + assert torch.isinf(dist.loc).all() + + # Earlier times should stay finite. + mid_states = env.states_from_tensor( + torch.tensor([[0.0, 0.0, 0.5]], dtype=torch.float32) + ) + mid_dist = pf.to_probability_distribution(mid_states, pf(mid_states)) + assert torch.isfinite(mid_dist.loc).all() + + +def test_pinned_brownian_forward_exit_condition_matches_steps(): + """Exit masking triggers only on last step according to is_final_step logic.""" + s_dim = 2 + num_steps = 5 + env = DiffusionSampling( + target_str="gmm2", + target_kwargs={"seed": 0}, + num_discretization_steps=num_steps, + device=torch.device("cpu"), + debug=True, + ) + pf_module = _ConstantModule( + output=torch.zeros(1, s_dim, dtype=torch.float32), + input_dim=s_dim + 1, + ) + pf = PinnedBrownianMotionForward( + s_dim=s_dim, + pf_module=pf_module, + sigma=1.0, + num_discretization_steps=num_steps, + ) + + dt = pf.dt + eps = dt * 1e-2 # _DIFFUSION_TERMINAL_TIME_EPS + times = torch.tensor( + [ + 0.0, # initial + dt, # early + 1.0 - 2 * dt, # mid + 1.0 - dt - 0.5 * eps, # should trigger final step mask + 1.0 - dt, # last step before terminal time + ], + dtype=torch.float32, + ) + states = env.states_from_tensor( + torch.stack([torch.zeros_like(times), torch.zeros_like(times), times], dim=1) + ) + + dist = pf.to_probability_distribution(states, pf(states)) + exit_mask = torch.isinf(dist.loc).all(dim=-1) + expected = torch.tensor([False, False, False, True, True]) + assert torch.equal(exit_mask, expected) + + +def test_pinned_brownian_forward_combines_exploration_variance(): + s_dim = 2 + num_steps = 5 + env = DiffusionSampling( + target_str="gmm2", + target_kwargs={"seed": 1}, + num_discretization_steps=num_steps, + device=torch.device("cpu"), + debug=True, + ) + pf_module = _ConstantModule( + output=torch.zeros(1, s_dim + 1, dtype=torch.float32), + input_dim=s_dim + 1, + ) + pf = PinnedBrownianMotionForward( + s_dim=s_dim, + pf_module=pf_module, + sigma=1.0, + num_discretization_steps=num_steps, + n_variance_outputs=1, + ) + + states = env.states_from_tensor(torch.tensor([[0.0, 0.0, 0.0]], dtype=torch.float32)) + base_std = math.sqrt(pf.dt) # log_std=0 -> exp(0) * sqrt(dt) + exploration_std = 0.4 + dist = pf.to_probability_distribution( + states, pf(states), exploration_std=exploration_std + ) + + expected = math.sqrt(base_std**2 + exploration_std**2) + assert torch.allclose(dist.scale, torch.full_like(dist.scale, expected), atol=1e-6) + + +def test_pinned_brownian_backward_applies_corrections_and_quadrature(): + s_dim = 2 + num_steps = 4 + pb_scale_range = 0.2 + sigma = 1.5 + env = DiffusionSampling( + target_str="gmm2", + target_kwargs={"seed": 2}, + num_discretization_steps=num_steps, + device=torch.device("cpu"), + debug=True, + ) + pb_module = _ConstantModule( + output=torch.tensor([[5.0, -5.0, 1.0]], dtype=torch.float32), + input_dim=s_dim + 1, + ) + pb = PinnedBrownianMotionBackward( + s_dim=s_dim, + pb_module=pb_module, + sigma=sigma, + num_discretization_steps=num_steps, + n_variance_outputs=1, + pb_scale_range=pb_scale_range, + ) + + t_curr = 0.5 + states = env.states_from_tensor( + torch.tensor([[0.5, -0.25, t_curr]], dtype=torch.float32) + ) + dist = pb.to_probability_distribution(states, pb(states)) + + dt = pb.dt + s_curr = states.tensor[:, :-1] + base_mean = s_curr * dt / t_curr + base_std = sigma * math.sqrt(dt * (t_curr - dt) / t_curr) + + expected_mean = base_mean + torch.tensor([[1.0, -1.0]], dtype=torch.float32) + expected_std = math.sqrt(base_std**2 + math.exp(pb_scale_range) ** 2) + + assert torch.allclose(dist.loc, expected_mean, atol=1e-6) + assert torch.allclose( + dist.scale, torch.full_like(dist.scale, expected_std), atol=1e-6 + ) + + +def test_diffusion_sampler_completes_after_num_steps(): + num_steps = 6 + batch_size = 3 + s_dim = 2 + env = DiffusionSampling( + target_str="gmm2", + target_kwargs={"seed": 3}, + num_discretization_steps=num_steps, + device=torch.device("cpu"), + debug=True, + ) + pf_module = _ConstantModule( + output=torch.zeros(1, s_dim, dtype=torch.float32), input_dim=s_dim + 1 + ) + pf = PinnedBrownianMotionForward( + s_dim=s_dim, + pf_module=pf_module, + sigma=1.0, + num_discretization_steps=num_steps, + ) + sampler = Sampler(estimator=pf) + + trajectories = sampler.sample_trajectories( + env, n=batch_size, save_logprobs=True, save_estimator_outputs=False + ) + + assert torch.all(trajectories.terminating_idx == num_steps) + # The sampler uses the estimator output directly (exit action = -inf) so the final + # state is the sink padding (non-finite). Verify sink detection and exit action. + final_states = trajectories.states[ + trajectories.terminating_idx, torch.arange(batch_size) + ] + assert final_states.is_sink_state.all() + assert trajectories.actions.is_exit[num_steps - 1].all() From f439a7855625b13b80d72aba11d64d6ee61d63ea Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 18 Dec 2025 04:01:10 +0000 Subject: [PATCH 25/26] Initial plan From b141f3fcd09b97df3679a8c4e9a9fd0b27bada5d Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 18 Dec 2025 04:05:16 +0000 Subject: [PATCH 26/26] Refactor: Define OUTPUT_DIR constant for visualization output directory Co-authored-by: josephdviviano <4142570+josephdviviano@users.noreply.github.com> --- src/gfn/gym/diffusion_sampling.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/src/gfn/gym/diffusion_sampling.py b/src/gfn/gym/diffusion_sampling.py index 4eb50ea3..93a83882 100644 --- a/src/gfn/gym/diffusion_sampling.py +++ b/src/gfn/gym/diffusion_sampling.py @@ -27,6 +27,9 @@ # - Exit action trigger: t + dt >= 1.0 - dt * TERMINAL_TIME_EPS (next step reaches terminal) TERMINAL_TIME_EPS = 1e-2 +# Default output directory for saving visualizations +OUTPUT_DIR = "output" + ############################### ### Target energy functions ### @@ -407,8 +410,8 @@ def visualize( if show: plt.show() else: - os.makedirs("output", exist_ok=True) - plt.savefig(f"output/{prefix}simple_gmm.png") + os.makedirs(OUTPUT_DIR, exist_ok=True) + plt.savefig(f"{OUTPUT_DIR}/{prefix}simple_gmm.png") plt.close() @@ -479,8 +482,8 @@ def visualize( if show: plt.show() else: - os.makedirs("output", exist_ok=True) - fig.savefig(f"output/{prefix}gmm25.png") + os.makedirs(OUTPUT_DIR, exist_ok=True) + fig.savefig(f"{OUTPUT_DIR}/{prefix}gmm25.png") plt.close() @@ -565,8 +568,8 @@ def visualize( if show: plt.show() else: - os.makedirs("output", exist_ok=True) - fig.savefig(f"output/{prefix}posterior9of25.png") + os.makedirs(OUTPUT_DIR, exist_ok=True) + fig.savefig(f"{OUTPUT_DIR}/{prefix}posterior9of25.png") plt.close() @@ -670,8 +673,8 @@ def visualize( if show: plt.show() else: - os.makedirs("output", exist_ok=True) - fig.savefig(f"output/{prefix}funnel.png") + os.makedirs(OUTPUT_DIR, exist_ok=True) + fig.savefig(f"{OUTPUT_DIR}/{prefix}funnel.png") plt.close() @@ -830,8 +833,8 @@ def visualize( if show: plt.show() else: - os.makedirs("output", exist_ok=True) - fig.savefig(f"output/{prefix}manywell.png") + os.makedirs(OUTPUT_DIR, exist_ok=True) + fig.savefig(f"{OUTPUT_DIR}/{prefix}manywell.png") plt.close()