Conversation
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## master #457 +/- ##
==========================================
+ Coverage 74.23% 74.33% +0.09%
==========================================
Files 47 48 +1
Lines 6805 7112 +307
Branches 800 830 +30
==========================================
+ Hits 5052 5287 +235
- Misses 1449 1503 +54
- Partials 304 322 +18
🚀 New features to boost your workflow:
|
src/gfn/estimators.py
Outdated
| s_curr, | ||
| s_curr * self.dt / t_curr, | ||
| torch.zeros_like(s_curr), | ||
| s_curr * (1.0 - self.dt / t_curr), |
There was a problem hiding this comment.
@hyeok9855 this might have been the cause of the problem you had before in your code (learning slower), worth checking.
| recalculate_all_logprobs=recalculate_all_logprobs, | ||
| ) | ||
|
|
||
| def trajectory_log_probs_backward( |
There was a problem hiding this comment.
leaving these here because they might come in handy, but I don't think they're actually needed right now in this implementation.
…gfn into relative_trajectory_balance
|
@copilot how can I prevent codecov from failing CI? I want it to track coverage, but not block merge requests. |
|
@josephdviviano I've opened a new pull request, #458, to work on those changes. Once the pull request is ready, I'll request review from you. |
Co-authored-by: josephdviviano <4142570+josephdviviano@users.noreply.github.com>
There was a problem hiding this comment.
Pull request overview
This PR adds Relative Trajectory Balance (RTB) support for fine-tuning diffusion-based GFlowNets from a prior to a posterior distribution. It introduces a two-stage training pipeline: (1) maximum likelihood estimation (MLE) pre-training of a prior diffusion model using either a fixed Brownian bridge or learned backward policy, and (2) RTB-based fine-tuning that adapts the prior to match a posterior distribution weighted by a reward function.
Key changes include:
- Implementation of
RelativeTrajectoryBalanceGFlowNetfor fine-tuning with the RTB loss - Implementation of
MLEDiffusiontrainer for prior pre-training via maximum likelihood - Support for learned variance in both forward and backward diffusion policies
- New Gaussian mixture target distributions for RTB demonstrations (25-mode prior, 9-mode posterior)
- Enhanced diffusion trajectory mask alignment with consistent terminal time detection
Reviewed changes
Copilot reviewed 15 out of 16 changed files in this pull request and generated 19 comments.
Show a summary per file
| File | Description |
|---|---|
src/gfn/gflownet/trajectory_balance.py |
Adds RelativeTrajectoryBalanceGFlowNet class implementing the RTB objective |
src/gfn/gflownet/mle.py |
New MLE trainer for diffusion pre-training with backward sampling |
src/gfn/estimators.py |
Adds learned variance support and exploration noise handling to forward/backward estimators |
src/gfn/utils/modules.py |
Implements DiffusionPISGradNetBackward and adds learned variance to forward module |
src/gfn/gym/diffusion_sampling.py |
Adds Grid25GaussianMixture and Posterior9of25GaussianMixture targets; improves terminal state detection |
src/gfn/gym/helpers/diffusion_utils.py |
Adds max_n_samples parameter to visualization utility |
src/gfn/gflownet/base.py |
Adds helper methods for computing forward/backward trajectory log-probs separately |
src/gfn/gflownet/__init__.py |
Exports RelativeTrajectoryBalanceGFlowNet |
src/gfn/env.py |
Minor comment clarification on boolean masking behavior |
tutorials/examples/train_diffusion_rtb.py |
Complete end-to-end tutorial demonstrating MLE pre-training and RTB fine-tuning |
tutorials/examples/output/.gitignore |
Ignores generated checkpoint and visualization files |
testing/test_rtb.py |
Unit tests for RTB loss computation and gradient flow |
testing/test_environments.py |
Test verifying diffusion trajectory mask alignment |
testing/gym/test_diffusion_sampling_rtb.py |
Tests for new GMM target distributions |
testing/gflownet/test_mle_diffusion.py |
Comprehensive tests for MLE diffusion trainer |
.flake8 |
Minor formatting (trailing newline) |
Comments suppressed due to low confidence (2)
src/gfn/gym/diffusion_sampling.py:69
- This call to BaseTarget.sample in an initialization method is overridden by SimpleGaussianMixture.sample.
This call to BaseTarget.sample in an initialization method is overridden by Grid25GaussianMixture.sample.
This call to BaseTarget.sample in an initialization method is overridden by Posterior9of25GaussianMixture.sample.
This call to BaseTarget.sample in an initialization method is overridden by Funnel.sample.
This call to BaseTarget.sample in an initialization method is overridden by ManyWell.sample.
self.gt_xs = self.sample(n_gt_xs, seed)
src/gfn/gym/diffusion_sampling.py:70
- This call to BaseTarget.log_reward in an initialization method is overridden by SimpleGaussianMixture.log_reward.
This call to BaseTarget.log_reward in an initialization method is overridden by Grid25GaussianMixture.log_reward.
This call to BaseTarget.log_reward in an initialization method is overridden by Posterior9of25GaussianMixture.log_reward.
This call to BaseTarget.log_reward in an initialization method is overridden by Funnel.log_reward.
This call to BaseTarget.log_reward in an initialization method is overridden by ManyWell.log_reward.
self.gt_xs_log_rewards = self.log_reward(self.gt_xs)
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
src/gfn/gym/diffusion_sampling.py
Outdated
| os.makedirs("output", exist_ok=True) | ||
| plt.savefig(f"output/{prefix}simple_gmm.png") |
There was a problem hiding this comment.
The output directory is changed from "viz" to "output" across multiple target classes. While this change is consistent, it would be better to define the output directory as a constant or configuration parameter at the module level rather than hardcoding it in multiple places.
There was a problem hiding this comment.
@copilot open a new pull request to apply changes based on this feedback
| class DiffusionPISGradNetBackward(nn.Module): | ||
| """Learnable backward correction module (PIS-style) for diffusion. | ||
|
|
||
| Produces mean and optional log-std corrections that are tanh-scaled by | ||
| `pb_scale_range` to stay close to the analytic Brownian bridge. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| s_dim: int, | ||
| harmonics_dim: int = 64, | ||
| t_emb_dim: int = 64, | ||
| s_emb_dim: int = 64, | ||
| hidden_dim: int = 64, | ||
| joint_layers: int = 2, | ||
| zero_init: bool = False, | ||
| clipping: bool = False, | ||
| gfn_clip: float = 1e4, | ||
| pb_scale_range: float = 0.1, | ||
| log_var_range: float = 4.0, | ||
| learn_variance: bool = True, | ||
| ) -> None: | ||
| super().__init__() | ||
| self.s_dim = s_dim | ||
| self.out_dim = s_dim + (1 if learn_variance else 0) | ||
| self.harmonics_dim = harmonics_dim | ||
| self.t_emb_dim = t_emb_dim | ||
| self.s_emb_dim = s_emb_dim | ||
| self.hidden_dim = hidden_dim | ||
| self.joint_layers = joint_layers | ||
| self.zero_init = zero_init | ||
| self.clipping = clipping | ||
| self.gfn_clip = gfn_clip | ||
| self.pb_scale_range = pb_scale_range | ||
| self.log_var_range = log_var_range | ||
| self.learn_variance = learn_variance | ||
|
|
||
| assert ( | ||
| self.s_emb_dim == self.t_emb_dim | ||
| ), "Dimensionality of state embedding and time embedding should be the same!" | ||
|
|
||
| self.t_model = DiffusionPISTimeEncoding( | ||
| self.harmonics_dim, self.t_emb_dim, self.hidden_dim | ||
| ) | ||
| self.s_model = DiffusionPISStateEncoding(self.s_dim, self.s_emb_dim) | ||
| self.joint_model = DiffusionPISJointPolicy( | ||
| self.s_emb_dim, | ||
| self.hidden_dim, | ||
| self.out_dim, | ||
| self.joint_layers, | ||
| self.zero_init, | ||
| ) | ||
|
|
||
| def forward(self, preprocessed_states: torch.Tensor) -> torch.Tensor: | ||
| s = preprocessed_states[..., :-1] | ||
| t = preprocessed_states[..., -1] | ||
| s_emb = self.s_model(s) | ||
| t_emb = self.t_model(t) | ||
| out = self.joint_model(s_emb, t_emb) | ||
|
|
||
| if self.clipping: | ||
| out = torch.clamp(out, -self.gfn_clip, self.gfn_clip) | ||
|
|
||
| # Tanh-scale to stay near Brownian bridge; last dim (if present) is log-std corr. | ||
| drift_corr = torch.tanh(out[..., : self.s_dim]) * self.pb_scale_range | ||
| if self.learn_variance and out.shape[-1] == self.s_dim + 1: | ||
| log_std_corr = torch.tanh(out[..., [-1]]) * self.pb_scale_range | ||
| log_std_corr = torch.clamp( | ||
| log_std_corr, -self.log_var_range, self.log_var_range | ||
| ) | ||
| out = torch.cat([drift_corr, log_std_corr], dim=-1) | ||
| else: | ||
| out = drift_corr | ||
|
|
||
| if torch.isnan(out).any(): | ||
| raise ValueError("DiffusionPISGradNetBackward produced NaNs") | ||
|
|
||
| return out |
There was a problem hiding this comment.
The new DiffusionPISGradNetBackward class lacks dedicated test coverage. While it's used indirectly in the tutorial script, there are no unit tests that directly verify its behavior (e.g., output shapes, learned corrections, variance handling). Consider adding tests similar to those for DiffusionPISGradNetForward to ensure the backward module works correctly.
There was a problem hiding this comment.
I added test coverage.
| eps = self.dt * _DIFFUSION_TERMINAL_TIME_EPS | ||
| is_final_step = (t_curr + self.dt) >= (1.0 - eps) | ||
| # TODO: The old code followed this convention (below). I believe the change | ||
| # is slightly more correct, but I'd like to check this during review. | ||
| # (1.0 - t_curr) < self.dt * 1e-2 # Triggers when t_curr ≈ 1.0 |
There was a problem hiding this comment.
This TODO comment suggests uncertainty about the correctness of the exit condition change. TODOs requesting review in production code should be resolved before merging. Either verify the correctness and remove the TODO, or if there's genuine uncertainty, add a test to validate the behavior.
| eps = self.dt * _DIFFUSION_TERMINAL_TIME_EPS | |
| is_final_step = (t_curr + self.dt) >= (1.0 - eps) | |
| # TODO: The old code followed this convention (below). I believe the change | |
| # is slightly more correct, but I'd like to check this during review. | |
| # (1.0 - t_curr) < self.dt * 1e-2 # Triggers when t_curr ≈ 1.0 | |
| # Note: this replaces an older heuristic `(1.0 - t_curr) < self.dt * 1e-2`, | |
| # using the shared `_DIFFUSION_TERMINAL_TIME_EPS` tolerance for consistency. | |
| eps = self.dt * _DIFFUSION_TERMINAL_TIME_EPS | |
| is_final_step = (t_curr + self.dt) >= (1.0 - eps) |
| """ | ||
| assert len(states.batch_shape) == 1, "States must have a batch_shape of length 1" | ||
| s_curr = states.tensor[:, :-1] | ||
| # s_curr = states.tensor[:, :-1] |
There was a problem hiding this comment.
The commented-out line appears to be dead code that should be removed. If it's intended for reference, consider moving it to a comment explaining why the change was made rather than leaving commented code.
| # s_curr = states.tensor[:, :-1] |
| if self.debug and torch.isnan(logpf_sum).any(): | ||
| raise ValueError("NaNs in logpf_sum during MLE loss.") | ||
|
|
||
| # TODO: Use included loss reduction helpers. |
There was a problem hiding this comment.
The TODO comment suggests using included loss reduction helpers, but the current implementation uses a simple conditional. Either implement the TODO or remove it if the current approach is acceptable.
| # TODO: Use included loss reduction helpers. |
| self.prior_pf = prior_pf | ||
| self.beta = torch.tensor(beta) | ||
| self.logZ = logZ or nn.Parameter(torch.tensor(init_logZ)) | ||
| self.debug = debug # TODO: to be passed to base classes. |
There was a problem hiding this comment.
The TODO comment about passing debug flag to base classes should be resolved. Either implement the propagation of the debug flag to parent classes or remove the TODO if it's not necessary for this feature.
| self.debug = debug # TODO: to be passed to base classes. | |
| self.debug = debug |
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Configure codecov to not block CI while tracking coverage
|
@josephdviviano I've opened a new pull request, #459, to work on those changes. Once the pull request is ready, I'll request review from you. |
Co-authored-by: josephdviviano <4142570+josephdviviano@users.noreply.github.com>
Refactor: Extract OUTPUT_DIR constant for visualization paths
| # This matches the exit condition in DiffusionSampling.step() and ensures the | ||
| # sampled action is marked as an exit action (-inf) so trajectory masks align | ||
| # correctly in get_trajectory_pbs. | ||
| eps = self.dt * _DIFFUSION_TERMINAL_TIME_EPS |
There was a problem hiding this comment.
Should revert change - am terminating one step too early.
Description
pfusing either a learned backward or brownian bridge from provided samples using maximum likelihood,DiffusionMLE.RelativeTrajectoryBalanceGFlowNet.train_diffusion_rtb.py.