diff --git a/pyproject.toml b/pyproject.toml index 730d9858..568397ad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,8 @@ einops = ">=0.6.1" numpy = ">=1.21.2" python = "^3.10" torch = ">=2.6.0" +tensordict = ">=0.6.1" +torch_geometric = ">=2.6.1" # dev dependencies. black = { version = "24.3", optional = true } @@ -99,6 +101,7 @@ include = '\.pyi?$' extend-exclude = '''/(\.git|\.hg|\.mypy_cache|\.ipynb|\.tox|\.venv|build)/g''' [tool.pyright] +pythonVersion = "3.10" include = ["src/gfn", "tutorials/examples", "testing"] # Removed ** globstars exclude = [ "**/node_modules", @@ -106,10 +109,14 @@ exclude = [ "**/.*", # Exclude dot files and folders ] +strict = [ + +] +# This is required as the CI pre-commit does not dl the module (i.e. numpy) +# Therefore, we have to ignore missing imports # Removed "strict": [], as it's redundant with typeCheckingMode typeCheckingMode = "basic" -pythonVersion = "3.10" # Removed enableTypeIgnoreComments, not available in pyproject.toml, and bad practice. @@ -124,6 +131,9 @@ reportUntypedFunctionDecorator = "none" reportMissingTypeStubs = false reportUnboundVariable = "warning" reportGeneralTypeIssues = "none" +reportAttributeAccessIssue = false + +[tool.pytest.ini_options] reportOptionalMemberAccess = "error" reportArgumentType = "error" #This setting doesn't exist, removed. diff --git a/src/gfn/actions.py b/src/gfn/actions.py index 0f1007a2..a370e01a 100644 --- a/src/gfn/actions.py +++ b/src/gfn/actions.py @@ -1,10 +1,12 @@ from __future__ import annotations # This allows to use the class name in type hints +import enum from abc import ABC from math import prod from typing import ClassVar, List, Sequence import torch +from tensordict import TensorDict class Actions(ABC): @@ -134,7 +136,7 @@ def extend_with_dummy_actions(self, required_first_dim: int) -> None: "extend_with_dummy_actions is only implemented for bi-dimensional actions." ) - def compare(self, other: torch.Tensor) -> torch.Tensor: + def _compare(self, other: torch.Tensor) -> torch.Tensor: """Compares the actions to a tensor of actions. Args: @@ -161,7 +163,7 @@ def is_dummy(self) -> torch.Tensor: dummy_actions_tensor = self.__class__.dummy_action.repeat( *self.batch_shape, *((1,) * len(self.__class__.action_shape)) ) - return self.compare(dummy_actions_tensor) + return self._compare(dummy_actions_tensor) @property def is_exit(self) -> torch.Tensor: @@ -169,4 +171,142 @@ def is_exit(self) -> torch.Tensor: exit_actions_tensor = self.__class__.exit_action.repeat( *self.batch_shape, *((1,) * len(self.__class__.action_shape)) ) - return self.compare(exit_actions_tensor) + return self._compare(exit_actions_tensor) + + +class GraphActionType(enum.IntEnum): + ADD_NODE = 0 + ADD_EDGE = 1 + EXIT = 2 + DUMMY = 3 + + +class GraphActions(Actions): + """Actions for graph-based environments. + + Each action is one of: + - ADD_NODE: Add a node with given features + - ADD_EDGE: Add an edge between two nodes with given features + - EXIT: Terminate the trajectory + + Attributes: + features_dim: Dimension of node/edge features + tensor: TensorDict containing: + - action_type: Type of action (GraphActionType) + - features: Features for nodes/edges + - edge_index: Source/target nodes for edges + """ + + features_dim: ClassVar[int] + + def __init__(self, tensor: TensorDict): + """Initializes a GraphAction object. + + Args: + action: a GraphActionType indicating the type of action. + features: a tensor of shape (batch_shape, feature_shape) representing the features + of the nodes or of the edges, depending on the action type. In case of EXIT + action, this can be None. + edge_index: an tensor of shape (batch_shape, 2) representing the edge to add. + This must defined if and only if the action type is GraphActionType.AddEdge. + """ + self.batch_shape = tensor["action_type"].shape + features = tensor.get("features", None) + if features is None: + assert torch.all( + torch.logical_or( + tensor["action_type"] == GraphActionType.EXIT, + tensor["action_type"] == GraphActionType.DUMMY, + ) + ) + features = torch.zeros((*self.batch_shape, self.features_dim)) + edge_index = tensor.get("edge_index", None) + if edge_index is None: + assert torch.all(tensor["action_type"] != GraphActionType.ADD_EDGE) + edge_index = torch.zeros((*self.batch_shape, 2), dtype=torch.long) + + self.tensor = TensorDict( + { + "action_type": tensor["action_type"], + "features": features, + "edge_index": edge_index, + }, + batch_size=self.batch_shape, + ) + + def __repr__(self): + return f"""GraphAction object with {self.batch_shape} actions.""" + + def _compare(self, other: GraphActions) -> torch.Tensor: + """Compares the actions to another GraphAction object. + + Args: + other: GraphAction object to compare. + + Returns: boolean tensor of shape batch_shape indicating whether the actions are equal. + """ + action_compare = torch.all( + self.tensor["action_type"] == other.tensor["action_type"] + ) + exit_compare = ( + torch.all(self.tensor["features"] == other.tensor["features"]) + | action_compare + == GraphActionType.EXIT + ) + edge_compare = (action_compare != GraphActionType.ADD_EDGE) | ( + torch.all(self.tensor["edge_index"] == other.tensor["edge_index"]) + ) + return action_compare & exit_compare & edge_compare + + @property + def is_exit(self) -> torch.Tensor: + """Returns a boolean tensor of shape `batch_shape` indicating whether the actions are exit actions.""" + return self.action_type == GraphActionType.EXIT + + @property + def is_dummy(self) -> torch.Tensor: + """Returns a boolean tensor of shape `batch_shape` indicating whether the actions are dummy actions.""" + return self.action_type == GraphActionType.DUMMY + + @property + def action_type(self) -> torch.Tensor: + """Returns the action type tensor.""" + return self.tensor["action_type"] + + @property + def features(self) -> torch.Tensor: + """Returns the features tensor.""" + return self.tensor["features"] + + @property + def edge_index(self) -> torch.Tensor: + """Returns the edge index tensor.""" + return self.tensor["edge_index"] + + @classmethod + def make_dummy_actions(cls, batch_shape: tuple[int]) -> GraphActions: + """Creates a GraphActions object of dummy actions with the given batch shape.""" + return cls( + TensorDict( + { + "action_type": torch.full( + batch_shape, fill_value=GraphActionType.DUMMY + ), + }, + batch_size=batch_shape, + ) + ) + + @classmethod + def make_exit_actions(cls, batch_shape: tuple[int]) -> Actions: + """Creates an GraphActions object of exit actions with the given batch shape.""" + return cls( + TensorDict( + { + "action_type": torch.full( + batch_shape, fill_value=GraphActionType.EXIT + ), + }, + batch_size=batch_shape, + ) + ) diff --git a/src/gfn/containers/__init__.py b/src/gfn/containers/__init__.py index 60da5da0..7570c91c 100644 --- a/src/gfn/containers/__init__.py +++ b/src/gfn/containers/__init__.py @@ -1,11 +1,11 @@ from .base import Container -from .replay_buffer import PrioritizedReplayBuffer, ReplayBuffer +from .replay_buffer import NormBasedDiversePrioritizedReplayBuffer, ReplayBuffer from .state_pairs import StatePairs from .trajectories import Trajectories from .transitions import Transitions __all__ = [ - "PrioritizedReplayBuffer", + "NormBasedDiversePrioritizedReplayBuffer", "ReplayBuffer", "StatePairs", "Trajectories", diff --git a/src/gfn/containers/replay_buffer.py b/src/gfn/containers/replay_buffer.py index 877901bc..81ac83bb 100644 --- a/src/gfn/containers/replay_buffer.py +++ b/src/gfn/containers/replay_buffer.py @@ -1,7 +1,7 @@ from __future__ import annotations import os -from typing import Generic, TypeVar, cast +from typing import Protocol, Union, cast, runtime_checkable import torch @@ -9,36 +9,47 @@ from gfn.containers.trajectories import Trajectories from gfn.containers.transitions import Transitions from gfn.env import Env -from gfn.states import DiscreteStates -ContainerType = TypeVar( - "ContainerType", Trajectories, Transitions, StatePairs[DiscreteStates] -) +@runtime_checkable +class Container(Protocol): + def __getitem__(self, idx): ... # noqa: E704 -class ReplayBuffer(Generic[ContainerType]): - """A replay buffer for GFlowNet training. + def extend(self, other): ... # noqa: E704 + + def __len__(self) -> int: ... # noqa: E704 + + @property + def log_rewards(self) -> torch.Tensor | None: ... # noqa: E704 + + @property + def last_states(self): ... # noqa: E704 - The buffer stores training objects (trajectories, transitions, or state pairs) - and provides functionality to add new objects and sample from the buffer. - When the buffer is full, new objects replace old ones in a FIFO manner. - """ + +ContainerUnion = Union[Trajectories, Transitions, StatePairs] +ValidContainerTypes = (Trajectories, Transitions, StatePairs) + + +class ReplayBuffer: + """A replay buffer of trajectories, transitions, or states.""" def __init__( self, env: Env, capacity: int = 1000, + prioritized: bool = False, ): - """Instantiates a replay buffer. - - Args: - env: the Environment instance. - capacity: the size of the buffer. - """ self.env = env self.capacity = capacity self._is_full = False - self.training_objects: ContainerType | None = None + self.training_objects: ContainerUnion | None = None + self.prioritized = prioritized + + def add(self, training_objects: ContainerUnion) -> None: + """Adds a training object to the buffer.""" + if not isinstance(training_objects, ValidContainerTypes): # type: ignore + raise TypeError("Must be a container type") + self._add_objs(training_objects) def __repr__(self): if self.training_objects is None: @@ -52,36 +63,49 @@ def __repr__(self): def __len__(self): return 0 if self.training_objects is None else len(self.training_objects) - def initialize(self, training_objects: ContainerType) -> None: + def initialize(self, training_objects: ContainerUnion) -> None: """Initializes the buffer with a training object.""" # Initialize with the same type as first added objects if isinstance(training_objects, Trajectories): - self.training_objects = cast(ContainerType, Trajectories(self.env)) + self.training_objects = cast(ContainerUnion, Trajectories(self.env)) elif isinstance(training_objects, Transitions): - self.training_objects = cast(ContainerType, Transitions(self.env)) + self.training_objects = cast(ContainerUnion, Transitions(self.env)) elif isinstance(training_objects, StatePairs): - self.training_objects = cast(ContainerType, StatePairs(self.env)) + self.training_objects = cast(ContainerUnion, StatePairs(self.env)) else: raise ValueError(f"Unsupported type: {type(training_objects)}") - def add(self, training_objects: ContainerType) -> None: - """Adds a batch of training objects to the buffer.""" + def _add_objs(self, training_objects: ContainerUnion): + """Adds a training object to the buffer.""" if self.training_objects is None: self.initialize(training_objects) assert self.training_objects is not None + assert isinstance(training_objects, type(self.training_objects)) # type: ignore - to_add = len(training_objects) - self._is_full |= len(self) + to_add >= self.capacity + # Adds the objects to the buffer. + self.training_objects.extend(training_objects) # type: ignore + + # Sort elements by log reward, capping the size at the defined capacity. + if self.prioritized: + if ( + self.training_objects.log_rewards is None + or training_objects.log_rewards is None + ): + raise ValueError("log_rewards must be defined for prioritized replay.") + + # Ascending sort. + ix = torch.argsort(self.training_objects.log_rewards) + self.training_objects = cast(ContainerUnion, self.training_objects[ix]) # type: ignore - self.training_objects.extend(training_objects) - self.training_objects = self.training_objects[-self.capacity :] + assert self.training_objects is not None + self.training_objects = cast(ContainerUnion, self.training_objects[-self.capacity :]) # type: ignore - def sample(self, n_trajectories: int) -> ContainerType: + def sample(self, n_trajectories: int) -> ContainerUnion: """Samples `n_trajectories` training objects from the buffer.""" if self.training_objects is None: raise ValueError("Buffer is empty") - return cast(ContainerType, self.training_objects.sample(n_trajectories)) + return cast(ContainerUnion, self.training_objects.sample(n_trajectories)) def save(self, directory: str): """Saves the buffer to disk.""" @@ -94,8 +118,8 @@ def load(self, directory: str): self.training_objects.load(os.path.join(directory, "training_objects")) -class PrioritizedReplayBuffer(ReplayBuffer[ContainerType]): - """A replay buffer of trajectories or transitions. +class NormBasedDiversePrioritizedReplayBuffer(ReplayBuffer): + """A replay buffer of trajectories or transitions with diverse trajectories. Attributes: env: the Environment instance. @@ -109,7 +133,7 @@ class PrioritizedReplayBuffer(ReplayBuffer[ContainerType]): def __init__( self, - env: "Env", + env: Env, capacity: int = 1000, cutoff_distance: float = 0.0, p_norm_distance: float = 1.0, @@ -128,30 +152,17 @@ def __init__( super().__init__(env, capacity) self.cutoff_distance = cutoff_distance self.p_norm_distance = p_norm_distance + self._prioritized = True - def _add_objs( - self, - training_objects: ContainerType, - ): + @property + def prioritized(self) -> bool: + return self._prioritized + + def add(self, training_objects: ContainerUnion): """Adds a training object to the buffer.""" - # Adds the objects to the buffer. - self.initialize(training_objects) - assert self.training_objects is not None - self.training_objects.extend(training_objects) - - # Sort elements by logreward, capping the size at the defined capacity. - assert self.training_objects.log_rewards is not None - ix = torch.argsort(self.training_objects.log_rewards) - self.training_objects = cast(ContainerType, self.training_objects[ix]) - self.training_objects = cast( - ContainerType, self.training_objects[-self.capacity :] - ) + if not isinstance(training_objects, ValidContainerTypes): # type: ignore + raise TypeError("Must be a container type") - def add( - self, - training_objects: ContainerType, - ): - """Adds a batch of training objects to the buffer.""" to_add = len(training_objects) self._is_full |= len(self) + to_add >= self.capacity @@ -166,9 +177,9 @@ def add( if log_rewards is None: raise ValueError("log_rewards must be defined for prioritized replay.") - # Sort the incoming elements by their log rewards. + # Sort the incoming elements by their logrewards. ix = torch.argsort(log_rewards, descending=True) - training_objects = training_objects[ix] + training_objects = cast(ContainerUnion, training_objects[ix]) # type: ignore # Filter all batch logrewards lower than the smallest logreward in buffer. assert ( @@ -181,7 +192,10 @@ def add( training_objects = training_objects[idx_bigger_rewards] # TODO: Concatenate input with final state for conditional GFN. - # if self.is_conditional: + if self.is_conditional: + raise NotImplementedError( + "{instance.__class__.__name__} does not yet support conditional GFNs." + ) # batch = torch.cat( # [dict_curr_batch["input"], dict_curr_batch["final_state"]], # dim=-1, @@ -230,7 +244,9 @@ def add( # Filter the batch for diverse final_states w.r.t the buffer. idx_batch_buffer = batch_buffer_dist > self.cutoff_distance - training_objects = training_objects[idx_batch_buffer] + training_objects = cast( + ContainerUnion, training_objects[idx_batch_buffer] + ) # If any training object remain after filtering, add them. if len(training_objects): diff --git a/src/gfn/containers/trajectories.py b/src/gfn/containers/trajectories.py index 1603bb83..c69d0348 100644 --- a/src/gfn/containers/trajectories.py +++ b/src/gfn/containers/trajectories.py @@ -33,7 +33,8 @@ class Trajectories(Container): when_is_done: Tensor of shape (n_trajectories,) indicating the time step at which each trajectory ends. is_backward: Whether the trajectories are backward or forward. log_rewards: Tensor of shape (n_trajectories,) containing the log rewards of the trajectories. - log_probs: Tensor of shape (max_length, n_trajectories) indicating the log probabilities of the trajectories' actions. + log_probs: Tensor of shape (max_length, n_trajectories) indicating the log probabilities of the + trajectories' actions. """ @@ -57,7 +58,8 @@ def __init__( when_is_done: Tensor of shape (n_trajectories,) indicating the time step at which each trajectory ends. is_backward: Whether the trajectories are backward or forward. log_rewards: Tensor of shape (n_trajectories,) containing the log rewards of the trajectories. - log_probs: Tensor of shape (max_length, n_trajectories) indicating the log probabilities of the trajectories' actions. + log_probs: Tensor of shape (max_length, n_trajectories) indicating the log probabilities of + the trajectories' actions. estimator_outputs: Tensor of shape (batch_shape, output_dim). When forward sampling off-policy for an n-step trajectory, n forward passes will be made on some function approximator, @@ -103,25 +105,37 @@ def __init__( assert ( log_probs.shape == (self.max_length, self.n_trajectories) and log_probs.dtype == torch.float - ) + ), f"log_probs.shape={log_probs.shape}, " + f"self.max_length={self.max_length}, " + f"self.n_trajectories={self.n_trajectories}" else: log_probs = torch.full(size=(0, 0), fill_value=0, dtype=torch.float) self.log_probs: torch.Tensor = log_probs self.estimator_outputs = estimator_outputs if self.estimator_outputs is not None: - # assert self.estimator_outputs.shape[:len(self.states.batch_shape)] == self.states.batch_shape TODO: check why fails + # TODO: check why this fails. + # assert self.estimator_outputs.shape[:len(self.states.batch_shape)] == self.states.batch_shape assert self.estimator_outputs.dtype == torch.float def __repr__(self) -> str: states = self.states.tensor.transpose(0, 1) assert states.ndim == 3 trajectories_representation = "" + assert isinstance( + self.env.s0, torch.Tensor + ), "not supported for Graph trajectories." + assert isinstance( + self.env.sf, torch.Tensor + ), "not supported for Graph trajectories." + for traj in states[:10]: one_traj_repr = [] for step in traj: one_traj_repr.append(str(step.cpu().numpy())) - if step.equal(self.env.s0 if self.is_backward else self.env.sf): + if self.is_backward and step.equal(self.env.s0): + break + elif not self.is_backward and step.equal(self.env.sf): break trajectories_representation += "-> ".join(one_traj_repr) + "\n" return ( @@ -255,7 +269,7 @@ def extend(self, other: Trajectories) -> None: # TODO: The replay buffer is storing `dones` - this wastes a lot of space. self.actions.extend(other.actions) - self.states.extend(other.states) + self.states.extend(other.states) # n_trajectories comes from this. self.when_is_done = torch.cat((self.when_is_done, other.when_is_done), dim=0) # For log_probs, we first need to make the first dimensions of self.log_probs @@ -482,6 +496,10 @@ def reverse_backward_trajectories(self, debug: bool = False) -> Trajectories: new_actions[torch.arange(len(self)), seq_lengths] = self.env.exit_action # Assign reversed states to new_states + assert isinstance(states[:, -1], torch.Tensor) + assert isinstance( + self.env.s0, torch.Tensor + ), "reverse_backward_trajectories not supported for Graph trajectories" assert torch.all(states[:, -1] == self.env.s0), "Last state must be s0" new_states[:, 0] = self.env.s0 new_states[:, 1:-1][mask] = states[:, :-1][mask][rev_idx[mask]] @@ -536,7 +554,7 @@ def reverse_backward_trajectories(self, debug: bool = False) -> Trajectories: def pad_dim0_to_target(a: torch.Tensor, target_dim0: int) -> torch.Tensor: - """Pads tensor a to match the dimention of b.""" + """Pads tensor a to match the dimension of b.""" assert a.shape[0] < target_dim0, "a is already larger than target_dim0!" pad_dim = target_dim0 - a.shape[0] pad_dim_full = (pad_dim,) + tuple(a.shape[1:]) diff --git a/src/gfn/containers/transitions.py b/src/gfn/containers/transitions.py index ccad21b8..8fe14f50 100644 --- a/src/gfn/containers/transitions.py +++ b/src/gfn/containers/transitions.py @@ -56,8 +56,8 @@ def __init__( the children of the transitions. is_backward: Whether the transitions are backward transitions (i.e. `next_states` is the parent of states). - log_rewards: Tensor of shape (n_transitions,) containing the log-rewards of the transitions (using a default value like - `-float('inf')` for non-terminating transitions). + log_rewards: Tensor of shape (n_transitions,) containing the log-rewards of the transitions (using a + default value like `-float('inf')` for non-terminating transitions). log_probs: Tensor of shape (n_transitions,) containing the log-probabilities of the actions. Raises: diff --git a/src/gfn/env.py b/src/gfn/env.py index e24906b2..82a81358 100644 --- a/src/gfn/env.py +++ b/src/gfn/env.py @@ -1,18 +1,20 @@ from abc import ABC, abstractmethod -from typing import Optional, Tuple, Union, cast +from typing import Optional, Tuple, cast import torch +from torch_geometric.data import Batch as GeometricBatch +from torch_geometric.data import Data as GeometricData -from gfn.actions import Actions +from gfn.actions import Actions, GraphActions from gfn.preprocessors import IdentityPreprocessor, Preprocessor -from gfn.states import DiscreteStates, States +from gfn.states import DiscreteStates, GraphStates, States from gfn.utils.common import set_seed # Errors NonValidActionsError = type("NonValidActionsError", (ValueError,), {}) -def get_device(device_str, default_device): +def get_device(device_str, default_device) -> torch.device: return torch.device(device_str) if device_str is not None else default_device @@ -22,12 +24,12 @@ class Env(ABC): def __init__( self, - s0: torch.Tensor, + s0: torch.Tensor | GeometricData, state_shape: Tuple, action_shape: Tuple, dummy_action: torch.Tensor, exit_action: torch.Tensor, - sf: Optional[torch.Tensor] = None, + sf: Optional[torch.Tensor | GeometricData] = None, device_str: Optional[str] = None, preprocessor: Optional[Preprocessor] = None, ): @@ -50,16 +52,20 @@ def __init__( """ self.device = get_device(device_str, default_device=s0.device) - self.s0 = s0.to(self.device) + self.s0 = s0.to(self.device) # type: ignore assert s0.shape == state_shape + if sf is None: sf = torch.full(s0.shape, -float("inf")).to(self.device) - self.sf: torch.Tensor = sf + + self.sf = sf + assert self.sf is not None assert self.sf.shape == state_shape + self.state_shape = state_shape self.action_shape = action_shape - self.dummy_action = dummy_action - self.exit_action = exit_action + self.dummy_action = dummy_action.to(self.device) + self.exit_action = exit_action.to(self.device) # Warning: don't use self.States or self.Actions to initialize an instance of the class. # Use self.states_from_tensor or self.actions_from_tensor instead. @@ -94,8 +100,8 @@ def states_from_batch_shape( Args: batch_shape: Tuple representing the shape of the batch of states. - random (optional): Initalize states randomly. - sink (optional): States initialized with s_f (the sink state). + random (optional): Initialize states randomly. + sink (optional): States initialized with sf (the sink state). Returns: States: A batch of initial states. @@ -205,7 +211,7 @@ class DefaultEnvAction(Actions): # In some cases overwritten by the user to support specific use-cases. def reset( self, - batch_shape: Optional[Union[int, Tuple[int, ...]]] = None, + batch_shape: int | Tuple[int, ...], random: bool = False, sink: bool = False, seed: Optional[int] = None, @@ -220,28 +226,13 @@ def reset( if random and seed is not None: set_seed(seed, performance_mode=True) - if batch_shape is None: - batch_shape = (1,) if isinstance(batch_shape, int): batch_shape = (batch_shape,) return self.states_from_batch_shape( batch_shape=batch_shape, random=random, sink=sink ) - def validate_actions( - self, states: States, actions: Actions, backward: bool = False - ) -> bool: - """First, asserts that states and actions have the same batch_shape. - Then, uses `is_action_valid`. - Returns a boolean indicating whether states/actions pairs are valid.""" - assert states.batch_shape == actions.batch_shape - return self.is_action_valid(states, actions, backward) - - def _step( - self, - states: States, - actions: Actions, - ) -> States: + def _step(self, states: States, actions: Actions) -> States: """Core step function. Calls the user-defined self.step() function. Function that takes a batch of states and actions and returns a batch of next @@ -255,13 +246,17 @@ def _step( valid_actions = actions[valid_states_idx] valid_states = states[valid_states_idx] - if not self.validate_actions(valid_states, valid_actions): + if not self.is_action_valid(valid_states, valid_actions): raise NonValidActionsError( "Some actions are not valid in the given states. See `is_action_valid`." ) + # Set to the sink state when the action is exit. new_sink_states_idx = actions.is_exit - new_states.tensor[new_sink_states_idx] = self.sf + sf_tensor = self.States.make_sink_states_tensor( + (int(new_sink_states_idx.sum().item()),) + ) + new_states[new_sink_states_idx] = self.States(sf_tensor) new_sink_states_idx = ~valid_states_idx | new_sink_states_idx assert new_sink_states_idx.shape == states.batch_shape @@ -269,20 +264,17 @@ def _step( not_done_actions = actions[~new_sink_states_idx] new_not_done_states_tensor = self.step(not_done_states, not_done_actions) - if not isinstance(new_not_done_states_tensor, torch.Tensor): + + if not isinstance(new_not_done_states_tensor, (torch.Tensor, GeometricBatch)): raise Exception( - "User implemented env.step function *must* return a torch.Tensor!" + "User implemented env.step function *must* return a torch.Tensor or " + "a GeometricBatch (for graph-based environments)." ) - new_states.tensor[~new_sink_states_idx] = new_not_done_states_tensor - + new_states[~new_sink_states_idx] = self.States(new_not_done_states_tensor) return new_states - def _backward_step( - self, - states: States, - actions: Actions, - ) -> States: + def _backward_step(self, states: States, actions: Actions) -> States: """Core backward_step function. Calls the user-defined self.backward_step fn. This function takes a batch of states and actions and returns a batch of next @@ -296,14 +288,14 @@ def _backward_step( valid_actions = actions[valid_states_idx] valid_states = states[valid_states_idx] - if not self.validate_actions(valid_states, valid_actions, backward=True): + if not self.is_action_valid(valid_states, valid_actions, backward=True): raise NonValidActionsError( "Some actions are not valid in the given states. See `is_action_valid`." ) # Calculate the backward step, and update only the states which are not Done. new_not_done_states_tensor = self.backward_step(valid_states, valid_actions) - new_states.tensor[valid_states_idx] = new_not_done_states_tensor + new_states[valid_states_idx] = self.States(new_not_done_states_tensor) return new_states @@ -355,6 +347,9 @@ class DiscreteEnv(Env, ABC): via mask tensors, that are directly attached to `States` objects. """ + s0: torch.Tensor # this tells the type checker that s0 is a torch.Tensor + sf: torch.Tensor # this tells the type checker that sf is a torch.Tensor + def __init__( self, n_actions: int, @@ -390,6 +385,8 @@ def __init__( if exit_action is None: exit_action = torch.tensor([n_actions - 1], device=device) + assert dummy_action is not None + assert exit_action is not None assert s0.shape == state_shape assert dummy_action.shape == action_shape assert exit_action.shape == action_shape @@ -432,32 +429,15 @@ def states_from_batch_shape( # In some cases overwritten by the user to support specific use-cases. def reset( self, - batch_shape: Optional[Union[int, Tuple[int, ...]]] = None, + batch_shape: int | Tuple[int, ...], random: bool = False, sink: bool = False, seed: Optional[int] = None, ) -> DiscreteStates: - """Instantiates a batch of initial states. - - `random` and `sink` cannot be both True. When `random` is `True` and `seed` is - not `None`, environment randomization is fixed by the submitted seed for - reproducibility. - """ - assert not (random and sink) - - if random and seed is not None: - torch.manual_seed(seed) # TODO: Improve seeding here? - - if batch_shape is None: - batch_shape = (1,) - if isinstance(batch_shape, int): - batch_shape = (batch_shape,) - states = self.states_from_batch_shape( - batch_shape=batch_shape, random=random, sink=sink - ) + """Instantiates a batch of initial DiscreteStates.""" + states = super().reset(batch_shape, random, sink, seed) states = cast(DiscreteStates, states) self.update_masks(states) - return states @abstractmethod @@ -481,12 +461,13 @@ class DiscreteEnvStates(DiscreteStates): return DiscreteEnvStates def make_actions_class(self) -> type[Actions]: + """Same functionality as the parent class, but with a different class name.""" env = self class DiscreteEnvActions(Actions): action_shape = env.action_shape - dummy_action = env.dummy_action.to(device=env.device) - exit_action = env.exit_action.to(device=env.device) + dummy_action = env.dummy_action + exit_action = env.exit_action return DiscreteEnvActions @@ -577,3 +558,72 @@ def terminating_states(self) -> DiscreteStates: raise NotImplementedError( "The environment does not support enumeration of states" ) + + +class GraphEnv(Env): + """Base class for graph-based environments.""" + + sf: GeometricData # this tells the type checker that sf is a GeometricData + + def __init__( + self, + s0: GeometricData, + sf: GeometricData, + device_str: Optional[str] = None, + preprocessor: Optional[Preprocessor] = None, + ): + """Initializes a graph-based environment. + + Args: + s0: The initial graph state. + sf: The sink graph state. + device_str: String representation of the device. + preprocessor: a Preprocessor object that converts raw graph states to a tensor + that can be fed into a neural network. Defaults to None, in which case + the IdentityPreprocessor is used. + """ + device = get_device(device_str, default_device=s0.device) + assert s0.x is not None + + self.s0 = s0.to(device) # type: ignore + self.features_dim = s0.x.shape[-1] + self.sf = sf.to(device) # type: ignore + + self.States = self.make_states_class() + self.Actions = self.make_actions_class() + + self.preprocessor = preprocessor + + def make_states_class(self) -> type[GraphStates]: + env = self + + class GraphEnvStates(GraphStates): + s0 = env.s0 + sf = env.sf + make_random_states_graph = env.make_random_states_tensor + + return GraphEnvStates + + def make_actions_class(self) -> type[GraphActions]: + """The default Actions class factory for all Environments. + + Returns a class that inherits from Actions and implements assumed methods. + The make_actions_class method should be overwritten to achieve more + environment-specific Actions functionality. + """ + env = self + + class DefaultGraphAction(GraphActions): + features_dim = env.features_dim + + return DefaultGraphAction + + @abstractmethod + def step(self, states: GraphStates, actions: Actions) -> torch.Tensor: + """Function that takes a batch of graph states and actions and returns a batch of next + graph states.""" + + @abstractmethod + def backward_step(self, states: GraphStates, actions: Actions) -> torch.Tensor: + """Function that takes a batch of graph states and actions and returns a batch of previous + graph states.""" diff --git a/src/gfn/gflownet/flow_matching.py b/src/gfn/gflownet/flow_matching.py index 90a21cb7..4f5cebe8 100644 --- a/src/gfn/gflownet/flow_matching.py +++ b/src/gfn/gflownet/flow_matching.py @@ -33,7 +33,7 @@ class FMGFlowNet(GFlowNet[StatePairs[DiscreteStates]]): def __init__(self, logF: DiscretePolicyEstimator, alpha: float = 1.0): super().__init__() - assert isinstance( # TODO: need a more flexible type check. + assert isinstance( logF, DiscretePolicyEstimator | ConditionalDiscretePolicyEstimator, ), "logF must be a DiscretePolicyEstimator or ConditionalDiscretePolicyEstimator" diff --git a/src/gfn/gflownet/trajectory_balance.py b/src/gfn/gflownet/trajectory_balance.py index c7a6a191..a83420fc 100644 --- a/src/gfn/gflownet/trajectory_balance.py +++ b/src/gfn/gflownet/trajectory_balance.py @@ -68,7 +68,7 @@ def loss( ) # If the conditioning values exist, we pass them to self.logZ - # (should be a ScalarEstimator or equivalant). + # (should be a ScalarEstimator or equivalent). if trajectories.conditioning is not None: with is_callable_exception_handler("logZ", self.logZ): assert isinstance(self.logZ, ScalarEstimator) diff --git a/src/gfn/gym/__init__.py b/src/gfn/gym/__init__.py index f19ea0ba..d684f1bb 100644 --- a/src/gfn/gym/__init__.py +++ b/src/gfn/gym/__init__.py @@ -1,5 +1,6 @@ from .box import Box from .discrete_ebm import DiscreteEBM +from .graph_building import GraphBuilding from .hypergrid import HyperGrid from .line import Line @@ -8,4 +9,5 @@ "DiscreteEBM", "HyperGrid", "Line", + "GraphBuilding", ] diff --git a/src/gfn/gym/graph_building.py b/src/gfn/gym/graph_building.py new file mode 100644 index 00000000..8340fbec --- /dev/null +++ b/src/gfn/gym/graph_building.py @@ -0,0 +1,315 @@ +from typing import Callable, Literal, Optional, Tuple + +import torch +from torch_geometric.data import Batch as GeometricBatch +from torch_geometric.data import Data as GeometricData + +from gfn.actions import GraphActions, GraphActionType +from gfn.env import GraphEnv +from gfn.states import GraphStates + + +class GraphBuilding(GraphEnv): + """Environment for incrementally building graphs. + + This environment allows constructing graphs by: + - Adding nodes with features + - Adding edges between existing nodes with features + - Terminating construction (EXIT) + + Args: + feature_dim: Dimension of node and edge features + state_evaluator: Callable that computes rewards for final states. + If None, uses default GCNConvEvaluator + device_str: Device to run computations on ('cpu' or 'cuda') + """ + + def __init__( + self, + feature_dim: int, + state_evaluator: Callable[[GraphStates], torch.Tensor], + device_str: Literal["cpu", "cuda"] = "cpu", + ): + s0 = GeometricData( + x=torch.zeros((0, feature_dim), dtype=torch.float32), + edge_attr=torch.zeros((0, feature_dim), dtype=torch.float32), + edge_index=torch.zeros((2, 0), dtype=torch.long), + device=device_str, + ) + sf = GeometricData( + x=torch.ones((1, feature_dim), dtype=torch.float32) * float("inf"), + edge_attr=torch.ones((0, feature_dim), dtype=torch.float32) * float("inf"), + edge_index=torch.zeros((2, 0), dtype=torch.long), + device=device_str, + ) + + self.state_evaluator = state_evaluator + self.feature_dim = feature_dim + + super().__init__( + s0=s0, + sf=sf, + device_str=device_str, + ) + + def reset( + self, + batch_shape: int | Tuple[int, ...], + random: bool = False, + sink: bool = False, + seed: Optional[int] = None, + ) -> GraphStates: + """Reset the environment to a new batch of graphs.""" + states = super().reset(batch_shape, random, sink, seed) + assert isinstance(states, GraphStates) + return states + + def step(self, states: GraphStates, actions: GraphActions) -> GeometricBatch: + """Step function for the GraphBuilding environment. + + Args: + states: GraphStates object representing the current graph. + actions: Actions indicating which edge to add. + + Returns the next graph the new GraphStates. + """ + if len(actions) == 0: + return states.tensor + + action_type = actions.action_type[0] + assert torch.all( + actions.action_type == action_type + ) # TODO: allow different action types + if action_type == GraphActionType.EXIT: + return self.States.make_sink_states_tensor(states.batch_shape) + + if action_type == GraphActionType.ADD_NODE: + batch_indices = torch.arange(len(states))[ + actions.action_type == GraphActionType.ADD_NODE + ] + states.tensor = self._add_node( + states.tensor, batch_indices, actions.features + ) + + if action_type == GraphActionType.ADD_EDGE: + # Get the data list from the batch + data_list = states.tensor.to_data_list() + + # Add edges to each graph + for i, (src, dst) in enumerate(actions.edge_index): + # Get the graph to modify + graph = data_list[i] + + # Add the new edge + graph.edge_index = torch.cat( + [ + graph.edge_index, + torch.tensor([[src], [dst]], device=graph.edge_index.device), + ], + dim=1, + ) + + # Add the edge feature + graph.edge_attr = torch.cat( + [graph.edge_attr, actions.features[i].unsqueeze(0)], dim=0 + ) + + # Create a new batch from the updated data list + new_tensor = GeometricBatch.from_data_list(data_list) + new_tensor.batch_shape = states.tensor.batch_shape + states.tensor = new_tensor + + return states.tensor + + def backward_step( + self, states: GraphStates, actions: GraphActions + ) -> GeometricBatch: + """Backward step function for the GraphBuilding environment. + + Args: + states: GraphStates object representing the current graph. + actions: Actions indicating which edge to remove. + + Returns the previous graph as a new GraphStates. + """ + if len(actions) == 0: + return states.tensor + + action_type = actions.action_type[0] + assert torch.all(actions.action_type == action_type) + + # Get the data list from the batch + data_list = states.tensor.to_data_list() + + if action_type == GraphActionType.ADD_NODE: + # Remove nodes with matching features + for i, features in enumerate(actions.features): + graph = data_list[i] + assert isinstance(graph.num_nodes, int) + + # Find nodes with matching features + is_equal = torch.all(graph.x == features.unsqueeze(0), dim=1) + + if torch.any(is_equal): + # Remove the first matching node + node_idx = int(torch.where(is_equal)[0][0].item()) + + # Remove the node + mask = torch.ones( + graph.num_nodes, + dtype=torch.bool, + device=graph.x.device, + ) + mask[node_idx] = False + + # Update node features + graph.x = graph.x[mask] + + elif action_type == GraphActionType.ADD_EDGE: + # Remove edges with matching indices + for i, (src, dst) in enumerate(actions.edge_index): + graph = data_list[i] + + # Find the edge to remove + edge_mask = ~( + (graph.edge_index[0] == src) & (graph.edge_index[1] == dst) + ) + + # Remove the edge + graph.edge_index = graph.edge_index[:, edge_mask] + graph.edge_attr = graph.edge_attr[edge_mask] + + # Create a new batch from the updated data list + new_batch = GeometricBatch.from_data_list(data_list) + + # Preserve the batch shape + new_batch.batch_shape = states.batch_shape + + return new_batch + + def is_action_valid( + self, states: GraphStates, actions: GraphActions, backward: bool = False + ) -> bool: + """Check if actions are valid for the given states. + + Args: + states: Current graph states. + actions: Actions to validate. + backward: Whether this is a backward step. + + Returns: + True if all actions are valid, False otherwise. + """ + # Get the data list from the batch + data_list = states.tensor.to_data_list() + + for i in range(len(actions)): + graph = data_list[i] + assert isinstance(graph.num_nodes, int) + + if actions.action_type[i] == GraphActionType.ADD_NODE: + # Check if a node with these features already exists + equal_nodes = torch.all( + graph.x == actions.features[i].unsqueeze(0), dim=1 + ) + + if backward: + # For backward actions, we need at least one matching node + if not torch.any(equal_nodes): + return False + else: + # For forward actions, we should not have any matching nodes + if torch.any(equal_nodes): + return False + + elif actions.action_type[i] == GraphActionType.ADD_EDGE: + src, dst = actions.edge_index[i] + + # Check if src and dst are valid node indices + if src >= graph.num_nodes or dst >= graph.num_nodes or src == dst: + return False + + # Check if the edge already exists + edge_exists = torch.any( + (graph.edge_index[0] == src) & (graph.edge_index[1] == dst) + ) + + if backward: + # For backward actions, the edge must exist + if not edge_exists: + return False + else: + # For forward actions, the edge must not exist + if edge_exists: + return False + + return True + + def _add_node( + self, + tensor: GeometricBatch, + batch_indices: torch.Tensor | list[int], + nodes_to_add: torch.Tensor, + ) -> GeometricBatch: + """Add nodes to graphs in a batch. + + Args: + tensor_dict: The current batch of graphs. + batch_indices: Indices of graphs to add nodes to. + nodes_to_add: Features of nodes to add. + + Returns: + Updated batch of graphs. + """ + batch_indices = ( + torch.tensor(batch_indices) + if isinstance(batch_indices, list) + else batch_indices + ) + if len(batch_indices) != len(nodes_to_add): + raise ValueError( + "Number of batch indices must match number of node feature lists" + ) + + # Get the data list from the batch + data_list = tensor.to_data_list() + + # Add nodes to the specified graphs + for graph_idx, new_nodes in zip(batch_indices, nodes_to_add): + # Get the graph to modify + graph = data_list[graph_idx] + + # Ensure new_nodes is 2D + new_nodes = torch.atleast_2d(new_nodes) + + # Check feature dimension + if new_nodes.shape[1] != graph.x.shape[1]: + raise ValueError(f"Node features must have dimension {graph.x.shape[1]}") + + # Add new nodes to the graph + graph.x = torch.cat([graph.x, new_nodes], dim=0) + + # Create a new batch from the updated data list + new_batch = GeometricBatch.from_data_list(data_list) + + # Preserve the batch shape + new_batch.batch_shape = tensor.batch_shape + return new_batch + + def reward(self, final_states: GraphStates) -> torch.Tensor: + """The environment's reward given a state. + This or log_reward must be implemented. + + Args: + final_states: A batch of final states. + + Returns: + torch.Tensor: Tensor of shape "batch_shape" containing the rewards. + """ + return self.state_evaluator(final_states) + + def make_random_states_tensor(self, batch_shape: Tuple) -> GraphStates: + """Generates random states tensor of shape (*batch_shape, feature_dim).""" + random_states_tensor = self.States.from_batch_shape(batch_shape) + assert isinstance(random_states_tensor, GraphStates) + return random_states_tensor diff --git a/src/gfn/gym/helpers/box_utils.py b/src/gfn/gym/helpers/box_utils.py index a4aa08a4..57adacb1 100644 --- a/src/gfn/gym/helpers/box_utils.py +++ b/src/gfn/gym/helpers/box_utils.py @@ -45,7 +45,8 @@ def __init__( delta: the radius of the quarter disk. northeastern: whether the quarter disk is northeastern or southwestern. centers: the centers of the distribution with shape (n_states, 2). - mixture_logits: Tensor of shape (n_states", n_components) containing the logits of the mixture of Beta distributions. + mixture_logits: Tensor of shape (n_states", n_components) containing the logits of + the mixture of Beta distributions. alpha: Tensor of shape (n_states", n_components) containing the alpha parameters of the Beta distributions. beta: Tensor of shape (n_states", n_components) containing the beta parameters of the Beta distributions. """ @@ -262,11 +263,16 @@ def __init__( Args: delta: the radius of the quarter disk. - mixture_logits: Tensor of shape (n_components,) containing the logits of the mixture of Beta distributions. - alpha_r: Tensor of shape (n_components,) containing the alpha parameters of the Beta distributions for the radius. - beta_r: Tensor of shape (n_components,) containing the beta parameters of the Beta distributions for the radius. - alpha_theta: Tensor of shape (n_components,) containing the alpha parameters of the Beta distributions for the angle. - beta_theta: Tensor of shape (n_components,) containing the beta parameters of the Beta distributions for the angle. + mixture_logits: Tensor of shape (n_components,) containing the logits of + the mixture of Beta distributions. + alpha_r: Tensor of shape (n_components,) containing the alpha parameters of + the Beta distributions for the radius. + beta_r: Tensor of shape (n_components,) containing the beta parameters of the + Beta distributions for the radius. + alpha_theta: Tensor of shape (n_components,) containing the alpha parameters of + the Beta distributions for the angle. + beta_theta: Tensor of shape (n_components,) containing the beta parameters of + the Beta distributions for the angle. """ self.delta = delta self.mixture_logits = mixture_logits @@ -374,7 +380,8 @@ def __init__( delta: the radius of the quarter disk. centers: the centers of the distribution with shape (n_states, 2). exit_probability: Tensor of shape (n_states,) containing the probability of exiting the quarter disk. - mixture_logits: Tensor of shape (n_states, n_components) containing the logits of the mixture of Beta distributions. + mixture_logits: Tensor of shape (n_states, n_components) containing the logits of the mixture of + Beta distributions. alpha: Tensor of shape (n_states, n_components) containing the alpha parameters of the Beta distributions. beta: Tensor of shape (n_states, n_components) containing the beta parameters of the Beta distributions. epsilon: the epsilon value to consider the state as being at the border of the square. @@ -578,9 +585,11 @@ def forward(self, preprocessed_states: torch.Tensor) -> torch.Tensor: """Computes the forward pass of the neural network. Args: - preprocessed_states: The tensor states of shape (*batch_shape, 2) to compute the forward pass of the neural network. + preprocessed_states: The tensor states of shape (*batch_shape, 2) to compute + the forward pass of the neural network. - Returns the output of the neural network as a tensor of shape (*batch_shape, 1 + 5 * max_n_components). + Returns the output of the neural network as a tensor of shape (*batch_shape, + 1 + 5 * max_n_components). """ assert preprocessed_states.shape[-1] == 2 batch_shape = preprocessed_states.shape[:-1] @@ -635,8 +644,9 @@ def forward(self, preprocessed_states: torch.Tensor) -> torch.Tensor: desired_out[~idx_s0] = desired_out_slice2 # Apply sigmoid to all except the dimensions between 1 and 1 + self._n_comp_max - # These are the components that represent the concentration parameters of the Betas, before normalizing, and should - # thus be between 0 and 1 (along with the exit probability) + # These are the components that represent the concentration parameters of the + # Betas, before normalizing, and should thus be between 0 and 1 (along with + # the exit probability). desired_out[..., 0] = torch.sigmoid(desired_out[..., 0]) desired_out[..., 1 + self._n_comp_max :] = torch.sigmoid( desired_out[..., 1 + self._n_comp_max :] @@ -688,9 +698,11 @@ def forward(self, preprocessed_states: torch.Tensor) -> torch.Tensor: """Computes the forward pass of the neural network. Args: - preprocessed_states: The tensor states of shape (*batch_shape, 2) to compute the forward pass of the neural network. + preprocessed_states: The tensor states of shape (*batch_shape, 2) to + compute the forward pass of the neural network. - Returns the output of the neural network as a tensor of shape (*batch_shape, 3 * n_components). + Returns the output of the neural network as a tensor of shape (*batch_shape, + 3 * n_components). """ assert preprocessed_states.shape[-1] == 2 batch_shape = preprocessed_states.shape[:-1] @@ -715,7 +727,8 @@ def forward(self, preprocessed_states: torch.Tensor) -> torch.Tensor: """Computes the forward pass of the neural network. Args: - preprocessed_states: The tensor states of shape (*batch_shape, input_dim) to compute the forward pass of the neural network. + preprocessed_states: The tensor states of shape (*batch_shape, input_dim) to compute + the forward pass of the neural network. Returns the output of the neural network as a tensor of shape (*batch_shape, output_dim). """ @@ -729,8 +742,8 @@ def forward(self, preprocessed_states: torch.Tensor) -> torch.Tensor: class BoxPBUniform(torch.nn.Module): """A module to be used to create a uniform PB distribution for the Box environment - A module that returns (1, 1, 1) for all states. Used with QuarterCircle, it leads to a - uniform distribution over parents in the south-western part of circle. + A module that returns (1, 1, 1) for all states. Used with QuarterCircle, it leads + to a uniform distribution over parents in the south-western part of circle. """ input_dim = 2 @@ -739,7 +752,8 @@ def forward(self, preprocessed_states: torch.Tensor) -> torch.Tensor: """Computes the forward pass of the neural network. Args: - preprocessed_states: The tensor states of shape (*batch_shape, 2) to compute the forward pass of the neural network. + preprocessed_states: The tensor states of shape (*batch_shape, 2) to compute + the forward pass of the neural network. Returns a tensor of shape (*batch_shape, 3) filled by ones. """ @@ -753,7 +767,8 @@ def split_PF_module_output(output: torch.Tensor, n_comp_max: int): """Splits the module output into the expected parameter sets. Args: - output: the module_output from the P_F model as a tensor of shape (*batch_shape, output_dim). + output: the module_output from the P_F model as a tensor of shape + (*batch_shape, output_dim). n_comp_max: the larger number of the two n_components and n_components_s0. Returns: diff --git a/src/gfn/gym/helpers/preprocessors.py b/src/gfn/gym/helpers/preprocessors.py index f2ce11c0..347b2ffa 100644 --- a/src/gfn/gym/helpers/preprocessors.py +++ b/src/gfn/gym/helpers/preprocessors.py @@ -18,8 +18,9 @@ def __init__( Args: n_states (int): The total number of states in the environment (not including s_f). - get_states_indices (Callable[[States], BatchOutputTensor]): function that returns the unique indices of the states. - BatchOutputTensor is a tensor of shape (*batch_shape, input_dim). + get_states_indices (Callable[[States], BatchOutputTensor]): function that returns + the unique indices of the states. + BatchOutputTensor is a tensor of shape (*batch_shape, input_dim). """ super().__init__(output_dim=n_states) self.get_states_indices = get_states_indices diff --git a/src/gfn/gym/line.py b/src/gfn/gym/line.py index 66e36eb0..59f5e74c 100644 --- a/src/gfn/gym/line.py +++ b/src/gfn/gym/line.py @@ -28,13 +28,9 @@ def __init__( self.mixture = [Normal(m, s) for m, s in zip(self.mus, self.sigmas)] self.init_value = init_value # Used in s0. - self.lb = torch.min(self.mus) - self.n_sd * torch.max( - self.sigmas - ) # Convenience only. - self.ub = torch.max(self.mus) + self.n_sd * torch.max( - self.sigmas - ) # Convenience only. - assert self.lb < self.init_value < self.ub + lb = torch.min(self.mus) - self.n_sd * torch.max(self.sigmas) + ub = torch.max(self.mus) + self.n_sd * torch.max(self.sigmas) + assert lb < self.init_value < ub s0 = torch.tensor([self.init_value, 0.0], device=torch.device(device_str)) dummy_action = torch.tensor([float("inf")], device=torch.device(device_str)) @@ -106,6 +102,7 @@ def log_reward(self, final_states: States) -> torch.Tensor: return log_rewards @property - def log_partition(self) -> float: + def log_partition(self) -> torch.Tensor: """Log Partition log of the number of gaussians.""" - return torch.tensor(len(self.mus)).log().item() + partition = len(self.mus).log().item() + return torch.tensor(partition) diff --git a/src/gfn/modules.py b/src/gfn/modules.py index b25f7955..7e7035a6 100644 --- a/src/gfn/modules.py +++ b/src/gfn/modules.py @@ -3,11 +3,17 @@ import torch import torch.nn as nn -from torch.distributions import Categorical, Distribution - -from gfn.preprocessors import IdentityPreprocessor, Preprocessor -from gfn.states import DiscreteStates, States -from gfn.utils.distributions import UnsqueezedCategorical +from tensordict import TensorDict +from torch.distributions import Categorical, Distribution, Normal + +from gfn.preprocessors import GraphPreprocessor, IdentityPreprocessor, Preprocessor +from gfn.states import DiscreteStates, GraphStates, States +from gfn.utils.distributions import ( + CategoricalActionType, + CategoricalIndexes, + CompositeDistribution, + UnsqueezedCategorical, +) REDUCTION_FXNS = { "mean": torch.mean, @@ -77,7 +83,6 @@ def __init__( ) preprocessor = IdentityPreprocessor(module.input_dim) self.preprocessor = preprocessor - self._output_dim_is_checked = False self.is_backward = is_backward def forward(self, input: States | torch.Tensor) -> torch.Tensor: @@ -90,14 +95,7 @@ def forward(self, input: States | torch.Tensor) -> torch.Tensor: """ if isinstance(input, States): input = self.preprocessor(input) - - out = self.module(input) - - if not self._output_dim_is_checked: - self.check_output_dim(out) - self._output_dim_is_checked = True - - return out + return self.module(input) def __repr__(self): return f"{self.__class__.__name__} module" @@ -213,10 +211,7 @@ def forward(self, input: States | torch.Tensor) -> torch.Tensor: if out.shape[-1] != 1: out = self.reduction_fxn(out, -1) - if not self._output_dim_is_checked: - # self.check_output_dim(out) - self._output_dim_is_checked = True - + assert out.shape[-1] == 1 return out @@ -281,7 +276,7 @@ def to_probability_distribution( on policy. epsilon: with probability epsilon, a random action is chosen. Does nothing if set to 0.0 (default), in which case it's on policy.""" - # self.check_output_dim(module_output) + assert module_output.shape[-1] == self.expected_output_dim masks = states.backward_masks if self.is_backward else states.forward_masks logits = module_output @@ -362,11 +357,7 @@ def forward(self, states: States, conditioning: torch.Tensor) -> torch.Tensor: Returns the output of the module, as a tensor of shape (*batch_shape, output_dim). """ out = self._forward_trunk(states, conditioning) - - if not self._output_dim_is_checked: - # self.check_output_dim(out) - self._output_dim_is_checked = True - + assert out.shape[-1] == self.expected_output_dim return out @@ -450,10 +441,7 @@ def forward(self, states: States, conditioning: torch.tensor) -> torch.Tensor: if out.shape[-1] != 1: out = self.reduction_fxn(out, -1) - if not self._output_dim_is_checked: - # self.check_output_dim(out) - self._output_dim_is_checked = True - + assert out.shape[-1] == self.expected_output_dim return out @property @@ -467,3 +455,108 @@ def to_probability_distribution( **policy_kwargs: Any, ) -> Distribution: raise NotImplementedError + + +class GraphActionPolicyEstimator(GFNModule): + r"""Container for forward and backward policy estimators for graph environments. + + $s \mapsto (P_F(s' \mid s))_{s' \in Children(s)}$. + + or + + $s \mapsto (P_B(s' \mid s))_{s' \in Parents(s)}$. + + Attributes: + temperature: scalar to divide the logits by before softmax. + sf_bias: scalar to subtract from the exit action logit before dividing by + temperature. + epsilon: with probability epsilon, a random action is chosen. + """ + + def __init__( + self, + module: nn.Module, + preprocessor: Preprocessor | None = None, + is_backward: bool = False, + ): + """Initializes a estimator for P_F for graph environments. + + Args: + is_backward: if False, then this is a forward policy, else backward policy. + """ + if preprocessor is None: + preprocessor = GraphPreprocessor() + super().__init__(module, preprocessor, is_backward) + + def expected_output_dim(self) -> int: + return 0 + + def forward(self, states: GraphStates) -> TensorDict: + """Forward pass of the module. + + Args: + states: The input graph states. + + Returns: + TensorDict containing: + - action_type: logits for action type selection (batch_shape, n_actions) + - features: parameters for node/edge features (batch_shape, feature_dim) + - edge_index: logits for edge connections (batch_shape, n_nodes, n_nodes) + """ + return self.module(states) + + def to_probability_distribution( + self, + states: GraphStates, + module_output: TensorDict, + temperature: float = 1.0, + epsilon: float = 0.0, + ) -> CompositeDistribution: + """Returns a probability distribution given a batch of states and module output. + + We handle off-policyness using these kwargs. + + Args: + states: The states to use. + module_output: The output of the module as a tensor of shape (*batch_shape, output_dim). + temperature: scalar to divide the logits by before softmax. Does nothing + if set to 1.0 (default), in which case it's on policy. + epsilon: with probability epsilon, a random action is chosen. Does nothing + if set to 0.0 (default), in which case it's on policy.""" + + raise NotImplementedError( + "This method is incompatible with pyg and will be fixed in a future PR." + ) + dists = {} + + action_type_logits = module_output["action_type"] + masks = states.backward_masks if self.is_backward else states.forward_masks + action_type_logits[~masks["action_type"]] = -float("inf") + action_type_probs = torch.softmax(action_type_logits / temperature, dim=-1) + uniform_dist_probs = masks["action_type"].float() / masks["action_type"].sum( + dim=-1, keepdim=True + ) + action_type_probs = ( + 1 - epsilon + ) * action_type_probs + epsilon * uniform_dist_probs + dists["action_type"] = CategoricalActionType(probs=action_type_probs) + + edge_index_logits = module_output["edge_index"] + edge_index_logits[~masks["edge_index"]] = -float("inf") + if torch.any(edge_index_logits != -float("inf")): + B, N, N = edge_index_logits.shape + edge_index_logits = edge_index_logits.reshape(B, N * N) + edge_index_probs = torch.softmax(edge_index_logits / temperature, dim=-1) + uniform_dist_probs = ( + torch.ones_like(edge_index_probs) / edge_index_probs.shape[-1] + ) + edge_index_probs = ( + 1 - epsilon + ) * edge_index_probs + epsilon * uniform_dist_probs + edge_index_probs[torch.isnan(edge_index_probs)] = 1 + dists["edge_index"] = CategoricalIndexes( + probs=edge_index_probs, n_nodes=states.tensor.num_nodes + ) + + dists["features"] = Normal(module_output["features"], temperature) + return CompositeDistribution(dists=dists) diff --git a/src/gfn/preprocessors.py b/src/gfn/preprocessors.py index 11abbf61..b80b2a57 100644 --- a/src/gfn/preprocessors.py +++ b/src/gfn/preprocessors.py @@ -2,8 +2,9 @@ from typing import Callable import torch +from torch_geometric.data import Batch as GeometricBatch -from gfn.states import DiscreteStates, States +from gfn.states import DiscreteStates, GraphStates, States class Preprocessor(ABC): @@ -57,8 +58,8 @@ def __init__( Each state is represented by a unique integer (>= 0) index. Args: - get_states_indices (Callable[[DiscreteStates], BatchOutputTensor]): function that returns the unique indices of the states. - BatchOutputTensor is a tensor of shape (*batch_shape, 1). + get_states_indices: function that returns the unique indices of the states. + torch.Tensor is a tensor of shape (*batch_shape, 1). """ super().__init__(output_dim=1) self.get_states_indices = get_states_indices @@ -72,3 +73,11 @@ def preprocess(self, states: DiscreteStates) -> torch.Tensor: Returns the unique indices of the states as a tensor of shape `batch_shape`. """ return self.get_states_indices(states).long().unsqueeze(-1) + + +class GraphPreprocessor(Preprocessor): + def __init__(self) -> None: + super().__init__(-1) # TODO: review output_dim API + + def preprocess(self, states: GraphStates) -> GeometricBatch: + return states.tensor diff --git a/src/gfn/samplers.py b/src/gfn/samplers.py index 05587354..34beabac 100644 --- a/src/gfn/samplers.py +++ b/src/gfn/samplers.py @@ -147,7 +147,7 @@ def sample_trajectories( if conditioning is not None: assert states.batch_shape == conditioning.shape[: len(states.batch_shape)] - device = states.tensor.device + device = states.device dones = ( states.is_initial_state @@ -205,13 +205,14 @@ def sample_trajectories( all_estimator_outputs.append(estimator_outputs_padded) actions[~dones] = valid_actions - trajectories_actions.append(actions) if save_logprobs: assert ( actions_log_probs is not None ), "actions_log_probs should not be None when save_logprobs is True" log_probs[~dones] = actions_log_probs - trajectories_logprobs.append(log_probs) + + trajectories_actions.append(actions) + trajectories_logprobs.append(log_probs) if self.estimator.is_backward: new_states = env._backward_step(states, actions) @@ -219,7 +220,7 @@ def sample_trajectories( new_states = env._step(states, actions) sink_states_mask = new_states.is_sink_state - # Increment the step, determine which trajectories are finisihed, and eval + # Increment the step, determine which trajectories are finished, and eval # rewards. step += 1 @@ -242,11 +243,10 @@ def sample_trajectories( ) states = new_states dones = dones | new_dones - trajectories_states.append(deepcopy(states)) # Stack all states and actions - stacked_states = env.States.stack_states(trajectories_states) + stacked_states = env.States.stack(trajectories_states) stacked_actions = env.Actions.stack(trajectories_actions)[ 1: ] # Drop dummy action diff --git a/src/gfn/states.py b/src/gfn/states.py index 113130bb..caaa70ee 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -3,9 +3,25 @@ from abc import ABC from copy import deepcopy from math import prod -from typing import Callable, ClassVar, Optional, Sequence, cast - +from typing import ( + Callable, + ClassVar, + List, + Literal, + Optional, + Sequence, + Tuple, + Union, + cast, +) + +import numpy as np import torch +from torch_geometric.data import Batch as GeometricBatch +from torch_geometric.data import Data as GeometricData +from torch_geometric.data.data import BaseData + +from gfn.actions import GraphActionType class States(ABC): @@ -23,8 +39,8 @@ class States(ABC): `DiscreteEnv`), then each `States` object is also endowed with a `forward_masks` and `backward_masks` boolean attributes representing which actions are allowed at each state. This makes it possible to instantly access the allowed actions at each state, - without having to call the environment's `validate_actions` method. Put different, - `validate_actions` for such environments, directly calls the masks. This is handled + without having to call the environment's `is_action_valid` method. Put different, + `is_action_valid` for such environments, directly calls the masks. This is handled in the DiscreteState subclass. A `batch_shape` attribute is also required, to keep track of the batch dimension. @@ -40,13 +56,14 @@ class States(ABC): Attributes: tensor: Tensor representing a batch of states. - batch_shape: Sizes of the batch dimensions. + _batch_shape: Sizes of the batch dimensions. _log_rewards: Stores the log rewards of each state. """ - state_shape: ClassVar[tuple[int, ...]] # Shape of one state - s0: ClassVar[torch.Tensor] # Source state of the DAG - sf: ClassVar[torch.Tensor] # Dummy state, used to pad a batch of states + state_shape: ClassVar[tuple[int, ...]] + s0: ClassVar[torch.Tensor | GeometricData] + sf: ClassVar[torch.Tensor | GeometricData] + make_random_states_tensor: Callable = lambda x: (_ for _ in ()).throw( NotImplementedError( "The environment does not support initialization of random states." @@ -63,12 +80,26 @@ def __init__(self, tensor: torch.Tensor): assert tensor.shape[-len(self.state_shape) :] == self.state_shape self.tensor = tensor - self.batch_shape = tuple(self.tensor.shape)[: -len(self.state_shape)] + self._batch_shape = tuple(self.tensor.shape)[: -len(self.state_shape)] + self._log_rewards = ( + None # Useful attribute if we want to store the log-reward of the states + ) + + @property + def batch_shape(self) -> tuple[int, ...]: + return self._batch_shape + + @batch_shape.setter + def batch_shape(self, batch_shape: tuple[int, ...]) -> None: + self._batch_shape = batch_shape @classmethod def from_batch_shape( - cls, batch_shape: tuple[int, ...], random: bool = False, sink: bool = False - ) -> States: + cls, + batch_shape: int | tuple[int, ...], + random: bool = False, + sink: bool = False, + ) -> States | GraphStates: """Create a States object with the given batch shape. By default, all states are initialized to $s_0$, the initial state. Optionally, @@ -85,6 +116,9 @@ def from_batch_shape( Raises: ValueError: If both Random and Sink are True. """ + if isinstance(batch_shape, int): + batch_shape = (batch_shape,) + if random and sink: raise ValueError("Only one of `random` and `sink` should be True.") @@ -101,19 +135,29 @@ def make_initial_states_tensor(cls, batch_shape: tuple[int, ...]) -> torch.Tenso """Makes a tensor with a `batch_shape` of states consisting of $s_0`$s.""" state_ndim = len(cls.state_shape) assert cls.s0 is not None and state_ndim is not None - return cls.s0.repeat(*batch_shape, *((1,) * state_ndim)) + if isinstance(cls.s0, torch.Tensor): + return cls.s0.repeat(*batch_shape, *((1,) * state_ndim)) + else: + raise NotImplementedError( + f"make_initial_states_tensor is not implemented by default for {cls.__name__}" + ) @classmethod def make_sink_states_tensor(cls, batch_shape: tuple[int, ...]) -> torch.Tensor: """Makes a tensor with a `batch_shape` of states consisting of $s_f$s.""" state_ndim = len(cls.state_shape) assert cls.sf is not None and state_ndim is not None - return cls.sf.repeat(*batch_shape, *((1,) * state_ndim)) + if isinstance(cls.sf, torch.Tensor): + return cls.sf.repeat(*batch_shape, *((1,) * state_ndim)) + else: + raise NotImplementedError( + f"make_sink_states_tensor is not implemented by default for {cls.__name__}" + ) - def __len__(self): + def __len__(self) -> int: return prod(self.batch_shape) - def __repr__(self): + def __repr__(self) -> str: return f"{self.__class__.__name__} object of batch shape {self.batch_shape} and state shape {self.state_shape}" @property @@ -130,7 +174,9 @@ def __getitem__( return out def __setitem__( - self, index: int | Sequence[int] | Sequence[bool], states: States + self, + index: int | slice | tuple | Sequence[int] | Sequence[bool] | torch.Tensor, + states: States, ) -> None: """Set particular states of the batch.""" self.tensor[index] = states.tensor @@ -195,7 +241,7 @@ def extend_with_sf(self, required_first_dim: int) -> None: Args: required_first_dim: The size of the first batch dimension post-expansion. """ - if len(self.batch_shape) == 2: + if len(self.batch_shape) == 2 and isinstance(self.__class__.sf, torch.Tensor): if self.batch_shape[0] >= required_first_dim: return self.tensor = torch.cat( @@ -210,10 +256,10 @@ def extend_with_sf(self, required_first_dim: int) -> None: self.batch_shape = (required_first_dim, self.batch_shape[1]) else: raise ValueError( - f"extend_with_sf is not implemented for batch shapes {self.batch_shape}" + f"extend_with_sf is not implemented for graph states nor for batch shapes {self.batch_shape}" ) - def compare(self, other: torch.Tensor) -> torch.Tensor: + def _compare(self, other: torch.Tensor) -> torch.Tensor: """Computes elementwise equality between state tensor with an external tensor. Args: @@ -234,31 +280,66 @@ def compare(self, other: torch.Tensor) -> torch.Tensor: @property def is_initial_state(self) -> torch.Tensor: """Returns a tensor of shape `batch_shape` that is True for states that are $s_0$ of the DAG.""" - source_states_tensor = self.__class__.s0.repeat( - *self.batch_shape, *((1,) * len(self.__class__.state_shape)) - ) - return self.compare(source_states_tensor) + if isinstance(self.__class__.s0, torch.Tensor): + source_states_tensor = self.__class__.s0.repeat( + *self.batch_shape, *((1,) * len(self.__class__.state_shape)) + ) + else: + raise NotImplementedError( + f"is_initial_state is not implemented by default for {self.__class__.__name__}" + ) + return self._compare(source_states_tensor) @property def is_sink_state(self) -> torch.Tensor: """Returns a tensor of shape `batch_shape` that is True for states that are $s_f$ of the DAG.""" # TODO: self.__class__.sf == self.tensor -- or something similar? - sink_states = self.__class__.sf.repeat( - *self.batch_shape, *((1,) * len(self.__class__.state_shape)) - ).to(self.tensor.device) - return self.compare(sink_states) + if isinstance(self.__class__.sf, torch.Tensor): + sink_states = self.__class__.sf.repeat( + *self.batch_shape, *((1,) * len(self.__class__.state_shape)) + ).to(self.tensor.device) + else: + raise NotImplementedError( + f"is_sink_state is not implemented by default for {self.__class__.__name__}" + ) + return self._compare(sink_states) + + @property + def log_rewards(self) -> torch.Tensor | None: + """Returns the log rewards of the states as tensor of shape `batch_shape`.""" + return self._log_rewards + + @log_rewards.setter + def log_rewards(self, log_rewards: torch.Tensor) -> None: + """Sets the log rewards of the states. + + Args: + log_rewards: Tensor of shape `batch_shape` representing the log rewards of the states. + """ + assert tuple(log_rewards.shape) == self.batch_shape + self._log_rewards = log_rewards def sample(self, n_samples: int) -> States: """Samples a subset of the States object.""" return self[torch.randperm(len(self))[:n_samples]] @classmethod - def stack_states(cls, states: Sequence[States]): + def stack(cls, states: Sequence[States]) -> States: """Given a list of states, stacks them along a new dimension (0).""" - state_example = states[0] # We assume all elems of `states` are the same. + state_example = states[0] + assert all( + state.batch_shape == state_example.batch_shape for state in states + ), "All states must have the same batch_shape" stacked_states = state_example.from_batch_shape((0, 0)) # Empty. stacked_states.tensor = torch.stack([s.tensor for s in states], dim=0) + if state_example._log_rewards: + log_rewards = [] + for s in states: + if s._log_rewards is None: + raise ValueError("Some states have no log rewards.") + log_rewards.append(s._log_rewards) + stacked_states._log_rewards = torch.stack(log_rewards, dim=0) # Adds the trajectory dimension. stacked_states.batch_shape = ( @@ -273,7 +354,7 @@ class DiscreteStates(States, ABC): States are endowed with a `forward_masks` and `backward_masks`: boolean attributes representing which actions are allowed at each state. This is the mechanism by - which all elements of the library (including an environment's `validate_actions` + which all elements of the library (including an environment's `is_action_valid` method) verifies the allowed actions at each state. Attributes: @@ -321,10 +402,10 @@ def __init__( assert self.forward_masks.shape == (*self.batch_shape, self.n_actions) assert self.backward_masks.shape == (*self.batch_shape, self.n_actions - 1) - def clone(self) -> States: + def clone(self) -> DiscreteStates: """Returns a clone of the current instance.""" return self.__class__( - self.tensor.detach().clone(), + self.tensor.detach().clone(), # TODO: Are States carrying gradients? self.forward_masks, self.backward_masks, ) @@ -450,12 +531,631 @@ def init_forward_masks(self, set_ones: bool = True): self.forward_masks = torch.zeros(shape).bool() @classmethod - def stack_states(cls, states: Sequence[DiscreteStates]): - stacked_states = cast(DiscreteStates, super().stack_states(states)) - stacked_states.forward_masks = torch.stack( - [s.forward_masks for s in states], dim=0 + def stack(cls, states: List[DiscreteStates]) -> DiscreteStates: + """Stacks a list of DiscreteStates objects along a new dimension (0).""" + out = super().stack(states) + assert isinstance(out, DiscreteStates) + out.forward_masks = torch.stack([s.forward_masks for s in states], dim=0) + out.backward_masks = torch.stack([s.backward_masks for s in states], dim=0) + return out + + +class GraphStates(States): + """ + Base class for Graph as a state representation. The `GraphStates` object is a batched + collection of multiple graph objects. The `GeometricBatch` object is used to + represent the batch of graph objects as states. + """ + + s0: ClassVar[GeometricData] + sf: ClassVar[GeometricData] + + def __init__(self, tensor: GeometricBatch): + """Initialize the GraphStates with a PyG Batch object. + + Args: + tensor: A PyG Batch object representing a batch of graphs. + """ + self.tensor = tensor + if not hasattr(self.tensor, "batch_shape"): + self.tensor.batch_shape = self.tensor.batch_size + + if tensor.x.size(0) > 0: + assert tensor.num_graphs == prod(tensor.batch_shape) + + self._log_rewards: Optional[torch.Tensor] = None + + @property + def batch_shape(self) -> tuple[int, ...]: + """Returns the batch shape as a tuple.""" + return self.tensor.batch_shape + + @classmethod + def make_initial_states_tensor(cls, batch_shape: int | Tuple) -> GeometricBatch: + """Makes a batch of graphs consisting of s0 states. + + Args: + batch_shape: Shape of the batch dimensions. + + Returns: + A PyG Batch object containing copies of the initial state. + """ + assert cls.s0.edge_attr is not None + assert cls.s0.x is not None + + batch_shape = batch_shape if isinstance(batch_shape, Tuple) else (batch_shape,) + num_graphs = int(np.prod(batch_shape)) + + # Create a list of Data objects by copying s0 + data_list = [cls.s0.clone() for _ in range(num_graphs)] + + if len(data_list) == 0: # If batch_shape is 0, create a single empty graph + data_list = [ + GeometricData( + x=torch.zeros(0, cls.s0.x.size(1)), + edge_index=torch.zeros(2, 0, dtype=torch.long), + edge_attr=torch.zeros(0, cls.s0.edge_attr.size(1)), + ) + ] + + # Create a batch from the list + batch = GeometricBatch.from_data_list(cast(List[BaseData], data_list)) + + # Store the batch shape for later reference + batch.batch_shape = tuple(batch_shape) + + return batch + + @classmethod + def make_sink_states_tensor(cls, batch_shape: int | Tuple) -> GeometricBatch: + """Makes a batch of graphs consisting of sf states. + + Args: + batch_shape: Shape of the batch dimensions. + + Returns: + A PyG Batch object containing copies of the sink state. + """ + assert cls.sf.edge_attr is not None + assert cls.sf.x is not None + if cls.sf is None: + raise NotImplementedError("Sink state is not defined") + + batch_shape = batch_shape if isinstance(batch_shape, Tuple) else (batch_shape,) + num_graphs = int(np.prod(batch_shape)) + + # Create a list of Data objects by copying sf + data_list = [cls.sf.clone() for _ in range(num_graphs)] + if len(data_list) == 0: # If batch_shape is 0, create a single empty graph + data_list = [ + GeometricData( + x=torch.zeros(0, cls.sf.x.size(1)), + edge_index=torch.zeros(2, 0, dtype=torch.long), + edge_attr=torch.zeros(0, cls.sf.edge_attr.size(1)), + ) + ] + + # Create a batch from the list + batch = GeometricBatch.from_data_list(cast(List[BaseData], data_list)) + + # Store the batch shape for later reference + batch.batch_shape = batch_shape + + return batch + + @classmethod + def make_random_states_tensor(cls, batch_shape: int | Tuple) -> GeometricBatch: + """Makes a batch of random graph states. + + Args: + batch_shape: Shape of the batch dimensions. + + Returns: + A PyG Batch object containing random graph states. + """ + assert cls.s0.edge_attr is not None + assert cls.s0.x is not None + + batch_shape = batch_shape if isinstance(batch_shape, Tuple) else (batch_shape,) + num_graphs = int(np.prod(batch_shape)) + device = cls.s0.x.device + + data_list = [] + for _ in range(num_graphs): + # Create a random graph with random number of nodes + num_nodes = np.random.randint(1, 10) + + # Create random node features + x = torch.rand(num_nodes, cls.s0.x.size(1), device=device) + + # Create random edges (not all possible edges to keep it sparse) + num_edges = np.random.randint(0, num_nodes * (num_nodes - 1) // 2 + 1) + if num_edges > 0 and num_nodes > 1: + # Generate random source and target nodes + edge_index = torch.zeros(2, num_edges, dtype=torch.long, device=device) + for i in range(num_edges): + src, dst = np.random.choice(num_nodes, 2, replace=False) + edge_index[0, i] = src + edge_index[1, i] = dst + + # Create random edge features + edge_attr = torch.rand( + num_edges, cls.s0.edge_attr.size(1), device=device + ) + + data = GeometricData(x=x, edge_index=edge_index, edge_attr=edge_attr) + else: + # No edges + data = GeometricData( + x=x, + edge_index=torch.zeros(2, 0, dtype=torch.long, device=device), + edge_attr=torch.zeros(0, cls.s0.edge_attr.size(1), device=device), + ) + + data_list.append(data) + + if len(data_list) == 0: # If batch_shape is 0, create a single empty graph + data_list = [ + GeometricData( + x=torch.zeros(0, cls.s0.x.size(1)), + edge_index=torch.zeros(2, 0, dtype=torch.long), + edge_attr=torch.zeros(0, cls.s0.edge_attr.size(1)), + ) + ] + + # Create a batch from the list + batch = GeometricBatch.from_data_list(cast(List[BaseData], data_list)) + + # Store the batch shape for later reference + batch.batch_shape = batch_shape + + return batch + + def __repr__(self): + """Returns a string representation of the GraphStates object.""" + return ( + f"{self.__class__.__name__} object of batch shape {self.batch_shape} and " + f"node feature dim {self.tensor.x.size(1)} and edge feature dim {self.tensor.edge_attr.size(1)}" + ) + + def __getitem__( + self, + index: Union[int, Sequence[int], slice, torch.Tensor, Literal[1], Tuple], + ) -> GraphStates: + """Get a subset of the GraphStates. + + Args: + index: Index or indices to select. + + Returns: + A new GraphStates object containing the selected graphs. + """ + assert ( + self.batch_shape != () + ), "We can't index on a Batch with 0-dimensional batch shape." + + # Convert the index to a list of indices. + tensor_idx = torch.arange(len(self)).view(*self.batch_shape)[index] + new_shape = tuple(tensor_idx.shape) + flat_idx = tensor_idx.flatten() + + # Get the selected graphs from the batch + selected_graphs = self.tensor.index_select(flat_idx) + if len(selected_graphs) == 0: + assert np.prod(new_shape) == 0 and len(new_shape) > 0 + selected_graphs = [ # TODO: Is this the best way to create an empty Batch? + GeometricData( + x=torch.zeros(*new_shape, self.tensor.x.size(1)), + edge_index=torch.zeros(2, 0, dtype=torch.long), + edge_attr=torch.zeros(*new_shape, self.tensor.edge_attr.size(1)), + ) + ] + + # Create a new batch from the selected graphs. + # TODO: is there any downside to always using GeometricBatch even when the batch dimension is empty. + new_batch = GeometricBatch.from_data_list(cast(List[BaseData], selected_graphs)) + new_batch.batch_shape = new_shape + + # Create a new GraphStates object + out = self.__class__(new_batch) + + # Copy log rewards if they exist + if self._log_rewards is not None: + out.log_rewards = self._log_rewards[index] + + return out + + def __setitem__( + self, + index: Union[int, Sequence[int], slice, torch.Tensor, Literal[1], Tuple], + graph: GraphStates, + ) -> None: + """Set a subset of the GraphStates. + + Args: + index: Index or indices to set. + graph: GraphStates object containing the new graphs. + """ + # Convert the index to a list of indices + batch_shape = self.batch_shape + if isinstance(index, int) and len(batch_shape) == 1: + indices = [index] + else: + tensor_idx = torch.arange(len(self)).view(*batch_shape) + indices = ( + tensor_idx[index].flatten().tolist() + ) # TODO: is .flatten() necessary? + + assert len(indices) == len(graph) + + # Get the data list from the current batch + data_list = self.tensor.to_data_list() + + # Get the data list from the new graphs + new_data_list = graph.tensor.to_data_list() + + # Replace the selected graphs + for i, idx in enumerate(indices): + data_list[idx] = new_data_list[i] + + # Create a new batch from the updated data list + self.tensor = GeometricBatch.from_data_list(data_list) + + # Preserve the batch shape + self.tensor.batch_shape = batch_shape + + @property + def device(self) -> torch.device: + """Returns the device of the tensor.""" + return self.tensor.x.device + + def to(self, device: torch.device) -> GraphStates: + """Moves the GraphStates to the specified device. + + Args: + device: The device to move to. + + Returns: + The GraphStates object on the specified device. + """ + self.tensor = self.tensor.to(device) + if self._log_rewards is not None: + self._log_rewards = self._log_rewards.to(device) + return self + + @staticmethod + def _clone_batch(batch: GeometricBatch) -> GeometricBatch: + """Clones a PyG Batch object. + + Args: + batch: The Batch object to clone. + + Returns: + A new Batch object with the same data. + """ + new_batch = batch.clone() + # The Batch.clone() changes the type of the batch shape to a list + # We need to set it back to a tuple + new_batch.batch_shape = batch.batch_shape + return new_batch + + def clone(self) -> GraphStates: + """Returns a detached clone of the current instance. + + Returns: + A new GraphStates object with the same data. + """ + # Create a deep copy of the batch + new_batch = self._clone_batch(self.tensor) + + # Create a new GraphStates object + out = self.__class__(new_batch) + + # Copy log rewards if they exist + if self._log_rewards is not None: + out._log_rewards = self._log_rewards.clone() + + return out + + def extend(self, other: GraphStates): + """Concatenates to another GraphStates object along the batch dimension. + + Args: + other: GraphStates object to concatenate with. + """ + if len(self) == 0: + # If self is empty, just copy other + self.tensor = self._clone_batch(other.tensor) + if other._log_rewards is not None: + self._log_rewards = other._log_rewards.clone() + return + + # Get the data lists + self_data_list = self.tensor.to_data_list() + other_data_list = other.tensor.to_data_list() + + # Update the batch shape + if len(self.batch_shape) == 1: + # Create a new batch + new_batch_shape = (self.batch_shape[0] + other.batch_shape[0],) + self.tensor = GeometricBatch.from_data_list(self_data_list + other_data_list) + self.tensor.batch_shape = new_batch_shape + else: + # Handle the case where batch_shape is (T, B) + # and we want to concatenate along the B dimension + assert len(self.batch_shape) == 2 + max_len = max(self.batch_shape[0], other.batch_shape[0]) + + # We need to extend both batches to the same length T + if self.batch_shape[0] < max_len: + self_extension = self.make_sink_states_tensor( + (max_len - self.batch_shape[0], self.batch_shape[1]) + ) + self_data_list = self_data_list + self_extension.to_data_list() + + if other.batch_shape[0] < max_len: + other_extension = other.make_sink_states_tensor( + (max_len - other.batch_shape[0], other.batch_shape[1]) + ) + other_data_list = other_data_list + other_extension.to_data_list() + + # Now both have the same length T, we can concatenate along B + batch_shape = (max_len, self.batch_shape[1] + other.batch_shape[1]) + self.tensor = GeometricBatch.from_data_list(self_data_list + other_data_list) + self.tensor.batch_shape = batch_shape + + # Combine log rewards if they exist + if self._log_rewards is not None and other._log_rewards is not None: + self.log_rewards = torch.cat([self._log_rewards, other._log_rewards], dim=0) + elif other._log_rewards is not None: + self.log_rewards = other._log_rewards.clone() + + def _compare(self, other: GeometricData) -> torch.Tensor: + """Compares the current batch of graphs with another graph. + + Args: + other: A PyG Data object to compare with. + + Returns: + A boolean tensor indicating which graphs in the batch are equal to other. + """ + out = torch.zeros(len(self), dtype=torch.bool, device=self.device) + + # Get the data list from the batch + data_list = self.tensor.to_data_list() + + assert other.edge_index is not None # TODO: is allowing None here a good idea? + assert other.edge_attr is not None # + assert other.num_nodes is not None # + + for i, data in enumerate(data_list): + # Check if the number of nodes is the same + if data.num_nodes != other.num_nodes: + continue + + # Check if node features are the same + if not torch.all(data.x == other.x): + continue + + # Check if the number of edges is the same + if data.edge_index.size(1) != other.edge_index.size(1): + continue + + # Check if edge indices are the same (this is more complex due to potential reordering) + # We'll use a simple heuristic: sort edges and compare + data_edges = data.edge_index.t().tolist() + other_edges = other.edge_index.t().tolist() + data_edges.sort() + other_edges.sort() + if data_edges != other_edges: + continue + + # Check if edge attributes are the same (after sorting) + data_edge_attr = data.edge_attr[ + torch.argsort(data.edge_index[0] * data.num_nodes + data.edge_index[1]) + ] + other_edge_attr = other.edge_attr[ + torch.argsort( + other.edge_index[0] * other.num_nodes + other.edge_index[1] + ) + ] + if not torch.all(data_edge_attr == other_edge_attr): + continue + + # If all checks pass, the graphs are equal + out[i] = True + + return out.view(self.batch_shape) + + @property + def is_sink_state(self) -> torch.Tensor: + """Returns a tensor that is True for states that are sf.""" + return self._compare(self.sf) + + @property + def is_initial_state(self) -> torch.Tensor: + """Returns a tensor that is True for states that are s0.""" + return self._compare(self.s0) + + @classmethod + def stack(cls, states: List[GraphStates]) -> GraphStates: + """Given a list of states, stacks them along a new dimension (0). + + Args: + states: List of GraphStates objects to stack. + + Returns: + A new GraphStates object with the stacked states. + """ + # Check that all states have the same batch shape + state_batch_shape = states[0].batch_shape + assert all(state.batch_shape == state_batch_shape for state in states) + + # Get all data lists + all_data_lists = [state.tensor.to_data_list() for state in states] + + # Flatten the list of lists + flat_data_list = [data for data_list in all_data_lists for data in data_list] + + # Create a new batch + batch = GeometricBatch.from_data_list(flat_data_list) + + # Set the batch shape + batch.batch_shape = (len(states),) + state_batch_shape + + # Create a new GraphStates object + out = cls(batch) + + # Stack log rewards if they exist + if all(state._log_rewards is not None for state in states): + log_rewards = [] + for state in states: + log_rewards.append(state._log_rewards) + out.log_rewards = torch.stack(log_rewards) + + return out + + @property + def forward_masks(self) -> dict: + """Returns masks denoting allowed forward actions. + + Returns: + A dictionary containing masks for different action types. + """ + # Get the data list from the batch + data_list = self.tensor.to_data_list() + N = self.tensor.x.size(0) + + # Initialize masks + action_type_mask = torch.ones( + self.batch_shape + (3,), dtype=torch.bool, device=self.device ) - stacked_states.backward_masks = torch.stack( - [s.backward_masks for s in states], dim=0 + features_mask = torch.ones( + self.batch_shape + (self.tensor.x.size(1),), + dtype=torch.bool, + device=self.device, ) - return stacked_states + edge_index_masks = torch.ones( + (len(data_list), N, N), dtype=torch.bool, device=self.device + ) + + # For each graph in the batch + for i, data in enumerate(data_list): + # Flatten the batch index + flat_idx = i + + # ADD_NODE is always allowed + action_type_mask[flat_idx, GraphActionType.ADD_NODE] = True + + # ADD_EDGE is allowed only if there are at least 2 nodes + assert data.num_nodes is not None + action_type_mask[flat_idx, GraphActionType.ADD_EDGE] = data.num_nodes > 1 + + # EXIT is always allowed + action_type_mask[flat_idx, GraphActionType.EXIT] = True + + # Create edge_index mask as a dense representation (NxN matrix) + start_n = 0 + for i, data in enumerate(data_list): + # For each graph, create a dense mask for potential edges + n = data.num_nodes + assert n is not None + edge_mask = torch.ones((n, n), dtype=torch.bool, device=self.device) + # Remove self-loops by setting diagonal to False + edge_mask.fill_diagonal_(False) + + # Exclude existing edges + if data.edge_index.size(1) > 0: + for j in range(data.edge_index.size(1)): + src, dst = data.edge_index[0, j], data.edge_index[1, j] + edge_mask[src, dst] = False + + edge_index_masks[i, start_n : (start_n + n), start_n : (start_n + n)] = ( + edge_mask + ) + start_n += n + + # Update ADD_EDGE mask based on whether there are valid edges to add + action_type_mask[flat_idx, GraphActionType.ADD_EDGE] &= edge_mask.any() + + return { + "action_type": action_type_mask, + "features": features_mask, + "edge_index": edge_index_masks, + } + + @property + def backward_masks(self) -> dict: + """Returns masks denoting allowed backward actions. + + Returns: + A dictionary containing masks for different action types. + """ + # Get the data list from the batch + data_list = self.tensor.to_data_list() + N = self.tensor.x.size(0) + + # Initialize masks + action_type_mask = torch.ones( + self.batch_shape + (3,), dtype=torch.bool, device=self.device + ) + features_mask = torch.ones( + self.batch_shape + (self.tensor.x.size(1),), + dtype=torch.bool, + device=self.device, + ) + edge_index_masks = torch.zeros( + (len(data_list), N, N), dtype=torch.bool, device=self.device + ) + + # For each graph in the batch + for i, data in enumerate(data_list): + assert data.num_nodes is not None + + # Flatten the batch index + flat_idx = i + + # ADD_NODE is allowed if there's at least one node (can remove a node) + action_type_mask[flat_idx, GraphActionType.ADD_NODE] = data.num_nodes >= 1 + + # ADD_EDGE is allowed if there's at least one edge (can remove an edge) + action_type_mask[flat_idx, GraphActionType.ADD_EDGE] = ( + data.edge_index.size(1) > 0 + ) + + # EXIT is allowed if there's at least one node + action_type_mask[flat_idx, GraphActionType.EXIT] = data.num_nodes >= 1 + + # Create edge_index mask for backward actions (existing edges that can be removed) + start_n = 0 + for i, data in enumerate(data_list): + # For backward actions, we can only remove existing edges + n = data.num_nodes + assert n is not None + edge_mask = torch.zeros((n, n), dtype=torch.bool, device=self.device) + + # Include only existing edges + if data.edge_index.size(1) > 0: + for j in range(data.edge_index.size(1)): + src, dst = ( + data.edge_index[0, j].item(), + data.edge_index[1, j].item(), + ) + edge_mask[src, dst] = True + + edge_index_masks[i, start_n : (start_n + n), start_n : (start_n + n)] = ( + edge_mask + ) + start_n += n + + return { + "action_type": action_type_mask, + "features": features_mask, + "edge_index": edge_index_masks, + } + + def flatten(self) -> None: + raise NotImplementedError + + def extend_with_sf(self, required_first_dim: int) -> None: + raise NotImplementedError diff --git a/src/gfn/utils/distributions.py b/src/gfn/utils/distributions.py index f4948d0d..90131c89 100644 --- a/src/gfn/utils/distributions.py +++ b/src/gfn/utils/distributions.py @@ -1,5 +1,7 @@ +from typing import Dict, Literal + import torch -from torch.distributions import Categorical +from torch.distributions import Categorical, Distribution class UnsqueezedCategorical(Categorical): @@ -39,3 +41,73 @@ def log_prob(self, sample: torch.Tensor) -> torch.Tensor: """ assert sample.shape[-1] == 1 return super().log_prob(sample.squeeze(-1)) + + +class CompositeDistribution( + Distribution +): # TODO: may use CompositeDistribution in TensorDict + """A mixture distribution.""" + + def __init__(self, dists: Dict[str, Distribution]): + """Initializes the mixture distribution. + + Args: + dists: A dictionary of distributions. + """ + super().__init__() + self.dists = dists + + def sample(self, sample_shape=torch.Size()) -> Dict[str, torch.Tensor]: + return {k: v.sample(sample_shape) for k, v in self.dists.items()} + + def log_prob(self, sample: Dict[str, torch.Tensor]) -> torch.Tensor | Literal[0]: + log_probs = [ + v.log_prob(sample[k]).reshape(sample[k].shape[0], -1).sum(dim=-1) + for k, v in self.dists.items() + ] + # Note: this returns the sum of the log_probs over all the components + # as it is a uniform mixture distribution. + return sum(log_probs) + + +class CategoricalIndexes(Categorical): + """Samples indexes from a categorical distribution.""" + + def __init__(self, probs: torch.Tensor, n_nodes: int): + """Initializes the distribution. + + Args: + probs: The probabilities of the categorical distribution. + n: The number of nodes in the graph. + """ + assert probs.shape == (probs.shape[0], n_nodes * n_nodes) + self.n_nodes = n_nodes + super().__init__(probs) + + def sample(self, sample_shape=torch.Size()) -> torch.Tensor: + samples = super().sample(sample_shape) + out = torch.stack( + [ + samples // self.n_nodes, + samples % self.n_nodes, + ], + dim=-1, + ) + return out + + def log_prob(self, value): + value = value[..., 0] * self.n_nodes + value[..., 1] + return super().log_prob(value) + + +class CategoricalActionType(Categorical): # TODO: remove, just to sample 1 action_type + def __init__(self, probs: torch.Tensor): + self.batch_len = len(probs) + super().__init__(probs[0]) + + def sample(self, sample_shape=torch.Size()) -> torch.Tensor: + samples = super().sample(sample_shape) + return samples.repeat(self.batch_len) + + def log_prob(self, value): + return super().log_prob(value[0]).repeat(self.batch_len) diff --git a/src/gfn/utils/modules.py b/src/gfn/utils/modules.py index 86a474ac..1dc8a1f0 100644 --- a/src/gfn/utils/modules.py +++ b/src/gfn/utils/modules.py @@ -13,10 +13,11 @@ def __init__( self, input_dim: int, output_dim: int, - hidden_dim: Optional[int] = 256, + hidden_dim: int = 256, n_hidden_layers: Optional[int] = 2, activation_fn: Optional[Literal["relu", "tanh", "elu"]] = "relu", trunk: Optional[nn.Module] = None, + add_layer_norm: bool = False, ): """Instantiates a MLP instance. @@ -28,6 +29,8 @@ def __init__( activation_fn: Activation function. trunk: If provided, this module will be used as the trunk of the network (i.e. all layers except last layer). + add_layer_norm: If True, add a LayerNorm after each linear layer. + (incompatible with `trunk` argument) """ super().__init__() self._input_dim = input_dim @@ -45,9 +48,18 @@ def __init__( activation = nn.ReLU elif activation_fn == "tanh": activation = nn.Tanh - arch = [nn.Linear(input_dim, hidden_dim), activation()] + if add_layer_norm: + arch = [ + nn.Linear(input_dim, hidden_dim), + nn.LayerNorm(hidden_dim), + activation(), + ] + else: + arch = [nn.Linear(input_dim, hidden_dim), activation()] for _ in range(n_hidden_layers - 1): arch.append(nn.Linear(hidden_dim, hidden_dim)) + if add_layer_norm: + arch.append(nn.LayerNorm(hidden_dim)) arch.append(activation()) self.trunk = nn.Sequential(*arch) self.trunk.hidden_dim = torch.tensor(hidden_dim) diff --git a/src/gfn/utils/prob_calculations.py b/src/gfn/utils/prob_calculations.py index a881194a..74dffdb2 100644 --- a/src/gfn/utils/prob_calculations.py +++ b/src/gfn/utils/prob_calculations.py @@ -26,9 +26,9 @@ def check_cond_forward( return module(states) -######################### -##### Trajectories ##### -######################### +# ------------ +# Trajectories +# ------------ def get_trajectory_pfs_and_pbs( @@ -170,9 +170,9 @@ def get_trajectory_pbs( return log_pb_trajectories -######################## -##### Transitions ##### -######################## +# ----------- +# Transitions +# ----------- def get_transition_pfs_and_pbs( diff --git a/src/gfn/utils/training.py b/src/gfn/utils/training.py index 5102a63b..10dada67 100644 --- a/src/gfn/utils/training.py +++ b/src/gfn/utils/training.py @@ -129,7 +129,7 @@ def states_actions_tns_to_traj( # stack is a class method, so actions[0] is just to access a class instance and is not particularly relevant actions = actions[0].stack(actions) log_rewards = env.log_reward(states[-2]) - states = states[0].stack_states(states) + states = states[0].stack(states) when_is_done = torch.tensor([len(states_tns) - 1]) log_probs = None @@ -167,7 +167,8 @@ def warm_up( env: The environment instance n_epochs: Number of epochs for warmup batch_size: Number of trajectories to sample from replay buffer - recalculate_all_logprobs: For PFBasedGFlowNets only, force recalculating all log probs. Useful trajectories do not already have log probs. + recalculate_all_logprobs: For PFBasedGFlowNets only, force recalculating all log probs. + Useful trajectories do not already have log probs. Returns: GFlowNet: A trained GFlowNet """ diff --git a/testing/test_actions.py b/testing/test_actions.py new file mode 100644 index 00000000..807dc0a5 --- /dev/null +++ b/testing/test_actions.py @@ -0,0 +1,142 @@ +from copy import deepcopy + +import pytest +import torch +from tensordict import TensorDict + +from gfn.actions import Actions, GraphActions + + +class ContinuousActions(Actions): + action_shape = (10,) + dummy_action = torch.zeros(10) + exit_action = torch.ones(10) + + +class TestGraphActions(GraphActions): + features_dim = 10 + + +@pytest.fixture +def continuous_action(): + return ContinuousActions(tensor=torch.arange(0, 10)) + + +@pytest.fixture +def graph_action(): + return TestGraphActions( + tensor=TensorDict( + { + "action_type": torch.zeros((1,), dtype=torch.float32), + "features": torch.zeros((1, 10), dtype=torch.float32), + }, + device="cpu", + ) + ) + + +def test_continuous_action(continuous_action): + BATCH = 5 + + exit_actions = continuous_action.make_exit_actions((BATCH,)) + assert torch.all( + exit_actions.tensor == continuous_action.exit_action.repeat(BATCH, 1) + ) + assert torch.all(exit_actions.is_exit == torch.ones(BATCH, dtype=torch.bool)) + assert torch.all(exit_actions.is_dummy == torch.zeros(BATCH, dtype=torch.bool)) + + dummy_actions = continuous_action.make_dummy_actions((BATCH,)) + assert torch.all( + dummy_actions.tensor == continuous_action.dummy_action.repeat(BATCH, 1) + ) + assert torch.all(dummy_actions.is_dummy == torch.ones(BATCH, dtype=torch.bool)) + assert torch.all(dummy_actions.is_exit == torch.zeros(BATCH, dtype=torch.bool)) + + # Test stack + stacked_actions = continuous_action.stack([exit_actions, dummy_actions]) + assert stacked_actions.batch_shape == (2, BATCH) + assert torch.all( + stacked_actions.tensor + == torch.stack([exit_actions.tensor, dummy_actions.tensor], dim=0) + ) + is_exit_stacked = torch.stack([exit_actions.is_exit, dummy_actions.is_exit], dim=0) + assert torch.all(stacked_actions.is_exit == is_exit_stacked) + assert stacked_actions[0, 1].is_exit + stacked_actions[0, 1] = stacked_actions[1, 1] + is_exit_stacked[0, 1] = False + assert torch.all(stacked_actions.is_exit == is_exit_stacked) + + # Test extend + extended_actions = deepcopy(exit_actions) + extended_actions.extend(dummy_actions) + assert extended_actions.batch_shape == (BATCH * 2,) + assert torch.all( + extended_actions.tensor + == torch.cat([exit_actions.tensor, dummy_actions.tensor], dim=0) + ) + is_exit_extended = torch.cat([exit_actions.is_exit, dummy_actions.is_exit], dim=0) + assert torch.all(extended_actions.is_exit == is_exit_extended) + assert extended_actions[0].is_exit and extended_actions[BATCH].is_dummy + extended_actions[0] = extended_actions[BATCH] + is_exit_extended[0] = False + assert torch.all(extended_actions.is_exit == is_exit_extended) + + +def test_graph_action(graph_action): + BATCH = 5 + + exit_actions = graph_action.make_exit_actions((BATCH,)) + assert torch.all(exit_actions.is_exit == torch.ones(BATCH, dtype=torch.bool)) + assert torch.all(exit_actions.is_dummy == torch.zeros(BATCH, dtype=torch.bool)) + dummy_actions = graph_action.make_dummy_actions((BATCH,)) + assert torch.all(dummy_actions.is_dummy == torch.ones(BATCH, dtype=torch.bool)) + assert torch.all(dummy_actions.is_exit == torch.zeros(BATCH, dtype=torch.bool)) + + # Test stack + stacked_actions = graph_action.stack([exit_actions, dummy_actions]) + assert stacked_actions.batch_shape == (2, BATCH) + manually_stacked_tensor = torch.stack( + [exit_actions.tensor, dummy_actions.tensor], dim=0 + ) + assert torch.all( + stacked_actions.tensor["action_type"] + == manually_stacked_tensor.get("action_type") + ) + assert torch.all( + stacked_actions.tensor["features"] == manually_stacked_tensor.get("features") + ) + assert torch.all( + stacked_actions.tensor["edge_index"] == manually_stacked_tensor.get("edge_index") + ) + is_exit_stacked = torch.stack([exit_actions.is_exit, dummy_actions.is_exit], dim=0) + assert torch.all(stacked_actions.is_exit == is_exit_stacked) + assert stacked_actions[0, 1].is_exit + stacked_actions[0, 1] = stacked_actions[1, 1] + is_exit_stacked[0, 1] = False + assert torch.all(stacked_actions.is_exit == is_exit_stacked) + + # Test extend + extended_actions = deepcopy(exit_actions) + extended_actions.extend(dummy_actions) + assert extended_actions.batch_shape == (BATCH * 2,) + manually_extended_tensor = torch.cat( + [exit_actions.tensor, dummy_actions.tensor], dim=0 + ) + assert torch.all( + extended_actions.tensor["action_type"] + == manually_extended_tensor.get("action_type") + ) + assert torch.all( + extended_actions.tensor["features"] == manually_extended_tensor.get("features") + ) + + assert torch.all( + extended_actions.tensor["edge_index"] + == manually_extended_tensor.get("edge_index") + ) + is_exit_extended = torch.cat([exit_actions.is_exit, dummy_actions.is_exit], dim=0) + assert torch.all(extended_actions.is_exit == is_exit_extended) + assert extended_actions[0].is_exit and extended_actions[BATCH].is_dummy + extended_actions[0] = extended_actions[BATCH] + is_exit_extended[0] = False + assert torch.all(extended_actions.is_exit == is_exit_extended) diff --git a/testing/test_environments.py b/testing/test_environments.py index 97c263df..b3f1a21b 100644 --- a/testing/test_environments.py +++ b/testing/test_environments.py @@ -3,9 +3,13 @@ import numpy as np import pytest import torch +from tensordict import TensorDict +from gfn.actions import GraphActionType from gfn.env import NonValidActionsError from gfn.gym import Box, DiscreteEBM, HyperGrid +from gfn.gym.graph_building import GraphBuilding +from gfn.states import GraphStates # Utilities. @@ -318,3 +322,192 @@ def test_get_grid(): # State indices of the grid are ordered from 0:HEIGHT**2. assert (env.get_states_indices(grid).ravel() == torch.arange(HEIGHT**2)).all() + + +def test_graph_env(): + FEATURE_DIM = 8 + BATCH_SIZE = 3 + NUM_NODES = 5 + + env = GraphBuilding( + feature_dim=FEATURE_DIM, state_evaluator=lambda s: torch.zeros(s.batch_shape) + ) + states = env.reset(batch_shape=BATCH_SIZE) + assert states.batch_shape == (BATCH_SIZE,) + action_cls = env.make_actions_class() + + # We can't add an edge without nodes. + with pytest.raises(NonValidActionsError): + actions = action_cls( + TensorDict( + { + "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_EDGE), + "features": torch.rand((BATCH_SIZE, FEATURE_DIM)), + "edge_index": torch.randint( + 0, 10, (BATCH_SIZE, 2), dtype=torch.long + ), + }, + batch_size=BATCH_SIZE, + ) + ) + states = env._step(states, actions) + + # Add nodes. + for _ in range(NUM_NODES): + actions = action_cls( + TensorDict( + { + "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_NODE), + "features": torch.rand((BATCH_SIZE, FEATURE_DIM)), + }, + batch_size=BATCH_SIZE, + ) + ) + states = env._step(states, actions) + + assert states.tensor.x.shape == (BATCH_SIZE * NUM_NODES, FEATURE_DIM) + + # We can't add a node with the same features. + with pytest.raises(NonValidActionsError): + first_node_mask = torch.arange(len(states.tensor.x)) // BATCH_SIZE == 0 + actions = action_cls( + TensorDict( + { + "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_NODE), + "features": states.tensor.x[first_node_mask], + }, + batch_size=BATCH_SIZE, + ) + ) + states = env._step(states, actions) + + # We can't add a self-loop edge for GraphBuilding env. + with pytest.raises(NonValidActionsError): + edge_index = torch.randint(0, 3, (BATCH_SIZE,), dtype=torch.long) + actions = action_cls( + TensorDict( + { + "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_EDGE), + "features": torch.rand((BATCH_SIZE, FEATURE_DIM)), + "edge_index": torch.stack([edge_index, edge_index], dim=1), + }, + batch_size=BATCH_SIZE, + ) + ) + states = env._step(states, actions) + + # Add edges. + for i in range(NUM_NODES - 1): + actions = action_cls( + TensorDict( + { + "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_EDGE), + "features": torch.rand((BATCH_SIZE, FEATURE_DIM)), + "edge_index": torch.tensor([[i, i + 1]] * BATCH_SIZE), + }, + batch_size=BATCH_SIZE, + ) + ) + states = env._step(states, actions) + + actions = action_cls( + TensorDict( + { + "action_type": torch.full((BATCH_SIZE,), GraphActionType.EXIT), + }, + batch_size=BATCH_SIZE, + ) + ) + + sf_states = env._step(states, actions) + assert torch.all(sf_states.is_sink_state) + assert isinstance(sf_states, GraphStates) + env.reward(sf_states) + + num_edges_per_batch = len(states.tensor.edge_attr) // BATCH_SIZE + # Remove edges. + for i in reversed(range(num_edges_per_batch)): + edge_idx = torch.arange(i, (i + 1) * BATCH_SIZE, i + 1) + actions = action_cls( + TensorDict( + { + "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_EDGE), + "features": states.tensor.edge_attr[edge_idx], + "edge_index": states.tensor.edge_index[:, edge_idx].T + - states.tensor.ptr[:-1, None], + }, + batch_size=BATCH_SIZE, + ) + ) + states = env._backward_step(states, actions) + + # We can't remove edges that don't exist. + with pytest.raises(NonValidActionsError): + actions = action_cls( + TensorDict( + { + "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_EDGE), + "features": torch.rand((BATCH_SIZE, FEATURE_DIM)), + "edge_index": torch.randint( + 0, 10, (BATCH_SIZE, 2), dtype=torch.long + ), + }, + batch_size=BATCH_SIZE, + ) + ) + states = env._backward_step(states, actions) + + # Remove nodes. + for i in reversed(range(1, NUM_NODES + 1)): + edge_idx = torch.arange(BATCH_SIZE) * i + actions = action_cls( + TensorDict( + { + "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_NODE), + "features": states.tensor.x[edge_idx], + }, + batch_size=BATCH_SIZE, + ) + ) + states = env._backward_step(states, actions) + + assert states.tensor.x.shape == (0, FEATURE_DIM) + + # Add one random node again + features = torch.rand((BATCH_SIZE, FEATURE_DIM)) + actions = action_cls( + TensorDict( + { + "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_NODE), + "features": features, + }, + batch_size=BATCH_SIZE, + ) + ) + states = env._step(states, actions) + + # We can't remove nodes that don't exist. + with pytest.raises(NonValidActionsError): + actions = action_cls( + TensorDict( + { + "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_NODE), + "features": features + 1e-5, + }, + batch_size=BATCH_SIZE, + ) + ) + states = env._backward_step(states, actions) + + # Remove the node. + actions = action_cls( + TensorDict( + { + "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_NODE), + "features": features, + }, + batch_size=BATCH_SIZE, + ) + ) + states = env._backward_step(states, actions) + assert states.tensor.x.shape == (0, FEATURE_DIM) diff --git a/testing/test_graph_states.py b/testing/test_graph_states.py new file mode 100644 index 00000000..ce0f119b --- /dev/null +++ b/testing/test_graph_states.py @@ -0,0 +1,462 @@ +import pytest +import torch +from torch_geometric.data import Batch as GeometricBatch +from torch_geometric.data import Data as GeometricData + +from gfn.actions import GraphActionType +from gfn.states import GraphStates + + +class MyGraphStates(GraphStates): + # Initial state: a graph with 2 nodes and 1 edge + s0 = GeometricData( + x=torch.tensor([[1.0], [2.0]]), + edge_index=torch.tensor([[0], [1]]), + edge_attr=torch.tensor([[0.5]]), + ) + + # Sink state: a graph with 2 nodes and 1 edge (different from s0) + sf = GeometricData( + x=torch.tensor([[3.0], [4.0]]), + edge_index=torch.tensor([[0], [1]]), + edge_attr=torch.tensor([[0.7]]), + ) + + +@pytest.fixture +def datas(): + """Creates a list of 10 GeometricData objects""" + return [ + GeometricData( + x=torch.tensor([[i], [i + 0.5]]), + edge_index=torch.tensor([[0], [1]]), + edge_attr=torch.tensor([[i * 0.1]]), + ) + for i in range(10) + ] + + +@pytest.fixture +def simple_graph_state(datas): + """Creates a simple graph state with 2 nodes and 1 edge""" + data = datas[0] + batch = GeometricBatch.from_data_list([data]) + batch.batch_shape = (1,) + return MyGraphStates(batch) + + +@pytest.fixture +def empty_graph_state(): + """Creates an empty GraphStates object""" + # Create an empty batch + batch = GeometricBatch() + batch.x = torch.zeros((0, 1)) + batch.edge_index = torch.zeros((2, 0), dtype=torch.long) + batch.edge_attr = torch.zeros((0, 1)) + batch.batch = torch.zeros((0,), dtype=torch.long) + batch.batch_shape = (0,) + return MyGraphStates(batch) + + +def test_getitem_1d(datas): + """Test indexing into GraphStates + + Make sure the behavior is consistent with that of a Tensor.__getitem__. + """ + # Create a tensor with 3 elements for comparison + tsr = torch.tensor([1, 2, 3]) + + # Create a batch with 3 graphs + batch = GeometricBatch.from_data_list(datas[:3]) + batch.batch_shape = (3,) + assert tuple(tsr.shape) == batch.batch_shape == (3,) + states = MyGraphStates(batch) + states.log_rewards = tsr.clone() + + # Get a single graph + single_tsr = tsr[1] + single_state = states[1] + assert tuple(single_tsr.shape) == single_state.tensor.batch_shape == () + assert single_state.log_rewards is not None and single_state.log_rewards.shape == () + assert single_state.tensor.num_nodes == 2 + assert torch.allclose(single_state.tensor.x, datas[1].x) + assert torch.allclose(single_state.log_rewards, tsr[1]) + + # Get multiple graphs + multi_tsr = tsr[[0, 2]] + multi_state = states[[0, 2]] + assert tuple(multi_tsr.shape) == multi_state.tensor.batch_shape == (2,) + assert multi_state.log_rewards is not None and multi_state.log_rewards.shape == (2,) + assert multi_state.tensor.num_nodes == 4 + assert torch.allclose(multi_state.tensor.get_example(0).x, datas[0].x) + assert torch.allclose(multi_state.tensor.get_example(1).x, datas[2].x) + assert torch.allclose(multi_state.log_rewards, tsr[[0, 2]]) + + +def test_getitem_2d(datas): + """Test indexing into GraphStates with 2D batch shape + + Make sure the behavior is consistent with that of a Tensor.__getitem__. + """ + # Create a tensor with 4 elements for comparison + tsr = torch.tensor([[1, 2], [3, 4]]) + + # Create a batch with 2x2 graphs + batch = GeometricBatch.from_data_list(datas[:4]) + batch.batch_shape = (2, 2) + assert tuple(tsr.shape) == batch.batch_shape == (2, 2) + states = MyGraphStates(batch) + states.log_rewards = tsr.clone() + + # Get a single row + tsr_row = tsr[0] + batch_row = states[0] + assert tuple(tsr_row.shape) == batch_row.tensor.batch_shape == (2,) + assert batch_row.log_rewards is not None and batch_row.log_rewards.shape == (2,) + assert batch_row.tensor.num_nodes == 4 # 2 graphs * 2 nodes + assert torch.allclose(batch_row.tensor.get_example(0).x, datas[0].x) + assert torch.allclose(batch_row.tensor.get_example(1).x, datas[1].x) + assert torch.allclose(batch_row.log_rewards, tsr[0]) + + # Try again with slicing + tsr_row2 = tsr[0, :] + batch_row2 = states[0, :] + assert tuple(tsr_row2.shape) == batch_row2.tensor.batch_shape == (2,) + assert torch.equal(batch_row.tensor.x, batch_row2.tensor.x) + + # Get a single graph with 2D indexing + single_tsr = tsr[1, 1] + single_state = states[1, 1] + assert tuple(single_tsr.shape) == single_state.tensor.batch_shape == () + assert single_state.log_rewards is not None and single_state.log_rewards.shape == () + assert single_state.tensor.num_nodes == 2 # 1 graph * 2 nodes + assert torch.allclose(single_state.tensor.x, datas[3].x) + assert torch.allclose(single_state.log_rewards, tsr[1, 1]) + + with pytest.raises(IndexError): + states[2, 2] + + # We can't index on a Batch with 0-dimensional batch shape + with pytest.raises(AssertionError): + single_state[0] + + +def test_setitem_1d(datas): + """Test setting values in GraphStates""" + # Create a graph state with 3 graphs + batch = GeometricBatch.from_data_list(datas[:3]) + batch.batch_shape = (3,) + states = MyGraphStates(batch) + + # Create a new graph state + new_batch = GeometricBatch.from_data_list(datas[3:5]) + new_batch.batch_shape = (2,) + new_states = MyGraphStates(new_batch) + + # Set the new graph in the first position + states[0] = new_states[0] + + # Check that the first graph is now the new graph + first_graph = states[0].tensor + assert torch.equal(first_graph.x, datas[3].x) + assert torch.equal(first_graph.edge_attr, datas[3].edge_attr) + assert torch.equal(first_graph.edge_index, datas[3].edge_index) + assert states.tensor.batch_shape == (3,) # Batch shape should not change + + # Set the new graph in the second and third positions + states[1:] = new_states # pyright: ignore # TODO: Fix pyright issue + + # Check that the second and third graphs are now the new graph + second_graph = states[1].tensor + assert torch.equal(second_graph.x, datas[3].x) + assert torch.equal(second_graph.edge_attr, datas[3].edge_attr) + assert torch.equal(second_graph.edge_index, datas[3].edge_index) + + third_graph = states[2].tensor + assert torch.equal(third_graph.x, datas[4].x) + assert torch.equal(third_graph.edge_attr, datas[4].edge_attr) + assert torch.equal(third_graph.edge_index, datas[4].edge_index) + assert states.tensor.batch_shape == (3,) # Batch shape should not change + + # Cannot set a graph with a wrong length + with pytest.raises(AssertionError): + states[0] = new_states + with pytest.raises(AssertionError): + states[1:] = new_states[0] # pyright: ignore + + +def test_setitem_2d(datas): + """Test setting values in GraphStates with 2D batch shape""" + # Create a graph state with 2x2 graphs + batch = GeometricBatch.from_data_list(datas[:4]) + batch.batch_shape = (2, 2) + states = MyGraphStates(batch) + + # Set the new graphs in the first row + new_batch_row = GeometricBatch.from_data_list(datas[4:6]) + new_batch_row.batch_shape = (2,) + new_states_row = MyGraphStates(new_batch_row) + states[0] = new_states_row + assert torch.equal(states[0, 0].tensor.x, datas[4].x) + assert torch.equal(states[0, 0].tensor.edge_attr, datas[4].edge_attr) + assert torch.equal(states[0, 0].tensor.edge_index, datas[4].edge_index) + assert states.tensor.batch_shape == (2, 2) # Batch shape should not change + + # Set the new graphs in the first column + new_batch_col = GeometricBatch.from_data_list(datas[6:8]) + new_batch_col.batch_shape = (2,) + new_states_col = MyGraphStates(new_batch_col) + states[:, 1] = new_states_col + assert torch.equal(states[1, 1].tensor.x, datas[7].x) + assert torch.equal(states[1, 1].tensor.edge_attr, datas[7].edge_attr) + assert torch.equal(states[1, 1].tensor.edge_index, datas[7].edge_index) + assert states.tensor.batch_shape == (2, 2) # Batch shape should not change + + +def test_clone(simple_graph_state): + """Test cloning a GraphStates object""" + cloned = simple_graph_state.clone() + + # Check that the clone has the same content + assert cloned.tensor.batch_shape == simple_graph_state.tensor.batch_shape + assert torch.equal(cloned.tensor.x, simple_graph_state.tensor.x) + assert torch.equal(cloned.tensor.edge_index, simple_graph_state.tensor.edge_index) + assert torch.equal(cloned.tensor.edge_attr, simple_graph_state.tensor.edge_attr) + + # Modify the clone and check that the original is unchanged + cloned.tensor.x[0, 0] = 99.0 + assert cloned.tensor.x[0, 0] == 99.0 + assert simple_graph_state.tensor.x[0, 0] == 0.0 + + +def test_is_initial_state(datas): + """Test is_initial_state property""" + # Create a batch with s0 and a different graph + s0 = MyGraphStates.s0.clone() + different = datas[9] + batch = GeometricBatch.from_data_list([s0, different]) + batch.batch_shape = (2,) + states = MyGraphStates(batch) + + # Check is_initial_state + is_initial = states.is_initial_state + assert is_initial[0].item() + assert not is_initial[1].item() + + +def test_is_sink_state(datas): + """Test is_sink_state property""" + # Create a batch with sf and a different graph + sf = MyGraphStates.sf.clone() + different = datas[9] + batch = GeometricBatch.from_data_list([sf, different]) + batch.batch_shape = (2,) + states = MyGraphStates(batch) + + # Check is_sink_state + is_sink = states.is_sink_state + assert is_sink[0].item() + assert not is_sink[1].item() + + +def test_from_batch_shape(): + """Test creating states from batch shape""" + # Create states with initial state + states = MyGraphStates.from_batch_shape((3,)) + assert states.tensor.batch_shape[0] == 3 + assert states.tensor.num_nodes == 6 # 3 graphs * 2 nodes per graph + + # Check all graphs are s0 + is_initial = states.is_initial_state + assert torch.all(is_initial) + + # Create states with sink state + sink_states = MyGraphStates.from_batch_shape((2,), sink=True) + assert sink_states.tensor.batch_shape[0] == 2 + assert sink_states.tensor.num_nodes == 4 # 2 graphs * 2 nodes per graph + + # Check all graphs are sf + is_sink = sink_states.is_sink_state + assert torch.all(is_sink) + + +def test_forward_masks(datas): + """Test forward_masks property""" + # Create a graph with 2 nodes and 1 edge + data = datas[0] + batch = GeometricBatch.from_data_list([data]) + batch.batch_shape = (1,) + states = MyGraphStates(batch) + + # Get forward masks + masks = states.forward_masks + + # Check action type mask + assert masks["action_type"].shape == (1, 3) + assert masks["action_type"][0, GraphActionType.ADD_NODE].item() # Can add node + assert ( + masks["action_type"][0, GraphActionType.ADD_EDGE] + ).item() # Can add edge (2 nodes) + assert masks["action_type"][0, GraphActionType.EXIT].item() # Can exit + + # Check features mask + assert masks["features"].shape == (1, 1) # 1 feature dimension + assert masks["features"][0, 0].item() # All features allowed + + # Check edge_index masks + assert len(masks["edge_index"]) == 1 # 1 graph + assert torch.all( + masks["edge_index"][0] == torch.tensor([[False, False], [True, False]]) + ) + + +def test_backward_masks(datas): + """Test backward_masks property""" + # Create a graph with 2 nodes and 1 edge + data = datas[0] + batch = GeometricBatch.from_data_list([data]) + batch.batch_shape = (1,) + states = MyGraphStates(batch) + + # Get backward masks + masks = states.backward_masks + + # Check action type mask + assert masks["action_type"].shape == (1, 3) + assert masks["action_type"][0, GraphActionType.ADD_NODE].item() # Can remove node + assert masks["action_type"][0, GraphActionType.ADD_EDGE].item() # Can remove edge + assert masks["action_type"][0, GraphActionType.EXIT].item() # Can exit + + # Check features mask + assert masks["features"].shape == (1, 1) # 1 feature dimension + assert masks["features"][0, 0].item() # All features allowed + + # Check edge_index masks + assert len(masks["edge_index"]) == 1 # 1 graph + assert torch.all( + masks["edge_index"][0] == torch.tensor([[False, True], [False, False]]) + ) + + +def test_stack_1d(datas): + """Test stacking GraphStates objects""" + # Create two states + batch1 = GeometricBatch.from_data_list(datas[0:2]) + batch1.batch_shape = (2,) + state1 = MyGraphStates(batch1) + + batch2 = GeometricBatch.from_data_list(datas[2:4]) + batch2.batch_shape = (2,) + state2 = MyGraphStates(batch2) + + # Stack the states + stacked = MyGraphStates.stack([state1, state2]) + + # Check the batch shape + assert stacked.tensor.batch_shape == (2, 2) + + # Check the number of nodes and edges + assert stacked.tensor.num_nodes == 8 # 4 states * 2 nodes + assert stacked.tensor.num_edges == 4 # 4 states * 1 edge + + # Check the batch indices + assert torch.equal(stacked.tensor.batch[:4], batch1.batch) + assert torch.equal(stacked.tensor.batch[4:], batch2.batch + 2) + + +def test_stack_2d(datas): + """Test stacking GraphStates objects with 2D batch shape""" + # Create two states + batch1 = GeometricBatch.from_data_list(datas[:4]) + batch1.batch_shape = (2, 2) + state1 = MyGraphStates(batch1) + + batch2 = GeometricBatch.from_data_list(datas[4:8]) + batch2.batch_shape = (2, 2) + state2 = MyGraphStates(batch2) + + # Stack the states + stacked = MyGraphStates.stack([state1, state2]) + + # Check the batch shape + assert stacked.tensor.batch_shape == (2, 2, 2) + + # Check the number of nodes and edges + assert stacked.tensor.num_nodes == 16 # 8 states * 2 nodes + assert stacked.tensor.num_edges == 8 # 8 states * 1 edge + + # Check the batch indices + assert torch.equal(stacked.tensor.batch[:8], batch1.batch) + assert torch.equal(stacked.tensor.batch[8:], batch2.batch + 4) + + +def test_extend_empty_state(empty_graph_state, simple_graph_state): + """Test extending an empty state with a non-empty state""" + empty_graph_state.extend(simple_graph_state) + + # Check that the empty state now has the same content as the simple state + assert empty_graph_state.tensor.batch_shape == simple_graph_state.tensor.batch_shape + assert torch.equal(empty_graph_state.tensor.x, simple_graph_state.tensor.x) + assert torch.equal( + empty_graph_state.tensor.edge_index, simple_graph_state.tensor.edge_index + ) + assert torch.equal( + empty_graph_state.tensor.edge_attr, simple_graph_state.tensor.edge_attr + ) + assert torch.equal(empty_graph_state.tensor.batch, simple_graph_state.tensor.batch) + + +def test_extend_1d(simple_graph_state): + """Test extending two 1D batch states""" + other_state = simple_graph_state.clone() + + # Store original number of nodes and edges + original_num_nodes = simple_graph_state.tensor.num_nodes + original_num_edges = simple_graph_state.tensor.num_edges + + simple_graph_state.extend(other_state) + + # Check batch shape is updated + assert simple_graph_state.tensor.batch_shape[0] == 2 + + # Check number of nodes and edges doubled + assert simple_graph_state.tensor.num_nodes == 2 * original_num_nodes + assert simple_graph_state.tensor.num_edges == 2 * original_num_edges + + # Check that batch indices are properly updated + batch_indices = simple_graph_state.tensor.batch + assert torch.equal( + batch_indices[:original_num_nodes], + torch.zeros(original_num_nodes, dtype=torch.long), + ) + assert torch.equal( + batch_indices[original_num_nodes:], + torch.ones(original_num_nodes, dtype=torch.long), + ) + + +def test_extend_2d(datas): + """Test extending two 2D batch states""" + batch1 = GeometricBatch.from_data_list(datas[:4]) + batch1.batch_shape = (2, 2) + state1 = MyGraphStates(batch1) + + batch2 = GeometricBatch.from_data_list(datas[4:]) + batch2.batch_shape = (3, 2) + state2 = MyGraphStates(batch2) + + # Extend state1 with state2 + state1.extend(state2) + + # Check final shape should be (max_len=3, B=4) + assert state1.tensor.batch_shape == (3, 4) + + # Check that we have the correct number of nodes and edges + # Each graph has 2 nodes and 1 edge + # For 3 time steps and 2 batches, we should have: + expected_nodes = 3 * 2 * 4 # T * nodes_per_graph * B + expected_edges = 3 * 1 * 4 # T * edges_per_graph * B + + # The actual count might be higher due to padding with sink states + assert state1.tensor.num_nodes >= expected_nodes + assert state1.tensor.num_edges >= expected_edges diff --git a/testing/test_samplers_and_trajectories.py b/testing/test_samplers_and_trajectories.py index 55e93601..c941d256 100644 --- a/testing/test_samplers_and_trajectories.py +++ b/testing/test_samplers_and_trajectories.py @@ -367,5 +367,65 @@ def test_states_actions_tns_to_traj(): # Test that we can add the trajectories to a replay buffer replay_buffer = ReplayBuffer(env, capacity=10) replay_buffer.add(trajs) - assert len(replay_buffer) > 0 - assert isinstance(replay_buffer.training_objects, Trajectories) + + +# ------ GRAPH TESTS ------ + +# TODO: This test fails randomly. it should not rely on a custom GraphActionNet. +# def test_graph_building(): +# feature_dim = 8 +# env = GraphBuilding( +# feature_dim=feature_dim, state_evaluator=lambda s: torch.zeros(s.batch_shape) +# ) + +# module = GraphActionNet(feature_dim) +# pf_estimator = GraphActionPolicyEstimator(module=module) + +# sampler = Sampler(estimator=pf_estimator) +# trajectories = sampler.sample_trajectories( +# env, +# n=7, +# save_logprobs=True, +# save_estimator_outputs=False, +# ) + +# assert len(trajectories) == 7 + + +# class GraphActionNet(nn.Module): +# def __init__(self, feature_dim: int): +# super().__init__() +# self.feature_dim = feature_dim +# self.action_type_conv = GCNConv(feature_dim, 3) +# self.features_conv = GCNConv(feature_dim, feature_dim) +# self.edge_index_conv = GCNConv(feature_dim, 8) + +# def forward(self, states: GraphStates) -> TensorDict: +# node_feature = states.tensor.x.reshape(-1, self.feature_dim) + +# if states.tensor.x.shape[0] == 0: +# action_type = torch.zeros((len(states), 3)) +# action_type[:, GraphActionType.ADD_NODE] = 1 +# features = torch.zeros((len(states), self.feature_dim)) +# else: +# action_type = self.action_type_conv(node_feature, states.tensor.edge_index) +# action_type = action_type.reshape( +# len(states), -1, action_type.shape[-1] +# ).mean(dim=1) +# features = self.features_conv(node_feature, states.tensor.edge_index) +# features = features.reshape(len(states), -1, features.shape[-1]).mean(dim=1) + +# edge_index = self.edge_index_conv(node_feature, states.tensor.edge_index) +# edge_index = torch.einsum("nf,mf->nm", edge_index, edge_index) +# edge_index = edge_index[None].repeat(len(states), 1, 1) + +# return TensorDict( +# { +# "action_type": action_type, +# "features": features, +# "edge_index": edge_index.reshape( +# states.batch_shape + edge_index.shape[1:] +# ), +# }, +# batch_size=states.batch_shape, +# ) diff --git a/tutorials/examples/README.md b/tutorials/examples/README.md index a636bff1..59ec3ff0 100644 --- a/tutorials/examples/README.md +++ b/tutorials/examples/README.md @@ -1,4 +1,4 @@ # Example training scripts The provided training scripts showcase different functionalities of the codebase. -At the top of the files, you will find commands to run in order to reproduce results published elsewhere \ No newline at end of file +At the top of the files, you will find commands to run in order to reproduce results published elsewhere. \ No newline at end of file diff --git a/tutorials/examples/train_box.py b/tutorials/examples/train_box.py index 56eb096b..04682a0f 100644 --- a/tutorials/examples/train_box.py +++ b/tutorials/examples/train_box.py @@ -272,7 +272,9 @@ def main(args): # noqa: C901 wandb.log(to_log, step=iteration) if iteration % (args.validation_interval // 5) == 0: tqdm.write( - f"States: {states_visited}, Loss: {loss.item():.3f}, {logZ_info}true logZ: {env.log_partition:.2f}, JSD: {jsd:.4f}" + f"States: {states_visited}, " + f"Loss: {loss.item():.3f}, {logZ_info}" + f"true logZ: {env.log_partition:.2f}, JSD: {jsd:.4f}" ) if iteration % args.validation_interval == 0: diff --git a/tutorials/examples/train_graph_ring.py b/tutorials/examples/train_graph_ring.py new file mode 100644 index 00000000..e702c1b7 --- /dev/null +++ b/tutorials/examples/train_graph_ring.py @@ -0,0 +1,969 @@ +"""Train a GFlowNet to generate ring graphs. + +This example demonstrates training a GFlowNet to generate graphs that are rings - where each vertex +has exactly two neighbors and the edges form a single cycle containing all vertices. The environment +supports both directed and undirected ring generation. + +Key components: +- RingGraphBuilding: Environment for building ring graphs +- RingPolicyModule: GNN-based policy network for predicting actions +- directed_reward/undirected_reward: Reward functions for validating ring structures +""" + +import math +import time +from typing import Optional + +import matplotlib.pyplot as plt +import torch +from matplotlib import patches +from tensordict import TensorDict +from torch import nn +from torch_geometric.data import Batch as GeometricBatch +from torch_geometric.data import Data as GeometricData +from torch_geometric.nn import DirGNNConv, GCNConv, GINConv + +from gfn.actions import Actions, GraphActions, GraphActionType +from gfn.containers import ReplayBuffer +from gfn.gflownet.trajectory_balance import TBGFlowNet +from gfn.gym import GraphBuilding +from gfn.modules import DiscretePolicyEstimator +from gfn.preprocessors import Preprocessor +from gfn.states import GraphStates +from gfn.utils.modules import MLP + +REW_VAL = 100.0 +EPS_VAL = 1e-6 + + +def directed_reward(states: GraphStates) -> torch.Tensor: + """Compute reward for directed ring graphs. + + This function evaluates if a graph forms a valid directed ring (directed cycle). + A valid directed ring must satisfy these conditions: + 1. Each node must have exactly one outgoing edge (row sum = 1 in adjacency matrix) + 2. Each node must have exactly one incoming edge (column sum = 1 in adjacency matrix) + 3. Following the edges must form a single cycle that includes all nodes + + The reward is binary: + - REW_VAL (100.0) for valid directed rings + - EPS_VAL (1e-6) for invalid structures + + Args: + states: A batch of graph states to evaluate + + Returns: + A tensor of rewards with the same batch shape as states + """ + if states.tensor.edge_index.numel() == 0: + return torch.full(states.batch_shape, EPS_VAL) + + out = torch.full((len(states),), EPS_VAL) # Default reward. + + for i in range(len(states)): + graph = states[i] + adj_matrix = torch.zeros(graph.tensor.num_nodes, graph.tensor.num_nodes) + adj_matrix[graph.tensor.edge_index[0], graph.tensor.edge_index[1]] = 1 + + # Check if each node has exactly one outgoing edge (row sum = 1) + if not torch.all(adj_matrix.sum(dim=1) == 1): + continue + + # Check that each node has exactly one incoming edge (column sum = 1) + if not torch.all(adj_matrix.sum(dim=0) == 1): + continue + + # Starting from node 0, follow edges and see if we visit all nodes + # and return to the start + visited = [] + current = 0 # Start from node 0 + + while current not in visited: + visited.append(current) + + # Get the outgoing neighbor + current = torch.where(adj_matrix[int(current)] == 1)[0].item() + + # If we've visited all nodes and returned to 0, it's a valid ring + if len(visited) == graph.tensor.num_nodes and current == 0: + out[i] = REW_VAL + break + + return out.view(*states.batch_shape) + + +def undirected_reward(states: GraphStates) -> torch.Tensor: + """Compute reward for undirected ring graphs. + + This function evaluates if a graph forms a valid undirected ring (cycle). + A valid undirected ring must satisfy these conditions: + 1. Each node must have exactly two neighbors (degree = 2) + 2. The graph must form a single connected cycle including all nodes + + The reward is binary: + - REW_VAL (100.0) for valid undirected rings + - EPS_VAL (1e-6) for invalid structures + + The algorithm: + 1. Checks that all nodes have degree 2 + 2. Performs a traversal starting from node 0, following edges + 3. Checks if the traversal visits all nodes and returns to start + + Args: + states: A batch of graph states to evaluate + + Returns: + A tensor of rewards with the same batch shape as states + """ + if states.tensor.edge_index.numel() == 0: + return torch.full(states.batch_shape, EPS_VAL) + + out = torch.full((len(states),), EPS_VAL) # Default reward. + + for i in range(len(states)): + graph = states[i] + if graph.tensor.num_nodes == 0: + continue + adj_matrix = torch.zeros(graph.tensor.num_nodes, graph.tensor.num_nodes) + adj_matrix[graph.tensor.edge_index[0], graph.tensor.edge_index[1]] = 1 + adj_matrix[graph.tensor.edge_index[1], graph.tensor.edge_index[0]] = 1 + + # In an undirected ring, every vertex should have degree 2. + if not torch.all(adj_matrix.sum(dim=1) == 2): + continue + + # Traverse the cycle starting from vertex 0. + start_vertex = 0 + visited = [start_vertex] + neighbors = torch.where(adj_matrix[start_vertex] == 1)[0] + if neighbors.numel() == 0: + continue + # Arbitrarily choose one neighbor to begin the traversal. + current = neighbors[0].item() + prev = start_vertex + + while True: + if current == start_vertex: + break + visited.append(int(current)) + current_neighbors = torch.where(adj_matrix[int(current)] == 1)[0] + # Exclude the neighbor we just came from. + current_neighbors_list = [n.item() for n in current_neighbors] + possible = [n for n in current_neighbors_list if n != prev] + if len(possible) != 1: + break + next_node = possible[0] + prev, current = current, next_node + + if current == start_vertex and len(visited) == graph.tensor.num_nodes: + out[i] = REW_VAL + + return out.view(*states.batch_shape) + + +class RingPolicyModule(nn.Module): + """Simple module which outputs a fixed logits for the actions, depending on the number of edges. + + Args: + n_nodes: The number of nodes in the graph. + """ + + def __init__( + self, + n_nodes: int, + directed: bool, + num_conv_layers: int = 1, + embedding_dim: int = 128, + is_backward: bool = False, + ): + super().__init__() + self.hidden_dim = self.embedding_dim = embedding_dim + self.is_directed = directed + self.is_backward = is_backward + self.n_nodes = n_nodes + self.num_conv_layers = num_conv_layers + + # Node embedding layer. + self.embedding = nn.Embedding(n_nodes, self.embedding_dim) + self.conv_blks = nn.ModuleList() + self.exit_mlp = MLP( + input_dim=self.hidden_dim, + output_dim=1, + hidden_dim=self.hidden_dim, + n_hidden_layers=1, + add_layer_norm=True, + ) + + if directed: + for i in range(num_conv_layers): + self.conv_blks.extend( + [ + DirGNNConv( + GCNConv( + self.embedding_dim if i == 0 else self.hidden_dim, + self.hidden_dim, + ), + alpha=0.5, + root_weight=True, + ), + # Process in/out components separately + nn.ModuleList( + [ + nn.Sequential( + nn.Linear( + self.hidden_dim // 2, self.hidden_dim // 2 + ), + nn.ReLU(), + nn.Linear( + self.hidden_dim // 2, self.hidden_dim // 2 + ), + ) + for _ in range( + 2 + ) # One for in-features, one for out-features. + ] + ), + ] + ) + else: # Undirected case. + for _ in range(num_conv_layers): + self.conv_blks.extend( + [ + GINConv( + MLP( + input_dim=( + self.embedding_dim if i == 0 else self.hidden_dim + ), + output_dim=self.hidden_dim, + hidden_dim=self.hidden_dim, + n_hidden_layers=1, + add_layer_norm=True, + ), + ), + nn.Sequential( + nn.Linear(self.hidden_dim, self.hidden_dim), + nn.ReLU(), + nn.Linear(self.hidden_dim, self.hidden_dim), + ), + ] + ) + + self.norm = nn.LayerNorm(self.hidden_dim) + + def _group_mean(self, tensor: torch.Tensor, batch_ptr: torch.Tensor) -> torch.Tensor: + cumsum = torch.zeros( + (len(tensor) + 1, *tensor.shape[1:]), + dtype=tensor.dtype, + device=tensor.device, + ) + cumsum[1:] = torch.cumsum(tensor, dim=0) + + # Subtract the end val from each batch idx fom the start val of each batch idx. + size = batch_ptr[1:] - batch_ptr[:-1] + return (cumsum[batch_ptr[1:]] - cumsum[batch_ptr[:-1]]) / size[:, None] + + def forward(self, states_tensor: GeometricBatch) -> torch.Tensor: + node_features, batch_ptr = (states_tensor.x, states_tensor.ptr) + batch_size = int(math.prod(states_tensor.batch_shape)) + + # Multiple action type convolutions with residual connections. + x = self.embedding(node_features.squeeze().int()) + for i in range(0, len(self.conv_blks), 2): + x_new = self.conv_blks[i](x, states_tensor.edge_index) # GIN/GCN conv. + if self.is_directed: + assert isinstance(self.conv_blks[i + 1], nn.ModuleList) + x_in, x_out = torch.chunk(x_new, 2, dim=-1) + + # Process each component separately through its own MLP. + mlp_in, mlp_out = self.conv_blks[i + 1] + x_in = mlp_in(x_in) + x_out = mlp_out(x_out) + x_new = torch.cat([x_in, x_out], dim=-1) + else: + x_new = self.conv_blks[i + 1](x_new) # Linear -> ReLU -> Linear. + + x = x_new + x if i > 0 else x_new # Residual connection. + x = self.norm(x) # Layernorm. + + # This MLP computes the exit action. + node_feature_means = self._group_mean(x, batch_ptr) + exit_action = self.exit_mlp(node_feature_means) + + x = x.reshape(*states_tensor.batch_shape, self.n_nodes, self.hidden_dim) + + # Undirected. + if self.is_directed: + feature_dim = self.hidden_dim // 2 + source_features = x[..., :feature_dim] + target_features = x[..., feature_dim:] + + # Dot product between source and target features (asymmetric). + edgewise_dot_prod = torch.einsum( + "bnf,bmf->bnm", source_features, target_features + ) + edgewise_dot_prod = edgewise_dot_prod / torch.sqrt(torch.tensor(feature_dim)) + + i_up, j_up = torch.triu_indices(self.n_nodes, self.n_nodes, offset=1) + i_lo, j_lo = torch.tril_indices(self.n_nodes, self.n_nodes, offset=-1) + + # Combine them. + i0 = torch.cat([i_up, i_lo]) + i1 = torch.cat([j_up, j_lo]) + out_size = self.n_nodes**2 - self.n_nodes + + else: + # Dot product between all node features (symmetric). + edgewise_dot_prod = torch.einsum("bnf,bmf->bnm", x, x) + edgewise_dot_prod = edgewise_dot_prod / torch.sqrt( + torch.tensor(self.hidden_dim) + ) + i0, i1 = torch.triu_indices(self.n_nodes, self.n_nodes, offset=1) + out_size = (self.n_nodes**2 - self.n_nodes) // 2 + + # Grab the needed elems from the adjacency matrix and reshape. + edge_actions = edgewise_dot_prod[torch.arange(batch_size)[:, None, None], i0, i1] + edge_actions = edge_actions.reshape(*states_tensor["batch_shape"], out_size) + + if self.is_backward: + return edge_actions + else: + return torch.cat([edge_actions, exit_action], dim=-1) + + +class RingGraphBuilding(GraphBuilding): + """Environment for building ring graphs with discrete action space. + + This environment is specialized for creating ring graphs where each node has + exactly two neighbors and the edges form a single cycle. The environment supports + both directed and undirected graphs. + + In each state, the policy can: + 1. Add an edge between existing nodes + 2. Use the exit action to terminate graph building + + The action space is discrete, with size: + - For directed graphs: n_nodes^2 - n_nodes + 1 (all possible directed edges + exit) + - For undirected graphs: (n_nodes^2 - n_nodes)/2 + 1 (upper triangle + exit) + + Args: + n_nodes: The number of nodes in the graph. + state_evaluator: A function that evaluates a state and returns a reward. + directed: Whether the graph should be directed. + """ + + def __init__(self, n_nodes: int, state_evaluator: callable, directed: bool): + self.n_nodes = n_nodes + if directed: + # all off-diagonal edges + exit. + self.n_actions = (n_nodes**2 - n_nodes) + 1 + else: + # bottom triangle + exit. + self.n_actions = ((n_nodes**2 - n_nodes) // 2) + 1 + super().__init__(feature_dim=n_nodes, state_evaluator=state_evaluator) + self.is_discrete = True # actions here are discrete, needed for FlowMatching + self.is_directed = directed + + def make_actions_class(self) -> type[Actions]: + env = self + + class RingActions(Actions): + """Actions for building ring graphs. + + Actions are represented as discrete indices where: + - 0 to n_actions-2: Adding an edge between specific nodes + - n_actions-1: EXIT action to terminate the trajectory + - n_actions: DUMMY action (used for padding) + """ + + action_shape = (1,) + dummy_action = torch.tensor([env.n_actions]) + exit_action = torch.tensor([env.n_actions - 1]) + + return RingActions + + def make_states_class(self) -> type[GraphStates]: + env = self + + class RingStates(GraphStates): + """Represents the state of a ring graph building process. + + This class extends GraphStates to specifically handle ring graph states. + Each state represents a graph with a fixed number of nodes where edges + are being added incrementally to form a ring structure. + + The state representation consists of: + - node_feature: Node IDs for each node in the graph (shape: [n_nodes, 1]) + - edge_feature: Features for each edge (shape: [n_edges, 1]) + - edge_index: Indices representing the source and target nodes for each edge (shape: [n_edges, 2]) + + Special states: + - s0: Initial state with n_nodes and no edges + - sf: Terminal state (used as a placeholder) + + The class provides masks for both forward and backward actions to determine + which actions are valid from the current state. + """ + + s0 = GeometricData( + x=torch.arange(env.n_nodes)[:, None], + edge_attr=torch.ones((0, 1)), + edge_index=torch.ones((2, 0), dtype=torch.long), + ) + sf = GeometricData( + x=-torch.ones(env.n_nodes)[:, None], + edge_attr=torch.zeros((0, 1)), + edge_index=torch.zeros((2, 0), dtype=torch.long), + ) + + def __init__(self, tensor: GeometricBatch): + self.tensor = tensor + self.node_features_dim = tensor.x.shape[-1] + self.edge_features_dim = tensor.edge_attr.shape[-1] + self._log_rewards: Optional[float] = None + + self.n_nodes = env.n_nodes + self.n_actions = env.n_actions + + @property + def forward_masks(self): + """Compute masks for valid forward actions from the current state. + + A forward action is valid if: + 1. The edge doesn't already exist in the graph + 2. The edge connects two distinct nodes + + For directed graphs, all possible src->dst edges are considered. + For undirected graphs, only the upper triangular portion of the adjacency matrix is used. + + The last action is always the EXIT action, which is always valid. + + Returns: + Tensor: Boolean mask of shape [batch_size, n_actions] where True indicates valid actions + """ + # Allow all actions. + forward_masks = torch.ones(len(self), self.n_actions, dtype=torch.bool) + + if env.is_directed: + i_up, j_up = torch.triu_indices( + self.n_nodes, self.n_nodes, offset=1 + ) # Upper triangle. + i_lo, j_lo = torch.tril_indices( + self.n_nodes, self.n_nodes, offset=-1 + ) # Lower triangle. + + # Combine them + ei0 = torch.cat([i_up, i_lo]) + ei1 = torch.cat([j_up, j_lo]) + else: + ei0, ei1 = torch.triu_indices(self.n_nodes, self.n_nodes, offset=1) + + # Remove existing edges. + for i in range(len(self)): + existing_edges = self[i].tensor.edge_index + assert torch.all(existing_edges >= 0) # TODO: convert to test. + + if existing_edges.numel() == 0: + edge_idx = torch.zeros(0, dtype=torch.bool) + else: + edge_idx = torch.logical_and( + existing_edges[0][..., None] == ei0[None], + existing_edges[1][..., None] == ei1[None], + ) + + # Collapse across the edge dimension. + if len(edge_idx.shape) == 2: + edge_idx = edge_idx.sum(0).bool() + + # Adds an unmasked exit action. + edge_idx = torch.cat((edge_idx, torch.BoolTensor([False]))) + forward_masks[i, edge_idx] = ( + False # Disallow the addition of this edge. + ) + + return forward_masks.view(*self.batch_shape, self.n_actions) + + @forward_masks.setter + def forward_masks(self, value: torch.Tensor): + pass # fwd masks is computed on the fly + + @property + def backward_masks(self): + """Compute masks for valid backward actions from the current state. + + A backward action is valid if: + 1. The edge exists in the current graph (i.e., can be removed) + + For directed graphs, all existing edges are considered for removal. + For undirected graphs, only the upper triangular edges are considered. + + The EXIT action is not included in backward masks. + + Returns: + Tensor: Boolean mask of shape [batch_size, n_actions-1] where True indicates valid actions + """ + # Disallow all actions. + backward_masks = torch.zeros( + len(self), self.n_actions - 1, dtype=torch.bool + ) + + for i in range(len(self)): + existing_edges = self[i].tensor.edge_index + + if env.is_directed: + i_up, j_up = torch.triu_indices( + self.n_nodes, self.n_nodes, offset=1 + ) # Upper triangle. + i_lo, j_lo = torch.tril_indices( + self.n_nodes, self.n_nodes, offset=-1 + ) # Lower triangle. + + # Combine them + ei0 = torch.cat([i_up, i_lo]) + ei1 = torch.cat([j_up, j_lo]) + else: + ei0, ei1 = torch.triu_indices( + self.n_nodes, self.n_nodes, offset=1 + ) + + if len(existing_edges) == 0: + edge_idx = torch.zeros(0, dtype=torch.bool) + else: + edge_idx = torch.logical_and( + existing_edges[0][..., None] == ei0[None], + existing_edges[1][..., None] == ei1[None], + ) + # Collapse across the edge dimension. + if len(edge_idx.shape) == 2: + edge_idx = edge_idx.sum(0).bool() + + backward_masks[i, edge_idx] = ( + True # Allow the removal of this edge. + ) + + return backward_masks.view(*self.batch_shape, self.n_actions - 1) + + @backward_masks.setter + def backward_masks(self, value: torch.Tensor): + pass # bwd masks is computed on the fly + + return RingStates + + def _step(self, states: GraphStates, actions: Actions) -> GraphStates: + """Take a step in the environment by applying actions to states. + + Args: + states: Current states batch + actions: Actions to apply + + Returns: + New states after applying the actions + """ + actions = self.convert_actions(states, actions) + new_states = super()._step(states, actions) + assert isinstance(new_states, GraphStates) + return new_states + + def _backward_step(self, states: GraphStates, actions: Actions) -> GraphStates: + """Take a backward step in the environment. + + Args: + states: Current states batch + actions: Actions to apply in reverse + + Returns: + New states after applying the backward actions + """ + actions = self.convert_actions(states, actions) + new_states = super()._backward_step(states, actions) + assert isinstance(new_states, GraphStates) + return new_states + + def convert_actions(self, states: GraphStates, actions: Actions) -> GraphActions: + """Converts the action from discrete space to graph action space. + + This method maps discrete action indices to specific graph operations: + - GraphActionType.ADD_EDGE: Add an edge between specific nodes + - GraphActionType.EXIT: Terminate trajectory + - GraphActionType.DUMMY: No-op action (for padding) + + Args: + states: Current states batch + actions: Discrete actions to convert + + Returns: + Equivalent actions in the GraphActions format + """ + # TODO: factor out into utility function. + action_tensor = actions.tensor.squeeze(-1).clone() + action_type = torch.where( + action_tensor == self.n_actions - 1, + GraphActionType.EXIT, + GraphActionType.ADD_EDGE, + ) + action_type[action_tensor == self.n_actions] = GraphActionType.DUMMY + + # Convert action indices to source-target node pairs + # TODO: factor out into utility function. + if self.is_directed: + i_up, j_up = torch.triu_indices( + self.n_nodes, self.n_nodes, offset=1 + ) # Upper triangle. + i_lo, j_lo = torch.tril_indices( + self.n_nodes, self.n_nodes, offset=-1 + ) # Lower triangle. + + # Combine them + ei0 = torch.cat([i_up, i_lo]) + ei1 = torch.cat([j_up, j_lo]) + + else: + ei0, ei1 = torch.triu_indices(self.n_nodes, self.n_nodes, offset=1) + + # Adds -1 "edge" representing exit, -2 "edge" representing dummy. + ei0 = torch.cat((ei0, torch.IntTensor([-1, -2])), dim=0) + ei1 = torch.cat((ei1, torch.IntTensor([-1, -2])), dim=0) + + # Indexes either the second last element (exit) or la + # action_tensor[action_tensor >= (self.n_actions - 1)] = 0 + ei0, ei1 = ei0[action_tensor], ei1[action_tensor] + + out = GraphActions( + TensorDict( + { + "action_type": action_type, + "features": torch.ones(action_tensor.shape + (1,)), + "edge_index": torch.stack([ei0, ei1], dim=-1), + }, + batch_size=action_tensor.shape, + ) + ) + return out + + +class GraphPreprocessor(Preprocessor): + """Preprocessor for graph states to extract the tensor representation. + + This simple preprocessor extracts the GeometricBatch from GraphStates to make + it compatible with the policy networks. It doesn't perform any complex + transformations, just ensuring the tensors are accessible in the right format. + + Args: + feature_dim: The dimension of features in the graph (default: 1) + """ + + def __init__(self, feature_dim: int = 1): + super().__init__(output_dim=feature_dim) + + def preprocess(self, states: GraphStates) -> GeometricBatch: + return states.tensor + + def __call__(self, states: GraphStates) -> GeometricBatch: + return self.preprocess(states) + + +def render_states(states: GraphStates, state_evaluator: callable, directed: bool): + """Visualize a batch of graph states as ring structures. + + This function creates a matplotlib visualization of graph states, rendering them + as circular layouts with nodes positioned evenly around a circle. For directed + graphs, edges are shown as arrows; for undirected graphs, edges are shown as lines. + + The visualization includes: + - Circular positioning of nodes + - Drawing edges between connected nodes + - Displaying the reward value for each graph + + Args: + states: A batch of graphs to visualize + state_evaluator: Function to compute rewards for each graph + directed: Whether to render directed or undirected edges + """ + rewards = state_evaluator(states) + fig, ax = plt.subplots(2, 4, figsize=(15, 7)) + for i in range(8): + current_ax = ax[i // 4, i % 4] + state = states[i] + n_circles = state.tensor.num_nodes + radius = 5 + xs, ys = [], [] + for j in range(n_circles): + angle = 2 * math.pi * j / n_circles + x = radius * math.cos(angle) + y = radius * math.sin(angle) + xs.append(x) + ys.append(y) + current_ax.add_patch( + patches.Circle((x, y), 0.5, facecolor="none", edgecolor="black") + ) + + edge_index = states[i].tensor.edge_index + + for edge in edge_index.T: + start_x, start_y = xs[edge[0]], ys[edge[0]] + end_x, end_y = xs[edge[1]], ys[edge[1]] + dx = end_x - start_x + dy = end_y - start_y + length = math.sqrt(dx**2 + dy**2) + dx, dy = dx / length, dy / length + + circle_radius = 0.5 + head_thickness = 0.2 + + start_x += dx * circle_radius + start_y += dy * circle_radius + if directed: + end_x -= dx * circle_radius + end_y -= dy * circle_radius + current_ax.arrow( + start_x, + start_y, + end_x - start_x, + end_y - start_y, + head_width=head_thickness, + head_length=head_thickness, + fc="black", + ec="black", + ) + + else: + end_x -= dx * (circle_radius + head_thickness) + end_y -= dy * (circle_radius + head_thickness) + current_ax.plot([start_x, end_x], [start_y, end_y], color="black") + + current_ax.set_title(f"State {i}, $r={rewards[i]:.2f}$") + current_ax.set_xlim(-(radius + 1), radius + 1) + current_ax.set_ylim(-(radius + 1), radius + 1) + current_ax.set_aspect("equal") + current_ax.set_xticks([]) + current_ax.set_yticks([]) + + plt.show() + + +class AdjacencyPolicyModule(nn.Module): + """Policy network that processes flattened adjacency matrices to predict graph actions. + + Unlike the GNN-based RingPolicyModule, this module uses standard MLPs to process + the entire adjacency matrix as a flattened vector. This approach: + + 1. Can directly process global graph structure without message passing + 2. May be more effective for small graphs where global patterns are important + 3. Does not require complex graph neural network operations + + The module architecture consists of: + - An MLP to process the flattened adjacency matrix into an embedding + - An edge MLP that predicts logits for each possible edge action + - An exit MLP that predicts a logit for the exit action + + Args: + n_nodes: Number of nodes in the graph + directed: Whether the graph is directed or undirected + embedding_dim: Dimension of internal embeddings (default: 128) + is_backward: Whether this is a backward policy (default: False) + """ + + def __init__( + self, + n_nodes: int, + directed: bool, + embedding_dim: int = 128, + is_backward: bool = False, + ): + super().__init__() + self.n_nodes = n_nodes + self.is_directed = directed + self.is_backward = is_backward + self.hidden_dim = embedding_dim + + # MLP for processing the flattened adjacency matrix + self.mlp = MLP( + input_dim=n_nodes * n_nodes, # Flattened adjacency matrix + output_dim=embedding_dim, + hidden_dim=embedding_dim, + n_hidden_layers=2, + add_layer_norm=True, + ) + + # Exit action MLP + self.exit_mlp = MLP( + input_dim=embedding_dim, + output_dim=1, + hidden_dim=embedding_dim, + n_hidden_layers=1, + add_layer_norm=True, + ) + + # Edge prediction MLP + if directed: + # For directed graphs: all off-diagonal elements + out_size = n_nodes**2 - n_nodes + else: + # For undirected graphs: upper triangle without diagonal + out_size = (n_nodes**2 - n_nodes) // 2 + + self.edge_mlp = MLP( + input_dim=embedding_dim, + output_dim=out_size, + hidden_dim=embedding_dim, + n_hidden_layers=1, + add_layer_norm=True, + ) + + def forward(self, states_tensor: GeometricBatch) -> torch.Tensor: + """Forward pass to compute action logits from graph states. + + Process: + 1. Convert the graph representation to adjacency matrices + 2. Process the flattened adjacency matrices through the main MLP + 3. Predict logits for edge actions and exit action + + Args: + states_tensor: A GeometricBatch containing graph state information + + Returns: + A tensor of logits for all possible actions + """ + # Convert the graph to adjacency matrix + batch_size = int(states_tensor.batch_size) + adj_matrices = torch.zeros( + (batch_size, self.n_nodes, self.n_nodes), + device=states_tensor.x.device, + ) + + # Fill the adjacency matrices from edge indices + if states_tensor.edge_index.numel() > 0: + for i in range(batch_size): + eis = states_tensor[i].edge_index + adj_matrices[i, eis[0], eis[1]] = 1 + + # Flatten the adjacency matrices for the MLP + adj_matrices_flat = adj_matrices.view(batch_size, -1) + + # Process with MLP + embedding = self.mlp(adj_matrices_flat) + + # Generate edge and exit actions + edge_actions = self.edge_mlp(embedding) + exit_action = self.exit_mlp(embedding) + + if self.is_backward: + return edge_actions + else: + return torch.cat([edge_actions, exit_action], dim=-1) + + +if __name__ == "__main__": + """ + Main execution for training a GFlowNet to generate ring graphs. + + This script demonstrates the complete workflow of training a GFlowNet + to generate valid ring structures in both directed and undirected settings. + + Configurable parameters: + - N_NODES: Number of nodes in the graph (default: 5) + - N_ITERATIONS: Number of training iterations (default: 1000) + - LR: Learning rate for optimizer (default: 0.001) + - BATCH_SIZE: Batch size for training (default: 128) + - DIRECTED: Whether to generate directed rings (default: True) + - USE_BUFFER: Whether to use a replay buffer (default: False) + - USE_GNN: Whether to use GNN-based policy (True) or MLP-based policy (False) + + The script performs the following steps: + 1. Initialize the environment and policy networks + 2. Train the GFlowNet using trajectory balance + 3. Visualize sample generated graphs + """ + N_NODES = 4 + N_ITERATIONS = 200 + LR = 0.001 + BATCH_SIZE = 1024 + DIRECTED = True + USE_BUFFER = False + USE_GNN = True # Set to False to use MLP with adjacency matrices instead of GNN + NUM_CONV_LAYERS = 1 + + state_evaluator = undirected_reward if not DIRECTED else directed_reward + torch.random.manual_seed(7) + env = RingGraphBuilding( + n_nodes=N_NODES, state_evaluator=state_evaluator, directed=DIRECTED + ) + + # Choose model type based on USE_GNN flag + if USE_GNN: + module_pf = RingPolicyModule( + env.n_nodes, DIRECTED, num_conv_layers=NUM_CONV_LAYERS + ) + module_pb = RingPolicyModule( + env.n_nodes, DIRECTED, is_backward=True, num_conv_layers=NUM_CONV_LAYERS + ) + else: + module_pf = AdjacencyPolicyModule(env.n_nodes, DIRECTED) + module_pb = AdjacencyPolicyModule(env.n_nodes, DIRECTED, is_backward=True) + + pf = DiscretePolicyEstimator( + module=module_pf, n_actions=env.n_actions, preprocessor=GraphPreprocessor() + ) + pb = DiscretePolicyEstimator( + module=module_pb, + n_actions=env.n_actions, + preprocessor=GraphPreprocessor(), + is_backward=True, + ) + gflownet = TBGFlowNet(pf, pb) + optimizer = torch.optim.Adam(gflownet.parameters(), lr=LR) + scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.1) + + replay_buffer = ReplayBuffer( + env, + capacity=BATCH_SIZE, + prioritized=True, + ) + + losses = [] + + t1 = time.time() + for iteration in range(N_ITERATIONS): + trajectories = gflownet.sample_trajectories( + env, + n=BATCH_SIZE, + save_logprobs=True, + epsilon=0.2 * (1 - iteration / N_ITERATIONS), + ) + training_samples = gflownet.to_training_samples(trajectories) + + # Collect rewards for reporting. + if isinstance(training_samples, tuple): + last_states = training_samples[1] + else: + last_states = training_samples.last_states + assert isinstance(last_states, GraphStates) + rewards = state_evaluator(last_states) + + if USE_BUFFER: + with torch.no_grad(): + replay_buffer.add(training_samples) + if iteration > 20: + training_samples = training_samples[: BATCH_SIZE // 2] + buffer_samples = replay_buffer.sample(n_trajectories=BATCH_SIZE // 2) + training_samples.extend(buffer_samples) # type: ignore + + optimizer.zero_grad() + loss = gflownet.loss(env, training_samples) + pct_rings = torch.mean(rewards > 0.1, dtype=torch.float) * 100 + print( + "Iteration {} - Loss: {:.02f}, rings: {:.0f}%".format( + iteration, loss.item(), pct_rings + ) + ) + loss.backward() + optimizer.step() + scheduler.step() + losses.append(loss.item()) + + t2 = time.time() + print("Time:", t2 - t1) + + # This comes from the gflownet, not the buffer. + last_states = trajectories.last_states[:8] + assert isinstance(last_states, GraphStates) + render_states(last_states, state_evaluator, DIRECTED) diff --git a/tutorials/examples/train_hypergrid.py b/tutorials/examples/train_hypergrid.py index 8a8e2552..8bcfb720 100644 --- a/tutorials/examples/train_hypergrid.py +++ b/tutorials/examples/train_hypergrid.py @@ -18,7 +18,7 @@ import wandb from tqdm import tqdm, trange -from gfn.containers import PrioritizedReplayBuffer, ReplayBuffer +from gfn.containers import NormBasedDiversePrioritizedReplayBuffer, ReplayBuffer from gfn.gflownet import ( DBGFlowNet, FMGFlowNet, @@ -181,7 +181,7 @@ def main(args): # noqa: C901 replay_buffer = None if args.replay_buffer_size > 0: if args.replay_buffer_prioritized: - replay_buffer = PrioritizedReplayBuffer( + replay_buffer = NormBasedDiversePrioritizedReplayBuffer( env, capacity=args.replay_buffer_size, cutoff_distance=args.cutoff_distance,