Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 2 additions & 20 deletions src/gfn/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,7 +603,7 @@ def __init__(
def states_from_tensor(
self, tensor: torch.Tensor, conditions: torch.Tensor | None = None
) -> DiscreteStates:
"""Wraps the supplied tensor in a DiscreteStates instance and updates masks.
"""Wraps the supplied tensor in a DiscreteStates instance.

Args:
tensor: Tensor of shape (*batch_shape, *state_shape) representing the states.
Expand All @@ -617,7 +617,6 @@ def states_from_tensor(
DiscreteStates,
self.States(tensor=tensor, conditions=conditions, debug=self.debug),
)
self.update_masks(states_instance)
return states_instance

def states_from_batch_shape(
Expand Down Expand Up @@ -649,7 +648,7 @@ def reset(
seed: Optional[int] = None,
conditions: Optional[torch.Tensor] = None,
) -> DiscreteStates:
"""Instantiates a batch of random, initial, or sink states and updates masks.
"""Instantiates a batch of random, initial, or sink states.

Args:
batch_shape: Shape of the batch (int or tuple).
Expand All @@ -664,19 +663,8 @@ def reset(
"""
states = super().reset(batch_shape, random, sink, seed, conditions=conditions)
states = cast(DiscreteStates, states)
self.update_masks(states)
return states

@abstractmethod
def update_masks(self, states: DiscreteStates) -> None:
"""Updates the masks in DiscreteStates.

Called automatically after each step for discrete environments.

Args:
states: The DiscreteStates object whose masks will be updated.
"""

def make_states_class(self) -> type[DiscreteStates]:
"""Returns the DiscreteStates class for this environment.

Expand Down Expand Up @@ -736,8 +724,6 @@ def is_action_valid(
def _step(self, states: DiscreteStates, actions: Actions) -> DiscreteStates:
"""Wrapper for the user-defined `step` function.

This calls the `_step` method of the parent class and updates masks.

Args:
states: The batch of discrete states.
actions: The batch of actions.
Expand All @@ -747,14 +733,11 @@ def _step(self, states: DiscreteStates, actions: Actions) -> DiscreteStates:
"""
new_states = super()._step(states, actions)
new_states = cast(DiscreteStates, new_states)
self.update_masks(new_states)
return new_states

def _backward_step(self, states: DiscreteStates, actions: Actions) -> DiscreteStates:
"""Wrapper for the user-defined `backward_step` function.

This calls the `_backward_step` method of the parent class and updates masks.

Args:
states: The batch of discrete states.
actions: The batch of actions.
Expand All @@ -764,7 +747,6 @@ def _backward_step(self, states: DiscreteStates, actions: Actions) -> DiscreteSt
"""
new_states = super()._backward_step(states, actions)
new_states = cast(DiscreteStates, new_states)
self.update_masks(new_states)
return new_states

def get_terminating_state_dist(self, states: DiscreteStates) -> torch.Tensor:
Expand Down
Loading