Add backward logprobs and estimator_outputs in Trajectories#396
Add backward logprobs and estimator_outputs in Trajectories#396
Conversation
josephdviviano
left a comment
There was a problem hiding this comment.
Some questions / comments, but I do like this direction quite a bit! Thank you!
| @@ -55,7 +55,9 @@ def __init__( | |||
| is_backward: bool = False, | |||
| log_rewards: torch.Tensor | None = None, | |||
| log_probs: torch.Tensor | None = None, | |||
There was a problem hiding this comment.
This is going to be an annoying change but I suspect we should be explicit about forward_log_probs and forward_estimator_outputs in the namespace as well.
| self.backward_estimator_outputs.shape[: len(self.states.batch_shape)] | ||
| == self.actions.batch_shape | ||
| and self.backward_estimator_outputs.is_floating_point() | ||
| ) |
There was a problem hiding this comment.
We could probably write a private _check_estimator_outputs() and check_log_probs() method which deduplicates this logic.
| def terminating_states(self) -> States: | ||
| """The terminating states of the trajectories. | ||
| """The terminating states of the trajectories. If backward, the terminating states | ||
| are in 0-th position. |
There was a problem hiding this comment.
Worth explicitly stating whether these are s0 or s0+1?
There was a problem hiding this comment.
Worth specifying the terminating state of the backward trajectory is s0, and the terminating state of the reversed forward trajectory is NOT s0.
There was a problem hiding this comment.
We need to be careful about the definition of reversed and backward. A backward trajectory (evaluated under pb) runs in the opposite direction of a forward trajectory (evaluated under pf). These can be reversed, but the reverse of a forward trajectory still corresponds to pf, and a reversed backward trajectory still corresponds to pb. |
Can we add these definitions here, and ensure that the language is consistent throughout.
| (max_len + 2, len(self)), device=states.device | ||
| ) | ||
| # shape (max_len + 2, n_trajectories, *state_dim) | ||
| actions = self.actions # shape (max_len, n_trajectories *action_dim) |
There was a problem hiding this comment.
Is it now possible to replace reverse_backward_trajectories() with simply reverse() which works on either forward or backward trajectories?
We could keep reverse_backward_trajectories() as essentially an alias:
def reverse_backward_trajectories(self):
assert self.is_backward()
self.reverse()
There was a problem hiding this comment.
i'm wondering if we should have a @property for both Transitions and Trajectories which is self.has_forward and self.has_backward for readability.
| Raises: | ||
| ValueError: If backward transitions are provided. | ||
| """ | ||
| if transitions.is_backward: |
There was a problem hiding this comment.
I think we need some kind of check here. For example, transitions.has_forward()
|
|
||
| optimizer.zero_grad() | ||
| loss = gflownet.loss(env, buffer_trajectories, recalculate_all_logprobs=True) | ||
| loss = gflownet.loss(env, buffer_trajectories, recalculate_all_logprobs=False) |
There was a problem hiding this comment.
This is because the forward logprobs are already populated? How?
| # terminating states (can be passed to log_reward fn) | ||
| if self.is_backward: | ||
| # [IMPORTANT ASSUMPTION] When backward sampling, all provided states are the | ||
| # *terminating* states (can be passed to log_reward fn) |
There was a problem hiding this comment.
Why do we need this assumption?
| transition. | ||
| """ | ||
| if transitions.is_backward: | ||
| raise ValueError("Backward transitions are not supported") |
There was a problem hiding this comment.
I'm confused about how this will work in practice. I suspect we still need some kind of assertion here.
Description
backward_logprobsandbackward_estimator_outputsinTrajectories, so that it doesn't have to be recomputed every time to get the loss.log_rewardsfor backward sampling.is_backwardflag fromTransitionssince we do not properly support it.Discussion needed
We can't calculate loss with "backward" Trajectories. We always need to use
.reverse_backward_trajectoriesbefore the loss calculation. The question is, is it worth supporting direct loss calculation using the backward Trajectories? This won't be too hard to implement.This was also the case for the Transitions (although I removed the backward flag from Transitions; this can be easily reverted). If we want to support training with backward Trajectories, do we also want the same for the backward Transitions?
TODO
Currently, we can store only one of the PF or PB during sampling (i.e., we always need to call our NN module for loss calculation). My next PR will resolve this by calculating both PF and PB within a single sampling process.