From c4ba5f5bfeaad5deb780c16d624f704c16566a3a Mon Sep 17 00:00:00 2001 From: younik Date: Thu, 29 Jan 2026 15:04:30 +0100 Subject: [PATCH 1/7] compute and cache masks --- src/gfn/env.py | 22 +-- src/gfn/gym/bitSequence.py | 216 ++++++++++++++---------- src/gfn/gym/discrete_ebm.py | 51 ++++-- src/gfn/gym/hypergrid.py | 51 ++++-- src/gfn/gym/perfect_tree.py | 98 ++++++----- src/gfn/gym/set_addition.py | 96 ++++++----- src/gfn/states.py | 327 ++++++++++++++---------------------- testing/test_states.py | 26 ++- 8 files changed, 460 insertions(+), 427 deletions(-) diff --git a/src/gfn/env.py b/src/gfn/env.py index c16bf067..bb485c33 100644 --- a/src/gfn/env.py +++ b/src/gfn/env.py @@ -603,7 +603,7 @@ def __init__( def states_from_tensor( self, tensor: torch.Tensor, conditions: torch.Tensor | None = None ) -> DiscreteStates: - """Wraps the supplied tensor in a DiscreteStates instance and updates masks. + """Wraps the supplied tensor in a DiscreteStates instance. Args: tensor: Tensor of shape (*batch_shape, *state_shape) representing the states. @@ -617,7 +617,6 @@ def states_from_tensor( DiscreteStates, self.States(tensor=tensor, conditions=conditions, debug=self.debug), ) - self.update_masks(states_instance) return states_instance def states_from_batch_shape( @@ -649,7 +648,7 @@ def reset( seed: Optional[int] = None, conditions: Optional[torch.Tensor] = None, ) -> DiscreteStates: - """Instantiates a batch of random, initial, or sink states and updates masks. + """Instantiates a batch of random, initial, or sink states. Args: batch_shape: Shape of the batch (int or tuple). @@ -664,19 +663,8 @@ def reset( """ states = super().reset(batch_shape, random, sink, seed, conditions=conditions) states = cast(DiscreteStates, states) - self.update_masks(states) return states - @abstractmethod - def update_masks(self, states: DiscreteStates) -> None: - """Updates the masks in DiscreteStates. - - Called automatically after each step for discrete environments. - - Args: - states: The DiscreteStates object whose masks will be updated. - """ - def make_states_class(self) -> type[DiscreteStates]: """Returns the DiscreteStates class for this environment. @@ -736,8 +724,6 @@ def is_action_valid( def _step(self, states: DiscreteStates, actions: Actions) -> DiscreteStates: """Wrapper for the user-defined `step` function. - This calls the `_step` method of the parent class and updates masks. - Args: states: The batch of discrete states. actions: The batch of actions. @@ -747,14 +733,11 @@ def _step(self, states: DiscreteStates, actions: Actions) -> DiscreteStates: """ new_states = super()._step(states, actions) new_states = cast(DiscreteStates, new_states) - self.update_masks(new_states) return new_states def _backward_step(self, states: DiscreteStates, actions: Actions) -> DiscreteStates: """Wrapper for the user-defined `backward_step` function. - This calls the `_backward_step` method of the parent class and updates masks. - Args: states: The batch of discrete states. actions: The batch of actions. @@ -764,7 +747,6 @@ def _backward_step(self, states: DiscreteStates, actions: Actions) -> DiscreteSt """ new_states = super()._backward_step(states, actions) new_states = cast(DiscreteStates, new_states) - self.update_masks(new_states) return new_states def get_terminating_state_dist(self, states: DiscreteStates) -> torch.Tensor: diff --git a/src/gfn/gym/bitSequence.py b/src/gfn/gym/bitSequence.py index 318d0d03..be8a6d6a 100644 --- a/src/gfn/gym/bitSequence.py +++ b/src/gfn/gym/bitSequence.py @@ -19,23 +19,19 @@ class BitSequenceStates(DiscreteStates): Attributes: word_size (ClassVar[int]): The size of each word in the bit sequence. + words_per_seq (ClassVar[int]): The number of words per sequence. tensor (torch.Tensor): The tensor representing the states. length (torch.Tensor): The tensor representing the length of each bit sequence. - forward_masks (torch.Tensor): The tensor representing the forward masks. - backward_masks (torch.Tensor): The tensor representing the backward masks. """ word_size: ClassVar[int] + words_per_seq: ClassVar[int] length: torch.Tensor - forward_masks: torch.Tensor - backward_masks: torch.Tensor def __init__( self, tensor: torch.Tensor, length: Optional[torch.Tensor] = None, - forward_masks: Optional[torch.Tensor] = None, - backward_masks: Optional[torch.Tensor] = None, conditions: Optional[torch.Tensor] = None, debug: bool = False, ) -> None: @@ -44,25 +40,64 @@ def __init__( Args: tensor: The tensor representing the states. length: The tensor representing the length of each bit sequence. - forward_masks: The tensor representing the forward masks. - backward_masks: The tensor representing the backward masks. conditions: The tensor representing the conditions. debug: If True, enable runtime guards in the parent class (not compile-friendly). """ super().__init__( tensor, - forward_masks=forward_masks, - backward_masks=backward_masks, conditions=conditions, debug=debug, ) if length is None: - length = torch.zeros(self.batch_shape, dtype=torch.long, device=self.device) + # Compute length from tensor: count non -1 values + length = (tensor != -1).sum(dim=-1).long() assert is_int_dtype(length) self.length = length - assert self.length is not None - assert self.forward_masks is not None - assert self.backward_masks is not None + + def _compute_forward_masks(self) -> torch.Tensor: + """Computes forward masks for BitSequence states.""" + forward_masks = torch.ones( + (*self.batch_shape, self.n_actions), + dtype=torch.bool, + device=self.device, + ) + is_done = self.length == self.words_per_seq + # When done, only exit action is allowed + forward_masks[is_done, :-1] = False + # When not done, exit action is not allowed + forward_masks[~is_done, -1] = False + return forward_masks + + def _compute_backward_masks(self) -> torch.Tensor: + """Computes backward masks for BitSequence states.""" + backward_masks = torch.zeros( + (*self.batch_shape, self.n_actions - 1), + dtype=torch.bool, + device=self.device, + ) + is_sink = self.is_sink_state + # For non-sink states, the valid backward action is the last action taken + non_sink_flat = ~is_sink.flatten() + if non_sink_flat.any(): + flat_tensor = self.tensor.view(-1, *self.state_shape) + flat_length = self.length.view(-1) + flat_backward_masks = backward_masks.view(-1, self.n_actions - 1) + + # Get indices of non-sink states with valid length + non_sink_indices = torch.where(non_sink_flat)[0] + non_sink_lengths = flat_length[non_sink_indices] + + # Only process states with length > 0 + valid_mask = non_sink_lengths > 0 + if valid_mask.any(): + valid_indices = non_sink_indices[valid_mask] + valid_lengths = non_sink_lengths[valid_mask] + # Get the last action for each valid state + last_actions = flat_tensor[valid_indices, valid_lengths - 1] + # Set the backward mask for each state + flat_backward_masks[valid_indices, last_actions] = True + backward_masks = flat_backward_masks.view(*self.batch_shape, self.n_actions - 1) + return backward_masks def clone(self) -> BitSequenceStates: """Returns a clone of the current BitSequencesStates object. @@ -73,15 +108,10 @@ def clone(self) -> BitSequenceStates: return self.__class__( self.tensor.detach().clone(), self.length.detach().clone(), - self.forward_masks.detach().clone(), - self.backward_masks.detach().clone(), self.conditions.detach().clone() if self.conditions is not None else None, debug=self.debug, ) - def _check_both_forward_backward_masks_exist(self): - assert self.forward_masks is not None and self.backward_masks is not None - def __getitem__( self, index: int | slice | tuple | Sequence[int] | Sequence[bool] | torch.Tensor ) -> BitSequenceStates: @@ -93,12 +123,9 @@ def __getitem__( Returns: A subset of the BitSequencesStates object. """ - self._check_both_forward_backward_masks_exist() return self.__class__( self.tensor[index], self.length[index], - self.forward_masks[index], - self.backward_masks[index], self.conditions[index] if self.conditions is not None else None, debug=self.debug, ) @@ -123,16 +150,11 @@ def flatten(self) -> BitSequenceStates: """ states = self.tensor.view(-1, *self.state_shape) length = self.length.view(-1) - self._check_both_forward_backward_masks_exist() - forward_masks = self.forward_masks.view(-1, self.forward_masks.shape[-1]) - backward_masks = self.backward_masks.view(-1, self.backward_masks.shape[-1]) if self.conditions is not None: conditions = self.conditions.view(-1, self.conditions.shape[-1]) else: conditions = None - return self.__class__( - states, length, forward_masks, backward_masks, conditions, debug=self.debug - ) + return self.__class__(states, length, conditions, debug=self.debug) def extend(self, other: BitSequenceStates) -> None: """Extends the current BitSequencesStates object with another BitSequencesStates object. @@ -149,25 +171,22 @@ def pad_dim0_with_sf(self, required_first_dim: int) -> None: """Extends the current BitSequencesStates object with sink states. Args: - required_first_dim: The required first dimension of the extended masks. + required_first_dim: The required first dimension of the extended states. """ - super().pad_dim0_with_sf(required_first_dim) + if self.batch_shape[0] >= required_first_dim: + return - def _extend(masks, first_dim): - return torch.cat( - ( - masks, - torch.ones( - first_dim - masks.shape[0], - *masks.shape[1:], - dtype=torch.bool, - device=self.device, - ), - ), - dim=0, - ) + pad_count = required_first_dim - self.batch_shape[0] + super().pad_dim0_with_sf(required_first_dim) - self.length = _extend(self.length, required_first_dim) + # Extend length with zeros for the sink states + length_pad = torch.zeros( + pad_count, + *self.length.shape[1:], + dtype=self.length.dtype, + device=self.device, + ) + self.length = torch.cat((self.length, length_pad), dim=0) @classmethod def stack(cls, states: Sequence[BitSequenceStates]) -> BitSequenceStates: @@ -290,13 +309,14 @@ class BitSequenceStatesImplementation(BitSequenceStates): make_random_states = env.make_random_states n_actions = env.n_actions word_size = env.word_size + words_per_seq = env.words_per_seq return BitSequenceStatesImplementation def states_from_tensor( self, tensor: torch.Tensor, length: Optional[torch.Tensor] = None ) -> BitSequenceStates: - """Wraps the supplied Tensor in a States instance & updates masks. + """Wraps the supplied Tensor in a States instance. Args: tensor: The tensor of shape `state_shape` representing the states. @@ -311,7 +331,6 @@ def states_from_tensor( states_instance = self.make_states_class()( tensor, length=length, debug=self.debug ) - self.update_masks(states_instance) return states_instance # In some cases overwritten by the user to support specific use-cases. @@ -339,34 +358,8 @@ def reset( batch_shape=batch_shape, random=False, sink=sink ) assert isinstance(states, BitSequenceStates) - self.update_masks(states) - return states - def update_masks(self, states: BitSequenceStates) -> None: - """Updates the forward and backward masks. - - Called automatically after each step. - - Args: - states: The states for which to update the masks. - """ - - is_done = states.length == self.words_per_seq - states.forward_masks = torch.ones_like( - states.forward_masks, dtype=torch.bool, device=states.device - ) - states.forward_masks[is_done, :-1] = False - states.forward_masks[~is_done, -1] = False - - is_sink = states.is_sink_state - - last_actions = states.tensor[~is_sink, states[~is_sink].length - 1] - states.backward_masks = torch.zeros_like( - states.backward_masks, dtype=torch.bool, device=states.device - ) - states.backward_masks[~is_sink, last_actions] = True - def step(self, states: BitSequenceStates, actions: Actions) -> BitSequenceStates: """Performs a step in the environment. @@ -425,7 +418,6 @@ def _step(self, states: BitSequenceStates, actions: Actions): new_states = super(DiscreteEnv, self)._step(states, actions) assert isinstance(new_states, BitSequenceStates) new_states.length = length + 1 - self.update_masks(new_states) return new_states def _backward_step(self, states: BitSequenceStates, actions: Actions): @@ -442,7 +434,6 @@ def _backward_step(self, states: BitSequenceStates, actions: Actions): new_states = super(DiscreteEnv, self)._backward_step(states, actions) assert isinstance(new_states, BitSequenceStates) new_states.length = length - 1 - self.update_masks(new_states) return new_states def make_modes_set(self, seed) -> torch.Tensor: @@ -789,29 +780,70 @@ def __init__( self.modes = self.make_modes_set(seed) # set of modes written as binary self.temperature = temperature - def update_masks(self, states: BitSequenceStates) -> None: - """Updates the forward and backward masks. + def make_states_class(self) -> type[BitSequenceStates]: + """Creates a BitSequenceStates class implementation for BitSequencePlus. - Args: - states: The states for which to update the masks. + Returns: + A BitSequenceStates class implementation with prepend-append mask logic. """ + env = self - is_done = states.length == self.words_per_seq - states.forward_masks = torch.ones_like( - states.forward_masks, dtype=torch.bool, device=states.device - ) - states.forward_masks[is_done, :-1] = False - states.forward_masks[~is_done, -1] = False + class BitSequencePlusStates(BitSequenceStates): + state_shape = (env.words_per_seq,) + s0 = env.s0 + sf = env.sf + make_random_states = env.make_random_states + n_actions = env.n_actions + word_size = env.word_size + words_per_seq = env.words_per_seq - is_sink = states.is_sink_state + def _compute_backward_masks(self) -> torch.Tensor: + """Computes backward masks for BitSequencePlus states. - last_actions = states.tensor[~is_sink, states[~is_sink].length - 1] - first_actions = states.tensor[~is_sink, 0] - states.backward_masks = torch.zeros_like( - states.backward_masks, dtype=torch.bool, device=states.device - ) - states.backward_masks[~is_sink, last_actions] = True - states.backward_masks[~is_sink, first_actions + (self.n_actions - 1) // 2] = True + Allows removing from either end (prepend or append). + """ + backward_masks = torch.zeros( + (*self.batch_shape, self.n_actions - 1), + dtype=torch.bool, + device=self.device, + ) + is_sink = self.is_sink_state + non_sink_flat = ~is_sink.flatten() + if non_sink_flat.any(): + flat_tensor = self.tensor.view(-1, *self.state_shape) + flat_length = self.length.view(-1) + flat_backward_masks = backward_masks.view(-1, self.n_actions - 1) + + non_sink_indices = torch.where(non_sink_flat)[0] + non_sink_lengths = flat_length[non_sink_indices] + + # Only process states with length > 0 + valid_mask = non_sink_lengths > 0 + if valid_mask.any(): + valid_indices = non_sink_indices[valid_mask] + valid_lengths = non_sink_lengths[valid_mask] + valid_tensors = flat_tensor[valid_indices] + + # Last action (remove from end) + last_actions = valid_tensors[ + torch.arange(len(valid_lengths), device=self.device), + valid_lengths - 1 + ] + flat_backward_masks[valid_indices, last_actions] = True + + # First action (remove from front) - shifted by (n_actions-1)//2 + first_actions = valid_tensors[:, 0] + flat_backward_masks[ + valid_indices, + first_actions + (env.n_actions - 1) // 2 + ] = True + + backward_masks = flat_backward_masks.view( + *self.batch_shape, self.n_actions - 1 + ) + return backward_masks + + return BitSequencePlusStates def step(self, states: BitSequenceStates, actions: Actions) -> BitSequenceStates: """Performs a step in the environment. diff --git a/src/gfn/gym/discrete_ebm.py b/src/gfn/gym/discrete_ebm.py index c3b46a09..795e36f8 100644 --- a/src/gfn/gym/discrete_ebm.py +++ b/src/gfn/gym/discrete_ebm.py @@ -120,17 +120,46 @@ def __init__( ) self.States: type[DiscreteStates] = self.States - def update_masks(self, states: DiscreteStates) -> None: - """Updates the masks of the states. - - Args: - states: The states to update the masks of. - """ - states.forward_masks[..., : self.ndim] = states.tensor == -1 - states.forward_masks[..., self.ndim : 2 * self.ndim] = states.tensor == -1 - states.forward_masks[..., -1] = torch.all(states.tensor != -1, dim=-1) - states.backward_masks[..., : self.ndim] = states.tensor == 0 - states.backward_masks[..., self.ndim : 2 * self.ndim] = states.tensor == 1 + def make_states_class(self) -> type[DiscreteStates]: + """Returns the DiscreteStates class for the DiscreteEBM environment.""" + env = self + + class DiscreteEBMStates(DiscreteStates): + state_shape = (env.ndim,) + s0 = env.s0 + sf = env.sf + make_random_states = env.make_random_states + n_actions = env.n_actions + + def _compute_forward_masks(self) -> torch.Tensor: + """Computes forward masks for DiscreteEBM states.""" + forward_masks = torch.zeros( + (*self.batch_shape, self.n_actions), + dtype=torch.bool, + device=self.device, + ) + # Action i in [0, ndim-1] replaces s[i] with 0 (only if s[i] == -1) + forward_masks[..., : env.ndim] = self.tensor == -1 + # Action i in [ndim, 2*ndim-1] replaces s[i-ndim] with 1 (only if s[i-ndim] == -1) + forward_masks[..., env.ndim : 2 * env.ndim] = self.tensor == -1 + # Exit action is only valid when state is complete (no -1s) + forward_masks[..., -1] = torch.all(self.tensor != -1, dim=-1) + return forward_masks + + def _compute_backward_masks(self) -> torch.Tensor: + """Computes backward masks for DiscreteEBM states.""" + backward_masks = torch.zeros( + (*self.batch_shape, self.n_actions - 1), + dtype=torch.bool, + device=self.device, + ) + # Backward action i in [0, ndim-1] sets s[i] back to -1 (only if s[i] == 0) + backward_masks[..., : env.ndim] = self.tensor == 0 + # Backward action i in [ndim, 2*ndim-1] sets s[i-ndim] back to -1 (only if s[i-ndim] == 1) + backward_masks[..., env.ndim : 2 * env.ndim] = self.tensor == 1 + return backward_masks + + return DiscreteEBMStates def make_random_states( self, diff --git a/src/gfn/gym/hypergrid.py b/src/gfn/gym/hypergrid.py index 9196aa11..49555f7c 100644 --- a/src/gfn/gym/hypergrid.py +++ b/src/gfn/gym/hypergrid.py @@ -146,20 +146,45 @@ def __init__( ) self.States: type[DiscreteStates] = self.States # for type checking - def update_masks(self, states: DiscreteStates) -> None: - """Updates the masks of the states. + def make_states_class(self) -> type[DiscreteStates]: + """Returns the DiscreteStates class for the HyperGrid environment.""" + env = self + + class HyperGridStates(DiscreteStates): + state_shape = env.state_shape + s0 = env.s0 + sf = env.sf + make_random_states = env.make_random_states + n_actions = env.n_actions + + def _compute_forward_masks(self) -> torch.Tensor: + """Computes forward masks for HyperGrid states. + + Not allowed to take any action beyond the environment height, + but allow early termination. + """ + # Create mask: True where action would go beyond height + at_height_limit = self.tensor == env.height - 1 + # Forward masks: all True except where at height limit + forward_masks = torch.ones( + (*self.batch_shape, self.n_actions), + dtype=torch.bool, + device=self.device, + ) + # Set non-exit actions to False where at height limit + # Exit action (last action) remains True + exit_mask = torch.zeros( + self.batch_shape + (1,), device=self.device, dtype=torch.bool + ) + full_mask = torch.cat([at_height_limit, exit_mask], dim=-1) + forward_masks[full_mask] = False + return forward_masks - Args: - states: The states to update the masks of. - """ - # Not allowed to take any action beyond the environment height, but - # allow early termination. - # TODO: do we need to handle the conditional case here? - states.set_nonexit_action_masks( - states.tensor == self.height - 1, - allow_exit=True, - ) - states.backward_masks = states.tensor != 0 + def _compute_backward_masks(self) -> torch.Tensor: + """Computes backward masks for HyperGrid states.""" + return self.tensor != 0 + + return HyperGridStates def make_random_states( self, diff --git a/src/gfn/gym/perfect_tree.py b/src/gfn/gym/perfect_tree.py index 6035b99a..51a0434b 100644 --- a/src/gfn/gym/perfect_tree.py +++ b/src/gfn/gym/perfect_tree.py @@ -79,6 +79,64 @@ def __init__( self.term_states, ) = self._build_tree() + def make_states_class(self) -> type[DiscreteStates]: + """Returns the DiscreteStates class for the PerfectBinaryTree environment.""" + env = self + + class PerfectBinaryTreeStates(DiscreteStates): + state_shape = (1,) + s0 = env.s0 + sf = env.sf + make_random_states = env.make_random_states + n_actions = env.n_actions + + def _compute_forward_masks(self) -> torch.Tensor: + """Computes forward masks for PerfectBinaryTree states.""" + forward_masks = torch.ones( + (*self.batch_shape, self.n_actions), + dtype=torch.bool, + device=self.device, + ) + + # Flatten the states and terminating states tensors for efficient comparison. + states_flat = self.tensor.view(-1, 1) + term_tensor = env.term_states.tensor.view(1, -1) + terminating_states_mask = (states_flat == term_tensor).any(dim=1) + + # Going from any node, we can choose action 0 or 1 + # Except terminating states where we must end the trajectory + # Reshape mask to match batch shape + terminating_states_mask = terminating_states_mask.view(self.batch_shape) + + # Non-terminating states: can take actions 0 or 1, but not exit + forward_masks[~terminating_states_mask, -1] = False + + # Terminating states: only exit action allowed + forward_masks[terminating_states_mask, :] = False + forward_masks[terminating_states_mask, -1] = True + + return forward_masks + + def _compute_backward_masks(self) -> torch.Tensor: + """Computes backward masks for PerfectBinaryTree states.""" + backward_masks = torch.zeros( + (*self.batch_shape, self.n_actions - 1), + dtype=torch.bool, + device=self.device, + ) + + initial_state_mask = (self.tensor == env.s0).view(self.batch_shape) + even_states = (self.tensor % 2 == 0).view(self.batch_shape) + + # Even states are to the right, so tied to action 1 + # Uneven states are to the left, tied to action 0 + backward_masks[even_states & ~initial_state_mask, 1] = True + backward_masks[~even_states & ~initial_state_mask, 0] = True + + return backward_masks + + return PerfectBinaryTreeStates + def _build_tree(self) -> tuple[dict, dict, DiscreteStates]: """Builds the tree and the transition tables. @@ -151,46 +209,6 @@ def step(self, states: DiscreteStates, actions: Actions) -> DiscreteStates: ).reshape(-1, 1) return self.States(next_states_tns) - def update_masks(self, states: DiscreteStates) -> None: - """Updates the masks of the states. - - Args: - states: The states to update the masks of. - """ - # Flatten the states and terminating states tensors for efficient comparison. - states_flat = states.tensor.view(-1, 1) - term_tensor = self.term_states.tensor.view(1, -1) - terminating_states_mask = (states_flat == term_tensor).any(dim=1) - initial_state_mask = (states.tensor == self.s0).view(-1) - even_states = (states.tensor % 2 == 0).view(-1) - - # Going from any node, we can choose action 0 or 1 - # Except terminating states where we must end the trajectory - not_term_mask = states.forward_masks[~terminating_states_mask] - not_term_mask[:, -1] = False - - term_mask = states.forward_masks[terminating_states_mask] - term_mask[:, :] = False - term_mask[:, -1] = True - - states.forward_masks[~terminating_states_mask] = not_term_mask - states.forward_masks[terminating_states_mask] = term_mask - - # Even states are to the right, so tied to action 1 - # Uneven states are to the left, tied to action 0 - even_mask = states.backward_masks[even_states] - odd_mask = states.backward_masks[~even_states] - - even_mask[:, 0] = False - even_mask[:, 1] = True - odd_mask[:, 0] = True - odd_mask[:, 1] = False - states.backward_masks[even_states] = even_mask - states.backward_masks[~even_states] = odd_mask - - # Initial state has no available backward action - states.backward_masks[initial_state_mask] = False - def get_states_indices(self, states: States): """Returns the indices of the states. diff --git a/src/gfn/gym/set_addition.py b/src/gfn/gym/set_addition.py index 04fef178..a0c9e2a7 100644 --- a/src/gfn/gym/set_addition.py +++ b/src/gfn/gym/set_addition.py @@ -60,6 +60,58 @@ def __init__( ) self.States: type[DiscreteStates] = self.States + def make_states_class(self) -> type[DiscreteStates]: + """Returns the DiscreteStates class for the SetAddition environment.""" + env = self + + class SetAdditionStates(DiscreteStates): + state_shape = (env.n_items,) + s0 = env.s0 + sf = env.sf + make_random_states = env.make_random_states + n_actions = env.n_actions + + def _compute_forward_masks(self) -> torch.Tensor: + """Computes forward masks for SetAddition states.""" + n_items_per_state = self.tensor.sum(dim=-1) + states_that_must_end = n_items_per_state >= env.max_traj_len + states_that_may_continue = (n_items_per_state < env.max_traj_len) & ( + n_items_per_state >= 0 + ) + + forward_masks = torch.zeros( + (*self.batch_shape, self.n_actions), + dtype=torch.bool, + device=self.device, + ) + + # For states that may continue: can add items not yet in set + forward_masks[states_that_may_continue, : env.n_items] = ( + self.tensor[states_that_may_continue] == 0 + ) + + # For states that must end: only exit action allowed + forward_masks[states_that_must_end, -1] = True + + # Allow exit action for all states if not fixed_length + if not env.fixed_length: + forward_masks[..., -1] = True + + return forward_masks + + def _compute_backward_masks(self) -> torch.Tensor: + """Computes backward masks for SetAddition states.""" + backward_masks = torch.zeros( + (*self.batch_shape, self.n_actions - 1), + dtype=torch.bool, + device=self.device, + ) + # Can remove items that are in the set + backward_masks[..., : env.n_items] = self.tensor != 0 + return backward_masks + + return SetAdditionStates + def get_states_indices(self, states: DiscreteStates): """Returns the indices of the states. @@ -77,50 +129,6 @@ def get_states_indices(self, states: DiscreteStates): indices = (canonical_base * states_raw).sum(-1).long() return indices - def update_masks(self, states: DiscreteStates) -> None: - """Updates the masks of the states. - - Args: - states: The states to update the masks of. - """ - n_items_per_state = states.tensor.sum(dim=-1) - states_that_must_end = n_items_per_state >= self.max_traj_len - states_that_may_continue = (n_items_per_state < self.max_traj_len) & ( - n_items_per_state >= 0 - ) - - cont_f_mask = torch.cat( - ( - (states.tensor[states_that_may_continue] == 0), - torch.zeros( - states.tensor[states_that_may_continue].shape[0], - 1, - dtype=torch.bool, - device=states.tensor.device, - ), - ), - 1, - ) - - states.forward_masks[states_that_may_continue] = cont_f_mask - # Disallow everything for trajs that must end - end_f_mask = torch.zeros( - states.tensor[states_that_must_end].shape[0], - states.forward_masks.shape[-1], - dtype=torch.bool, - device=states.tensor.device, - ) - end_f_mask[..., -1] = True - - states.forward_masks[states_that_must_end] = end_f_mask - - # Disallow everything for trajs that must end - # states.forward_masks[states_that_must_end, : self.n_items] = 0 - if not self.fixed_length: - states.forward_masks[..., -1] = 1 # Allow exit action - - states.backward_masks[..., : self.n_items] = states.tensor != 0 - def step(self, states: DiscreteStates, actions: Actions) -> DiscreteStates: """Performs a step in the environment. diff --git a/src/gfn/states.py b/src/gfn/states.py index ce881a93..a18865ad 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -630,21 +630,24 @@ def to(self, device: torch.device) -> States: class DiscreteStates(States, ABC): """Base class for states of discrete environments. - DiscreteStates are endowed with `forward_masks` and `backward_masks`: boolean - attributes representing which actions are allowed at each state. This is the mechanism - by which all elements of the library verifies the allowed actions at each state. + DiscreteStates provide `forward_masks` and `backward_masks` as cached properties + that compute which actions are allowed at each state on demand. This approach + (similar to GraphStates) makes slicing operations faster since masks don't need + to be sliced - they are recomputed only when accessed. + + Subclasses must implement `_compute_forward_masks` and `_compute_backward_masks` + to define the mask computation logic for their specific environment. Attributes: n_actions: Number of possible actions. device: The device on which the states are stored. - forward_masks: Boolean tensor indicating forward actions allowed at each state. - backward_masks: Boolean tensor indicating backward actions allowed at each state. + forward_masks: Property that returns boolean tensor of allowed forward actions. + backward_masks: Property that returns boolean tensor of allowed backward actions. Compile-related expectations: - - Inputs (state tensor and masks) should already be on the target device with - correct shapes; debug can be used to validate during development/tests. - - Mask helpers reset masks before applying new conditions; rely on this behavior - to avoid cross-step leakage. + - Inputs (state tensor) should already be on the target device with correct shapes; + debug can be used to validate during development/tests. + - Masks are computed on-demand and cached; cache is invalidated when needed. """ n_actions: ClassVar[int] @@ -652,21 +655,15 @@ class DiscreteStates(States, ABC): def __init__( self, tensor: torch.Tensor, - forward_masks: Optional[torch.Tensor] = None, - backward_masks: Optional[torch.Tensor] = None, conditions: Optional[torch.Tensor] = None, device: torch.device | None = None, debug: bool = False, ) -> None: - """Initializes a DiscreteStates container with a batch of states and masks. + """Initializes a DiscreteStates container with a batch of states. Args: tensor: Tensor of shape (*batch_shape, *state_shape) representing a batch of states. - forward_masks: Optional boolean tensor of shape (*batch_shape, n_actions) - indicating forward actions allowed at each state. - backward_masks: Optional boolean tensor of shape (*batch_shape, n_actions - 1) - indicating backward actions allowed at each state. conditions: Optional tensor of shape (*batch_shape, condition_dim) containing condition vectors for conditional GFlowNets. device: The device to store the states on. @@ -677,54 +674,105 @@ def __init__( # Keep shape validation in debug to avoid graph breaks in compiled regions. assert tensor.shape == self.batch_shape + self.state_shape - # In the usual case, no masks are provided and we produce these defaults. - # Note: this **must** be updated externally by the env. - if forward_masks is None: - forward_masks = torch.ones( - (*self.batch_shape, self.__class__.n_actions), - dtype=torch.bool, - device=self.device, - ) - else: - forward_masks = forward_masks.to(self.device) - if backward_masks is None: - backward_masks = torch.ones( - (*self.batch_shape, self.__class__.n_actions - 1), - dtype=torch.bool, - device=self.device, - ) - else: - backward_masks = backward_masks.to(self.device) + # Masks are computed on-demand and cached + self._forward_masks_cache: Optional[torch.Tensor] = None + self._backward_masks_cache: Optional[torch.Tensor] = None + + def _invalidate_masks_cache(self) -> None: + """Invalidates the cached masks, forcing recomputation on next access.""" + self._forward_masks_cache = None + self._backward_masks_cache = None + + @property + def forward_masks(self) -> torch.Tensor: + """Returns forward action masks, computing and caching if needed. + + Returns: + Boolean tensor of shape (*batch_shape, n_actions) indicating which + forward actions are allowed at each state. + """ + if self._forward_masks_cache is None: + self._forward_masks_cache = self._compute_forward_masks() + if self.debug: + assert self._forward_masks_cache.shape == ( + *self.batch_shape, + self.n_actions, + ) + return self._forward_masks_cache + + @forward_masks.setter + def forward_masks(self, value: torch.Tensor) -> None: + """Sets the forward masks cache directly. + + Args: + value: Boolean tensor of shape (*batch_shape, n_actions). + """ + self._forward_masks_cache = value + + @property + def backward_masks(self) -> torch.Tensor: + """Returns backward action masks, computing and caching if needed. + + Returns: + Boolean tensor of shape (*batch_shape, n_actions - 1) indicating which + backward actions are allowed at each state. + """ + if self._backward_masks_cache is None: + self._backward_masks_cache = self._compute_backward_masks() + if self.debug: + assert self._backward_masks_cache.shape == ( + *self.batch_shape, + self.n_actions - 1, + ) + return self._backward_masks_cache + + @backward_masks.setter + def backward_masks(self, value: torch.Tensor) -> None: + """Sets the backward masks cache directly. + + Args: + value: Boolean tensor of shape (*batch_shape, n_actions - 1). + """ + self._backward_masks_cache = value + + def _compute_forward_masks(self) -> torch.Tensor: + """Computes forward action masks for the current states. + + Must be implemented by subclasses to define environment-specific mask logic. - self.forward_masks: torch.Tensor = forward_masks - self.backward_masks: torch.Tensor = backward_masks + Returns: + Boolean tensor of shape (*batch_shape, n_actions). + """ + raise NotImplementedError( + f"{self.__class__.__name__} must implement _compute_forward_masks" + ) - assert self.forward_masks.shape == (*self.batch_shape, self.n_actions) - assert self.backward_masks.shape == (*self.batch_shape, self.n_actions - 1) + def _compute_backward_masks(self) -> torch.Tensor: + """Computes backward action masks for the current states. + + Must be implemented by subclasses to define environment-specific mask logic. + + Returns: + Boolean tensor of shape (*batch_shape, n_actions - 1). + """ + raise NotImplementedError( + f"{self.__class__.__name__} must implement _compute_backward_masks" + ) def clone(self) -> DiscreteStates: """Returns a clone of the current instance. Returns: - A new DiscreteStates object with the same data, masks, and conditions. + A new DiscreteStates object with the same data and conditions. + Masks are recomputed on demand for the cloned states. """ cloned = self.__class__( tensor=self.tensor.clone(), - forward_masks=self.forward_masks.clone(), - backward_masks=self.backward_masks.clone(), conditions=self.conditions.clone() if self.conditions is not None else None, debug=self.debug, ) return cloned - def _check_both_forward_backward_masks_exist(self): - # Only validate in debug to avoid graph breaks in compiled regions. - if self.debug: - if not torch.is_tensor(self.forward_masks): - raise TypeError("forward_masks must be tensors") - if not torch.is_tensor(self.backward_masks): - raise TypeError("backward_masks must be tensors") - def __repr__(self) -> str: """Returns a detailed string representation of the DiscreteStates object. @@ -736,7 +784,6 @@ def __repr__(self) -> str: f"batch={self.batch_shape},", f"state={self.state_shape},", f"actions={self.n_actions},", - f"masks={tuple(self.forward_masks.shape)},", ] if self.conditions is not None: parts.append(f"conditions={self.conditions.shape},") @@ -746,56 +793,50 @@ def __repr__(self) -> str: def __getitem__( self, index: int | slice | tuple | Sequence[int] | Sequence[bool] | torch.Tensor ) -> DiscreteStates: - """Returns a subset of the discrete states and their masks. + """Returns a subset of the discrete states. + + Masks are computed on demand for the new states rather than being sliced, + which makes this operation faster. Args: index: Indices to select states. Returns: - A new DiscreteStates object with the selected states, masks, and conditions. + A new DiscreteStates object with the selected states and conditions. """ states = self.tensor[index] - self._check_both_forward_backward_masks_exist() - forward_masks = self.forward_masks[index] - backward_masks = self.backward_masks[index] conditions = self.conditions[index] if self.conditions is not None else None - out = self.__class__( - states, forward_masks, backward_masks, conditions, debug=self.debug - ) + out = self.__class__(states, conditions, debug=self.debug) return out def __setitem__( self, index: int | Sequence[int] | Sequence[bool], states: DiscreteStates ) -> None: - """Sets particular discrete states and their masks. + """Sets particular discrete states. Args: index: Indices to set. - states: DiscreteStates object containing the new states and masks. + states: DiscreteStates object containing the new states. """ super().__setitem__(index, states) - self._check_both_forward_backward_masks_exist() - self.forward_masks[index] = states.forward_masks - self.backward_masks[index] = states.backward_masks + # Invalidate masks cache since underlying tensor has changed + self._invalidate_masks_cache() def flatten(self) -> DiscreteStates: - """Flattens the batch dimension of the discrete states and their masks. + """Flattens the batch dimension of the discrete states. + + Masks are computed on demand for the flattened states. Returns: A new DiscreteStates object with the batch dimension flattened. """ states = self.tensor.view(-1, *self.state_shape) - self._check_both_forward_backward_masks_exist() - forward_masks = self.forward_masks.view(-1, self.forward_masks.shape[-1]) - backward_masks = self.backward_masks.view(-1, self.backward_masks.shape[-1]) conditions = ( self.conditions.view(-1, self.conditions.shape[-1]) if self.conditions is not None else None ) - return self.__class__( - states, forward_masks, backward_masks, conditions, debug=self.debug - ) + return self.__class__(states, conditions, debug=self.debug) def extend(self, other: DiscreteStates) -> None: """Concatenates another DiscreteStates object along the batch dimension. @@ -818,13 +859,6 @@ def extend(self, other: DiscreteStates) -> None: ) self.tensor = torch.cat((self.tensor, other.tensor), dim=1) - self.forward_masks = torch.cat( - (self.forward_masks, other.forward_masks), dim=len(self.batch_shape) - 1 - ) - self.backward_masks = torch.cat( - (self.backward_masks, other.backward_masks), dim=len(self.batch_shape) - 1 - ) - if self.conditions is not None and other.conditions is not None: self.conditions = torch.cat( (self.conditions, other.conditions), dim=len(self.batch_shape) - 1 @@ -837,34 +871,23 @@ def extend(self, other: DiscreteStates) -> None: ) self.conditions = None + # Invalidate masks cache since underlying tensor has changed + self._invalidate_masks_cache() + def pad_dim0_with_sf(self, required_first_dim: int) -> None: - r"""Extends forward and backward masks along the first batch dimension. + r"""Extends states along the first batch dimension with sink states. - After extending the state along the first batch dimensions with $s_f$ by - `required_first_dim`, also extends both forward and backward masks with ones - along the first dimension by `required_first_dim`. + Given a batch of states (i.e. of `batch_shape=(a, b)`), extends `a` to a + DiscreteStates object of `batch_shape = (required_first_dim, b)`, by adding the + required number of $s_f$ tensors. This is useful to extend trajectories of + different lengths. Args: required_first_dim: The size of the first batch dimension post-expansion. """ super().pad_dim0_with_sf(required_first_dim) - - def _extend(masks, first_dim): - return torch.cat( - ( - masks, - torch.ones( - first_dim - masks.shape[0], - *masks.shape[1:], - dtype=torch.bool, - device=self.device, - ), - ), - dim=0, - ) - - self.forward_masks = _extend(self.forward_masks, required_first_dim) - self.backward_masks = _extend(self.backward_masks, required_first_dim) + # Invalidate masks cache since underlying tensor has changed + self._invalidate_masks_cache() @classmethod def stack(cls, states: Sequence[DiscreteStates]) -> DiscreteStates: @@ -874,112 +897,18 @@ def stack(cls, states: Sequence[DiscreteStates]) -> DiscreteStates: states: List of DiscreteStates objects to stack. Returns: - A new DiscreteStates object with the stacked states, masks, and conditions. + A new DiscreteStates object with the stacked states and conditions. + Masks are computed on demand for the stacked states. """ out = super().stack(states) # Note: conditions are already stacked by parent class assert isinstance(out, DiscreteStates) - out.forward_masks = torch.stack([s.forward_masks for s in states], dim=0).to( - out.device - ) - out.backward_masks = torch.stack([s.backward_masks for s in states], dim=0).to( - out.device - ) return out - # The helper methods are convenience functions for common mask operations. - def set_nonexit_action_masks( - self, - cond: torch.Tensor, - allow_exit: bool, - ) -> None: - """Masks denoting disallowed actions according to cond, appending the exit mask. - - A convenience function for common mask operations. - - Args: - cond: a boolean of shape (*batch_shape,) + (n_actions - 1,), which - denotes which actions are *not* allowed. For example, if a state element - represents action count, and no action can be repeated more than 5 - times, cond might be state.tensor > 5 (assuming count starts at 0). - allow_exit: sets whether exiting can happen at any point in the - trajectory - if so, it should be set to True. - - Notes: - - Always resets `forward_masks` to all True before applying the new mask - so updates do not leak across steps. - - Works for 1D or 2D batch shapes; cond must match `batch_shape`. - - Debug guards validate shape/dtype but should be off in compiled regions. - """ - if self.debug: - # Validate mask shape/dtype to catch silent misalignment during testing. - expected_shape = self.batch_shape + (self.n_actions - 1,) - if cond.shape != expected_shape: - raise ValueError( - f"cond must have shape {expected_shape}; got {cond.shape}" - ) - if cond.dtype is not torch.bool: - raise ValueError(f"cond must be boolean; got {cond.dtype}") - - # Resets masks in place to prevent side-effects across steps. - self.forward_masks[:] = True - exit_mask = torch.zeros( - self.batch_shape + (1,), device=cond.device, dtype=cond.dtype - ) - - if not allow_exit: - exit_mask.fill_(True) - - # Concatenate and mask in a single tensor op to stay torch.compile friendly. - # Sets the forward mask to be False where this concatenated mask is True. - self.forward_masks[torch.cat([cond, exit_mask], dim=-1)] = False - - def set_exit_masks(self, batch_idx: torch.Tensor) -> None: - """Sets forward masks such that the only allowable next action is to exit. - - A convenience function for common mask operations. - - Args: - batch_idx: A boolean index along the batch dimension, along which to - enforce exits. - - Notes: - - Works for 1D or 2D batch shapes; `batch_idx` must match `batch_shape`. - - Clears all actions for the selected batch entries, then sets only the - exit action True via masked_fill to stay torch.compile friendly. - - Does not move devices; expects masks/tensors already on the target device. - """ - if self.debug: - if batch_idx.shape != self.batch_shape: - raise ValueError( - f"batch_idx must have shape {self.batch_shape}; got {batch_idx.shape}" - ) - if batch_idx.dtype is not torch.bool: - raise ValueError(f"batch_idx must be boolean; got {batch_idx.dtype}") - - # Avoid Python .item() to stay torch.compile friendly. For any True entry in - # batch_idx (1D or 2D), zero all actions then set only the exit action True. - self.forward_masks[batch_idx] = False - # Use masked_fill on the last action slice to avoid advanced indexing graph breaks. - self.forward_masks[..., -1].masked_fill_(batch_idx, True) - - def init_forward_masks(self, set_ones: bool = True) -> None: - """Initalizes forward masks. - - A convienience function for common mask operations. - - Args: - set_ones: if True, forward masks are initalized to all ones. Otherwise, - they are initalized to all zeros. - """ - shape = self.batch_shape + (self.n_actions,) - if set_ones: - self.forward_masks = torch.ones(shape).to(self.device).bool() - else: - self.forward_masks = torch.zeros(shape).to(self.device).bool() - def to(self, device: torch.device) -> DiscreteStates: - """Moves the tensor and masks to the specified device in-place. + """Moves the tensor to the specified device in-place. + + Masks will be recomputed on the new device when accessed. Args: device: The device to move to. @@ -988,10 +917,10 @@ def to(self, device: torch.device) -> DiscreteStates: The DiscreteStates object on the specified device. """ self.tensor = self.tensor.to(device) - self.forward_masks = self.forward_masks.to(device) - self.backward_masks = self.backward_masks.to(device) if self.conditions is not None: self.conditions = self.conditions.to(device) + # Invalidate masks cache; they will be recomputed on the new device when accessed + self._invalidate_masks_cache() return self diff --git a/testing/test_states.py b/testing/test_states.py index dd1c5765..9d12d230 100644 --- a/testing/test_states.py +++ b/testing/test_states.py @@ -34,6 +34,22 @@ class SimpleDiscreteStates(DiscreteStates): s0 = torch.tensor([0.0, 0.0]) sf = torch.tensor([1.0, 1.0]) + def _compute_forward_masks(self) -> torch.Tensor: + """All forward actions allowed.""" + return torch.ones( + (*self.batch_shape, self.n_actions), + dtype=torch.bool, + device=self.device, + ) + + def _compute_backward_masks(self) -> torch.Tensor: + """All backward actions allowed.""" + return torch.ones( + (*self.batch_shape, self.n_actions - 1), + dtype=torch.bool, + device=self.device, + ) + class SimpleTensorStates(States): state_shape = (2,) # 2-dimensional state @@ -79,10 +95,7 @@ def simple_discrete_state(): """Creates a simple discrete state with 3 possible actions""" # Create a single state tensor tensor = torch.tensor([[0.5, 0.5]]) - forward_masks = torch.tensor([[True, True, True]]) # All actions allowed - backward_masks = torch.tensor([[True, True]]) # All backward actions allowed - - return SimpleDiscreteStates(tensor, forward_masks, backward_masks) + return SimpleDiscreteStates(tensor) @pytest.fixture @@ -90,10 +103,7 @@ def empty_discrete_state(): """Creates an empty discrete state""" # Create an empty state tensor tensor = torch.zeros((0, 2)) - forward_masks = torch.zeros((0, 3), dtype=torch.bool) - backward_masks = torch.zeros((0, 2), dtype=torch.bool) - - return SimpleDiscreteStates(tensor, forward_masks, backward_masks) + return SimpleDiscreteStates(tensor) @pytest.fixture From c4ce384d6c2d08a8921b9668afd3fa1109fe1b50 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Fri, 30 Jan 2026 17:22:12 +0100 Subject: [PATCH 2/7] restore mask util functions --- src/gfn/gym/bitSequence.py | 9 ++-- src/gfn/states.py | 108 ++++++++++++++++++++++++++++++++++--- testing/test_states.py | 6 +-- 3 files changed, 108 insertions(+), 15 deletions(-) diff --git a/src/gfn/gym/bitSequence.py b/src/gfn/gym/bitSequence.py index be8a6d6a..60ae070b 100644 --- a/src/gfn/gym/bitSequence.py +++ b/src/gfn/gym/bitSequence.py @@ -96,7 +96,9 @@ def _compute_backward_masks(self) -> torch.Tensor: last_actions = flat_tensor[valid_indices, valid_lengths - 1] # Set the backward mask for each state flat_backward_masks[valid_indices, last_actions] = True - backward_masks = flat_backward_masks.view(*self.batch_shape, self.n_actions - 1) + backward_masks = flat_backward_masks.view( + *self.batch_shape, self.n_actions - 1 + ) return backward_masks def clone(self) -> BitSequenceStates: @@ -827,15 +829,14 @@ def _compute_backward_masks(self) -> torch.Tensor: # Last action (remove from end) last_actions = valid_tensors[ torch.arange(len(valid_lengths), device=self.device), - valid_lengths - 1 + valid_lengths - 1, ] flat_backward_masks[valid_indices, last_actions] = True # First action (remove from front) - shifted by (n_actions-1)//2 first_actions = valid_tensors[:, 0] flat_backward_masks[ - valid_indices, - first_actions + (env.n_actions - 1) // 2 + valid_indices, first_actions + (env.n_actions - 1) // 2 ] = True backward_masks = flat_backward_masks.view( diff --git a/src/gfn/states.py b/src/gfn/states.py index a18865ad..a5425f3d 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -738,25 +738,29 @@ def backward_masks(self, value: torch.Tensor) -> None: def _compute_forward_masks(self) -> torch.Tensor: """Computes forward action masks for the current states. - Must be implemented by subclasses to define environment-specific mask logic. + By default, all forward actions are allowed. + Typically, this method should be overridden by subclasses to define environment-specific mask logic. Returns: Boolean tensor of shape (*batch_shape, n_actions). """ - raise NotImplementedError( - f"{self.__class__.__name__} must implement _compute_forward_masks" + return torch.ones( + self.batch_shape + (self.n_actions,), dtype=torch.bool, device=self.device ) def _compute_backward_masks(self) -> torch.Tensor: """Computes backward action masks for the current states. - Must be implemented by subclasses to define environment-specific mask logic. + By default, all backward actions are allowed. + Typically, this method should be overridden by subclasses to define environment-specific mask logic. Returns: Boolean tensor of shape (*batch_shape, n_actions - 1). """ - raise NotImplementedError( - f"{self.__class__.__name__} must implement _compute_backward_masks" + return torch.ones( + self.batch_shape + (self.n_actions - 1,), + dtype=torch.bool, + device=self.device, ) def clone(self) -> DiscreteStates: @@ -905,6 +909,98 @@ def stack(cls, states: Sequence[DiscreteStates]) -> DiscreteStates: assert isinstance(out, DiscreteStates) return out + # The helper methods are convenience functions for common mask operations. + + def set_nonexit_action_masks( + self, + cond: torch.Tensor, + allow_exit: bool, + ) -> None: + """Masks denoting disallowed actions according to cond, appending the exit mask. + + A convenience function for common mask operations. + + Args: + cond: a boolean of shape (*batch_shape,) + (n_actions - 1,), which + denotes which actions are *not* allowed. For example, if a state element + represents action count, and no action can be repeated more than 5 + times, cond might be state.tensor > 5 (assuming count starts at 0). + allow_exit: sets whether exiting can happen at any point in the + trajectory - if so, it should be set to True. + + Notes: + - Always resets `forward_masks` to all True before applying the new mask + so updates do not leak across steps. + - Works for 1D or 2D batch shapes; cond must match `batch_shape`. + - Debug guards validate shape/dtype but should be off in compiled regions. + """ + if self.debug: + # Validate mask shape/dtype to catch silent misalignment during testing. + expected_shape = self.batch_shape + (self.n_actions - 1,) + if cond.shape != expected_shape: + raise ValueError( + f"cond must have shape {expected_shape}; got {cond.shape}" + ) + if cond.dtype is not torch.bool: + raise ValueError(f"cond must be boolean; got {cond.dtype}") + + # Resets masks in place to prevent side-effects across steps. + self.forward_masks[:] = True + exit_mask = torch.zeros( + self.batch_shape + (1,), device=cond.device, dtype=cond.dtype + ) + + if not allow_exit: + exit_mask.fill_(True) + + # Concatenate and mask in a single tensor op to stay torch.compile friendly. + # Sets the forward mask to be False where this concatenated mask is True. + self.forward_masks[torch.cat([cond, exit_mask], dim=-1)] = False + + def set_exit_masks(self, batch_idx: torch.Tensor) -> None: + """Sets forward masks such that the only allowable next action is to exit. + + A convenience function for common mask operations. + + Args: + batch_idx: A boolean index along the batch dimension, along which to + enforce exits. + + Notes: + - Works for 1D or 2D batch shapes; `batch_idx` must match `batch_shape`. + - Clears all actions for the selected batch entries, then sets only the + exit action True via masked_fill to stay torch.compile friendly. + - Does not move devices; expects masks/tensors already on the target device. + """ + if self.debug: + if batch_idx.shape != self.batch_shape: + raise ValueError( + f"batch_idx must have shape {self.batch_shape}; got {batch_idx.shape}" + ) + if batch_idx.dtype is not torch.bool: + raise ValueError(f"batch_idx must be boolean; got {batch_idx.dtype}") + + # Avoid Python .item() to stay torch.compile friendly. For any True entry in + # batch_idx (1D or 2D), zero all actions then set only the exit action True. + self.forward_masks[batch_idx] = False + # Use masked_fill on the last action slice to avoid advanced indexing graph breaks. + self.forward_masks[..., -1].masked_fill_(batch_idx, True) + + def init_forward_masks(self, set_ones: bool = True) -> None: + """Initalizes forward masks. + + A convienience function for common mask operations. + + Args: + set_ones: if True, forward masks are initalized to all ones. Otherwise, + they are initalized to all zeros. + """ + shape = self.batch_shape + (self.n_actions,) + if set_ones: + self.forward_masks = torch.ones(shape).to(self.device).bool() + else: + self.forward_masks = torch.zeros(shape).to(self.device).bool() + def to(self, device: torch.device) -> DiscreteStates: """Moves the tensor to the specified device in-place. diff --git a/testing/test_states.py b/testing/test_states.py index 9d12d230..e28e38be 100644 --- a/testing/test_states.py +++ b/testing/test_states.py @@ -755,11 +755,7 @@ class SimpleDiscreteStates(DiscreteStates): sf = torch.ones(2) tensor = torch.zeros(batch_shape + SimpleDiscreteStates.state_shape) - fm = torch.ones(batch_shape + (SimpleDiscreteStates.n_actions,), dtype=torch.bool) - bm = torch.ones( - batch_shape + (SimpleDiscreteStates.n_actions - 1,), dtype=torch.bool - ) - return SimpleDiscreteStates(tensor, fm, bm, debug=True) + return SimpleDiscreteStates(tensor, debug=True) def test_set_nonexit_action_masks_resets_each_call_1d(): From 9a7b6e5d223de626090e02991502a8b97c4cc009 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Fri, 30 Jan 2026 17:36:07 +0100 Subject: [PATCH 3/7] fix test --- testing/test_estimators.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/testing/test_estimators.py b/testing/test_estimators.py index ca28c08a..cd66c04d 100644 --- a/testing/test_estimators.py +++ b/testing/test_estimators.py @@ -363,12 +363,9 @@ def test_discrete_policy_estimator_integration(): # Create test states batch_size = 8 states_tensor = torch.randint(0, env.height, (batch_size, env.ndim)) - forward_masks = torch.ones(batch_size, env.n_actions, dtype=torch.bool) - backward_masks = torch.ones(batch_size, env.n_actions - 1, dtype=torch.bool) # Create states using environment's States class - states = env.States(states_tensor, forward_masks, backward_masks) - env.update_masks(states) + states = env.States(states_tensor) # Test different parameter combinations test_params = [ From 93b1c10d9acf73c7598354c0d7c3dd30baf82bfb Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Fri, 30 Jan 2026 17:47:22 +0100 Subject: [PATCH 4/7] fix typing --- src/gfn/utils/distributed.py | 11 ++++++----- testing/test_states.py | 13 ++----------- 2 files changed, 8 insertions(+), 16 deletions(-) diff --git a/src/gfn/utils/distributed.py b/src/gfn/utils/distributed.py index 08bbf9b2..276105b5 100644 --- a/src/gfn/utils/distributed.py +++ b/src/gfn/utils/distributed.py @@ -255,10 +255,11 @@ def initialize_distributed_compute( training_ranks = [ r for r in range(num_training_ranks) ] # e.g., 0..num_training_ranks-1 - train_global_group = dist.new_group( + train_global_group = cast(dist.ProcessGroup, dist.new_group( ranks=training_ranks, - backend=dist_backend, - timeout=datetime.timedelta(minutes=5), + backend=dist_backend, + timeout=datetime.timedelta(minutes=5), + ), ) buffer_group = None @@ -268,11 +269,11 @@ def initialize_distributed_compute( buffer_ranks = list( range(num_training_ranks, num_training_ranks + num_remote_buffers) ) - buffer_group = dist.new_group( + buffer_group = cast(dist.ProcessGroup, dist.new_group( buffer_ranks, backend=dist_backend, timeout=datetime.timedelta(minutes=5), - ) + )) logger.info("Buffer group ranks: %s", buffer_ranks) # Each training rank gets assigned to a buffer rank diff --git a/testing/test_states.py b/testing/test_states.py index e28e38be..d3361752 100644 --- a/testing/test_states.py +++ b/testing/test_states.py @@ -901,13 +901,7 @@ class NoDebugDiscreteStates(DiscreteStates): @classmethod def make_random_states(cls, batch_shape, device=None): t = torch.zeros(batch_shape + cls.state_shape, device=device) - fm = torch.ones( - batch_shape + (cls.n_actions,), dtype=torch.bool, device=device - ) - bm = torch.ones( - batch_shape + (cls.n_actions - 1,), dtype=torch.bool, device=device - ) - return cls(t, fm, bm) + return cls(t) with pytest.raises(TypeError, match="must accept a `debug`"): NoDebugDiscreteStates.from_batch_shape((2,), random=True, debug=True) @@ -1004,11 +998,8 @@ class SimpleDiscreteStates(DiscreteStates): sf = torch.tensor([1.0, 1.0]) tensor = torch.tensor([[0.5, 0.5]], device=torch.device("cuda")) - # Provide masks on CUDA as well - forward_masks = torch.tensor([[True, True, True]], device=torch.device("cuda")) - backward_masks = torch.tensor([[True, True]], device=torch.device("cuda")) - state = SimpleDiscreteStates(tensor, forward_masks, backward_masks) + state = SimpleDiscreteStates(tensor) assert state.device.type == "cuda" assert state.forward_masks.device.type == "cuda" assert state.backward_masks.device.type == "cuda" From 0e2247e34ce9f03569a2b5206ac5fe9f3bde000b Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Fri, 30 Jan 2026 17:53:21 +0100 Subject: [PATCH 5/7] fix test state init --- testing/test_states.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/testing/test_states.py b/testing/test_states.py index d3361752..107d57d9 100644 --- a/testing/test_states.py +++ b/testing/test_states.py @@ -1207,11 +1207,7 @@ class SimpleDiscreteStates(DiscreteStates): sf = torch.tensor([1.0, 1.0]) A = SimpleDiscreteStates(torch.tensor([[0.5, 0.5]], device=cpu)) - B = SimpleDiscreteStates( - torch.tensor([[0.1, 0.2]], device=cuda), - torch.ones((1, 3), dtype=torch.bool, device=cuda), - torch.ones((1, 2), dtype=torch.bool, device=cuda), - ) + B = SimpleDiscreteStates(torch.tensor([[0.1, 0.2]], device=cuda)) assert A.device.type == "cpu" assert B.device.type == "cuda" # Mask devices are consistent with instance devices From 529d424ea3235190c9379d2368e3f3705c81cfa0 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Fri, 30 Jan 2026 18:53:37 +0100 Subject: [PATCH 6/7] fix tutorial --- src/gfn/utils/distributed.py | 19 +++--- tutorials/notebooks/intro_discrete.ipynb | 76 ++++++++++++++---------- 2 files changed, 55 insertions(+), 40 deletions(-) diff --git a/src/gfn/utils/distributed.py b/src/gfn/utils/distributed.py index 276105b5..7e109775 100644 --- a/src/gfn/utils/distributed.py +++ b/src/gfn/utils/distributed.py @@ -255,8 +255,10 @@ def initialize_distributed_compute( training_ranks = [ r for r in range(num_training_ranks) ] # e.g., 0..num_training_ranks-1 - train_global_group = cast(dist.ProcessGroup, dist.new_group( - ranks=training_ranks, + train_global_group = cast( + dist.ProcessGroup, + dist.new_group( + ranks=training_ranks, backend=dist_backend, timeout=datetime.timedelta(minutes=5), ), @@ -269,11 +271,14 @@ def initialize_distributed_compute( buffer_ranks = list( range(num_training_ranks, num_training_ranks + num_remote_buffers) ) - buffer_group = cast(dist.ProcessGroup, dist.new_group( - buffer_ranks, - backend=dist_backend, - timeout=datetime.timedelta(minutes=5), - )) + buffer_group = cast( + dist.ProcessGroup, + dist.new_group( + buffer_ranks, + backend=dist_backend, + timeout=datetime.timedelta(minutes=5), + ), + ) logger.info("Buffer group ranks: %s", buffer_ranks) # Each training rank gets assigned to a buffer rank diff --git a/tutorials/notebooks/intro_discrete.ipynb b/tutorials/notebooks/intro_discrete.ipynb index 784b5ea6..f5842b0a 100644 --- a/tutorials/notebooks/intro_discrete.ipynb +++ b/tutorials/notebooks/intro_discrete.ipynb @@ -1657,7 +1657,7 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": null, "metadata": { "id": "92RxW4V7aLk7" }, @@ -1701,38 +1701,48 @@ " # Sf represents when a trajectory is done (we selected the exit action).\n", " sf=torch.ones(state_dim, dtype=torch.float, device=device) * -1,\n", " )\n", - " self.States: type[DiscreteStates] = self.States\n", - "\n", - " def update_masks(self, states: type[DiscreteStates]) -> None:\n", - " \"Update the masks based on the current states.\"\n", - " # Backward masks are simply any action we've already taken.\n", - " states.backward_masks = states.tensor != 0 # n - 1 actions.\n", - "\n", - " # Forward masks begin as allowing any action. Allowed actions are 1.\n", - " states.init_forward_masks(set_ones=True)\n", - "\n", - " # Then, we remove any done action, and also the exit action.\n", - " states.set_nonexit_action_masks(states.tensor == 1, allow_exit=False)\n", - "\n", - " if self.mask_invalid_actions:\n", - " # Now we remove invalid actions. Here we are enforcing that\n", - " # only one left eyebrow, one right eyebrow, and one smile can be\n", - " # selected. 0 = not allowed.\n", - " #invalid_actions = torch.ones(states.batch_shape + states.state_shape).bool()\n", - " invalid_actions = torch.ones(states.forward_masks.shape).bool()\n", - " invalid_actions[..., 0][states.tensor[..., 1].bool()] = 0 # l_eb\n", - " invalid_actions[..., 1][states.tensor[..., 0].bool()] = 0 # l_eb\n", - " invalid_actions[..., 2][states.tensor[..., 3].bool()] = 0 # r_eb\n", - " invalid_actions[..., 3][states.tensor[..., 2].bool()] = 0 # r_eb\n", - " invalid_actions[..., 4][states.tensor[..., 5].bool()] = 0 # smile\n", - " invalid_actions[..., 5][states.tensor[..., 4].bool()] = 0 # smile\n", - "\n", - " states.forward_masks = (states.forward_masks * invalid_actions)\n", - "\n", - " # Trajectories must be length 3. Any trajectory that has taken 3 actions\n", - " # should be forced to exit.\n", - " batch_idx = states.tensor.sum(-1) >= 3\n", - " states.set_exit_masks(batch_idx)\n", + "\n", + " def make_states_class(self) -> type[DiscreteStates]:\n", + " base_class = super().make_states_class()\n", + " env = self\n", + " class SmilesStates(base_class):\n", + " def _compute_backward_masks(self) -> torch.Tensor:\n", + " return self.tensor != 0\n", + "\n", + " def _compute_forward_masks(self) -> torch.Tensor:\n", + " # Forward masks begin as allowing any action. Allowed actions are 1.\n", + " forward_masks = torch.ones(self.batch_shape + (env.n_actions,), dtype=torch.bool, device=self.device)\n", + "\n", + " # Then, we remove any done action, and also the exit action.\n", + " done_or_exit = torch.zeros_like(forward_masks)\n", + " done_or_exit[..., :-1] = self.tensor == 1 # done\n", + " done_or_exit[..., -1] = 1 # exit\n", + " forward_masks[done_or_exit] = False\n", + "\n", + " if env.mask_invalid_actions:\n", + " # Now we remove invalid actions. Here we are enforcing that\n", + " # only one left eyebrow, one right eyebrow, and one smile can be\n", + " # selected. 0 = not allowed.\n", + " #invalid_actions = torch.ones(states.batch_shape + states.state_shape).bool()\n", + " invalid_actions = torch.ones(forward_masks.shape).bool()\n", + " invalid_actions[..., 0][self.tensor[..., 1].bool()] = 0 # l_eb\n", + " invalid_actions[..., 1][self.tensor[..., 0].bool()] = 0 # l_eb\n", + " invalid_actions[..., 2][self.tensor[..., 3].bool()] = 0 # r_eb\n", + " invalid_actions[..., 3][self.tensor[..., 2].bool()] = 0 # r_eb\n", + " invalid_actions[..., 4][self.tensor[..., 5].bool()] = 0 # smile\n", + " invalid_actions[..., 5][self.tensor[..., 4].bool()] = 0 # smile\n", + "\n", + " forward_masks = forward_masks * invalid_actions\n", + "\n", + " # Trajectories must be length 3. Any trajectory that has taken 3 actions\n", + " # should be forced to exit.\n", + " batch_idx = self.tensor.sum(-1) >= 3\n", + " forward_masks[batch_idx] = False\n", + " forward_masks[..., -1].masked_fill_(batch_idx, True) \n", + " return forward_masks\n", + "\n", + " return SmilesStates\n", + "\n", "\n", " def step(\n", " self, states: DiscreteStates, actions: Actions\n", From a2331339f577666be28b2656d607488969e717f7 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Fri, 6 Feb 2026 14:59:24 +0100 Subject: [PATCH 7/7] fix test script --- tutorials/examples/test_scripts.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tutorials/examples/test_scripts.py b/tutorials/examples/test_scripts.py index 217d8546..cd53ab37 100644 --- a/tutorials/examples/test_scripts.py +++ b/tutorials/examples/test_scripts.py @@ -90,7 +90,10 @@ class HypergridArgs(CommonArgs): back_ratio: float = 0.5 distributed: bool = False diverse_replay_buffer: bool = False - epsilon: float = 0.1 + epsilon: float = 0.0 + temperature: float = 1.0 + n_noisy_layers: int = 0 + noisy_std_init: float = 0.5 height: int = 8 loss: str = "TB" lr_logz: float = 1e-3