From 3841ea31e02c60e2a6ab29503a10ef8c87d06426 Mon Sep 17 00:00:00 2001 From: hyeok9855 Date: Tue, 15 Apr 2025 02:03:09 +0900 Subject: [PATCH 1/3] fix partial energy forward looking trick in backward samling --- energy_sampling/models/gfn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/energy_sampling/models/gfn.py b/energy_sampling/models/gfn.py index 190fb3d6..9b8aeee0 100644 --- a/energy_sampling/models/gfn.py +++ b/energy_sampling/models/gfn.py @@ -233,8 +233,8 @@ def get_trajectory_bwd(self, s, discretizer, exploration_std, log_r): logf[:, trajectory_length - i - 1] = flow if self.partial_energy: ref_log_var = (self.t_scale * ts[:, max(1, trajectory_length - i - 1)]).log() - log_p_ref = -0.5 * (logtwopi + ref_log_var.unsqueeze(1) + (-ref_log_var).exp().unsqueeze(1) * (s ** 2)).sum(1) - logf[:, trajectory_length - i - 1] += ts[:, trajectory_length - i - 1] * log_p_ref + ts[:, i + 1] * log_r(s) + log_p_ref = -0.5 * (logtwopi + ref_log_var.unsqueeze(1) + (-ref_log_var).exp().unsqueeze(1) * (s_ ** 2)).sum(1) + logf[:, trajectory_length - i - 1] += (1 - ts[:, trajectory_length - i - 1]) * log_p_ref + ts[:, trajectory_length - i - 1] * log_r(s_) noise = ((s - s_) - dts.unsqueeze(1) * pf_mean) / (dts.sqrt().unsqueeze(1) * (pflogvars / 2).exp()) logpf[:, trajectory_length - i - 1] = -0.5 * (noise ** 2 + logtwopi + dts.log().unsqueeze(1) + pflogvars).sum( From bd93bea6440e49c104b9c7d97fbb23c347530f6b Mon Sep 17 00:00:00 2001 From: hyeok9855 Date: Tue, 15 Apr 2025 02:14:11 +0900 Subject: [PATCH 2/3] vectorize partial energy calculations --- energy_sampling/models/gfn.py | 38 +++++++++++++++++++++++++++-------- 1 file changed, 30 insertions(+), 8 deletions(-) diff --git a/energy_sampling/models/gfn.py b/energy_sampling/models/gfn.py index 9b8aeee0..096ac15b 100644 --- a/energy_sampling/models/gfn.py +++ b/energy_sampling/models/gfn.py @@ -143,10 +143,11 @@ def get_trajectory_fwd(self, s, discretizer, exploration_std, log_r, pis=False): pf_mean, pflogvars = self.split_params(pfs) logf[:, i] = flow - if self.partial_energy: - ref_log_var = (self.t_scale * ts[:, max(1, i)]).log() - log_p_ref = -0.5 * (logtwopi + ref_log_var.unsqueeze(1) + (-ref_log_var).exp().unsqueeze(1) * (s ** 2)).sum(1) - logf[:, i] += (1 - ts[:, i]) * log_p_ref + ts[:, i] * log_r(s) + # Note: We instead use the vectorized version outside of the loop + # if self.partial_energy: + # ref_log_var = (self.t_scale * ts[:, max(1, i)]).log() + # log_p_ref = -0.5 * (logtwopi + ref_log_var.unsqueeze(1) + (-ref_log_var).exp().unsqueeze(1) * (s ** 2)).sum(1) + # logf[:, i] += (1 - ts[:, i]) * log_p_ref + ts[:, i] * log_r(s) if exploration_std is None: if pis: @@ -192,6 +193,16 @@ def get_trajectory_fwd(self, s, discretizer, exploration_std, log_r, pis=False): s = s_ states[:, i + 1] = s + if self.partial_energy: + assert log_r is not None + ref_log_var = (self.t_scale * ts[:, 1:-1]).log().unsqueeze(2) # (bsz, T - 1, 1) + log_p_ref = -0.5 * ( + logtwopi + ref_log_var + (-ref_log_var).exp() * (states[:, 1:-1] ** 2) + ).sum(-1) + logf[:, 1:-1] += (1 - ts[:, 1:-1]) * log_p_ref + ts[:, 1:-1] * log_r( + states[:, 1:-1].reshape(-1, self.dim) + ).view(bsz, trajectory_length - 1) + return states, logpf, logpb, logf def get_trajectory_bwd(self, s, discretizer, exploration_std, log_r): @@ -231,10 +242,11 @@ def get_trajectory_bwd(self, s, discretizer, exploration_std, log_r): pf_mean, pflogvars = self.split_params(pfs) logf[:, trajectory_length - i - 1] = flow - if self.partial_energy: - ref_log_var = (self.t_scale * ts[:, max(1, trajectory_length - i - 1)]).log() - log_p_ref = -0.5 * (logtwopi + ref_log_var.unsqueeze(1) + (-ref_log_var).exp().unsqueeze(1) * (s_ ** 2)).sum(1) - logf[:, trajectory_length - i - 1] += (1 - ts[:, trajectory_length - i - 1]) * log_p_ref + ts[:, trajectory_length - i - 1] * log_r(s_) + # Note: We instead use the vectorized version outside of the loop + # if self.partial_energy: + # ref_log_var = (self.t_scale * ts[:, max(1, trajectory_length - i - 1)]).log() + # log_p_ref = -0.5 * (logtwopi + ref_log_var.unsqueeze(1) + (-ref_log_var).exp().unsqueeze(1) * (s_ ** 2)).sum(1) + # logf[:, trajectory_length - i - 1] += (1 - ts[:, trajectory_length - i - 1]) * log_p_ref + ts[:, trajectory_length - i - 1] * log_r(s_) noise = ((s - s_) - dts.unsqueeze(1) * pf_mean) / (dts.sqrt().unsqueeze(1) * (pflogvars / 2).exp()) logpf[:, trajectory_length - i - 1] = -0.5 * (noise ** 2 + logtwopi + dts.log().unsqueeze(1) + pflogvars).sum( @@ -243,6 +255,16 @@ def get_trajectory_bwd(self, s, discretizer, exploration_std, log_r): s = s_ states[:, trajectory_length - i - 1] = s + if self.partial_energy: + assert log_r is not None + ref_log_var = (self.t_scale * ts[:, 1:-1]).log().unsqueeze(2) # (bsz, T - 1, 1) + log_p_ref = -0.5 * ( + logtwopi + ref_log_var + (-ref_log_var).exp() * (states[:, 1:-1] ** 2) + ).sum(-1) + logf[:, 1:-1] += (1 - ts[:, 1:-1]) * log_p_ref + ts[:, 1:-1] * log_r( + states[:, 1:-1].reshape(-1, self.dim) + ).view(bsz, trajectory_length - 1) + return states, logpf, logpb, logf def sample(self, batch_size, discretizer, log_r): From 27932506881a85e6bd7b9cb9c2310da366c1c75e Mon Sep 17 00:00:00 2001 From: hyeok9855 Date: Tue, 15 Apr 2025 02:14:30 +0900 Subject: [PATCH 3/3] vectorize manywell logprob calculation --- energy_sampling/energies/many_well.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/energy_sampling/energies/many_well.py b/energy_sampling/energies/many_well.py index 669a3a5c..7f4a615e 100644 --- a/energy_sampling/energies/many_well.py +++ b/energy_sampling/energies/many_well.py @@ -64,10 +64,10 @@ def doublewell_logprob(self, x): return x1_term + x2_term def manywell_logprob(self, x): - assert x.ndim == 2 - logprob = torch.stack( - [self.doublewell_logprob(x[:, i*2:i*2+2]) for i in range(self.n_wells)], - dim=1).sum(dim=1) + assert x.ndim == 2 # [batch_size, ndim] + x_reshaped = x.view(-1, self.n_wells, 2).reshape(-1, 2) # [batch_size * n_wells, 2] + logprob = self.doublewell_logprob(x_reshaped) # [batch_size * n_wells] + logprob = logprob.reshape(-1, self.n_wells).sum(dim=1) # [batch_size] return logprob def sample_first_dimension(self, batch_size):