diff --git a/pyproject.toml b/pyproject.toml index 2fc0f22d..ec4dd579 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,9 @@ classifiers = [ einops = ">=0.6.1" numpy = ">=1.21.2" python = "^3.10" -torch = ">=1.9.0" +torch = ">=2.6.0" +tensordict = ">=0.6.1" +torch_geometric = ">=2.6.1" # dev dependencies. black = { version = "24.3", optional = true } diff --git a/src/gfn/actions.py b/src/gfn/actions.py index f5886ecb..b85691a7 100644 --- a/src/gfn/actions.py +++ b/src/gfn/actions.py @@ -1,10 +1,12 @@ from __future__ import annotations # This allows to use the class name in type hints +import enum from abc import ABC from math import prod from typing import ClassVar, Sequence import torch +from tensordict import TensorDict class Actions(ABC): @@ -170,3 +172,156 @@ def is_exit(self) -> torch.Tensor: *self.batch_shape, *((1,) * len(self.__class__.action_shape)) ) return self.compare(exit_actions_tensor) + + +class GraphActionType(enum.IntEnum): + ADD_NODE = 0 + ADD_EDGE = 1 + EXIT = 2 + DUMMY = 3 + + +class GraphActions(Actions): + """Actions for graph-based environments. + + Each action is one of: + - ADD_NODE: Add a node with given features + - ADD_EDGE: Add an edge between two nodes with given features + - EXIT: Terminate the trajectory + + Attributes: + features_dim: Dimension of node/edge features + tensor: TensorDict containing: + - action_type: Type of action (GraphActionType) + - features: Features for nodes/edges + - edge_index: Source/target nodes for edges + """ + + features_dim: ClassVar[int] + + def __init__(self, tensor: TensorDict): + """Initializes a GraphAction object. + + Args: + action: a GraphActionType indicating the type of action. + features: a tensor of shape (batch_shape, feature_shape) representing the features of the nodes or of the edges, depending on the action type. + In case of EXIT action, this can be None. + edge_index: an tensor of shape (batch_shape, 2) representing the edge to add. + This must defined if and only if the action type is GraphActionType.AddEdge. + """ + self.batch_shape = tensor["action_type"].shape + features = tensor.get("features", None) + if features is None: + assert torch.all( + torch.logical_or( + tensor["action_type"] == GraphActionType.EXIT, + tensor["action_type"] == GraphActionType.DUMMY, + ) + ) + features = torch.zeros((*self.batch_shape, self.features_dim)) + edge_index = tensor.get("edge_index", None) + if edge_index is None: + assert torch.all(tensor["action_type"] != GraphActionType.ADD_EDGE) + edge_index = torch.zeros((*self.batch_shape, 2), dtype=torch.long) + + self.tensor = TensorDict( + { + "action_type": tensor["action_type"], + "features": features, + "edge_index": edge_index, + }, + batch_size=self.batch_shape, + ) + + def __repr__(self): + return f"""GraphAction object with {self.batch_shape} actions.""" + + @property + def device(self) -> torch.device | None: + """Returns the device of the features tensor.""" + return self.tensor.device + + def __len__(self) -> int: + """Returns the number of actions in the batch.""" + return int(prod(self.batch_shape)) + + def __getitem__(self, index: int | Sequence[int] | Sequence[bool]) -> GraphActions: + """Get particular actions of the batch.""" + return GraphActions(self.tensor[index]) + + def __setitem__( + self, index: int | Sequence[int] | Sequence[bool], action: GraphActions + ) -> None: + """Set particular actions of the batch.""" + self.tensor[index] = action.tensor + + def compare(self, other: GraphActions) -> torch.Tensor: + """Compares the actions to another GraphAction object. + + Args: + other: GraphAction object to compare. + + Returns: boolean tensor of shape batch_shape indicating whether the actions are equal. + """ + compare = torch.all(self.tensor == other.tensor, dim=-1) + return ( + compare["action_type"] + & (compare["action_type"] == GraphActionType.EXIT | compare["features"]) + & ( + compare["action_type"] + != GraphActionType.ADD_EDGE | compare["edge_index"] + ) + ) + + @property + def is_exit(self) -> torch.Tensor: + """Returns a boolean tensor of shape `batch_shape` indicating whether the actions are exit actions.""" + return self.action_type == GraphActionType.EXIT + + @property + def is_dummy(self) -> torch.Tensor: + """Returns a boolean tensor of shape `batch_shape` indicating whether the actions are dummy actions.""" + return self.action_type == GraphActionType.DUMMY + + @property + def action_type(self) -> torch.Tensor: + """Returns the action type tensor.""" + return self.tensor["action_type"] + + @property + def features(self) -> torch.Tensor: + """Returns the features tensor.""" + return self.tensor["features"] + + @property + def edge_index(self) -> torch.Tensor: + """Returns the edge index tensor.""" + return self.tensor["edge_index"] + + @classmethod + def make_dummy_actions(cls, batch_shape: tuple[int]) -> GraphActions: + """Creates a GraphActions object of dummy actions with the given batch shape.""" + return cls( + TensorDict( + { + "action_type": torch.full( + batch_shape, fill_value=GraphActionType.DUMMY + ), + }, + batch_size=batch_shape, + ) + ) + + @classmethod + def make_exit_actions(cls, batch_shape: tuple[int]) -> Actions: + """Creates an GraphActions object of exit actions with the given batch shape.""" + return cls( + TensorDict( + { + "action_type": torch.full( + batch_shape, fill_value=GraphActionType.EXIT + ), + }, + batch_size=batch_shape, + ) + ) diff --git a/src/gfn/containers/__init__.py b/src/gfn/containers/__init__.py index a133fe4a..1b64a853 100644 --- a/src/gfn/containers/__init__.py +++ b/src/gfn/containers/__init__.py @@ -1,3 +1,3 @@ -from .replay_buffer import PrioritizedReplayBuffer, ReplayBuffer +from .replay_buffer import NormBasedDiversePrioritizedReplayBuffer, ReplayBuffer from .trajectories import Trajectories from .transitions import Transitions diff --git a/src/gfn/containers/replay_buffer.py b/src/gfn/containers/replay_buffer.py index bb12c0db..9df46014 100644 --- a/src/gfn/containers/replay_buffer.py +++ b/src/gfn/containers/replay_buffer.py @@ -22,6 +22,7 @@ class ReplayBuffer: training_objects: the buffer of objects used for training. terminating_states: a States class representation of $s_f$. objects_type: the type of buffer (transitions, trajectories, or states). + prioritized: whether the buffer is prioritized by log_reward or not. """ def __init__( @@ -29,6 +30,7 @@ def __init__( env: Env, objects_type: Literal["transitions", "trajectories", "states"], capacity: int = 1000, + prioritized: bool = False, ): """Instantiates a replay buffer. Args: @@ -53,6 +55,7 @@ def __init__( raise ValueError(f"Unknown objects_type: {objects_type}") self._is_full = False + self.prioritized = prioritized def __repr__(self): return f"ReplayBuffer(capacity={self.capacity}, containing {len(self)} {self.objects_type})" @@ -60,27 +63,53 @@ def __repr__(self): def __len__(self): return len(self.training_objects) - def add(self, training_objects: Transitions | Trajectories | tuple[States]): + def _add_objs( + self, + training_objects: Transitions | Trajectories | tuple[States], + ): """Adds a training object to the buffer.""" terminating_states = None if isinstance(training_objects, tuple): assert self.objects_type == "states" and self.terminating_states is not None training_objects, terminating_states = training_objects # pyright: ignore - + to_add = len(training_objects) - self._is_full |= len(self) + to_add >= self.capacity + # Adds the objects to the buffer. self.training_objects.extend(training_objects) # pyright: ignore + + # Sort elements by logreward, capping the size at the defined capacity. + if self.prioritized: + + if ( + self.training_objects.log_rewards is None + or training_objects.log_rewards is None # pyright: ignore + ): + raise ValueError("log_rewards must be defined for prioritized replay.") + + # Ascending sort. + ix = torch.argsort(self.training_objects.log_rewards) # pyright: ignore + self.training_objects = self.training_objects[ix] # pyright: ignore self.training_objects = self.training_objects[ - -self.capacity : + -self.capacity : # Ascending sort, so we retain the final elements. ] # pyright: ignore + # Add the terminating states to the buffer. if self.terminating_states is not None: assert terminating_states is not None - self.terminating_states.extend(terminating_states) # pyright: ignore + self.terminating_states.extend(terminating_states) + + # Sort terminating states by logreward as well. + if self.prioritized: + self.terminating_states = self.terminating_states[ix] + self.terminating_states = self.terminating_states[-self.capacity :] + def add(self, training_objects: Transitions | Trajectories | tuple[States]): + """Adds a training object to the buffer.""" + self._add_objs(training_objects) + def sample(self, n_trajectories: int) -> Transitions | Trajectories | tuple[States]: """Samples `n_trajectories` training objects from the buffer.""" if self.terminating_states is not None: @@ -113,8 +142,8 @@ def load(self, directory: str): ) -class PrioritizedReplayBuffer(ReplayBuffer): - """A replay buffer of trajectories or transitions. +class NormBasedDiversePrioritizedReplayBuffer(ReplayBuffer): + """A replay buffer of trajectories or transitions with diverse trajectories. Attributes: env: the Environment instance. @@ -152,53 +181,27 @@ def __init__( super().__init__(env, objects_type, capacity) self.cutoff_distance = cutoff_distance self.p_norm_distance = p_norm_distance + self._prioritized = True - def _add_objs( - self, - training_objects: Transitions | Trajectories | tuple[States], - terminating_states: States | None = None, - ): - """Adds a training object to the buffer.""" - # Adds the objects to the buffer. - self.training_objects.extend(training_objects) # pyright: ignore - - # Sort elements by logreward, capping the size at the defined capacity. - ix = torch.argsort(self.training_objects.log_rewards) # pyright: ignore - self.training_objects = self.training_objects[ix] # pyright: ignore - self.training_objects = self.training_objects[ - -self.capacity : - ] # pyright: ignore - - # Add the terminating states to the buffer. - if self.terminating_states is not None: - assert terminating_states is not None - self.terminating_states.extend(terminating_states) - - # Sort terminating states by logreward as well. - self.terminating_states = self.terminating_states[ix] - self.terminating_states = self.terminating_states[-self.capacity :] + @property + def prioritized(self) -> bool: + return self._prioritized def add(self, training_objects: Transitions | Trajectories | tuple[States]): """Adds a training object to the buffer.""" - terminating_states = None - if isinstance(training_objects, tuple): - assert self.objects_type == "states" and self.terminating_states is not None - training_objects, terminating_states = training_objects # pyright: ignore - to_add = len(training_objects) self._is_full |= len(self) + to_add >= self.capacity # The buffer isn't full yet. if len(self.training_objects) < self.capacity: - self._add_objs(training_objects, terminating_states) + self._add_objs(training_objects) # Our buffer is full and we will prioritize diverse, high reward additions. else: - if ( - self.training_objects.log_rewards is None - or training_objects.log_rewards is None # pyright: ignore - ): - raise ValueError("log_rewards must be defined for prioritized replay.") + terminating_states = None + if isinstance(training_objects, tuple): + assert self.objects_type == "states" and self.terminating_states is not None + training_objects, terminating_states = training_objects # pyright: ignore # Sort the incoming elements by their logrewards. ix = torch.argsort( diff --git a/src/gfn/containers/trajectories.py b/src/gfn/containers/trajectories.py index e9927ef4..fe8a662b 100644 --- a/src/gfn/containers/trajectories.py +++ b/src/gfn/containers/trajectories.py @@ -104,7 +104,7 @@ def __init__( assert ( log_probs.shape == (self.max_length, self.n_trajectories) and log_probs.dtype == torch.float - ) + ), f"log_probs.shape={log_probs.shape}, self.max_length={self.max_length}, self.n_trajectories={self.n_trajectories}" else: log_probs = torch.full(size=(0, 0), fill_value=0, dtype=torch.float) self.log_probs: torch.Tensor = log_probs @@ -256,7 +256,7 @@ def extend(self, other: Trajectories) -> None: # TODO: The replay buffer is storing `dones` - this wastes a lot of space. self.actions.extend(other.actions) - self.states.extend(other.states) + self.states.extend(other.states) # n_trajectories comes from this. self.when_is_done = torch.cat((self.when_is_done, other.when_is_done), dim=0) # For log_probs, we first need to make the first dimensions of self.log_probs diff --git a/src/gfn/env.py b/src/gfn/env.py index 01b238ce..8c1992b3 100644 --- a/src/gfn/env.py +++ b/src/gfn/env.py @@ -2,17 +2,18 @@ from typing import Optional, Tuple, Union import torch +from torch_geometric.data import Batch, Data -from gfn.actions import Actions +from gfn.actions import Actions, GraphActions from gfn.preprocessors import IdentityPreprocessor, Preprocessor -from gfn.states import DiscreteStates, States +from gfn.states import DiscreteStates, GraphStates, States from gfn.utils.common import set_seed # Errors NonValidActionsError = type("NonValidActionsError", (ValueError,), {}) -def get_device(device_str, default_device): +def get_device(device_str, default_device) -> torch.device: return torch.device(device_str) if device_str is not None else default_device @@ -22,12 +23,12 @@ class Env(ABC): def __init__( self, - s0: torch.Tensor, + s0: torch.Tensor | Data, state_shape: Tuple, action_shape: Tuple, dummy_action: torch.Tensor, exit_action: torch.Tensor, - sf: Optional[torch.Tensor] = None, + sf: Optional[torch.Tensor | Data] = None, device_str: Optional[str] = None, preprocessor: Optional[Preprocessor] = None, ): @@ -54,7 +55,7 @@ def __init__( assert s0.shape == state_shape if sf is None: sf = torch.full(s0.shape, -float("inf")).to(self.device) - self.sf: torch.Tensor = sf + self.sf = sf assert self.sf.shape == state_shape self.state_shape = state_shape self.action_shape = action_shape @@ -260,8 +261,12 @@ def _step( "Some actions are not valid in the given states. See `is_action_valid`." ) + # Set to the sink state when the action is exit. new_sink_states_idx = actions.is_exit - new_states.tensor[new_sink_states_idx] = self.sf + sf_tensor = self.States.make_sink_states_tensor( + (int(new_sink_states_idx.sum().item()),) + ) + new_states[new_sink_states_idx] = self.States(sf_tensor) new_sink_states_idx = ~valid_states_idx | new_sink_states_idx assert new_sink_states_idx.shape == states.batch_shape @@ -269,13 +274,13 @@ def _step( not_done_actions = actions[~new_sink_states_idx] new_not_done_states_tensor = self.step(not_done_states, not_done_actions) - if not isinstance(new_not_done_states_tensor, torch.Tensor): + + if not isinstance(new_not_done_states_tensor, (torch.Tensor, Batch)): raise Exception( "User implemented env.step function *must* return a torch.Tensor!" ) - new_states.tensor[~new_sink_states_idx] = new_not_done_states_tensor - + new_states[~new_sink_states_idx] = self.States(new_not_done_states_tensor) return new_states def _backward_step( @@ -303,7 +308,7 @@ def _backward_step( # Calculate the backward step, and update only the states which are not Done. new_not_done_states_tensor = self.backward_step(valid_states, valid_actions) - new_states.tensor[valid_states_idx] = new_not_done_states_tensor + new_states[valid_states_idx] = self.States(new_not_done_states_tensor) if isinstance(new_states, DiscreteStates): self.update_masks(new_states) # pyright: ignore @@ -358,6 +363,9 @@ class DiscreteEnv(Env, ABC): via mask tensors, that are directly attached to `States` objects. """ + s0: torch.Tensor # this tells the type checker that s0 is a torch.Tensor + sf: torch.Tensor # this tells the type checker that sf is a torch.Tensor + def __init__( self, n_actions: int, @@ -507,28 +515,14 @@ def _step(self, states: DiscreteStates, actions: Actions) -> States: return new_states def get_states_indices(self, states: DiscreteStates) -> torch.Tensor: - """Returns the indices of the states in the environment. - - Args: - states: The batch of states. - - Returns: - torch.Tensor: Tensor of shape "batch_shape" containing the indices of the states. - """ - return NotImplementedError( + """Returns the indices of the states in the environment.""" + raise NotImplementedError( "The environment does not support enumeration of states" ) def get_terminating_states_indices(self, states: DiscreteStates) -> torch.Tensor: - """Returns the indices of the terminating states in the environment. - - Args: - states: The batch of states. - - Returns: - torch.Tensor: Tensor of shape "batch_shape" containing the indices of the terminating states. - """ - return NotImplementedError( + """Returns the indices of the terminating states in the environment.""" + raise NotImplementedError( "The environment does not support enumeration of states" ) @@ -572,3 +566,71 @@ def terminating_states(self) -> DiscreteStates: raise NotImplementedError( "The environment does not support enumeration of states" ) + + +class GraphEnv(Env): + """Base class for graph-based environments.""" + + sf: Data # this tells the type checker that sf is a Data + + def __init__( + self, + s0: Data, + sf: Data, + device_str: Optional[str] = None, + preprocessor: Optional[Preprocessor] = None, + ): + """Initializes a graph-based environment. + + Args: + s0: The initial graph state. + sf: The sink graph state. + device_str: 'cpu' or 'cuda'. Defaults to None, in which case the device is + inferred from s0. + preprocessor: a Preprocessor object that converts raw graph states to a tensor + that can be fed into a neural network. Defaults to None, in which case + the IdentityPreprocessor is used. + """ + self.s0 = s0.to(device_str) + self.features_dim = s0.x.shape[-1] + self.sf = sf + + self.States = self.make_states_class() + self.Actions = self.make_actions_class() + + self.preprocessor = preprocessor + self.is_discrete = False + + def make_states_class(self) -> type[GraphStates]: + env = self + + class GraphEnvStates(GraphStates): + s0 = env.s0 + sf = env.sf + make_random_states_graph = env.make_random_states_tensor + + return GraphEnvStates + + def make_actions_class(self) -> type[GraphActions]: + """The default Actions class factory for all Environments. + + Returns a class that inherits from Actions and implements assumed methods. + The make_actions_class method should be overwritten to achieve more + environment-specific Actions functionality. + """ + env = self + + class DefaultGraphAction(GraphActions): + features_dim = env.features_dim + + return DefaultGraphAction + + @abstractmethod + def step(self, states: GraphStates, actions: Actions) -> torch.Tensor: + """Function that takes a batch of graph states and actions and returns a batch of next + graph states.""" + + @abstractmethod + def backward_step(self, states: GraphStates, actions: Actions) -> torch.Tensor: + """Function that takes a batch of graph states and actions and returns a batch of previous + graph states.""" diff --git a/src/gfn/gflownet/flow_matching.py b/src/gfn/gflownet/flow_matching.py index 6b7a679a..d879916d 100644 --- a/src/gfn/gflownet/flow_matching.py +++ b/src/gfn/gflownet/flow_matching.py @@ -1,4 +1,4 @@ -from typing import Any, Tuple, Union +from typing import Any, Optional, Tuple, Union import torch @@ -33,7 +33,7 @@ class FMGFlowNet(GFlowNet[Tuple[DiscreteStates, DiscreteStates]]): def __init__(self, logF: DiscretePolicyEstimator, alpha: float = 1.0): super().__init__() - assert isinstance( # TODO: need a more flexible type check. + assert isinstance( logF, DiscretePolicyEstimator | ConditionalDiscretePolicyEstimator, ), "logF must be a DiscretePolicyEstimator or ConditionalDiscretePolicyEstimator" @@ -157,7 +157,7 @@ def reward_matching_loss( self, env: DiscreteEnv, terminating_states: DiscreteStates, - conditioning: torch.Tensor, + conditioning: Optional[torch.Tensor], ) -> torch.Tensor: """Calculates the reward matching loss from the terminating states.""" del env # Unused @@ -173,6 +173,7 @@ def reward_matching_loss( # Handle the boundary condition (for all x, F(X->S_f) = R(x)). terminating_log_edge_flows = log_edge_flows[:, -1] log_rewards = terminating_states.log_rewards + return (terminating_log_edge_flows - log_rewards).pow(2).mean() def loss( diff --git a/src/gfn/gym/__init__.py b/src/gfn/gym/__init__.py index fbec4831..ebec6f20 100644 --- a/src/gfn/gym/__init__.py +++ b/src/gfn/gym/__init__.py @@ -1,3 +1,4 @@ from gfn.gym.box import Box from gfn.gym.discrete_ebm import DiscreteEBM +from gfn.gym.graph_building import GraphBuilding from gfn.gym.hypergrid import HyperGrid diff --git a/src/gfn/gym/graph_building.py b/src/gfn/gym/graph_building.py new file mode 100644 index 00000000..2ba32b31 --- /dev/null +++ b/src/gfn/gym/graph_building.py @@ -0,0 +1,301 @@ +from typing import Callable, Literal, Tuple + +import torch +from torch_geometric.data import Data, Batch + +from gfn.actions import GraphActions, GraphActionType +from gfn.env import GraphEnv, NonValidActionsError +from gfn.states import GraphStates + + +class GraphBuilding(GraphEnv): + """Environment for incrementally building graphs. + + This environment allows constructing graphs by: + - Adding nodes with features + - Adding edges between existing nodes with features + - Terminating construction (EXIT) + + Args: + feature_dim: Dimension of node and edge features + state_evaluator: Callable that computes rewards for final states. + If None, uses default GCNConvEvaluator + device_str: Device to run computations on ('cpu' or 'cuda') + """ + + def __init__( + self, + feature_dim: int, + state_evaluator: Callable[[GraphStates], torch.Tensor], + device_str: Literal["cpu", "cuda"] = "cpu", + ): + s0 = Data( + x=torch.zeros((0, feature_dim), dtype=torch.float32), + edge_attr=torch.zeros((0, feature_dim), dtype=torch.float32), + edge_index=torch.zeros((2, 0), dtype=torch.long), + device=device_str, + ) + sf = Data( + x=torch.ones((1, feature_dim), dtype=torch.float32) * float("inf"), + edge_attr=torch.ones((0, feature_dim), dtype=torch.float32) * float("inf"), + edge_index=torch.zeros((2, 0), dtype=torch.long), + device=device_str, + ) + + self.state_evaluator = state_evaluator + self.feature_dim = feature_dim + + super().__init__( + s0=s0, + sf=sf, + device_str=device_str, + ) + + def reset(self, batch_shape: Tuple | int) -> GraphStates: + """Reset the environment to a new batch of graphs.""" + states = super().reset(batch_shape) + assert isinstance(states, GraphStates) + return states + + def step(self, states: GraphStates, actions: GraphActions) -> Batch: + """Step function for the GraphBuilding environment. + + Args: + states: GraphStates object representing the current graph. + actions: Actions indicating which edge to add. + + Returns the next graph the new GraphStates. + """ + if not self.is_action_valid(states, actions): + raise NonValidActionsError("Invalid action.") + if len(actions) == 0: + return states.tensor + + action_type = actions.action_type[0] + assert torch.all(actions.action_type == action_type) # TODO: allow different action types + if action_type == GraphActionType.EXIT: + return self.States.make_sink_states_tensor(states.batch_shape) + + if action_type == GraphActionType.ADD_NODE: + batch_indices = torch.arange(len(states))[ + actions.action_type == GraphActionType.ADD_NODE + ] + states.tensor = self._add_node( + states.tensor, batch_indices, actions.features + ) + + if action_type == GraphActionType.ADD_EDGE: + # Get the data list from the batch + data_list = states.tensor.to_data_list() + + # Add edges to each graph + for i, (src, dst) in enumerate(actions.edge_index): + # Get the graph to modify + graph = data_list[i] + + # Add the new edge + graph.edge_index = torch.cat([ + graph.edge_index, + torch.tensor([[src], [dst]], device=graph.edge_index.device) + ], dim=1) + + # Add the edge feature + graph.edge_attr = torch.cat([ + graph.edge_attr, + actions.features[i].unsqueeze(0) + ], dim=0) + + # Create a new batch from the updated data list + new_tensor = Batch.from_data_list(data_list) + new_tensor.batch_shape = states.tensor.batch_shape + states.tensor = new_tensor + + return states.tensor + + def backward_step(self, states: GraphStates, actions: GraphActions) -> Batch: + """Backward step function for the GraphBuilding environment. + + Args: + states: GraphStates object representing the current graph. + actions: Actions indicating which edge to remove. + + Returns the previous graph as a new GraphStates. + """ + if not self.is_action_valid(states, actions, backward=True): + raise NonValidActionsError("Invalid action.") + if len(actions) == 0: + return states.tensor + + action_type = actions.action_type[0] + assert torch.all(actions.action_type == action_type) + + # Get the data list from the batch + data_list = states.tensor.to_data_list() + + if action_type == GraphActionType.ADD_NODE: + # Remove nodes with matching features + for i, features in enumerate(actions.features): + graph = data_list[i] + + # Find nodes with matching features + is_equal = torch.all(graph.x == features.unsqueeze(0), dim=1) + + if torch.any(is_equal): + # Remove the first matching node + node_idx = torch.where(is_equal)[0][0].item() + + # Remove the node + mask = torch.ones(graph.num_nodes, dtype=torch.bool, device=graph.x.device) + mask[node_idx] = False + + # Update node features + graph.x = graph.x[mask] + + elif action_type == GraphActionType.ADD_EDGE: + # Remove edges with matching indices + for i, (src, dst) in enumerate(actions.edge_index): + graph = data_list[i] + + # Find the edge to remove + edge_mask = ~((graph.edge_index[0] == src) & (graph.edge_index[1] == dst)) + + # Remove the edge + graph.edge_index = graph.edge_index[:, edge_mask] + graph.edge_attr = graph.edge_attr[edge_mask] + + # Create a new batch from the updated data list + new_batch = Batch.from_data_list(data_list) + + # Preserve the batch shape + new_batch.batch_shape = torch.tensor(states.batch_shape, device=states.tensor.x.device) + + return new_batch + + def is_action_valid( + self, states: GraphStates, actions: GraphActions, backward: bool = False + ) -> bool: + """Check if actions are valid for the given states. + + Args: + states: Current graph states. + actions: Actions to validate. + backward: Whether this is a backward step. + + Returns: + True if all actions are valid, False otherwise. + """ + # Get the data list from the batch + data_list = states.tensor.to_data_list() + + for i in range(len(actions)): + graph = data_list[i] + if actions.action_type[i] == GraphActionType.ADD_NODE: + # Check if a node with these features already exists + equal_nodes = torch.all(graph.x == actions.features[i].unsqueeze(0), dim=1) + + if backward: + # For backward actions, we need at least one matching node + if not torch.any(equal_nodes): + return False + else: + # For forward actions, we should not have any matching nodes + if torch.any(equal_nodes): + return False + + elif actions.action_type[i] == GraphActionType.ADD_EDGE: + src, dst = actions.edge_index[i] + + # Check if src and dst are valid node indices + if src >= graph.num_nodes or dst >= graph.num_nodes or src == dst: + return False + + # Check if the edge already exists + edge_exists = torch.any((graph.edge_index[0] == src) & (graph.edge_index[1] == dst)) + + if backward: + # For backward actions, the edge must exist + if not edge_exists: + return False + else: + # For forward actions, the edge must not exist + if edge_exists: + return False + + return True + + def _add_node( + self, + tensor: Batch, + batch_indices: torch.Tensor, + nodes_to_add: torch.Tensor, + ) -> Batch: + """Add nodes to graphs in a batch. + + Args: + tensor_dict: The current batch of graphs. + batch_indices: Indices of graphs to add nodes to. + nodes_to_add: Features of nodes to add. + + Returns: + Updated batch of graphs. + """ + if isinstance(batch_indices, list): + batch_indices = torch.tensor(batch_indices) + if len(batch_indices) != len(nodes_to_add): + raise ValueError( + "Number of batch indices must match number of node feature lists" + ) + # Get the data list from the batch + data_list = tensor.to_data_list() + + # Add nodes to the specified graphs + for graph_idx, new_nodes in zip(batch_indices, nodes_to_add): + # Get the graph to modify + graph = data_list[graph_idx] + + # Ensure new_nodes is 2D + new_nodes = torch.atleast_2d(new_nodes) + + # Check feature dimension + if new_nodes.shape[1] != graph.x.shape[1]: + raise ValueError(f"Node features must have dimension {graph.x.shape[1]}") + + # Generate unique indices for new nodes + num_new_nodes = new_nodes.shape[0] + new_indices = GraphStates.unique_node_indices(num_new_nodes) + + # Add new nodes to the graph + graph.x = torch.cat([graph.x, new_nodes], dim=0) + + # Add node indices if they exist + if hasattr(graph, "node_index"): + graph.node_index = torch.cat([graph.node_index, new_indices], dim=0) + else: + # If node_index doesn't exist, create it + graph.node_index = torch.cat([ + torch.arange(graph.num_nodes - num_new_nodes, device=graph.x.device), + new_indices + ], dim=0) + + # Create a new batch from the updated data list + new_batch = Batch.from_data_list(data_list) + + # Preserve the batch shape + new_batch.batch_shape = tensor.batch_shape + return new_batch + + def reward(self, final_states: GraphStates) -> torch.Tensor: + """The environment's reward given a state. + This or log_reward must be implemented. + + Args: + final_states: A batch of final states. + + Returns: + torch.Tensor: Tensor of shape "batch_shape" containing the rewards. + """ + return self.state_evaluator(final_states) + + def make_random_states_tensor(self, batch_shape: Tuple) -> GraphStates: + """Generates random states tensor of shape (*batch_shape, feature_dim).""" + return self.States.from_batch_shape(batch_shape) diff --git a/src/gfn/modules.py b/src/gfn/modules.py index 3ab6a991..342b4632 100644 --- a/src/gfn/modules.py +++ b/src/gfn/modules.py @@ -1,13 +1,19 @@ -from abc import ABC, abstractmethod +from abc import ABC from typing import Any import torch import torch.nn as nn -from torch.distributions import Categorical, Distribution - -from gfn.preprocessors import IdentityPreprocessor, Preprocessor -from gfn.states import DiscreteStates, States -from gfn.utils.distributions import UnsqueezedCategorical +from tensordict import TensorDict +from torch.distributions import Categorical, Distribution, Normal + +from gfn.preprocessors import GraphPreprocessor, IdentityPreprocessor, Preprocessor +from gfn.states import DiscreteStates, GraphStates, States +from gfn.utils.distributions import ( + CategoricalActionType, + CategoricalIndexes, + CompositeDistribution, + UnsqueezedCategorical, +) REDUCTION_FXNS = { "mean": torch.mean, @@ -77,7 +83,6 @@ def __init__( ) preprocessor = IdentityPreprocessor(module.input_dim) self.preprocessor = preprocessor - self._output_dim_is_checked = False self.is_backward = is_backward def forward(self, input: States | torch.Tensor) -> torch.Tensor: @@ -90,32 +95,11 @@ def forward(self, input: States | torch.Tensor) -> torch.Tensor: """ if isinstance(input, States): input = self.preprocessor(input) - - out = self.module(input) - - if not self._output_dim_is_checked: - self.check_output_dim(out) - self._output_dim_is_checked = True - - return out + return self.module(input) def __repr__(self): return f"{self.__class__.__name__} module" - @property - @abstractmethod - def expected_output_dim(self) -> int: - """Expected output dimension of the module.""" - - def check_output_dim(self, module_output: torch.Tensor) -> None: - """Check that the output of the module has the correct shape. Raises an error if not.""" - assert module_output.dtype == torch.float - if module_output.shape[-1] != self.expected_output_dim(): # pyright: ignore - raise ValueError( - f"{self.__class__.__name__} output dimension should be {self.expected_output_dim()}" # pyright: ignore - + f" but is {module_output.shape[-1]}." - ) - def to_probability_distribution( self, states: States, @@ -192,9 +176,6 @@ def __init__( ) self.reduction_fxn = REDUCTION_FXNS[reduction] - def expected_output_dim(self) -> int: - return 1 - def forward(self, input: States | torch.Tensor) -> torch.Tensor: """Forward pass of the module. @@ -212,10 +193,7 @@ def forward(self, input: States | torch.Tensor) -> torch.Tensor: if out.shape[-1] != 1: out = self.reduction_fxn(out, -1) - if not self._output_dim_is_checked: - # self.check_output_dim(out) - self._output_dim_is_checked = True - + assert out.shape[-1] == 1 return out @@ -250,12 +228,21 @@ def __init__( """ super().__init__(module, preprocessor, is_backward=is_backward) self.n_actions = n_actions + self.expected_output_dim = self.n_actions - int(self.is_backward) - def expected_output_dim(self) -> int: - if self.is_backward: - return self.n_actions - 1 - else: - return self.n_actions + def forward(self, states: DiscreteStates) -> torch.Tensor: + """Forward pass of the module. + + Args: + states: The input states. + + Returns the output of the module, as a tensor of shape (*batch_shape, output_dim). + """ + out = super().forward(states) + assert ( + out.shape[-1] == self.expected_output_dim + ), f"Expected output dim: {self.expected_output_dim}, got: {out.shape[-1]}" + return out def to_probability_distribution( self, @@ -279,7 +266,7 @@ def to_probability_distribution( on policy. epsilon: with probability epsilon, a random action is chosen. Does nothing if set to 0.0 (default), in which case it's on policy.""" - # self.check_output_dim(module_output) + assert module_output.shape[-1] == self.expected_output_dim masks = states.backward_masks if self.is_backward else states.forward_masks logits = module_output @@ -362,11 +349,7 @@ def forward(self, states: States, conditioning: torch.Tensor) -> torch.Tensor: Returns the output of the module, as a tensor of shape (*batch_shape, output_dim). """ out = self._forward_trunk(states, conditioning) - - if not self._output_dim_is_checked: - # self.check_output_dim(out) - self._output_dim_is_checked = True - + assert out.shape[-1] == self.expected_output_dim return out @@ -450,15 +433,9 @@ def forward(self, states: States, conditioning: torch.tensor) -> torch.Tensor: if out.shape[-1] != 1: out = self.reduction_fxn(out, -1) - if not self._output_dim_is_checked: - # self.check_output_dim(out) - self._output_dim_is_checked = True - + assert out.shape[-1] == 1 return out - def expected_output_dim(self) -> int: - return 1 - def to_probability_distribution( self, states: States, @@ -466,3 +443,102 @@ def to_probability_distribution( **policy_kwargs: Any, ) -> Distribution: raise NotImplementedError + + +class GraphActionPolicyEstimator(GFNModule): + r"""Container for forward and backward policy estimators for graph environments. + + $s \mapsto (P_F(s' \mid s))_{s' \in Children(s)}$. + + or + + $s \mapsto (P_B(s' \mid s))_{s' \in Parents(s)}$. + + Attributes: + temperature: scalar to divide the logits by before softmax. + sf_bias: scalar to subtract from the exit action logit before dividing by + temperature. + epsilon: with probability epsilon, a random action is chosen. + """ + + def __init__( + self, + module: nn.Module, + preprocessor: Preprocessor | None = None, + is_backward: bool = False, + ): + """Initializes a estimator for P_F for graph environments. + + Args: + is_backward: if False, then this is a forward policy, else backward policy. + """ + if preprocessor is None: + preprocessor = GraphPreprocessor() + super().__init__(module, preprocessor, is_backward) + + def forward(self, states: GraphStates) -> TensorDict: + """Forward pass of the module. + + Args: + states: The input graph states. + + Returns: + TensorDict containing: + - action_type: logits for action type selection (batch_shape, n_actions) + - features: parameters for node/edge features (batch_shape, feature_dim) + - edge_index: logits for edge connections (batch_shape, n_nodes, n_nodes) + """ + return self.module(states) + + def to_probability_distribution( + self, + states: GraphStates, + module_output: TensorDict, + temperature: float = 1.0, + epsilon: float = 0.0, + ) -> CompositeDistribution: + """Returns a probability distribution given a batch of states and module output. + + We handle off-policyness using these kwargs. + + Args: + states: The states to use. + module_output: The output of the module as a tensor of shape (*batch_shape, output_dim). + temperature: scalar to divide the logits by before softmax. Does nothing + if set to 1.0 (default), in which case it's on policy. + epsilon: with probability epsilon, a random action is chosen. Does nothing + if set to 0.0 (default), in which case it's on policy.""" + + dists = {} + + action_type_logits = module_output["action_type"] + masks = states.backward_masks if self.is_backward else states.forward_masks + action_type_logits[~masks["action_type"]] = -float("inf") + action_type_probs = torch.softmax(action_type_logits / temperature, dim=-1) + uniform_dist_probs = masks["action_type"].float() / masks["action_type"].sum( + dim=-1, keepdim=True + ) + action_type_probs = ( + 1 - epsilon + ) * action_type_probs + epsilon * uniform_dist_probs + dists["action_type"] = CategoricalActionType(probs=action_type_probs) + + edge_index_logits = module_output["edge_index"] + edge_index_logits[~masks["edge_index"]] = -float("inf") + if torch.any(edge_index_logits != -float("inf")): + B, N, N = edge_index_logits.shape + edge_index_logits = edge_index_logits.reshape(B, N * N) + edge_index_probs = torch.softmax(edge_index_logits / temperature, dim=-1) + uniform_dist_probs = ( + torch.ones_like(edge_index_probs) / edge_index_probs.shape[-1] + ) + edge_index_probs = ( + 1 - epsilon + ) * edge_index_probs + epsilon * uniform_dist_probs + edge_index_probs[torch.isnan(edge_index_probs)] = 1 + dists["edge_index"] = CategoricalIndexes( + probs=edge_index_probs, n_nodes=states.tensor.num_nodes + ) + + dists["features"] = Normal(module_output["features"], temperature) + return CompositeDistribution(dists=dists) diff --git a/src/gfn/preprocessors.py b/src/gfn/preprocessors.py index dfa3e2b1..bee77249 100644 --- a/src/gfn/preprocessors.py +++ b/src/gfn/preprocessors.py @@ -2,8 +2,9 @@ from typing import Callable import torch +from torch_geometric.data import Batch -from gfn.states import States +from gfn.states import GraphStates, States class Preprocessor(ABC): @@ -72,3 +73,11 @@ def preprocess(self, states) -> torch.Tensor: Returns the unique indices of the states as a tensor of shape `batch_shape`. """ return self.get_states_indices(states).long().unsqueeze(-1) + + +class GraphPreprocessor(Preprocessor): + def __init__(self) -> None: + super().__init__(-1) # TODO: review output_dim API + + def preprocess(self, states: GraphStates) -> Batch: + return states.tensor diff --git a/src/gfn/samplers.py b/src/gfn/samplers.py index ead9d004..c20423f4 100644 --- a/src/gfn/samplers.py +++ b/src/gfn/samplers.py @@ -147,7 +147,7 @@ def sample_trajectories( if conditioning is not None: assert states.batch_shape == conditioning.shape[: len(states.batch_shape)] - device = states.tensor.device + device = states.device dones = ( states.is_initial_state @@ -207,11 +207,12 @@ def sample_trajectories( all_estimator_outputs.append(estimator_outputs_padded) actions[~dones] = valid_actions - trajectories_actions.append(actions) if save_logprobs: # When off_policy, actions_log_probs are None. log_probs[~dones] = actions_log_probs - trajectories_logprobs.append(log_probs) + + trajectories_actions.append(actions) + trajectories_logprobs.append(log_probs) if self.estimator.is_backward: new_states = env._backward_step(states, actions) @@ -219,7 +220,7 @@ def sample_trajectories( new_states = env._step(states, actions) sink_states_mask = new_states.is_sink_state - # Increment the step, determine which trajectories are finisihed, and eval + # Increment the step, determine which trajectories are finished, and eval # rewards. step += 1 @@ -242,37 +243,24 @@ def sample_trajectories( ) states = new_states dones = dones | new_dones - trajectories_states.append(deepcopy(states)) - # TODO: do not ignore the next three ignores - trajectories_states = states.stack_states( - trajectories_states - ) # pyright: ignore - trajectories_actions = env.Actions.stack(trajectories_actions)[ - 1: # Drop dummy action - ] # pyright: ignore - trajectories_logprobs = ( - torch.stack(trajectories_logprobs, dim=0)[1:] # Drop dummy logprob - if save_logprobs - else None - ) # pyright: ignore - # TODO: use torch.nested.nested_tensor(dtype, device, requires_grad). - if save_estimator_outputs: - all_estimator_outputs = torch.stack(all_estimator_outputs, dim=0) - # TODO: do not ignore the next ignores trajectories = Trajectories( env=env, - states=trajectories_states, # pyright: ignore + states=env.States.stack(trajectories_states), conditioning=conditioning, - actions=trajectories_actions, # pyright: ignore + actions=env.Actions.stack(trajectories_actions[1:]), when_is_done=trajectories_dones, is_backward=self.estimator.is_backward, log_rewards=trajectories_log_rewards, - log_probs=trajectories_logprobs, # pyright: ignore + log_probs=( + torch.stack(trajectories_logprobs, dim=0)[1:] if save_logprobs else None + ), estimator_outputs=( - all_estimator_outputs if save_estimator_outputs else None - ), # pyright: ignore + torch.stack(all_estimator_outputs, dim=0) + if save_estimator_outputs + else None + ), ) return trajectories diff --git a/src/gfn/states.py b/src/gfn/states.py index 4b31a3e7..3ac4988f 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -3,9 +3,13 @@ from abc import ABC from copy import deepcopy from math import prod -from typing import Callable, ClassVar, List, Optional, Sequence +from typing import Callable, ClassVar, List, Optional, Sequence, Tuple +import numpy as np import torch +from torch_geometric.data import Data, Batch + +from gfn.actions import GraphActionType class States(ABC): @@ -40,13 +44,14 @@ class States(ABC): Attributes: tensor: Tensor representing a batch of states. - batch_shape: Sizes of the batch dimensions. + _batch_shape: Sizes of the batch dimensions. _log_rewards: Stores the log rewards of each state. """ - state_shape: ClassVar[tuple[int, ...]] # Shape of one state - s0: ClassVar[torch.Tensor] # Source state of the DAG - sf: ClassVar[torch.Tensor] # Dummy state, used to pad a batch of states + state_shape: ClassVar[tuple[int, ...]] + s0: ClassVar[torch.Tensor] + sf: ClassVar[torch.Tensor] + make_random_states_tensor: Callable = lambda x: (_ for _ in ()).throw( NotImplementedError( "The environment does not support initialization of random states." @@ -63,11 +68,19 @@ def __init__(self, tensor: torch.Tensor): assert tensor.shape[-len(self.state_shape) :] == self.state_shape self.tensor = tensor - self.batch_shape = tuple(self.tensor.shape)[: -len(self.state_shape)] + self._batch_shape = tuple(self.tensor.shape)[: -len(self.state_shape)] self._log_rewards = ( None # Useful attribute if we want to store the log-reward of the states ) + @property + def batch_shape(self) -> tuple[int, ...]: + return self._batch_shape + + @batch_shape.setter + def batch_shape(self, batch_shape: tuple[int, ...]) -> None: + self._batch_shape = batch_shape + @classmethod def from_batch_shape( cls, batch_shape: tuple[int, ...], random: bool = False, sink: bool = False @@ -104,14 +117,24 @@ def make_initial_states_tensor(cls, batch_shape: tuple[int, ...]) -> torch.Tenso """Makes a tensor with a `batch_shape` of states consisting of $s_0`$s.""" state_ndim = len(cls.state_shape) assert cls.s0 is not None and state_ndim is not None - return cls.s0.repeat(*batch_shape, *((1,) * state_ndim)) + if isinstance(cls.s0, torch.Tensor): + return cls.s0.repeat(*batch_shape, *((1,) * state_ndim)) + else: + raise NotImplementedError( + f"make_initial_states_tensor is not implemented by default for {cls.__name__}" + ) @classmethod def make_sink_states_tensor(cls, batch_shape: tuple[int, ...]) -> torch.Tensor: """Makes a tensor with a `batch_shape` of states consisting of $s_f$s.""" state_ndim = len(cls.state_shape) assert cls.sf is not None and state_ndim is not None - return cls.sf.repeat(*batch_shape, *((1,) * state_ndim)) + if isinstance(cls.sf, torch.Tensor): + return cls.sf.repeat(*batch_shape, *((1,) * state_ndim)) + else: + raise NotImplementedError( + f"make_sink_states_tensor is not implemented by default for {cls.__name__}" + ) def __len__(self): return prod(self.batch_shape) @@ -135,7 +158,9 @@ def __getitem__( return out def __setitem__( - self, index: int | Sequence[int] | Sequence[bool], states: States + self, + index: int | slice | tuple | Sequence[int] | Sequence[bool] | torch.Tensor, + states: States, ) -> None: """Set particular states of the batch.""" self.tensor[index] = states.tensor @@ -205,7 +230,7 @@ def extend_with_sf(self, required_first_dim: int) -> None: Args: required_first_dim: The size of the first batch dimension post-expansion. """ - if len(self.batch_shape) == 2: + if len(self.batch_shape) == 2 and isinstance(self.__class__.sf, torch.Tensor): if self.batch_shape[0] >= required_first_dim: return self.tensor = torch.cat( @@ -220,7 +245,7 @@ def extend_with_sf(self, required_first_dim: int) -> None: self.batch_shape = (required_first_dim, self.batch_shape[1]) else: raise ValueError( - f"extend_with_sf is not implemented for batch shapes {self.batch_shape}" + f"extend_with_sf is not implemented for graph states nor for batch shapes {self.batch_shape}" ) def compare(self, other: torch.Tensor) -> torch.Tensor: @@ -244,22 +269,32 @@ def compare(self, other: torch.Tensor) -> torch.Tensor: @property def is_initial_state(self) -> torch.Tensor: """Returns a tensor of shape `batch_shape` that is True for states that are $s_0$ of the DAG.""" - source_states_tensor = self.__class__.s0.repeat( - *self.batch_shape, *((1,) * len(self.__class__.state_shape)) - ) + if isinstance(self.__class__.s0, torch.Tensor): + source_states_tensor = self.__class__.s0.repeat( + *self.batch_shape, *((1,) * len(self.__class__.state_shape)) + ) + else: + raise NotImplementedError( + f"is_initial_state is not implemented by default for {self.__class__.__name__}" + ) return self.compare(source_states_tensor) @property def is_sink_state(self) -> torch.Tensor: """Returns a tensor of shape `batch_shape` that is True for states that are $s_f$ of the DAG.""" # TODO: self.__class__.sf == self.tensor -- or something similar? - sink_states = self.__class__.sf.repeat( - *self.batch_shape, *((1,) * len(self.__class__.state_shape)) - ).to(self.tensor.device) + if isinstance(self.__class__.sf, torch.Tensor): + sink_states = self.__class__.sf.repeat( + *self.batch_shape, *((1,) * len(self.__class__.state_shape)) + ).to(self.tensor.device) + else: + raise NotImplementedError( + f"is_sink_state is not implemented by default for {self.__class__.__name__}" + ) return self.compare(sink_states) @property - def log_rewards(self) -> torch.Tensor: + def log_rewards(self) -> torch.Tensor | None: """Returns the log rewards of the states as tensor of shape `batch_shape`.""" return self._log_rewards @@ -278,17 +313,22 @@ def sample(self, n_samples: int) -> States: return self[torch.randperm(len(self))[:n_samples]] @classmethod - def stack_states(cls, states: List[States]): + def stack(cls, states: Sequence[States]) -> States: """Given a list of states, stacks them along a new dimension (0).""" - state_example = states[0] # We assume all elems of `states` are the same. + state_example = states[0] + assert all( + state.batch_shape == state_example.batch_shape for state in states + ), "All states must have the same batch_shape" stacked_states = state_example.from_batch_shape((0, 0)) # Empty. stacked_states.tensor = torch.stack([s.tensor for s in states], dim=0) - # TODO: do not ignore the next ignore if state_example._log_rewards: - stacked_states._log_rewards = torch.stack( - [s._log_rewards for s in states], dim=0 # pyright: ignore - ) + log_rewards = [] + for s in states: + if s._log_rewards is None: + raise ValueError("Some states have no log rewards.") + log_rewards.append(s._log_rewards) + stacked_states._log_rewards = torch.stack(log_rewards, dim=0) # Adds the trajectory dimension. stacked_states.batch_shape = ( @@ -484,12 +524,590 @@ def init_forward_masks(self, set_ones: bool = True): self.forward_masks = torch.zeros(shape).bool() @classmethod - def stack_states(cls, states: List[DiscreteStates]): - stacked_states: DiscreteStates = super().stack_states(states) # pyright: ignore - stacked_states.forward_masks = torch.stack( - [s.forward_masks for s in states], dim=0 # pyright: ignore + def stack(cls, states: List[DiscreteStates]) -> DiscreteStates: + """Stacks a list of DiscreteStates objects along a new dimension (0).""" + out = super().stack(states) + assert isinstance(out, DiscreteStates) + out.forward_masks = torch.stack([s.forward_masks for s in states], dim=0) + out.backward_masks = torch.stack([s.backward_masks for s in states], dim=0) + return out + + +class GraphStates(States): + """ + Base class for Graph as a state representation. The `GraphStates` object is a batched collection of + multiple graph objects. The `Batch` object from PyTorch Geometric is used to represent the batch of + graph objects as states. + """ + + s0: ClassVar[Data] + sf: ClassVar[Data] + + def __init__(self, tensor: Batch): + """Initialize the GraphStates with a PyG Batch object. + + Args: + tensor: A PyG Batch object representing a batch of graphs. + """ + self.tensor = tensor + if not hasattr(self.tensor, 'batch_shape'): + self.tensor.batch_shape = self.tensor.batch_size + self._log_rewards: Optional[torch.Tensor] = None + + @property + def batch_shape(self) -> tuple[int, ...]: + """Returns the batch shape as a tuple.""" + return tuple(self.tensor.batch_shape.tolist()) + + + @classmethod + def from_batch_shape( + cls, batch_shape: int | Tuple, random: bool = False, sink: bool = False + ) -> GraphStates: + """Create a GraphStates object with the given batch shape. + + Args: + batch_shape: Shape of the batch dimensions. + random: Initialize states randomly. + sink: States initialized with s_f (the sink state). + + Returns: + A GraphStates object with the specified batch shape. + """ + if random and sink: + raise ValueError("Only one of `random` and `sink` should be True.") + if random: + tensor = cls.make_random_states_tensor(batch_shape) + elif sink: + tensor = cls.make_sink_states_tensor(batch_shape) + else: + tensor = cls.make_initial_states_tensor(batch_shape) + return cls(tensor) + + @classmethod + def make_initial_states_tensor(cls, batch_shape: int | Tuple) -> Batch: + """Makes a batch of graphs consisting of s0 states. + + Args: + batch_shape: Shape of the batch dimensions. + + Returns: + A PyG Batch object containing copies of the initial state. + """ + batch_shape = batch_shape if isinstance(batch_shape, Tuple) else (batch_shape,) + num_graphs = int(np.prod(batch_shape)) + + # Create a list of Data objects by copying s0 + data_list = [cls.s0.clone() for _ in range(num_graphs)] + + if len(data_list) == 0: # If batch_shape is 0, create a single empty graph + data_list = [Data( + x=torch.zeros(0, cls.s0.x.size(1)), + edge_index=torch.zeros(2, 0, dtype=torch.long), + edge_attr=torch.zeros(0, cls.s0.edge_attr.size(1)) + )] + + # Create a batch from the list + batch = Batch.from_data_list(data_list) + + # Store the batch shape for later reference + batch.batch_shape = torch.tensor(batch_shape, device=cls.s0.x.device) + + return batch + + @classmethod + def make_sink_states_tensor(cls, batch_shape: int | Tuple) -> Batch: + """Makes a batch of graphs consisting of sf states. + + Args: + batch_shape: Shape of the batch dimensions. + + Returns: + A PyG Batch object containing copies of the sink state. + """ + if cls.sf is None: + raise NotImplementedError("Sink state is not defined") + + batch_shape = batch_shape if isinstance(batch_shape, Tuple) else (batch_shape,) + num_graphs = int(np.prod(batch_shape)) + + # Create a list of Data objects by copying sf + data_list = [cls.sf.clone() for _ in range(num_graphs)] + if len(data_list) == 0: # If batch_shape is 0, create a single empty graph + data_list = [Data( + x=torch.zeros(0, cls.sf.x.size(1)), + edge_index=torch.zeros(2, 0, dtype=torch.long), + edge_attr=torch.zeros(0, cls.sf.edge_attr.size(1)) + )] + + # Create a batch from the list + batch = Batch.from_data_list(data_list) + + # Store the batch shape for later reference + batch.batch_shape = torch.tensor(batch_shape, device=cls.sf.x.device) + + return batch + + @classmethod + def make_random_states_tensor(cls, batch_shape: int | Tuple) -> Batch: + """Makes a batch of random graph states. + + Args: + batch_shape: Shape of the batch dimensions. + + Returns: + A PyG Batch object containing random graph states. + """ + batch_shape = batch_shape if isinstance(batch_shape, Tuple) else (batch_shape,) + num_graphs = int(np.prod(batch_shape)) + device = cls.s0.x.device + + data_list = [] + for _ in range(num_graphs): + # Create a random graph with random number of nodes + num_nodes = np.random.randint(1, 10) + + # Create random node features + x = torch.rand(num_nodes, cls.s0.x.size(1), device=device) + + # Create random edges (not all possible edges to keep it sparse) + num_edges = np.random.randint(0, num_nodes * (num_nodes - 1) // 2 + 1) + if num_edges > 0 and num_nodes > 1: + # Generate random source and target nodes + edge_index = torch.zeros(2, num_edges, dtype=torch.long, device=device) + for i in range(num_edges): + src, dst = np.random.choice(num_nodes, 2, replace=False) + edge_index[0, i] = src + edge_index[1, i] = dst + + # Create random edge features + edge_attr = torch.rand(num_edges, cls.s0.edge_attr.size(1), device=device) + + data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr) + else: + # No edges + data = Data(x=x, edge_index=torch.zeros(2, 0, dtype=torch.long, device=device), + edge_attr=torch.zeros(0, cls.s0.edge_attr.size(1), device=device)) + + data_list.append(data) + + if len(data_list) == 0: # If batch_shape is 0, create a single empty graph + data_list = [Data( + x=torch.zeros(0, cls.s0.x.size(1)), + edge_index=torch.zeros(2, 0, dtype=torch.long), + edge_attr=torch.zeros(0, cls.s0.edge_attr.size(1)) + )] + + # Create a batch from the list + batch = Batch.from_data_list(data_list) + + # Store the batch shape for later reference + batch.batch_shape = torch.tensor(batch_shape, device=device) + + return batch + + def __len__(self) -> int: + """Returns the number of graphs in the batch.""" + return int(np.prod(self.batch_shape)) + + def __repr__(self): + """Returns a string representation of the GraphStates object.""" + return ( + f"{self.__class__.__name__} object of batch shape {self.batch_shape} and " + f"node feature dim {self.tensor.x.size(1)} and edge feature dim {self.tensor.edge_attr.size(1)}" ) - stacked_states.backward_masks = torch.stack( - [s.backward_masks for s in states], dim=0 # pyright: ignore + + def __getitem__( + self, index: int | Sequence[int] | slice | torch.Tensor + ) -> GraphStates: + """Get a subset of the GraphStates. + + Args: + index: Index or indices to select. + + Returns: + A new GraphStates object containing the selected graphs. + """ + # Convert the index to a list of indices + tensor_idx = torch.arange(len(self)).view(*self.batch_shape) + if isinstance(index, int): + new_shape = (1,) + else: + new_shape = tensor_idx[index].shape + indices = tensor_idx[index].flatten().tolist() + + # Get the selected graphs from the batch + selected_graphs = self.tensor.index_select(indices) + if len(selected_graphs) == 0: + assert np.prod(new_shape) == 0 + selected_graphs = [Data( + x=torch.zeros(0, self.tensor.x.size(1)), + edge_index=torch.zeros(2, 0, dtype=torch.long), + edge_attr=torch.zeros(0, self.tensor.edge_attr.size(1)) + )] + + # Create a new batch from the selected graphs + new_batch = Batch.from_data_list(selected_graphs) + new_batch.batch_shape = torch.tensor(new_shape, device=self.tensor.batch_shape.device) + + # Create a new GraphStates object + out = self.__class__(new_batch) + + # Copy log rewards if they exist + if self._log_rewards is not None: + out._log_rewards = self._log_rewards[indices] + + return out + + def __setitem__(self, index: int | Sequence[int], graph: GraphStates): + """Set a subset of the GraphStates. + + Args: + index: Index or indices to set. + graph: GraphStates object containing the new graphs. + """ + # Convert the index to a list of indices + batch_shape = self.batch_shape + if isinstance(index, int): + indices = [index] + else: + tensor_idx = torch.arange(len(self)).view(*batch_shape) + indices = tensor_idx[index].flatten().tolist() + + # Get the data list from the current batch + data_list = self.tensor.to_data_list() + + # Get the data list from the new graphs + new_data_list = graph.tensor.to_data_list() + + # Replace the selected graphs + for i, idx in enumerate(indices): + if i < len(new_data_list): + data_list[idx] = new_data_list[i] + + # Create a new batch from the updated data list + self.tensor = Batch.from_data_list(data_list) + + # Preserve the batch shape + self.tensor.batch_shape = torch.tensor(batch_shape, device=self.tensor.x.device) + + @property + def device(self) -> torch.device: + """Returns the device of the tensor.""" + return self.tensor.x.device + + def to(self, device: torch.device) -> GraphStates: + """Moves the GraphStates to the specified device. + + Args: + device: The device to move to. + + Returns: + The GraphStates object on the specified device. + """ + self.tensor = self.tensor.to(device) + if self._log_rewards is not None: + self._log_rewards = self._log_rewards.to(device) + return self + + def clone(self) -> GraphStates: + """Returns a detached clone of the current instance. + + Returns: + A new GraphStates object with the same data. + """ + # Create a deep copy of the batch + data_list = [data.clone() for data in self.tensor.to_data_list()] + new_batch = Batch.from_data_list(data_list) + new_batch.batch_shape = self.tensor.batch_shape.clone() + + # Create a new GraphStates object + out = self.__class__(new_batch) + + # Copy log rewards if they exist + if self._log_rewards is not None: + out._log_rewards = self._log_rewards.clone() + + return out + + def extend(self, other: GraphStates): + """Concatenates to another GraphStates object along the batch dimension. + + Args: + other: GraphStates object to concatenate with. + """ + if len(self) == 0: + # If self is empty, just copy other + self.tensor = other.tensor.clone() + if other._log_rewards is not None: + self._log_rewards = other._log_rewards.clone() + return + + # Get the data lists + self_data_list = self.tensor.to_data_list() + other_data_list = other.tensor.to_data_list() + + # Update the batch shape + if len(self.batch_shape) == 1: + # Create a new batch + new_batch_shape = (self.batch_shape[0] + other.batch_shape[0],) + self.tensor = Batch.from_data_list(self_data_list + other_data_list) + self.tensor.batch_shape = torch.tensor(new_batch_shape, device=self.tensor.x.device) + else: + # Handle the case where batch_shape is (T, B) + # and we want to concatenate along the B dimension + assert len(self.batch_shape) == 2 + max_len = max(self.batch_shape[0], other.batch_shape[0]) + + # We need to extend both batches to the same length T + if self.batch_shape[0] < max_len: + self_extension = self.make_sink_states_tensor( + (max_len - self.batch_shape[0], self.batch_shape[1]) + ) + self_data_list = self_data_list + self_extension.to_data_list() + + if other.batch_shape[0] < max_len: + other_extension = other.make_sink_states_tensor( + (max_len - other.batch_shape[0], other.batch_shape[1]) + ) + other_data_list = other_data_list + other_extension.to_data_list() + + # Now both have the same length T, we can concatenate along B + batch_shape = (max_len, self.batch_shape[1] + other.batch_shape[1]) + self.tensor = Batch.from_data_list(self_data_list + other_data_list) + self.tensor.batch_shape = torch.tensor(batch_shape, device=self.tensor.x.device) + + # Combine log rewards if they exist + if self._log_rewards is not None and other._log_rewards is not None: + self._log_rewards = torch.cat([self._log_rewards, other._log_rewards], dim=0) + elif other._log_rewards is not None: + self._log_rewards = other._log_rewards.clone() + + @property + def log_rewards(self) -> torch.Tensor | None: + """Returns the log rewards of the states.""" + return self._log_rewards + + @log_rewards.setter + def log_rewards(self, log_rewards: torch.Tensor) -> None: + """Sets the log rewards of the states. + + Args: + log_rewards: Tensor of shape `batch_shape` representing the log rewards. + """ + assert log_rewards.shape == self.batch_shape + self._log_rewards = log_rewards + + def _compare(self, other: Data) -> torch.Tensor: + """Compares the current batch of graphs with another graph. + + Args: + other: A PyG Data object to compare with. + + Returns: + A boolean tensor indicating which graphs in the batch are equal to other. + """ + out = torch.zeros(len(self), dtype=torch.bool, device=self.device) + + # Get the data list from the batch + data_list = self.tensor.to_data_list() + + for i, data in enumerate(data_list): + # Check if the number of nodes is the same + if data.num_nodes != other.num_nodes: + continue + + # Check if node features are the same + if not torch.all(data.x == other.x): + continue + + # Check if the number of edges is the same + if data.edge_index.size(1) != other.edge_index.size(1): + continue + + # Check if edge indices are the same (this is more complex due to potential reordering) + # We'll use a simple heuristic: sort edges and compare + data_edges = data.edge_index.t().tolist() + other_edges = other.edge_index.t().tolist() + data_edges.sort() + other_edges.sort() + if data_edges != other_edges: + continue + + # Check if edge attributes are the same (after sorting) + data_edge_attr = data.edge_attr[torch.argsort(data.edge_index[0] * data.num_nodes + data.edge_index[1])] + other_edge_attr = other.edge_attr[torch.argsort(other.edge_index[0] * other.num_nodes + other.edge_index[1])] + if not torch.all(data_edge_attr == other_edge_attr): + continue + + # If all checks pass, the graphs are equal + out[i] = True + + return out.view(self.batch_shape) + + @property + def is_sink_state(self) -> torch.Tensor: + """Returns a tensor that is True for states that are sf.""" + return self._compare(self.sf) + + @property + def is_initial_state(self) -> torch.Tensor: + """Returns a tensor that is True for states that are s0.""" + return self._compare(self.s0) + + @classmethod + def stack(cls, states: List[GraphStates]) -> GraphStates: + """Given a list of states, stacks them along a new dimension (0). + + Args: + states: List of GraphStates objects to stack. + + Returns: + A new GraphStates object with the stacked states. + """ + # Check that all states have the same batch shape + state_batch_shape = states[0].batch_shape + assert all(state.batch_shape == state_batch_shape for state in states) + + # Get all data lists + all_data_lists = [state.tensor.to_data_list() for state in states] + + # Flatten the list of lists + flat_data_list = [data for data_list in all_data_lists for data in data_list] + + # Create a new batch + batch = Batch.from_data_list(flat_data_list) + + # Set the batch shape + batch.batch_shape = torch.tensor( + (len(states),) + state_batch_shape, + device=states[0].device ) - return stacked_states + + # Create a new GraphStates object + out = cls(batch) + + # Stack log rewards if they exist + if all(state._log_rewards is not None for state in states): + out._log_rewards = torch.stack([state._log_rewards for state in states], dim=0) + + return out + + @property + def forward_masks(self) -> dict: + """Returns masks denoting allowed forward actions. + + Returns: + A dictionary containing masks for different action types. + """ + # Get the data list from the batch + data_list = self.tensor.to_data_list() + N = self.tensor.x.size(0) + + # Initialize masks + action_type_mask = torch.ones(self.batch_shape + (3,), dtype=torch.bool, device=self.device) + features_mask = torch.ones( + self.batch_shape + (self.tensor.x.size(1),), + dtype=torch.bool, + device=self.device + ) + edge_index_masks = torch.ones((len(data_list), N, N), dtype=torch.bool, device=self.device) + + + # For each graph in the batch + for i, data in enumerate(data_list): + # Flatten the batch index + flat_idx = i + + # ADD_NODE is always allowed + action_type_mask[flat_idx, GraphActionType.ADD_NODE] = True + + # ADD_EDGE is allowed only if there are at least 2 nodes + action_type_mask[flat_idx, GraphActionType.ADD_EDGE] = data.num_nodes > 1 + + # EXIT is always allowed + action_type_mask[flat_idx, GraphActionType.EXIT] = True + + # Create edge_index mask as a dense representation (NxN matrix) + start_n = 0 + for i, data in enumerate(data_list): + # For each graph, create a dense mask for potential edges + n = data.num_nodes + + edge_mask = torch.ones((n, n), dtype=torch.bool, device=self.device) + # Remove self-loops by setting diagonal to False + edge_mask.fill_diagonal_(False) + + # Exclude existing edges + if data.edge_index.size(1) > 0: + for j in range(data.edge_index.size(1)): + src, dst = data.edge_index[0, j], data.edge_index[1, j] + edge_mask[src, dst] = False + + edge_index_masks[i, start_n:(start_n + n), start_n:(start_n + n)] = edge_mask + start_n += n + + # Update ADD_EDGE mask based on whether there are valid edges to add + action_type_mask[flat_idx, GraphActionType.ADD_EDGE] &= edge_mask.any() + + return { + "action_type": action_type_mask, + "features": features_mask, + "edge_index": edge_index_masks + } + + @property + def backward_masks(self) -> dict: + """Returns masks denoting allowed backward actions. + + Returns: + A dictionary containing masks for different action types. + """ + # Get the data list from the batch + data_list = self.tensor.to_data_list() + N = self.tensor.x.size(0) + + # Initialize masks + action_type_mask = torch.ones(self.batch_shape + (3,), dtype=torch.bool, device=self.device) + features_mask = torch.ones( + self.batch_shape + (self.tensor.x.size(1),), + dtype=torch.bool, + device=self.device + ) + edge_index_masks = torch.zeros((len(data_list), N, N), dtype=torch.bool, device=self.device) + + # For each graph in the batch + for i, data in enumerate(data_list): + # Flatten the batch index + flat_idx = i + + # ADD_NODE is allowed if there's at least one node (can remove a node) + action_type_mask[flat_idx, GraphActionType.ADD_NODE] = data.num_nodes >= 1 + + # ADD_EDGE is allowed if there's at least one edge (can remove an edge) + action_type_mask[flat_idx, GraphActionType.ADD_EDGE] = data.edge_index.size(1) > 0 + + # EXIT is allowed if there's at least one node + action_type_mask[flat_idx, GraphActionType.EXIT] = data.num_nodes >= 1 + + # Create edge_index mask for backward actions (existing edges that can be removed) + start_n = 0 + for i, data in enumerate(data_list): + # For backward actions, we can only remove existing edges + n = data.num_nodes + edge_mask = torch.zeros((n, n), dtype=torch.bool, device=self.device) + + # Include only existing edges + if data.edge_index.size(1) > 0: + for j in range(data.edge_index.size(1)): + src, dst = data.edge_index[0, j].item(), data.edge_index[1, j].item() + edge_mask[src, dst] = True + + edge_index_masks[i, start_n:(start_n + n), start_n:(start_n + n)] = edge_mask + start_n += n + + return { + "action_type": action_type_mask, + "features": features_mask, + "edge_index": edge_index_masks + } diff --git a/src/gfn/utils/distributions.py b/src/gfn/utils/distributions.py index f4948d0d..3ea029cf 100644 --- a/src/gfn/utils/distributions.py +++ b/src/gfn/utils/distributions.py @@ -1,5 +1,7 @@ +from typing import Dict + import torch -from torch.distributions import Categorical +from torch.distributions import Categorical, Distribution class UnsqueezedCategorical(Categorical): @@ -39,3 +41,75 @@ def log_prob(self, sample: torch.Tensor) -> torch.Tensor: """ assert sample.shape[-1] == 1 return super().log_prob(sample.squeeze(-1)) + + +class CompositeDistribution( + Distribution +): # TODO: may use CompositeDistribution in TensorDict + """A mixture distribution.""" + + def __init__(self, dists: Dict[str, Distribution]): + """Initializes the mixture distribution. + + Args: + dists: A dictionary of distributions. + """ + super().__init__() + self.dists = dists + + def sample(self, sample_shape=torch.Size()) -> Dict[str, torch.Tensor]: + return {k: v.sample(sample_shape) for k, v in self.dists.items()} + + def log_prob(self, sample: Dict[str, torch.Tensor]) -> torch.Tensor: + log_probs = [ + v.log_prob(sample[k]).reshape(sample[k].shape[0], -1).sum(dim=-1) + for k, v in self.dists.items() + ] + # Note: this returns the sum of the log_probs over all the components + # as it is a uniform mixture distribution. + return sum(log_probs) + + +class CategoricalIndexes(Categorical): + """Samples indexes from a categorical distribution.""" + + def __init__(self, probs: torch.Tensor, n_nodes: int): + """Initializes the distribution. + + Args: + probs: The probabilities of the categorical distribution. + n: The number of nodes in the graph. + """ + assert probs.shape == ( + probs.shape[0], + n_nodes * n_nodes, + ) + super().__init__(probs) + + def sample(self, sample_shape=torch.Size()) -> torch.Tensor: + samples = super().sample(sample_shape) + out = torch.stack( + [ + samples // self.n_nodes, + samples % self.n_nodes, + ], + dim=-1, + ) + return out + + def log_prob(self, value): + value = value[..., 0] * self.n_nodes + value[..., 1] + return super().log_prob(value) + + +class CategoricalActionType(Categorical): # TODO: remove, just to sample 1 action_type + def __init__(self, probs: torch.Tensor): + self.batch_len = len(probs) + super().__init__(probs[0]) + + def sample(self, sample_shape=torch.Size()) -> torch.Tensor: + samples = super().sample(sample_shape) + return samples.repeat(self.batch_len) + + def log_prob(self, value): + return super().log_prob(value[0]).repeat(self.batch_len) diff --git a/src/gfn/utils/modules.py b/src/gfn/utils/modules.py index d8106783..d13028d7 100644 --- a/src/gfn/utils/modules.py +++ b/src/gfn/utils/modules.py @@ -17,6 +17,7 @@ def __init__( n_hidden_layers: Optional[int] = 2, activation_fn: Optional[Literal["relu", "tanh", "elu"]] = "relu", trunk: Optional[nn.Module] = None, + add_layer_norm: bool = False, ): """Instantiates a MLP instance. @@ -28,6 +29,8 @@ def __init__( activation_fn: Activation function. trunk: If provided, this module will be used as the trunk of the network (i.e. all layers except last layer). + add_layer_norm: If True, add a LayerNorm after each linear layer. + (incompatible with `trunk` argument) """ super().__init__() self._input_dim = input_dim @@ -44,9 +47,18 @@ def __init__( activation = nn.ReLU elif activation_fn == "tanh": activation = nn.Tanh - arch = [nn.Linear(input_dim, hidden_dim), activation()] + if add_layer_norm: + arch = [ + nn.Linear(input_dim, hidden_dim), + nn.LayerNorm(hidden_dim), + activation(), + ] + else: + arch = [nn.Linear(input_dim, hidden_dim), activation()] for _ in range(n_hidden_layers - 1): arch.append(nn.Linear(hidden_dim, hidden_dim)) + if add_layer_norm: + arch.append(nn.LayerNorm(hidden_dim)) arch.append(activation()) self.trunk = nn.Sequential(*arch) self.trunk.hidden_dim = hidden_dim diff --git a/src/gfn/utils/training.py b/src/gfn/utils/training.py index 36afb450..d42b61dc 100644 --- a/src/gfn/utils/training.py +++ b/src/gfn/utils/training.py @@ -132,7 +132,7 @@ def states_actions_tns_to_traj( # stack is a class method, so actions[0] is just to access a class instance and is not particularly relevant actions = actions[0].stack(actions) log_rewards = env.log_reward(states[-2]) - states = states[0].stack_states(states) + states = states[0].stack(states) when_is_done = torch.tensor([len(states_tns) - 1]) log_probs = None diff --git a/testing/test_actions.py b/testing/test_actions.py new file mode 100644 index 00000000..03e26a94 --- /dev/null +++ b/testing/test_actions.py @@ -0,0 +1,139 @@ +from copy import deepcopy + +import pytest +import torch +from tensordict import TensorDict + +from gfn.actions import Actions, GraphActions + + +class ContinuousActions(Actions): + action_shape = (10,) + dummy_action = torch.zeros(10) + exit_action = torch.ones(10) + + +class TestGraphActions(GraphActions): + features_dim = 10 + + +@pytest.fixture +def continuous_action(): + return ContinuousActions(tensor=torch.arange(0, 10)) + + +@pytest.fixture +def graph_action(): + return TestGraphActions( + tensor=TensorDict( + { + "action_type": torch.zeros((1,), dtype=torch.float32), + "features": torch.zeros((1, 10), dtype=torch.float32), + }, + device="cpu", + ) + ) + + +def test_continuous_action(continuous_action): + BATCH = 5 + + exit_actions = continuous_action.make_exit_actions((BATCH,)) + assert torch.all( + exit_actions.tensor == continuous_action.exit_action.repeat(BATCH, 1) + ) + assert torch.all(exit_actions.is_exit == torch.ones(BATCH, dtype=torch.bool)) + assert torch.all(exit_actions.is_dummy == torch.zeros(BATCH, dtype=torch.bool)) + + dummy_actions = continuous_action.make_dummy_actions((BATCH,)) + assert torch.all( + dummy_actions.tensor == continuous_action.dummy_action.repeat(BATCH, 1) + ) + assert torch.all(dummy_actions.is_dummy == torch.ones(BATCH, dtype=torch.bool)) + assert torch.all(dummy_actions.is_exit == torch.zeros(BATCH, dtype=torch.bool)) + + # Test stack + stacked_actions = continuous_action.stack([exit_actions, dummy_actions]) + assert stacked_actions.batch_shape == (2, BATCH) + assert torch.all( + stacked_actions.tensor + == torch.stack([exit_actions.tensor, dummy_actions.tensor], dim=0) + ) + is_exit_stacked = torch.stack([exit_actions.is_exit, dummy_actions.is_exit], dim=0) + assert torch.all(stacked_actions.is_exit == is_exit_stacked) + assert stacked_actions[0, 1].is_exit + stacked_actions[0, 1] = stacked_actions[1, 1] + is_exit_stacked[0, 1] = False + assert torch.all(stacked_actions.is_exit == is_exit_stacked) + + # Test extend + extended_actions = deepcopy(exit_actions) + extended_actions.extend(dummy_actions) + assert extended_actions.batch_shape == (BATCH * 2,) + assert torch.all( + extended_actions.tensor + == torch.cat([exit_actions.tensor, dummy_actions.tensor], dim=0) + ) + is_exit_extended = torch.cat([exit_actions.is_exit, dummy_actions.is_exit], dim=0) + assert torch.all(extended_actions.is_exit == is_exit_extended) + assert extended_actions[0].is_exit and extended_actions[BATCH].is_dummy + extended_actions[0] = extended_actions[BATCH] + is_exit_extended[0] = False + assert torch.all(extended_actions.is_exit == is_exit_extended) + + +def test_graph_action(graph_action): + BATCH = 5 + + exit_actions = graph_action.make_exit_actions((BATCH,)) + assert torch.all(exit_actions.is_exit == torch.ones(BATCH, dtype=torch.bool)) + assert torch.all(exit_actions.is_dummy == torch.zeros(BATCH, dtype=torch.bool)) + dummy_actions = graph_action.make_dummy_actions((BATCH,)) + assert torch.all(dummy_actions.is_dummy == torch.ones(BATCH, dtype=torch.bool)) + assert torch.all(dummy_actions.is_exit == torch.zeros(BATCH, dtype=torch.bool)) + + # Test stack + stacked_actions = graph_action.stack([exit_actions, dummy_actions]) + assert stacked_actions.batch_shape == (2, BATCH) + manually_stacked_tensor = torch.stack( + [exit_actions.tensor, dummy_actions.tensor], dim=0 + ) + assert torch.all( + stacked_actions.tensor["action_type"] == manually_stacked_tensor["action_type"] + ) + assert torch.all( + stacked_actions.tensor["features"] == manually_stacked_tensor["features"] + ) + assert torch.all( + stacked_actions.tensor["edge_index"] == manually_stacked_tensor["edge_index"] + ) + is_exit_stacked = torch.stack([exit_actions.is_exit, dummy_actions.is_exit], dim=0) + assert torch.all(stacked_actions.is_exit == is_exit_stacked) + assert stacked_actions[0, 1].is_exit + stacked_actions[0, 1] = stacked_actions[1, 1] + is_exit_stacked[0, 1] = False + assert torch.all(stacked_actions.is_exit == is_exit_stacked) + + # Test extend + extended_actions = deepcopy(exit_actions) + extended_actions.extend(dummy_actions) + assert extended_actions.batch_shape == (BATCH * 2,) + manually_extended_tensor = torch.cat( + [exit_actions.tensor, dummy_actions.tensor], dim=0 + ) + assert torch.all( + extended_actions.tensor["action_type"] + == manually_extended_tensor["action_type"] + ) + assert torch.all( + extended_actions.tensor["features"] == manually_extended_tensor["features"] + ) + assert torch.all( + extended_actions.tensor["edge_index"] == manually_extended_tensor["edge_index"] + ) + is_exit_extended = torch.cat([exit_actions.is_exit, dummy_actions.is_exit], dim=0) + assert torch.all(extended_actions.is_exit == is_exit_extended) + assert extended_actions[0].is_exit and extended_actions[BATCH].is_dummy + extended_actions[0] = extended_actions[BATCH] + is_exit_extended[0] = False + assert torch.all(extended_actions.is_exit == is_exit_extended) diff --git a/testing/test_environments.py b/testing/test_environments.py index dc852537..22137a3d 100644 --- a/testing/test_environments.py +++ b/testing/test_environments.py @@ -1,9 +1,12 @@ import numpy as np import pytest import torch +from tensordict import TensorDict +from gfn.actions import GraphActionType from gfn.env import NonValidActionsError from gfn.gym import Box, DiscreteEBM, HyperGrid +from gfn.gym.graph_building import GraphBuilding # Utilities. @@ -334,3 +337,166 @@ def test_get_grid(): # State indices of the grid are ordered from 0:HEIGHT**2. assert (env.get_states_indices(grid).ravel() == torch.arange(HEIGHT**2)).all() + + +def test_graph_env(): + FEATURE_DIM = 8 + BATCH_SIZE = 3 + NUM_NODES = 5 + + env = GraphBuilding( + feature_dim=FEATURE_DIM, state_evaluator=lambda s: torch.zeros(s.batch_shape) + ) + states = env.reset(batch_shape=BATCH_SIZE) + assert states.batch_shape == (BATCH_SIZE,) + action_cls = env.make_actions_class() + + with pytest.raises(NonValidActionsError): + actions = action_cls( + TensorDict( + { + "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_EDGE), + "features": torch.rand((BATCH_SIZE, FEATURE_DIM)), + "edge_index": torch.randint( + 0, 10, (BATCH_SIZE, 2), dtype=torch.long + ), + }, + batch_size=BATCH_SIZE, + ) + ) + states = env.step(states, actions) + + for _ in range(NUM_NODES): + actions = action_cls( + TensorDict( + { + "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_NODE), + "features": torch.rand((BATCH_SIZE, FEATURE_DIM)), + }, + batch_size=BATCH_SIZE, + ) + ) + states = env.step(states, actions) + states = env.States(states) + + assert states.tensor.x.shape == (BATCH_SIZE * NUM_NODES, FEATURE_DIM) + + with pytest.raises(NonValidActionsError): + first_node_mask = ( + torch.arange(len(states.tensor.x)) // BATCH_SIZE == 0 + ) + actions = action_cls( + TensorDict( + { + "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_NODE), + "features": states.tensor.x[first_node_mask], + }, + batch_size=BATCH_SIZE, + ) + ) + states = env.step(states, actions) + + with pytest.raises(NonValidActionsError): + edge_index = torch.randint(0, 3, (BATCH_SIZE,), dtype=torch.long) + actions = action_cls( + TensorDict( + { + "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_EDGE), + "features": torch.rand((BATCH_SIZE, FEATURE_DIM)), + "edge_index": torch.stack([edge_index, edge_index], dim=1), + }, + batch_size=BATCH_SIZE, + ) + ) + states = env.step(states, actions) + + for i in range(NUM_NODES - 1): + # node_is = torch.arange(BATCH_SIZE) * NUM_NODES + i + # node_js = torch.arange(BATCH_SIZE) * NUM_NODES + i + 1 + actions = action_cls( + TensorDict( + { + "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_EDGE), + "features": torch.rand((BATCH_SIZE, FEATURE_DIM)), + "edge_index": torch.tensor([[i, i + 1]] * BATCH_SIZE), + }, + batch_size=BATCH_SIZE, + ) + ) + states = env.step(states, actions) + states = env.States(states) + + actions = action_cls( + TensorDict( + { + "action_type": torch.full((BATCH_SIZE,), GraphActionType.EXIT), + }, + batch_size=BATCH_SIZE, + ) + ) + + states.forward_masks + states.backward_masks + sf_states = env.step(states, actions) + sf_states = env.States(sf_states) + assert torch.all(sf_states.is_sink_state) + env.reward(sf_states) + + num_edges_per_batch = len(states.tensor.edge_attr) // BATCH_SIZE + for i in reversed(range(num_edges_per_batch)): + edge_idx = torch.arange(i, (i + 1) * BATCH_SIZE, i + 1) + actions = action_cls( + TensorDict( + { + "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_EDGE), + "features": states.tensor.edge_attr[edge_idx], + "edge_index": states.tensor.edge_index[:, edge_idx].T - states.tensor.ptr[:-1, None], + }, + batch_size=BATCH_SIZE, + ) + ) + states = env.backward_step(states, actions) + states = env.States(states) + + with pytest.raises(NonValidActionsError): + actions = action_cls( + TensorDict( + { + "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_EDGE), + "features": torch.rand((BATCH_SIZE, FEATURE_DIM)), + "edge_index": torch.randint( + 0, 10, (BATCH_SIZE, 2), dtype=torch.long + ), + }, + batch_size=BATCH_SIZE, + ) + ) + states = env.backward_step(states, actions) + + for i in reversed(range(1, NUM_NODES + 1)): + edge_idx = torch.arange(BATCH_SIZE) * i + actions = action_cls( + TensorDict( + { + "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_NODE), + "features": states.tensor.x[edge_idx], + }, + batch_size=BATCH_SIZE, + ) + ) + states = env.backward_step(states, actions) + states = env.States(states) + + assert states.tensor.x.shape == (0, FEATURE_DIM) + + with pytest.raises(NonValidActionsError): + actions = action_cls( + TensorDict( + { + "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_NODE), + "features": torch.rand((BATCH_SIZE, FEATURE_DIM)), + }, + batch_size=BATCH_SIZE, + ) + ) + states = env.backward_step(states, actions) diff --git a/testing/test_graph_states.py b/testing/test_graph_states.py new file mode 100644 index 00000000..f03f6373 --- /dev/null +++ b/testing/test_graph_states.py @@ -0,0 +1,350 @@ +import pytest +import torch +from torch_geometric.data import Data, Batch +from gfn.states import GraphStates +from gfn.actions import GraphActionType + + +class MyGraphStates(GraphStates): + # Initial state: a graph with 2 nodes and 1 edge + s0 = Data( + x=torch.tensor([[1.0], [2.0]]), + edge_index=torch.tensor([[0], [1]]), + edge_attr=torch.tensor([[0.5]]) + ) + + # Sink state: a graph with 2 nodes and 1 edge (different from s0) + sf = Data( + x=torch.tensor([[3.0], [4.0]]), + edge_index=torch.tensor([[0], [1]]), + edge_attr=torch.tensor([[0.7]]) + ) + + +@pytest.fixture +def simple_graph_state(): + """Creates a simple graph state with 2 nodes and 1 edge""" + data = Data( + x=torch.tensor([[1.0], [2.0]]), + edge_index=torch.tensor([[0], [1]]), + edge_attr=torch.tensor([[0.5]]) + ) + batch = Batch.from_data_list([data]) + batch.batch_shape = torch.tensor([1]) + return MyGraphStates(batch) + + +@pytest.fixture +def empty_graph_state(): + """Creates an empty GraphStates object""" + # Create an empty batch + batch = Batch() + batch.x = torch.zeros((0, 1)) + batch.edge_index = torch.zeros((2, 0), dtype=torch.long) + batch.edge_attr = torch.zeros((0, 1)) + batch.batch = torch.zeros((0,), dtype=torch.long) + batch.batch_shape = torch.tensor([0]) + return MyGraphStates(batch) + + +def test_extend_empty_state(empty_graph_state, simple_graph_state): + """Test extending an empty state with a non-empty state""" + empty_graph_state.extend(simple_graph_state) + + # Check that the empty state now has the same content as the simple state + assert torch.equal(empty_graph_state.tensor.batch_shape, simple_graph_state.tensor.batch_shape) + assert torch.equal(empty_graph_state.tensor.x, simple_graph_state.tensor.x) + assert torch.equal(empty_graph_state.tensor.edge_index, simple_graph_state.tensor.edge_index) + assert torch.equal(empty_graph_state.tensor.edge_attr, simple_graph_state.tensor.edge_attr) + assert torch.equal(empty_graph_state.tensor.batch, simple_graph_state.tensor.batch) + + +def test_extend_1d_batch(simple_graph_state): + """Test extending two 1D batch states""" + other_state = simple_graph_state.clone() + + # Store original number of nodes and edges + original_num_nodes = simple_graph_state.tensor.num_nodes + original_num_edges = simple_graph_state.tensor.num_edges + + simple_graph_state.extend(other_state) + + # Check batch shape is updated + assert simple_graph_state.tensor.batch_shape[0] == 2 + + # Check number of nodes and edges doubled + assert simple_graph_state.tensor.num_nodes == 2 * original_num_nodes + assert simple_graph_state.tensor.num_edges == 2 * original_num_edges + + # Check that batch indices are properly updated + batch_indices = simple_graph_state.tensor.batch + assert torch.equal(batch_indices[:original_num_nodes], torch.zeros(original_num_nodes, dtype=torch.long)) + assert torch.equal(batch_indices[original_num_nodes:], torch.ones(original_num_nodes, dtype=torch.long)) + + +def test_extend_2d_batch(): + """Test extending two 2D batch states""" + # Create first state (T=2, B=1) + data1 = Data( + x=torch.tensor([[1.0], [2.0]]), + edge_index=torch.tensor([[0], [1]]), + edge_attr=torch.tensor([[0.5]]) + ) + data2 = Data( + x=torch.tensor([[3.0], [4.0]]), + edge_index=torch.tensor([[0], [1]]), + edge_attr=torch.tensor([[0.6]]) + ) + batch1 = Batch.from_data_list([data1, data2]) + batch1.batch_shape = torch.tensor([2, 1]) + state1 = MyGraphStates(batch1) + + # Create second state (T=3, B=1) + data3 = Data( + x=torch.tensor([[5.0], [6.0]]), + edge_index=torch.tensor([[0], [1]]), + edge_attr=torch.tensor([[0.7]]) + ) + data4 = Data( + x=torch.tensor([[7.0], [8.0]]), + edge_index=torch.tensor([[0], [1]]), + edge_attr=torch.tensor([[0.8]]) + ) + data5 = Data( + x=torch.tensor([[9.0], [10.0]]), + edge_index=torch.tensor([[0], [1]]), + edge_attr=torch.tensor([[0.9]]) + ) + batch2 = Batch.from_data_list([data3, data4, data5]) + batch2.batch_shape = torch.tensor([3, 1]) + state2 = MyGraphStates(batch2) + + # Extend state1 with state2 + state1.extend(state2) + + # Check final shape should be (max_len=3, B=2) + assert torch.equal(state1.tensor.batch_shape, torch.tensor([3, 2])) + + # Check that we have the correct number of nodes and edges + # Each graph has 2 nodes and 1 edge + # For 3 time steps and 2 batches, we should have: + expected_nodes = 3 * 2 * 2 # T * nodes_per_graph * B + expected_edges = 3 * 1 * 2 # T * edges_per_graph * B + + # The actual count might be higher due to padding with sink states + assert state1.tensor.num_nodes >= expected_nodes + assert state1.tensor.num_edges >= expected_edges + + +def test_getitem(): + """Test indexing into GraphStates""" + # Create a batch with 3 graphs + data1 = Data( + x=torch.tensor([[1.0], [2.0]]), + edge_index=torch.tensor([[0], [1]]), + edge_attr=torch.tensor([[0.5]]) + ) + data2 = Data( + x=torch.tensor([[3.0], [4.0]]), + edge_index=torch.tensor([[0], [1]]), + edge_attr=torch.tensor([[0.6]]) + ) + data3 = Data( + x=torch.tensor([[5.0], [6.0]]), + edge_index=torch.tensor([[0], [1]]), + edge_attr=torch.tensor([[0.7]]) + ) + batch = Batch.from_data_list([data1, data2, data3]) + batch.batch_shape = torch.tensor([3]) + states = MyGraphStates(batch) + + # Get a single graph + single_state = states[1] + assert single_state.tensor.batch_shape[0] == 1 + assert single_state.tensor.num_nodes == 2 + assert torch.allclose(single_state.tensor.x, torch.tensor([[3.0], [4.0]])) + + # Get multiple graphs + multi_state = states[[0, 2]] + assert multi_state.tensor.batch_shape[0] == 2 + assert multi_state.tensor.num_nodes == 4 + + # Check first graph in selection + first_graph = multi_state.tensor.get_example(0) + assert torch.allclose(first_graph.x, torch.tensor([[1.0], [2.0]])) + + # Check second graph in selection + second_graph = multi_state.tensor.get_example(1) + assert torch.allclose(second_graph.x, torch.tensor([[5.0], [6.0]])) + + +def test_clone(simple_graph_state): + """Test cloning a GraphStates object""" + cloned = simple_graph_state.clone() + + # Check that the clone has the same content + assert torch.equal(cloned.tensor.batch_shape, simple_graph_state.tensor.batch_shape) + assert torch.equal(cloned.tensor.x, simple_graph_state.tensor.x) + assert torch.equal(cloned.tensor.edge_index, simple_graph_state.tensor.edge_index) + assert torch.equal(cloned.tensor.edge_attr, simple_graph_state.tensor.edge_attr) + + # Modify the clone and check that the original is unchanged + cloned.tensor.x[0, 0] = 99.0 + assert cloned.tensor.x[0, 0] == 99.0 + assert simple_graph_state.tensor.x[0, 0] == 1.0 + + +def test_is_initial_state(): + """Test is_initial_state property""" + # Create a batch with s0 and a different graph + s0 = MyGraphStates.s0.clone() + different = Data( + x=torch.tensor([[5.0], [6.0]]), + edge_index=torch.tensor([[0], [1]]), + edge_attr=torch.tensor([[0.9]]) + ) + batch = Batch.from_data_list([s0, different]) + batch.batch_shape = torch.tensor([2]) + states = MyGraphStates(batch) + + # Check is_initial_state + is_initial = states.is_initial_state + assert is_initial[0] == True + assert is_initial[1] == False + + +def test_is_sink_state(): + """Test is_sink_state property""" + # Create a batch with sf and a different graph + sf = MyGraphStates.sf.clone() + different = Data( + x=torch.tensor([[5.0], [6.0]]), + edge_index=torch.tensor([[0], [1]]), + edge_attr=torch.tensor([[0.9]]) + ) + batch = Batch.from_data_list([sf, different]) + batch.batch_shape = torch.tensor([2]) + states = MyGraphStates(batch) + + # Check is_sink_state + is_sink = states.is_sink_state + assert is_sink[0] == True + assert is_sink[1] == False + + +def test_from_batch_shape(): + """Test creating states from batch shape""" + # Create states with initial state + states = MyGraphStates.from_batch_shape((3,)) + assert states.tensor.batch_shape[0] == 3 + assert states.tensor.num_nodes == 6 # 3 graphs * 2 nodes per graph + + # Check all graphs are s0 + is_initial = states.is_initial_state + assert torch.all(is_initial) + + # Create states with sink state + sink_states = MyGraphStates.from_batch_shape((2,), sink=True) + assert sink_states.tensor.batch_shape[0] == 2 + assert sink_states.tensor.num_nodes == 4 # 2 graphs * 2 nodes per graph + + # Check all graphs are sf + is_sink = sink_states.is_sink_state + assert torch.all(is_sink) + + +def test_forward_masks(): + """Test forward_masks property""" + # Create a graph with 2 nodes and 1 edge + data = Data( + x=torch.tensor([[1.0], [2.0]]), + edge_index=torch.tensor([[0], [1]]), + edge_attr=torch.tensor([[0.5]]) + ) + batch = Batch.from_data_list([data]) + batch.batch_shape = torch.tensor([1]) + states = MyGraphStates(batch) + + # Get forward masks + masks = states.forward_masks + + # Check action type mask + assert masks["action_type"].shape == (1, 3) + assert masks["action_type"][0, GraphActionType.ADD_NODE] == True # Can add node + assert masks["action_type"][0, GraphActionType.ADD_EDGE] == True # Can add edge (2 nodes) + assert masks["action_type"][0, GraphActionType.EXIT] == True # Can exit + + # Check features mask + assert masks["features"].shape == (1, 1) # 1 feature dimension + assert masks["features"][0, 0] == True # All features allowed + + # Check edge_index masks + assert len(masks["edge_index"]) == 1 # 1 graph + assert torch.all(masks["edge_index"][0] == torch.tensor([[False, False], [True, False]])) + + +def test_backward_masks(): + """Test backward_masks property""" + # Create a graph with 2 nodes and 1 edge + data = Data( + x=torch.tensor([[1.0], [2.0]]), + edge_index=torch.tensor([[0], [1]]), + edge_attr=torch.tensor([[0.5]]) + ) + batch = Batch.from_data_list([data]) + batch.batch_shape = torch.tensor([1]) + states = MyGraphStates(batch) + + # Get backward masks + masks = states.backward_masks + + # Check action type mask + assert masks["action_type"].shape == (1, 3) + assert masks["action_type"][0, GraphActionType.ADD_NODE] == True # Can remove node + assert masks["action_type"][0, GraphActionType.ADD_EDGE] == True # Can remove edge + assert masks["action_type"][0, GraphActionType.EXIT] == True # Can exit + + # Check features mask + assert masks["features"].shape == (1, 1) # 1 feature dimension + assert masks["features"][0, 0] == True # All features allowed + + # Check edge_index masks + assert len(masks["edge_index"]) == 1 # 1 graph + assert torch.all(masks["edge_index"][0] == torch.tensor([[False, True], [False, False]])) + + +def test_stack(): + """Test stacking GraphStates objects""" + # Create two states + data1 = Data( + x=torch.tensor([[1.0], [2.0]]), + edge_index=torch.tensor([[0], [1]]), + edge_attr=torch.tensor([[0.5]]) + ) + batch1 = Batch.from_data_list([data1]) + batch1.batch_shape = torch.tensor([1]) + state1 = MyGraphStates(batch1) + + data2 = Data( + x=torch.tensor([[3.0], [4.0]]), + edge_index=torch.tensor([[0], [1]]), + edge_attr=torch.tensor([[0.7]]) + ) + batch2 = Batch.from_data_list([data2]) + batch2.batch_shape = torch.tensor([1]) + state2 = MyGraphStates(batch2) + + # Stack the states + stacked = MyGraphStates.stack([state1, state2]) + + # Check the batch shape + assert torch.equal(stacked.tensor.batch_shape, torch.tensor([2, 1])) + + # Check the number of nodes and edges + assert stacked.tensor.num_nodes == 4 # 2 states * 2 nodes + assert stacked.tensor.num_edges == 2 # 2 states * 1 edge + + # Check the batch indices + batch_indices = stacked.tensor.batch + assert torch.equal(batch_indices[:2], torch.zeros(2, dtype=torch.long)) + assert torch.equal(batch_indices[2:], torch.ones(2, dtype=torch.long)) diff --git a/testing/test_samplers_and_trajectories.py b/testing/test_samplers_and_trajectories.py index df5fd643..2cb1cbb5 100644 --- a/testing/test_samplers_and_trajectories.py +++ b/testing/test_samplers_and_trajectories.py @@ -2,13 +2,19 @@ import pytest import torch +from tensordict import TensorDict +from torch import nn +from torch_geometric.nn import GCNConv +from gfn.actions import GraphActionType from gfn.containers import Trajectories from gfn.containers.replay_buffer import ReplayBuffer from gfn.gym import Box, DiscreteEBM, HyperGrid +from gfn.gym.graph_building import GraphBuilding from gfn.gym.helpers.box_utils import BoxPBEstimator, BoxPBMLP, BoxPFEstimator, BoxPFMLP -from gfn.modules import DiscretePolicyEstimator, GFNModule +from gfn.modules import DiscretePolicyEstimator, GFNModule, GraphActionPolicyEstimator from gfn.samplers import LocalSearchSampler, Sampler +from gfn.states import GraphStates from gfn.utils.modules import MLP from gfn.utils.prob_calculations import get_trajectory_pfs from gfn.utils.training import states_actions_tns_to_traj @@ -356,3 +362,65 @@ def test_states_actions_tns_to_traj(): trajs = states_actions_tns_to_traj(states, actions, env) replay_buffer.add(trajs) + + +# ------ GRAPH TESTS ------ + + +class GraphActionNet(nn.Module): + def __init__(self, feature_dim: int): + super().__init__() + self.feature_dim = feature_dim + self.action_type_conv = GCNConv(feature_dim, 3) + self.features_conv = GCNConv(feature_dim, feature_dim) + self.edge_index_conv = GCNConv(feature_dim, 8) + + def forward(self, states: GraphStates) -> TensorDict: + node_feature = states.tensor.x.reshape(-1, self.feature_dim) + + if states.tensor.x.shape[0] == 0: + action_type = torch.zeros((len(states), 3)) + action_type[:, GraphActionType.ADD_NODE] = 1 + features = torch.zeros((len(states), self.feature_dim)) + else: + action_type = self.action_type_conv(node_feature, states.tensor.edge_index) + action_type = action_type.reshape( + len(states), -1, action_type.shape[-1] + ).mean(dim=1) + features = self.features_conv(node_feature, states.tensor.edge_index) + features = features.reshape(len(states), -1, features.shape[-1]).mean(dim=1) + + edge_index = self.edge_index_conv(node_feature, states.tensor.edge_index) + edge_index = torch.einsum("nf,mf->nm", edge_index, edge_index) + edge_index = edge_index[None].repeat(len(states), 1, 1) + + return TensorDict( + { + "action_type": action_type, + "features": features, + "edge_index": edge_index.reshape( + states.batch_shape + edge_index.shape[1:] + ), + }, + batch_size=states.batch_shape, + ) + + +def test_graph_building(): + feature_dim = 8 + env = GraphBuilding( + feature_dim=feature_dim, state_evaluator=lambda s: torch.zeros(s.batch_shape) + ) + + module = GraphActionNet(feature_dim) + pf_estimator = GraphActionPolicyEstimator(module=module) + + sampler = Sampler(estimator=pf_estimator) + trajectories = sampler.sample_trajectories( + env, + n=7, + save_logprobs=True, + save_estimator_outputs=False, + ) + + assert len(trajectories) == 7 diff --git a/tutorials/examples/train_graph_ring.py b/tutorials/examples/train_graph_ring.py new file mode 100644 index 00000000..b7003f7b --- /dev/null +++ b/tutorials/examples/train_graph_ring.py @@ -0,0 +1,970 @@ +"""Train a GFlowNet to generate ring graphs. + +This example demonstrates training a GFlowNet to generate graphs that are rings - where each vertex +has exactly two neighbors and the edges form a single cycle containing all vertices. The environment +supports both directed and undirected ring generation. + +Key components: +- RingGraphBuilding: Environment for building ring graphs +- RingPolicyModule: GNN-based policy network for predicting actions +- directed_reward/undirected_reward: Reward functions for validating ring structures +""" + +import math +import time +from typing import Optional + +from matplotlib import patches +import matplotlib.pyplot as plt +import torch +from tensordict import TensorDict +from torch import nn +from torch_geometric.data import Batch, Data +from torch_geometric.nn import GINConv, GCNConv, DirGNNConv + +from gfn.actions import Actions, GraphActions, GraphActionType +from gfn.gflownet.trajectory_balance import TBGFlowNet +from gfn.gym import GraphBuilding +from gfn.modules import DiscretePolicyEstimator +from gfn.preprocessors import Preprocessor +from gfn.states import GraphStates +from gfn.utils.modules import MLP +from gfn.containers import ReplayBuffer + + +REW_VAL = 100.0 +EPS_VAL = 1e-6 + + +def directed_reward(states: GraphStates) -> torch.Tensor: + """Compute reward for directed ring graphs. + + This function evaluates if a graph forms a valid directed ring (directed cycle). + A valid directed ring must satisfy these conditions: + 1. Each node must have exactly one outgoing edge (row sum = 1 in adjacency matrix) + 2. Each node must have exactly one incoming edge (column sum = 1 in adjacency matrix) + 3. Following the edges must form a single cycle that includes all nodes + + The reward is binary: + - REW_VAL (100.0) for valid directed rings + - EPS_VAL (1e-6) for invalid structures + + Args: + states: A batch of graph states to evaluate + + Returns: + A tensor of rewards with the same batch shape as states + """ + if states.tensor.edge_index.numel() == 0: + return torch.full(states.batch_shape, EPS_VAL) + + out = torch.full((len(states),), EPS_VAL) # Default reward. + + for i in range(len(states)): + graph = states[i] + adj_matrix = torch.zeros(graph.tensor.num_nodes, graph.tensor.num_nodes) + adj_matrix[graph.tensor.edge_index[0], graph.tensor.edge_index[1]] = 1 + + # Check if each node has exactly one outgoing edge (row sum = 1) + if not torch.all(adj_matrix.sum(dim=1) == 1): + continue + + # Check that each node has exactly one incoming edge (column sum = 1) + if not torch.all(adj_matrix.sum(dim=0) == 1): + continue + + # Starting from node 0, follow edges and see if we visit all nodes + # and return to the start + visited = [] + current = 0 # Start from node 0 + + while current not in visited: + visited.append(current) + + # Get the outgoing neighbor + current = torch.where(adj_matrix[current] == 1)[0].item() + + # If we've visited all nodes and returned to 0, it's a valid ring + if len(visited) == graph.tensor.num_nodes and current == 0: + out[i] = REW_VAL + break + + return out.view(*states.batch_shape) + + +def undirected_reward(states: GraphStates) -> torch.Tensor: + """Compute reward for undirected ring graphs. + + This function evaluates if a graph forms a valid undirected ring (cycle). + A valid undirected ring must satisfy these conditions: + 1. Each node must have exactly two neighbors (degree = 2) + 2. The graph must form a single connected cycle including all nodes + + The reward is binary: + - REW_VAL (100.0) for valid undirected rings + - EPS_VAL (1e-6) for invalid structures + + The algorithm: + 1. Checks that all nodes have degree 2 + 2. Performs a traversal starting from node 0, following edges + 3. Checks if the traversal visits all nodes and returns to start + + Args: + states: A batch of graph states to evaluate + + Returns: + A tensor of rewards with the same batch shape as states + """ + if states.tensor.edge_index.numel() == 0: + return torch.full(states.batch_shape, EPS_VAL) + + out = torch.full((len(states),), EPS_VAL) # Default reward. + + for i in range(len(states)): + graph = states[i] + if graph.tensor.num_nodes == 0: + continue + adj_matrix = torch.zeros(graph.tensor.num_nodes, graph.tensor.num_nodes) + adj_matrix[graph.tensor.edge_index[0], graph.tensor.edge_index[1]] = 1 + adj_matrix[graph.tensor.edge_index[1], graph.tensor.edge_index[0]] = 1 + + # In an undirected ring, every vertex should have degree 2. + if not torch.all(adj_matrix.sum(dim=1) == 2): + continue + + # Traverse the cycle starting from vertex 0. + start_vertex = 0 + visited = [start_vertex] + neighbors = torch.where(adj_matrix[start_vertex] == 1)[0] + if neighbors.numel() == 0: + continue + # Arbitrarily choose one neighbor to begin the traversal. + current = neighbors[0].item() + prev = start_vertex + + while True: + if current == start_vertex: + break + visited.append(current) + current_neighbors = torch.where(adj_matrix[current] == 1)[0] + # Exclude the neighbor we just came from. + current_neighbors_list = [n.item() for n in current_neighbors] + possible = [n for n in current_neighbors_list if n != prev] + if len(possible) != 1: + break + next_node = possible[0] + prev, current = current, next_node + + if current == start_vertex and len(visited) == graph.tensor.num_nodes: + out[i] = REW_VAL + + return out.view(*states.batch_shape) + + +class RingPolicyModule(nn.Module): + """Simple module which outputs a fixed logits for the actions, depending on the number of edges. + + Args: + n_nodes: The number of nodes in the graph. + """ + + def __init__( + self, + n_nodes: int, + directed: bool, + num_conv_layers: int = 1, + embedding_dim: int = 128, + is_backward: bool = False, + ): + super().__init__() + self.hidden_dim = self.embedding_dim = embedding_dim + self.is_directed = directed + self.is_backward = is_backward + self.n_nodes = n_nodes + self.num_conv_layers = num_conv_layers + + # Node embedding layer. + self.embedding = nn.Embedding(n_nodes, self.embedding_dim) + self.conv_blks = nn.ModuleList() + self.exit_mlp = MLP( + input_dim=self.hidden_dim, + output_dim=1, + hidden_dim=self.hidden_dim, + n_hidden_layers=1, + add_layer_norm=True, + ) + + if directed: + for i in range(num_conv_layers): + self.conv_blks.extend( + [ + DirGNNConv( + GCNConv( + self.embedding_dim if i == 0 else self.hidden_dim, + self.hidden_dim, + ), + alpha=0.5, + root_weight=True, + ), + # Process in/out components separately + nn.ModuleList( + [ + nn.Sequential( + nn.Linear( + self.hidden_dim // 2, self.hidden_dim // 2 + ), + nn.ReLU(), + nn.Linear( + self.hidden_dim // 2, self.hidden_dim // 2 + ), + ) + for _ in range( + 2 + ) # One for in-features, one for out-features. + ] + ), + ] + ) + else: # Undirected case. + for _ in range(num_conv_layers): + self.conv_blks.extend( + [ + GINConv( + MLP( + input_dim=self.embedding_dim, + output_dim=self.hidden_dim, + hidden_dim=self.hidden_dim, + n_hidden_layers=1, + add_layer_norm=True, + ), + ), + nn.Sequential( + nn.Linear(self.hidden_dim, self.hidden_dim), + nn.ReLU(), + nn.Linear(self.hidden_dim, self.hidden_dim), + ), + ] + ) + + self.norm = nn.LayerNorm(self.hidden_dim) + + def _group_mean( + self, tensor: torch.Tensor, batch_ptr: torch.Tensor + ) -> torch.Tensor: + cumsum = torch.zeros( + (len(tensor) + 1, *tensor.shape[1:]), + dtype=tensor.dtype, + device=tensor.device, + ) + cumsum[1:] = torch.cumsum(tensor, dim=0) + + # Subtract the end val from each batch idx fom the start val of each batch idx. + size = batch_ptr[1:] - batch_ptr[:-1] + return (cumsum[batch_ptr[1:]] - cumsum[batch_ptr[:-1]]) / size[:, None] + + def forward(self, states_tensor: Batch) -> torch.Tensor: + node_features, batch_ptr = ( + states_tensor.x, + states_tensor.ptr, + ) + batch_size = int(torch.prod(states_tensor.batch_shape)) + + # Multiple action type convolutions with residual connections. + x = self.embedding(node_features.squeeze().int()) + for i in range(0, len(self.conv_blks), 2): + x_new = self.conv_blks[i](x, states_tensor.edge_index) # GIN/GCN conv. + if self.is_directed: + x_in, x_out = torch.chunk(x_new, 2, dim=-1) + # Process each component separately through its own MLP + x_in = self.conv_blks[i + 1][0](x_in) # First MLP in ModuleList. + x_out = self.conv_blks[i + 1][1](x_out) # Second MLP in ModuleList. + x_new = torch.cat([x_in, x_out], dim=-1) + else: + x_new = self.conv_blks[i + 1](x_new) # Linear -> ReLU -> Linear. + + x = x_new + x if i > 0 else x_new # Residual connection. + x = self.norm(x) # Layernorm. + + # This MLP computes the exit action. + node_feature_means = self._group_mean(x, batch_ptr) + exit_action = self.exit_mlp(node_feature_means) + + x = x.reshape(*states_tensor.batch_shape, self.n_nodes, self.hidden_dim) + + # Undirected. + if self.is_directed: + feature_dim = self.hidden_dim // 2 + source_features = x[..., :feature_dim] + target_features = x[..., feature_dim:] + + # Dot product between source and target features (asymmetric). + edgewise_dot_prod = torch.einsum( + "bnf,bmf->bnm", source_features, target_features + ) + edgewise_dot_prod = edgewise_dot_prod / torch.sqrt( + torch.tensor(feature_dim) + ) + + i_up, j_up = torch.triu_indices(self.n_nodes, self.n_nodes, offset=1) + i_lo, j_lo = torch.tril_indices(self.n_nodes, self.n_nodes, offset=-1) + + # Combine them. + i0 = torch.cat([i_up, i_lo]) + i1 = torch.cat([j_up, j_lo]) + out_size = self.n_nodes**2 - self.n_nodes + + else: + # Dot product between all node features (symmetric). + edgewise_dot_prod = torch.einsum("bnf,bmf->bnm", x, x) + edgewise_dot_prod = edgewise_dot_prod / torch.sqrt( + torch.tensor(self.hidden_dim) + ) + i0, i1 = torch.triu_indices(self.n_nodes, self.n_nodes, offset=1) + out_size = (self.n_nodes**2 - self.n_nodes) // 2 + + # Grab the needed elems from the adjacency matrix and reshape. + edge_actions = edgewise_dot_prod[ + torch.arange(batch_size)[:, None, None], i0, i1 + ] + edge_actions = edge_actions.reshape(*states_tensor["batch_shape"], out_size) + + if self.is_backward: + return edge_actions + else: + return torch.cat([edge_actions, exit_action], dim=-1) + + +class RingGraphBuilding(GraphBuilding): + """Environment for building ring graphs with discrete action space. + + This environment is specialized for creating ring graphs where each node has + exactly two neighbors and the edges form a single cycle. The environment supports + both directed and undirected graphs. + + In each state, the policy can: + 1. Add an edge between existing nodes + 2. Use the exit action to terminate graph building + + The action space is discrete, with size: + - For directed graphs: n_nodes^2 - n_nodes + 1 (all possible directed edges + exit) + - For undirected graphs: (n_nodes^2 - n_nodes)/2 + 1 (upper triangle + exit) + + Args: + n_nodes: The number of nodes in the graph. + state_evaluator: A function that evaluates a state and returns a reward. + directed: Whether the graph should be directed. + """ + + def __init__(self, n_nodes: int, state_evaluator: callable, directed: bool): + self.n_nodes = n_nodes + if directed: + # all off-diagonal edges + exit. + self.n_actions = (n_nodes**2 - n_nodes) + 1 + else: + # bottom triangle + exit. + self.n_actions = ((n_nodes**2 - n_nodes) // 2) + 1 + super().__init__(feature_dim=n_nodes, state_evaluator=state_evaluator) + self.is_discrete = True # actions here are discrete, needed for FlowMatching + self.is_directed = directed + + def make_actions_class(self) -> type[Actions]: + env = self + + class RingActions(Actions): + """Actions for building ring graphs. + + Actions are represented as discrete indices where: + - 0 to n_actions-2: Adding an edge between specific nodes + - n_actions-1: EXIT action to terminate the trajectory + - n_actions: DUMMY action (used for padding) + """ + + action_shape = (1,) + dummy_action = torch.tensor([env.n_actions]) + exit_action = torch.tensor([env.n_actions - 1]) + + return RingActions + + def make_states_class(self) -> type[GraphStates]: + env = self + + class RingStates(GraphStates): + """Represents the state of a ring graph building process. + + This class extends GraphStates to specifically handle ring graph states. + Each state represents a graph with a fixed number of nodes where edges + are being added incrementally to form a ring structure. + + The state representation consists of: + - node_feature: Node IDs for each node in the graph (shape: [n_nodes, 1]) + - edge_feature: Features for each edge (shape: [n_edges, 1]) + - edge_index: Indices representing the source and target nodes for each edge (shape: [n_edges, 2]) + + Special states: + - s0: Initial state with n_nodes and no edges + - sf: Terminal state (used as a placeholder) + + The class provides masks for both forward and backward actions to determine + which actions are valid from the current state. + """ + s0 = Data( + x=torch.arange(env.n_nodes)[:, None], + edge_attr=torch.ones((0, 1)), + edge_index=torch.ones((2, 0), dtype=torch.long), + ) + sf = Data( + x=-torch.ones(env.n_nodes)[:, None], + edge_attr=torch.zeros((0, 1)), + edge_index=torch.zeros((2, 0), dtype=torch.long), + ) + + def __init__(self, tensor: Batch): + self.tensor = tensor + self.node_features_dim = tensor.x.shape[-1] + self.edge_features_dim = tensor.edge_attr.shape[-1] + self._log_rewards: Optional[float] = None + + self.n_nodes = env.n_nodes + self.n_actions = env.n_actions + + @property + def forward_masks(self): + """Compute masks for valid forward actions from the current state. + + A forward action is valid if: + 1. The edge doesn't already exist in the graph + 2. The edge connects two distinct nodes + + For directed graphs, all possible src->dst edges are considered. + For undirected graphs, only the upper triangular portion of the adjacency matrix is used. + + The last action is always the EXIT action, which is always valid. + + Returns: + Tensor: Boolean mask of shape [batch_size, n_actions] where True indicates valid actions + """ + # Allow all actions. + forward_masks = torch.ones(len(self), self.n_actions, dtype=torch.bool) + + if env.is_directed: + i_up, j_up = torch.triu_indices( + self.n_nodes, self.n_nodes, offset=1 + ) # Upper triangle. + i_lo, j_lo = torch.tril_indices( + self.n_nodes, self.n_nodes, offset=-1 + ) # Lower triangle. + + # Combine them + ei0 = torch.cat([i_up, i_lo]) + ei1 = torch.cat([j_up, j_lo]) + else: + ei0, ei1 = torch.triu_indices(self.n_nodes, self.n_nodes, offset=1) + + # Remove existing edges. + for i in range(len(self)): + existing_edges = self[i].tensor.edge_index + assert torch.all(existing_edges >= 0) # TODO: convert to test. + + if existing_edges.numel() == 0: + edge_idx = torch.zeros(0, dtype=torch.bool) + else: + edge_idx = torch.logical_and( + existing_edges[0][..., None] == ei0[None], + existing_edges[1][..., None] == ei1[None], + ) + + # Collapse across the edge dimension. + if len(edge_idx.shape) == 2: + edge_idx = edge_idx.sum(0).bool() + + # Adds an unmasked exit action. + edge_idx = torch.cat((edge_idx, torch.BoolTensor([False]))) + forward_masks[i, edge_idx] = ( + False # Disallow the addition of this edge. + ) + + return forward_masks.view(*self.batch_shape, self.n_actions) + + @forward_masks.setter + def forward_masks(self, value: torch.Tensor): + pass # fwd masks is computed on the fly + + @property + def backward_masks(self): + """Compute masks for valid backward actions from the current state. + + A backward action is valid if: + 1. The edge exists in the current graph (i.e., can be removed) + + For directed graphs, all existing edges are considered for removal. + For undirected graphs, only the upper triangular edges are considered. + + The EXIT action is not included in backward masks. + + Returns: + Tensor: Boolean mask of shape [batch_size, n_actions-1] where True indicates valid actions + """ + # Disallow all actions. + backward_masks = torch.zeros( + len(self), self.n_actions - 1, dtype=torch.bool + ) + + for i in range(len(self)): + existing_edges = self[i].tensor.edge_index + + if env.is_directed: + i_up, j_up = torch.triu_indices( + self.n_nodes, self.n_nodes, offset=1 + ) # Upper triangle. + i_lo, j_lo = torch.tril_indices( + self.n_nodes, self.n_nodes, offset=-1 + ) # Lower triangle. + + # Combine them + ei0 = torch.cat([i_up, i_lo]) + ei1 = torch.cat([j_up, j_lo]) + else: + ei0, ei1 = torch.triu_indices( + self.n_nodes, self.n_nodes, offset=1 + ) + + if len(existing_edges) == 0: + edge_idx = torch.zeros(0, dtype=torch.bool) + else: + edge_idx = torch.logical_and( + existing_edges[0][..., None] == ei0[None], + existing_edges[1][..., None] == ei1[None], + ) + # Collapse across the edge dimension. + if len(edge_idx.shape) == 2: + edge_idx = edge_idx.sum(0).bool() + + backward_masks[i, edge_idx] = ( + True # Allow the removal of this edge. + ) + + return backward_masks.view(*self.batch_shape, self.n_actions - 1) + + @backward_masks.setter + def backward_masks(self, value: torch.Tensor): + pass # bwd masks is computed on the fly + + return RingStates + + def _step(self, states: GraphStates, actions: Actions) -> GraphStates: + """Take a step in the environment by applying actions to states. + + Args: + states: Current states batch + actions: Actions to apply + + Returns: + New states after applying the actions + """ + actions = self.convert_actions(states, actions) + new_states = super()._step(states, actions) + assert isinstance(new_states, GraphStates) + return new_states + + def _backward_step(self, states: GraphStates, actions: Actions) -> GraphStates: + """Take a backward step in the environment. + + Args: + states: Current states batch + actions: Actions to apply in reverse + + Returns: + New states after applying the backward actions + """ + actions = self.convert_actions(states, actions) + new_states = super()._backward_step(states, actions) + assert isinstance(new_states, GraphStates) + return new_states + + def convert_actions(self, states: GraphStates, actions: Actions) -> GraphActions: + """Converts the action from discrete space to graph action space. + + This method maps discrete action indices to specific graph operations: + - GraphActionType.ADD_EDGE: Add an edge between specific nodes + - GraphActionType.EXIT: Terminate trajectory + - GraphActionType.DUMMY: No-op action (for padding) + + Args: + states: Current states batch + actions: Discrete actions to convert + + Returns: + Equivalent actions in the GraphActions format + """ + # TODO: factor out into utility function. + action_tensor = actions.tensor.squeeze(-1).clone() + action_type = torch.where( + action_tensor == self.n_actions - 1, + GraphActionType.EXIT, + GraphActionType.ADD_EDGE, + ) + action_type[action_tensor == self.n_actions] = GraphActionType.DUMMY + + # Convert action indices to source-target node pairs + # TODO: factor out into utility function. + if self.is_directed: + i_up, j_up = torch.triu_indices( + self.n_nodes, self.n_nodes, offset=1 + ) # Upper triangle. + i_lo, j_lo = torch.tril_indices( + self.n_nodes, self.n_nodes, offset=-1 + ) # Lower triangle. + + # Combine them + ei0 = torch.cat([i_up, i_lo]) + ei1 = torch.cat([j_up, j_lo]) + + else: + ei0, ei1 = torch.triu_indices(self.n_nodes, self.n_nodes, offset=1) + + # Adds -1 "edge" representing exit, -2 "edge" representing dummy. + ei0 = torch.cat((ei0, torch.IntTensor([-1, -2])), dim=0) + ei1 = torch.cat((ei1, torch.IntTensor([-1, -2])), dim=0) + + # Indexes either the second last element (exit) or la + # action_tensor[action_tensor >= (self.n_actions - 1)] = 0 + ei0, ei1 = ei0[action_tensor], ei1[action_tensor] + + out = GraphActions( + TensorDict( + { + "action_type": action_type, + "features": torch.ones(action_tensor.shape + (1,)), + "edge_index": torch.stack([ei0, ei1], dim=-1), + }, + batch_size=action_tensor.shape, + ) + ) + return out + + +class GraphPreprocessor(Preprocessor): + """Preprocessor for graph states to extract the tensor representation. + + This simple preprocessor extracts the torch_geometric Batch from GraphStates to make + it compatible with the policy networks. It doesn't perform any complex + transformations, just ensuring the tensors are accessible in the right format. + + Args: + feature_dim: The dimension of features in the graph (default: 1) + """ + + def __init__(self, feature_dim: int = 1): + super().__init__(output_dim=feature_dim) + + def preprocess(self, states: GraphStates) -> Batch: + return states.tensor + + def __call__(self, states: GraphStates) -> Batch: + return self.preprocess(states) + + +def render_states(states: GraphStates, state_evaluator: callable, directed: bool): + """Visualize a batch of graph states as ring structures. + + This function creates a matplotlib visualization of graph states, rendering them + as circular layouts with nodes positioned evenly around a circle. For directed + graphs, edges are shown as arrows; for undirected graphs, edges are shown as lines. + + The visualization includes: + - Circular positioning of nodes + - Drawing edges between connected nodes + - Displaying the reward value for each graph + + Args: + states: A batch of graphs to visualize + state_evaluator: Function to compute rewards for each graph + directed: Whether to render directed or undirected edges + """ + rewards = state_evaluator(states) + fig, ax = plt.subplots(2, 4, figsize=(15, 7)) + for i in range(8): + current_ax = ax[i // 4, i % 4] + state = states[i] + n_circles = state.tensor.num_nodes + radius = 5 + xs, ys = [], [] + for j in range(n_circles): + angle = 2 * math.pi * j / n_circles + x = radius * math.cos(angle) + y = radius * math.sin(angle) + xs.append(x) + ys.append(y) + current_ax.add_patch( + patches.Circle((x, y), 0.5, facecolor="none", edgecolor="black") + ) + + edge_index = states[i].tensor.edge_index + + for edge in edge_index.T: + start_x, start_y = xs[edge[0]], ys[edge[0]] + end_x, end_y = xs[edge[1]], ys[edge[1]] + dx = end_x - start_x + dy = end_y - start_y + length = math.sqrt(dx**2 + dy**2) + dx, dy = dx / length, dy / length + + circle_radius = 0.5 + head_thickness = 0.2 + + start_x += dx * circle_radius + start_y += dy * circle_radius + if directed: + end_x -= dx * circle_radius + end_y -= dy * circle_radius + current_ax.arrow( + start_x, + start_y, + end_x - start_x, + end_y - start_y, + head_width=head_thickness, + head_length=head_thickness, + fc="black", + ec="black", + ) + + else: + end_x -= dx * (circle_radius + head_thickness) + end_y -= dy * (circle_radius + head_thickness) + current_ax.plot([start_x, end_x], [start_y, end_y], color="black") + + current_ax.set_title(f"State {i}, $r={rewards[i]:.2f}$") + current_ax.set_xlim(-(radius + 1), radius + 1) + current_ax.set_ylim(-(radius + 1), radius + 1) + current_ax.set_aspect("equal") + current_ax.set_xticks([]) + current_ax.set_yticks([]) + + plt.show() + + +class AdjacencyPolicyModule(nn.Module): + """Policy network that processes flattened adjacency matrices to predict graph actions. + + Unlike the GNN-based RingPolicyModule, this module uses standard MLPs to process + the entire adjacency matrix as a flattened vector. This approach: + + 1. Can directly process global graph structure without message passing + 2. May be more effective for small graphs where global patterns are important + 3. Does not require complex graph neural network operations + + The module architecture consists of: + - An MLP to process the flattened adjacency matrix into an embedding + - An edge MLP that predicts logits for each possible edge action + - An exit MLP that predicts a logit for the exit action + + Args: + n_nodes: Number of nodes in the graph + directed: Whether the graph is directed or undirected + embedding_dim: Dimension of internal embeddings (default: 128) + is_backward: Whether this is a backward policy (default: False) + """ + + def __init__( + self, + n_nodes: int, + directed: bool, + embedding_dim: int = 128, + is_backward: bool = False, + ): + super().__init__() + self.n_nodes = n_nodes + self.is_directed = directed + self.is_backward = is_backward + self.hidden_dim = embedding_dim + + # MLP for processing the flattened adjacency matrix + self.mlp = MLP( + input_dim=n_nodes * n_nodes, # Flattened adjacency matrix + output_dim=embedding_dim, + hidden_dim=embedding_dim, + n_hidden_layers=2, + add_layer_norm=True, + ) + + # Exit action MLP + self.exit_mlp = MLP( + input_dim=embedding_dim, + output_dim=1, + hidden_dim=embedding_dim, + n_hidden_layers=1, + add_layer_norm=True, + ) + + # Edge prediction MLP + if directed: + # For directed graphs: all off-diagonal elements + out_size = n_nodes**2 - n_nodes + else: + # For undirected graphs: upper triangle without diagonal + out_size = (n_nodes**2 - n_nodes) // 2 + + self.edge_mlp = MLP( + input_dim=embedding_dim, + output_dim=out_size, + hidden_dim=embedding_dim, + n_hidden_layers=1, + add_layer_norm=True, + ) + + def forward(self, states_tensor: Batch) -> torch.Tensor: + """Forward pass to compute action logits from graph states. + + Process: + 1. Convert the graph representation to adjacency matrices + 2. Process the flattened adjacency matrices through the main MLP + 3. Predict logits for edge actions and exit action + + Args: + states_tensor: A torch_geometric Batch containing graph state information + + Returns: + A tensor of logits for all possible actions + """ + # Convert the graph to adjacency matrix + batch_size = int(states_tensor.batch_size) + adj_matrices = torch.zeros( + (batch_size, self.n_nodes, self.n_nodes), + device=states_tensor.x.device, + ) + + # Fill the adjacency matrices from edge indices + if states_tensor.edge_index.numel() > 0: + for i in range(batch_size): + eis = states_tensor[i].edge_index + adj_matrices[i, eis[0], eis[1]] = 1 + + # Flatten the adjacency matrices for the MLP + adj_matrices_flat = adj_matrices.view(batch_size, -1) + + # Process with MLP + embedding = self.mlp(adj_matrices_flat) + + # Generate edge and exit actions + edge_actions = self.edge_mlp(embedding) + exit_action = self.exit_mlp(embedding) + + if self.is_backward: + return edge_actions + else: + return torch.cat([edge_actions, exit_action], dim=-1) + + +if __name__ == "__main__": + """ + Main execution for training a GFlowNet to generate ring graphs. + + This script demonstrates the complete workflow of training a GFlowNet + to generate valid ring structures in both directed and undirected settings. + + Configurable parameters: + - N_NODES: Number of nodes in the graph (default: 5) + - N_ITERATIONS: Number of training iterations (default: 1000) + - LR: Learning rate for optimizer (default: 0.001) + - BATCH_SIZE: Batch size for training (default: 128) + - DIRECTED: Whether to generate directed rings (default: True) + - USE_BUFFER: Whether to use a replay buffer (default: False) + - USE_GNN: Whether to use GNN-based policy (True) or MLP-based policy (False) + + The script performs the following steps: + 1. Initialize the environment and policy networks + 2. Train the GFlowNet using trajectory balance + 3. Visualize sample generated graphs + """ + N_NODES = 4 + N_ITERATIONS = 200 + LR = 0.001 + BATCH_SIZE = 1024 + DIRECTED = True + USE_BUFFER = False + USE_GNN = True # Set to False to use MLP with adjacency matrices instead of GNN + + state_evaluator = undirected_reward if not DIRECTED else directed_reward + torch.random.manual_seed(7) + env = RingGraphBuilding( + n_nodes=N_NODES, state_evaluator=state_evaluator, directed=DIRECTED + ) + + # Choose model type based on USE_GNN flag + if USE_GNN: + module_pf = RingPolicyModule(env.n_nodes, DIRECTED) + module_pb = RingPolicyModule(env.n_nodes, DIRECTED, is_backward=True) + else: + module_pf = AdjacencyPolicyModule(env.n_nodes, DIRECTED) + module_pb = AdjacencyPolicyModule(env.n_nodes, DIRECTED, is_backward=True) + + pf = DiscretePolicyEstimator( + module=module_pf, n_actions=env.n_actions, preprocessor=GraphPreprocessor() + ) + pb = DiscretePolicyEstimator( + module=module_pb, + n_actions=env.n_actions, + preprocessor=GraphPreprocessor(), + is_backward=True, + ) + gflownet = TBGFlowNet(pf, pb) + optimizer = torch.optim.Adam(gflownet.parameters(), lr=LR) + scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.1) + + replay_buffer = ReplayBuffer( + env, + objects_type="trajectories", + capacity=BATCH_SIZE, + prioritized=True, + ) + + losses = [] + + t1 = time.time() + for iteration in range(N_ITERATIONS): + trajectories = gflownet.sample_trajectories( + env, + n=BATCH_SIZE, + save_logprobs=True, # pyright: ignore + epsilon=0.2 * (1 - iteration / N_ITERATIONS), + ) + training_samples = gflownet.to_training_samples(trajectories) + + # Collect rewards for reporting. + if isinstance(training_samples, tuple): + last_states = training_samples[1] + else: + last_states = training_samples.last_states + assert isinstance(last_states, GraphStates) + rewards = state_evaluator(last_states) + + if USE_BUFFER: + with torch.no_grad(): + replay_buffer.add(training_samples) + if iteration > 20: + training_samples = training_samples[: BATCH_SIZE // 2] + buffer_samples = replay_buffer.sample( + n_trajectories=BATCH_SIZE // 2 + ) + training_samples.extend(buffer_samples) + + optimizer.zero_grad() + loss = gflownet.loss(env, training_samples) # pyright: ignore + pct_rings = torch.mean(rewards > 0.1, dtype=torch.float) * 100 + print( + "Iteration {} - Loss: {:.02f}, rings: {:.0f}%".format( + iteration, loss.item(), pct_rings + ) + ) + loss.backward() + optimizer.step() + scheduler.step() + losses.append(loss.item()) + + t2 = time.time() + print("Time:", t2 - t1) + + # This comes from the gflownet, not the buffer. + last_states = trajectories.last_states[:8] + assert isinstance(last_states, GraphStates) + render_states(last_states, state_evaluator, DIRECTED)