From 9ae28b2419953b3bbe873cc7913d1d6447f8c240 Mon Sep 17 00:00:00 2001 From: alip67 Date: Wed, 6 Nov 2024 21:34:27 +0900 Subject: [PATCH 001/102] including Graphs as States for torchgfn --- src/gfn/states.py | 128 +++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 127 insertions(+), 1 deletion(-) diff --git a/src/gfn/states.py b/src/gfn/states.py index c95ac91d..a4a15f7b 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -3,9 +3,10 @@ from abc import ABC from copy import deepcopy from math import prod -from typing import Callable, ClassVar, List, Optional, Sequence +from typing import Callable, ClassVar, List, Optional, Sequence, cast, Tuple import torch +from torch_geometric.data import Batch, Data class States(ABC): @@ -501,3 +502,128 @@ def stack_states(states: List[States]): ) + state_example.batch_shape return stacked_states + +class GraphStates(ABC): + """ + Base class for Graph as a state representation. The `GraphStates` object is a batched collection of + multiple graph objects. The `Batch` object from PyTorch Geometric is used to represent the batch of + graph objects as states. + """ + + + s0: ClassVar[Data] + sf: ClassVar[Data] + node_feature_dim: ClassVar[int] + edge_feature_dim: ClassVar[int] + make_random_states_graph: Callable = lambda x: (_ for _ in ()).throw( + NotImplementedError( + "The environment does not support initialization of random Graph states." + ) + ) + + def __init__(self, graphs: Batch): + self.data: Batch = graphs + self.batch_shape: int = self.data.num_graphs + self._log_rewards: float = None + + @classmethod + def from_batch_shape(cls, batch_shape: int, random: bool = False, sink: bool=False) -> GraphStates: + if random and sink: + raise ValueError("Only one of `random` and `sink` should be True.") + if random: + data = cls.make_random_states_graph(batch_shape) + elif sink: + data = cls.make_sink_states_graph(batch_shape) + else: + data = cls.make_initial_states_graph(batch_shape) + return cls(data) + + @classmethod + def make_initial_states_graph(cls, batch_shape: int) -> Batch: + data = Batch.from_data_list([cls.s0 for _ in range(batch_shape)]) + return data + + @classmethod + def make_sink_states_graph(cls, batch_shape: int) -> Batch: + data = Batch.from_data_list([cls.sf for _ in range(batch_shape)]) + return data + + # @classmethod + # def make_random_states_graph(cls, batch_shape: int) -> Batch: + # data = Batch.from_data_list([cls.make_random_states_graph() for _ in range(batch_shape)]) + # return data + + def __len__(self): + return self.data.batch_size + + def __repr__(self): + return (f"{self.__class__.__name__} object of batch shape {self.batch_shape} and " + f"node feature dim {self.node_feature_dim} and edge feature dim {self.edge_feature_dim}") + + def __getitem__(self, index: int | Sequence[int] | slice) -> GraphStates: + if isinstance(index, int): + out = self.__class__(Batch.from_data_list([self.data[index]])) + elif isinstance(index, (Sequence, slice)): + out = self.__class__(Batch.from_data_list(self.data.index_select(index))) + else: + raise NotImplementedError("Indexing with type {} is not implemented".format(type(index))) + + if self._log_rewards is not None: + out._log_rewards = self._log_rewards[index] + + return out + + def __setitem__(self, index: int | Sequence[int], graph: GraphStates): + """ + Set particular states of the Batch + """ + data_list = self.data.to_data_list() + if isinstance(index, int): + assert len(graph) == 1, "GraphStates must have a batch size of 1 for single index assignment" + data_list[index] = graph.data[0] + self.data = Batch.from_data_list(data_list) + elif isinstance(index, Sequence): + assert len(index) == len(graph), "Index and GraphState must have the same length" + for i, idx in enumerate(index): + data_list[idx] = graph.data[i] + self.data = Batch.from_data_list(data_list) + elif isinstance(index, slice): + assert index.stop - index.start == len(graph), "Index slice and GraphStates must have the same length" + data_list[index] = graph.data.to_data_list() + self.data = Batch.from_data_list(data_list) + else: + raise NotImplementedError("Setters with type {} is not implemented".format(type(index))) + + @property + def device(self) -> torch.device: + return self.data.get_example(0).x.device + + def to(self, device: torch.device) -> GraphStates: + """ + Moves and/or casts the graph states to the specified device + """ + if self.device != device: + self.data = self.data.to(device) + return self + + def clone(self) -> GraphStates: + """Returns a *detached* clone of the current instance using deepcopy.""" + return deepcopy(self) + + def extend(self, other: GraphStates): + """Concatenates to another GraphStates object along the batch dimension""" + self.data = Batch.from_data_list(self.data.to_data_list() + other.data.to_data_list()) + if self._log_rewards is not None: + assert other._log_rewards is not None + self._log_rewards = torch.cat( + (self._log_rewards, other._log_rewards), dim=0 + ) + + + @property + def log_rewards(self) -> torch.Tensor: + return self._log_rewards + + @log_rewards.setter + def log_rewards(self, log_rewards: torch.Tensor) -> None: + self._log_rewards = log_rewards \ No newline at end of file From de6ab1c022a45b5e8462bf283ea30eec19c8e337 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Thu, 7 Nov 2024 21:50:10 +0100 Subject: [PATCH 002/102] add GraphEnv --- src/gfn/env.py | 107 +++++++++++++++++++++++++++++- src/gfn/gym/graph_building.py | 118 ++++++++++++++++++++++++++++++++++ src/gfn/states.py | 41 ++++++++---- 3 files changed, 252 insertions(+), 14 deletions(-) create mode 100644 src/gfn/gym/graph_building.py diff --git a/src/gfn/env.py b/src/gfn/env.py index 7a60e8ec..517fa97c 100644 --- a/src/gfn/env.py +++ b/src/gfn/env.py @@ -2,10 +2,11 @@ from typing import Optional, Tuple, Union import torch +from torch_geometric.data import Batch, Data from gfn.actions import Actions from gfn.preprocessors import IdentityPreprocessor, Preprocessor -from gfn.states import DiscreteStates, States +from gfn.states import DiscreteStates, GraphStates, States from gfn.utils.common import set_seed # Errors @@ -559,3 +560,107 @@ def terminating_states(self) -> DiscreteStates: raise NotImplementedError( "The environment does not support enumeration of states" ) + + +class GraphEnv(Env): + """Base class for graph-based environments.""" + + def __init__( + self, + s0: Data, + node_feature_dim: int, + edge_feature_dim: int, + action_shape: Tuple, + dummy_action: torch.Tensor, + exit_action: torch.Tensor, + sf: Optional[Data] = None, + device_str: Optional[str] = None, + preprocessor: Optional[Preprocessor] = None, + ): + """Initializes a graph-based environment. + + Args: + s0: The initial graph state. + node_feature_dim: The dimension of the node features. + edge_feature_dim: The dimension of the edge features. + action_shape: Tuple representing the shape of the actions. + dummy_action: Tensor of shape "action_shape" representing a dummy action. + exit_action: Tensor of shape "action_shape" representing the exit action. + sf: The final graph state. + device_str: 'cpu' or 'cuda'. Defaults to None, in which case the device is + inferred from s0. + preprocessor: a Preprocessor object that converts raw graph states to a tensor + that can be fed into a neural network. Defaults to None, in which case + the IdentityPreprocessor is used. + """ + self.device = get_device(device_str, default_device=s0.x.device) + + if sf is None: + sf = Data( + x=torch.full((s0.num_nodes, node_feature_dim), -float("inf")).to( + self.device + ), + edge_attr=torch.full( + (s0.num_edges, edge_feature_dim), -float("inf") + ).to(self.device), + edge_index=s0.edge_index, + batch=torch.zeros(s0.num_nodes, dtype=torch.long, device=self.device), + ) + + super().__init__( + s0=s0, + state_shape=(s0.num_nodes, node_feature_dim), + action_shape=action_shape, + dummy_action=dummy_action, + exit_action=exit_action, + sf=sf, + device_str=device_str, + preprocessor=preprocessor, + ) + + self.node_feature_dim = node_feature_dim + self.edge_feature_dim = edge_feature_dim + self.GraphStates = self.make_graph_states_class() + + def make_graph_states_class(self) -> type[GraphStates]: + env = self + + class GraphEnvStates(GraphStates): + s0 = env.s0 + sf = env.sf + node_feature_dim = env.node_feature_dim + edge_feature_dim = env.edge_feature_dim + make_random_states_graph = env.make_random_states_graph + + return GraphEnvStates + + def states_from_tensor(self, tensor: Batch) -> GraphStates: + """Wraps the supplied Batch in a GraphStates instance.""" + return self.GraphStates(tensor) + + def states_from_batch_shape(self, batch_shape: int) -> GraphStates: + """Returns a batch of s0 states with a given batch_shape.""" + return self.GraphStates.from_batch_shape(batch_shape) + + @abstractmethod + def step(self, states: GraphStates, actions: Actions) -> GraphStates: + """Function that takes a batch of graph states and actions and returns a batch of next + graph states.""" + + @abstractmethod + def backward_step(self, states: GraphStates, actions: Actions) -> GraphStates: + """Function that takes a batch of graph states and actions and returns a batch of previous + graph states.""" + + @abstractmethod + def is_action_valid( + self, + states: GraphStates, + actions: Actions, + backward: bool = False, + ) -> bool: + """Returns True if the actions are valid in the given graph states.""" + + @abstractmethod + def make_random_states_graph(self, batch_shape: int) -> Batch: + """Optional method inherited by all GraphStates instances to emit a random Batch of graphs.""" diff --git a/src/gfn/gym/graph_building.py b/src/gfn/gym/graph_building.py new file mode 100644 index 00000000..99848870 --- /dev/null +++ b/src/gfn/gym/graph_building.py @@ -0,0 +1,118 @@ +from copy import copy +from typing import Callable, Literal, Tuple + +import torch +from gfn.actions import Actions +from torch_geometric.data import Data, Batch +from torch_geometric.nn import GCNConv +from gfn.env import GraphEnv +from gfn.states import GraphStates + + +class GraphBuilding(GraphEnv): + + def __init__(self, + num_nodes: int, + node_feature_dim: int, + edge_feature_dim: int, + state_evaluator: Callable[[Batch], torch.Tensor] | None = None, + device_str: Literal["cpu", "cuda"] = "cpu" + ): + s0 = Data(x=torch.zeros((num_nodes, node_feature_dim)).to(device_str)) + exit_action = torch.tensor( + [-float("inf"), -float("inf")], device=torch.device(device_str) + ) + dummy_action = torch.tensor( + [float("inf"), float("inf")], device=torch.device(device_str) + ) + if state_evaluator is None: + state_evaluator = GCNConvEvaluator(node_feature_dim) + self.state_evaluator = state_evaluator + + super().__init__( + s0=s0, + node_feature_dim=node_feature_dim, + edge_feature_dim=edge_feature_dim, + action_shape=(2,), + dummy_action=dummy_action, + exit_action=exit_action, + device_str=device_str, + ) + + + def step(self, states: GraphStates, actions: Actions) -> GraphStates: + """Step function for the GraphBuilding environment. + + Args: + states: GraphStates object representing the current graph. + actions: Actions indicating which edge to add. + + Returns the next graph the new GraphStates. + """ + graphs: Batch = copy.deepcopy(states.data) + assert len(graphs) == len(actions) + + for i, act in enumerate(actions.tensor): + edge_index = torch.cat([graphs[i].edge_index, act.unsqueeze(1)], dim=1) + graphs[i].edge_index = edge_index + + return GraphStates(graphs) + + def backward_step(self, states: GraphStates, actions: Actions) -> GraphStates: + """Backward step function for the GraphBuilding environment. + + Args: + states: GraphStates object representing the current graph. + actions: Actions indicating which edge to remove. + + Returns the previous graph as a new GraphStates. + """ + graphs: Batch = copy.deepcopy(states.data) + assert len(graphs) == len(actions) + + for i, act in enumerate(actions.tensor): + edge_index = graphs[i].edge_index + edge_index = edge_index[:, edge_index[1] != act] + graphs[i].edge_index = edge_index + + return GraphStates(graphs) + + def is_action_valid( + self, states: GraphStates, actions: Actions, backward: bool = False + ) -> bool: + for i, act in enumerate(actions.tensor): + if backward and len(states.data[i].edge_index[1]) == 0: + return False + if not backward and torch.any(states.data[i].edge_index[1] == act): + return False + return True + + def reward(self, final_states: GraphStates) -> torch.Tensor: + """The environment's reward given a state. + This or log_reward must be implemented. + + Args: + final_states: A batch of final states. + + Returns: + torch.Tensor: Tensor of shape "batch_shape" containing the rewards. + """ + return self.state_evaluator(final_states.data).sum(dim=1) + + @property + def log_partition(self) -> float: + "Returns the logarithm of the partition function." + raise NotImplementedError + + @property + def true_dist_pmf(self) -> torch.Tensor: + "Returns a one-dimensional tensor representing the true distribution." + raise NotImplementedError + + +class GCNConvEvaluator: + def __init__(self, num_features): + self.net = GCNConv(num_features, 1) + + def __call__(self, batch: Batch) -> torch.Tensor: + return self.net(batch.x, batch.edge_index) \ No newline at end of file diff --git a/src/gfn/states.py b/src/gfn/states.py index a4a15f7b..8e6d513f 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -3,7 +3,7 @@ from abc import ABC from copy import deepcopy from math import prod -from typing import Callable, ClassVar, List, Optional, Sequence, cast, Tuple +from typing import Callable, ClassVar, List, Optional, Sequence import torch from torch_geometric.data import Batch, Data @@ -503,6 +503,7 @@ def stack_states(states: List[States]): return stacked_states + class GraphStates(ABC): """ Base class for Graph as a state representation. The `GraphStates` object is a batched collection of @@ -510,7 +511,6 @@ class GraphStates(ABC): graph objects as states. """ - s0: ClassVar[Data] sf: ClassVar[Data] node_feature_dim: ClassVar[int] @@ -527,7 +527,9 @@ def __init__(self, graphs: Batch): self._log_rewards: float = None @classmethod - def from_batch_shape(cls, batch_shape: int, random: bool = False, sink: bool=False) -> GraphStates: + def from_batch_shape( + cls, batch_shape: int, random: bool = False, sink: bool = False + ) -> GraphStates: if random and sink: raise ValueError("Only one of `random` and `sink` should be True.") if random: @@ -557,8 +559,10 @@ def __len__(self): return self.data.batch_size def __repr__(self): - return (f"{self.__class__.__name__} object of batch shape {self.batch_shape} and " - f"node feature dim {self.node_feature_dim} and edge feature dim {self.edge_feature_dim}") + return ( + f"{self.__class__.__name__} object of batch shape {self.batch_shape} and " + f"node feature dim {self.node_feature_dim} and edge feature dim {self.edge_feature_dim}" + ) def __getitem__(self, index: int | Sequence[int] | slice) -> GraphStates: if isinstance(index, int): @@ -566,7 +570,9 @@ def __getitem__(self, index: int | Sequence[int] | slice) -> GraphStates: elif isinstance(index, (Sequence, slice)): out = self.__class__(Batch.from_data_list(self.data.index_select(index))) else: - raise NotImplementedError("Indexing with type {} is not implemented".format(type(index))) + raise NotImplementedError( + "Indexing with type {} is not implemented".format(type(index)) + ) if self._log_rewards is not None: out._log_rewards = self._log_rewards[index] @@ -579,20 +585,28 @@ def __setitem__(self, index: int | Sequence[int], graph: GraphStates): """ data_list = self.data.to_data_list() if isinstance(index, int): - assert len(graph) == 1, "GraphStates must have a batch size of 1 for single index assignment" + assert ( + len(graph) == 1 + ), "GraphStates must have a batch size of 1 for single index assignment" data_list[index] = graph.data[0] self.data = Batch.from_data_list(data_list) elif isinstance(index, Sequence): - assert len(index) == len(graph), "Index and GraphState must have the same length" + assert len(index) == len( + graph + ), "Index and GraphState must have the same length" for i, idx in enumerate(index): data_list[idx] = graph.data[i] self.data = Batch.from_data_list(data_list) elif isinstance(index, slice): - assert index.stop - index.start == len(graph), "Index slice and GraphStates must have the same length" + assert index.stop - index.start == len( + graph + ), "Index slice and GraphStates must have the same length" data_list[index] = graph.data.to_data_list() self.data = Batch.from_data_list(data_list) else: - raise NotImplementedError("Setters with type {} is not implemented".format(type(index))) + raise NotImplementedError( + "Setters with type {} is not implemented".format(type(index)) + ) @property def device(self) -> torch.device: @@ -612,18 +626,19 @@ def clone(self) -> GraphStates: def extend(self, other: GraphStates): """Concatenates to another GraphStates object along the batch dimension""" - self.data = Batch.from_data_list(self.data.to_data_list() + other.data.to_data_list()) + self.data = Batch.from_data_list( + self.data.to_data_list() + other.data.to_data_list() + ) if self._log_rewards is not None: assert other._log_rewards is not None self._log_rewards = torch.cat( (self._log_rewards, other._log_rewards), dim=0 ) - @property def log_rewards(self) -> torch.Tensor: return self._log_rewards @log_rewards.setter def log_rewards(self, log_rewards: torch.Tensor) -> None: - self._log_rewards = log_rewards \ No newline at end of file + self._log_rewards = log_rewards From 24e23e8a08a172982b2cd1cd5dc3ed762af69e5a Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Thu, 7 Nov 2024 23:43:15 +0100 Subject: [PATCH 003/102] add deps and reformat --- pyproject.toml | 1 + src/gfn/gym/graph_building.py | 24 ++++++++++++------------ 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 0523821a..53f7f768 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ einops = ">=0.6.1" numpy = ">=1.21.2" python = "^3.10" torch = ">=1.9.0" +torch_geometric = ">=2.6.0" # dev dependencies. black = { version = "24.3", optional = true } diff --git a/src/gfn/gym/graph_building.py b/src/gfn/gym/graph_building.py index 99848870..39f6d8d0 100644 --- a/src/gfn/gym/graph_building.py +++ b/src/gfn/gym/graph_building.py @@ -1,22 +1,23 @@ from copy import copy -from typing import Callable, Literal, Tuple +from typing import Callable, Literal import torch -from gfn.actions import Actions -from torch_geometric.data import Data, Batch +from torch_geometric.data import Batch, Data from torch_geometric.nn import GCNConv + +from gfn.actions import Actions from gfn.env import GraphEnv from gfn.states import GraphStates class GraphBuilding(GraphEnv): - - def __init__(self, + def __init__( + self, num_nodes: int, node_feature_dim: int, edge_feature_dim: int, state_evaluator: Callable[[Batch], torch.Tensor] | None = None, - device_str: Literal["cpu", "cuda"] = "cpu" + device_str: Literal["cpu", "cuda"] = "cpu", ): s0 = Data(x=torch.zeros((num_nodes, node_feature_dim)).to(device_str)) exit_action = torch.tensor( @@ -32,14 +33,13 @@ def __init__(self, super().__init__( s0=s0, node_feature_dim=node_feature_dim, - edge_feature_dim=edge_feature_dim, + edge_feature_dim=edge_feature_dim, action_shape=(2,), dummy_action=dummy_action, exit_action=exit_action, device_str=device_str, ) - def step(self, states: GraphStates, actions: Actions) -> GraphStates: """Step function for the GraphBuilding environment. @@ -55,7 +55,7 @@ def step(self, states: GraphStates, actions: Actions) -> GraphStates: for i, act in enumerate(actions.tensor): edge_index = torch.cat([graphs[i].edge_index, act.unsqueeze(1)], dim=1) graphs[i].edge_index = edge_index - + return GraphStates(graphs) def backward_step(self, states: GraphStates, actions: Actions) -> GraphStates: @@ -74,7 +74,7 @@ def backward_step(self, states: GraphStates, actions: Actions) -> GraphStates: edge_index = graphs[i].edge_index edge_index = edge_index[:, edge_index[1] != act] graphs[i].edge_index = edge_index - + return GraphStates(graphs) def is_action_valid( @@ -86,7 +86,7 @@ def is_action_valid( if not backward and torch.any(states.data[i].edge_index[1] == act): return False return True - + def reward(self, final_states: GraphStates) -> torch.Tensor: """The environment's reward given a state. This or log_reward must be implemented. @@ -115,4 +115,4 @@ def __init__(self, num_features): self.net = GCNConv(num_features, 1) def __call__(self, batch: Batch) -> torch.Tensor: - return self.net(batch.x, batch.edge_index) \ No newline at end of file + return self.net(batch.x, batch.edge_index) From 1f7b220a22f2e77eb721d5f7b57caed4d8166b30 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Fri, 8 Nov 2024 15:10:33 +0100 Subject: [PATCH 004/102] add test, fix errors, add valid action check --- src/gfn/env.py | 67 ++++++++++++----------------------- src/gfn/gym/graph_building.py | 62 +++++++++++++++++++++----------- src/gfn/states.py | 50 +++++++++++++++++--------- testing/test_environments.py | 38 +++++++++++++++++++- 4 files changed, 134 insertions(+), 83 deletions(-) diff --git a/src/gfn/env.py b/src/gfn/env.py index 517fa97c..3d61a06a 100644 --- a/src/gfn/env.py +++ b/src/gfn/env.py @@ -594,35 +594,33 @@ def __init__( the IdentityPreprocessor is used. """ self.device = get_device(device_str, default_device=s0.x.device) - + self.s0 = s0.to(self.device) + + self.node_feature_dim = node_feature_dim + self.edge_feature_dim = edge_feature_dim + self.state_shape = (s0.num_nodes, self.node_feature_dim) + assert s0.x.shape == self.state_shape + if sf is None: sf = Data( - x=torch.full((s0.num_nodes, node_feature_dim), -float("inf")).to( - self.device - ), - edge_attr=torch.full( - (s0.num_edges, edge_feature_dim), -float("inf") - ).to(self.device), + x=torch.full(self.state_shape, -float("inf")), + edge_attr=torch.full((s0.num_edges, edge_feature_dim), -float("inf")), edge_index=s0.edge_index, - batch=torch.zeros(s0.num_nodes, dtype=torch.long, device=self.device), - ) + ).to(self.device) + self.sf: torch.Tensor = sf + assert self.sf.x.shape == self.state_shape - super().__init__( - s0=s0, - state_shape=(s0.num_nodes, node_feature_dim), - action_shape=action_shape, - dummy_action=dummy_action, - exit_action=exit_action, - sf=sf, - device_str=device_str, - preprocessor=preprocessor, - ) + self.action_shape = action_shape + self.dummy_action = dummy_action + self.exit_action = exit_action - self.node_feature_dim = node_feature_dim - self.edge_feature_dim = edge_feature_dim - self.GraphStates = self.make_graph_states_class() + self.States = self.make_states_class() + self.Actions = self.make_actions_class() - def make_graph_states_class(self) -> type[GraphStates]: + self.preprocessor = preprocessor + self.is_discrete = False + + def make_states_class(self) -> type[GraphStates]: env = self class GraphEnvStates(GraphStates): @@ -630,18 +628,10 @@ class GraphEnvStates(GraphStates): sf = env.sf node_feature_dim = env.node_feature_dim edge_feature_dim = env.edge_feature_dim - make_random_states_graph = env.make_random_states_graph + make_random_states_graph = env.make_random_states_tensor return GraphEnvStates - def states_from_tensor(self, tensor: Batch) -> GraphStates: - """Wraps the supplied Batch in a GraphStates instance.""" - return self.GraphStates(tensor) - - def states_from_batch_shape(self, batch_shape: int) -> GraphStates: - """Returns a batch of s0 states with a given batch_shape.""" - return self.GraphStates.from_batch_shape(batch_shape) - @abstractmethod def step(self, states: GraphStates, actions: Actions) -> GraphStates: """Function that takes a batch of graph states and actions and returns a batch of next @@ -651,16 +641,3 @@ def step(self, states: GraphStates, actions: Actions) -> GraphStates: def backward_step(self, states: GraphStates, actions: Actions) -> GraphStates: """Function that takes a batch of graph states and actions and returns a batch of previous graph states.""" - - @abstractmethod - def is_action_valid( - self, - states: GraphStates, - actions: Actions, - backward: bool = False, - ) -> bool: - """Returns True if the actions are valid in the given graph states.""" - - @abstractmethod - def make_random_states_graph(self, batch_shape: int) -> Batch: - """Optional method inherited by all GraphStates instances to emit a random Batch of graphs.""" diff --git a/src/gfn/gym/graph_building.py b/src/gfn/gym/graph_building.py index 39f6d8d0..611a28fd 100644 --- a/src/gfn/gym/graph_building.py +++ b/src/gfn/gym/graph_building.py @@ -1,12 +1,12 @@ -from copy import copy -from typing import Callable, Literal +from copy import deepcopy +from typing import Callable, Literal, Tuple import torch from torch_geometric.data import Batch, Data from torch_geometric.nn import GCNConv from gfn.actions import Actions -from gfn.env import GraphEnv +from gfn.env import GraphEnv, NonValidActionsError from gfn.states import GraphStates @@ -19,7 +19,10 @@ def __init__( state_evaluator: Callable[[Batch], torch.Tensor] | None = None, device_str: Literal["cpu", "cuda"] = "cpu", ): - s0 = Data(x=torch.zeros((num_nodes, node_feature_dim)).to(device_str)) + s0 = Data( + x=torch.zeros((num_nodes, node_feature_dim)), + edge_index=torch.zeros((2, 0), dtype=torch.long), + ).to(device_str) exit_action = torch.tensor( [-float("inf"), -float("inf")], device=torch.device(device_str) ) @@ -49,14 +52,14 @@ def step(self, states: GraphStates, actions: Actions) -> GraphStates: Returns the next graph the new GraphStates. """ - graphs: Batch = copy.deepcopy(states.data) + if not self.is_action_valid(states, actions): + raise NonValidActionsError("Invalid action.") + graphs: Batch = deepcopy(states.data) assert len(graphs) == len(actions) - for i, act in enumerate(actions.tensor): - edge_index = torch.cat([graphs[i].edge_index, act.unsqueeze(1)], dim=1) - graphs[i].edge_index = edge_index - - return GraphStates(graphs) + edge_index = torch.cat([graphs.edge_index, actions.tensor.T], dim=1) + graphs.edge_index = edge_index + return self.States(graphs) def backward_step(self, states: GraphStates, actions: Actions) -> GraphStates: """Backward step function for the GraphBuilding environment. @@ -67,7 +70,9 @@ def backward_step(self, states: GraphStates, actions: Actions) -> GraphStates: Returns the previous graph as a new GraphStates. """ - graphs: Batch = copy.deepcopy(states.data) + if not self.is_action_valid(states, actions, backward=True): + raise NonValidActionsError("Invalid action.") + graphs: Batch = deepcopy(states.data) assert len(graphs) == len(actions) for i, act in enumerate(actions.tensor): @@ -75,17 +80,29 @@ def backward_step(self, states: GraphStates, actions: Actions) -> GraphStates: edge_index = edge_index[:, edge_index[1] != act] graphs[i].edge_index = edge_index - return GraphStates(graphs) + return self.States(graphs) def is_action_valid( self, states: GraphStates, actions: Actions, backward: bool = False ) -> bool: - for i, act in enumerate(actions.tensor): - if backward and len(states.data[i].edge_index[1]) == 0: - return False - if not backward and torch.any(states.data[i].edge_index[1] == act): - return False - return True + current_edges = states.data.edge_index + new_edges = actions.tensor + + if torch.any(new_edges[:, 0] == new_edges[:, 1]): + return False + if current_edges.shape[1] == 0: + return not backward + + if backward: + some_edges_not_exist = torch.any( + torch.all(current_edges[:, None, :] != new_edges.T[:, :, None], dim=0) + ) + return not some_edges_not_exist + else: + some_edges_exist = torch.any( + torch.all(current_edges[:, None, :] == new_edges.T[:, :, None], dim=0) + ) + return not some_edges_exist def reward(self, final_states: GraphStates) -> torch.Tensor: """The environment's reward given a state. @@ -97,7 +114,9 @@ def reward(self, final_states: GraphStates) -> torch.Tensor: Returns: torch.Tensor: Tensor of shape "batch_shape" containing the rewards. """ - return self.state_evaluator(final_states.data).sum(dim=1) + per_node_rew = self.state_evaluator(final_states.data) + node_batch_idx = final_states.data.batch + return torch.bincount(node_batch_idx, weights=per_node_rew) @property def log_partition(self) -> float: @@ -109,10 +128,13 @@ def true_dist_pmf(self) -> torch.Tensor: "Returns a one-dimensional tensor representing the true distribution." raise NotImplementedError + def make_random_states_tensor(self, batch_shape: Tuple) -> GraphStates: + """Generates random states tensor of shape (*batch_shape, num_nodes, node_feature_dim).""" + return self.States.from_batch_shape(batch_shape) class GCNConvEvaluator: def __init__(self, num_features): self.net = GCNConv(num_features, 1) def __call__(self, batch: Batch) -> torch.Tensor: - return self.net(batch.x, batch.edge_index) + return self.net(batch.x, batch.edge_index).squeeze(-1) diff --git a/src/gfn/states.py b/src/gfn/states.py index 8e6d513f..6a94fe84 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -3,7 +3,7 @@ from abc import ABC from copy import deepcopy from math import prod -from typing import Callable, ClassVar, List, Optional, Sequence +from typing import Callable, ClassVar, List, Optional, Sequence, Tuple import torch from torch_geometric.data import Batch, Data @@ -523,7 +523,8 @@ class GraphStates(ABC): def __init__(self, graphs: Batch): self.data: Batch = graphs - self.batch_shape: int = self.data.num_graphs + self.batch_shape: int = len(self.data) + self.state_shape = (self.data.get_example(0).num_nodes, self.node_feature_dim) self._log_rewards: float = None @classmethod @@ -541,19 +542,41 @@ def from_batch_shape( return cls(data) @classmethod - def make_initial_states_graph(cls, batch_shape: int) -> Batch: + def make_initial_states_graph(cls, batch_shape: int | Tuple) -> Batch: + if isinstance(batch_shape, Tuple) and len(batch_shape) > 1: + raise NotImplementedError("Batch shape with more than one dimension is not supported") + if isinstance(batch_shape, Tuple): + batch_shape = batch_shape[0] + data = Batch.from_data_list([cls.s0 for _ in range(batch_shape)]) return data @classmethod - def make_sink_states_graph(cls, batch_shape: int) -> Batch: + def make_sink_states_graph(cls, batch_shape: Tuple) -> Batch: + if isinstance(batch_shape, Tuple) and len(batch_shape) > 1: + raise NotImplementedError("Batch shape with more than one dimension is not supported") + if isinstance(batch_shape, Tuple): + batch_shape = batch_shape[0] + data = Batch.from_data_list([cls.sf for _ in range(batch_shape)]) return data - # @classmethod - # def make_random_states_graph(cls, batch_shape: int) -> Batch: - # data = Batch.from_data_list([cls.make_random_states_graph() for _ in range(batch_shape)]) - # return data + @classmethod + def make_random_states_graph(cls, batch_shape: int) -> Batch: + if isinstance(batch_shape, Tuple) and len(batch_shape) > 1: + raise NotImplementedError("Batch shape with more than one dimension is not supported") + if isinstance(batch_shape, Tuple): + batch_shape = batch_shape[0] + + data_list = [] + for _ in range(batch_shape): + data = Data( + x=torch.rand(cls.s0.num_nodes, cls.node_feature_dim), + edge_attr=torch.rand(cls.s0.num_edges, cls.edge_feature_dim), + edge_index=cls.s0.edge_index, # TODO: make it random + ) + data_list.append(data) + return Batch.from_data_list(data_list) def __len__(self): return self.data.batch_size @@ -564,15 +587,8 @@ def __repr__(self): f"node feature dim {self.node_feature_dim} and edge feature dim {self.edge_feature_dim}" ) - def __getitem__(self, index: int | Sequence[int] | slice) -> GraphStates: - if isinstance(index, int): - out = self.__class__(Batch.from_data_list([self.data[index]])) - elif isinstance(index, (Sequence, slice)): - out = self.__class__(Batch.from_data_list(self.data.index_select(index))) - else: - raise NotImplementedError( - "Indexing with type {} is not implemented".format(type(index)) - ) + def __getitem__(self, index: int | Sequence[int] | slice | torch.Tensor) -> GraphStates: + out = self.__class__(Batch(self.data[index])) if self._log_rewards is not None: out._log_rewards = self._log_rewards[index] diff --git a/testing/test_environments.py b/testing/test_environments.py index 5dbd4cc6..2786d61a 100644 --- a/testing/test_environments.py +++ b/testing/test_environments.py @@ -4,6 +4,7 @@ from gfn.env import NonValidActionsError from gfn.gym import Box, DiscreteEBM, HyperGrid +from gfn.gym.graph_building import GraphBuilding # Utilities. @@ -273,7 +274,7 @@ def test_states_getitem(ndim: int, env_name: str): states = env.reset(batch_shape=ND_BATCH_SHAPE, random=True) # Boolean selector to index batch elements. - selections = torch.randint(0, 2, ND_BATCH_SHAPE, dtype=torch.bool) + selections = torch.randint(0, 2,ND_BATCH_SHAPE, dtype=torch.bool) n_selections = int(torch.sum(selections)) selected_states = states[selections] @@ -316,3 +317,38 @@ def test_get_grid(): # State indices of the grid are ordered from 0:HEIGHT**2. assert (env.get_states_indices(grid).ravel() == torch.arange(HEIGHT**2)).all() + + +def test_graph_env(): + NUM_NODES = 4 + FEATURE_DIM = 8 + BATCH_SIZE = 3 + + env = GraphBuilding(num_nodes=NUM_NODES, node_feature_dim=FEATURE_DIM, edge_feature_dim=FEATURE_DIM) + states = env.reset(batch_shape=BATCH_SIZE) + assert states.batch_shape == BATCH_SIZE + assert states.state_shape == (NUM_NODES, FEATURE_DIM) + + actions_traj = torch.tensor([ + [[0, 1], [1, 2], [2, 3]], + [[0, 2], [1, 3], [2, 4]], + [[0, 3], [1, 4], [2, 5]], + [[0, 4], [1, 5], [2, 6]], + [[0, 5], [1, 6], [2, 7]], + ], dtype=torch.long) + + for action_tensor in actions_traj: + actions = env.actions_from_tensor(action_tensor) + states = env.step(states, actions) + + invalid_actions = torch.tensor([[0, 0], [1, 1], [2, 2]]) + actions = env.actions_from_tensor(invalid_actions) + with pytest.raises(NonValidActionsError): + states = env.step(states, actions) + invalid_actions = torch.tensor(actions_traj[0]) + actions = env.actions_from_tensor(invalid_actions) + with pytest.raises(NonValidActionsError): + states = env.step(states, actions) + + expected_rewards = torch.zeros(BATCH_SIZE) + assert (env.reward(states) == expected_rewards).all() \ No newline at end of file From 63e4f1cb0eee6c17464d7844aedf9bfa31c33aa7 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Fri, 8 Nov 2024 15:13:28 +0100 Subject: [PATCH 005/102] fix formatting --- src/gfn/env.py | 4 ++-- src/gfn/gym/graph_building.py | 3 ++- src/gfn/states.py | 21 ++++++++++++--------- testing/test_environments.py | 29 +++++++++++++++++------------ 4 files changed, 33 insertions(+), 24 deletions(-) diff --git a/src/gfn/env.py b/src/gfn/env.py index 3d61a06a..3c86a3de 100644 --- a/src/gfn/env.py +++ b/src/gfn/env.py @@ -595,12 +595,12 @@ def __init__( """ self.device = get_device(device_str, default_device=s0.x.device) self.s0 = s0.to(self.device) - + self.node_feature_dim = node_feature_dim self.edge_feature_dim = edge_feature_dim self.state_shape = (s0.num_nodes, self.node_feature_dim) assert s0.x.shape == self.state_shape - + if sf is None: sf = Data( x=torch.full(self.state_shape, -float("inf")), diff --git a/src/gfn/gym/graph_building.py b/src/gfn/gym/graph_building.py index 611a28fd..8b9dcc59 100644 --- a/src/gfn/gym/graph_building.py +++ b/src/gfn/gym/graph_building.py @@ -97,7 +97,7 @@ def is_action_valid( some_edges_not_exist = torch.any( torch.all(current_edges[:, None, :] != new_edges.T[:, :, None], dim=0) ) - return not some_edges_not_exist + return not some_edges_not_exist else: some_edges_exist = torch.any( torch.all(current_edges[:, None, :] == new_edges.T[:, :, None], dim=0) @@ -132,6 +132,7 @@ def make_random_states_tensor(self, batch_shape: Tuple) -> GraphStates: """Generates random states tensor of shape (*batch_shape, num_nodes, node_feature_dim).""" return self.States.from_batch_shape(batch_shape) + class GCNConvEvaluator: def __init__(self, num_features): self.net = GCNConv(num_features, 1) diff --git a/src/gfn/states.py b/src/gfn/states.py index 6a94fe84..513f814b 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -515,11 +515,6 @@ class GraphStates(ABC): sf: ClassVar[Data] node_feature_dim: ClassVar[int] edge_feature_dim: ClassVar[int] - make_random_states_graph: Callable = lambda x: (_ for _ in ()).throw( - NotImplementedError( - "The environment does not support initialization of random Graph states." - ) - ) def __init__(self, graphs: Batch): self.data: Batch = graphs @@ -544,7 +539,9 @@ def from_batch_shape( @classmethod def make_initial_states_graph(cls, batch_shape: int | Tuple) -> Batch: if isinstance(batch_shape, Tuple) and len(batch_shape) > 1: - raise NotImplementedError("Batch shape with more than one dimension is not supported") + raise NotImplementedError( + "Batch shape with more than one dimension is not supported" + ) if isinstance(batch_shape, Tuple): batch_shape = batch_shape[0] @@ -554,7 +551,9 @@ def make_initial_states_graph(cls, batch_shape: int | Tuple) -> Batch: @classmethod def make_sink_states_graph(cls, batch_shape: Tuple) -> Batch: if isinstance(batch_shape, Tuple) and len(batch_shape) > 1: - raise NotImplementedError("Batch shape with more than one dimension is not supported") + raise NotImplementedError( + "Batch shape with more than one dimension is not supported" + ) if isinstance(batch_shape, Tuple): batch_shape = batch_shape[0] @@ -564,7 +563,9 @@ def make_sink_states_graph(cls, batch_shape: Tuple) -> Batch: @classmethod def make_random_states_graph(cls, batch_shape: int) -> Batch: if isinstance(batch_shape, Tuple) and len(batch_shape) > 1: - raise NotImplementedError("Batch shape with more than one dimension is not supported") + raise NotImplementedError( + "Batch shape with more than one dimension is not supported" + ) if isinstance(batch_shape, Tuple): batch_shape = batch_shape[0] @@ -587,7 +588,9 @@ def __repr__(self): f"node feature dim {self.node_feature_dim} and edge feature dim {self.edge_feature_dim}" ) - def __getitem__(self, index: int | Sequence[int] | slice | torch.Tensor) -> GraphStates: + def __getitem__( + self, index: int | Sequence[int] | slice | torch.Tensor + ) -> GraphStates: out = self.__class__(Batch(self.data[index])) if self._log_rewards is not None: diff --git a/testing/test_environments.py b/testing/test_environments.py index 2786d61a..002e5ea4 100644 --- a/testing/test_environments.py +++ b/testing/test_environments.py @@ -274,7 +274,7 @@ def test_states_getitem(ndim: int, env_name: str): states = env.reset(batch_shape=ND_BATCH_SHAPE, random=True) # Boolean selector to index batch elements. - selections = torch.randint(0, 2,ND_BATCH_SHAPE, dtype=torch.bool) + selections = torch.randint(0, 2, ND_BATCH_SHAPE, dtype=torch.bool) n_selections = int(torch.sum(selections)) selected_states = states[selections] @@ -324,31 +324,36 @@ def test_graph_env(): FEATURE_DIM = 8 BATCH_SIZE = 3 - env = GraphBuilding(num_nodes=NUM_NODES, node_feature_dim=FEATURE_DIM, edge_feature_dim=FEATURE_DIM) + env = GraphBuilding( + num_nodes=NUM_NODES, node_feature_dim=FEATURE_DIM, edge_feature_dim=FEATURE_DIM + ) states = env.reset(batch_shape=BATCH_SIZE) assert states.batch_shape == BATCH_SIZE assert states.state_shape == (NUM_NODES, FEATURE_DIM) - actions_traj = torch.tensor([ - [[0, 1], [1, 2], [2, 3]], - [[0, 2], [1, 3], [2, 4]], - [[0, 3], [1, 4], [2, 5]], - [[0, 4], [1, 5], [2, 6]], - [[0, 5], [1, 6], [2, 7]], - ], dtype=torch.long) + actions_traj = torch.tensor( + [ + [[0, 1], [1, 2], [2, 3]], + [[0, 2], [1, 3], [2, 4]], + [[0, 3], [1, 4], [2, 5]], + [[0, 4], [1, 5], [2, 6]], + [[0, 5], [1, 6], [2, 7]], + ], + dtype=torch.long, + ) for action_tensor in actions_traj: actions = env.actions_from_tensor(action_tensor) states = env.step(states, actions) - invalid_actions = torch.tensor([[0, 0], [1, 1], [2, 2]]) + invalid_actions = torch.tensor([[0, 0], [1, 1], [2, 2]]) actions = env.actions_from_tensor(invalid_actions) with pytest.raises(NonValidActionsError): states = env.step(states, actions) - invalid_actions = torch.tensor(actions_traj[0]) + invalid_actions = torch.tensor(actions_traj[0]) actions = env.actions_from_tensor(invalid_actions) with pytest.raises(NonValidActionsError): states = env.step(states, actions) expected_rewards = torch.zeros(BATCH_SIZE) - assert (env.reward(states) == expected_rewards).all() \ No newline at end of file + assert (env.reward(states) == expected_rewards).all() From 8034fb2bdcdb1bff2a5cc44cf777164de51b7736 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Thu, 14 Nov 2024 13:42:25 +0100 Subject: [PATCH 006/102] add GraphAction --- src/gfn/actions.py | 86 ++++++++++++++++++++++++- src/gfn/env.py | 38 ++++++----- src/gfn/gym/graph_building.py | 116 +++++++++++++++++++++------------- src/gfn/states.py | 6 +- testing/test_environments.py | 112 +++++++++++++++++++++++++------- 5 files changed, 268 insertions(+), 90 deletions(-) diff --git a/src/gfn/actions.py b/src/gfn/actions.py index 2006b018..e6a1e67f 100644 --- a/src/gfn/actions.py +++ b/src/gfn/actions.py @@ -1,8 +1,9 @@ from __future__ import annotations # This allows to use the class name in type hints from abc import ABC +import enum from math import prod -from typing import ClassVar, Sequence +from typing import ClassVar, Optional, Sequence import torch @@ -168,3 +169,86 @@ def is_exit(self) -> torch.Tensor: *self.batch_shape, *((1,) * len(self.__class__.action_shape)) ) return self.compare(exit_actions_tensor) + + +class GraphActionType(enum.Enum): + EXIT = enum.auto() + ADD_NODE = enum.auto() + ADD_EDGE = enum.auto() + + +class GraphActions: + + nodes_features_dim: ClassVar[int] # Dim size of the features tensor. + edge_features_dim: ClassVar[int] # Dim size of the edge features tensor. + + def __init__(self, action_type: GraphActionType, features: torch.Tensor, edge_index: Optional[torch.Tensor] = None): + """Initializes a GraphAction object. + + Args: + action: a GraphActionType indicating the type of action. + features: a tensor of shape (*batch_shape, feature_shape) representing the features of the nodes or of the edges, depending on the action type + edge_index: an tensor of shape (*batch_shape, 2) representing the edge to add. + This must defined if and only if the action type is GraphActionType.AddEdge. + """ + self.action_type = action_type + if self.action_type == GraphActionType.ADD_NODE: + assert features.shape[-1] == self.nodes_features_dim + assert edge_index is None + elif self.action_type == GraphActionType.ADD_EDGE: + assert features.shape[-1] == self.edge_features_dim + assert edge_index is not None + assert edge_index.shape[-1] == 2 + + + self.features = features + self.edge_index = edge_index + self.batch_shape = tuple(self.features.shape[:-1]) + + def __repr__(self): + return f"""GraphAction object of type {self.action_type} and features of shape {self.features.shape}.""" + + @property + def device(self) -> torch.device: + """Returns the device of the features tensor.""" + return self.features.device + + def __len__(self) -> int: + """Returns the number of actions in the batch.""" + return prod(self.batch_shape) + + def __getitem__(self, index: int | Sequence[int] | Sequence[bool]) -> GraphActions: + """Get particular actions of the batch.""" + features = self.features[index] + edge_index = self.edge_index[index] if self.edge_index is not None else None + return GraphActions(self.action_type, features, edge_index) + + def __setitem__(self, index: int | Sequence[int] | Sequence[bool], action: GraphActions) -> None: + """Set particular actions of the batch.""" + assert self.action_type == action.action_type + self.features[index] = action.features + if self.edge_index is not None: + self.edge_index[index] = action.edge_index + + def compare(self, other: GraphActions) -> torch.Tensor: + """Compares the actions to another GraphAction object. + + Args: + other: GraphAction object to compare. + + Returns: boolean tensor of shape batch_shape indicating whether the actions are equal. + """ + if self.action_type != other.action_type: + return torch.zeros(self.batch_shape, dtype=torch.bool, device=self.device) + out = torch.all(self.features == other.features, dim=-1) + if self.edge_index is not None: + out &= torch.all(self.edge_index == other.edge_index, dim=-1) + return out + + @property + def is_exit(self) -> torch.Tensor: + """Returns a boolean tensor of shape `batch_shape` indicating whether the actions are exit actions.""" + return torch.full(self.batch_shape, self.action_type == GraphActionType.Exit, dtype=torch.bool, device=self.device) + + + diff --git a/src/gfn/env.py b/src/gfn/env.py index 3c86a3de..8780c069 100644 --- a/src/gfn/env.py +++ b/src/gfn/env.py @@ -4,7 +4,7 @@ import torch from torch_geometric.data import Batch, Data -from gfn.actions import Actions +from gfn.actions import Actions, GraphActions from gfn.preprocessors import IdentityPreprocessor, Preprocessor from gfn.states import DiscreteStates, GraphStates, States from gfn.utils.common import set_seed @@ -570,9 +570,6 @@ def __init__( s0: Data, node_feature_dim: int, edge_feature_dim: int, - action_shape: Tuple, - dummy_action: torch.Tensor, - exit_action: torch.Tensor, sf: Optional[Data] = None, device_str: Optional[str] = None, preprocessor: Optional[Preprocessor] = None, @@ -593,26 +590,12 @@ def __init__( that can be fed into a neural network. Defaults to None, in which case the IdentityPreprocessor is used. """ - self.device = get_device(device_str, default_device=s0.x.device) - self.s0 = s0.to(self.device) + self.s0 = s0.to(device_str) self.node_feature_dim = node_feature_dim self.edge_feature_dim = edge_feature_dim - self.state_shape = (s0.num_nodes, self.node_feature_dim) - assert s0.x.shape == self.state_shape - - if sf is None: - sf = Data( - x=torch.full(self.state_shape, -float("inf")), - edge_attr=torch.full((s0.num_edges, edge_feature_dim), -float("inf")), - edge_index=s0.edge_index, - ).to(self.device) - self.sf: torch.Tensor = sf - assert self.sf.x.shape == self.state_shape - self.action_shape = action_shape - self.dummy_action = dummy_action - self.exit_action = exit_action + self.sf = sf self.States = self.make_states_class() self.Actions = self.make_actions_class() @@ -632,6 +615,21 @@ class GraphEnvStates(GraphStates): return GraphEnvStates + def make_actions_class(self) -> type[GraphActions]: + """The default Actions class factory for all Environments. + + Returns a class that inherits from Actions and implements assumed methods. + The make_actions_class method should be overwritten to achieve more + environment-specific Actions functionality. + """ + env = self + + class DefaultGraphAction(GraphActions): + nodes_features_dim = env.node_feature_dim + edge_features_dim = env.edge_feature_dim + + return DefaultGraphAction + @abstractmethod def step(self, states: GraphStates, actions: Actions) -> GraphStates: """Function that takes a batch of graph states and actions and returns a batch of next diff --git a/src/gfn/gym/graph_building.py b/src/gfn/gym/graph_building.py index 8b9dcc59..6d017d38 100644 --- a/src/gfn/gym/graph_building.py +++ b/src/gfn/gym/graph_building.py @@ -5,7 +5,7 @@ from torch_geometric.data import Batch, Data from torch_geometric.nn import GCNConv -from gfn.actions import Actions +from gfn.actions import GraphActions, GraphActionType from gfn.env import GraphEnv, NonValidActionsError from gfn.states import GraphStates @@ -13,22 +13,13 @@ class GraphBuilding(GraphEnv): def __init__( self, - num_nodes: int, node_feature_dim: int, edge_feature_dim: int, state_evaluator: Callable[[Batch], torch.Tensor] | None = None, device_str: Literal["cpu", "cuda"] = "cpu", ): - s0 = Data( - x=torch.zeros((num_nodes, node_feature_dim)), - edge_index=torch.zeros((2, 0), dtype=torch.long), - ).to(device_str) - exit_action = torch.tensor( - [-float("inf"), -float("inf")], device=torch.device(device_str) - ) - dummy_action = torch.tensor( - [float("inf"), float("inf")], device=torch.device(device_str) - ) + s0 = Data().to(device_str) + if state_evaluator is None: state_evaluator = GCNConvEvaluator(node_feature_dim) self.state_evaluator = state_evaluator @@ -37,13 +28,10 @@ def __init__( s0=s0, node_feature_dim=node_feature_dim, edge_feature_dim=edge_feature_dim, - action_shape=(2,), - dummy_action=dummy_action, - exit_action=exit_action, device_str=device_str, ) - def step(self, states: GraphStates, actions: Actions) -> GraphStates: + def step(self, states: GraphStates, actions: GraphActions) -> GraphStates: """Step function for the GraphBuilding environment. Args: @@ -57,11 +45,26 @@ def step(self, states: GraphStates, actions: Actions) -> GraphStates: graphs: Batch = deepcopy(states.data) assert len(graphs) == len(actions) - edge_index = torch.cat([graphs.edge_index, actions.tensor.T], dim=1) - graphs.edge_index = edge_index + if actions.action_type == GraphActionType.ADD_NODE: + if graphs.x is None: + graphs.x = actions.features[:, None, :] + else: + graphs.x = torch.cat([graphs.x, actions.features[:, None, :]], dim=1) + + if actions.action_type == GraphActionType.ADD_EDGE: + assert actions.edge_index is not None + if graphs.edge_attr is None: + graphs.edge_attr = actions.features[:, None, :] + assert graphs.edge_index is None + graphs.edge_index = actions.edge_index[:, :, None] + else: + graphs.edge_attr = torch.cat([graphs.edge_attr, actions.features[:, None, :]], dim=1) + graphs.edge_index = torch.cat([graphs.edge_index, actions.edge_index[:, :, None]], dim=2) + return self.States(graphs) - def backward_step(self, states: GraphStates, actions: Actions) -> GraphStates: + + def backward_step(self, states: GraphStates, actions: GraphActions) -> GraphStates: """Backward step function for the GraphBuilding environment. Args: @@ -75,34 +78,59 @@ def backward_step(self, states: GraphStates, actions: Actions) -> GraphStates: graphs: Batch = deepcopy(states.data) assert len(graphs) == len(actions) - for i, act in enumerate(actions.tensor): - edge_index = graphs[i].edge_index - edge_index = edge_index[:, edge_index[1] != act] - graphs[i].edge_index = edge_index + if actions.action_type == GraphActionType.ADD_NODE: + assert graphs.x is not None + is_equal = torch.all(graphs.x == actions.features[:, None], dim=-1) + assert torch.all(torch.sum(is_equal, dim=-1) == 1) + graphs.x = graphs.x[~is_equal].reshape(states.data.batch_size, -1, self.node_feature_dim) + + elif actions.action_type == GraphActionType.ADD_EDGE: + assert actions.edge_index is not None + is_equal = torch.all(graphs.edge_index == actions.edge_index[:, :, None], dim=1) + assert torch.all(torch.sum(is_equal, dim=-1) == 1) + graphs.edge_attr = graphs.edge_attr[~is_equal].reshape(states.data.batch_size, -1, self.edge_feature_dim) + edge_index = graphs.edge_index.permute(0, 2, 1)[~is_equal] + graphs.edge_index = edge_index.reshape(states.data.batch_size, -1, 2).permute(0, 2, 1) return self.States(graphs) def is_action_valid( - self, states: GraphStates, actions: Actions, backward: bool = False + self, states: GraphStates, actions: GraphActions, backward: bool = False ) -> bool: - current_edges = states.data.edge_index - new_edges = actions.tensor - - if torch.any(new_edges[:, 0] == new_edges[:, 1]): - return False - if current_edges.shape[1] == 0: - return not backward + if actions.action_type == GraphActionType.ADD_NODE: + if actions.edge_index is not None: + return False + if states.data.x is None: + return not backward + + equal_nodes_per_batch = torch.sum( + torch.all(states.data.x == actions.features[:, None], dim=-1), + dim=-1 + ) - if backward: - some_edges_not_exist = torch.any( - torch.all(current_edges[:, None, :] != new_edges.T[:, :, None], dim=0) + if backward: # TODO: check if no edge are connected? + return torch.all(equal_nodes_per_batch == 1) + return torch.all(equal_nodes_per_batch == 0) + + if actions.action_type == GraphActionType.ADD_EDGE: + assert actions.edge_index is not None + if torch.any(actions.edge_index[:, 0] == actions.edge_index[:, 1]): + return False + if states.data.edge_index is None: + return not backward + + equal_edges_per_batch_attr = torch.sum( + torch.all(states.data.edge_attr == actions.features[:, None], dim=-1), + dim=-1 ) - return not some_edges_not_exist - else: - some_edges_exist = torch.any( - torch.all(current_edges[:, None, :] == new_edges.T[:, :, None], dim=0) + equal_edges_per_batch_index = torch.sum( + torch.all(states.data.edge_index == actions.edge_index[:, :, None], dim=1), + dim=-1 ) - return not some_edges_exist + if backward: + return torch.all(equal_edges_per_batch_attr == 1) and torch.all(equal_edges_per_batch_index == 1) + return torch.all(equal_edges_per_batch_attr == 0) and torch.all(equal_edges_per_batch_index == 0) + def reward(self, final_states: GraphStates) -> torch.Tensor: """The environment's reward given a state. @@ -114,9 +142,7 @@ def reward(self, final_states: GraphStates) -> torch.Tensor: Returns: torch.Tensor: Tensor of shape "batch_shape" containing the rewards. """ - per_node_rew = self.state_evaluator(final_states.data) - node_batch_idx = final_states.data.batch - return torch.bincount(node_batch_idx, weights=per_node_rew) + return self.state_evaluator(final_states.data) @property def log_partition(self) -> float: @@ -138,4 +164,8 @@ def __init__(self, num_features): self.net = GCNConv(num_features, 1) def __call__(self, batch: Batch) -> torch.Tensor: - return self.net(batch.x, batch.edge_index).squeeze(-1) + out = torch.empty(len(batch), device=batch.x.device) + for i in range(len(batch)): # looks like net doesn't work with batch + out[i] = self.net(batch.x[i], batch.edge_index[i]).mean() + + return out diff --git a/src/gfn/states.py b/src/gfn/states.py index 513f814b..503514c4 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -512,14 +512,13 @@ class GraphStates(ABC): """ s0: ClassVar[Data] - sf: ClassVar[Data] + sf: ClassVar[Optional[Data]] node_feature_dim: ClassVar[int] edge_feature_dim: ClassVar[int] def __init__(self, graphs: Batch): self.data: Batch = graphs self.batch_shape: int = len(self.data) - self.state_shape = (self.data.get_example(0).num_nodes, self.node_feature_dim) self._log_rewards: float = None @classmethod @@ -550,6 +549,9 @@ def make_initial_states_graph(cls, batch_shape: int | Tuple) -> Batch: @classmethod def make_sink_states_graph(cls, batch_shape: Tuple) -> Batch: + if cls.sf is None: + raise NotImplementedError("Sink state is not defined") + if isinstance(batch_shape, Tuple) and len(batch_shape) > 1: raise NotImplementedError( "Batch shape with more than one dimension is not supported" diff --git a/testing/test_environments.py b/testing/test_environments.py index 002e5ea4..cccc2bc1 100644 --- a/testing/test_environments.py +++ b/testing/test_environments.py @@ -2,6 +2,7 @@ import pytest import torch +from gfn.actions import GraphActionType from gfn.env import NonValidActionsError from gfn.gym import Box, DiscreteEBM, HyperGrid from gfn.gym.graph_building import GraphBuilding @@ -320,40 +321,103 @@ def test_get_grid(): def test_graph_env(): - NUM_NODES = 4 FEATURE_DIM = 8 BATCH_SIZE = 3 + NUM_NODES = 5 - env = GraphBuilding( - num_nodes=NUM_NODES, node_feature_dim=FEATURE_DIM, edge_feature_dim=FEATURE_DIM - ) + env = GraphBuilding(node_feature_dim=FEATURE_DIM, edge_feature_dim=FEATURE_DIM) states = env.reset(batch_shape=BATCH_SIZE) assert states.batch_shape == BATCH_SIZE - assert states.state_shape == (NUM_NODES, FEATURE_DIM) - - actions_traj = torch.tensor( - [ - [[0, 1], [1, 2], [2, 3]], - [[0, 2], [1, 3], [2, 4]], - [[0, 3], [1, 4], [2, 5]], - [[0, 4], [1, 5], [2, 6]], - [[0, 5], [1, 6], [2, 7]], - ], - dtype=torch.long, - ) + action_cls = env.make_actions_class() - for action_tensor in actions_traj: - actions = env.actions_from_tensor(action_tensor) + with pytest.raises(NonValidActionsError): + actions = action_cls( + GraphActionType.ADD_EDGE, + torch.rand((BATCH_SIZE, FEATURE_DIM)), + torch.randint(0, 10, (BATCH_SIZE, 2), dtype=torch.long) + ) + states = env.step(states, actions) + + for _ in range(NUM_NODES): + actions = action_cls( + GraphActionType.ADD_NODE, + torch.rand((BATCH_SIZE, FEATURE_DIM)), + ) states = env.step(states, actions) + + assert states.data.x.shape == (BATCH_SIZE, NUM_NODES, FEATURE_DIM) - invalid_actions = torch.tensor([[0, 0], [1, 1], [2, 2]]) - actions = env.actions_from_tensor(invalid_actions) with pytest.raises(NonValidActionsError): + actions = action_cls( + GraphActionType.ADD_NODE, + states.data.x[:, 0], + ) states = env.step(states, actions) - invalid_actions = torch.tensor(actions_traj[0]) - actions = env.actions_from_tensor(invalid_actions) + with pytest.raises(NonValidActionsError): + edge_index = torch.randint(0, 3, (BATCH_SIZE,), dtype=torch.long) + actions = action_cls( + GraphActionType.ADD_EDGE, + torch.rand((BATCH_SIZE, FEATURE_DIM)), + torch.stack([edge_index, edge_index], dim=1) + ) states = env.step(states, actions) - expected_rewards = torch.zeros(BATCH_SIZE) - assert (env.reward(states) == expected_rewards).all() + for i in range(NUM_NODES - 1): + edge_index = torch.tensor([[i, i + 1]] * BATCH_SIZE) + actions = action_cls( + GraphActionType.ADD_EDGE, + torch.rand((BATCH_SIZE, FEATURE_DIM)), + edge_index + ) + states = env.step(states, actions) + + with pytest.raises(NonValidActionsError): + edge_index = torch.tensor([[0, 1]] * BATCH_SIZE) + actions = action_cls( + GraphActionType.ADD_EDGE, + torch.rand((BATCH_SIZE, FEATURE_DIM)), + edge_index + ) + states = env.step(states, actions) + + env.reward(states) + + # with pytest.raises(NonValidActionsError): + # actions = action_cls( + # GraphActionType.ADD_NODE, + # states.data.x[:, 0], + # ) + # states = env.backward_step(states, actions) + + for i in reversed(range(states.data.edge_attr.shape[1])): + actions = action_cls( + GraphActionType.ADD_EDGE, + states.data.edge_attr[:, i], + states.data.edge_index[:, :, i] + ) + states = env.backward_step(states, actions) + + with pytest.raises(NonValidActionsError): + actions = action_cls( + GraphActionType.ADD_EDGE, + torch.rand((BATCH_SIZE, FEATURE_DIM)), + torch.randint(0, 10, (BATCH_SIZE, 2), dtype=torch.long) + ) + states = env.backward_step(states, actions) + + for i in reversed(range(NUM_NODES)): + actions = action_cls( + GraphActionType.ADD_NODE, + states.data.x[:, i], + ) + states = env.backward_step(states, actions) + + assert states.data.x.shape == (BATCH_SIZE, 0, FEATURE_DIM) + + with pytest.raises(NonValidActionsError): + actions = action_cls( + GraphActionType.ADD_NODE, + torch.rand((BATCH_SIZE, FEATURE_DIM)), + ) + states = env.backward_step(states, actions) \ No newline at end of file From d17967155258fd33ac1e669ecc5c80e5afdb7d1f Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Thu, 14 Nov 2024 18:31:27 +0100 Subject: [PATCH 007/102] fix batching mechanism --- src/gfn/actions.py | 12 +++--- src/gfn/gym/graph_building.py | 79 +++++++++++++++++++---------------- testing/test_environments.py | 31 ++++++++------ 3 files changed, 67 insertions(+), 55 deletions(-) diff --git a/src/gfn/actions.py b/src/gfn/actions.py index e6a1e67f..4628590a 100644 --- a/src/gfn/actions.py +++ b/src/gfn/actions.py @@ -192,18 +192,18 @@ def __init__(self, action_type: GraphActionType, features: torch.Tensor, edge_in This must defined if and only if the action type is GraphActionType.AddEdge. """ self.action_type = action_type + batch_dim, features_dim = features.shape if self.action_type == GraphActionType.ADD_NODE: - assert features.shape[-1] == self.nodes_features_dim + assert features_dim == self.nodes_features_dim assert edge_index is None elif self.action_type == GraphActionType.ADD_EDGE: - assert features.shape[-1] == self.edge_features_dim + assert features_dim == self.edge_features_dim assert edge_index is not None - assert edge_index.shape[-1] == 2 - + assert edge_index.shape == (2, batch_dim) self.features = features self.edge_index = edge_index - self.batch_shape = tuple(self.features.shape[:-1]) + self.batch_shape = (batch_dim,) def __repr__(self): return f"""GraphAction object of type {self.action_type} and features of shape {self.features.shape}.""" @@ -248,7 +248,7 @@ def compare(self, other: GraphActions) -> torch.Tensor: @property def is_exit(self) -> torch.Tensor: """Returns a boolean tensor of shape `batch_shape` indicating whether the actions are exit actions.""" - return torch.full(self.batch_shape, self.action_type == GraphActionType.Exit, dtype=torch.bool, device=self.device) + return torch.full(self.batch_shape, self.action_type == GraphActionType.EXIT, dtype=torch.bool, device=self.device) diff --git a/src/gfn/gym/graph_building.py b/src/gfn/gym/graph_building.py index 6d017d38..b433340b 100644 --- a/src/gfn/gym/graph_building.py +++ b/src/gfn/gym/graph_building.py @@ -47,19 +47,19 @@ def step(self, states: GraphStates, actions: GraphActions) -> GraphStates: if actions.action_type == GraphActionType.ADD_NODE: if graphs.x is None: - graphs.x = actions.features[:, None, :] + graphs.x = actions.features else: - graphs.x = torch.cat([graphs.x, actions.features[:, None, :]], dim=1) + graphs.x = torch.cat([graphs.x, actions.features]) if actions.action_type == GraphActionType.ADD_EDGE: assert actions.edge_index is not None if graphs.edge_attr is None: - graphs.edge_attr = actions.features[:, None, :] + graphs.edge_attr = actions.features assert graphs.edge_index is None - graphs.edge_index = actions.edge_index[:, :, None] + graphs.edge_index = actions.edge_index else: - graphs.edge_attr = torch.cat([graphs.edge_attr, actions.features[:, None, :]], dim=1) - graphs.edge_index = torch.cat([graphs.edge_index, actions.edge_index[:, :, None]], dim=2) + graphs.edge_attr = torch.cat([graphs.edge_attr, actions.features]) + graphs.edge_index = torch.cat([graphs.edge_index, actions.edge_index], dim=1) return self.States(graphs) @@ -80,17 +80,17 @@ def backward_step(self, states: GraphStates, actions: GraphActions) -> GraphStat if actions.action_type == GraphActionType.ADD_NODE: assert graphs.x is not None - is_equal = torch.all(graphs.x == actions.features[:, None], dim=-1) - assert torch.all(torch.sum(is_equal, dim=-1) == 1) - graphs.x = graphs.x[~is_equal].reshape(states.data.batch_size, -1, self.node_feature_dim) - + is_equal = torch.any( + torch.all(graphs.x[:, None] == actions.features, dim=-1), + dim=-1 + ) + graphs.x = graphs.x[~is_equal] elif actions.action_type == GraphActionType.ADD_EDGE: assert actions.edge_index is not None - is_equal = torch.all(graphs.edge_index == actions.edge_index[:, :, None], dim=1) - assert torch.all(torch.sum(is_equal, dim=-1) == 1) - graphs.edge_attr = graphs.edge_attr[~is_equal].reshape(states.data.batch_size, -1, self.edge_feature_dim) - edge_index = graphs.edge_index.permute(0, 2, 1)[~is_equal] - graphs.edge_index = edge_index.reshape(states.data.batch_size, -1, 2).permute(0, 2, 1) + is_equal = torch.all(graphs.edge_index[:, None] == actions.edge_index[:, :, None], dim=0) + is_equal = torch.any(is_equal, dim=0) + graphs.edge_attr = graphs.edge_attr[~is_equal] + graphs.edge_index = graphs.edge_index[:, ~is_equal] return self.States(graphs) @@ -103,10 +103,10 @@ def is_action_valid( if states.data.x is None: return not backward - equal_nodes_per_batch = torch.sum( - torch.all(states.data.x == actions.features[:, None], dim=-1), - dim=-1 - ) + equal_nodes_per_batch = torch.all( + states.data.x == actions.features[:, None], dim=-1 + ).reshape(states.data.batch_size, -1) + equal_nodes_per_batch = torch.sum(equal_nodes_per_batch, dim=-1) if backward: # TODO: check if no edge are connected? return torch.all(equal_nodes_per_batch == 1) @@ -114,19 +114,28 @@ def is_action_valid( if actions.action_type == GraphActionType.ADD_EDGE: assert actions.edge_index is not None - if torch.any(actions.edge_index[:, 0] == actions.edge_index[:, 1]): + if torch.any(actions.edge_index[0] == actions.edge_index[1]): return False - if states.data.edge_index is None: - return not backward - - equal_edges_per_batch_attr = torch.sum( - torch.all(states.data.edge_attr == actions.features[:, None], dim=-1), - dim=-1 - ) - equal_edges_per_batch_index = torch.sum( - torch.all(states.data.edge_index == actions.edge_index[:, :, None], dim=1), - dim=-1 - ) + if states.data.num_nodes is None or states.data.num_nodes == 0: + return False + if torch.any(actions.edge_index > states.data.num_nodes): + return False + + batch_idx = actions.edge_index % actions.batch_shape[0] + if torch.any(batch_idx != torch.arange(actions.batch_shape[0])): + return False + if states.data.edge_attr is None: + return True + + equal_edges_per_batch_attr = torch.all( + states.data.edge_attr == actions.features[:, None], dim=-1 + ).reshape(states.data.batch_size, -1) + equal_edges_per_batch_attr = torch.sum(equal_edges_per_batch_attr, dim=-1) + + equal_edges_per_batch_index = torch.all( + states.data.edge_index[:, None] == actions.edge_index[:, :, None], dim=0 + ).reshape(states.data.batch_size, -1) + equal_edges_per_batch_index = torch.sum(equal_edges_per_batch_index, dim=-1) if backward: return torch.all(equal_edges_per_batch_attr == 1) and torch.all(equal_edges_per_batch_index == 1) return torch.all(equal_edges_per_batch_attr == 0) and torch.all(equal_edges_per_batch_index == 0) @@ -164,8 +173,6 @@ def __init__(self, num_features): self.net = GCNConv(num_features, 1) def __call__(self, batch: Batch) -> torch.Tensor: - out = torch.empty(len(batch), device=batch.x.device) - for i in range(len(batch)): # looks like net doesn't work with batch - out[i] = self.net(batch.x[i], batch.edge_index[i]).mean() - - return out + out = self.net(batch.x, batch.edge_index) + out = out.reshape(batch.batch_size, -1) + return out.mean(-1) \ No newline at end of file diff --git a/testing/test_environments.py b/testing/test_environments.py index cccc2bc1..55dda0d5 100644 --- a/testing/test_environments.py +++ b/testing/test_environments.py @@ -334,7 +334,7 @@ def test_graph_env(): actions = action_cls( GraphActionType.ADD_EDGE, torch.rand((BATCH_SIZE, FEATURE_DIM)), - torch.randint(0, 10, (BATCH_SIZE, 2), dtype=torch.long) + torch.randint(0, 10, (2, BATCH_SIZE), dtype=torch.long) ) states = env.step(states, actions) @@ -345,12 +345,13 @@ def test_graph_env(): ) states = env.step(states, actions) - assert states.data.x.shape == (BATCH_SIZE, NUM_NODES, FEATURE_DIM) + assert states.data.x.shape == (BATCH_SIZE * NUM_NODES, FEATURE_DIM) with pytest.raises(NonValidActionsError): + first_node_mask = torch.arange(len(states.data.x)) // BATCH_SIZE == 0 actions = action_cls( GraphActionType.ADD_NODE, - states.data.x[:, 0], + states.data.x[first_node_mask], ) states = env.step(states, actions) @@ -359,16 +360,17 @@ def test_graph_env(): actions = action_cls( GraphActionType.ADD_EDGE, torch.rand((BATCH_SIZE, FEATURE_DIM)), - torch.stack([edge_index, edge_index], dim=1) + torch.stack([edge_index, edge_index]) ) states = env.step(states, actions) for i in range(NUM_NODES - 1): - edge_index = torch.tensor([[i, i + 1]] * BATCH_SIZE) + node_is = torch.arange(i * BATCH_SIZE, (i + 1) * BATCH_SIZE) + node_js = torch.arange((i + 1) * BATCH_SIZE, (i + 2) * BATCH_SIZE) actions = action_cls( GraphActionType.ADD_EDGE, torch.rand((BATCH_SIZE, FEATURE_DIM)), - edge_index + torch.stack([node_is, node_js]) ) states = env.step(states, actions) @@ -377,7 +379,7 @@ def test_graph_env(): actions = action_cls( GraphActionType.ADD_EDGE, torch.rand((BATCH_SIZE, FEATURE_DIM)), - edge_index + edge_index.T ) states = env.step(states, actions) @@ -390,11 +392,13 @@ def test_graph_env(): # ) # states = env.backward_step(states, actions) - for i in reversed(range(states.data.edge_attr.shape[1])): + num_edges_per_batch = states.data.edge_attr.shape[0] // BATCH_SIZE + for i in reversed(range(num_edges_per_batch)): + edge_idx = torch.arange(i * BATCH_SIZE, (i + 1) * BATCH_SIZE) actions = action_cls( GraphActionType.ADD_EDGE, - states.data.edge_attr[:, i], - states.data.edge_index[:, :, i] + states.data.edge_attr[edge_idx], + states.data.edge_index[:, edge_idx] ) states = env.backward_step(states, actions) @@ -402,18 +406,19 @@ def test_graph_env(): actions = action_cls( GraphActionType.ADD_EDGE, torch.rand((BATCH_SIZE, FEATURE_DIM)), - torch.randint(0, 10, (BATCH_SIZE, 2), dtype=torch.long) + torch.randint(0, 10, (2, BATCH_SIZE), dtype=torch.long) ) states = env.backward_step(states, actions) for i in reversed(range(NUM_NODES)): + edge_idx = torch.arange(i * BATCH_SIZE, (i + 1) * BATCH_SIZE) actions = action_cls( GraphActionType.ADD_NODE, - states.data.x[:, i], + states.data.x[edge_idx], ) states = env.backward_step(states, actions) - assert states.data.x.shape == (BATCH_SIZE, 0, FEATURE_DIM) + assert states.data.x.shape == (0, FEATURE_DIM) with pytest.raises(NonValidActionsError): actions = action_cls( From 7ff96d5d8e0484e28813bd242cf6c07987d92355 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Sat, 16 Nov 2024 11:26:06 +0100 Subject: [PATCH 008/102] add support for EXIT action --- src/gfn/actions.py | 59 ++++++++++++++++++++++------------- src/gfn/gym/graph_building.py | 11 +++++-- testing/test_environments.py | 5 ++- 3 files changed, 49 insertions(+), 26 deletions(-) diff --git a/src/gfn/actions.py b/src/gfn/actions.py index 4628590a..7207c1c5 100644 --- a/src/gfn/actions.py +++ b/src/gfn/actions.py @@ -179,31 +179,38 @@ class GraphActionType(enum.Enum): class GraphActions: - nodes_features_dim: ClassVar[int] # Dim size of the features tensor. - edge_features_dim: ClassVar[int] # Dim size of the edge features tensor. + nodes_features_dim: ClassVar[int] + edge_features_dim: ClassVar[int] - def __init__(self, action_type: GraphActionType, features: torch.Tensor, edge_index: Optional[torch.Tensor] = None): + def __init__(self, action_type: GraphActionType, features: Optional[torch.Tensor] = None, edge_index: Optional[torch.Tensor] = None): """Initializes a GraphAction object. Args: action: a GraphActionType indicating the type of action. - features: a tensor of shape (*batch_shape, feature_shape) representing the features of the nodes or of the edges, depending on the action type - edge_index: an tensor of shape (*batch_shape, 2) representing the edge to add. + features: a tensor of shape (batch_shape, feature_shape) representing the features of the nodes or of the edges, depending on the action type. + In case of EXIT action, this can be None. + edge_index: an tensor of shape (batch_shape, 2) representing the edge to add. This must defined if and only if the action type is GraphActionType.AddEdge. """ self.action_type = action_type - batch_dim, features_dim = features.shape - if self.action_type == GraphActionType.ADD_NODE: - assert features_dim == self.nodes_features_dim + if self.action_type == GraphActionType.EXIT: + assert features is None assert edge_index is None - elif self.action_type == GraphActionType.ADD_EDGE: - assert features_dim == self.edge_features_dim - assert edge_index is not None - assert edge_index.shape == (2, batch_dim) - - self.features = features - self.edge_index = edge_index - self.batch_shape = (batch_dim,) + self.features = None + self.edge_index = None + else: + assert features is not None + batch_dim, features_dim = features.shape + if self.action_type == GraphActionType.ADD_NODE: + assert features_dim == self.nodes_features_dim + assert edge_index is None + elif self.action_type == GraphActionType.ADD_EDGE: + assert features_dim == self.edge_features_dim + assert edge_index is not None + assert edge_index.shape == (2, batch_dim) + + self.features = features + self.edge_index = edge_index def __repr__(self): return f"""GraphAction object of type {self.action_type} and features of shape {self.features.shape}.""" @@ -215,19 +222,26 @@ def device(self) -> torch.device: def __len__(self) -> int: """Returns the number of actions in the batch.""" - return prod(self.batch_shape) + if self.action_type == GraphActionType.EXIT: + raise ValueError("Cannot get the length of exit actions.") + else: + assert self.features is not None + return self.features.shape[0] def __getitem__(self, index: int | Sequence[int] | Sequence[bool]) -> GraphActions: """Get particular actions of the batch.""" - features = self.features[index] + features = self.features[index] if self.features is not None else None edge_index = self.edge_index[index] if self.edge_index is not None else None return GraphActions(self.action_type, features, edge_index) def __setitem__(self, index: int | Sequence[int] | Sequence[bool], action: GraphActions) -> None: """Set particular actions of the batch.""" assert self.action_type == action.action_type - self.features[index] = action.features - if self.edge_index is not None: + if self.action_type != GraphActionType.EXIT: + assert self.features is not None + self.features[index] = action.features + if self.action_type == GraphActionType.ADD_EDGE: + assert self.edge_index is not None self.edge_index[index] = action.edge_index def compare(self, other: GraphActions) -> torch.Tensor: @@ -239,7 +253,8 @@ def compare(self, other: GraphActions) -> torch.Tensor: Returns: boolean tensor of shape batch_shape indicating whether the actions are equal. """ if self.action_type != other.action_type: - return torch.zeros(self.batch_shape, dtype=torch.bool, device=self.device) + len_ = self.features.shape[0] if self.features is not None else 1 + return torch.zeros(len_, dtype=torch.bool, device=self.device) out = torch.all(self.features == other.features, dim=-1) if self.edge_index is not None: out &= torch.all(self.edge_index == other.edge_index, dim=-1) @@ -248,7 +263,7 @@ def compare(self, other: GraphActions) -> torch.Tensor: @property def is_exit(self) -> torch.Tensor: """Returns a boolean tensor of shape `batch_shape` indicating whether the actions are exit actions.""" - return torch.full(self.batch_shape, self.action_type == GraphActionType.EXIT, dtype=torch.bool, device=self.device) + return torch.full((1,), self.action_type == GraphActionType.EXIT, dtype=torch.bool, device=self.device) diff --git a/src/gfn/gym/graph_building.py b/src/gfn/gym/graph_building.py index b433340b..2d060759 100644 --- a/src/gfn/gym/graph_building.py +++ b/src/gfn/gym/graph_building.py @@ -43,15 +43,16 @@ def step(self, states: GraphStates, actions: GraphActions) -> GraphStates: if not self.is_action_valid(states, actions): raise NonValidActionsError("Invalid action.") graphs: Batch = deepcopy(states.data) - assert len(graphs) == len(actions) if actions.action_type == GraphActionType.ADD_NODE: + assert len(graphs) == len(actions) if graphs.x is None: graphs.x = actions.features else: graphs.x = torch.cat([graphs.x, actions.features]) if actions.action_type == GraphActionType.ADD_EDGE: + assert len(graphs) == len(actions) assert actions.edge_index is not None if graphs.edge_attr is None: graphs.edge_attr = actions.features @@ -97,6 +98,9 @@ def backward_step(self, states: GraphStates, actions: GraphActions) -> GraphStat def is_action_valid( self, states: GraphStates, actions: GraphActions, backward: bool = False ) -> bool: + if actions.action_type == GraphActionType.EXIT: + return True # TODO: what are the conditions for exit action? + if actions.action_type == GraphActionType.ADD_NODE: if actions.edge_index is not None: return False @@ -121,8 +125,9 @@ def is_action_valid( if torch.any(actions.edge_index > states.data.num_nodes): return False - batch_idx = actions.edge_index % actions.batch_shape[0] - if torch.any(batch_idx != torch.arange(actions.batch_shape[0])): + batch_dim = actions.features.shape[0] + batch_idx = actions.edge_index % batch_dim + if torch.any(batch_idx != torch.arange(batch_dim)): return False if states.data.edge_attr is None: return True diff --git a/testing/test_environments.py b/testing/test_environments.py index 55dda0d5..e157f1bb 100644 --- a/testing/test_environments.py +++ b/testing/test_environments.py @@ -383,12 +383,15 @@ def test_graph_env(): ) states = env.step(states, actions) + actions = action_cls(GraphActionType.EXIT) + states = env.step(states, actions) env.reward(states) # with pytest.raises(NonValidActionsError): + # node_idx = torch.arange(0, BATCH_SIZE) # actions = action_cls( # GraphActionType.ADD_NODE, - # states.data.x[:, 0], + # states.data.x[node_idxs], # ) # states = env.backward_step(states, actions) From dacbbf746b0f31bce7930adf56a15e63cfeaed38 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Tue, 19 Nov 2024 12:40:13 +0100 Subject: [PATCH 009/102] add GraphActionPolicyEstimator --- src/gfn/actions.py | 2 +- src/gfn/modules.py | 165 +++++++++++++++------- src/gfn/states.py | 12 ++ src/gfn/utils/distributions.py | 19 ++- testing/test_samplers_and_trajectories.py | 12 +- 5 files changed, 155 insertions(+), 55 deletions(-) diff --git a/src/gfn/actions.py b/src/gfn/actions.py index 7207c1c5..0bdb5529 100644 --- a/src/gfn/actions.py +++ b/src/gfn/actions.py @@ -172,9 +172,9 @@ def is_exit(self) -> torch.Tensor: class GraphActionType(enum.Enum): - EXIT = enum.auto() ADD_NODE = enum.auto() ADD_EDGE = enum.auto() + EXIT = enum.auto() class GraphActions: diff --git a/src/gfn/modules.py b/src/gfn/modules.py index 2eabf53d..4083d169 100644 --- a/src/gfn/modules.py +++ b/src/gfn/modules.py @@ -1,13 +1,14 @@ from abc import ABC, abstractmethod -from typing import Any +from typing import Any, Dict import torch import torch.nn as nn -from torch.distributions import Categorical, Distribution +from torch.distributions import Categorical, Distribution, Normal +from gfn.actions import GraphActionType from gfn.preprocessors import IdentityPreprocessor, Preprocessor -from gfn.states import DiscreteStates, States -from gfn.utils.distributions import UnsqueezedCategorical +from gfn.states import DiscreteStates, GraphStates, States +from gfn.utils.distributions import ComposedDistribution, UnsqueezedCategorical REDUCTION_FXNS = { "mean": torch.mean, @@ -90,32 +91,11 @@ def forward(self, input: States | torch.Tensor) -> torch.Tensor: """ if isinstance(input, States): input = self.preprocessor(input) - - out = self.module(input) - - if not self._output_dim_is_checked: - self.check_output_dim(out) - self._output_dim_is_checked = True - - return out + return self.module(input) def __repr__(self): return f"{self.__class__.__name__} module" - @property - @abstractmethod - def expected_output_dim(self) -> int: - """Expected output dimension of the module.""" - - def check_output_dim(self, module_output: torch.Tensor) -> None: - """Check that the output of the module has the correct shape. Raises an error if not.""" - assert module_output.dtype == torch.float - if module_output.shape[-1] != self.expected_output_dim(): - raise ValueError( - f"{self.__class__.__name__} output dimension should be {self.expected_output_dim()}" - + f" but is {module_output.shape[-1]}." - ) - def to_probability_distribution( self, states: States, @@ -192,9 +172,6 @@ def __init__( ) self.reduction_fxn = REDUCTION_FXNS[reduction] - def expected_output_dim(self) -> int: - return 1 - def forward(self, input: States | torch.Tensor) -> torch.Tensor: """Forward pass of the module. @@ -212,10 +189,7 @@ def forward(self, input: States | torch.Tensor) -> torch.Tensor: if out.shape[-1] != 1: out = self.reduction_fxn(out, -1) - if not self._output_dim_is_checked: - self.check_output_dim(out) - self._output_dim_is_checked = True - + assert out.shape[-1] == 1 return out @@ -250,12 +224,19 @@ def __init__( """ super().__init__(module, preprocessor, is_backward=is_backward) self.n_actions = n_actions + self.expected_output_dim = self.n_actions - int(self.is_backward) - def expected_output_dim(self) -> int: - if self.is_backward: - return self.n_actions - 1 - else: - return self.n_actions + def forward(self, states: DiscreteStates) -> torch.Tensor: + """Forward pass of the module. + + Args: + states: The input states. + + Returns the output of the module, as a tensor of shape (*batch_shape, output_dim). + """ + out = super().forward(states) + assert out.shape[-1] == self.expected_output_dim + return out def to_probability_distribution( self, @@ -279,7 +260,7 @@ def to_probability_distribution( on policy. epsilon: with probability epsilon, a random action is chosen. Does nothing if set to 0.0 (default), in which case it's on policy.""" - self.check_output_dim(module_output) + assert module_output.shape[-1] == self.expected_output_dim masks = states.backward_masks if self.is_backward else states.forward_masks logits = module_output @@ -296,7 +277,7 @@ def to_probability_distribution( # LogEdgeFlows are greedy, as are most P_B. else: - return UnsqueezedCategorical(logits=logits) + return UnsqueezedCategorical(logits=logits) class ConditionalDiscretePolicyEstimator(DiscretePolicyEstimator): @@ -362,11 +343,7 @@ def forward(self, states: States, conditioning: torch.Tensor) -> torch.Tensor: Returns the output of the module, as a tensor of shape (*batch_shape, output_dim). """ out = self._forward_trunk(states, conditioning) - - if not self._output_dim_is_checked: - self.check_output_dim(out) - self._output_dim_is_checked = True - + assert out.shape[-1] == self.expected_output_dim return out @@ -450,15 +427,9 @@ def forward(self, states: States, conditioning: torch.tensor) -> torch.Tensor: if out.shape[-1] != 1: out = self.reduction_fxn(out, -1) - if not self._output_dim_is_checked: - self.check_output_dim(out) - self._output_dim_is_checked = True - + assert out.shape[-1] == 1 return out - def expected_output_dim(self) -> int: - return 1 - def to_probability_distribution( self, states: States, @@ -466,3 +437,93 @@ def to_probability_distribution( **policy_kwargs: Any, ) -> Distribution: raise NotImplementedError + + +class GraphActionPolicyEstimator(GFNModule): + r"""Container for forward and backward policy estimators for graph environments. + + $s \mapsto (P_F(s' \mid s))_{s' \in Children(s)}$. + + or + + $s \mapsto (P_B(s' \mid s))_{s' \in Parents(s)}$. + + Attributes: + temperature: scalar to divide the logits by before softmax. + sf_bias: scalar to subtract from the exit action logit before dividing by + temperature. + epsilon: with probability epsilon, a random action is chosen. + """ + + def __init__( + self, + module: nn.ModuleDict, + preprocessor: Preprocessor | None = None, + is_backward: bool = False, + ): + """Initializes a estimator for P_F for graph environments. + + Args: + is_backward: if False, then this is a forward policy, else backward policy. + """ + super().__init__(module, preprocessor, is_backward) + assert isinstance(self.module, nn.ModuleDict) + assert self.module.keys() == {"action_type", "edge_index", "features"} + + def forward(self, states: GraphStates) -> Dict[str, torch.Tensor]: + """Forward pass of the module. + + Args: + states: The input graph states. + + Returns the . + """ + action_type_logits = self.module["action_type"](states) + edge_index_logits = self.module["edge_index"](states) + features = self.module["features"](states) + + assert action_type_logits == len(GraphActionType) + assert edge_index_logits.shape[-1] == 2 + return { + "action_type": action_type_logits, + "edge_index": edge_index_logits, + "features": features + } + + def to_probability_distribution( + self, + states: GraphStates, + module_output: Dict[str, torch.Tensor], + temperature: float = 1.0, + epsilon: float = 0.0, + ) -> ComposedDistribution: + """Returns a probability distribution given a batch of states and module output. + + We handle off-policyness using these kwargs. + + Args: + states: The states to use. + module_output: The output of the module as a tensor of shape (*batch_shape, output_dim). + temperature: scalar to divide the logits by before softmax. Does nothing + if set to 1.0 (default), in which case it's on policy. + epsilon: with probability epsilon, a random action is chosen. Does nothing + if set to 0.0 (default), in which case it's on policy.""" + + dists = {} + + action_type_logits = module_output["action_type"] + action_type_masks = states.backward_masks if self.is_backward else states.forward_masks + action_type_logits[~action_type_masks] = -float("inf") + action_type_probs = torch.softmax(action_type_logits / temperature, dim=-1) + uniform_dist_probs = action_type_masks.float() / action_type_masks.sum(dim=-1, keepdim=True) + action_type_probs = (1 - epsilon) * action_type_probs + epsilon * uniform_dist_probs + + edge_index_logits = module_output["edge_index"] + edge_index_probs = torch.softmax(edge_index_logits / temperature, dim=-1) + uniform_dist_probs = torch.ones_like(edge_index_probs) / edge_index_probs.shape[-1] + edge_index_probs = (1 - epsilon) * edge_index_probs + epsilon * uniform_dist_probs + + dists["action_type"] = Categorical(probs=action_type_probs) + dists["features"] = Normal(module_output["features"], temperature) + dists["edge_index"] = Categorical(probs=edge_index_probs) + return ComposedDistribution(dists=dists) diff --git a/src/gfn/states.py b/src/gfn/states.py index 503514c4..5d63a41b 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -8,6 +8,8 @@ import torch from torch_geometric.data import Batch, Data +from gfn.actions import GraphActionType + class States(ABC): """Base class for states, seen as nodes of the DAG. @@ -521,6 +523,16 @@ def __init__(self, graphs: Batch): self.batch_shape: int = len(self.data) self._log_rewards: float = None + # TODO logic repeated from env.is_valid_action + self.forward_masks = torch.ones((self.batch_shape, 3), dtype=torch.bool) + self.forward_masks[:, GraphActionType.ADD_EDGE.value] = self.data.x.shape[0] > 0 + self.forward_masks[:, GraphActionType.EXIT.value] = self.data.x.shape[0] > 0 + + self.backward_masks = torch.ones((self.batch_shape, 3), dtype=torch.bool) + self.backward_masks[:, GraphActionType.ADD_NODE.value] = self.data.x.shape[0] > 0 + self.backward_masks[:, GraphActionType.ADD_EDGE.value] = self.data.edge_attr.shape[0] > 0 + self.backward_masks[:, GraphActionType.EXIT.value] = self.data.x.shape[0] > 0 + @classmethod def from_batch_shape( cls, batch_shape: int, random: bool = False, sink: bool = False diff --git a/src/gfn/utils/distributions.py b/src/gfn/utils/distributions.py index f4948d0d..5cfb9cc8 100644 --- a/src/gfn/utils/distributions.py +++ b/src/gfn/utils/distributions.py @@ -1,5 +1,6 @@ +from typing import Dict import torch -from torch.distributions import Categorical +from torch.distributions import Distribution, Categorical class UnsqueezedCategorical(Categorical): @@ -39,3 +40,19 @@ def log_prob(self, sample: torch.Tensor) -> torch.Tensor: """ assert sample.shape[-1] == 1 return super().log_prob(sample.squeeze(-1)) + + +class ComposedDistribution(Distribution): + """A mixture distribution.""" + + def __init__(self, dists: Dict[str, Distribution]): + """Initializes the mixture distribution. + + Args: + dists: A dictionary of distributions. + """ + super().__init__() + self.dists = dists + + def sample(self, sample_shape: torch.Size) -> Dict[str, torch.Tensor]: + return {k: v.sample(sample_shape) for k, v in self.dists.items()} \ No newline at end of file diff --git a/testing/test_samplers_and_trajectories.py b/testing/test_samplers_and_trajectories.py index 65199552..02014cb0 100644 --- a/testing/test_samplers_and_trajectories.py +++ b/testing/test_samplers_and_trajectories.py @@ -5,10 +5,12 @@ from gfn.containers import Trajectories from gfn.containers.replay_buffer import ReplayBuffer from gfn.gym import Box, DiscreteEBM, HyperGrid +from gfn.gym.graph_building import GraphBuilding from gfn.gym.helpers.box_utils import BoxPBEstimator, BoxPBMLP, BoxPFEstimator, BoxPFMLP -from gfn.modules import DiscretePolicyEstimator, GFNModule +from gfn.modules import DiscretePolicyEstimator, GFNModule, GraphActionPolicyEstimator from gfn.samplers import Sampler from gfn.utils.modules import MLP +from torch_geometric.nn import GCNConv def trajectory_sampling_with_return( @@ -214,3 +216,11 @@ def test_replay_buffer( replay_buffer.add(training_objects) except Exception as e: raise ValueError(f"Error while testing {env_name}") from e + + +def test_graph_building(): + node_feature_dim = 8 + env = GraphBuilding(node_feature_dim=node_feature_dim, edge_feature_dim=4) + graph_net = GCNConv(node_feature_dim, 1) + + GraphActionPolicyEstimator(module=graph_net) From e74e5000e4636f8fc3960a5f60323420d34784fa Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Fri, 22 Nov 2024 15:05:54 +0100 Subject: [PATCH 010/102] Sampler integration work --- src/gfn/actions.py | 24 +++++--- src/gfn/env.py | 25 +++++--- src/gfn/gym/graph_building.py | 12 +++- src/gfn/modules.py | 25 ++++---- src/gfn/samplers.py | 18 +++--- src/gfn/states.py | 69 +++++++++------------- src/gfn/utils/distributions.py | 11 +++- testing/test_samplers_and_trajectories.py | 72 +++++++++++++++++++++-- 8 files changed, 171 insertions(+), 85 deletions(-) diff --git a/src/gfn/actions.py b/src/gfn/actions.py index 0bdb5529..03a1bfdc 100644 --- a/src/gfn/actions.py +++ b/src/gfn/actions.py @@ -171,10 +171,10 @@ def is_exit(self) -> torch.Tensor: return self.compare(exit_actions_tensor) -class GraphActionType(enum.Enum): - ADD_NODE = enum.auto() - ADD_EDGE = enum.auto() - EXIT = enum.auto() +class GraphActionType(enum.IntEnum): + ADD_NODE = 0 + ADD_EDGE = 1 + EXIT = 2 class GraphActions: @@ -182,7 +182,7 @@ class GraphActions: nodes_features_dim: ClassVar[int] edge_features_dim: ClassVar[int] - def __init__(self, action_type: GraphActionType, features: Optional[torch.Tensor] = None, edge_index: Optional[torch.Tensor] = None): + def __init__(self, action_type: torch.Tensor, features: Optional[torch.Tensor] = None, edge_index: Optional[torch.Tensor] = None): """Initializes a GraphAction object. Args: @@ -192,7 +192,9 @@ def __init__(self, action_type: GraphActionType, features: Optional[torch.Tensor edge_index: an tensor of shape (batch_shape, 2) representing the edge to add. This must defined if and only if the action type is GraphActionType.AddEdge. """ - self.action_type = action_type + self.batch_shape = action_type.shape + assert torch.all(action_type == action_type[0]) + self.action_type = action_type[0] if self.action_type == GraphActionType.EXIT: assert features is None assert edge_index is None @@ -201,9 +203,9 @@ def __init__(self, action_type: GraphActionType, features: Optional[torch.Tensor else: assert features is not None batch_dim, features_dim = features.shape + assert (batch_dim,) == self.batch_shape if self.action_type == GraphActionType.ADD_NODE: assert features_dim == self.nodes_features_dim - assert edge_index is None elif self.action_type == GraphActionType.ADD_EDGE: assert features_dim == self.edge_features_dim assert edge_index is not None @@ -265,5 +267,13 @@ def is_exit(self) -> torch.Tensor: """Returns a boolean tensor of shape `batch_shape` indicating whether the actions are exit actions.""" return torch.full((1,), self.action_type == GraphActionType.EXIT, dtype=torch.bool, device=self.device) + @classmethod + def make_dummy_actions(cls, batch_shape: tuple[int]) -> GraphActions: # TODO: remove make_dummy_actions + """Creates an Actions object of dummy actions with the given batch shape.""" + return GraphActions( + action_type=torch.full(batch_shape, fill_value=GraphActionType.EXIT), + features=None, + edge_index=None + ) diff --git a/src/gfn/env.py b/src/gfn/env.py index 8780c069..8db703b3 100644 --- a/src/gfn/env.py +++ b/src/gfn/env.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Optional, Tuple, Union +from typing import Dict, Optional, Tuple, Union import torch from torch_geometric.data import Batch, Data @@ -568,8 +568,8 @@ class GraphEnv(Env): def __init__( self, s0: Data, - node_feature_dim: int, - edge_feature_dim: int, + # node_feature_dim: int, + # edge_feature_dim: int, sf: Optional[Data] = None, device_str: Optional[str] = None, preprocessor: Optional[Preprocessor] = None, @@ -592,8 +592,8 @@ def __init__( """ self.s0 = s0.to(device_str) - self.node_feature_dim = node_feature_dim - self.edge_feature_dim = edge_feature_dim + self.node_feature_dim = s0.x.shape[1] + self.edge_feature_dim = s0.edge_attr.shape[1] self.sf = sf @@ -609,8 +609,8 @@ def make_states_class(self) -> type[GraphStates]: class GraphEnvStates(GraphStates): s0 = env.s0 sf = env.sf - node_feature_dim = env.node_feature_dim - edge_feature_dim = env.edge_feature_dim + # node_feature_dim = env.node_feature_dim + # edge_feature_dim = env.edge_feature_dim make_random_states_graph = env.make_random_states_tensor return GraphEnvStates @@ -630,6 +630,17 @@ class DefaultGraphAction(GraphActions): return DefaultGraphAction + def actions_from_tensor(self, tensor: Dict[str, torch.Tensor]): + """Wraps the supplied Tensor in an Actions instance. + + Args: + tensor: The tensor of shape "action_shape" representing the actions. + + Returns: + Actions: An instance of Actions. + """ + return self.Actions(**tensor) + @abstractmethod def step(self, states: GraphStates, actions: Actions) -> GraphStates: """Function that takes a batch of graph states and actions and returns a batch of next diff --git a/src/gfn/gym/graph_building.py b/src/gfn/gym/graph_building.py index 2d060759..e33fafea 100644 --- a/src/gfn/gym/graph_building.py +++ b/src/gfn/gym/graph_building.py @@ -18,7 +18,14 @@ def __init__( state_evaluator: Callable[[Batch], torch.Tensor] | None = None, device_str: Literal["cpu", "cuda"] = "cpu", ): - s0 = Data().to(device_str) + s0 = Data( + x=torch.zeros((0, node_feature_dim), dtype=torch.float32), + edge_attr=torch.zeros((0, edge_feature_dim), dtype=torch.float32), + edge_index=torch.zeros((2, 0), dtype=torch.long), + ).to(device_str) + sf = Data( + x=torch.ones((1, node_feature_dim), dtype=torch.float32) * float('inf'), + ).to(device_str) if state_evaluator is None: state_evaluator = GCNConvEvaluator(node_feature_dim) @@ -26,8 +33,7 @@ def __init__( super().__init__( s0=s0, - node_feature_dim=node_feature_dim, - edge_feature_dim=edge_feature_dim, + sf=sf, device_str=device_str, ) diff --git a/src/gfn/modules.py b/src/gfn/modules.py index 4083d169..adbcc1c3 100644 --- a/src/gfn/modules.py +++ b/src/gfn/modules.py @@ -458,7 +458,7 @@ class GraphActionPolicyEstimator(GFNModule): def __init__( self, module: nn.ModuleDict, - preprocessor: Preprocessor | None = None, + # preprocessor: Preprocessor | None = None, is_backward: bool = False, ): """Initializes a estimator for P_F for graph environments. @@ -466,9 +466,12 @@ def __init__( Args: is_backward: if False, then this is a forward policy, else backward policy. """ - super().__init__(module, preprocessor, is_backward) - assert isinstance(self.module, nn.ModuleDict) - assert self.module.keys() == {"action_type", "edge_index", "features"} + #super().__init__(module, preprocessor, is_backward) + nn.Module.__init__(self) + assert isinstance(module, nn.ModuleDict) + assert module.keys() == {"action_type", "edge_index", "features"} + self.module = module + self.is_backward = is_backward def forward(self, states: GraphStates) -> Dict[str, torch.Tensor]: """Forward pass of the module. @@ -482,8 +485,7 @@ def forward(self, states: GraphStates) -> Dict[str, torch.Tensor]: edge_index_logits = self.module["edge_index"](states) features = self.module["features"](states) - assert action_type_logits == len(GraphActionType) - assert edge_index_logits.shape[-1] == 2 + assert action_type_logits.shape[-1] == len(GraphActionType) return { "action_type": action_type_logits, "edge_index": edge_index_logits, @@ -517,13 +519,14 @@ def to_probability_distribution( action_type_probs = torch.softmax(action_type_logits / temperature, dim=-1) uniform_dist_probs = action_type_masks.float() / action_type_masks.sum(dim=-1, keepdim=True) action_type_probs = (1 - epsilon) * action_type_probs + epsilon * uniform_dist_probs + dists["action_type"] = Categorical(probs=action_type_probs) edge_index_logits = module_output["edge_index"] - edge_index_probs = torch.softmax(edge_index_logits / temperature, dim=-1) - uniform_dist_probs = torch.ones_like(edge_index_probs) / edge_index_probs.shape[-1] - edge_index_probs = (1 - epsilon) * edge_index_probs + epsilon * uniform_dist_probs + if edge_index_logits.shape[-1] != 0: + edge_index_probs = torch.softmax(edge_index_logits / temperature, dim=-1) + uniform_dist_probs = torch.ones_like(edge_index_probs) / edge_index_probs.shape[-1] + edge_index_probs = (1 - epsilon) * edge_index_probs + epsilon * uniform_dist_probs + dists["edge_index"] = UnsqueezedCategorical(probs=edge_index_probs) - dists["action_type"] = Categorical(probs=action_type_probs) dists["features"] = Normal(module_output["features"], temperature) - dists["edge_index"] = Categorical(probs=edge_index_probs) return ComposedDistribution(dists=dists) diff --git a/src/gfn/samplers.py b/src/gfn/samplers.py index 819620f0..8a8a112c 100644 --- a/src/gfn/samplers.py +++ b/src/gfn/samplers.py @@ -7,7 +7,7 @@ from gfn.containers import Trajectories from gfn.env import Env from gfn.modules import GFNModule -from gfn.states import States, stack_states +from gfn.states import States from gfn.utils.handlers import ( has_conditioning_exception_handler, no_conditioning_exception_handler, @@ -147,7 +147,7 @@ def sample_trajectories( if conditioning is not None: assert states.batch_shape == conditioning.shape[: len(states.batch_shape)] - device = states.tensor.device + device = states.device dones = ( states.is_initial_state @@ -155,8 +155,8 @@ def sample_trajectories( else states.is_sink_state ) - trajectories_states: List[States] = [deepcopy(states)] - trajectories_actions: List[torch.Tensor] = [] + trajectories_states: States = deepcopy(states) + trajectories_actions: Optional[Actions] = None trajectories_logprobs: List[torch.Tensor] = [] trajectories_dones = torch.zeros( n_trajectories, dtype=torch.long, device=device @@ -206,7 +206,11 @@ def sample_trajectories( if save_logprobs: # When off_policy, actions_log_probs are None. log_probs[~dones] = actions_log_probs - trajectories_actions.append(actions) + + if trajectories_actions is None: + trajectories_actions = actions + else: + trajectories_actions.extend(actions) trajectories_logprobs.append(log_probs) if self.estimator.is_backward: @@ -239,10 +243,8 @@ def sample_trajectories( states = new_states dones = dones | new_dones - trajectories_states.append(deepcopy(states)) + trajectories_states.extend(deepcopy(states)) - trajectories_states = stack_states(trajectories_states) - trajectories_actions = env.Actions.stack(trajectories_actions) trajectories_logprobs = ( torch.stack(trajectories_logprobs, dim=0) if save_logprobs else None ) diff --git a/src/gfn/states.py b/src/gfn/states.py index 5d63a41b..d1cbed9b 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -5,6 +5,7 @@ from math import prod from typing import Callable, ClassVar, List, Optional, Sequence, Tuple +import numpy as np import torch from torch_geometric.data import Batch, Data @@ -478,34 +479,6 @@ def init_forward_masks(self, set_ones: bool = True): self.forward_masks = torch.zeros(shape).bool() -def stack_states(states: List[States]): - """Given a list of states, stacks them along a new dimension (0).""" - state_example = states[0] # We assume all elems of `states` are the same. - - stacked_states = state_example.from_batch_shape((0, 0)) # Empty. - stacked_states.tensor = torch.stack([s.tensor for s in states], dim=0) - if state_example._log_rewards: - stacked_states._log_rewards = torch.stack( - [s._log_rewards for s in states], dim=0 - ) - - # We are dealing with a list of DiscretrStates instances. - if hasattr(state_example, "forward_masks"): - stacked_states.forward_masks = torch.stack( - [s.forward_masks for s in states], dim=0 - ) - stacked_states.backward_masks = torch.stack( - [s.backward_masks for s in states], dim=0 - ) - - # Adds the trajectory dimension. - stacked_states.batch_shape = ( - stacked_states.tensor.shape[0], - ) + state_example.batch_shape - - return stacked_states - - class GraphStates(ABC): """ Base class for Graph as a state representation. The `GraphStates` object is a batched collection of @@ -515,8 +488,6 @@ class GraphStates(ABC): s0: ClassVar[Data] sf: ClassVar[Optional[Data]] - node_feature_dim: ClassVar[int] - edge_feature_dim: ClassVar[int] def __init__(self, graphs: Batch): self.data: Batch = graphs @@ -524,14 +495,15 @@ def __init__(self, graphs: Batch): self._log_rewards: float = None # TODO logic repeated from env.is_valid_action + not_empty = self.data.x is not None and self.data.x.shape[0] > 0 self.forward_masks = torch.ones((self.batch_shape, 3), dtype=torch.bool) - self.forward_masks[:, GraphActionType.ADD_EDGE.value] = self.data.x.shape[0] > 0 - self.forward_masks[:, GraphActionType.EXIT.value] = self.data.x.shape[0] > 0 - + self.forward_masks[:, GraphActionType.ADD_EDGE] = not_empty + self.forward_masks[:, GraphActionType.EXIT] = not_empty + self.backward_masks = torch.ones((self.batch_shape, 3), dtype=torch.bool) - self.backward_masks[:, GraphActionType.ADD_NODE.value] = self.data.x.shape[0] > 0 - self.backward_masks[:, GraphActionType.ADD_EDGE.value] = self.data.edge_attr.shape[0] > 0 - self.backward_masks[:, GraphActionType.EXIT.value] = self.data.x.shape[0] > 0 + self.backward_masks[:, GraphActionType.ADD_NODE] = not_empty + self.backward_masks[:, GraphActionType.ADD_EDGE] = not_empty and self.data.edge_attr.shape[0] > 0 + self.backward_masks[:, GraphActionType.EXIT] = not_empty @classmethod def from_batch_shape( @@ -586,8 +558,8 @@ def make_random_states_graph(cls, batch_shape: int) -> Batch: data_list = [] for _ in range(batch_shape): data = Data( - x=torch.rand(cls.s0.num_nodes, cls.node_feature_dim), - edge_attr=torch.rand(cls.s0.num_edges, cls.edge_feature_dim), + x=torch.rand(cls.s0.num_nodes, cls.s0.x.shape[1]), + edge_attr=torch.rand(cls.s0.num_edges, cls.s0.edge_attr.shape[1]), edge_index=cls.s0.edge_index, # TODO: make it random ) data_list.append(data) @@ -599,13 +571,18 @@ def __len__(self): def __repr__(self): return ( f"{self.__class__.__name__} object of batch shape {self.batch_shape} and " - f"node feature dim {self.node_feature_dim} and edge feature dim {self.edge_feature_dim}" + f"node feature dim {self.s0.x.shape[1]} and edge feature dim {self.s0.edge_attr.shape[1]}" ) def __getitem__( self, index: int | Sequence[int] | slice | torch.Tensor ) -> GraphStates: - out = self.__class__(Batch(self.data[index])) + idxs = np.arange(len(self.data))[index] + data = [] + for i in idxs: + data.append(self.data.get_example(i)) + + out = GraphStates(Batch.from_data_list(data)) if self._log_rewards is not None: out._log_rewards = self._log_rewards[index] @@ -643,7 +620,10 @@ def __setitem__(self, index: int | Sequence[int], graph: GraphStates): @property def device(self) -> torch.device: - return self.data.get_example(0).x.device + sample = self.data.get_example(0).x + if sample is not None: + return sample.device + return torch.device("cuda" if torch.cuda.is_available() else "cpu") def to(self, device: torch.device) -> GraphStates: """ @@ -675,3 +655,10 @@ def log_rewards(self) -> torch.Tensor: @log_rewards.setter def log_rewards(self, log_rewards: torch.Tensor) -> None: self._log_rewards = log_rewards + + @property + def is_sink_state(self) -> torch.Tensor: + batch_dim = len(self.data.ptr) - 1 + if len(self.data.x) == 0: + return torch.zeros(batch_dim, dtype=torch.bool) + return torch.all(self.data.x == self.sf.x, dim=-1).reshape(batch_dim,) diff --git a/src/gfn/utils/distributions.py b/src/gfn/utils/distributions.py index 5cfb9cc8..a44727b3 100644 --- a/src/gfn/utils/distributions.py +++ b/src/gfn/utils/distributions.py @@ -54,5 +54,12 @@ def __init__(self, dists: Dict[str, Distribution]): super().__init__() self.dists = dists - def sample(self, sample_shape: torch.Size) -> Dict[str, torch.Tensor]: - return {k: v.sample(sample_shape) for k, v in self.dists.items()} \ No newline at end of file + def sample(self, sample_shape=torch.Size()) -> Dict[str, torch.Tensor]: + return {k: v.sample(sample_shape) for k, v in self.dists.items()} + + def log_prob(self, sample: Dict[str, torch.Tensor]) -> torch.Tensor: + log_probs = [ + v.log_prob(sample[k]).reshape(sample[k].shape[0], -1).sum(dim=-1) + for k, v in self.dists.items() + ] + return sum(log_probs) \ No newline at end of file diff --git a/testing/test_samplers_and_trajectories.py b/testing/test_samplers_and_trajectories.py index 02014cb0..598a7437 100644 --- a/testing/test_samplers_and_trajectories.py +++ b/testing/test_samplers_and_trajectories.py @@ -1,7 +1,12 @@ from typing import Literal, Tuple import pytest +import torch +from torch import nn +from torch_geometric.nn import GCNConv +from torch_geometric.data import Batch +from gfn.actions import GraphActionType from gfn.containers import Trajectories from gfn.containers.replay_buffer import ReplayBuffer from gfn.gym import Box, DiscreteEBM, HyperGrid @@ -9,9 +14,8 @@ from gfn.gym.helpers.box_utils import BoxPBEstimator, BoxPBMLP, BoxPFEstimator, BoxPFMLP from gfn.modules import DiscretePolicyEstimator, GFNModule, GraphActionPolicyEstimator from gfn.samplers import Sampler +from gfn.states import GraphStates from gfn.utils.modules import MLP -from torch_geometric.nn import GCNConv - def trajectory_sampling_with_return( env_name: str, @@ -218,9 +222,65 @@ def test_replay_buffer( raise ValueError(f"Error while testing {env_name}") from e +# ------ GRAPH TESTS ------ + + +class ActionTypeNet(nn.Module): + def __init__(self, feature_dim: int): + super().__init__() + self.conv = GCNConv(feature_dim, len(GraphActionType)) + + def forward(self, states: GraphStates) -> torch.Tensor: + if len(states.data.x) == 0: + out = torch.zeros((len(states), len(GraphActionType))) + out[:, GraphActionType.ADD_NODE] = 1 + return out + + x = self.conv(states.data.x, states.data.edge_index) + return torch.mean(x, dim=0) + +class FeaturesNet(nn.Module): + def __init__(self, feature_dim: int): + super().__init__() + self.feature_dim = feature_dim + self.conv = GCNConv(feature_dim, feature_dim) + + def forward(self, states: GraphStates) -> torch.Tensor: + if len(states.data.x) == 0: + return torch.zeros((len(states), self.feature_dim)) + x = self.conv(states.data.x, states.data.edge_index) + x = x.reshape(len(states), -1, x.shape[-1]).mean(dim=0) + return x + +class EdgeIndexNet(nn.Module): + def __init__(self, feature_dim: int): + super().__init__() + self.conv = GCNConv(feature_dim, 8) + + def forward(self, states: GraphStates) -> torch.Tensor: + x = self.conv(states.data.x, states.data.edge_index) + return torch.einsum("nf,mf->nm", x, x) + def test_graph_building(): - node_feature_dim = 8 - env = GraphBuilding(node_feature_dim=node_feature_dim, edge_feature_dim=4) - graph_net = GCNConv(node_feature_dim, 1) + feature_dim = 8 + env = GraphBuilding(node_feature_dim=feature_dim, edge_feature_dim=feature_dim) + + action_type_net = ActionTypeNet(feature_dim) + features_net = FeaturesNet(feature_dim) + edge_index = EdgeIndexNet(feature_dim) + module = nn.ModuleDict({ + "action_type": action_type_net, + "features": features_net, + "edge_index": edge_index + }) + + pf_estimator = GraphActionPolicyEstimator(module=module) + + sampler = Sampler(estimator=pf_estimator) + trajectories = sampler.sample_trajectories( + env, + n=5, + save_logprobs=True, + save_estimator_outputs=True, + ) - GraphActionPolicyEstimator(module=graph_net) From 5e64c84b8d3779d291c0a42100c8b14d055e4ce3 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Tue, 26 Nov 2024 11:48:39 +0100 Subject: [PATCH 011/102] use TensorDict --- src/gfn/actions.py | 45 ++++++++------- src/gfn/gym/graph_building.py | 37 ++++++------ src/gfn/modules.py | 48 ++++++++-------- src/gfn/samplers.py | 12 ++-- src/gfn/states.py | 14 +++-- src/gfn/utils/distributions.py | 7 ++- testing/test_environments.py | 22 ++++---- testing/test_samplers_and_trajectories.py | 69 ++++++++++------------- 8 files changed, 130 insertions(+), 124 deletions(-) diff --git a/src/gfn/actions.py b/src/gfn/actions.py index 03a1bfdc..49ac974c 100644 --- a/src/gfn/actions.py +++ b/src/gfn/actions.py @@ -1,7 +1,7 @@ from __future__ import annotations # This allows to use the class name in type hints -from abc import ABC import enum +from abc import ABC from math import prod from typing import ClassVar, Optional, Sequence @@ -178,13 +178,17 @@ class GraphActionType(enum.IntEnum): class GraphActions: - nodes_features_dim: ClassVar[int] edge_features_dim: ClassVar[int] - def __init__(self, action_type: torch.Tensor, features: Optional[torch.Tensor] = None, edge_index: Optional[torch.Tensor] = None): + def __init__( + self, + action_type: torch.Tensor, + features: Optional[torch.Tensor] = None, + edge_index: Optional[torch.Tensor] = None, + ): """Initializes a GraphAction object. - + Args: action: a GraphActionType indicating the type of action. features: a tensor of shape (batch_shape, feature_shape) representing the features of the nodes or of the edges, depending on the action type. @@ -210,7 +214,7 @@ def __init__(self, action_type: torch.Tensor, features: Optional[torch.Tensor] = assert features_dim == self.edge_features_dim assert edge_index is not None assert edge_index.shape == (2, batch_dim) - + self.features = features self.edge_index = edge_index @@ -236,16 +240,14 @@ def __getitem__(self, index: int | Sequence[int] | Sequence[bool]) -> GraphActio edge_index = self.edge_index[index] if self.edge_index is not None else None return GraphActions(self.action_type, features, edge_index) - def __setitem__(self, index: int | Sequence[int] | Sequence[bool], action: GraphActions) -> None: + def __setitem__( + self, index: int | Sequence[int] | Sequence[bool], action: GraphActions + ) -> None: """Set particular actions of the batch.""" - assert self.action_type == action.action_type - if self.action_type != GraphActionType.EXIT: - assert self.features is not None - self.features[index] = action.features - if self.action_type == GraphActionType.ADD_EDGE: - assert self.edge_index is not None - self.edge_index[index] = action.edge_index - + self.action_type[index] = action.action_type + self.features[index] = action.features + self.edge_index[index] = action.edge_index + def compare(self, other: GraphActions) -> torch.Tensor: """Compares the actions to another GraphAction object. @@ -265,15 +267,20 @@ def compare(self, other: GraphActions) -> torch.Tensor: @property def is_exit(self) -> torch.Tensor: """Returns a boolean tensor of shape `batch_shape` indicating whether the actions are exit actions.""" - return torch.full((1,), self.action_type == GraphActionType.EXIT, dtype=torch.bool, device=self.device) + return torch.full( + (1,), + self.action_type == GraphActionType.EXIT, + dtype=torch.bool, + device=self.device, + ) @classmethod - def make_dummy_actions(cls, batch_shape: tuple[int]) -> GraphActions: # TODO: remove make_dummy_actions + def make_dummy_actions( + cls, batch_shape: tuple[int] + ) -> GraphActions: # TODO: remove make_dummy_actions """Creates an Actions object of dummy actions with the given batch shape.""" return GraphActions( action_type=torch.full(batch_shape, fill_value=GraphActionType.EXIT), features=None, - edge_index=None + edge_index=None, ) - - diff --git a/src/gfn/gym/graph_building.py b/src/gfn/gym/graph_building.py index e33fafea..f8a4f624 100644 --- a/src/gfn/gym/graph_building.py +++ b/src/gfn/gym/graph_building.py @@ -24,7 +24,7 @@ def __init__( edge_index=torch.zeros((2, 0), dtype=torch.long), ).to(device_str) sf = Data( - x=torch.ones((1, node_feature_dim), dtype=torch.float32) * float('inf'), + x=torch.ones((1, node_feature_dim), dtype=torch.float32) * float("inf"), ).to(device_str) if state_evaluator is None: @@ -66,11 +66,12 @@ def step(self, states: GraphStates, actions: GraphActions) -> GraphStates: graphs.edge_index = actions.edge_index else: graphs.edge_attr = torch.cat([graphs.edge_attr, actions.features]) - graphs.edge_index = torch.cat([graphs.edge_index, actions.edge_index], dim=1) + graphs.edge_index = torch.cat( + [graphs.edge_index, actions.edge_index], dim=1 + ) return self.States(graphs) - def backward_step(self, states: GraphStates, actions: GraphActions) -> GraphStates: """Backward step function for the GraphBuilding environment. @@ -88,13 +89,14 @@ def backward_step(self, states: GraphStates, actions: GraphActions) -> GraphStat if actions.action_type == GraphActionType.ADD_NODE: assert graphs.x is not None is_equal = torch.any( - torch.all(graphs.x[:, None] == actions.features, dim=-1), - dim=-1 + torch.all(graphs.x[:, None] == actions.features, dim=-1), dim=-1 ) graphs.x = graphs.x[~is_equal] elif actions.action_type == GraphActionType.ADD_EDGE: assert actions.edge_index is not None - is_equal = torch.all(graphs.edge_index[:, None] == actions.edge_index[:, :, None], dim=0) + is_equal = torch.all( + graphs.edge_index[:, None] == actions.edge_index[:, :, None], dim=0 + ) is_equal = torch.any(is_equal, dim=0) graphs.edge_attr = graphs.edge_attr[~is_equal] graphs.edge_index = graphs.edge_index[:, ~is_equal] @@ -106,13 +108,13 @@ def is_action_valid( ) -> bool: if actions.action_type == GraphActionType.EXIT: return True # TODO: what are the conditions for exit action? - + if actions.action_type == GraphActionType.ADD_NODE: if actions.edge_index is not None: return False if states.data.x is None: return not backward - + equal_nodes_per_batch = torch.all( states.data.x == actions.features[:, None], dim=-1 ).reshape(states.data.batch_size, -1) @@ -121,7 +123,7 @@ def is_action_valid( if backward: # TODO: check if no edge are connected? return torch.all(equal_nodes_per_batch == 1) return torch.all(equal_nodes_per_batch == 0) - + if actions.action_type == GraphActionType.ADD_EDGE: assert actions.edge_index is not None if torch.any(actions.edge_index[0] == actions.edge_index[1]): @@ -130,9 +132,9 @@ def is_action_valid( return False if torch.any(actions.edge_index > states.data.num_nodes): return False - + batch_dim = actions.features.shape[0] - batch_idx = actions.edge_index % batch_dim + batch_idx = actions.edge_index % batch_dim if torch.any(batch_idx != torch.arange(batch_dim)): return False if states.data.edge_attr is None: @@ -142,15 +144,18 @@ def is_action_valid( states.data.edge_attr == actions.features[:, None], dim=-1 ).reshape(states.data.batch_size, -1) equal_edges_per_batch_attr = torch.sum(equal_edges_per_batch_attr, dim=-1) - + equal_edges_per_batch_index = torch.all( states.data.edge_index[:, None] == actions.edge_index[:, :, None], dim=0 ).reshape(states.data.batch_size, -1) equal_edges_per_batch_index = torch.sum(equal_edges_per_batch_index, dim=-1) if backward: - return torch.all(equal_edges_per_batch_attr == 1) and torch.all(equal_edges_per_batch_index == 1) - return torch.all(equal_edges_per_batch_attr == 0) and torch.all(equal_edges_per_batch_index == 0) - + return torch.all(equal_edges_per_batch_attr == 1) and torch.all( + equal_edges_per_batch_index == 1 + ) + return torch.all(equal_edges_per_batch_attr == 0) and torch.all( + equal_edges_per_batch_index == 0 + ) def reward(self, final_states: GraphStates) -> torch.Tensor: """The environment's reward given a state. @@ -186,4 +191,4 @@ def __init__(self, num_features): def __call__(self, batch: Batch) -> torch.Tensor: out = self.net(batch.x, batch.edge_index) out = out.reshape(batch.batch_size, -1) - return out.mean(-1) \ No newline at end of file + return out.mean(-1) diff --git a/src/gfn/modules.py b/src/gfn/modules.py index adbcc1c3..90087e5a 100644 --- a/src/gfn/modules.py +++ b/src/gfn/modules.py @@ -1,8 +1,9 @@ -from abc import ABC, abstractmethod +from abc import ABC from typing import Any, Dict import torch import torch.nn as nn +from tensordict import TensorDict from torch.distributions import Categorical, Distribution, Normal from gfn.actions import GraphActionType @@ -277,7 +278,7 @@ def to_probability_distribution( # LogEdgeFlows are greedy, as are most P_B. else: - return UnsqueezedCategorical(logits=logits) + return UnsqueezedCategorical(logits=logits) class ConditionalDiscretePolicyEstimator(DiscretePolicyEstimator): @@ -457,7 +458,7 @@ class GraphActionPolicyEstimator(GFNModule): def __init__( self, - module: nn.ModuleDict, + module: nn.Module, # preprocessor: Preprocessor | None = None, is_backward: bool = False, ): @@ -466,14 +467,12 @@ def __init__( Args: is_backward: if False, then this is a forward policy, else backward policy. """ - #super().__init__(module, preprocessor, is_backward) + # super().__init__(module, preprocessor, is_backward) nn.Module.__init__(self) - assert isinstance(module, nn.ModuleDict) - assert module.keys() == {"action_type", "edge_index", "features"} self.module = module self.is_backward = is_backward - - def forward(self, states: GraphStates) -> Dict[str, torch.Tensor]: + + def forward(self, states: GraphStates) -> TensorDict: """Forward pass of the module. Args: @@ -481,16 +480,7 @@ def forward(self, states: GraphStates) -> Dict[str, torch.Tensor]: Returns the . """ - action_type_logits = self.module["action_type"](states) - edge_index_logits = self.module["edge_index"](states) - features = self.module["features"](states) - - assert action_type_logits.shape[-1] == len(GraphActionType) - return { - "action_type": action_type_logits, - "edge_index": edge_index_logits, - "features": features - } + return self.module(states) def to_probability_distribution( self, @@ -510,22 +500,32 @@ def to_probability_distribution( if set to 1.0 (default), in which case it's on policy. epsilon: with probability epsilon, a random action is chosen. Does nothing if set to 0.0 (default), in which case it's on policy.""" - + dists = {} action_type_logits = module_output["action_type"] - action_type_masks = states.backward_masks if self.is_backward else states.forward_masks + action_type_masks = ( + states.backward_masks if self.is_backward else states.forward_masks + ) action_type_logits[~action_type_masks] = -float("inf") action_type_probs = torch.softmax(action_type_logits / temperature, dim=-1) - uniform_dist_probs = action_type_masks.float() / action_type_masks.sum(dim=-1, keepdim=True) - action_type_probs = (1 - epsilon) * action_type_probs + epsilon * uniform_dist_probs + uniform_dist_probs = action_type_masks.float() / action_type_masks.sum( + dim=-1, keepdim=True + ) + action_type_probs = ( + 1 - epsilon + ) * action_type_probs + epsilon * uniform_dist_probs dists["action_type"] = Categorical(probs=action_type_probs) edge_index_logits = module_output["edge_index"] if edge_index_logits.shape[-1] != 0: edge_index_probs = torch.softmax(edge_index_logits / temperature, dim=-1) - uniform_dist_probs = torch.ones_like(edge_index_probs) / edge_index_probs.shape[-1] - edge_index_probs = (1 - epsilon) * edge_index_probs + epsilon * uniform_dist_probs + uniform_dist_probs = ( + torch.ones_like(edge_index_probs) / edge_index_probs.shape[-1] + ) + edge_index_probs = ( + 1 - epsilon + ) * edge_index_probs + epsilon * uniform_dist_probs dists["edge_index"] = UnsqueezedCategorical(probs=edge_index_probs) dists["features"] = Normal(module_output["features"], temperature) diff --git a/src/gfn/samplers.py b/src/gfn/samplers.py index 8a8a112c..4b706a39 100644 --- a/src/gfn/samplers.py +++ b/src/gfn/samplers.py @@ -193,12 +193,12 @@ def sample_trajectories( if estimator_outputs is not None: # Place estimator outputs into a stackable tensor. Note that this # will be replaced with torch.nested.nested_tensor in the future. - estimator_outputs_padded = torch.full( - (n_trajectories,) + estimator_outputs.shape[1:], + estimator_outputs_padded = torch.full_like( + estimator_outputs.expand( + (n_trajectories,) + estimator_outputs.shape[1:] + ), fill_value=-float("inf"), - dtype=torch.float, - device=device, - ) + ).clone() # TODO: inefficient estimator_outputs_padded[~dones] = estimator_outputs all_estimator_outputs.append(estimator_outputs_padded) @@ -206,7 +206,7 @@ def sample_trajectories( if save_logprobs: # When off_policy, actions_log_probs are None. log_probs[~dones] = actions_log_probs - + if trajectories_actions is None: trajectories_actions = actions else: diff --git a/src/gfn/states.py b/src/gfn/states.py index d1cbed9b..b6c9bfb7 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -3,7 +3,7 @@ from abc import ABC from copy import deepcopy from math import prod -from typing import Callable, ClassVar, List, Optional, Sequence, Tuple +from typing import Callable, ClassVar, Optional, Sequence, Tuple import numpy as np import torch @@ -499,10 +499,12 @@ def __init__(self, graphs: Batch): self.forward_masks = torch.ones((self.batch_shape, 3), dtype=torch.bool) self.forward_masks[:, GraphActionType.ADD_EDGE] = not_empty self.forward_masks[:, GraphActionType.EXIT] = not_empty - + self.backward_masks = torch.ones((self.batch_shape, 3), dtype=torch.bool) self.backward_masks[:, GraphActionType.ADD_NODE] = not_empty - self.backward_masks[:, GraphActionType.ADD_EDGE] = not_empty and self.data.edge_attr.shape[0] > 0 + self.backward_masks[:, GraphActionType.ADD_EDGE] = ( + not_empty and self.data.edge_attr.shape[0] > 0 + ) self.backward_masks[:, GraphActionType.EXIT] = not_empty @classmethod @@ -655,10 +657,12 @@ def log_rewards(self) -> torch.Tensor: @log_rewards.setter def log_rewards(self, log_rewards: torch.Tensor) -> None: self._log_rewards = log_rewards - + @property def is_sink_state(self) -> torch.Tensor: batch_dim = len(self.data.ptr) - 1 if len(self.data.x) == 0: return torch.zeros(batch_dim, dtype=torch.bool) - return torch.all(self.data.x == self.sf.x, dim=-1).reshape(batch_dim,) + return torch.all(self.data.x == self.sf.x, dim=-1).reshape( + batch_dim, + ) diff --git a/src/gfn/utils/distributions.py b/src/gfn/utils/distributions.py index a44727b3..421a04e9 100644 --- a/src/gfn/utils/distributions.py +++ b/src/gfn/utils/distributions.py @@ -1,6 +1,7 @@ from typing import Dict + import torch -from torch.distributions import Distribution, Categorical +from torch.distributions import Categorical, Distribution class UnsqueezedCategorical(Categorical): @@ -56,10 +57,10 @@ def __init__(self, dists: Dict[str, Distribution]): def sample(self, sample_shape=torch.Size()) -> Dict[str, torch.Tensor]: return {k: v.sample(sample_shape) for k, v in self.dists.items()} - + def log_prob(self, sample: Dict[str, torch.Tensor]) -> torch.Tensor: log_probs = [ v.log_prob(sample[k]).reshape(sample[k].shape[0], -1).sum(dim=-1) for k, v in self.dists.items() ] - return sum(log_probs) \ No newline at end of file + return sum(log_probs) diff --git a/testing/test_environments.py b/testing/test_environments.py index e157f1bb..80198fd2 100644 --- a/testing/test_environments.py +++ b/testing/test_environments.py @@ -334,7 +334,7 @@ def test_graph_env(): actions = action_cls( GraphActionType.ADD_EDGE, torch.rand((BATCH_SIZE, FEATURE_DIM)), - torch.randint(0, 10, (2, BATCH_SIZE), dtype=torch.long) + torch.randint(0, 10, (2, BATCH_SIZE), dtype=torch.long), ) states = env.step(states, actions) @@ -344,7 +344,7 @@ def test_graph_env(): torch.rand((BATCH_SIZE, FEATURE_DIM)), ) states = env.step(states, actions) - + assert states.data.x.shape == (BATCH_SIZE * NUM_NODES, FEATURE_DIM) with pytest.raises(NonValidActionsError): @@ -360,17 +360,17 @@ def test_graph_env(): actions = action_cls( GraphActionType.ADD_EDGE, torch.rand((BATCH_SIZE, FEATURE_DIM)), - torch.stack([edge_index, edge_index]) + torch.stack([edge_index, edge_index]), ) states = env.step(states, actions) for i in range(NUM_NODES - 1): - node_is = torch.arange(i * BATCH_SIZE, (i + 1) * BATCH_SIZE) + node_is = torch.arange(i * BATCH_SIZE, (i + 1) * BATCH_SIZE) node_js = torch.arange((i + 1) * BATCH_SIZE, (i + 2) * BATCH_SIZE) actions = action_cls( GraphActionType.ADD_EDGE, torch.rand((BATCH_SIZE, FEATURE_DIM)), - torch.stack([node_is, node_js]) + torch.stack([node_is, node_js]), ) states = env.step(states, actions) @@ -379,7 +379,7 @@ def test_graph_env(): actions = action_cls( GraphActionType.ADD_EDGE, torch.rand((BATCH_SIZE, FEATURE_DIM)), - edge_index.T + edge_index.T, ) states = env.step(states, actions) @@ -401,15 +401,15 @@ def test_graph_env(): actions = action_cls( GraphActionType.ADD_EDGE, states.data.edge_attr[edge_idx], - states.data.edge_index[:, edge_idx] + states.data.edge_index[:, edge_idx], ) states = env.backward_step(states, actions) - + with pytest.raises(NonValidActionsError): actions = action_cls( GraphActionType.ADD_EDGE, torch.rand((BATCH_SIZE, FEATURE_DIM)), - torch.randint(0, 10, (2, BATCH_SIZE), dtype=torch.long) + torch.randint(0, 10, (2, BATCH_SIZE), dtype=torch.long), ) states = env.backward_step(states, actions) @@ -420,7 +420,7 @@ def test_graph_env(): states.data.x[edge_idx], ) states = env.backward_step(states, actions) - + assert states.data.x.shape == (0, FEATURE_DIM) with pytest.raises(NonValidActionsError): @@ -428,4 +428,4 @@ def test_graph_env(): GraphActionType.ADD_NODE, torch.rand((BATCH_SIZE, FEATURE_DIM)), ) - states = env.backward_step(states, actions) \ No newline at end of file + states = env.backward_step(states, actions) diff --git a/testing/test_samplers_and_trajectories.py b/testing/test_samplers_and_trajectories.py index 598a7437..302f31b3 100644 --- a/testing/test_samplers_and_trajectories.py +++ b/testing/test_samplers_and_trajectories.py @@ -2,9 +2,9 @@ import pytest import torch +from tensordict import TensorDict from torch import nn from torch_geometric.nn import GCNConv -from torch_geometric.data import Batch from gfn.actions import GraphActionType from gfn.containers import Trajectories @@ -17,6 +17,7 @@ from gfn.states import GraphStates from gfn.utils.modules import MLP + def trajectory_sampling_with_return( env_name: str, preprocessor_name: Literal["KHot", "OneHot", "Identity", "Enum"], @@ -225,55 +226,44 @@ def test_replay_buffer( # ------ GRAPH TESTS ------ -class ActionTypeNet(nn.Module): - def __init__(self, feature_dim: int): - super().__init__() - self.conv = GCNConv(feature_dim, len(GraphActionType)) - - def forward(self, states: GraphStates) -> torch.Tensor: - if len(states.data.x) == 0: - out = torch.zeros((len(states), len(GraphActionType))) - out[:, GraphActionType.ADD_NODE] = 1 - return out - - x = self.conv(states.data.x, states.data.edge_index) - return torch.mean(x, dim=0) - -class FeaturesNet(nn.Module): +class GraphActionNet(nn.Module): def __init__(self, feature_dim: int): super().__init__() self.feature_dim = feature_dim - self.conv = GCNConv(feature_dim, feature_dim) + self.action_type_conv = GCNConv(feature_dim, len(GraphActionType)) + self.features_conv = GCNConv(feature_dim, feature_dim) + self.edge_index_conv = GCNConv(feature_dim, 8) - def forward(self, states: GraphStates) -> torch.Tensor: + def forward(self, states: GraphStates) -> TensorDict: if len(states.data.x) == 0: - return torch.zeros((len(states), self.feature_dim)) - x = self.conv(states.data.x, states.data.edge_index) - x = x.reshape(len(states), -1, x.shape[-1]).mean(dim=0) - return x - -class EdgeIndexNet(nn.Module): - def __init__(self, feature_dim: int): - super().__init__() - self.conv = GCNConv(feature_dim, 8) + action_type = torch.zeros((len(states), len(GraphActionType))) + action_type[:, GraphActionType.ADD_NODE] = 1 + features = torch.zeros((len(states), self.feature_dim)) + else: + action_type = self.action_type_conv(states.data.x, states.data.edge_index) + action_type = torch.mean(action_type, dim=0) + features = self.features_conv(states.data.x, states.data.edge_index) + features = features.reshape(len(states), -1, features.shape[-1]).mean(dim=0) + + edge_index = self.edge_index_conv(states.data.x, states.data.edge_index) + edge_index = edge_index.reshape(states.batch_shape, -1, 8) + edge_index = torch.einsum("bnf,bmf->bnm", edge_index, edge_index) + + return TensorDict( + { + "action_type": action_type, + "features": features, + "edge_index": edge_index, + }, + batch_size=states.batch_shape, + ) - def forward(self, states: GraphStates) -> torch.Tensor: - x = self.conv(states.data.x, states.data.edge_index) - return torch.einsum("nf,mf->nm", x, x) def test_graph_building(): feature_dim = 8 env = GraphBuilding(node_feature_dim=feature_dim, edge_feature_dim=feature_dim) - action_type_net = ActionTypeNet(feature_dim) - features_net = FeaturesNet(feature_dim) - edge_index = EdgeIndexNet(feature_dim) - module = nn.ModuleDict({ - "action_type": action_type_net, - "features": features_net, - "edge_index": edge_index - }) - + module = GraphActionNet(feature_dim) pf_estimator = GraphActionPolicyEstimator(module=module) sampler = Sampler(estimator=pf_estimator) @@ -283,4 +273,3 @@ def test_graph_building(): save_logprobs=True, save_estimator_outputs=True, ) - From 81f8b7142ebfa2495aaa7e30e034ca39217314fd Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Thu, 28 Nov 2024 19:40:13 +0100 Subject: [PATCH 012/102] solve some errors --- src/gfn/actions.py | 69 +++++++----------- src/gfn/env.py | 24 ++---- src/gfn/gym/graph_building.py | 89 ++++++++++++----------- src/gfn/states.py | 21 ++++-- src/gfn/utils/distributions.py | 2 +- testing/test_environments.py | 2 +- testing/test_samplers_and_trajectories.py | 5 +- 7 files changed, 99 insertions(+), 113 deletions(-) diff --git a/src/gfn/actions.py b/src/gfn/actions.py index 49ac974c..89119fa1 100644 --- a/src/gfn/actions.py +++ b/src/gfn/actions.py @@ -177,9 +177,8 @@ class GraphActionType(enum.IntEnum): EXIT = 2 -class GraphActions: - nodes_features_dim: ClassVar[int] - edge_features_dim: ClassVar[int] +class GraphActions(Actions): + features_dim: ClassVar[int] def __init__( self, @@ -197,26 +196,20 @@ def __init__( This must defined if and only if the action type is GraphActionType.AddEdge. """ self.batch_shape = action_type.shape - assert torch.all(action_type == action_type[0]) - self.action_type = action_type[0] - if self.action_type == GraphActionType.EXIT: - assert features is None - assert edge_index is None - self.features = None - self.edge_index = None - else: - assert features is not None - batch_dim, features_dim = features.shape - assert (batch_dim,) == self.batch_shape - if self.action_type == GraphActionType.ADD_NODE: - assert features_dim == self.nodes_features_dim - elif self.action_type == GraphActionType.ADD_EDGE: - assert features_dim == self.edge_features_dim - assert edge_index is not None - assert edge_index.shape == (2, batch_dim) - - self.features = features - self.edge_index = edge_index + self.action_type = action_type + + if features is None: + assert torch.all(action_type == GraphActionType.EXIT) + features = torch.zeros((*self.batch_shape, self.features_dim)) + if edge_index is None: + assert torch.all(action_type != GraphActionType.ADD_EDGE) + edge_index = torch.zeros((2, *self.batch_shape)) + + batch_dim, _ = features.shape + assert (batch_dim,) == self.batch_shape + assert edge_index.shape == (2, batch_dim) + self.features = features + self.edge_index = edge_index def __repr__(self): return f"""GraphAction object of type {self.action_type} and features of shape {self.features.shape}.""" @@ -228,17 +221,14 @@ def device(self) -> torch.device: def __len__(self) -> int: """Returns the number of actions in the batch.""" - if self.action_type == GraphActionType.EXIT: - raise ValueError("Cannot get the length of exit actions.") - else: - assert self.features is not None - return self.features.shape[0] + return prod(self.batch_shape) def __getitem__(self, index: int | Sequence[int] | Sequence[bool]) -> GraphActions: """Get particular actions of the batch.""" - features = self.features[index] if self.features is not None else None - edge_index = self.edge_index[index] if self.edge_index is not None else None - return GraphActions(self.action_type, features, edge_index) + action_type = self.action_type[index] + features = self.features[index] + edge_index = self.edge_index[:, index] + return GraphActions(action_type, features, edge_index) def __setitem__( self, index: int | Sequence[int] | Sequence[bool], action: GraphActions @@ -246,7 +236,7 @@ def __setitem__( """Set particular actions of the batch.""" self.action_type[index] = action.action_type self.features[index] = action.features - self.edge_index[index] = action.edge_index + self.edge_index[:, index] = action.edge_index def compare(self, other: GraphActions) -> torch.Tensor: """Compares the actions to another GraphAction object. @@ -267,20 +257,15 @@ def compare(self, other: GraphActions) -> torch.Tensor: @property def is_exit(self) -> torch.Tensor: """Returns a boolean tensor of shape `batch_shape` indicating whether the actions are exit actions.""" - return torch.full( - (1,), - self.action_type == GraphActionType.EXIT, - dtype=torch.bool, - device=self.device, - ) + return self.action_type == GraphActionType.EXIT @classmethod def make_dummy_actions( cls, batch_shape: tuple[int] - ) -> GraphActions: # TODO: remove make_dummy_actions + ) -> GraphActions: """Creates an Actions object of dummy actions with the given batch shape.""" - return GraphActions( + return cls( action_type=torch.full(batch_shape, fill_value=GraphActionType.EXIT), - features=None, - edge_index=None, + #features=torch.zeros((*batch_shape, 0, cls.nodes_features_dim)), + #edge_index=torch.zeros((2, *batch_shape, 0)), ) diff --git a/src/gfn/env.py b/src/gfn/env.py index 8db703b3..d47e4cfa 100644 --- a/src/gfn/env.py +++ b/src/gfn/env.py @@ -264,10 +264,12 @@ def _step( not_done_actions = actions[~new_sink_states_idx] new_not_done_states_tensor = self.step(not_done_states, not_done_actions) - if not isinstance(new_not_done_states_tensor, torch.Tensor): - raise Exception( - "User implemented env.step function *must* return a torch.Tensor!" - ) + + # TODO: uncomment (change Data to TensorDict) + # if not isinstance(new_not_done_states_tensor, torch.Tensor): + # raise Exception( + # "User implemented env.step function *must* return a torch.Tensor!" + # ) new_states.tensor[~new_sink_states_idx] = new_not_done_states_tensor @@ -568,8 +570,6 @@ class GraphEnv(Env): def __init__( self, s0: Data, - # node_feature_dim: int, - # edge_feature_dim: int, sf: Optional[Data] = None, device_str: Optional[str] = None, preprocessor: Optional[Preprocessor] = None, @@ -578,8 +578,6 @@ def __init__( Args: s0: The initial graph state. - node_feature_dim: The dimension of the node features. - edge_feature_dim: The dimension of the edge features. action_shape: Tuple representing the shape of the actions. dummy_action: Tensor of shape "action_shape" representing a dummy action. exit_action: Tensor of shape "action_shape" representing the exit action. @@ -591,10 +589,7 @@ def __init__( the IdentityPreprocessor is used. """ self.s0 = s0.to(device_str) - - self.node_feature_dim = s0.x.shape[1] - self.edge_feature_dim = s0.edge_attr.shape[1] - + self.features_dim = s0.x.shape[1] self.sf = sf self.States = self.make_states_class() @@ -609,8 +604,6 @@ def make_states_class(self) -> type[GraphStates]: class GraphEnvStates(GraphStates): s0 = env.s0 sf = env.sf - # node_feature_dim = env.node_feature_dim - # edge_feature_dim = env.edge_feature_dim make_random_states_graph = env.make_random_states_tensor return GraphEnvStates @@ -625,8 +618,7 @@ def make_actions_class(self) -> type[GraphActions]: env = self class DefaultGraphAction(GraphActions): - nodes_features_dim = env.node_feature_dim - edge_features_dim = env.edge_feature_dim + features_dim = env.features_dim return DefaultGraphAction diff --git a/src/gfn/gym/graph_building.py b/src/gfn/gym/graph_building.py index f8a4f624..67c6cbae 100644 --- a/src/gfn/gym/graph_building.py +++ b/src/gfn/gym/graph_building.py @@ -13,22 +13,21 @@ class GraphBuilding(GraphEnv): def __init__( self, - node_feature_dim: int, - edge_feature_dim: int, + feature_dim: int, state_evaluator: Callable[[Batch], torch.Tensor] | None = None, device_str: Literal["cpu", "cuda"] = "cpu", ): s0 = Data( - x=torch.zeros((0, node_feature_dim), dtype=torch.float32), - edge_attr=torch.zeros((0, edge_feature_dim), dtype=torch.float32), + x=torch.zeros((0, feature_dim), dtype=torch.float32), + edge_attr=torch.zeros((0, feature_dim), dtype=torch.float32), edge_index=torch.zeros((2, 0), dtype=torch.long), ).to(device_str) sf = Data( - x=torch.ones((1, node_feature_dim), dtype=torch.float32) * float("inf"), + x=torch.ones((1, feature_dim), dtype=torch.float32) * float("inf"), ).to(device_str) if state_evaluator is None: - state_evaluator = GCNConvEvaluator(node_feature_dim) + state_evaluator = GCNConvEvaluator(feature_dim) self.state_evaluator = state_evaluator super().__init__( @@ -37,7 +36,7 @@ def __init__( device_str=device_str, ) - def step(self, states: GraphStates, actions: GraphActions) -> GraphStates: + def step(self, states: GraphStates, actions: GraphActions) -> Data: """Step function for the GraphBuilding environment. Args: @@ -50,14 +49,17 @@ def step(self, states: GraphStates, actions: GraphActions) -> GraphStates: raise NonValidActionsError("Invalid action.") graphs: Batch = deepcopy(states.data) - if actions.action_type == GraphActionType.ADD_NODE: + action_type = actions.action_type[0] + assert torch.all(actions.action_type == action_type) + + if action_type == GraphActionType.ADD_NODE: assert len(graphs) == len(actions) if graphs.x is None: graphs.x = actions.features else: graphs.x = torch.cat([graphs.x, actions.features]) - if actions.action_type == GraphActionType.ADD_EDGE: + if action_type == GraphActionType.ADD_EDGE: assert len(graphs) == len(actions) assert actions.edge_index is not None if graphs.edge_attr is None: @@ -70,7 +72,7 @@ def step(self, states: GraphStates, actions: GraphActions) -> GraphStates: [graphs.edge_index, actions.edge_index], dim=1 ) - return self.States(graphs) + return graphs def backward_step(self, states: GraphStates, actions: GraphActions) -> GraphStates: """Backward step function for the GraphBuilding environment. @@ -106,56 +108,57 @@ def backward_step(self, states: GraphStates, actions: GraphActions) -> GraphStat def is_action_valid( self, states: GraphStates, actions: GraphActions, backward: bool = False ) -> bool: - if actions.action_type == GraphActionType.EXIT: - return True # TODO: what are the conditions for exit action? - - if actions.action_type == GraphActionType.ADD_NODE: - if actions.edge_index is not None: - return False - if states.data.x is None: - return not backward - + add_node_mask = actions.action_type == GraphActionType.ADD_NODE + if not torch.any(add_node_mask): + add_node_out = True + else: equal_nodes_per_batch = torch.all( - states.data.x == actions.features[:, None], dim=-1 + states[add_node_mask].data.x == actions[add_node_mask].features[:, None], dim=-1 ).reshape(states.data.batch_size, -1) equal_nodes_per_batch = torch.sum(equal_nodes_per_batch, dim=-1) - if backward: # TODO: check if no edge are connected? - return torch.all(equal_nodes_per_batch == 1) - return torch.all(equal_nodes_per_batch == 0) - - if actions.action_type == GraphActionType.ADD_EDGE: - assert actions.edge_index is not None - if torch.any(actions.edge_index[0] == actions.edge_index[1]): + add_node_out = torch.all(equal_nodes_per_batch == 1) + else: + add_node_out = torch.all(equal_nodes_per_batch == 0) + + add_edge_mask = actions.action_type == GraphActionType.ADD_EDGE + if not torch.any(add_edge_mask): + add_edge_out = True + else: + add_edge_states = states[add_edge_mask] + add_edge_actions = actions[add_edge_mask] + + if torch.any(add_edge_actions.edge_index[0] == add_edge_actions.edge_index[1]): return False - if states.data.num_nodes is None or states.data.num_nodes == 0: + if add_edge_states.data.num_nodes == 0: return False - if torch.any(actions.edge_index > states.data.num_nodes): + if torch.any(add_edge_actions.edge_index > add_edge_states.data.num_nodes): return False - batch_dim = actions.features.shape[0] - batch_idx = actions.edge_index % batch_dim + batch_dim = add_edge_actions.features.shape[0] + batch_idx = add_edge_actions.edge_index % batch_dim if torch.any(batch_idx != torch.arange(batch_dim)): return False - if states.data.edge_attr is None: - return True equal_edges_per_batch_attr = torch.all( - states.data.edge_attr == actions.features[:, None], dim=-1 - ).reshape(states.data.batch_size, -1) + add_edge_states.data.edge_attr == add_edge_actions.features[:, None], dim=-1 + ).reshape(add_edge_states.data.batch_size, -1) equal_edges_per_batch_attr = torch.sum(equal_edges_per_batch_attr, dim=-1) - equal_edges_per_batch_index = torch.all( - states.data.edge_index[:, None] == actions.edge_index[:, :, None], dim=0 - ).reshape(states.data.batch_size, -1) + add_edge_states.data.edge_index[:, None] == add_edge_actions.edge_index[:, :, None], dim=0 + ).reshape(add_edge_states.data.batch_size, -1) equal_edges_per_batch_index = torch.sum(equal_edges_per_batch_index, dim=-1) + if backward: - return torch.all(equal_edges_per_batch_attr == 1) and torch.all( + add_edge_out = torch.all(equal_edges_per_batch_attr == 1) and torch.all( equal_edges_per_batch_index == 1 ) - return torch.all(equal_edges_per_batch_attr == 0) and torch.all( - equal_edges_per_batch_index == 0 - ) + else: + add_edge_out = torch.all(equal_edges_per_batch_attr == 0) and torch.all( + equal_edges_per_batch_index == 0 + ) + + return bool(add_node_out) and bool(add_edge_out) def reward(self, final_states: GraphStates) -> torch.Tensor: """The environment's reward given a state. @@ -180,7 +183,7 @@ def true_dist_pmf(self) -> torch.Tensor: raise NotImplementedError def make_random_states_tensor(self, batch_shape: Tuple) -> GraphStates: - """Generates random states tensor of shape (*batch_shape, num_nodes, node_feature_dim).""" + """Generates random states tensor of shape (*batch_shape, feature_dim).""" return self.States.from_batch_shape(batch_shape) diff --git a/src/gfn/states.py b/src/gfn/states.py index b6c9bfb7..8e5ec735 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -491,16 +491,16 @@ class GraphStates(ABC): def __init__(self, graphs: Batch): self.data: Batch = graphs - self.batch_shape: int = len(self.data) + self.batch_shape: tuple = (len(self.data),) self._log_rewards: float = None # TODO logic repeated from env.is_valid_action not_empty = self.data.x is not None and self.data.x.shape[0] > 0 - self.forward_masks = torch.ones((self.batch_shape, 3), dtype=torch.bool) + self.forward_masks = torch.ones((*self.batch_shape, 3), dtype=torch.bool) self.forward_masks[:, GraphActionType.ADD_EDGE] = not_empty self.forward_masks[:, GraphActionType.EXIT] = not_empty - self.backward_masks = torch.ones((self.batch_shape, 3), dtype=torch.bool) + self.backward_masks = torch.ones((*self.batch_shape, 3), dtype=torch.bool) self.backward_masks[:, GraphActionType.ADD_NODE] = not_empty self.backward_masks[:, GraphActionType.ADD_EDGE] = ( not_empty and self.data.edge_attr.shape[0] > 0 @@ -580,10 +580,13 @@ def __getitem__( self, index: int | Sequence[int] | slice | torch.Tensor ) -> GraphStates: idxs = np.arange(len(self.data))[index] - data = [] - for i in idxs: - data.append(self.data.get_example(i)) - + data = [self.data.get_example(i) for i in idxs] + if len(data) == 0: + data.append(Data( + x=torch.zeros((0, self.data.x.shape[1]), dtype=torch.float32), + edge_attr=torch.zeros((0, self.data.edge_attr.shape[1]), dtype=torch.float32), + edge_index=torch.zeros((2, 0), dtype=torch.long), + )) out = GraphStates(Batch.from_data_list(data)) if self._log_rewards is not None: @@ -666,3 +669,7 @@ def is_sink_state(self) -> torch.Tensor: return torch.all(self.data.x == self.sf.x, dim=-1).reshape( batch_dim, ) + + @property + def tensor(self) -> Batch: + return self.data diff --git a/src/gfn/utils/distributions.py b/src/gfn/utils/distributions.py index 421a04e9..1fa6aa89 100644 --- a/src/gfn/utils/distributions.py +++ b/src/gfn/utils/distributions.py @@ -43,7 +43,7 @@ def log_prob(self, sample: torch.Tensor) -> torch.Tensor: return super().log_prob(sample.squeeze(-1)) -class ComposedDistribution(Distribution): +class ComposedDistribution(Distribution): # TODO: CompositeDistribution in TensorDict """A mixture distribution.""" def __init__(self, dists: Dict[str, Distribution]): diff --git a/testing/test_environments.py b/testing/test_environments.py index 80198fd2..948742bf 100644 --- a/testing/test_environments.py +++ b/testing/test_environments.py @@ -325,7 +325,7 @@ def test_graph_env(): BATCH_SIZE = 3 NUM_NODES = 5 - env = GraphBuilding(node_feature_dim=FEATURE_DIM, edge_feature_dim=FEATURE_DIM) + env = GraphBuilding(feature_dim=FEATURE_DIM) states = env.reset(batch_shape=BATCH_SIZE) assert states.batch_shape == BATCH_SIZE action_cls = env.make_actions_class() diff --git a/testing/test_samplers_and_trajectories.py b/testing/test_samplers_and_trajectories.py index 302f31b3..d8a1c260 100644 --- a/testing/test_samplers_and_trajectories.py +++ b/testing/test_samplers_and_trajectories.py @@ -225,7 +225,6 @@ def test_replay_buffer( # ------ GRAPH TESTS ------ - class GraphActionNet(nn.Module): def __init__(self, feature_dim: int): super().__init__() @@ -246,7 +245,7 @@ def forward(self, states: GraphStates) -> TensorDict: features = features.reshape(len(states), -1, features.shape[-1]).mean(dim=0) edge_index = self.edge_index_conv(states.data.x, states.data.edge_index) - edge_index = edge_index.reshape(states.batch_shape, -1, 8) + edge_index = edge_index.reshape(*states.batch_shape, -1, 8) edge_index = torch.einsum("bnf,bmf->bnm", edge_index, edge_index) return TensorDict( @@ -261,7 +260,7 @@ def forward(self, states: GraphStates) -> TensorDict: def test_graph_building(): feature_dim = 8 - env = GraphBuilding(node_feature_dim=feature_dim, edge_feature_dim=feature_dim) + env = GraphBuilding(feature_dim=feature_dim) module = GraphActionNet(feature_dim) pf_estimator = GraphActionPolicyEstimator(module=module) From 34781efbd1cd420676682e62bf8d765996218fe3 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Thu, 28 Nov 2024 22:25:15 +0100 Subject: [PATCH 013/102] use tensordict in actions --- src/gfn/actions.py | 57 +++++++++++++++++++++++++++------------------- 1 file changed, 34 insertions(+), 23 deletions(-) diff --git a/src/gfn/actions.py b/src/gfn/actions.py index 89119fa1..400a9480 100644 --- a/src/gfn/actions.py +++ b/src/gfn/actions.py @@ -6,6 +6,7 @@ from typing import ClassVar, Optional, Sequence import torch +from tensordict import TensorDict class Actions(ABC): @@ -196,8 +197,6 @@ def __init__( This must defined if and only if the action type is GraphActionType.AddEdge. """ self.batch_shape = action_type.shape - self.action_type = action_type - if features is None: assert torch.all(action_type == GraphActionType.EXIT) features = torch.zeros((*self.batch_shape, self.features_dim)) @@ -205,19 +204,19 @@ def __init__( assert torch.all(action_type != GraphActionType.ADD_EDGE) edge_index = torch.zeros((2, *self.batch_shape)) - batch_dim, _ = features.shape - assert (batch_dim,) == self.batch_shape - assert edge_index.shape == (2, batch_dim) - self.features = features - self.edge_index = edge_index + self.tensor = TensorDict({ + "action_type": action_type, + "features": features, + "edge_index": edge_index.T, + }, batch_size=self.batch_shape) def __repr__(self): - return f"""GraphAction object of type {self.action_type} and features of shape {self.features.shape}.""" + return f"""GraphAction object with {self.batch_shape} actions.""" @property def device(self) -> torch.device: """Returns the device of the features tensor.""" - return self.features.device + return self.tensor.device def __len__(self) -> int: """Returns the number of actions in the batch.""" @@ -225,18 +224,18 @@ def __len__(self) -> int: def __getitem__(self, index: int | Sequence[int] | Sequence[bool]) -> GraphActions: """Get particular actions of the batch.""" - action_type = self.action_type[index] - features = self.features[index] - edge_index = self.edge_index[:, index] - return GraphActions(action_type, features, edge_index) + tensor = self.tensor[index] + return GraphActions( + tensor["action_type"], + tensor["features"], + tensor["edge_index"].T + ) def __setitem__( self, index: int | Sequence[int] | Sequence[bool], action: GraphActions ) -> None: """Set particular actions of the batch.""" - self.action_type[index] = action.action_type - self.features[index] = action.features - self.edge_index[:, index] = action.edge_index + self.tensor[index] = action.tensor def compare(self, other: GraphActions) -> torch.Tensor: """Compares the actions to another GraphAction object. @@ -246,19 +245,31 @@ def compare(self, other: GraphActions) -> torch.Tensor: Returns: boolean tensor of shape batch_shape indicating whether the actions are equal. """ - if self.action_type != other.action_type: - len_ = self.features.shape[0] if self.features is not None else 1 - return torch.zeros(len_, dtype=torch.bool, device=self.device) - out = torch.all(self.features == other.features, dim=-1) - if self.edge_index is not None: - out &= torch.all(self.edge_index == other.edge_index, dim=-1) - return out + compare = torch.all(self.tensor == other.tensor, dim=-1) + return compare["action_type"] & \ + (compare["action_type"] == GraphActionType.EXIT | compare["features"]) & \ + (compare["action_type"] != GraphActionType.ADD_EDGE | compare["edge_index"]) @property def is_exit(self) -> torch.Tensor: """Returns a boolean tensor of shape `batch_shape` indicating whether the actions are exit actions.""" return self.action_type == GraphActionType.EXIT + @property + def action_type(self) -> torch.Tensor: + """Returns the action type tensor.""" + return self.tensor["action_type"] + + @property + def features(self) -> torch.Tensor: + """Returns the features tensor.""" + return self.tensor["features"] + + @property + def edge_index(self) -> torch.Tensor: + """Returns the edge index tensor.""" + return self.tensor["edge_index"].T + @classmethod def make_dummy_actions( cls, batch_shape: tuple[int] From 3e584f2168eec0a6688b0ccbba7f1b00bd68549f Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Mon, 2 Dec 2024 19:50:59 +0100 Subject: [PATCH 014/102] handle sf --- src/gfn/gym/graph_building.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/gfn/gym/graph_building.py b/src/gfn/gym/graph_building.py index 67c6cbae..f3448068 100644 --- a/src/gfn/gym/graph_building.py +++ b/src/gfn/gym/graph_building.py @@ -52,6 +52,9 @@ def step(self, states: GraphStates, actions: GraphActions) -> Data: action_type = actions.action_type[0] assert torch.all(actions.action_type == action_type) + if action_type == GraphActionType.EXIT: + return self.sf # TODO: not possible to backtrack then... maybe a boolen in state? + if action_type == GraphActionType.ADD_NODE: assert len(graphs) == len(actions) if graphs.x is None: @@ -72,6 +75,7 @@ def step(self, states: GraphStates, actions: GraphActions) -> Data: [graphs.edge_index, actions.edge_index], dim=1 ) + import pdb; pdb.set_trace() return graphs def backward_step(self, states: GraphStates, actions: GraphActions) -> GraphStates: From d5e438f80444f267f647b8af2a55dd8a2071b951 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Tue, 3 Dec 2024 15:27:40 +0100 Subject: [PATCH 015/102] remove Data --- src/gfn/actions.py | 8 +- src/gfn/env.py | 26 ++-- src/gfn/gym/graph_building.py | 90 +++++++------- src/gfn/modules.py | 6 +- src/gfn/samplers.py | 2 +- src/gfn/states.py | 142 +++++++++------------- src/gfn/utils/distributions.py | 24 ++++ testing/test_samplers_and_trajectories.py | 18 ++- 8 files changed, 154 insertions(+), 162 deletions(-) diff --git a/src/gfn/actions.py b/src/gfn/actions.py index 400a9480..36c52ae5 100644 --- a/src/gfn/actions.py +++ b/src/gfn/actions.py @@ -202,12 +202,12 @@ def __init__( features = torch.zeros((*self.batch_shape, self.features_dim)) if edge_index is None: assert torch.all(action_type != GraphActionType.ADD_EDGE) - edge_index = torch.zeros((2, *self.batch_shape)) + edge_index = torch.zeros((*self.batch_shape, 2), dtype=torch.long) self.tensor = TensorDict({ "action_type": action_type, "features": features, - "edge_index": edge_index.T, + "edge_index": edge_index, }, batch_size=self.batch_shape) def __repr__(self): @@ -228,7 +228,7 @@ def __getitem__(self, index: int | Sequence[int] | Sequence[bool]) -> GraphActio return GraphActions( tensor["action_type"], tensor["features"], - tensor["edge_index"].T + tensor["edge_index"] ) def __setitem__( @@ -268,7 +268,7 @@ def features(self) -> torch.Tensor: @property def edge_index(self) -> torch.Tensor: """Returns the edge index tensor.""" - return self.tensor["edge_index"].T + return self.tensor["edge_index"] @classmethod def make_dummy_actions( diff --git a/src/gfn/env.py b/src/gfn/env.py index d47e4cfa..e09c5b79 100644 --- a/src/gfn/env.py +++ b/src/gfn/env.py @@ -2,7 +2,7 @@ from typing import Dict, Optional, Tuple, Union import torch -from torch_geometric.data import Batch, Data +from tensordict import TensorDict from gfn.actions import Actions, GraphActions from gfn.preprocessors import IdentityPreprocessor, Preprocessor @@ -256,7 +256,8 @@ def _step( ) new_sink_states_idx = actions.is_exit - new_states.tensor[new_sink_states_idx] = self.sf + sf_tensor = self.States.make_sink_states_tensor(new_sink_states_idx.sum()) + new_states[new_sink_states_idx] = self.States(sf_tensor) new_sink_states_idx = ~valid_states_idx | new_sink_states_idx assert new_sink_states_idx.shape == states.batch_shape @@ -265,14 +266,12 @@ def _step( new_not_done_states_tensor = self.step(not_done_states, not_done_actions) - # TODO: uncomment (change Data to TensorDict) - # if not isinstance(new_not_done_states_tensor, torch.Tensor): - # raise Exception( - # "User implemented env.step function *must* return a torch.Tensor!" - # ) - - new_states.tensor[~new_sink_states_idx] = new_not_done_states_tensor + if not isinstance(new_not_done_states_tensor, (torch.Tensor, TensorDict)): + raise Exception( + "User implemented env.step function *must* return a torch.Tensor!" + ) + new_states[~new_sink_states_idx] = self.States(new_not_done_states_tensor) return new_states def _backward_step( @@ -569,8 +568,8 @@ class GraphEnv(Env): def __init__( self, - s0: Data, - sf: Optional[Data] = None, + s0: TensorDict, + sf: Optional[TensorDict] = None, device_str: Optional[str] = None, preprocessor: Optional[Preprocessor] = None, ): @@ -578,9 +577,6 @@ def __init__( Args: s0: The initial graph state. - action_shape: Tuple representing the shape of the actions. - dummy_action: Tensor of shape "action_shape" representing a dummy action. - exit_action: Tensor of shape "action_shape" representing the exit action. sf: The final graph state. device_str: 'cpu' or 'cuda'. Defaults to None, in which case the device is inferred from s0. @@ -589,7 +585,7 @@ def __init__( the IdentityPreprocessor is used. """ self.s0 = s0.to(device_str) - self.features_dim = s0.x.shape[1] + self.features_dim = s0["node_feature"].shape[-1] self.sf = sf self.States = self.make_states_class() diff --git a/src/gfn/gym/graph_building.py b/src/gfn/gym/graph_building.py index f3448068..12fa8943 100644 --- a/src/gfn/gym/graph_building.py +++ b/src/gfn/gym/graph_building.py @@ -2,8 +2,8 @@ from typing import Callable, Literal, Tuple import torch -from torch_geometric.data import Batch, Data from torch_geometric.nn import GCNConv +from tensordict import TensorDict from gfn.actions import GraphActions, GraphActionType from gfn.env import GraphEnv, NonValidActionsError @@ -14,17 +14,19 @@ class GraphBuilding(GraphEnv): def __init__( self, feature_dim: int, - state_evaluator: Callable[[Batch], torch.Tensor] | None = None, + state_evaluator: Callable[[GraphStates], torch.Tensor] | None = None, device_str: Literal["cpu", "cuda"] = "cpu", ): - s0 = Data( - x=torch.zeros((0, feature_dim), dtype=torch.float32), - edge_attr=torch.zeros((0, feature_dim), dtype=torch.float32), - edge_index=torch.zeros((2, 0), dtype=torch.long), - ).to(device_str) - sf = Data( - x=torch.ones((1, feature_dim), dtype=torch.float32) * float("inf"), - ).to(device_str) + s0 = TensorDict({ + "node_feature": torch.zeros((0, feature_dim), dtype=torch.float32), + "edge_feature": torch.zeros((0, feature_dim), dtype=torch.float32), + "edge_index": torch.zeros((0, 2), dtype=torch.long), + }, device=device_str) + sf = TensorDict({ + "node_feature": torch.ones((1, feature_dim), dtype=torch.float32) * float("inf"), + "edge_feature": torch.ones((1, feature_dim), dtype=torch.float32) * float("inf"), + "edge_index": torch.zeros((0, 2), dtype=torch.long), + }, device=device_str) if state_evaluator is None: state_evaluator = GCNConvEvaluator(feature_dim) @@ -36,7 +38,7 @@ def __init__( device_str=device_str, ) - def step(self, states: GraphStates, actions: GraphActions) -> Data: + def step(self, states: GraphStates, actions: GraphActions) -> TensorDict: """Step function for the GraphBuilding environment. Args: @@ -47,36 +49,24 @@ def step(self, states: GraphStates, actions: GraphActions) -> Data: """ if not self.is_action_valid(states, actions): raise NonValidActionsError("Invalid action.") - graphs: Batch = deepcopy(states.data) + state_tensor = deepcopy(states.tensor) action_type = actions.action_type[0] assert torch.all(actions.action_type == action_type) - if action_type == GraphActionType.EXIT: - return self.sf # TODO: not possible to backtrack then... maybe a boolen in state? + return self.States.make_sink_states_tensor(states.batch_shape) if action_type == GraphActionType.ADD_NODE: - assert len(graphs) == len(actions) - if graphs.x is None: - graphs.x = actions.features - else: - graphs.x = torch.cat([graphs.x, actions.features]) + assert len(state_tensor) == len(actions) + state_tensor["node_feature"] = torch.cat([state_tensor["node_feature"], actions.features[:, None]], dim=1) if action_type == GraphActionType.ADD_EDGE: - assert len(graphs) == len(actions) - assert actions.edge_index is not None - if graphs.edge_attr is None: - graphs.edge_attr = actions.features - assert graphs.edge_index is None - graphs.edge_index = actions.edge_index - else: - graphs.edge_attr = torch.cat([graphs.edge_attr, actions.features]) - graphs.edge_index = torch.cat( - [graphs.edge_index, actions.edge_index], dim=1 - ) - - import pdb; pdb.set_trace() - return graphs + assert len(state_tensor) == len(actions) + state_tensor["edge_feature"] = torch.cat([state_tensor["edge_feature"], actions.features[:, None]], dim=1) + state_tensor["edge_index"] = torch.cat( + [state_tensor["edge_index"], actions.edge_index[:, None]], dim=1 + ) + return state_tensor def backward_step(self, states: GraphStates, actions: GraphActions) -> GraphStates: """Backward step function for the GraphBuilding environment. @@ -116,9 +106,10 @@ def is_action_valid( if not torch.any(add_node_mask): add_node_out = True else: + node_feature = states.tensor["node_feature"][add_node_mask] equal_nodes_per_batch = torch.all( - states[add_node_mask].data.x == actions[add_node_mask].features[:, None], dim=-1 - ).reshape(states.data.batch_size, -1) + node_feature == actions[add_node_mask].features[:, None], dim=-1 + ).reshape(len(node_feature), -1) equal_nodes_per_batch = torch.sum(equal_nodes_per_batch, dim=-1) if backward: # TODO: check if no edge are connected? add_node_out = torch.all(equal_nodes_per_batch == 1) @@ -129,14 +120,14 @@ def is_action_valid( if not torch.any(add_edge_mask): add_edge_out = True else: - add_edge_states = states[add_edge_mask] + add_edge_states = states[add_edge_mask].tensor add_edge_actions = actions[add_edge_mask] - if torch.any(add_edge_actions.edge_index[0] == add_edge_actions.edge_index[1]): + if torch.any(add_edge_actions.edge_index[:, 0] == add_edge_actions.edge_index[:, 1]): return False - if add_edge_states.data.num_nodes == 0: + if add_edge_states["node_feature"].shape[1] == 0: return False - if torch.any(add_edge_actions.edge_index > add_edge_states.data.num_nodes): + if torch.any(add_edge_actions.edge_index > add_edge_states["node_feature"].shape[1]): return False batch_dim = add_edge_actions.features.shape[0] @@ -145,12 +136,12 @@ def is_action_valid( return False equal_edges_per_batch_attr = torch.all( - add_edge_states.data.edge_attr == add_edge_actions.features[:, None], dim=-1 - ).reshape(add_edge_states.data.batch_size, -1) + add_edge_states["edge_feature"] == add_edge_actions.features[:, None], dim=-1 + ).reshape(len(add_edge_states), -1) equal_edges_per_batch_attr = torch.sum(equal_edges_per_batch_attr, dim=-1) equal_edges_per_batch_index = torch.all( - add_edge_states.data.edge_index[:, None] == add_edge_actions.edge_index[:, :, None], dim=0 - ).reshape(add_edge_states.data.batch_size, -1) + add_edge_states["edge_index"] == add_edge_actions.edge_index, dim=0 + ).reshape(len(add_edge_states), -1) equal_edges_per_batch_index = torch.sum(equal_edges_per_batch_index, dim=-1) if backward: @@ -161,7 +152,7 @@ def is_action_valid( add_edge_out = torch.all(equal_edges_per_batch_attr == 0) and torch.all( equal_edges_per_batch_index == 0 ) - + return bool(add_node_out) and bool(add_edge_out) def reward(self, final_states: GraphStates) -> torch.Tensor: @@ -174,7 +165,7 @@ def reward(self, final_states: GraphStates) -> torch.Tensor: Returns: torch.Tensor: Tensor of shape "batch_shape" containing the rewards. """ - return self.state_evaluator(final_states.data) + return self.state_evaluator(final_states) @property def log_partition(self) -> float: @@ -193,9 +184,12 @@ def make_random_states_tensor(self, batch_shape: Tuple) -> GraphStates: class GCNConvEvaluator: def __init__(self, num_features): + self.num_features = num_features self.net = GCNConv(num_features, 1) - def __call__(self, batch: Batch) -> torch.Tensor: - out = self.net(batch.x, batch.edge_index) - out = out.reshape(batch.batch_size, -1) + def __call__(self, state: GraphStates) -> torch.Tensor: + node_feature = state.tensor["node_feature"].reshape(-1, self.num_features) + edge_index = state.tensor["edge_index"].reshape(-1, 2).T + out = self.net(node_feature, edge_index) + out = out.reshape(len(state), state.tensor["node_feature"].shape[1]) return out.mean(-1) diff --git a/src/gfn/modules.py b/src/gfn/modules.py index 90087e5a..9d0f52f7 100644 --- a/src/gfn/modules.py +++ b/src/gfn/modules.py @@ -9,7 +9,7 @@ from gfn.actions import GraphActionType from gfn.preprocessors import IdentityPreprocessor, Preprocessor from gfn.states import DiscreteStates, GraphStates, States -from gfn.utils.distributions import ComposedDistribution, UnsqueezedCategorical +from gfn.utils.distributions import CategoricalIndexes, ComposedDistribution, UnsqueezedCategorical REDUCTION_FXNS = { "mean": torch.mean, @@ -518,7 +518,7 @@ def to_probability_distribution( dists["action_type"] = Categorical(probs=action_type_probs) edge_index_logits = module_output["edge_index"] - if edge_index_logits.shape[-1] != 0: + if states.tensor["node_feature"].shape[1] > 1: edge_index_probs = torch.softmax(edge_index_logits / temperature, dim=-1) uniform_dist_probs = ( torch.ones_like(edge_index_probs) / edge_index_probs.shape[-1] @@ -526,7 +526,7 @@ def to_probability_distribution( edge_index_probs = ( 1 - epsilon ) * edge_index_probs + epsilon * uniform_dist_probs - dists["edge_index"] = UnsqueezedCategorical(probs=edge_index_probs) + dists["edge_index"] = CategoricalIndexes(probs=edge_index_probs) dists["features"] = Normal(module_output["features"], temperature) return ComposedDistribution(dists=dists) diff --git a/src/gfn/samplers.py b/src/gfn/samplers.py index 4b706a39..083d8792 100644 --- a/src/gfn/samplers.py +++ b/src/gfn/samplers.py @@ -167,7 +167,7 @@ def sample_trajectories( step = 0 all_estimator_outputs = [] - + while not all(dones): actions = env.actions_from_batch_shape((n_trajectories,)) # Dummy actions. log_probs = torch.full( diff --git a/src/gfn/states.py b/src/gfn/states.py index 8e5ec735..81762909 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -7,6 +7,7 @@ import numpy as np import torch +from tensordict import TensorDict from torch_geometric.data import Batch, Data from gfn.actions import GraphActionType @@ -486,16 +487,19 @@ class GraphStates(ABC): graph objects as states. """ - s0: ClassVar[Data] - sf: ClassVar[Optional[Data]] + s0: ClassVar[TensorDict] + sf: ClassVar[TensorDict] - def __init__(self, graphs: Batch): - self.data: Batch = graphs - self.batch_shape: tuple = (len(self.data),) + def __init__(self, tensor: TensorDict): + self.tensor = tensor + self.node_features_dim = tensor["node_feature"].shape[-1] + self.edge_features_dim = tensor["edge_feature"].shape[-1] + + self.batch_shape: tuple = tensor.batch_size self._log_rewards: float = None # TODO logic repeated from env.is_valid_action - not_empty = self.data.x is not None and self.data.x.shape[0] > 0 + not_empty = self.tensor["node_feature"].shape[1] > 1 self.forward_masks = torch.ones((*self.batch_shape, 3), dtype=torch.bool) self.forward_masks[:, GraphActionType.ADD_EDGE] = not_empty self.forward_masks[:, GraphActionType.EXIT] = not_empty @@ -503,7 +507,7 @@ def __init__(self, graphs: Batch): self.backward_masks = torch.ones((*self.batch_shape, 3), dtype=torch.bool) self.backward_masks[:, GraphActionType.ADD_NODE] = not_empty self.backward_masks[:, GraphActionType.ADD_EDGE] = ( - not_empty and self.data.edge_attr.shape[0] > 0 + not_empty and self.tensor["edge_feature"].shape[1] > 0 > 0 ) self.backward_masks[:, GraphActionType.EXIT] = not_empty @@ -514,15 +518,15 @@ def from_batch_shape( if random and sink: raise ValueError("Only one of `random` and `sink` should be True.") if random: - data = cls.make_random_states_graph(batch_shape) + tensor = cls.make_random_states_tensor(batch_shape) elif sink: - data = cls.make_sink_states_graph(batch_shape) + tensor = cls.make_sink_states_tensor(batch_shape) else: - data = cls.make_initial_states_graph(batch_shape) - return cls(data) + tensor = cls.make_initial_states_tensor(batch_shape) + return cls(tensor) @classmethod - def make_initial_states_graph(cls, batch_shape: int | Tuple) -> Batch: + def make_initial_states_tensor(cls, batch_shape: int | Tuple) -> TensorDict: if isinstance(batch_shape, Tuple) and len(batch_shape) > 1: raise NotImplementedError( "Batch shape with more than one dimension is not supported" @@ -530,11 +534,14 @@ def make_initial_states_graph(cls, batch_shape: int | Tuple) -> Batch: if isinstance(batch_shape, Tuple): batch_shape = batch_shape[0] - data = Batch.from_data_list([cls.s0 for _ in range(batch_shape)]) - return data + return TensorDict({ + "node_feature": cls.s0["node_feature"].repeat(batch_shape, 1, 1), + "edge_feature": cls.s0["edge_feature"].repeat(batch_shape, 1, 1), + "edge_index": cls.s0["edge_index"].repeat(batch_shape, 1, 1) + }, batch_size=batch_shape) @classmethod - def make_sink_states_graph(cls, batch_shape: Tuple) -> Batch: + def make_sink_states_tensor(cls, batch_shape: Tuple) -> TensorDict: if cls.sf is None: raise NotImplementedError("Sink state is not defined") @@ -545,11 +552,14 @@ def make_sink_states_graph(cls, batch_shape: Tuple) -> Batch: if isinstance(batch_shape, Tuple): batch_shape = batch_shape[0] - data = Batch.from_data_list([cls.sf for _ in range(batch_shape)]) - return data + return TensorDict({ + "node_feature": cls.sf["node_feature"].repeat(batch_shape, 1, 1), + "edge_feature": cls.sf["edge_feature"].repeat(batch_shape, 1, 1), + "edge_index": cls.sf["edge_index"].repeat(batch_shape, 1, 1) + }, batch_size=int(batch_shape)) @classmethod - def make_random_states_graph(cls, batch_shape: int) -> Batch: + def make_random_states_tensor(cls, batch_shape: int) -> TensorDict: if isinstance(batch_shape, Tuple) and len(batch_shape) > 1: raise NotImplementedError( "Batch shape with more than one dimension is not supported" @@ -557,37 +567,30 @@ def make_random_states_graph(cls, batch_shape: int) -> Batch: if isinstance(batch_shape, Tuple): batch_shape = batch_shape[0] - data_list = [] - for _ in range(batch_shape): - data = Data( - x=torch.rand(cls.s0.num_nodes, cls.s0.x.shape[1]), - edge_attr=torch.rand(cls.s0.num_edges, cls.s0.edge_attr.shape[1]), - edge_index=cls.s0.edge_index, # TODO: make it random - ) - data_list.append(data) - return Batch.from_data_list(data_list) + num_nodes = np.random.randint(10) + num_edges = np.random.randint(num_nodes * (num_nodes - 1) // 2) + node_features_dim = cls.s0["node_feature"].shape[-1] + edge_features_dim = cls.s0["edge_feature"].shape[-1] + tensor = TensorDict({ + "node_feature": torch.rand(batch_shape, num_nodes, node_features_dim), + "edge_feature": torch.rand(batch_shape, num_edges, edge_features_dim), + "edge_index": torch.randint(num_nodes, size=(batch_shape, num_edges, 2)), + }) + return tensor def __len__(self): - return self.data.batch_size + return np.prod(self.tensor.batch_size) def __repr__(self): return ( f"{self.__class__.__name__} object of batch shape {self.batch_shape} and " - f"node feature dim {self.s0.x.shape[1]} and edge feature dim {self.s0.edge_attr.shape[1]}" + f"node feature dim {self.node_features_dim} and edge feature dim {self.edge_features_dim}" ) def __getitem__( self, index: int | Sequence[int] | slice | torch.Tensor ) -> GraphStates: - idxs = np.arange(len(self.data))[index] - data = [self.data.get_example(i) for i in idxs] - if len(data) == 0: - data.append(Data( - x=torch.zeros((0, self.data.x.shape[1]), dtype=torch.float32), - edge_attr=torch.zeros((0, self.data.edge_attr.shape[1]), dtype=torch.float32), - edge_index=torch.zeros((2, 0), dtype=torch.long), - )) - out = GraphStates(Batch.from_data_list(data)) + out = GraphStates(self.tensor[index]) if self._log_rewards is not None: out._log_rewards = self._log_rewards[index] @@ -598,44 +601,21 @@ def __setitem__(self, index: int | Sequence[int], graph: GraphStates): """ Set particular states of the Batch """ - data_list = self.data.to_data_list() - if isinstance(index, int): - assert ( - len(graph) == 1 - ), "GraphStates must have a batch size of 1 for single index assignment" - data_list[index] = graph.data[0] - self.data = Batch.from_data_list(data_list) - elif isinstance(index, Sequence): - assert len(index) == len( - graph - ), "Index and GraphState must have the same length" - for i, idx in enumerate(index): - data_list[idx] = graph.data[i] - self.data = Batch.from_data_list(data_list) - elif isinstance(index, slice): - assert index.stop - index.start == len( - graph - ), "Index slice and GraphStates must have the same length" - data_list[index] = graph.data.to_data_list() - self.data = Batch.from_data_list(data_list) - else: - raise NotImplementedError( - "Setters with type {} is not implemented".format(type(index)) - ) + len_index = len(self.tensor[index]) + if len_index != 0 and len_index != len(self.tensor): + raise ValueError("Can only set states with the same batch size as the original batch") + + self.tensor[index] = graph.tensor @property def device(self) -> torch.device: - sample = self.data.get_example(0).x - if sample is not None: - return sample.device - return torch.device("cuda" if torch.cuda.is_available() else "cpu") + return self.tensor.device def to(self, device: torch.device) -> GraphStates: """ Moves and/or casts the graph states to the specified device """ - if self.device != device: - self.data = self.data.to(device) + self.tensor = self.tensor.to(device) return self def clone(self) -> GraphStates: @@ -644,14 +624,9 @@ def clone(self) -> GraphStates: def extend(self, other: GraphStates): """Concatenates to another GraphStates object along the batch dimension""" - self.data = Batch.from_data_list( - self.data.to_data_list() + other.data.to_data_list() - ) - if self._log_rewards is not None: - assert other._log_rewards is not None - self._log_rewards = torch.cat( - (self._log_rewards, other._log_rewards), dim=0 - ) + self.tensor["node_feature"] = torch.cat([self.tensor["node_feature"], other.tensor["node_feature"]], dim=1) + self.tensor["edge_feature"] = torch.cat([self.tensor["edge_feature"], other.tensor["edge_feature"]], dim=1) + self.tensor["edge_index"] = torch.cat([self.tensor["edge_index"], other.tensor["edge_index"]], dim=1) @property def log_rewards(self) -> torch.Tensor: @@ -663,13 +638,10 @@ def log_rewards(self, log_rewards: torch.Tensor) -> None: @property def is_sink_state(self) -> torch.Tensor: - batch_dim = len(self.data.ptr) - 1 - if len(self.data.x) == 0: - return torch.zeros(batch_dim, dtype=torch.bool) - return torch.all(self.data.x == self.sf.x, dim=-1).reshape( - batch_dim, + if self.tensor["node_feature"].shape[1] == 0: + return torch.zeros(self.batch_shape, dtype=torch.bool) + return ( + torch.all(self.tensor["node_feature"] == self.sf["node_feature"], dim=(1, 2)) & + torch.all(self.tensor["edge_feature"] == self.sf["edge_feature"], dim=(1, 2)) & + torch.all(self.tensor["edge_index"] == self.sf["edge_index"], dim=(1, 2)) ) - - @property - def tensor(self) -> Batch: - return self.data diff --git a/src/gfn/utils/distributions.py b/src/gfn/utils/distributions.py index 1fa6aa89..e1e387c6 100644 --- a/src/gfn/utils/distributions.py +++ b/src/gfn/utils/distributions.py @@ -64,3 +64,27 @@ def log_prob(self, sample: Dict[str, torch.Tensor]) -> torch.Tensor: for k, v in self.dists.items() ] return sum(log_probs) + + +class CategoricalIndexes(Categorical): + """Samples indexes from a categorical distribution.""" + + def __init__(self, probs: torch.Tensor): + """Initializes the distribution. + + Args: + probs: The probabilities of the categorical distribution. + """ + self.n = probs.shape[-1] + batch_size = probs.shape[0] + assert probs.shape == (batch_size, self.n, self.n) + super().__init__(probs.reshape(batch_size, self.n * self.n)) + + def sample(self, sample_shape=torch.Size()) -> torch.Tensor: + samples = super().sample(sample_shape) + out = torch.stack([samples // self.n, samples % self.n], dim=-1) + return out + + def log_prob(self, value): + value = value[..., 0] * self.n + value[..., 1] + return super().log_prob(value) \ No newline at end of file diff --git a/testing/test_samplers_and_trajectories.py b/testing/test_samplers_and_trajectories.py index d8a1c260..96911165 100644 --- a/testing/test_samplers_and_trajectories.py +++ b/testing/test_samplers_and_trajectories.py @@ -234,19 +234,24 @@ def __init__(self, feature_dim: int): self.edge_index_conv = GCNConv(feature_dim, 8) def forward(self, states: GraphStates) -> TensorDict: - if len(states.data.x) == 0: + node_feature = states.tensor["node_feature"].reshape(-1, self.feature_dim) + edge_index = states.tensor["edge_index"].reshape(-1, 2).T + + if states.tensor["node_feature"].shape[1] == 0: action_type = torch.zeros((len(states), len(GraphActionType))) action_type[:, GraphActionType.ADD_NODE] = 1 features = torch.zeros((len(states), self.feature_dim)) else: - action_type = self.action_type_conv(states.data.x, states.data.edge_index) - action_type = torch.mean(action_type, dim=0) - features = self.features_conv(states.data.x, states.data.edge_index) - features = features.reshape(len(states), -1, features.shape[-1]).mean(dim=0) + action_type = self.action_type_conv(node_feature, edge_index) + action_type = action_type.reshape(len(states), -1, action_type.shape[-1]).mean(dim=1) + action_type = action_type.mean(dim=0).expand(len(states), -1) + features = self.features_conv(node_feature, edge_index) + features = features.reshape(len(states), -1, features.shape[-1]).mean(dim=1) - edge_index = self.edge_index_conv(states.data.x, states.data.edge_index) + edge_index = self.edge_index_conv(node_feature, edge_index) edge_index = edge_index.reshape(*states.batch_shape, -1, 8) edge_index = torch.einsum("bnf,bmf->bnm", edge_index, edge_index) + torch.diagonal(edge_index, dim1=-2, dim2=-1).fill_(float("-inf")) return TensorDict( { @@ -259,6 +264,7 @@ def forward(self, states: GraphStates) -> TensorDict: def test_graph_building(): + torch.manual_seed(7) feature_dim = 8 env = GraphBuilding(feature_dim=feature_dim) From fba5d509d64c72af5d3cb0f920f497944705b2f7 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Fri, 6 Dec 2024 15:46:45 +0100 Subject: [PATCH 016/102] categorical action type --- src/gfn/containers/trajectories.py | 2 +- src/gfn/gym/graph_building.py | 2 ++ src/gfn/modules.py | 4 ++-- src/gfn/samplers.py | 9 ++++----- src/gfn/states.py | 8 +++++--- src/gfn/utils/distributions.py | 16 +++++++++++++++- testing/test_samplers_and_trajectories.py | 9 ++++++--- 7 files changed, 35 insertions(+), 15 deletions(-) diff --git a/src/gfn/containers/trajectories.py b/src/gfn/containers/trajectories.py index 5feb665a..c56cd1ee 100644 --- a/src/gfn/containers/trajectories.py +++ b/src/gfn/containers/trajectories.py @@ -76,7 +76,7 @@ def __init__( self.states = ( states if states is not None else env.states_from_batch_shape((0, 0)) ) - assert len(self.states.batch_shape) == 2 + assert len(self.states.batch_shape) == 2, self.states.batch_shape self.actions = ( actions if actions is not None else env.actions_from_batch_shape((0, 0)) ) diff --git a/src/gfn/gym/graph_building.py b/src/gfn/gym/graph_building.py index 12fa8943..728f1fd5 100644 --- a/src/gfn/gym/graph_building.py +++ b/src/gfn/gym/graph_building.py @@ -50,6 +50,8 @@ def step(self, states: GraphStates, actions: GraphActions) -> TensorDict: if not self.is_action_valid(states, actions): raise NonValidActionsError("Invalid action.") state_tensor = deepcopy(states.tensor) + if len(actions) == 0: + return state_tensor action_type = actions.action_type[0] assert torch.all(actions.action_type == action_type) diff --git a/src/gfn/modules.py b/src/gfn/modules.py index 9d0f52f7..0500a157 100644 --- a/src/gfn/modules.py +++ b/src/gfn/modules.py @@ -9,7 +9,7 @@ from gfn.actions import GraphActionType from gfn.preprocessors import IdentityPreprocessor, Preprocessor from gfn.states import DiscreteStates, GraphStates, States -from gfn.utils.distributions import CategoricalIndexes, ComposedDistribution, UnsqueezedCategorical +from gfn.utils.distributions import CategoricalActionType, CategoricalIndexes, ComposedDistribution, UnsqueezedCategorical REDUCTION_FXNS = { "mean": torch.mean, @@ -515,7 +515,7 @@ def to_probability_distribution( action_type_probs = ( 1 - epsilon ) * action_type_probs + epsilon * uniform_dist_probs - dists["action_type"] = Categorical(probs=action_type_probs) + dists["action_type"] = CategoricalActionType(probs=action_type_probs) edge_index_logits = module_output["edge_index"] if states.tensor["node_feature"].shape[1] > 1: diff --git a/src/gfn/samplers.py b/src/gfn/samplers.py index 083d8792..c8289955 100644 --- a/src/gfn/samplers.py +++ b/src/gfn/samplers.py @@ -194,11 +194,9 @@ def sample_trajectories( # Place estimator outputs into a stackable tensor. Note that this # will be replaced with torch.nested.nested_tensor in the future. estimator_outputs_padded = torch.full_like( - estimator_outputs.expand( - (n_trajectories,) + estimator_outputs.shape[1:] - ), - fill_value=-float("inf"), - ).clone() # TODO: inefficient + estimator_outputs.expand((n_trajectories,) + estimator_outputs.shape[1:]).clone(), + fill_value=-float("inf") + ) estimator_outputs_padded[~dones] = estimator_outputs all_estimator_outputs.append(estimator_outputs_padded) @@ -243,6 +241,7 @@ def sample_trajectories( states = new_states dones = dones | new_dones + import pdb; pdb.set_trace() trajectories_states.extend(deepcopy(states)) trajectories_logprobs = ( diff --git a/src/gfn/states.py b/src/gfn/states.py index 81762909..6df93eca 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -602,11 +602,13 @@ def __setitem__(self, index: int | Sequence[int], graph: GraphStates): Set particular states of the Batch """ len_index = len(self.tensor[index]) - if len_index != 0 and len_index != len(self.tensor): + if len_index == 0: + return + elif len_index == len(self.tensor): + self.tensor = graph.tensor + else: # TODO: fix this raise ValueError("Can only set states with the same batch size as the original batch") - self.tensor[index] = graph.tensor - @property def device(self) -> torch.device: return self.tensor.device diff --git a/src/gfn/utils/distributions.py b/src/gfn/utils/distributions.py index e1e387c6..e6f6beaa 100644 --- a/src/gfn/utils/distributions.py +++ b/src/gfn/utils/distributions.py @@ -87,4 +87,18 @@ def sample(self, sample_shape=torch.Size()) -> torch.Tensor: def log_prob(self, value): value = value[..., 0] * self.n + value[..., 1] - return super().log_prob(value) \ No newline at end of file + return super().log_prob(value) + + +class CategoricalActionType(Categorical): # TODO: remove, just to sample 1 action_type + + def __init__(self, probs: torch.Tensor): + self.batch_len = len(probs) + super().__init__(probs[0]) + + def sample(self, sample_shape=torch.Size()) -> torch.Tensor: + samples = super().sample(sample_shape) + return samples.repeat(self.batch_len) + + def log_prob(self, value): + return super().log_prob(value[0]).repeat(self.batch_len) \ No newline at end of file diff --git a/testing/test_samplers_and_trajectories.py b/testing/test_samplers_and_trajectories.py index 96911165..e3ffef3f 100644 --- a/testing/test_samplers_and_trajectories.py +++ b/testing/test_samplers_and_trajectories.py @@ -244,7 +244,6 @@ def forward(self, states: GraphStates) -> TensorDict: else: action_type = self.action_type_conv(node_feature, edge_index) action_type = action_type.reshape(len(states), -1, action_type.shape[-1]).mean(dim=1) - action_type = action_type.mean(dim=0).expand(len(states), -1) features = self.features_conv(node_feature, edge_index) features = features.reshape(len(states), -1, features.shape[-1]).mean(dim=1) @@ -274,7 +273,11 @@ def test_graph_building(): sampler = Sampler(estimator=pf_estimator) trajectories = sampler.sample_trajectories( env, - n=5, + n=7, save_logprobs=True, - save_estimator_outputs=True, + save_estimator_outputs=False, ) + + +if __name__ == "__main__": + test_graph_building() \ No newline at end of file From 478bd148cb68d8ec49e87e8eaaf8008645e5f126 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Tue, 10 Dec 2024 16:46:06 +0100 Subject: [PATCH 017/102] change batching --- src/gfn/gym/graph_building.py | 68 ++++-- src/gfn/modules.py | 2 +- src/gfn/samplers.py | 1 - src/gfn/states.py | 239 ++++++++++++++++------ testing/test_samplers_and_trajectories.py | 4 +- 5 files changed, 230 insertions(+), 84 deletions(-) diff --git a/src/gfn/gym/graph_building.py b/src/gfn/gym/graph_building.py index 728f1fd5..67d3d316 100644 --- a/src/gfn/gym/graph_building.py +++ b/src/gfn/gym/graph_building.py @@ -59,14 +59,14 @@ def step(self, states: GraphStates, actions: GraphActions) -> TensorDict: return self.States.make_sink_states_tensor(states.batch_shape) if action_type == GraphActionType.ADD_NODE: - assert len(state_tensor) == len(actions) - state_tensor["node_feature"] = torch.cat([state_tensor["node_feature"], actions.features[:, None]], dim=1) + batch_indices = torch.arange(len(states))[actions.action_type == GraphActionType.ADD_NODE] + state_tensor = self._add_node(state_tensor, batch_indices, actions.features) if action_type == GraphActionType.ADD_EDGE: assert len(state_tensor) == len(actions) - state_tensor["edge_feature"] = torch.cat([state_tensor["edge_feature"], actions.features[:, None]], dim=1) + state_tensor["edge_feature"] = torch.cat([state_tensor["edge_feature"], actions.features], dim=0) state_tensor["edge_index"] = torch.cat( - [state_tensor["edge_index"], actions.edge_index[:, None]], dim=1 + [state_tensor["edge_index"], actions.edge_index], dim=0 ) return state_tensor @@ -108,11 +108,10 @@ def is_action_valid( if not torch.any(add_node_mask): add_node_out = True else: - node_feature = states.tensor["node_feature"][add_node_mask] + node_feature = states[add_node_mask].tensor["node_feature"] equal_nodes_per_batch = torch.all( node_feature == actions[add_node_mask].features[:, None], dim=-1 - ).reshape(len(node_feature), -1) - equal_nodes_per_batch = torch.sum(equal_nodes_per_batch, dim=-1) + ).reshape(-1) if backward: # TODO: check if no edge are connected? add_node_out = torch.all(equal_nodes_per_batch == 1) else: @@ -127,9 +126,9 @@ def is_action_valid( if torch.any(add_edge_actions.edge_index[:, 0] == add_edge_actions.edge_index[:, 1]): return False - if add_edge_states["node_feature"].shape[1] == 0: + if add_edge_states["node_feature"].shape[0] == 0: return False - if torch.any(add_edge_actions.edge_index > add_edge_states["node_feature"].shape[1]): + if torch.any(add_edge_actions.edge_index > add_edge_states["node_feature"].shape[0]): return False batch_dim = add_edge_actions.features.shape[0] @@ -156,6 +155,47 @@ def is_action_valid( ) return bool(add_node_out) and bool(add_edge_out) + + def _add_node(self, tensor_dict: TensorDict, batch_indices: torch.Tensor, nodes_to_add: torch.Tensor) -> TensorDict: + if isinstance(batch_indices, list): + batch_indices = torch.tensor(batch_indices) + if len(batch_indices) != len(nodes_to_add): + raise ValueError("Number of batch indices must match number of node feature lists") + + modified_dict = tensor_dict.clone() + node_feature_dim = modified_dict['node_feature'].shape[1] + edge_feature_dim = modified_dict['edge_feature'].shape[1] + + for graph_idx, new_nodes in zip(batch_indices, nodes_to_add): + start_ptr = tensor_dict['batch_ptr'][graph_idx] + end_ptr = tensor_dict['batch_ptr'][graph_idx + 1] + num_original_nodes = end_ptr - start_ptr + + if new_nodes.ndim == 1: + new_nodes = new_nodes.unsqueeze(0) + if new_nodes.shape[1] != node_feature_dim: + raise ValueError(f"Node features must have dimension {node_feature_dim}") + + # Update batch pointers for subsequent graphs + shift = new_nodes.shape[0] + modified_dict['batch_ptr'][graph_idx + 1:] += shift + + # Expand node features + original_nodes = modified_dict['node_feature'][start_ptr:end_ptr] + modified_dict['node_feature'] = torch.cat([ + modified_dict['node_feature'][:end_ptr], + new_nodes, + modified_dict['node_feature'][end_ptr:] + ]) + + # Update edge indices + # Increment indices for edges after the current graph + edge_mask_0 = modified_dict['edge_index'][:, 0] >= end_ptr + edge_mask_1 = modified_dict['edge_index'][:, 1] >= end_ptr + modified_dict['edge_index'][edge_mask_0, 0] += shift + modified_dict['edge_index'][edge_mask_1, 1] += shift + + return modified_dict def reward(self, final_states: GraphStates) -> torch.Tensor: """The environment's reward given a state. @@ -186,12 +226,14 @@ def make_random_states_tensor(self, batch_shape: Tuple) -> GraphStates: class GCNConvEvaluator: def __init__(self, num_features): - self.num_features = num_features self.net = GCNConv(num_features, 1) def __call__(self, state: GraphStates) -> torch.Tensor: - node_feature = state.tensor["node_feature"].reshape(-1, self.num_features) - edge_index = state.tensor["edge_index"].reshape(-1, 2).T + node_feature = state.tensor["node_feature"] + edge_index = state.tensor["edge_index"].T + if len(node_feature) == 0: + return torch.zeros(len(state)) + out = self.net(node_feature, edge_index) - out = out.reshape(len(state), state.tensor["node_feature"].shape[1]) + out = out.reshape(*state.batch_shape, -1) return out.mean(-1) diff --git a/src/gfn/modules.py b/src/gfn/modules.py index 0500a157..86a345d4 100644 --- a/src/gfn/modules.py +++ b/src/gfn/modules.py @@ -518,7 +518,7 @@ def to_probability_distribution( dists["action_type"] = CategoricalActionType(probs=action_type_probs) edge_index_logits = module_output["edge_index"] - if states.tensor["node_feature"].shape[1] > 1: + if states.tensor["node_feature"].shape[0] > 1 and torch.any(edge_index_logits != -float("inf")): edge_index_probs = torch.softmax(edge_index_logits / temperature, dim=-1) uniform_dist_probs = ( torch.ones_like(edge_index_probs) / edge_index_probs.shape[-1] diff --git a/src/gfn/samplers.py b/src/gfn/samplers.py index c8289955..3085b697 100644 --- a/src/gfn/samplers.py +++ b/src/gfn/samplers.py @@ -241,7 +241,6 @@ def sample_trajectories( states = new_states dones = dones | new_dones - import pdb; pdb.set_trace() trajectories_states.extend(deepcopy(states)) trajectories_logprobs = ( diff --git a/src/gfn/states.py b/src/gfn/states.py index 6df93eca..f91bec88 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -3,7 +3,7 @@ from abc import ABC from copy import deepcopy from math import prod -from typing import Callable, ClassVar, Optional, Sequence, Tuple +from typing import Callable, ClassVar, List, Optional, Sequence, Tuple import numpy as np import torch @@ -495,25 +495,25 @@ def __init__(self, tensor: TensorDict): self.node_features_dim = tensor["node_feature"].shape[-1] self.edge_features_dim = tensor["edge_feature"].shape[-1] - self.batch_shape: tuple = tensor.batch_size + self.batch_shape: tuple = tuple(tensor["batch_shape"].tolist()) self._log_rewards: float = None # TODO logic repeated from env.is_valid_action - not_empty = self.tensor["node_feature"].shape[1] > 1 - self.forward_masks = torch.ones((*self.batch_shape, 3), dtype=torch.bool) - self.forward_masks[:, GraphActionType.ADD_EDGE] = not_empty - self.forward_masks[:, GraphActionType.EXIT] = not_empty - - self.backward_masks = torch.ones((*self.batch_shape, 3), dtype=torch.bool) - self.backward_masks[:, GraphActionType.ADD_NODE] = not_empty - self.backward_masks[:, GraphActionType.ADD_EDGE] = ( - not_empty and self.tensor["edge_feature"].shape[1] > 0 > 0 - ) - self.backward_masks[:, GraphActionType.EXIT] = not_empty + not_empty = self.tensor["batch_ptr"][:-1] + 1 < self.tensor["batch_ptr"][1:] + self.forward_masks = torch.ones((np.prod(self.batch_shape), 3), dtype=torch.bool) + self.forward_masks[..., GraphActionType.ADD_EDGE] = not_empty + self.forward_masks[..., GraphActionType.EXIT] = not_empty + self.forward_masks = self.forward_masks.view(*self.batch_shape, 3) + + self.backward_masks = torch.ones((np.prod(self.batch_shape), 3), dtype=torch.bool) + self.backward_masks[..., GraphActionType.ADD_NODE] = not_empty + self.backward_masks[..., GraphActionType.ADD_EDGE] = not_empty # TODO: check at least one edge is present + self.backward_masks[..., GraphActionType.EXIT] = not_empty + self.backward_masks = self.backward_masks.view(*self.batch_shape, 3) @classmethod def from_batch_shape( - cls, batch_shape: int, random: bool = False, sink: bool = False + cls, batch_shape: int | Tuple, random: bool = False, sink: bool = False ) -> GraphStates: if random and sink: raise ValueError("Only one of `random` and `sink` should be True.") @@ -527,71 +527,100 @@ def from_batch_shape( @classmethod def make_initial_states_tensor(cls, batch_shape: int | Tuple) -> TensorDict: - if isinstance(batch_shape, Tuple) and len(batch_shape) > 1: - raise NotImplementedError( - "Batch shape with more than one dimension is not supported" - ) - if isinstance(batch_shape, Tuple): - batch_shape = batch_shape[0] + batch_shape = batch_shape if isinstance(batch_shape, Tuple) else (batch_shape,) return TensorDict({ - "node_feature": cls.s0["node_feature"].repeat(batch_shape, 1, 1), - "edge_feature": cls.s0["edge_feature"].repeat(batch_shape, 1, 1), - "edge_index": cls.s0["edge_index"].repeat(batch_shape, 1, 1) - }, batch_size=batch_shape) + "node_feature": cls.s0["node_feature"].repeat(np.prod(batch_shape), 1), + "edge_feature": cls.s0["edge_feature"].repeat(np.prod(batch_shape), 1), + "edge_index": cls.s0["edge_index"].repeat(np.prod(batch_shape), 1), + "batch_ptr": torch.arange(np.prod(batch_shape) + 1) * cls.s0["node_feature"].shape[0], + "batch_shape": batch_shape + }) @classmethod - def make_sink_states_tensor(cls, batch_shape: Tuple) -> TensorDict: + def make_sink_states_tensor(cls, batch_shape: int | Tuple) -> TensorDict: if cls.sf is None: raise NotImplementedError("Sink state is not defined") - if isinstance(batch_shape, Tuple) and len(batch_shape) > 1: - raise NotImplementedError( - "Batch shape with more than one dimension is not supported" - ) - if isinstance(batch_shape, Tuple): - batch_shape = batch_shape[0] - + batch_shape = batch_shape if isinstance(batch_shape, Tuple) else (batch_shape,) return TensorDict({ - "node_feature": cls.sf["node_feature"].repeat(batch_shape, 1, 1), - "edge_feature": cls.sf["edge_feature"].repeat(batch_shape, 1, 1), - "edge_index": cls.sf["edge_index"].repeat(batch_shape, 1, 1) - }, batch_size=int(batch_shape)) + "node_feature": cls.sf["node_feature"].repeat(np.prod(batch_shape), 1), + "edge_feature": cls.sf["edge_feature"].repeat(np.prod(batch_shape), 1), + "edge_index": cls.sf["edge_index"].repeat(np.prod(batch_shape), 1), + "batch_ptr": torch.arange(np.prod(batch_shape) + 1) * cls.sf["node_feature"].shape[0], + "batch_shape": batch_shape + }) @classmethod - def make_random_states_tensor(cls, batch_shape: int) -> TensorDict: - if isinstance(batch_shape, Tuple) and len(batch_shape) > 1: - raise NotImplementedError( - "Batch shape with more than one dimension is not supported" - ) - if isinstance(batch_shape, Tuple): - batch_shape = batch_shape[0] + def make_random_states_tensor(cls, batch_shape: int | Tuple) -> TensorDict: + batch_shape = batch_shape if isinstance(batch_shape, Tuple) else (batch_shape,) num_nodes = np.random.randint(10) num_edges = np.random.randint(num_nodes * (num_nodes - 1) // 2) node_features_dim = cls.s0["node_feature"].shape[-1] edge_features_dim = cls.s0["edge_feature"].shape[-1] - tensor = TensorDict({ - "node_feature": torch.rand(batch_shape, num_nodes, node_features_dim), - "edge_feature": torch.rand(batch_shape, num_edges, edge_features_dim), - "edge_index": torch.randint(num_nodes, size=(batch_shape, num_edges, 2)), + return TensorDict({ + "node_feature": torch.rand(np.prod(batch_shape) * num_nodes, node_features_dim), + "edge_feature": torch.rand(np.prod(batch_shape) * num_edges, edge_features_dim), + "edge_index": torch.randint(num_nodes, size=(np.prod(batch_shape) * num_edges, 2)), + "batch_ptr": torch.arange(np.prod(batch_shape) + 1) * num_nodes, + "batch_shape": batch_shape }) - return tensor def __len__(self): - return np.prod(self.tensor.batch_size) + return np.prod(self.batch_shape) def __repr__(self): return ( - f"{self.__class__.__name__} object of batch shape {self.batch_shape} and " + f"{self.__class__.__name__} object of batch shape {self.tensor['batch_shape']} and " f"node feature dim {self.node_features_dim} and edge feature dim {self.edge_features_dim}" ) def __getitem__( self, index: int | Sequence[int] | slice | torch.Tensor ) -> GraphStates: - out = GraphStates(self.tensor[index]) - + if isinstance(index, (int, list)): + index = torch.tensor(index) + if index.dtype == torch.bool: + index = torch.where(index)[0] + + if torch.any(index >= len(self.tensor['batch_ptr']) - 1): + raise ValueError("Graph index out of bounds") + + start_ptrs = self.tensor['batch_ptr'][:-1][index] + end_ptrs = self.tensor['batch_ptr'][1:][index] + + node_features = [torch.empty(0, self.node_features_dim)] + edge_features = [torch.empty(0, self.edge_features_dim)] + edge_indices = [torch.empty(0, 2, dtype=torch.long)] + batch_ptr = [0] + + for start, end in zip(start_ptrs, end_ptrs): + graph_nodes = self.tensor['node_feature'][start:end] + node_features.append(graph_nodes) + batch_ptr.append(batch_ptr[-1] + len(graph_nodes)) + + # Find edges for this graph + edge_mask = ((self.tensor['edge_index'][:, 0] >= start) & + (self.tensor['edge_index'][:, 0] < end)) + graph_edges = self.tensor['edge_feature'][edge_mask] + edge_features.append(graph_edges) + + # Adjust edge indices to be local to this graph + graph_edge_index = self.tensor['edge_index'][edge_mask] + graph_edge_index[:, 0] -= start + graph_edge_index[:, 1] -= start + edge_indices.append(graph_edge_index) + + out = self.__class__(TensorDict({ + 'node_feature': torch.cat(node_features), + 'edge_feature': torch.cat(edge_features), + 'edge_index': torch.cat(edge_indices), + 'batch_ptr': torch.tensor(batch_ptr), + 'batch_shape': (len(index),) + })) + + if self._log_rewards is not None: out._log_rewards = self._log_rewards[index] @@ -601,13 +630,64 @@ def __setitem__(self, index: int | Sequence[int], graph: GraphStates): """ Set particular states of the Batch """ - len_index = len(self.tensor[index]) - if len_index == 0: - return - elif len_index == len(self.tensor): - self.tensor = graph.tensor - else: # TODO: fix this - raise ValueError("Can only set states with the same batch size as the original batch") + if isinstance(index, (int, list)): + index = torch.tensor(index) + if index.dtype == torch.bool: + index = torch.where(index)[0] + + # Validate indices + if torch.any(index >= len(self.tensor['batch_ptr']) - 1): + raise ValueError("Target graph index out of bounds") + + # Get batch pointers for target and source + target_start_ptrs = self.tensor['batch_ptr'][:-1][index] + target_end_ptrs = self.tensor['batch_ptr'][1:][index] + + # Source graph details + source_tensor_dict = graph.tensor + source_num_graphs = torch.prod(source_tensor_dict['batch_shape']) + + # Validate source and target indices match + if len(index) != source_num_graphs: + raise ValueError("Number of source graphs must match number of target indices") + + for i, graph_idx in enumerate(index): + # Get start and end pointers for the current graph + start_ptr = self.tensor['batch_ptr'][graph_idx] + end_ptr = self.tensor['batch_ptr'][graph_idx + 1] + + new_nodes = source_tensor_dict['node_feature'][ + source_tensor_dict['batch_ptr'][i]:source_tensor_dict['batch_ptr'][i + 1] + ] + + # Ensure new nodes have correct feature dimension + if new_nodes.ndim == 1: + new_nodes = new_nodes.unsqueeze(0) + + if new_nodes.shape[1] != self.node_features_dim: + raise ValueError(f"Node features must have dimension {node_feature_dim}") + + # Number of new nodes to add + shift = new_nodes.shape[0] - (end_ptr - start_ptr) + + # Concatenate node features + self.tensor['node_feature'] = torch.cat([ + self.tensor['node_feature'][:start_ptr], # Nodes before the current graph + new_nodes, # New nodes to add + self.tensor['node_feature'][end_ptr:] # Nodes after the current graph + ]) + + # Update edge indices for subsequent graphs + edge_mask_0 = self.tensor['edge_index'][:, 0] >= end_ptr + edge_mask_1 = self.tensor['edge_index'][:, 1] >= end_ptr + self.tensor['edge_index'][edge_mask_0, 0] += shift + self.tensor['edge_index'][edge_mask_1, 1] += shift + + # Update batch pointers + self.tensor['batch_ptr'][graph_idx + 1:] += shift + + # TODO: add new edges + @property def device(self) -> torch.device: @@ -626,9 +706,10 @@ def clone(self) -> GraphStates: def extend(self, other: GraphStates): """Concatenates to another GraphStates object along the batch dimension""" - self.tensor["node_feature"] = torch.cat([self.tensor["node_feature"], other.tensor["node_feature"]], dim=1) - self.tensor["edge_feature"] = torch.cat([self.tensor["edge_feature"], other.tensor["edge_feature"]], dim=1) - self.tensor["edge_index"] = torch.cat([self.tensor["edge_index"], other.tensor["edge_index"]], dim=1) + self.tensor["node_feature"] = torch.cat([self.tensor["node_feature"], other.tensor["node_feature"]], dim=0) + self.tensor["edge_feature"] = torch.cat([self.tensor["edge_feature"], other.tensor["edge_feature"]], dim=0) + self.tensor["edge_index"] = torch.cat([self.tensor["edge_index"], other.tensor["edge_index"]], dim=0) + @property def log_rewards(self) -> torch.Tensor: @@ -640,10 +721,34 @@ def log_rewards(self, log_rewards: torch.Tensor) -> None: @property def is_sink_state(self) -> torch.Tensor: - if self.tensor["node_feature"].shape[1] == 0: + if len(self.tensor["node_feature"]) != np.prod(self.batch_shape): return torch.zeros(self.batch_shape, dtype=torch.bool) - return ( - torch.all(self.tensor["node_feature"] == self.sf["node_feature"], dim=(1, 2)) & - torch.all(self.tensor["edge_feature"] == self.sf["edge_feature"], dim=(1, 2)) & - torch.all(self.tensor["edge_index"] == self.sf["edge_index"], dim=(1, 2)) + return torch.all(self.tensor["node_feature"] == self.sf["node_feature"], dim=-1).view(self.batch_shape) + + +def stack_states(states: List[States]): + """Given a list of states, stacks them along a new dimension (0).""" + state_example = states[0] # We assume all elems of `states` are the same. + + stacked_states = state_example.from_batch_shape((0, 0)) # Empty. + stacked_states.tensor = torch.stack([s.tensor for s in states], dim=0) + if state_example._log_rewards: + stacked_states._log_rewards = torch.stack( + [s._log_rewards for s in states], dim=0 ) + + # We are dealing with a list of DiscretrStates instances. + if hasattr(state_example, "forward_masks"): + stacked_states.forward_masks = torch.stack( + [s.forward_masks for s in states], dim=0 + ) + stacked_states.backward_masks = torch.stack( + [s.backward_masks for s in states], dim=0 + ) + + # Adds the trajectory dimension. + stacked_states.batch_shape = ( + stacked_states.tensor.shape[0], + ) + state_example.batch_shape + + return stacked_states \ No newline at end of file diff --git a/testing/test_samplers_and_trajectories.py b/testing/test_samplers_and_trajectories.py index e3ffef3f..90bdfa49 100644 --- a/testing/test_samplers_and_trajectories.py +++ b/testing/test_samplers_and_trajectories.py @@ -235,9 +235,9 @@ def __init__(self, feature_dim: int): def forward(self, states: GraphStates) -> TensorDict: node_feature = states.tensor["node_feature"].reshape(-1, self.feature_dim) - edge_index = states.tensor["edge_index"].reshape(-1, 2).T + edge_index = states.tensor["edge_index"].T - if states.tensor["node_feature"].shape[1] == 0: + if states.tensor["node_feature"].shape[0] == 0: action_type = torch.zeros((len(states), len(GraphActionType))) action_type[:, GraphActionType.ADD_NODE] = 1 features = torch.zeros((len(states), self.feature_dim)) From dd80f2815237d7b64ec3885bc0139b063edcc122 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Wed, 11 Dec 2024 14:52:22 +0100 Subject: [PATCH 018/102] fix stacking --- src/gfn/actions.py | 13 ++++++++++ src/gfn/samplers.py | 13 ++++------ src/gfn/states.py | 58 ++++++++++++++++++++++----------------------- 3 files changed, 47 insertions(+), 37 deletions(-) diff --git a/src/gfn/actions.py b/src/gfn/actions.py index 36c52ae5..f31ac032 100644 --- a/src/gfn/actions.py +++ b/src/gfn/actions.py @@ -280,3 +280,16 @@ def make_dummy_actions( #features=torch.zeros((*batch_shape, 0, cls.nodes_features_dim)), #edge_index=torch.zeros((2, *batch_shape, 0)), ) + + @classmethod + def stack(cls, actions_list: list[GraphActions]) -> GraphActions: + """Stacks a list of GraphActions objects into a single GraphActions object.""" + actions_tensor = torch.stack( + [actions.tensor for actions in actions_list], dim=0 + ) + return cls( + actions_tensor["action_type"], + actions_tensor["features"], + actions_tensor["edge_index"] + ) + diff --git a/src/gfn/samplers.py b/src/gfn/samplers.py index 3085b697..e056de12 100644 --- a/src/gfn/samplers.py +++ b/src/gfn/samplers.py @@ -155,8 +155,8 @@ def sample_trajectories( else states.is_sink_state ) - trajectories_states: States = deepcopy(states) - trajectories_actions: Optional[Actions] = None + trajectories_states: List[States] = [deepcopy(states)] + trajectories_actions: List[Actions] = [] trajectories_logprobs: List[torch.Tensor] = [] trajectories_dones = torch.zeros( n_trajectories, dtype=torch.long, device=device @@ -205,10 +205,7 @@ def sample_trajectories( # When off_policy, actions_log_probs are None. log_probs[~dones] = actions_log_probs - if trajectories_actions is None: - trajectories_actions = actions - else: - trajectories_actions.extend(actions) + trajectories_actions.append(actions) trajectories_logprobs.append(log_probs) if self.estimator.is_backward: @@ -241,8 +238,8 @@ def sample_trajectories( states = new_states dones = dones | new_dones - trajectories_states.extend(deepcopy(states)) - + trajectories_states = env.States.stack(trajectories_states) + trajectories_actions = env.Actions.stack(trajectories_actions) trajectories_logprobs = ( torch.stack(trajectories_logprobs, dim=0) if save_logprobs else None ) diff --git a/src/gfn/states.py b/src/gfn/states.py index f91bec88..07969891 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -293,6 +293,34 @@ def log_rewards(self, log_rewards: torch.Tensor) -> None: def sample(self, n_samples: int) -> States: """Samples a subset of the States object.""" return self[torch.randperm(len(self))[:n_samples]] + + @classmethod + def stack(cls, states: List[States]): + """Given a list of states, stacks them along a new dimension (0).""" + state_example = states[0] # We assume all elems of `states` are the same. + + stacked_states = state_example.from_batch_shape((0, 0)) # Empty. + stacked_states.tensor = torch.stack([s.tensor for s in states], dim=0) + if state_example._log_rewards: + stacked_states._log_rewards = torch.stack( + [s._log_rewards for s in states], dim=0 + ) + + # We are dealing with a list of DiscretrStates instances. + if hasattr(state_example, "forward_masks"): + stacked_states.forward_masks = torch.stack( + [s.forward_masks for s in states], dim=0 + ) + stacked_states.backward_masks = torch.stack( + [s.backward_masks for s in states], dim=0 + ) + + # Adds the trajectory dimension. + stacked_states.batch_shape = ( + stacked_states.tensor.shape[0], + ) + state_example.batch_shape + + return stacked_states class DiscreteStates(States, ABC): @@ -480,7 +508,7 @@ def init_forward_masks(self, set_ones: bool = True): self.forward_masks = torch.zeros(shape).bool() -class GraphStates(ABC): +class GraphStates(States): """ Base class for Graph as a state representation. The `GraphStates` object is a batched collection of multiple graph objects. The `Batch` object from PyTorch Geometric is used to represent the batch of @@ -724,31 +752,3 @@ def is_sink_state(self) -> torch.Tensor: if len(self.tensor["node_feature"]) != np.prod(self.batch_shape): return torch.zeros(self.batch_shape, dtype=torch.bool) return torch.all(self.tensor["node_feature"] == self.sf["node_feature"], dim=-1).view(self.batch_shape) - - -def stack_states(states: List[States]): - """Given a list of states, stacks them along a new dimension (0).""" - state_example = states[0] # We assume all elems of `states` are the same. - - stacked_states = state_example.from_batch_shape((0, 0)) # Empty. - stacked_states.tensor = torch.stack([s.tensor for s in states], dim=0) - if state_example._log_rewards: - stacked_states._log_rewards = torch.stack( - [s._log_rewards for s in states], dim=0 - ) - - # We are dealing with a list of DiscretrStates instances. - if hasattr(state_example, "forward_masks"): - stacked_states.forward_masks = torch.stack( - [s.forward_masks for s in states], dim=0 - ) - stacked_states.backward_masks = torch.stack( - [s.backward_masks for s in states], dim=0 - ) - - # Adds the trajectory dimension. - stacked_states.batch_shape = ( - stacked_states.tensor.shape[0], - ) + state_example.batch_shape - - return stacked_states \ No newline at end of file From 616551c46f56e2b53334ab3f169432d976d4f317 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Wed, 11 Dec 2024 17:17:00 +0100 Subject: [PATCH 019/102] fix graph stacking --- src/gfn/env.py | 2 +- src/gfn/samplers.py | 9 ++++++--- src/gfn/states.py | 34 ++++++++++++++++++++++++++-------- 3 files changed, 33 insertions(+), 12 deletions(-) diff --git a/src/gfn/env.py b/src/gfn/env.py index e09c5b79..25026dd7 100644 --- a/src/gfn/env.py +++ b/src/gfn/env.py @@ -256,7 +256,7 @@ def _step( ) new_sink_states_idx = actions.is_exit - sf_tensor = self.States.make_sink_states_tensor(new_sink_states_idx.sum()) + sf_tensor = self.States.make_sink_states_tensor((new_sink_states_idx.sum(),)) new_states[new_sink_states_idx] = self.States(sf_tensor) new_sink_states_idx = ~valid_states_idx | new_sink_states_idx assert new_sink_states_idx.shape == states.batch_shape diff --git a/src/gfn/samplers.py b/src/gfn/samplers.py index e056de12..d0f580fd 100644 --- a/src/gfn/samplers.py +++ b/src/gfn/samplers.py @@ -193,9 +193,11 @@ def sample_trajectories( if estimator_outputs is not None: # Place estimator outputs into a stackable tensor. Note that this # will be replaced with torch.nested.nested_tensor in the future. - estimator_outputs_padded = torch.full_like( - estimator_outputs.expand((n_trajectories,) + estimator_outputs.shape[1:]).clone(), - fill_value=-float("inf") + estimator_outputs_padded = torch.full( + (n_trajectories,) + estimator_outputs.shape[1:], + fill_value=-float("inf"), + dtype=torch.float, + device=device, ) estimator_outputs_padded[~dones] = estimator_outputs all_estimator_outputs.append(estimator_outputs_padded) @@ -237,6 +239,7 @@ def sample_trajectories( ) states = new_states dones = dones | new_dones + trajectories_states.append(deepcopy(states)) trajectories_states = env.States.stack(trajectories_states) trajectories_actions = env.Actions.stack(trajectories_actions) diff --git a/src/gfn/states.py b/src/gfn/states.py index 07969891..408a41d8 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -523,9 +523,7 @@ def __init__(self, tensor: TensorDict): self.node_features_dim = tensor["node_feature"].shape[-1] self.edge_features_dim = tensor["edge_feature"].shape[-1] - self.batch_shape: tuple = tuple(tensor["batch_shape"].tolist()) self._log_rewards: float = None - # TODO logic repeated from env.is_valid_action not_empty = self.tensor["batch_ptr"][:-1] + 1 < self.tensor["batch_ptr"][1:] self.forward_masks = torch.ones((np.prod(self.batch_shape), 3), dtype=torch.bool) @@ -538,6 +536,10 @@ def __init__(self, tensor: TensorDict): self.backward_masks[..., GraphActionType.ADD_EDGE] = not_empty # TODO: check at least one edge is present self.backward_masks[..., GraphActionType.EXIT] = not_empty self.backward_masks = self.backward_masks.view(*self.batch_shape, 3) + + @property + def batch_shape(self) -> tuple: + return tuple(self.tensor["batch_shape"].tolist()) @classmethod def from_batch_shape( @@ -667,10 +669,6 @@ def __setitem__(self, index: int | Sequence[int], graph: GraphStates): if torch.any(index >= len(self.tensor['batch_ptr']) - 1): raise ValueError("Target graph index out of bounds") - # Get batch pointers for target and source - target_start_ptrs = self.tensor['batch_ptr'][:-1][index] - target_end_ptrs = self.tensor['batch_ptr'][1:][index] - # Source graph details source_tensor_dict = graph.tensor source_num_graphs = torch.prod(source_tensor_dict['batch_shape']) @@ -736,8 +734,10 @@ def extend(self, other: GraphStates): """Concatenates to another GraphStates object along the batch dimension""" self.tensor["node_feature"] = torch.cat([self.tensor["node_feature"], other.tensor["node_feature"]], dim=0) self.tensor["edge_feature"] = torch.cat([self.tensor["edge_feature"], other.tensor["edge_feature"]], dim=0) - self.tensor["edge_index"] = torch.cat([self.tensor["edge_index"], other.tensor["edge_index"]], dim=0) - + self.tensor["edge_index"] = torch.cat([self.tensor["edge_index"], other.tensor["edge_index"] + self.tensor["batch_ptr"][-1]], dim=0) + self.tensor["batch_ptr"] = torch.cat([self.tensor["batch_ptr"], other.tensor["batch_ptr"][1:] + self.tensor["batch_ptr"][-1]], dim=0) + assert torch.all(self.tensor["batch_shape"][1:] == other.tensor["batch_shape"][1:]) + self.tensor["batch_shape"] = (self.tensor["batch_shape"][0] + other.tensor["batch_shape"][0],) + self.batch_shape[1:] @property def log_rewards(self) -> torch.Tensor: @@ -752,3 +752,21 @@ def is_sink_state(self) -> torch.Tensor: if len(self.tensor["node_feature"]) != np.prod(self.batch_shape): return torch.zeros(self.batch_shape, dtype=torch.bool) return torch.all(self.tensor["node_feature"] == self.sf["node_feature"], dim=-1).view(self.batch_shape) + + @classmethod + def stack(cls, states: List[GraphStates]): + """Given a list of states, stacks them along a new dimension (0).""" + stacked_states = cls.from_batch_shape(0) + state_batch_shape = states[0].batch_shape + for state in states: + assert state.batch_shape == state_batch_shape + stacked_states.extend(state) + + stacked_states.forward_masks = torch.stack( + [s.forward_masks for s in states], dim=0 + ) + stacked_states.backward_masks = torch.stack( + [s.backward_masks for s in states], dim=0 + ) + stacked_states.tensor["batch_shape"] = (len(states),) + state_batch_shape + return stacked_states From 77611d41dccf4317427e9a86dbcb80808c608ee7 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Thu, 12 Dec 2024 12:57:19 +0100 Subject: [PATCH 020/102] fix test graph env --- src/gfn/env.py | 4 +- src/gfn/gym/graph_building.py | 49 ++++++++++-------------- src/gfn/states.py | 8 ++-- testing/test_environments.py | 71 +++++++++++++++++------------------ 4 files changed, 60 insertions(+), 72 deletions(-) diff --git a/src/gfn/env.py b/src/gfn/env.py index 25026dd7..3c543d51 100644 --- a/src/gfn/env.py +++ b/src/gfn/env.py @@ -630,11 +630,11 @@ def actions_from_tensor(self, tensor: Dict[str, torch.Tensor]): return self.Actions(**tensor) @abstractmethod - def step(self, states: GraphStates, actions: Actions) -> GraphStates: + def step(self, states: GraphStates, actions: Actions) -> torch.Tensor: """Function that takes a batch of graph states and actions and returns a batch of next graph states.""" @abstractmethod - def backward_step(self, states: GraphStates, actions: Actions) -> GraphStates: + def backward_step(self, states: GraphStates, actions: Actions) -> torch.Tensor: """Function that takes a batch of graph states and actions and returns a batch of previous graph states.""" diff --git a/src/gfn/gym/graph_building.py b/src/gfn/gym/graph_building.py index 67d3d316..8a2936d3 100644 --- a/src/gfn/gym/graph_building.py +++ b/src/gfn/gym/graph_building.py @@ -63,14 +63,12 @@ def step(self, states: GraphStates, actions: GraphActions) -> TensorDict: state_tensor = self._add_node(state_tensor, batch_indices, actions.features) if action_type == GraphActionType.ADD_EDGE: - assert len(state_tensor) == len(actions) state_tensor["edge_feature"] = torch.cat([state_tensor["edge_feature"], actions.features], dim=0) - state_tensor["edge_index"] = torch.cat( - [state_tensor["edge_index"], actions.edge_index], dim=0 - ) + state_tensor["edge_index"] = torch.cat([state_tensor["edge_index"], actions.edge_index], dim=0) + return state_tensor - def backward_step(self, states: GraphStates, actions: GraphActions) -> GraphStates: + def backward_step(self, states: GraphStates, actions: GraphActions) -> torch.Tensor: """Backward step function for the GraphBuilding environment. Args: @@ -81,25 +79,26 @@ def backward_step(self, states: GraphStates, actions: GraphActions) -> GraphStat """ if not self.is_action_valid(states, actions, backward=True): raise NonValidActionsError("Invalid action.") - graphs: Batch = deepcopy(states.data) - assert len(graphs) == len(actions) + state_tensor = deepcopy(states.tensor) - if actions.action_type == GraphActionType.ADD_NODE: - assert graphs.x is not None + action_type = actions.action_type[0] + assert torch.all(actions.action_type == action_type) + if action_type == GraphActionType.ADD_NODE: is_equal = torch.any( - torch.all(graphs.x[:, None] == actions.features, dim=-1), dim=-1 + torch.all(state_tensor["node_feature"][:, None] == actions.features, dim=-1), + dim=-1 ) - graphs.x = graphs.x[~is_equal] - elif actions.action_type == GraphActionType.ADD_EDGE: + state_tensor["node_feature"] = state_tensor["node_feature"][~is_equal] + elif action_type == GraphActionType.ADD_EDGE: assert actions.edge_index is not None is_equal = torch.all( - graphs.edge_index[:, None] == actions.edge_index[:, :, None], dim=0 + state_tensor["edge_index"] == actions.edge_index[:, None], dim=-1 ) is_equal = torch.any(is_equal, dim=0) - graphs.edge_attr = graphs.edge_attr[~is_equal] - graphs.edge_index = graphs.edge_index[:, ~is_equal] + state_tensor["edge_feature"] = state_tensor["edge_feature"][~is_equal] + state_tensor["edge_index"] = state_tensor["edge_index"][~is_equal] - return self.States(graphs) + return state_tensor def is_action_valid( self, states: GraphStates, actions: GraphActions, backward: bool = False @@ -111,8 +110,8 @@ def is_action_valid( node_feature = states[add_node_mask].tensor["node_feature"] equal_nodes_per_batch = torch.all( node_feature == actions[add_node_mask].features[:, None], dim=-1 - ).reshape(-1) - if backward: # TODO: check if no edge are connected? + ).sum(dim=-1) + if backward: # TODO: check if no edge is connected? add_node_out = torch.all(equal_nodes_per_batch == 1) else: add_node_out = torch.all(equal_nodes_per_batch == 0) @@ -131,18 +130,13 @@ def is_action_valid( if torch.any(add_edge_actions.edge_index > add_edge_states["node_feature"].shape[0]): return False - batch_dim = add_edge_actions.features.shape[0] - batch_idx = add_edge_actions.edge_index % batch_dim - if torch.any(batch_idx != torch.arange(batch_dim)): - return False - equal_edges_per_batch_attr = torch.all( add_edge_states["edge_feature"] == add_edge_actions.features[:, None], dim=-1 - ).reshape(len(add_edge_states), -1) + ) equal_edges_per_batch_attr = torch.sum(equal_edges_per_batch_attr, dim=-1) equal_edges_per_batch_index = torch.all( - add_edge_states["edge_index"] == add_edge_actions.edge_index, dim=0 - ).reshape(len(add_edge_states), -1) + add_edge_states["edge_index"] == add_edge_actions.edge_index[:, None], dim=-1 + ) equal_edges_per_batch_index = torch.sum(equal_edges_per_batch_index, dim=-1) if backward: @@ -164,12 +158,10 @@ def _add_node(self, tensor_dict: TensorDict, batch_indices: torch.Tensor, nodes_ modified_dict = tensor_dict.clone() node_feature_dim = modified_dict['node_feature'].shape[1] - edge_feature_dim = modified_dict['edge_feature'].shape[1] for graph_idx, new_nodes in zip(batch_indices, nodes_to_add): start_ptr = tensor_dict['batch_ptr'][graph_idx] end_ptr = tensor_dict['batch_ptr'][graph_idx + 1] - num_original_nodes = end_ptr - start_ptr if new_nodes.ndim == 1: new_nodes = new_nodes.unsqueeze(0) @@ -181,7 +173,6 @@ def _add_node(self, tensor_dict: TensorDict, batch_indices: torch.Tensor, nodes_ modified_dict['batch_ptr'][graph_idx + 1:] += shift # Expand node features - original_nodes = modified_dict['node_feature'][start_ptr:end_ptr] modified_dict['node_feature'] = torch.cat([ modified_dict['node_feature'][:end_ptr], new_nodes, diff --git a/src/gfn/states.py b/src/gfn/states.py index 408a41d8..cccab384 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -628,7 +628,6 @@ def __getitem__( for start, end in zip(start_ptrs, end_ptrs): graph_nodes = self.tensor['node_feature'][start:end] node_features.append(graph_nodes) - batch_ptr.append(batch_ptr[-1] + len(graph_nodes)) # Find edges for this graph edge_mask = ((self.tensor['edge_index'][:, 0] >= start) & @@ -638,10 +637,11 @@ def __getitem__( # Adjust edge indices to be local to this graph graph_edge_index = self.tensor['edge_index'][edge_mask] - graph_edge_index[:, 0] -= start - graph_edge_index[:, 1] -= start + graph_edge_index[:, 0] -= (batch_ptr[-1] - start) + graph_edge_index[:, 1] -= (batch_ptr[-1] - start) edge_indices.append(graph_edge_index) - + batch_ptr.append(batch_ptr[-1] + len(graph_nodes)) + out = self.__class__(TensorDict({ 'node_feature': torch.cat(node_features), 'edge_feature': torch.cat(edge_features), diff --git a/testing/test_environments.py b/testing/test_environments.py index 948742bf..a061014d 100644 --- a/testing/test_environments.py +++ b/testing/test_environments.py @@ -327,64 +327,59 @@ def test_graph_env(): env = GraphBuilding(feature_dim=FEATURE_DIM) states = env.reset(batch_shape=BATCH_SIZE) - assert states.batch_shape == BATCH_SIZE + assert states.batch_shape == (BATCH_SIZE,) action_cls = env.make_actions_class() with pytest.raises(NonValidActionsError): actions = action_cls( - GraphActionType.ADD_EDGE, + torch.full((BATCH_SIZE,), GraphActionType.ADD_EDGE), torch.rand((BATCH_SIZE, FEATURE_DIM)), - torch.randint(0, 10, (2, BATCH_SIZE), dtype=torch.long), + torch.randint(0, 10, (BATCH_SIZE, 2), dtype=torch.long), ) states = env.step(states, actions) for _ in range(NUM_NODES): actions = action_cls( - GraphActionType.ADD_NODE, + torch.full((BATCH_SIZE,), GraphActionType.ADD_NODE), torch.rand((BATCH_SIZE, FEATURE_DIM)), ) states = env.step(states, actions) + states = env.States(states) - assert states.data.x.shape == (BATCH_SIZE * NUM_NODES, FEATURE_DIM) + assert states.tensor["node_feature"].shape == (BATCH_SIZE * NUM_NODES, FEATURE_DIM) with pytest.raises(NonValidActionsError): - first_node_mask = torch.arange(len(states.data.x)) // BATCH_SIZE == 0 + first_node_mask = torch.arange(len(states.tensor["node_feature"])) // BATCH_SIZE == 0 actions = action_cls( - GraphActionType.ADD_NODE, - states.data.x[first_node_mask], + torch.full((BATCH_SIZE,), GraphActionType.ADD_NODE), + states.tensor["node_feature"][first_node_mask], ) states = env.step(states, actions) with pytest.raises(NonValidActionsError): edge_index = torch.randint(0, 3, (BATCH_SIZE,), dtype=torch.long) actions = action_cls( - GraphActionType.ADD_EDGE, + torch.full((BATCH_SIZE,), GraphActionType.ADD_EDGE), torch.rand((BATCH_SIZE, FEATURE_DIM)), - torch.stack([edge_index, edge_index]), + torch.stack([edge_index, edge_index], dim=1), ) states = env.step(states, actions) for i in range(NUM_NODES - 1): - node_is = torch.arange(i * BATCH_SIZE, (i + 1) * BATCH_SIZE) - node_js = torch.arange((i + 1) * BATCH_SIZE, (i + 2) * BATCH_SIZE) + node_is = states.tensor["batch_ptr"][:-1] + i + node_js = states.tensor["batch_ptr"][:-1] + i + 1 actions = action_cls( - GraphActionType.ADD_EDGE, + torch.full((BATCH_SIZE,), GraphActionType.ADD_EDGE), torch.rand((BATCH_SIZE, FEATURE_DIM)), - torch.stack([node_is, node_js]), + torch.stack([node_is, node_js], dim=1), ) states = env.step(states, actions) + states = env.States(states) - with pytest.raises(NonValidActionsError): - edge_index = torch.tensor([[0, 1]] * BATCH_SIZE) - actions = action_cls( - GraphActionType.ADD_EDGE, - torch.rand((BATCH_SIZE, FEATURE_DIM)), - edge_index.T, - ) - states = env.step(states, actions) - - actions = action_cls(GraphActionType.EXIT) - states = env.step(states, actions) + actions = action_cls(torch.full((BATCH_SIZE,), GraphActionType.EXIT)) + sf_states = env.step(states, actions) + sf_states = env.States(sf_states) + assert torch.all(sf_states.is_sink_state) env.reward(states) # with pytest.raises(NonValidActionsError): @@ -395,37 +390,39 @@ def test_graph_env(): # ) # states = env.backward_step(states, actions) - num_edges_per_batch = states.data.edge_attr.shape[0] // BATCH_SIZE + num_edges_per_batch = len(states.tensor["edge_feature"]) // BATCH_SIZE for i in reversed(range(num_edges_per_batch)): edge_idx = torch.arange(i * BATCH_SIZE, (i + 1) * BATCH_SIZE) actions = action_cls( - GraphActionType.ADD_EDGE, - states.data.edge_attr[edge_idx], - states.data.edge_index[:, edge_idx], + torch.full((BATCH_SIZE,), GraphActionType.ADD_EDGE), + states.tensor["edge_feature"][edge_idx], + states.tensor["edge_index"][edge_idx], ) states = env.backward_step(states, actions) + states = env.States(states) with pytest.raises(NonValidActionsError): actions = action_cls( - GraphActionType.ADD_EDGE, + torch.full((BATCH_SIZE,), GraphActionType.ADD_EDGE), torch.rand((BATCH_SIZE, FEATURE_DIM)), - torch.randint(0, 10, (2, BATCH_SIZE), dtype=torch.long), + torch.randint(0, 10, (BATCH_SIZE, 2), dtype=torch.long), ) states = env.backward_step(states, actions) - for i in reversed(range(NUM_NODES)): - edge_idx = torch.arange(i * BATCH_SIZE, (i + 1) * BATCH_SIZE) + for i in reversed(range(1, NUM_NODES + 1)): + edge_idx = torch.arange(BATCH_SIZE) * i actions = action_cls( - GraphActionType.ADD_NODE, - states.data.x[edge_idx], + torch.full((BATCH_SIZE,), GraphActionType.ADD_NODE), + states.tensor["node_feature"][edge_idx], ) states = env.backward_step(states, actions) + states = env.States(states) - assert states.data.x.shape == (0, FEATURE_DIM) + assert states.tensor["node_feature"].shape == (0, FEATURE_DIM) with pytest.raises(NonValidActionsError): actions = action_cls( - GraphActionType.ADD_NODE, + torch.full((BATCH_SIZE,), GraphActionType.ADD_NODE), torch.rand((BATCH_SIZE, FEATURE_DIM)), ) states = env.backward_step(states, actions) From 5874ff6443f0496b0ef13ece7355d5f8adfe06d9 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Fri, 20 Dec 2024 00:56:55 +0100 Subject: [PATCH 021/102] add ring example --- src/gfn/actions.py | 39 ++---- src/gfn/env.py | 12 +- src/gfn/gflownet/flow_matching.py | 16 +-- src/gfn/gym/__init__.py | 1 + src/gfn/modules.py | 3 +- src/gfn/states.py | 28 ++-- testing/test_environments.py | 87 ++++++------ testing/test_samplers_and_trajectories.py | 4 - tutorials/examples/test_graph_ring.py | 159 ++++++++++++++++++++++ 9 files changed, 249 insertions(+), 100 deletions(-) create mode 100644 tutorials/examples/test_graph_ring.py diff --git a/src/gfn/actions.py b/src/gfn/actions.py index f31ac032..d2e9b3b0 100644 --- a/src/gfn/actions.py +++ b/src/gfn/actions.py @@ -181,12 +181,7 @@ class GraphActionType(enum.IntEnum): class GraphActions(Actions): features_dim: ClassVar[int] - def __init__( - self, - action_type: torch.Tensor, - features: Optional[torch.Tensor] = None, - edge_index: Optional[torch.Tensor] = None, - ): + def __init__(self, tensor: TensorDict): """Initializes a GraphAction object. Args: @@ -196,16 +191,18 @@ def __init__( edge_index: an tensor of shape (batch_shape, 2) representing the edge to add. This must defined if and only if the action type is GraphActionType.AddEdge. """ - self.batch_shape = action_type.shape + self.batch_shape = tensor["action_type"].shape + features = tensor.get("features", None) if features is None: - assert torch.all(action_type == GraphActionType.EXIT) + assert torch.all(tensor["action_type"] == GraphActionType.EXIT) features = torch.zeros((*self.batch_shape, self.features_dim)) + edge_index = tensor.get("edge_index", None) if edge_index is None: - assert torch.all(action_type != GraphActionType.ADD_EDGE) + assert torch.all(tensor["action_type"] != GraphActionType.ADD_EDGE) edge_index = torch.zeros((*self.batch_shape, 2), dtype=torch.long) self.tensor = TensorDict({ - "action_type": action_type, + "action_type": tensor["action_type"], "features": features, "edge_index": edge_index, }, batch_size=self.batch_shape) @@ -224,12 +221,8 @@ def __len__(self) -> int: def __getitem__(self, index: int | Sequence[int] | Sequence[bool]) -> GraphActions: """Get particular actions of the batch.""" - tensor = self.tensor[index] - return GraphActions( - tensor["action_type"], - tensor["features"], - tensor["edge_index"] - ) + return GraphActions(self.tensor[index]) + def __setitem__( self, index: int | Sequence[int] | Sequence[bool], action: GraphActions @@ -276,9 +269,11 @@ def make_dummy_actions( ) -> GraphActions: """Creates an Actions object of dummy actions with the given batch shape.""" return cls( - action_type=torch.full(batch_shape, fill_value=GraphActionType.EXIT), - #features=torch.zeros((*batch_shape, 0, cls.nodes_features_dim)), - #edge_index=torch.zeros((2, *batch_shape, 0)), + TensorDict({ + "action_type": torch.full(batch_shape, fill_value=GraphActionType.EXIT), + # "features": torch.zeros((*batch_shape, 0, cls.nodes_features_dim)), + # "edge_index": torch.zeros((2, *batch_shape, 0)), + }, batch_size=batch_shape) ) @classmethod @@ -287,9 +282,5 @@ def stack(cls, actions_list: list[GraphActions]) -> GraphActions: actions_tensor = torch.stack( [actions.tensor for actions in actions_list], dim=0 ) - return cls( - actions_tensor["action_type"], - actions_tensor["features"], - actions_tensor["edge_index"] - ) + return cls(actions_tensor) diff --git a/src/gfn/env.py b/src/gfn/env.py index 3c543d51..86c592b5 100644 --- a/src/gfn/env.py +++ b/src/gfn/env.py @@ -219,6 +219,7 @@ def reset( batch_shape = (1,) if isinstance(batch_shape, int): batch_shape = (batch_shape,) + return self.States.from_batch_shape( batch_shape=batch_shape, random=random, sink=sink ) @@ -618,17 +619,6 @@ class DefaultGraphAction(GraphActions): return DefaultGraphAction - def actions_from_tensor(self, tensor: Dict[str, torch.Tensor]): - """Wraps the supplied Tensor in an Actions instance. - - Args: - tensor: The tensor of shape "action_shape" representing the actions. - - Returns: - Actions: An instance of Actions. - """ - return self.Actions(**tensor) - @abstractmethod def step(self, states: GraphStates, actions: Actions) -> torch.Tensor: """Function that takes a batch of graph states and actions and returns a batch of next diff --git a/src/gfn/gflownet/flow_matching.py b/src/gfn/gflownet/flow_matching.py index 38072080..8347b835 100644 --- a/src/gfn/gflownet/flow_matching.py +++ b/src/gfn/gflownet/flow_matching.py @@ -33,10 +33,10 @@ class FMGFlowNet(GFlowNet[Tuple[DiscreteStates, DiscreteStates]]): def __init__(self, logF: DiscretePolicyEstimator, alpha: float = 1.0): super().__init__() - assert isinstance( # TODO: need a more flexible type check. - logF, - DiscretePolicyEstimator | ConditionalDiscretePolicyEstimator, - ), "logF must be a DiscretePolicyEstimator or ConditionalDiscretePolicyEstimator" + # assert isinstance( # TODO: need a more flexible type check. + # logF, + # DiscretePolicyEstimator | ConditionalDiscretePolicyEstimator, + # ), "logF must be a DiscretePolicyEstimator or ConditionalDiscretePolicyEstimator" self.logF = logF self.alpha = alpha @@ -50,10 +50,10 @@ def sample_trajectories( **policy_kwargs: Any, ) -> Trajectories: """Sample trajectory with optional kwargs controling the policy.""" - if not env.is_discrete: - raise NotImplementedError( - "Flow Matching GFlowNet only supports discrete environments for now." - ) + # if not env.is_discrete: + # raise NotImplementedError( + # "Flow Matching GFlowNet only supports discrete environments for now." + # ) sampler = Sampler(estimator=self.logF) trajectories = sampler.sample_trajectories( env, diff --git a/src/gfn/gym/__init__.py b/src/gfn/gym/__init__.py index fbec4831..20490566 100644 --- a/src/gfn/gym/__init__.py +++ b/src/gfn/gym/__init__.py @@ -1,3 +1,4 @@ from gfn.gym.box import Box from gfn.gym.discrete_ebm import DiscreteEBM from gfn.gym.hypergrid import HyperGrid +from gfn.gym.graph_building import GraphBuilding \ No newline at end of file diff --git a/src/gfn/modules.py b/src/gfn/modules.py index 86a345d4..169a1f57 100644 --- a/src/gfn/modules.py +++ b/src/gfn/modules.py @@ -79,7 +79,6 @@ def __init__( ) preprocessor = IdentityPreprocessor(module.input_dim) self.preprocessor = preprocessor - self._output_dim_is_checked = False self.is_backward = is_backward def forward(self, input: States | torch.Tensor) -> torch.Tensor: @@ -236,7 +235,7 @@ def forward(self, states: DiscreteStates) -> torch.Tensor: Returns the output of the module, as a tensor of shape (*batch_shape, output_dim). """ out = super().forward(states) - assert out.shape[-1] == self.expected_output_dim + assert out.shape[-1] == self.expected_output_dim, f"Expected output dim: {self.expected_output_dim}, got: {out.shape[-1]}" return out def to_probability_distribution( diff --git a/src/gfn/states.py b/src/gfn/states.py index cccab384..215d49b0 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -597,8 +597,8 @@ def make_random_states_tensor(cls, batch_shape: int | Tuple) -> TensorDict: "batch_shape": batch_shape }) - def __len__(self): - return np.prod(self.batch_shape) + def __len__(self) -> int: + return int(np.prod(self.batch_shape)) def __repr__(self): return ( @@ -609,10 +609,8 @@ def __repr__(self): def __getitem__( self, index: int | Sequence[int] | slice | torch.Tensor ) -> GraphStates: - if isinstance(index, (int, list)): - index = torch.tensor(index) - if index.dtype == torch.bool: - index = torch.where(index)[0] + tensor_idx = torch.arange(len(self)).view(*self.batch_shape) + index = tensor_idx[index].flatten() if torch.any(index >= len(self.tensor['batch_ptr']) - 1): raise ValueError("Graph index out of bounds") @@ -747,11 +745,23 @@ def log_rewards(self) -> torch.Tensor: def log_rewards(self, log_rewards: torch.Tensor) -> None: self._log_rewards = log_rewards + def _compare(self, other: TensorDict) -> torch.Tensor: + out = torch.zeros(len(self.tensor["batch_ptr"]) - 1, dtype=torch.bool) + for i in range(len(self.tensor["batch_ptr"]) - 1): + start, end = self.tensor["batch_ptr"][i], self.tensor["batch_ptr"][i + 1] + if end - start != len(other["node_feature"]): + out[i] = False + else: + out[i] = torch.all(self.tensor["node_feature"][start:end] == other["node_feature"]) + return out.view(self.batch_shape) + @property def is_sink_state(self) -> torch.Tensor: - if len(self.tensor["node_feature"]) != np.prod(self.batch_shape): - return torch.zeros(self.batch_shape, dtype=torch.bool) - return torch.all(self.tensor["node_feature"] == self.sf["node_feature"], dim=-1).view(self.batch_shape) + return self._compare(self.sf) + + @property + def is_initial_state(self) -> torch.Tensor: + return self._compare(self.s0) @classmethod def stack(cls, states: List[GraphStates]): diff --git a/testing/test_environments.py b/testing/test_environments.py index a061014d..a2919d50 100644 --- a/testing/test_environments.py +++ b/testing/test_environments.py @@ -1,6 +1,7 @@ import numpy as np import pytest import torch +from tensordict import TensorDict from gfn.actions import GraphActionType from gfn.env import NonValidActionsError @@ -331,18 +332,18 @@ def test_graph_env(): action_cls = env.make_actions_class() with pytest.raises(NonValidActionsError): - actions = action_cls( - torch.full((BATCH_SIZE,), GraphActionType.ADD_EDGE), - torch.rand((BATCH_SIZE, FEATURE_DIM)), - torch.randint(0, 10, (BATCH_SIZE, 2), dtype=torch.long), - ) + actions = action_cls(TensorDict({ + "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_EDGE), + "features": torch.rand((BATCH_SIZE, FEATURE_DIM)), + "edge_index": torch.randint(0, 10, (BATCH_SIZE, 2), dtype=torch.long), + }, batch_size=BATCH_SIZE)) states = env.step(states, actions) for _ in range(NUM_NODES): - actions = action_cls( - torch.full((BATCH_SIZE,), GraphActionType.ADD_NODE), - torch.rand((BATCH_SIZE, FEATURE_DIM)), - ) + actions = action_cls(TensorDict({ + "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_NODE), + "features": torch.rand((BATCH_SIZE, FEATURE_DIM)), + }, batch_size=BATCH_SIZE)) states = env.step(states, actions) states = env.States(states) @@ -350,33 +351,35 @@ def test_graph_env(): with pytest.raises(NonValidActionsError): first_node_mask = torch.arange(len(states.tensor["node_feature"])) // BATCH_SIZE == 0 - actions = action_cls( - torch.full((BATCH_SIZE,), GraphActionType.ADD_NODE), - states.tensor["node_feature"][first_node_mask], - ) + actions = action_cls(TensorDict({ + "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_NODE), + "features": states.tensor["node_feature"][first_node_mask], + }, batch_size=BATCH_SIZE)) states = env.step(states, actions) with pytest.raises(NonValidActionsError): edge_index = torch.randint(0, 3, (BATCH_SIZE,), dtype=torch.long) - actions = action_cls( - torch.full((BATCH_SIZE,), GraphActionType.ADD_EDGE), - torch.rand((BATCH_SIZE, FEATURE_DIM)), - torch.stack([edge_index, edge_index], dim=1), - ) + actions = action_cls(TensorDict({ + "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_EDGE), + "features": torch.rand((BATCH_SIZE, FEATURE_DIM)), + "edge_index": torch.stack([edge_index, edge_index], dim=1), + }, batch_size=BATCH_SIZE)) states = env.step(states, actions) for i in range(NUM_NODES - 1): node_is = states.tensor["batch_ptr"][:-1] + i node_js = states.tensor["batch_ptr"][:-1] + i + 1 - actions = action_cls( - torch.full((BATCH_SIZE,), GraphActionType.ADD_EDGE), - torch.rand((BATCH_SIZE, FEATURE_DIM)), - torch.stack([node_is, node_js], dim=1), - ) + actions = action_cls(TensorDict({ + "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_EDGE), + "features": torch.rand((BATCH_SIZE, FEATURE_DIM)), + "edge_index": torch.stack([node_is, node_js], dim=1), + }, batch_size=BATCH_SIZE)) states = env.step(states, actions) states = env.States(states) - actions = action_cls(torch.full((BATCH_SIZE,), GraphActionType.EXIT)) + actions = action_cls(TensorDict({ + "action_type": torch.full((BATCH_SIZE,), GraphActionType.EXIT), + }, batch_size=BATCH_SIZE)) sf_states = env.step(states, actions) sf_states = env.States(sf_states) assert torch.all(sf_states.is_sink_state) @@ -393,36 +396,36 @@ def test_graph_env(): num_edges_per_batch = len(states.tensor["edge_feature"]) // BATCH_SIZE for i in reversed(range(num_edges_per_batch)): edge_idx = torch.arange(i * BATCH_SIZE, (i + 1) * BATCH_SIZE) - actions = action_cls( - torch.full((BATCH_SIZE,), GraphActionType.ADD_EDGE), - states.tensor["edge_feature"][edge_idx], - states.tensor["edge_index"][edge_idx], - ) + actions = action_cls(TensorDict({ + "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_EDGE), + "features": states.tensor["edge_feature"][edge_idx], + "edge_index": states.tensor["edge_index"][edge_idx], + }, batch_size=BATCH_SIZE)) states = env.backward_step(states, actions) states = env.States(states) with pytest.raises(NonValidActionsError): - actions = action_cls( - torch.full((BATCH_SIZE,), GraphActionType.ADD_EDGE), - torch.rand((BATCH_SIZE, FEATURE_DIM)), - torch.randint(0, 10, (BATCH_SIZE, 2), dtype=torch.long), - ) + actions = action_cls(TensorDict({ + "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_EDGE), + "features": torch.rand((BATCH_SIZE, FEATURE_DIM)), + "edge_index": torch.randint(0, 10, (BATCH_SIZE, 2), dtype=torch.long), + }, batch_size=BATCH_SIZE)) states = env.backward_step(states, actions) for i in reversed(range(1, NUM_NODES + 1)): edge_idx = torch.arange(BATCH_SIZE) * i - actions = action_cls( - torch.full((BATCH_SIZE,), GraphActionType.ADD_NODE), - states.tensor["node_feature"][edge_idx], - ) + actions = action_cls(TensorDict({ + "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_NODE), + "features": states.tensor["node_feature"][edge_idx], + }, batch_size=BATCH_SIZE)) states = env.backward_step(states, actions) states = env.States(states) assert states.tensor["node_feature"].shape == (0, FEATURE_DIM) with pytest.raises(NonValidActionsError): - actions = action_cls( - torch.full((BATCH_SIZE,), GraphActionType.ADD_NODE), - torch.rand((BATCH_SIZE, FEATURE_DIM)), - ) + actions = action_cls(TensorDict({ + "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_NODE), + "features": torch.rand((BATCH_SIZE, FEATURE_DIM)), + }, batch_size=BATCH_SIZE)) states = env.backward_step(states, actions) diff --git a/testing/test_samplers_and_trajectories.py b/testing/test_samplers_and_trajectories.py index 90bdfa49..470c8b09 100644 --- a/testing/test_samplers_and_trajectories.py +++ b/testing/test_samplers_and_trajectories.py @@ -277,7 +277,3 @@ def test_graph_building(): save_logprobs=True, save_estimator_outputs=False, ) - - -if __name__ == "__main__": - test_graph_building() \ No newline at end of file diff --git a/tutorials/examples/test_graph_ring.py b/tutorials/examples/test_graph_ring.py new file mode 100644 index 00000000..8546ec41 --- /dev/null +++ b/tutorials/examples/test_graph_ring.py @@ -0,0 +1,159 @@ +"""Write ane xamples where we want to create graphs that are rings.""" + +import torch +from torch import nn +from gfn.actions import Actions, GraphActionType, GraphActions +from gfn.gflownet.flow_matching import FMGFlowNet +from gfn.gym import GraphBuilding +from gfn.modules import DiscretePolicyEstimator +from gfn.preprocessors import Preprocessor +from gfn.states import GraphStates +from tensordict import TensorDict +from torch_geometric.nn import GCNConv + + +def state_evaluator(states: GraphStates) -> torch.Tensor: + if states.tensor["edge_index"].shape[0] == 0: + return torch.zeros(states.batch_shape) + if states.tensor["edge_index"].shape[0] != states.tensor["node_feature"].shape[0]: + return torch.zeros(states.batch_shape) + + i0 = torch.unique(states.tensor["edge_index"][0], sorted=False) + i1 = torch.unique(states.tensor["edge_index"][1], sorted=False) + + if len(i0) == len(i1) == states.tensor["node_feature"].shape[0]: + return torch.ones(states.batch_shape) + return torch.zeros(states.batch_shape) + + +class RingPolicyEstimator(nn.Module): + def __init__(self, n_nodes: int): + super().__init__() + self.action_type_conv = GCNConv(1, 1) + self.edge_index_conv = GCNConv(1, 8) + self.n_nodes = n_nodes + + def _group_sum(self, tensor: torch.Tensor, batch_ptr: torch.Tensor) -> torch.Tensor: + cumsum = torch.zeros((len(tensor) + 1, *tensor.shape[1:]), dtype=tensor.dtype, device=tensor.device) + cumsum[1:] = torch.cumsum(tensor, dim=0) + return cumsum[batch_ptr[1:]] - cumsum[batch_ptr[:-1]] + + def forward(self, states_tensor: TensorDict) -> torch.Tensor: + node_feature = states_tensor["node_feature"].reshape(-1, 1) + edge_index = states_tensor["edge_index"].T + batch_ptr = states_tensor["batch_ptr"] + + action_type = self.action_type_conv(node_feature, edge_index) + action_type = self._group_sum(action_type, batch_ptr) + + edge_index = self.edge_index_conv(node_feature, edge_index) + #edge_index = self._group_sum(edge_index, batch_ptr) + edge_index = edge_index.reshape(*states_tensor["batch_shape"], -1, 8) + edge_index = torch.einsum("bnf,bmf->bnm", edge_index, edge_index) + torch.diagonal(edge_index, dim1=-2, dim2=-1).fill_(float("-inf")) + edge_actions = edge_index.reshape(*states_tensor["batch_shape"], -1) + + return torch.cat([action_type, edge_actions], dim=-1) + +class RingGraphBuilding(GraphBuilding): + def __init__(self, nodes: int = 10): + self.nodes = nodes + self.n_actions = 1 + nodes * nodes + super().__init__(feature_dim=1, state_evaluator=state_evaluator) + + + def make_actions_class(self) -> type[Actions]: + env = self + class RingActions(Actions): + action_shape = (1,) + dummy_action = torch.tensor([env.n_actions]) + exit_action = torch.zeros(1,) + + return RingActions + + + def make_states_class(self) -> type[GraphStates]: + env = self + + class RingStates(GraphStates): + s0 = TensorDict({ + "node_feature": torch.zeros((env.nodes, 1)), + "edge_feature": torch.zeros((0, 1)), + "edge_index": torch.zeros((0, 2), dtype=torch.long), + }, batch_size=()) + sf = TensorDict({ + "node_feature": torch.ones((env.nodes, 1)), + "edge_feature": torch.zeros((0, 1)), + "edge_index": torch.zeros((0, 2), dtype=torch.long), + }, batch_size=()) + n_actions = env.n_actions + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.forward_masks = torch.ones(self.batch_shape + (self.n_actions,), dtype=torch.bool) + self.backward_masks = torch.ones(self.batch_shape + (self.n_actions,), dtype=torch.bool) + return RingStates + + def _step(self, states: GraphStates, actions: Actions) -> GraphStates: + actions = self.convert_actions(actions) + return super()._step(states, actions) + + def _backward_step(self, states: GraphStates, actions: Actions) -> GraphStates: + actions = self.convert_actions(actions) + return super()._backward_step(states, actions) + + def convert_actions(self, actions: Actions) -> GraphActions: + action_tensor = actions.tensor.squeeze(-1) + action_type = torch.where(action_tensor == 0, GraphActionType.EXIT, GraphActionType.ADD_EDGE) + edge_index_i0 = (action_tensor - 1) // (self.nodes) + edge_index_i1 = (action_tensor - 1) % (self.nodes) + # edge_index_i1 = edge_index_i1 + (edge_index_i1 >= edge_index_i0) + + edge_index = torch.stack([edge_index_i0, edge_index_i1], dim=-1) + return GraphActions(TensorDict({ + "action_type": action_type, + "features": torch.ones(action_tensor.shape + (1,)), + "edge_index": edge_index, + }, batch_size=action_tensor.shape)) + + +class GraphPreprocessor(Preprocessor): + + def __init__(self, feature_dim: int = 1): + super().__init__(output_dim=feature_dim) + + def preprocess(self, states: GraphStates) -> TensorDict: + return states.tensor + + def __call__(self, states: GraphStates) -> torch.Tensor: + return self.preprocess(states) + + +if __name__ == "__main__": + torch.random.manual_seed(42) + env = RingGraphBuilding(nodes=10) + module = RingPolicyEstimator(env.nodes) + + pf_estimator = DiscretePolicyEstimator(module=module, n_actions=env.n_actions, preprocessor=GraphPreprocessor()) + + gflownet = FMGFlowNet(pf_estimator) + optimizer = torch.optim.Adam(gflownet.parameters(), lr=1e-3) + + visited_terminating_states = env.States.from_batch_shape((0,)) + losses = [] + + for iteration in range(100): + print(f"Iteration {iteration}") + trajectories = gflownet.sample_trajectories(env, n=128) + samples = gflownet.to_training_samples(trajectories) + optimizer.zero_grad() + loss = gflownet.loss(env, samples) + loss.backward() + optimizer.step() + + visited_terminating_states.extend(trajectories.last_states) + losses.append(loss.item()) + + + From 9d42332b6946a01f0457a66a6255deb8455ee48a Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Fri, 20 Dec 2024 13:09:59 +0100 Subject: [PATCH 022/102] remove check edge_features --- src/gfn/gym/graph_building.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/src/gfn/gym/graph_building.py b/src/gfn/gym/graph_building.py index 8a2936d3..d1eb9af1 100644 --- a/src/gfn/gym/graph_building.py +++ b/src/gfn/gym/graph_building.py @@ -123,6 +123,7 @@ def is_action_valid( add_edge_states = states[add_edge_mask].tensor add_edge_actions = actions[add_edge_mask] + import pdb; pdb.set_trace() if torch.any(add_edge_actions.edge_index[:, 0] == add_edge_actions.edge_index[:, 1]): return False if add_edge_states["node_feature"].shape[0] == 0: @@ -130,23 +131,15 @@ def is_action_valid( if torch.any(add_edge_actions.edge_index > add_edge_states["node_feature"].shape[0]): return False - equal_edges_per_batch_attr = torch.all( - add_edge_states["edge_feature"] == add_edge_actions.features[:, None], dim=-1 - ) - equal_edges_per_batch_attr = torch.sum(equal_edges_per_batch_attr, dim=-1) equal_edges_per_batch_index = torch.all( add_edge_states["edge_index"] == add_edge_actions.edge_index[:, None], dim=-1 ) equal_edges_per_batch_index = torch.sum(equal_edges_per_batch_index, dim=-1) if backward: - add_edge_out = torch.all(equal_edges_per_batch_attr == 1) and torch.all( - equal_edges_per_batch_index == 1 - ) + add_edge_out = torch.all(equal_edges_per_batch_index == 1) else: - add_edge_out = torch.all(equal_edges_per_batch_attr == 0) and torch.all( - equal_edges_per_batch_index == 0 - ) + add_edge_out = torch.all(equal_edges_per_batch_index == 0) return bool(add_node_out) and bool(add_edge_out) From 2d442426521ca58709e2afb9d63c6335aeee3852 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Fri, 20 Dec 2024 13:10:17 +0100 Subject: [PATCH 023/102] fix GraphStates set --- src/gfn/states.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/src/gfn/states.py b/src/gfn/states.py index 215d49b0..9be90bc4 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -658,10 +658,8 @@ def __setitem__(self, index: int | Sequence[int], graph: GraphStates): """ Set particular states of the Batch """ - if isinstance(index, (int, list)): - index = torch.tensor(index) - if index.dtype == torch.bool: - index = torch.where(index)[0] + tensor_idx = torch.arange(len(self)).view(*self.batch_shape) + index = tensor_idx[index].flatten() # Validate indices if torch.any(index >= len(self.tensor['batch_ptr']) - 1): @@ -679,10 +677,10 @@ def __setitem__(self, index: int | Sequence[int], graph: GraphStates): # Get start and end pointers for the current graph start_ptr = self.tensor['batch_ptr'][graph_idx] end_ptr = self.tensor['batch_ptr'][graph_idx + 1] + source_start_ptr = source_tensor_dict['batch_ptr'][i] + source_end_ptr = source_tensor_dict['batch_ptr'][i + 1] - new_nodes = source_tensor_dict['node_feature'][ - source_tensor_dict['batch_ptr'][i]:source_tensor_dict['batch_ptr'][i + 1] - ] + new_nodes = source_tensor_dict['node_feature'][source_start_ptr:source_end_ptr] # Ensure new nodes have correct feature dimension if new_nodes.ndim == 1: @@ -706,13 +704,18 @@ def __setitem__(self, index: int | Sequence[int], graph: GraphStates): edge_mask_1 = self.tensor['edge_index'][:, 1] >= end_ptr self.tensor['edge_index'][edge_mask_0, 0] += shift self.tensor['edge_index'][edge_mask_1, 1] += shift + self.tensor['edge_index'] = torch.cat([ + self.tensor['edge_index'], + source_tensor_dict['edge_index'] - source_start_ptr + start_ptr, + ], dim=0) + self.tensor['edge_feature'] = torch.cat([ + self.tensor['edge_feature'], + source_tensor_dict['edge_feature'], + ], dim=0) # Update batch pointers self.tensor['batch_ptr'][graph_idx + 1:] += shift - # TODO: add new edges - - @property def device(self) -> torch.device: return self.tensor.device From 173d4fb029eebda8ce8065a880c27ed3b3ffd885 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Fri, 20 Dec 2024 13:12:50 +0100 Subject: [PATCH 024/102] remove debug --- src/gfn/gym/graph_building.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/gfn/gym/graph_building.py b/src/gfn/gym/graph_building.py index d1eb9af1..1bbb4704 100644 --- a/src/gfn/gym/graph_building.py +++ b/src/gfn/gym/graph_building.py @@ -123,7 +123,6 @@ def is_action_valid( add_edge_states = states[add_edge_mask].tensor add_edge_actions = actions[add_edge_mask] - import pdb; pdb.set_trace() if torch.any(add_edge_actions.edge_index[:, 0] == add_edge_actions.edge_index[:, 1]): return False if add_edge_states["node_feature"].shape[0] == 0: From 7265857e21198dada0a281d4c0397f04f2c26d27 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Fri, 20 Dec 2024 15:19:09 +0100 Subject: [PATCH 025/102] fix add_edge action --- src/gfn/gym/graph_building.py | 19 ++++++++++++------- src/gfn/states.py | 7 +++++-- 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/src/gfn/gym/graph_building.py b/src/gfn/gym/graph_building.py index 1bbb4704..cec3c6c6 100644 --- a/src/gfn/gym/graph_building.py +++ b/src/gfn/gym/graph_building.py @@ -64,7 +64,10 @@ def step(self, states: GraphStates, actions: GraphActions) -> TensorDict: if action_type == GraphActionType.ADD_EDGE: state_tensor["edge_feature"] = torch.cat([state_tensor["edge_feature"], actions.features], dim=0) - state_tensor["edge_index"] = torch.cat([state_tensor["edge_index"], actions.edge_index], dim=0) + state_tensor["edge_index"] = torch.cat([ + state_tensor["edge_index"], + actions.edge_index + state_tensor["batch_ptr"][:-1][:, None] + ], dim=0) return state_tensor @@ -91,8 +94,9 @@ def backward_step(self, states: GraphStates, actions: GraphActions) -> torch.Ten state_tensor["node_feature"] = state_tensor["node_feature"][~is_equal] elif action_type == GraphActionType.ADD_EDGE: assert actions.edge_index is not None + global_edge_index = actions.edge_index + state_tensor["batch_ptr"][:-1][:, None] is_equal = torch.all( - state_tensor["edge_index"] == actions.edge_index[:, None], dim=-1 + state_tensor["edge_index"] == global_edge_index[:, None], dim=-1 ) is_equal = torch.any(is_equal, dim=0) state_tensor["edge_feature"] = state_tensor["edge_feature"][~is_equal] @@ -130,15 +134,16 @@ def is_action_valid( if torch.any(add_edge_actions.edge_index > add_edge_states["node_feature"].shape[0]): return False - equal_edges_per_batch_index = torch.all( - add_edge_states["edge_index"] == add_edge_actions.edge_index[:, None], dim=-1 + global_edge_index = add_edge_actions.edge_index + add_edge_states["batch_ptr"][:-1][:, None] + equal_edges_per_batch = torch.all( + add_edge_states["edge_index"] == global_edge_index[:, None], dim=-1 ) - equal_edges_per_batch_index = torch.sum(equal_edges_per_batch_index, dim=-1) + equal_edges_per_batch = equal_edges_per_batch.sum(dim=-1) if backward: - add_edge_out = torch.all(equal_edges_per_batch_index == 1) + add_edge_out = torch.all(equal_edges_per_batch == 1) else: - add_edge_out = torch.all(equal_edges_per_batch_index == 0) + add_edge_out = torch.all(equal_edges_per_batch == 0) return bool(add_node_out) and bool(add_edge_out) diff --git a/src/gfn/states.py b/src/gfn/states.py index 9be90bc4..471145bf 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -704,13 +704,15 @@ def __setitem__(self, index: int | Sequence[int], graph: GraphStates): edge_mask_1 = self.tensor['edge_index'][:, 1] >= end_ptr self.tensor['edge_index'][edge_mask_0, 0] += shift self.tensor['edge_index'][edge_mask_1, 1] += shift + edge_to_add_mask = torch.all(source_tensor_dict['edge_index'] >= source_start_ptr, dim=-1) + edge_to_add_mask &= torch.all(source_tensor_dict['edge_index'] < source_end_ptr, dim=-1) self.tensor['edge_index'] = torch.cat([ self.tensor['edge_index'], - source_tensor_dict['edge_index'] - source_start_ptr + start_ptr, + source_tensor_dict['edge_index'][edge_to_add_mask] - source_start_ptr + start_ptr, ], dim=0) self.tensor['edge_feature'] = torch.cat([ self.tensor['edge_feature'], - source_tensor_dict['edge_feature'], + source_tensor_dict['edge_feature'][edge_to_add_mask], ], dim=0) # Update batch pointers @@ -735,6 +737,7 @@ def extend(self, other: GraphStates): """Concatenates to another GraphStates object along the batch dimension""" self.tensor["node_feature"] = torch.cat([self.tensor["node_feature"], other.tensor["node_feature"]], dim=0) self.tensor["edge_feature"] = torch.cat([self.tensor["edge_feature"], other.tensor["edge_feature"]], dim=0) + # TODO: fix indices self.tensor["edge_index"] = torch.cat([self.tensor["edge_index"], other.tensor["edge_index"] + self.tensor["batch_ptr"][-1]], dim=0) self.tensor["batch_ptr"] = torch.cat([self.tensor["batch_ptr"], other.tensor["batch_ptr"][1:] + self.tensor["batch_ptr"][-1]], dim=0) assert torch.all(self.tensor["batch_shape"][1:] == other.tensor["batch_shape"][1:]) From 2b3208fd738aabaa873191635ddaa0628e96012e Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Fri, 20 Dec 2024 16:12:42 +0100 Subject: [PATCH 026/102] fix edge_index after get --- src/gfn/gym/graph_building.py | 3 +-- src/gfn/states.py | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/gfn/gym/graph_building.py b/src/gfn/gym/graph_building.py index cec3c6c6..48ea17e3 100644 --- a/src/gfn/gym/graph_building.py +++ b/src/gfn/gym/graph_building.py @@ -137,8 +137,7 @@ def is_action_valid( global_edge_index = add_edge_actions.edge_index + add_edge_states["batch_ptr"][:-1][:, None] equal_edges_per_batch = torch.all( add_edge_states["edge_index"] == global_edge_index[:, None], dim=-1 - ) - equal_edges_per_batch = equal_edges_per_batch.sum(dim=-1) + ).sum(dim=-1) if backward: add_edge_out = torch.all(equal_edges_per_batch == 1) diff --git a/src/gfn/states.py b/src/gfn/states.py index 471145bf..2b419696 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -635,8 +635,8 @@ def __getitem__( # Adjust edge indices to be local to this graph graph_edge_index = self.tensor['edge_index'][edge_mask] - graph_edge_index[:, 0] -= (batch_ptr[-1] - start) - graph_edge_index[:, 1] -= (batch_ptr[-1] - start) + graph_edge_index[:, 0] -= (start - batch_ptr[-1]) + graph_edge_index[:, 1] -= (start - batch_ptr[-1]) edge_indices.append(graph_edge_index) batch_ptr.append(batch_ptr[-1] + len(graph_nodes)) From b84246f2aa497dd668552145ff0fd1641b2cd5ac Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Sun, 22 Dec 2024 23:44:05 +0100 Subject: [PATCH 027/102] push updated code --- src/gfn/env.py | 2 +- src/gfn/gym/graph_building.py | 6 +- src/gfn/states.py | 20 ++++-- tutorials/examples/test_graph_ring.py | 98 ++++++++++++++++++--------- 4 files changed, 83 insertions(+), 43 deletions(-) diff --git a/src/gfn/env.py b/src/gfn/env.py index 86c592b5..38884c97 100644 --- a/src/gfn/env.py +++ b/src/gfn/env.py @@ -300,7 +300,7 @@ def _backward_step( # Calculate the backward step, and update only the states which are not Done. new_not_done_states_tensor = self.backward_step(valid_states, valid_actions) - new_states.tensor[valid_states_idx] = new_not_done_states_tensor + new_states[valid_states_idx] = self.States(new_not_done_states_tensor) if isinstance(new_states, DiscreteStates): self.update_masks(new_states) diff --git a/src/gfn/gym/graph_building.py b/src/gfn/gym/graph_building.py index 48ea17e3..9a15d468 100644 --- a/src/gfn/gym/graph_building.py +++ b/src/gfn/gym/graph_building.py @@ -71,7 +71,7 @@ def step(self, states: GraphStates, actions: GraphActions) -> TensorDict: return state_tensor - def backward_step(self, states: GraphStates, actions: GraphActions) -> torch.Tensor: + def backward_step(self, states: GraphStates, actions: GraphActions) -> TensorDict: """Backward step function for the GraphBuilding environment. Args: @@ -83,6 +83,8 @@ def backward_step(self, states: GraphStates, actions: GraphActions) -> torch.Ten if not self.is_action_valid(states, actions, backward=True): raise NonValidActionsError("Invalid action.") state_tensor = deepcopy(states.tensor) + if len(actions) == 0: + return state_tensor action_type = actions.action_type[0] assert torch.all(actions.action_type == action_type) @@ -126,14 +128,12 @@ def is_action_valid( else: add_edge_states = states[add_edge_mask].tensor add_edge_actions = actions[add_edge_mask] - if torch.any(add_edge_actions.edge_index[:, 0] == add_edge_actions.edge_index[:, 1]): return False if add_edge_states["node_feature"].shape[0] == 0: return False if torch.any(add_edge_actions.edge_index > add_edge_states["node_feature"].shape[0]): return False - global_edge_index = add_edge_actions.edge_index + add_edge_states["batch_ptr"][:-1][:, None] equal_edges_per_batch = torch.all( add_edge_states["edge_index"] == global_edge_index[:, None], dim=-1 diff --git a/src/gfn/states.py b/src/gfn/states.py index 2b419696..7875038c 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -523,7 +523,7 @@ def __init__(self, tensor: TensorDict): self.node_features_dim = tensor["node_feature"].shape[-1] self.edge_features_dim = tensor["edge_feature"].shape[-1] - self._log_rewards: float = None + self._log_rewards: Optional[float] = None # TODO logic repeated from env.is_valid_action not_empty = self.tensor["batch_ptr"][:-1] + 1 < self.tensor["batch_ptr"][1:] self.forward_masks = torch.ones((np.prod(self.batch_shape), 3), dtype=torch.bool) @@ -700,18 +700,19 @@ def __setitem__(self, index: int | Sequence[int], graph: GraphStates): ]) # Update edge indices for subsequent graphs - edge_mask_0 = self.tensor['edge_index'][:, 0] >= end_ptr - edge_mask_1 = self.tensor['edge_index'][:, 1] >= end_ptr - self.tensor['edge_index'][edge_mask_0, 0] += shift - self.tensor['edge_index'][edge_mask_1, 1] += shift + edge_mask = self.tensor['edge_index'] >= end_ptr + assert torch.all(edge_mask[..., 0] == edge_mask[..., 1]) + edge_mask = torch.all(edge_mask, dim=-1) + self.tensor['edge_index'][edge_mask] += shift + edge_mask |= torch.all(self.tensor['edge_index'] < start_ptr, dim=-1) edge_to_add_mask = torch.all(source_tensor_dict['edge_index'] >= source_start_ptr, dim=-1) edge_to_add_mask &= torch.all(source_tensor_dict['edge_index'] < source_end_ptr, dim=-1) self.tensor['edge_index'] = torch.cat([ - self.tensor['edge_index'], + self.tensor['edge_index'][edge_mask], source_tensor_dict['edge_index'][edge_to_add_mask] - source_start_ptr + start_ptr, ], dim=0) self.tensor['edge_feature'] = torch.cat([ - self.tensor['edge_feature'], + self.tensor['edge_feature'][edge_mask], source_tensor_dict['edge_feature'][edge_to_add_mask], ], dim=0) @@ -759,6 +760,11 @@ def _compare(self, other: TensorDict) -> torch.Tensor: out[i] = False else: out[i] = torch.all(self.tensor["node_feature"][start:end] == other["node_feature"]) + edge_mask = torch.all((self.tensor["edge_index"] >= start) & (self.tensor["edge_index"] < end), dim=-1) + edge_index = self.tensor["edge_index"][edge_mask] - start + out[i] &= len(edge_index) == len(other["edge_index"]) and torch.all(edge_index == other["edge_index"]) + edge_feature = self.tensor["edge_feature"][edge_mask] + out[i] &= len(edge_feature) == len(other["edge_feature"]) and torch.all(edge_feature == other["edge_feature"]) return out.view(self.batch_shape) @property diff --git a/tutorials/examples/test_graph_ring.py b/tutorials/examples/test_graph_ring.py index 8546ec41..3b2679b8 100644 --- a/tutorials/examples/test_graph_ring.py +++ b/tutorials/examples/test_graph_ring.py @@ -1,5 +1,6 @@ """Write ane xamples where we want to create graphs that are rings.""" +from typing import Optional import torch from torch import nn from gfn.actions import Actions, GraphActionType, GraphActions @@ -13,17 +14,24 @@ def state_evaluator(states: GraphStates) -> torch.Tensor: + eps = 1e-6 if states.tensor["edge_index"].shape[0] == 0: - return torch.zeros(states.batch_shape) + return torch.full(states.batch_shape, eps) if states.tensor["edge_index"].shape[0] != states.tensor["node_feature"].shape[0]: - return torch.zeros(states.batch_shape) - - i0 = torch.unique(states.tensor["edge_index"][0], sorted=False) - i1 = torch.unique(states.tensor["edge_index"][1], sorted=False) - - if len(i0) == len(i1) == states.tensor["node_feature"].shape[0]: - return torch.ones(states.batch_shape) - return torch.zeros(states.batch_shape) + return torch.full(states.batch_shape, eps) + + out = torch.zeros(len(states)) + for i in range(len(states)): + start, end = states.tensor["batch_ptr"][i], states.tensor["batch_ptr"][i + 1] + edge_index_mask = torch.all(states.tensor["edge_index"] >= start, dim=-1) & torch.all(states.tensor["edge_index"] < end, dim=-1) + edge_index = states.tensor["edge_index"][edge_index_mask] + arange = torch.arange(start, end) + # TODO: not correct, accepts multiple rings + if torch.all(torch.sort(edge_index[:, 0])[0] == arange) and torch.all(torch.sort(edge_index[:, 1])[0] == arange): + out[i] = 1 + else: + out[i] = eps + return out.view(*states.batch_shape) class RingPolicyEstimator(nn.Module): @@ -47,18 +55,16 @@ def forward(self, states_tensor: TensorDict) -> torch.Tensor: action_type = self._group_sum(action_type, batch_ptr) edge_index = self.edge_index_conv(node_feature, edge_index) - #edge_index = self._group_sum(edge_index, batch_ptr) - edge_index = edge_index.reshape(*states_tensor["batch_shape"], -1, 8) + edge_index = edge_index.reshape(*states_tensor["batch_shape"], self.n_nodes, 8) edge_index = torch.einsum("bnf,bmf->bnm", edge_index, edge_index) - torch.diagonal(edge_index, dim1=-2, dim2=-1).fill_(float("-inf")) - edge_actions = edge_index.reshape(*states_tensor["batch_shape"], -1) + edge_actions = edge_index.reshape(*states_tensor["batch_shape"], self.n_nodes * self.n_nodes) return torch.cat([action_type, edge_actions], dim=-1) class RingGraphBuilding(GraphBuilding): - def __init__(self, nodes: int = 10): - self.nodes = nodes - self.n_actions = 1 + nodes * nodes + def __init__(self, n_nodes: int = 10): + self.n_nodes = n_nodes + self.n_actions = 1 + n_nodes * n_nodes super().__init__(feature_dim=1, state_evaluator=state_evaluator) @@ -77,22 +83,52 @@ def make_states_class(self) -> type[GraphStates]: class RingStates(GraphStates): s0 = TensorDict({ - "node_feature": torch.zeros((env.nodes, 1)), + "node_feature": torch.zeros((env.n_nodes, 1)), "edge_feature": torch.zeros((0, 1)), "edge_index": torch.zeros((0, 2), dtype=torch.long), }, batch_size=()) sf = TensorDict({ - "node_feature": torch.ones((env.nodes, 1)), + "node_feature": torch.ones((env.n_nodes, 1)), "edge_feature": torch.zeros((0, 1)), "edge_index": torch.zeros((0, 2), dtype=torch.long), }, batch_size=()) - n_actions = env.n_actions - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - self.forward_masks = torch.ones(self.batch_shape + (self.n_actions,), dtype=torch.bool) - self.backward_masks = torch.ones(self.batch_shape + (self.n_actions,), dtype=torch.bool) + def __init__(self, tensor: TensorDict): + self.tensor = tensor + self.node_features_dim = tensor["node_feature"].shape[-1] + self.edge_features_dim = tensor["edge_feature"].shape[-1] + self._log_rewards: Optional[float] = None + + self.n_nodes = env.n_nodes + self.n_actions = env.n_actions + + @property + def forward_masks(self): + forward_masks = torch.ones(len(self), self.n_actions, dtype=torch.bool) + forward_masks[:, 1::self.n_nodes + 1] = False + for i in range(len(self)): + existing_edges = self[i].tensor["edge_index"] + forward_masks[i, 1 + existing_edges[:, 0] * self.n_nodes + existing_edges[:, 1]] = False + + return forward_masks.view(*self.batch_shape, self.n_actions) + + @forward_masks.setter + def forward_masks(self, value: torch.Tensor): + pass # fwd masks is computed on the fly + + @property + def backward_masks(self): + backward_masks = torch.zeros(len(self), self.n_actions, dtype=torch.bool) + for i in range(len(self)): + existing_edges = self[i].tensor["edge_index"] + backward_masks[i, 1 + existing_edges[:, 0] * self.n_nodes + existing_edges[:, 1]] = True + + return backward_masks.view(*self.batch_shape, self.n_actions) + + @backward_masks.setter + def backward_masks(self, value: torch.Tensor): + pass # bwd masks is computed on the fly + return RingStates def _step(self, states: GraphStates, actions: Actions) -> GraphStates: @@ -106,9 +142,8 @@ def _backward_step(self, states: GraphStates, actions: Actions) -> GraphStates: def convert_actions(self, actions: Actions) -> GraphActions: action_tensor = actions.tensor.squeeze(-1) action_type = torch.where(action_tensor == 0, GraphActionType.EXIT, GraphActionType.ADD_EDGE) - edge_index_i0 = (action_tensor - 1) // (self.nodes) - edge_index_i1 = (action_tensor - 1) % (self.nodes) - # edge_index_i1 = edge_index_i1 + (edge_index_i1 >= edge_index_i0) + edge_index_i0 = (action_tensor - 1) // (self.n_nodes) + edge_index_i1 = (action_tensor - 1) % (self.n_nodes) edge_index = torch.stack([edge_index_i0, edge_index_i1], dim=-1) return GraphActions(TensorDict({ @@ -132,8 +167,8 @@ def __call__(self, states: GraphStates) -> torch.Tensor: if __name__ == "__main__": torch.random.manual_seed(42) - env = RingGraphBuilding(nodes=10) - module = RingPolicyEstimator(env.nodes) + env = RingGraphBuilding(n_nodes=3) + module = RingPolicyEstimator(env.n_nodes) pf_estimator = DiscretePolicyEstimator(module=module, n_actions=env.n_actions, preprocessor=GraphPreprocessor()) @@ -143,9 +178,8 @@ def __call__(self, states: GraphStates) -> torch.Tensor: visited_terminating_states = env.States.from_batch_shape((0,)) losses = [] - for iteration in range(100): - print(f"Iteration {iteration}") - trajectories = gflownet.sample_trajectories(env, n=128) + for iteration in range(128): + trajectories = gflownet.sample_trajectories(env, n=32) samples = gflownet.to_training_samples(trajectories) optimizer.zero_grad() loss = gflownet.loss(env, samples) From fa0d22ae2ccb18609279712281042baf1a8863ea Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Fri, 27 Dec 2024 13:43:11 +0100 Subject: [PATCH 028/102] add rendering --- tutorials/examples/test_graph_ring.py | 51 +++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/tutorials/examples/test_graph_ring.py b/tutorials/examples/test_graph_ring.py index 3b2679b8..f1d513c3 100644 --- a/tutorials/examples/test_graph_ring.py +++ b/tutorials/examples/test_graph_ring.py @@ -1,5 +1,6 @@ """Write ane xamples where we want to create graphs that are rings.""" +import math from typing import Optional import torch from torch import nn @@ -11,6 +12,7 @@ from gfn.states import GraphStates from tensordict import TensorDict from torch_geometric.nn import GCNConv +import matplotlib.pyplot as plt def state_evaluator(states: GraphStates) -> torch.Tensor: @@ -165,6 +167,53 @@ def __call__(self, states: GraphStates) -> torch.Tensor: return self.preprocess(states) +def render_states(states: GraphStates): + fig, ax = plt.subplots(2, 4, figsize=(15, 7)) + for i in range(8): + current_ax = ax[i // 4, i % 4] + state = states[i] + n_circles = state.tensor["node_feature"].shape[0] + radius = 5 + xs, ys = [], [] + for j in range(n_circles): + angle = 2 * math.pi * j / n_circles + x = radius * math.cos(angle) + y = radius * math.sin(angle) + xs.append(x) + ys.append(y) + current_ax.add_patch(plt.Circle((x, y), 0.5, facecolor='none', edgecolor='black')) + + for edge in state.tensor["edge_index"]: + start_x, start_y = xs[edge[0]], ys[edge[0]] + end_x, end_y = xs[edge[1]], ys[edge[1]] + dx = end_x - start_x + dy = end_y - start_y + length = math.sqrt(dx**2 + dy**2) + dx, dy = dx/length, dy/length + + circle_radius = 0.5 + head_thickness = 0.2 + start_x += dx * (circle_radius) + start_y += dy * (circle_radius) + end_x -= dx * (circle_radius + head_thickness) + end_y -= dy * (circle_radius + head_thickness) + + current_ax.arrow(start_x, start_y, + end_x - start_x, end_y - start_y, + head_width=head_thickness, head_length=head_thickness, + fc='black', ec='black') + + current_ax.set_title(f"State {i}") + current_ax.set_xlim(-(radius + 1), radius + 1) + current_ax.set_ylim(-(radius + 1), radius + 1) + current_ax.set_aspect("equal") + current_ax.set_xticks([]) + current_ax.set_yticks([]) + + + plt.show() + + if __name__ == "__main__": torch.random.manual_seed(42) env = RingGraphBuilding(n_nodes=3) @@ -188,6 +237,8 @@ def __call__(self, states: GraphStates) -> torch.Tensor: visited_terminating_states.extend(trajectories.last_states) losses.append(loss.item()) + + render_states(visited_terminating_states[:-8]) From 27d192a9c7346724cccb73ae1ce50c240ba2062e Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Mon, 6 Jan 2025 10:56:08 +0100 Subject: [PATCH 029/102] fix gradient propagation --- tutorials/examples/test_graph_ring.py | 41 ++++++++++++++++----------- 1 file changed, 24 insertions(+), 17 deletions(-) diff --git a/tutorials/examples/test_graph_ring.py b/tutorials/examples/test_graph_ring.py index f1d513c3..c6b591c2 100644 --- a/tutorials/examples/test_graph_ring.py +++ b/tutorials/examples/test_graph_ring.py @@ -1,6 +1,7 @@ """Write ane xamples where we want to create graphs that are rings.""" import math +import time from typing import Optional import torch from torch import nn @@ -37,11 +38,12 @@ def state_evaluator(states: GraphStates) -> torch.Tensor: class RingPolicyEstimator(nn.Module): - def __init__(self, n_nodes: int): + def __init__(self, n_nodes: int, edge_hidden_dim: int = 128): super().__init__() self.action_type_conv = GCNConv(1, 1) - self.edge_index_conv = GCNConv(1, 8) + self.edge_index_conv = GCNConv(1, edge_hidden_dim) self.n_nodes = n_nodes + self.edge_hidden_dim = edge_hidden_dim def _group_sum(self, tensor: torch.Tensor, batch_ptr: torch.Tensor) -> torch.Tensor: cumsum = torch.zeros((len(tensor) + 1, *tensor.shape[1:]), dtype=tensor.dtype, device=tensor.device) @@ -57,7 +59,7 @@ def forward(self, states_tensor: TensorDict) -> torch.Tensor: action_type = self._group_sum(action_type, batch_ptr) edge_index = self.edge_index_conv(node_feature, edge_index) - edge_index = edge_index.reshape(*states_tensor["batch_shape"], self.n_nodes, 8) + edge_index = edge_index.reshape(*states_tensor["batch_shape"], self.n_nodes, self.edge_hidden_dim) edge_index = torch.einsum("bnf,bmf->bnm", edge_index, edge_index) edge_actions = edge_index.reshape(*states_tensor["batch_shape"], self.n_nodes * self.n_nodes) @@ -85,12 +87,12 @@ def make_states_class(self) -> type[GraphStates]: class RingStates(GraphStates): s0 = TensorDict({ - "node_feature": torch.zeros((env.n_nodes, 1)), - "edge_feature": torch.zeros((0, 1)), - "edge_index": torch.zeros((0, 2), dtype=torch.long), + "node_feature": torch.arange(env.n_nodes).unsqueeze(-1), + "edge_feature": torch.ones((0, 1)), + "edge_index": torch.ones((0, 2), dtype=torch.long), }, batch_size=()) sf = TensorDict({ - "node_feature": torch.ones((env.n_nodes, 1)), + "node_feature": torch.zeros((env.n_nodes, 1)), "edge_feature": torch.zeros((0, 1)), "edge_index": torch.zeros((0, 2), dtype=torch.long), }, batch_size=()) @@ -168,6 +170,7 @@ def __call__(self, states: GraphStates) -> torch.Tensor: def render_states(states: GraphStates): + rewards = state_evaluator(states) fig, ax = plt.subplots(2, 4, figsize=(15, 7)) for i in range(8): current_ax = ax[i // 4, i % 4] @@ -203,7 +206,7 @@ def render_states(states: GraphStates): head_width=head_thickness, head_length=head_thickness, fc='black', ec='black') - current_ax.set_title(f"State {i}") + current_ax.set_title(f"State {i}, $r={rewards[i]:.2f}$") current_ax.set_xlim(-(radius + 1), radius + 1) current_ax.set_ylim(-(radius + 1), radius + 1) current_ax.set_aspect("equal") @@ -215,30 +218,34 @@ def render_states(states: GraphStates): if __name__ == "__main__": + N_NODES = 3 + N_ITERATIONS = 1024 torch.random.manual_seed(42) - env = RingGraphBuilding(n_nodes=3) + env = RingGraphBuilding(n_nodes=N_NODES) module = RingPolicyEstimator(env.n_nodes) pf_estimator = DiscretePolicyEstimator(module=module, n_actions=env.n_actions, preprocessor=GraphPreprocessor()) gflownet = FMGFlowNet(pf_estimator) - optimizer = torch.optim.Adam(gflownet.parameters(), lr=1e-3) + optimizer = torch.optim.Adam(gflownet.parameters(), lr=1e-2) + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=N_ITERATIONS, eta_min=1e-4) - visited_terminating_states = env.States.from_batch_shape((0,)) losses = [] - for iteration in range(128): - trajectories = gflownet.sample_trajectories(env, n=32) + t1 = time.time() + for iteration in range(N_ITERATIONS): + trajectories = gflownet.sample_trajectories(env, n=64) samples = gflownet.to_training_samples(trajectories) optimizer.zero_grad() loss = gflownet.loss(env, samples) + print("Iteration", iteration, "Loss:", loss.item()) loss.backward() optimizer.step() - - visited_terminating_states.extend(trajectories.last_states) losses.append(loss.item()) - - render_states(visited_terminating_states[:-8]) + scheduler.step() + t2 = time.time() + print("Time:", t2 - t1) + render_states(trajectories.last_states[:8]) From f4fc3ab1a5f8ac953ec98c0cd05264868257a418 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Sun, 12 Jan 2025 20:26:30 +0100 Subject: [PATCH 030/102] fix formatting --- src/gfn/actions.py | 51 ++-- src/gfn/env.py | 6 +- src/gfn/gym/__init__.py | 2 +- src/gfn/gym/graph_building.py | 134 +++++++---- src/gfn/modules.py | 15 +- src/gfn/samplers.py | 2 +- src/gfn/states.py | 270 ++++++++++++++-------- src/gfn/utils/distributions.py | 5 +- testing/test_environments.py | 146 ++++++++---- testing/test_samplers_and_trajectories.py | 7 +- tutorials/examples/test_graph_ring.py | 160 ++++++++----- 11 files changed, 517 insertions(+), 281 deletions(-) diff --git a/src/gfn/actions.py b/src/gfn/actions.py index d2e9b3b0..c80eb2d3 100644 --- a/src/gfn/actions.py +++ b/src/gfn/actions.py @@ -3,7 +3,7 @@ import enum from abc import ABC from math import prod -from typing import ClassVar, Optional, Sequence +from typing import ClassVar, Sequence import torch from tensordict import TensorDict @@ -201,11 +201,14 @@ def __init__(self, tensor: TensorDict): assert torch.all(tensor["action_type"] != GraphActionType.ADD_EDGE) edge_index = torch.zeros((*self.batch_shape, 2), dtype=torch.long) - self.tensor = TensorDict({ - "action_type": tensor["action_type"], - "features": features, - "edge_index": edge_index, - }, batch_size=self.batch_shape) + self.tensor = TensorDict( + { + "action_type": tensor["action_type"], + "features": features, + "edge_index": edge_index, + }, + batch_size=self.batch_shape, + ) def __repr__(self): return f"""GraphAction object with {self.batch_shape} actions.""" @@ -223,7 +226,6 @@ def __getitem__(self, index: int | Sequence[int] | Sequence[bool]) -> GraphActio """Get particular actions of the batch.""" return GraphActions(self.tensor[index]) - def __setitem__( self, index: int | Sequence[int] | Sequence[bool], action: GraphActions ) -> None: @@ -239,9 +241,14 @@ def compare(self, other: GraphActions) -> torch.Tensor: Returns: boolean tensor of shape batch_shape indicating whether the actions are equal. """ compare = torch.all(self.tensor == other.tensor, dim=-1) - return compare["action_type"] & \ - (compare["action_type"] == GraphActionType.EXIT | compare["features"]) & \ - (compare["action_type"] != GraphActionType.ADD_EDGE | compare["edge_index"]) + return ( + compare["action_type"] + & (compare["action_type"] == GraphActionType.EXIT | compare["features"]) + & ( + compare["action_type"] + != GraphActionType.ADD_EDGE | compare["edge_index"] + ) + ) @property def is_exit(self) -> torch.Tensor: @@ -257,25 +264,28 @@ def action_type(self) -> torch.Tensor: def features(self) -> torch.Tensor: """Returns the features tensor.""" return self.tensor["features"] - + @property def edge_index(self) -> torch.Tensor: """Returns the edge index tensor.""" return self.tensor["edge_index"] @classmethod - def make_dummy_actions( - cls, batch_shape: tuple[int] - ) -> GraphActions: + def make_dummy_actions(cls, batch_shape: tuple[int]) -> GraphActions: """Creates an Actions object of dummy actions with the given batch shape.""" return cls( - TensorDict({ - "action_type": torch.full(batch_shape, fill_value=GraphActionType.EXIT), - # "features": torch.zeros((*batch_shape, 0, cls.nodes_features_dim)), - # "edge_index": torch.zeros((2, *batch_shape, 0)), - }, batch_size=batch_shape) + TensorDict( + { + "action_type": torch.full( + batch_shape, fill_value=GraphActionType.EXIT + ), + # "features": torch.zeros((*batch_shape, 0, cls.nodes_features_dim)), + # "edge_index": torch.zeros((2, *batch_shape, 0)), + }, + batch_size=batch_shape, + ) ) - + @classmethod def stack(cls, actions_list: list[GraphActions]) -> GraphActions: """Stacks a list of GraphActions objects into a single GraphActions object.""" @@ -283,4 +293,3 @@ def stack(cls, actions_list: list[GraphActions]) -> GraphActions: [actions.tensor for actions in actions_list], dim=0 ) return cls(actions_tensor) - diff --git a/src/gfn/env.py b/src/gfn/env.py index 38884c97..28734926 100644 --- a/src/gfn/env.py +++ b/src/gfn/env.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Dict, Optional, Tuple, Union +from typing import Optional, Tuple, Union import torch from tensordict import TensorDict @@ -219,7 +219,7 @@ def reset( batch_shape = (1,) if isinstance(batch_shape, int): batch_shape = (batch_shape,) - + return self.States.from_batch_shape( batch_shape=batch_shape, random=random, sink=sink ) @@ -266,7 +266,7 @@ def _step( not_done_actions = actions[~new_sink_states_idx] new_not_done_states_tensor = self.step(not_done_states, not_done_actions) - + if not isinstance(new_not_done_states_tensor, (torch.Tensor, TensorDict)): raise Exception( "User implemented env.step function *must* return a torch.Tensor!" diff --git a/src/gfn/gym/__init__.py b/src/gfn/gym/__init__.py index 20490566..ebec6f20 100644 --- a/src/gfn/gym/__init__.py +++ b/src/gfn/gym/__init__.py @@ -1,4 +1,4 @@ from gfn.gym.box import Box from gfn.gym.discrete_ebm import DiscreteEBM +from gfn.gym.graph_building import GraphBuilding from gfn.gym.hypergrid import HyperGrid -from gfn.gym.graph_building import GraphBuilding \ No newline at end of file diff --git a/src/gfn/gym/graph_building.py b/src/gfn/gym/graph_building.py index 9a15d468..415c4ec1 100644 --- a/src/gfn/gym/graph_building.py +++ b/src/gfn/gym/graph_building.py @@ -2,8 +2,8 @@ from typing import Callable, Literal, Tuple import torch -from torch_geometric.nn import GCNConv from tensordict import TensorDict +from torch_geometric.nn import GCNConv from gfn.actions import GraphActions, GraphActionType from gfn.env import GraphEnv, NonValidActionsError @@ -17,16 +17,24 @@ def __init__( state_evaluator: Callable[[GraphStates], torch.Tensor] | None = None, device_str: Literal["cpu", "cuda"] = "cpu", ): - s0 = TensorDict({ - "node_feature": torch.zeros((0, feature_dim), dtype=torch.float32), - "edge_feature": torch.zeros((0, feature_dim), dtype=torch.float32), - "edge_index": torch.zeros((0, 2), dtype=torch.long), - }, device=device_str) - sf = TensorDict({ - "node_feature": torch.ones((1, feature_dim), dtype=torch.float32) * float("inf"), - "edge_feature": torch.ones((1, feature_dim), dtype=torch.float32) * float("inf"), - "edge_index": torch.zeros((0, 2), dtype=torch.long), - }, device=device_str) + s0 = TensorDict( + { + "node_feature": torch.zeros((0, feature_dim), dtype=torch.float32), + "edge_feature": torch.zeros((0, feature_dim), dtype=torch.float32), + "edge_index": torch.zeros((0, 2), dtype=torch.long), + }, + device=device_str, + ) + sf = TensorDict( + { + "node_feature": torch.ones((1, feature_dim), dtype=torch.float32) + * float("inf"), + "edge_feature": torch.ones((1, feature_dim), dtype=torch.float32) + * float("inf"), + "edge_index": torch.zeros((0, 2), dtype=torch.long), + }, + device=device_str, + ) if state_evaluator is None: state_evaluator = GCNConvEvaluator(feature_dim) @@ -59,15 +67,22 @@ def step(self, states: GraphStates, actions: GraphActions) -> TensorDict: return self.States.make_sink_states_tensor(states.batch_shape) if action_type == GraphActionType.ADD_NODE: - batch_indices = torch.arange(len(states))[actions.action_type == GraphActionType.ADD_NODE] + batch_indices = torch.arange(len(states))[ + actions.action_type == GraphActionType.ADD_NODE + ] state_tensor = self._add_node(state_tensor, batch_indices, actions.features) if action_type == GraphActionType.ADD_EDGE: - state_tensor["edge_feature"] = torch.cat([state_tensor["edge_feature"], actions.features], dim=0) - state_tensor["edge_index"] = torch.cat([ - state_tensor["edge_index"], - actions.edge_index + state_tensor["batch_ptr"][:-1][:, None] - ], dim=0) + state_tensor["edge_feature"] = torch.cat( + [state_tensor["edge_feature"], actions.features], dim=0 + ) + state_tensor["edge_index"] = torch.cat( + [ + state_tensor["edge_index"], + actions.edge_index + state_tensor["batch_ptr"][:-1][:, None], + ], + dim=0, + ) return state_tensor @@ -90,13 +105,17 @@ def backward_step(self, states: GraphStates, actions: GraphActions) -> TensorDic assert torch.all(actions.action_type == action_type) if action_type == GraphActionType.ADD_NODE: is_equal = torch.any( - torch.all(state_tensor["node_feature"][:, None] == actions.features, dim=-1), - dim=-1 + torch.all( + state_tensor["node_feature"][:, None] == actions.features, dim=-1 + ), + dim=-1, ) state_tensor["node_feature"] = state_tensor["node_feature"][~is_equal] elif action_type == GraphActionType.ADD_EDGE: assert actions.edge_index is not None - global_edge_index = actions.edge_index + state_tensor["batch_ptr"][:-1][:, None] + global_edge_index = ( + actions.edge_index + state_tensor["batch_ptr"][:-1][:, None] + ) is_equal = torch.all( state_tensor["edge_index"] == global_edge_index[:, None], dim=-1 ) @@ -121,67 +140,84 @@ def is_action_valid( add_node_out = torch.all(equal_nodes_per_batch == 1) else: add_node_out = torch.all(equal_nodes_per_batch == 0) - + add_edge_mask = actions.action_type == GraphActionType.ADD_EDGE if not torch.any(add_edge_mask): add_edge_out = True else: add_edge_states = states[add_edge_mask].tensor add_edge_actions = actions[add_edge_mask] - if torch.any(add_edge_actions.edge_index[:, 0] == add_edge_actions.edge_index[:, 1]): + if torch.any( + add_edge_actions.edge_index[:, 0] == add_edge_actions.edge_index[:, 1] + ): return False if add_edge_states["node_feature"].shape[0] == 0: return False - if torch.any(add_edge_actions.edge_index > add_edge_states["node_feature"].shape[0]): + if torch.any( + add_edge_actions.edge_index > add_edge_states["node_feature"].shape[0] + ): return False - global_edge_index = add_edge_actions.edge_index + add_edge_states["batch_ptr"][:-1][:, None] + global_edge_index = ( + add_edge_actions.edge_index + add_edge_states["batch_ptr"][:-1][:, None] + ) equal_edges_per_batch = torch.all( add_edge_states["edge_index"] == global_edge_index[:, None], dim=-1 ).sum(dim=-1) - + if backward: add_edge_out = torch.all(equal_edges_per_batch == 1) else: add_edge_out = torch.all(equal_edges_per_batch == 0) - + return bool(add_node_out) and bool(add_edge_out) - - def _add_node(self, tensor_dict: TensorDict, batch_indices: torch.Tensor, nodes_to_add: torch.Tensor) -> TensorDict: + + def _add_node( + self, + tensor_dict: TensorDict, + batch_indices: torch.Tensor, + nodes_to_add: torch.Tensor, + ) -> TensorDict: if isinstance(batch_indices, list): batch_indices = torch.tensor(batch_indices) if len(batch_indices) != len(nodes_to_add): - raise ValueError("Number of batch indices must match number of node feature lists") - + raise ValueError( + "Number of batch indices must match number of node feature lists" + ) + modified_dict = tensor_dict.clone() - node_feature_dim = modified_dict['node_feature'].shape[1] - + node_feature_dim = modified_dict["node_feature"].shape[1] + for graph_idx, new_nodes in zip(batch_indices, nodes_to_add): - start_ptr = tensor_dict['batch_ptr'][graph_idx] - end_ptr = tensor_dict['batch_ptr'][graph_idx + 1] + tensor_dict["batch_ptr"][graph_idx] + end_ptr = tensor_dict["batch_ptr"][graph_idx + 1] if new_nodes.ndim == 1: new_nodes = new_nodes.unsqueeze(0) if new_nodes.shape[1] != node_feature_dim: - raise ValueError(f"Node features must have dimension {node_feature_dim}") - + raise ValueError( + f"Node features must have dimension {node_feature_dim}" + ) + # Update batch pointers for subsequent graphs shift = new_nodes.shape[0] - modified_dict['batch_ptr'][graph_idx + 1:] += shift - + modified_dict["batch_ptr"][graph_idx + 1 :] += shift + # Expand node features - modified_dict['node_feature'] = torch.cat([ - modified_dict['node_feature'][:end_ptr], - new_nodes, - modified_dict['node_feature'][end_ptr:] - ]) - + modified_dict["node_feature"] = torch.cat( + [ + modified_dict["node_feature"][:end_ptr], + new_nodes, + modified_dict["node_feature"][end_ptr:], + ] + ) + # Update edge indices # Increment indices for edges after the current graph - edge_mask_0 = modified_dict['edge_index'][:, 0] >= end_ptr - edge_mask_1 = modified_dict['edge_index'][:, 1] >= end_ptr - modified_dict['edge_index'][edge_mask_0, 0] += shift - modified_dict['edge_index'][edge_mask_1, 1] += shift - + edge_mask_0 = modified_dict["edge_index"][:, 0] >= end_ptr + edge_mask_1 = modified_dict["edge_index"][:, 1] >= end_ptr + modified_dict["edge_index"][edge_mask_0, 0] += shift + modified_dict["edge_index"][edge_mask_1, 1] += shift + return modified_dict def reward(self, final_states: GraphStates) -> torch.Tensor: diff --git a/src/gfn/modules.py b/src/gfn/modules.py index 169a1f57..e8058e65 100644 --- a/src/gfn/modules.py +++ b/src/gfn/modules.py @@ -9,7 +9,12 @@ from gfn.actions import GraphActionType from gfn.preprocessors import IdentityPreprocessor, Preprocessor from gfn.states import DiscreteStates, GraphStates, States -from gfn.utils.distributions import CategoricalActionType, CategoricalIndexes, ComposedDistribution, UnsqueezedCategorical +from gfn.utils.distributions import ( + CategoricalActionType, + CategoricalIndexes, + ComposedDistribution, + UnsqueezedCategorical, +) REDUCTION_FXNS = { "mean": torch.mean, @@ -235,7 +240,9 @@ def forward(self, states: DiscreteStates) -> torch.Tensor: Returns the output of the module, as a tensor of shape (*batch_shape, output_dim). """ out = super().forward(states) - assert out.shape[-1] == self.expected_output_dim, f"Expected output dim: {self.expected_output_dim}, got: {out.shape[-1]}" + assert ( + out.shape[-1] == self.expected_output_dim + ), f"Expected output dim: {self.expected_output_dim}, got: {out.shape[-1]}" return out def to_probability_distribution( @@ -517,7 +524,9 @@ def to_probability_distribution( dists["action_type"] = CategoricalActionType(probs=action_type_probs) edge_index_logits = module_output["edge_index"] - if states.tensor["node_feature"].shape[0] > 1 and torch.any(edge_index_logits != -float("inf")): + if states.tensor["node_feature"].shape[0] > 1 and torch.any( + edge_index_logits != -float("inf") + ): edge_index_probs = torch.softmax(edge_index_logits / temperature, dim=-1) uniform_dist_probs = ( torch.ones_like(edge_index_probs) / edge_index_probs.shape[-1] diff --git a/src/gfn/samplers.py b/src/gfn/samplers.py index d0f580fd..17352f22 100644 --- a/src/gfn/samplers.py +++ b/src/gfn/samplers.py @@ -167,7 +167,7 @@ def sample_trajectories( step = 0 all_estimator_outputs = [] - + while not all(dones): actions = env.actions_from_batch_shape((n_trajectories,)) # Dummy actions. log_probs = torch.full( diff --git a/src/gfn/states.py b/src/gfn/states.py index 7875038c..d7cc96f6 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -293,7 +293,7 @@ def log_rewards(self, log_rewards: torch.Tensor) -> None: def sample(self, n_samples: int) -> States: """Samples a subset of the States object.""" return self[torch.randperm(len(self))[:n_samples]] - + @classmethod def stack(cls, states: List[States]): """Given a list of states, stacks them along a new dimension (0).""" @@ -526,17 +526,23 @@ def __init__(self, tensor: TensorDict): self._log_rewards: Optional[float] = None # TODO logic repeated from env.is_valid_action not_empty = self.tensor["batch_ptr"][:-1] + 1 < self.tensor["batch_ptr"][1:] - self.forward_masks = torch.ones((np.prod(self.batch_shape), 3), dtype=torch.bool) + self.forward_masks = torch.ones( + (np.prod(self.batch_shape), 3), dtype=torch.bool + ) self.forward_masks[..., GraphActionType.ADD_EDGE] = not_empty self.forward_masks[..., GraphActionType.EXIT] = not_empty self.forward_masks = self.forward_masks.view(*self.batch_shape, 3) - self.backward_masks = torch.ones((np.prod(self.batch_shape), 3), dtype=torch.bool) + self.backward_masks = torch.ones( + (np.prod(self.batch_shape), 3), dtype=torch.bool + ) self.backward_masks[..., GraphActionType.ADD_NODE] = not_empty - self.backward_masks[..., GraphActionType.ADD_EDGE] = not_empty # TODO: check at least one edge is present + self.backward_masks[ + ..., GraphActionType.ADD_EDGE + ] = not_empty # TODO: check at least one edge is present self.backward_masks[..., GraphActionType.EXIT] = not_empty self.backward_masks = self.backward_masks.view(*self.batch_shape, 3) - + @property def batch_shape(self) -> tuple: return tuple(self.tensor["batch_shape"].tolist()) @@ -559,13 +565,16 @@ def from_batch_shape( def make_initial_states_tensor(cls, batch_shape: int | Tuple) -> TensorDict: batch_shape = batch_shape if isinstance(batch_shape, Tuple) else (batch_shape,) - return TensorDict({ - "node_feature": cls.s0["node_feature"].repeat(np.prod(batch_shape), 1), - "edge_feature": cls.s0["edge_feature"].repeat(np.prod(batch_shape), 1), - "edge_index": cls.s0["edge_index"].repeat(np.prod(batch_shape), 1), - "batch_ptr": torch.arange(np.prod(batch_shape) + 1) * cls.s0["node_feature"].shape[0], - "batch_shape": batch_shape - }) + return TensorDict( + { + "node_feature": cls.s0["node_feature"].repeat(np.prod(batch_shape), 1), + "edge_feature": cls.s0["edge_feature"].repeat(np.prod(batch_shape), 1), + "edge_index": cls.s0["edge_index"].repeat(np.prod(batch_shape), 1), + "batch_ptr": torch.arange(np.prod(batch_shape) + 1) + * cls.s0["node_feature"].shape[0], + "batch_shape": batch_shape, + } + ) @classmethod def make_sink_states_tensor(cls, batch_shape: int | Tuple) -> TensorDict: @@ -573,13 +582,16 @@ def make_sink_states_tensor(cls, batch_shape: int | Tuple) -> TensorDict: raise NotImplementedError("Sink state is not defined") batch_shape = batch_shape if isinstance(batch_shape, Tuple) else (batch_shape,) - return TensorDict({ - "node_feature": cls.sf["node_feature"].repeat(np.prod(batch_shape), 1), - "edge_feature": cls.sf["edge_feature"].repeat(np.prod(batch_shape), 1), - "edge_index": cls.sf["edge_index"].repeat(np.prod(batch_shape), 1), - "batch_ptr": torch.arange(np.prod(batch_shape) + 1) * cls.sf["node_feature"].shape[0], - "batch_shape": batch_shape - }) + return TensorDict( + { + "node_feature": cls.sf["node_feature"].repeat(np.prod(batch_shape), 1), + "edge_feature": cls.sf["edge_feature"].repeat(np.prod(batch_shape), 1), + "edge_index": cls.sf["edge_index"].repeat(np.prod(batch_shape), 1), + "batch_ptr": torch.arange(np.prod(batch_shape) + 1) + * cls.sf["node_feature"].shape[0], + "batch_shape": batch_shape, + } + ) @classmethod def make_random_states_tensor(cls, batch_shape: int | Tuple) -> TensorDict: @@ -589,13 +601,21 @@ def make_random_states_tensor(cls, batch_shape: int | Tuple) -> TensorDict: num_edges = np.random.randint(num_nodes * (num_nodes - 1) // 2) node_features_dim = cls.s0["node_feature"].shape[-1] edge_features_dim = cls.s0["edge_feature"].shape[-1] - return TensorDict({ - "node_feature": torch.rand(np.prod(batch_shape) * num_nodes, node_features_dim), - "edge_feature": torch.rand(np.prod(batch_shape) * num_edges, edge_features_dim), - "edge_index": torch.randint(num_nodes, size=(np.prod(batch_shape) * num_edges, 2)), - "batch_ptr": torch.arange(np.prod(batch_shape) + 1) * num_nodes, - "batch_shape": batch_shape - }) + return TensorDict( + { + "node_feature": torch.rand( + np.prod(batch_shape) * num_nodes, node_features_dim + ), + "edge_feature": torch.rand( + np.prod(batch_shape) * num_edges, edge_features_dim + ), + "edge_index": torch.randint( + num_nodes, size=(np.prod(batch_shape) * num_edges, 2) + ), + "batch_ptr": torch.arange(np.prod(batch_shape) + 1) * num_nodes, + "batch_shape": batch_shape, + } + ) def __len__(self) -> int: return int(np.prod(self.batch_shape)) @@ -611,44 +631,48 @@ def __getitem__( ) -> GraphStates: tensor_idx = torch.arange(len(self)).view(*self.batch_shape) index = tensor_idx[index].flatten() - - if torch.any(index >= len(self.tensor['batch_ptr']) - 1): + + if torch.any(index >= len(self.tensor["batch_ptr"]) - 1): raise ValueError("Graph index out of bounds") - - start_ptrs = self.tensor['batch_ptr'][:-1][index] - end_ptrs = self.tensor['batch_ptr'][1:][index] - + + start_ptrs = self.tensor["batch_ptr"][:-1][index] + end_ptrs = self.tensor["batch_ptr"][1:][index] + node_features = [torch.empty(0, self.node_features_dim)] edge_features = [torch.empty(0, self.edge_features_dim)] edge_indices = [torch.empty(0, 2, dtype=torch.long)] batch_ptr = [0] - + for start, end in zip(start_ptrs, end_ptrs): - graph_nodes = self.tensor['node_feature'][start:end] + graph_nodes = self.tensor["node_feature"][start:end] node_features.append(graph_nodes) # Find edges for this graph - edge_mask = ((self.tensor['edge_index'][:, 0] >= start) & - (self.tensor['edge_index'][:, 0] < end)) - graph_edges = self.tensor['edge_feature'][edge_mask] + edge_mask = (self.tensor["edge_index"][:, 0] >= start) & ( + self.tensor["edge_index"][:, 0] < end + ) + graph_edges = self.tensor["edge_feature"][edge_mask] edge_features.append(graph_edges) - + # Adjust edge indices to be local to this graph - graph_edge_index = self.tensor['edge_index'][edge_mask] - graph_edge_index[:, 0] -= (start - batch_ptr[-1]) - graph_edge_index[:, 1] -= (start - batch_ptr[-1]) + graph_edge_index = self.tensor["edge_index"][edge_mask] + graph_edge_index[:, 0] -= start - batch_ptr[-1] + graph_edge_index[:, 1] -= start - batch_ptr[-1] edge_indices.append(graph_edge_index) batch_ptr.append(batch_ptr[-1] + len(graph_nodes)) - out = self.__class__(TensorDict({ - 'node_feature': torch.cat(node_features), - 'edge_feature': torch.cat(edge_features), - 'edge_index': torch.cat(edge_indices), - 'batch_ptr': torch.tensor(batch_ptr), - 'batch_shape': (len(index),) - })) + out = self.__class__( + TensorDict( + { + "node_feature": torch.cat(node_features), + "edge_feature": torch.cat(edge_features), + "edge_index": torch.cat(edge_indices), + "batch_ptr": torch.tensor(batch_ptr), + "batch_shape": (len(index),), + } + ) + ) - if self._log_rewards is not None: out._log_rewards = self._log_rewards[index] @@ -660,64 +684,88 @@ def __setitem__(self, index: int | Sequence[int], graph: GraphStates): """ tensor_idx = torch.arange(len(self)).view(*self.batch_shape) index = tensor_idx[index].flatten() - + # Validate indices - if torch.any(index >= len(self.tensor['batch_ptr']) - 1): + if torch.any(index >= len(self.tensor["batch_ptr"]) - 1): raise ValueError("Target graph index out of bounds") - + # Source graph details source_tensor_dict = graph.tensor - source_num_graphs = torch.prod(source_tensor_dict['batch_shape']) - + source_num_graphs = torch.prod(source_tensor_dict["batch_shape"]) + # Validate source and target indices match if len(index) != source_num_graphs: - raise ValueError("Number of source graphs must match number of target indices") - + raise ValueError( + "Number of source graphs must match number of target indices" + ) + for i, graph_idx in enumerate(index): # Get start and end pointers for the current graph - start_ptr = self.tensor['batch_ptr'][graph_idx] - end_ptr = self.tensor['batch_ptr'][graph_idx + 1] - source_start_ptr = source_tensor_dict['batch_ptr'][i] - source_end_ptr = source_tensor_dict['batch_ptr'][i + 1] + start_ptr = self.tensor["batch_ptr"][graph_idx] + end_ptr = self.tensor["batch_ptr"][graph_idx + 1] + source_start_ptr = source_tensor_dict["batch_ptr"][i] + source_end_ptr = source_tensor_dict["batch_ptr"][i + 1] + + new_nodes = source_tensor_dict["node_feature"][ + source_start_ptr:source_end_ptr + ] - new_nodes = source_tensor_dict['node_feature'][source_start_ptr:source_end_ptr] - # Ensure new nodes have correct feature dimension if new_nodes.ndim == 1: new_nodes = new_nodes.unsqueeze(0) - + if new_nodes.shape[1] != self.node_features_dim: - raise ValueError(f"Node features must have dimension {node_feature_dim}") - + raise ValueError( + f"Node features must have dimension {node_feature_dim}" + ) + # Number of new nodes to add shift = new_nodes.shape[0] - (end_ptr - start_ptr) - + # Concatenate node features - self.tensor['node_feature'] = torch.cat([ - self.tensor['node_feature'][:start_ptr], # Nodes before the current graph - new_nodes, # New nodes to add - self.tensor['node_feature'][end_ptr:] # Nodes after the current graph - ]) - + self.tensor["node_feature"] = torch.cat( + [ + self.tensor["node_feature"][ + :start_ptr + ], # Nodes before the current graph + new_nodes, # New nodes to add + self.tensor["node_feature"][ + end_ptr: + ], # Nodes after the current graph + ] + ) + # Update edge indices for subsequent graphs - edge_mask = self.tensor['edge_index'] >= end_ptr + edge_mask = self.tensor["edge_index"] >= end_ptr assert torch.all(edge_mask[..., 0] == edge_mask[..., 1]) edge_mask = torch.all(edge_mask, dim=-1) - self.tensor['edge_index'][edge_mask] += shift - edge_mask |= torch.all(self.tensor['edge_index'] < start_ptr, dim=-1) - edge_to_add_mask = torch.all(source_tensor_dict['edge_index'] >= source_start_ptr, dim=-1) - edge_to_add_mask &= torch.all(source_tensor_dict['edge_index'] < source_end_ptr, dim=-1) - self.tensor['edge_index'] = torch.cat([ - self.tensor['edge_index'][edge_mask], - source_tensor_dict['edge_index'][edge_to_add_mask] - source_start_ptr + start_ptr, - ], dim=0) - self.tensor['edge_feature'] = torch.cat([ - self.tensor['edge_feature'][edge_mask], - source_tensor_dict['edge_feature'][edge_to_add_mask], - ], dim=0) + self.tensor["edge_index"][edge_mask] += shift + edge_mask |= torch.all(self.tensor["edge_index"] < start_ptr, dim=-1) + edge_to_add_mask = torch.all( + source_tensor_dict["edge_index"] >= source_start_ptr, dim=-1 + ) + edge_to_add_mask &= torch.all( + source_tensor_dict["edge_index"] < source_end_ptr, dim=-1 + ) + self.tensor["edge_index"] = torch.cat( + [ + self.tensor["edge_index"][edge_mask], + source_tensor_dict["edge_index"][edge_to_add_mask] + - source_start_ptr + + start_ptr, + ], + dim=0, + ) + self.tensor["edge_feature"] = torch.cat( + [ + self.tensor["edge_feature"][edge_mask], + source_tensor_dict["edge_feature"][edge_to_add_mask], + ], + dim=0, + ) # Update batch pointers - self.tensor['batch_ptr'][graph_idx + 1:] += shift + self.tensor["batch_ptr"][graph_idx + 1 :] += shift @property def device(self) -> torch.device: @@ -736,13 +784,33 @@ def clone(self) -> GraphStates: def extend(self, other: GraphStates): """Concatenates to another GraphStates object along the batch dimension""" - self.tensor["node_feature"] = torch.cat([self.tensor["node_feature"], other.tensor["node_feature"]], dim=0) - self.tensor["edge_feature"] = torch.cat([self.tensor["edge_feature"], other.tensor["edge_feature"]], dim=0) + self.tensor["node_feature"] = torch.cat( + [self.tensor["node_feature"], other.tensor["node_feature"]], dim=0 + ) + self.tensor["edge_feature"] = torch.cat( + [self.tensor["edge_feature"], other.tensor["edge_feature"]], dim=0 + ) # TODO: fix indices - self.tensor["edge_index"] = torch.cat([self.tensor["edge_index"], other.tensor["edge_index"] + self.tensor["batch_ptr"][-1]], dim=0) - self.tensor["batch_ptr"] = torch.cat([self.tensor["batch_ptr"], other.tensor["batch_ptr"][1:] + self.tensor["batch_ptr"][-1]], dim=0) - assert torch.all(self.tensor["batch_shape"][1:] == other.tensor["batch_shape"][1:]) - self.tensor["batch_shape"] = (self.tensor["batch_shape"][0] + other.tensor["batch_shape"][0],) + self.batch_shape[1:] + self.tensor["edge_index"] = torch.cat( + [ + self.tensor["edge_index"], + other.tensor["edge_index"] + self.tensor["batch_ptr"][-1], + ], + dim=0, + ) + self.tensor["batch_ptr"] = torch.cat( + [ + self.tensor["batch_ptr"], + other.tensor["batch_ptr"][1:] + self.tensor["batch_ptr"][-1], + ], + dim=0, + ) + assert torch.all( + self.tensor["batch_shape"][1:] == other.tensor["batch_shape"][1:] + ) + self.tensor["batch_shape"] = ( + self.tensor["batch_shape"][0] + other.tensor["batch_shape"][0], + ) + self.batch_shape[1:] @property def log_rewards(self) -> torch.Tensor: @@ -759,12 +827,22 @@ def _compare(self, other: TensorDict) -> torch.Tensor: if end - start != len(other["node_feature"]): out[i] = False else: - out[i] = torch.all(self.tensor["node_feature"][start:end] == other["node_feature"]) - edge_mask = torch.all((self.tensor["edge_index"] >= start) & (self.tensor["edge_index"] < end), dim=-1) + out[i] = torch.all( + self.tensor["node_feature"][start:end] == other["node_feature"] + ) + edge_mask = torch.all( + (self.tensor["edge_index"] >= start) + & (self.tensor["edge_index"] < end), + dim=-1, + ) edge_index = self.tensor["edge_index"][edge_mask] - start - out[i] &= len(edge_index) == len(other["edge_index"]) and torch.all(edge_index == other["edge_index"]) + out[i] &= len(edge_index) == len(other["edge_index"]) and torch.all( + edge_index == other["edge_index"] + ) edge_feature = self.tensor["edge_feature"][edge_mask] - out[i] &= len(edge_feature) == len(other["edge_feature"]) and torch.all(edge_feature == other["edge_feature"]) + out[i] &= len(edge_feature) == len(other["edge_feature"]) and torch.all( + edge_feature == other["edge_feature"] + ) return out.view(self.batch_shape) @property diff --git a/src/gfn/utils/distributions.py b/src/gfn/utils/distributions.py index e6f6beaa..250a275b 100644 --- a/src/gfn/utils/distributions.py +++ b/src/gfn/utils/distributions.py @@ -91,7 +91,6 @@ def log_prob(self, value): class CategoricalActionType(Categorical): # TODO: remove, just to sample 1 action_type - def __init__(self, probs: torch.Tensor): self.batch_len = len(probs) super().__init__(probs[0]) @@ -99,6 +98,6 @@ def __init__(self, probs: torch.Tensor): def sample(self, sample_shape=torch.Size()) -> torch.Tensor: samples = super().sample(sample_shape) return samples.repeat(self.batch_len) - + def log_prob(self, value): - return super().log_prob(value[0]).repeat(self.batch_len) \ No newline at end of file + return super().log_prob(value[0]).repeat(self.batch_len) diff --git a/testing/test_environments.py b/testing/test_environments.py index a2919d50..45768c2e 100644 --- a/testing/test_environments.py +++ b/testing/test_environments.py @@ -332,54 +332,88 @@ def test_graph_env(): action_cls = env.make_actions_class() with pytest.raises(NonValidActionsError): - actions = action_cls(TensorDict({ - "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_EDGE), - "features": torch.rand((BATCH_SIZE, FEATURE_DIM)), - "edge_index": torch.randint(0, 10, (BATCH_SIZE, 2), dtype=torch.long), - }, batch_size=BATCH_SIZE)) + actions = action_cls( + TensorDict( + { + "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_EDGE), + "features": torch.rand((BATCH_SIZE, FEATURE_DIM)), + "edge_index": torch.randint( + 0, 10, (BATCH_SIZE, 2), dtype=torch.long + ), + }, + batch_size=BATCH_SIZE, + ) + ) states = env.step(states, actions) for _ in range(NUM_NODES): - actions = action_cls(TensorDict({ - "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_NODE), - "features": torch.rand((BATCH_SIZE, FEATURE_DIM)), - }, batch_size=BATCH_SIZE)) + actions = action_cls( + TensorDict( + { + "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_NODE), + "features": torch.rand((BATCH_SIZE, FEATURE_DIM)), + }, + batch_size=BATCH_SIZE, + ) + ) states = env.step(states, actions) states = env.States(states) assert states.tensor["node_feature"].shape == (BATCH_SIZE * NUM_NODES, FEATURE_DIM) with pytest.raises(NonValidActionsError): - first_node_mask = torch.arange(len(states.tensor["node_feature"])) // BATCH_SIZE == 0 - actions = action_cls(TensorDict({ - "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_NODE), - "features": states.tensor["node_feature"][first_node_mask], - }, batch_size=BATCH_SIZE)) + first_node_mask = ( + torch.arange(len(states.tensor["node_feature"])) // BATCH_SIZE == 0 + ) + actions = action_cls( + TensorDict( + { + "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_NODE), + "features": states.tensor["node_feature"][first_node_mask], + }, + batch_size=BATCH_SIZE, + ) + ) states = env.step(states, actions) with pytest.raises(NonValidActionsError): edge_index = torch.randint(0, 3, (BATCH_SIZE,), dtype=torch.long) - actions = action_cls(TensorDict({ - "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_EDGE), - "features": torch.rand((BATCH_SIZE, FEATURE_DIM)), - "edge_index": torch.stack([edge_index, edge_index], dim=1), - }, batch_size=BATCH_SIZE)) + actions = action_cls( + TensorDict( + { + "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_EDGE), + "features": torch.rand((BATCH_SIZE, FEATURE_DIM)), + "edge_index": torch.stack([edge_index, edge_index], dim=1), + }, + batch_size=BATCH_SIZE, + ) + ) states = env.step(states, actions) for i in range(NUM_NODES - 1): node_is = states.tensor["batch_ptr"][:-1] + i node_js = states.tensor["batch_ptr"][:-1] + i + 1 - actions = action_cls(TensorDict({ - "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_EDGE), - "features": torch.rand((BATCH_SIZE, FEATURE_DIM)), - "edge_index": torch.stack([node_is, node_js], dim=1), - }, batch_size=BATCH_SIZE)) + actions = action_cls( + TensorDict( + { + "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_EDGE), + "features": torch.rand((BATCH_SIZE, FEATURE_DIM)), + "edge_index": torch.stack([node_is, node_js], dim=1), + }, + batch_size=BATCH_SIZE, + ) + ) states = env.step(states, actions) states = env.States(states) - actions = action_cls(TensorDict({ - "action_type": torch.full((BATCH_SIZE,), GraphActionType.EXIT), - }, batch_size=BATCH_SIZE)) + actions = action_cls( + TensorDict( + { + "action_type": torch.full((BATCH_SIZE,), GraphActionType.EXIT), + }, + batch_size=BATCH_SIZE, + ) + ) sf_states = env.step(states, actions) sf_states = env.States(sf_states) assert torch.all(sf_states.is_sink_state) @@ -396,36 +430,58 @@ def test_graph_env(): num_edges_per_batch = len(states.tensor["edge_feature"]) // BATCH_SIZE for i in reversed(range(num_edges_per_batch)): edge_idx = torch.arange(i * BATCH_SIZE, (i + 1) * BATCH_SIZE) - actions = action_cls(TensorDict({ - "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_EDGE), - "features": states.tensor["edge_feature"][edge_idx], - "edge_index": states.tensor["edge_index"][edge_idx], - }, batch_size=BATCH_SIZE)) + actions = action_cls( + TensorDict( + { + "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_EDGE), + "features": states.tensor["edge_feature"][edge_idx], + "edge_index": states.tensor["edge_index"][edge_idx], + }, + batch_size=BATCH_SIZE, + ) + ) states = env.backward_step(states, actions) states = env.States(states) with pytest.raises(NonValidActionsError): - actions = action_cls(TensorDict({ - "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_EDGE), - "features": torch.rand((BATCH_SIZE, FEATURE_DIM)), - "edge_index": torch.randint(0, 10, (BATCH_SIZE, 2), dtype=torch.long), - }, batch_size=BATCH_SIZE)) + actions = action_cls( + TensorDict( + { + "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_EDGE), + "features": torch.rand((BATCH_SIZE, FEATURE_DIM)), + "edge_index": torch.randint( + 0, 10, (BATCH_SIZE, 2), dtype=torch.long + ), + }, + batch_size=BATCH_SIZE, + ) + ) states = env.backward_step(states, actions) for i in reversed(range(1, NUM_NODES + 1)): edge_idx = torch.arange(BATCH_SIZE) * i - actions = action_cls(TensorDict({ - "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_NODE), - "features": states.tensor["node_feature"][edge_idx], - }, batch_size=BATCH_SIZE)) + actions = action_cls( + TensorDict( + { + "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_NODE), + "features": states.tensor["node_feature"][edge_idx], + }, + batch_size=BATCH_SIZE, + ) + ) states = env.backward_step(states, actions) states = env.States(states) assert states.tensor["node_feature"].shape == (0, FEATURE_DIM) with pytest.raises(NonValidActionsError): - actions = action_cls(TensorDict({ - "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_NODE), - "features": torch.rand((BATCH_SIZE, FEATURE_DIM)), - }, batch_size=BATCH_SIZE)) + actions = action_cls( + TensorDict( + { + "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_NODE), + "features": torch.rand((BATCH_SIZE, FEATURE_DIM)), + }, + batch_size=BATCH_SIZE, + ) + ) states = env.backward_step(states, actions) diff --git a/testing/test_samplers_and_trajectories.py b/testing/test_samplers_and_trajectories.py index 470c8b09..d23dbd61 100644 --- a/testing/test_samplers_and_trajectories.py +++ b/testing/test_samplers_and_trajectories.py @@ -225,6 +225,7 @@ def test_replay_buffer( # ------ GRAPH TESTS ------ + class GraphActionNet(nn.Module): def __init__(self, feature_dim: int): super().__init__() @@ -243,7 +244,9 @@ def forward(self, states: GraphStates) -> TensorDict: features = torch.zeros((len(states), self.feature_dim)) else: action_type = self.action_type_conv(node_feature, edge_index) - action_type = action_type.reshape(len(states), -1, action_type.shape[-1]).mean(dim=1) + action_type = action_type.reshape( + len(states), -1, action_type.shape[-1] + ).mean(dim=1) features = self.features_conv(node_feature, edge_index) features = features.reshape(len(states), -1, features.shape[-1]).mean(dim=1) @@ -277,3 +280,5 @@ def test_graph_building(): save_logprobs=True, save_estimator_outputs=False, ) + + assert len(trajectories) == 7 diff --git a/tutorials/examples/test_graph_ring.py b/tutorials/examples/test_graph_ring.py index c6b591c2..b6d401f7 100644 --- a/tutorials/examples/test_graph_ring.py +++ b/tutorials/examples/test_graph_ring.py @@ -3,17 +3,19 @@ import math import time from typing import Optional + +import matplotlib.pyplot as plt import torch +from tensordict import TensorDict from torch import nn -from gfn.actions import Actions, GraphActionType, GraphActions +from torch_geometric.nn import GCNConv + +from gfn.actions import Actions, GraphActions, GraphActionType from gfn.gflownet.flow_matching import FMGFlowNet from gfn.gym import GraphBuilding from gfn.modules import DiscretePolicyEstimator from gfn.preprocessors import Preprocessor from gfn.states import GraphStates -from tensordict import TensorDict -from torch_geometric.nn import GCNConv -import matplotlib.pyplot as plt def state_evaluator(states: GraphStates) -> torch.Tensor: @@ -26,11 +28,15 @@ def state_evaluator(states: GraphStates) -> torch.Tensor: out = torch.zeros(len(states)) for i in range(len(states)): start, end = states.tensor["batch_ptr"][i], states.tensor["batch_ptr"][i + 1] - edge_index_mask = torch.all(states.tensor["edge_index"] >= start, dim=-1) & torch.all(states.tensor["edge_index"] < end, dim=-1) + edge_index_mask = torch.all( + states.tensor["edge_index"] >= start, dim=-1 + ) & torch.all(states.tensor["edge_index"] < end, dim=-1) edge_index = states.tensor["edge_index"][edge_index_mask] arange = torch.arange(start, end) # TODO: not correct, accepts multiple rings - if torch.all(torch.sort(edge_index[:, 0])[0] == arange) and torch.all(torch.sort(edge_index[:, 1])[0] == arange): + if torch.all(torch.sort(edge_index[:, 0])[0] == arange) and torch.all( + torch.sort(edge_index[:, 1])[0] == arange + ): out[i] = 1 else: out[i] = eps @@ -46,7 +52,11 @@ def __init__(self, n_nodes: int, edge_hidden_dim: int = 128): self.edge_hidden_dim = edge_hidden_dim def _group_sum(self, tensor: torch.Tensor, batch_ptr: torch.Tensor) -> torch.Tensor: - cumsum = torch.zeros((len(tensor) + 1, *tensor.shape[1:]), dtype=tensor.dtype, device=tensor.device) + cumsum = torch.zeros( + (len(tensor) + 1, *tensor.shape[1:]), + dtype=tensor.dtype, + device=tensor.device, + ) cumsum[1:] = torch.cumsum(tensor, dim=0) return cumsum[batch_ptr[1:]] - cumsum[batch_ptr[:-1]] @@ -59,82 +69,102 @@ def forward(self, states_tensor: TensorDict) -> torch.Tensor: action_type = self._group_sum(action_type, batch_ptr) edge_index = self.edge_index_conv(node_feature, edge_index) - edge_index = edge_index.reshape(*states_tensor["batch_shape"], self.n_nodes, self.edge_hidden_dim) + edge_index = edge_index.reshape( + *states_tensor["batch_shape"], self.n_nodes, self.edge_hidden_dim + ) edge_index = torch.einsum("bnf,bmf->bnm", edge_index, edge_index) - edge_actions = edge_index.reshape(*states_tensor["batch_shape"], self.n_nodes * self.n_nodes) + edge_actions = edge_index.reshape( + *states_tensor["batch_shape"], self.n_nodes * self.n_nodes + ) return torch.cat([action_type, edge_actions], dim=-1) + class RingGraphBuilding(GraphBuilding): def __init__(self, n_nodes: int = 10): self.n_nodes = n_nodes self.n_actions = 1 + n_nodes * n_nodes super().__init__(feature_dim=1, state_evaluator=state_evaluator) - def make_actions_class(self) -> type[Actions]: env = self + class RingActions(Actions): action_shape = (1,) dummy_action = torch.tensor([env.n_actions]) - exit_action = torch.zeros(1,) + exit_action = torch.zeros( + 1, + ) return RingActions - def make_states_class(self) -> type[GraphStates]: env = self class RingStates(GraphStates): - s0 = TensorDict({ - "node_feature": torch.arange(env.n_nodes).unsqueeze(-1), - "edge_feature": torch.ones((0, 1)), - "edge_index": torch.ones((0, 2), dtype=torch.long), - }, batch_size=()) - sf = TensorDict({ - "node_feature": torch.zeros((env.n_nodes, 1)), - "edge_feature": torch.zeros((0, 1)), - "edge_index": torch.zeros((0, 2), dtype=torch.long), - }, batch_size=()) + s0 = TensorDict( + { + "node_feature": torch.arange(env.n_nodes).unsqueeze(-1), + "edge_feature": torch.ones((0, 1)), + "edge_index": torch.ones((0, 2), dtype=torch.long), + }, + batch_size=(), + ) + sf = TensorDict( + { + "node_feature": torch.zeros((env.n_nodes, 1)), + "edge_feature": torch.zeros((0, 1)), + "edge_index": torch.zeros((0, 2), dtype=torch.long), + }, + batch_size=(), + ) def __init__(self, tensor: TensorDict): self.tensor = tensor self.node_features_dim = tensor["node_feature"].shape[-1] self.edge_features_dim = tensor["edge_feature"].shape[-1] self._log_rewards: Optional[float] = None - + self.n_nodes = env.n_nodes self.n_actions = env.n_actions @property def forward_masks(self): forward_masks = torch.ones(len(self), self.n_actions, dtype=torch.bool) - forward_masks[:, 1::self.n_nodes + 1] = False + forward_masks[:, 1 :: self.n_nodes + 1] = False for i in range(len(self)): existing_edges = self[i].tensor["edge_index"] - forward_masks[i, 1 + existing_edges[:, 0] * self.n_nodes + existing_edges[:, 1]] = False - + forward_masks[ + i, + 1 + existing_edges[:, 0] * self.n_nodes + existing_edges[:, 1], + ] = False + return forward_masks.view(*self.batch_shape, self.n_actions) - + @forward_masks.setter def forward_masks(self, value: torch.Tensor): - pass # fwd masks is computed on the fly + pass # fwd masks is computed on the fly @property def backward_masks(self): - backward_masks = torch.zeros(len(self), self.n_actions, dtype=torch.bool) + backward_masks = torch.zeros( + len(self), self.n_actions, dtype=torch.bool + ) for i in range(len(self)): existing_edges = self[i].tensor["edge_index"] - backward_masks[i, 1 + existing_edges[:, 0] * self.n_nodes + existing_edges[:, 1]] = True - + backward_masks[ + i, + 1 + existing_edges[:, 0] * self.n_nodes + existing_edges[:, 1], + ] = True + return backward_masks.view(*self.batch_shape, self.n_actions) - + @backward_masks.setter def backward_masks(self, value: torch.Tensor): - pass # bwd masks is computed on the fly - + pass # bwd masks is computed on the fly + return RingStates - + def _step(self, states: GraphStates, actions: Actions) -> GraphStates: actions = self.convert_actions(actions) return super()._step(states, actions) @@ -145,20 +175,26 @@ def _backward_step(self, states: GraphStates, actions: Actions) -> GraphStates: def convert_actions(self, actions: Actions) -> GraphActions: action_tensor = actions.tensor.squeeze(-1) - action_type = torch.where(action_tensor == 0, GraphActionType.EXIT, GraphActionType.ADD_EDGE) + action_type = torch.where( + action_tensor == 0, GraphActionType.EXIT, GraphActionType.ADD_EDGE + ) edge_index_i0 = (action_tensor - 1) // (self.n_nodes) edge_index_i1 = (action_tensor - 1) % (self.n_nodes) edge_index = torch.stack([edge_index_i0, edge_index_i1], dim=-1) - return GraphActions(TensorDict({ - "action_type": action_type, - "features": torch.ones(action_tensor.shape + (1,)), - "edge_index": edge_index, - }, batch_size=action_tensor.shape)) + return GraphActions( + TensorDict( + { + "action_type": action_type, + "features": torch.ones(action_tensor.shape + (1,)), + "edge_index": edge_index, + }, + batch_size=action_tensor.shape, + ) + ) class GraphPreprocessor(Preprocessor): - def __init__(self, feature_dim: int = 1): super().__init__(output_dim=feature_dim) @@ -184,28 +220,36 @@ def render_states(states: GraphStates): y = radius * math.sin(angle) xs.append(x) ys.append(y) - current_ax.add_patch(plt.Circle((x, y), 0.5, facecolor='none', edgecolor='black')) - + current_ax.add_patch( + plt.Circle((x, y), 0.5, facecolor="none", edgecolor="black") + ) + for edge in state.tensor["edge_index"]: start_x, start_y = xs[edge[0]], ys[edge[0]] end_x, end_y = xs[edge[1]], ys[edge[1]] dx = end_x - start_x dy = end_y - start_y length = math.sqrt(dx**2 + dy**2) - dx, dy = dx/length, dy/length - + dx, dy = dx / length, dy / length + circle_radius = 0.5 head_thickness = 0.2 start_x += dx * (circle_radius) start_y += dy * (circle_radius) end_x -= dx * (circle_radius + head_thickness) end_y -= dy * (circle_radius + head_thickness) - - current_ax.arrow(start_x, start_y, - end_x - start_x, end_y - start_y, - head_width=head_thickness, head_length=head_thickness, - fc='black', ec='black') - + + current_ax.arrow( + start_x, + start_y, + end_x - start_x, + end_y - start_y, + head_width=head_thickness, + head_length=head_thickness, + fc="black", + ec="black", + ) + current_ax.set_title(f"State {i}, $r={rewards[i]:.2f}$") current_ax.set_xlim(-(radius + 1), radius + 1) current_ax.set_ylim(-(radius + 1), radius + 1) @@ -213,7 +257,6 @@ def render_states(states: GraphStates): current_ax.set_xticks([]) current_ax.set_yticks([]) - plt.show() @@ -224,11 +267,15 @@ def render_states(states: GraphStates): env = RingGraphBuilding(n_nodes=N_NODES) module = RingPolicyEstimator(env.n_nodes) - pf_estimator = DiscretePolicyEstimator(module=module, n_actions=env.n_actions, preprocessor=GraphPreprocessor()) + pf_estimator = DiscretePolicyEstimator( + module=module, n_actions=env.n_actions, preprocessor=GraphPreprocessor() + ) gflownet = FMGFlowNet(pf_estimator) optimizer = torch.optim.Adam(gflownet.parameters(), lr=1e-2) - scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=N_ITERATIONS, eta_min=1e-4) + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, T_max=N_ITERATIONS, eta_min=1e-4 + ) losses = [] @@ -246,6 +293,3 @@ def render_states(states: GraphStates): t2 = time.time() print("Time:", t2 - t1) render_states(trajectories.last_states[:8]) - - - From 8f1c62c65b9a68de92e22aa4ca918dc8f2975493 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Sun, 12 Jan 2025 22:50:17 +0100 Subject: [PATCH 031/102] address comments --- pyproject.toml | 3 +- src/gfn/actions.py | 17 ++++++++++-- src/gfn/containers/trajectories.py | 2 +- src/gfn/gym/graph_building.py | 34 ++++++++++------------- src/gfn/modules.py | 11 +++++--- src/gfn/states.py | 32 +++++++++++++++------ testing/test_environments.py | 4 ++- testing/test_samplers_and_trajectories.py | 4 ++- 8 files changed, 70 insertions(+), 37 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index bd1c9cf5..0b1958c7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,7 @@ einops = ">=0.6.1" numpy = ">=1.21.2" python = "^3.10" torch = ">=1.9.0" -torch_geometric = ">=2.6.0" +tensordict = ">=0.6.1" # dev dependencies. black = { version = "24.3", optional = true } @@ -60,6 +60,7 @@ dev = [ "sphinx", "tox", "flake8", + "torch_geometric>=2.6.0", ] scripts = ["tqdm", "wandb", "scikit-learn", "scipy", "matplotlib"] diff --git a/src/gfn/actions.py b/src/gfn/actions.py index c80eb2d3..e76388d9 100644 --- a/src/gfn/actions.py +++ b/src/gfn/actions.py @@ -179,6 +179,21 @@ class GraphActionType(enum.IntEnum): class GraphActions(Actions): + """Actions for graph-based environments. + + Each action is one of: + - ADD_NODE: Add a node with given features + - ADD_EDGE: Add an edge between two nodes with given features + - EXIT: Terminate the trajectory + + Attributes: + features_dim: Dimension of node/edge features + tensor: TensorDict containing: + - action_type: Type of action (GraphActionType) + - features: Features for nodes/edges + - edge_index: Source/target nodes for edges + """ + features_dim: ClassVar[int] def __init__(self, tensor: TensorDict): @@ -279,8 +294,6 @@ def make_dummy_actions(cls, batch_shape: tuple[int]) -> GraphActions: "action_type": torch.full( batch_shape, fill_value=GraphActionType.EXIT ), - # "features": torch.zeros((*batch_shape, 0, cls.nodes_features_dim)), - # "edge_index": torch.zeros((2, *batch_shape, 0)), }, batch_size=batch_shape, ) diff --git a/src/gfn/containers/trajectories.py b/src/gfn/containers/trajectories.py index c56cd1ee..5feb665a 100644 --- a/src/gfn/containers/trajectories.py +++ b/src/gfn/containers/trajectories.py @@ -76,7 +76,7 @@ def __init__( self.states = ( states if states is not None else env.states_from_batch_shape((0, 0)) ) - assert len(self.states.batch_shape) == 2, self.states.batch_shape + assert len(self.states.batch_shape) == 2 self.actions = ( actions if actions is not None else env.actions_from_batch_shape((0, 0)) ) diff --git a/src/gfn/gym/graph_building.py b/src/gfn/gym/graph_building.py index 415c4ec1..c3b58a5b 100644 --- a/src/gfn/gym/graph_building.py +++ b/src/gfn/gym/graph_building.py @@ -3,7 +3,6 @@ import torch from tensordict import TensorDict -from torch_geometric.nn import GCNConv from gfn.actions import GraphActions, GraphActionType from gfn.env import GraphEnv, NonValidActionsError @@ -11,10 +10,24 @@ class GraphBuilding(GraphEnv): + """Environment for incrementally building graphs. + + This environment allows constructing graphs by: + - Adding nodes with features + - Adding edges between existing nodes with features + - Terminating construction (EXIT) + + Args: + feature_dim: Dimension of node and edge features + state_evaluator: Callable that computes rewards for final states. + If None, uses default GCNConvEvaluator + device_str: Device to run computations on ('cpu' or 'cuda') + """ + def __init__( self, feature_dim: int, - state_evaluator: Callable[[GraphStates], torch.Tensor] | None = None, + state_evaluator: Callable[[GraphStates], torch.Tensor], device_str: Literal["cpu", "cuda"] = "cpu", ): s0 = TensorDict( @@ -36,8 +49,6 @@ def __init__( device=device_str, ) - if state_evaluator is None: - state_evaluator = GCNConvEvaluator(feature_dim) self.state_evaluator = state_evaluator super().__init__( @@ -245,18 +256,3 @@ def true_dist_pmf(self) -> torch.Tensor: def make_random_states_tensor(self, batch_shape: Tuple) -> GraphStates: """Generates random states tensor of shape (*batch_shape, feature_dim).""" return self.States.from_batch_shape(batch_shape) - - -class GCNConvEvaluator: - def __init__(self, num_features): - self.net = GCNConv(num_features, 1) - - def __call__(self, state: GraphStates) -> torch.Tensor: - node_feature = state.tensor["node_feature"] - edge_index = state.tensor["edge_index"].T - if len(node_feature) == 0: - return torch.zeros(len(state)) - - out = self.net(node_feature, edge_index) - out = out.reshape(*state.batch_shape, -1) - return out.mean(-1) diff --git a/src/gfn/modules.py b/src/gfn/modules.py index e8058e65..c11d4758 100644 --- a/src/gfn/modules.py +++ b/src/gfn/modules.py @@ -1,12 +1,11 @@ from abc import ABC -from typing import Any, Dict +from typing import Any import torch import torch.nn as nn from tensordict import TensorDict from torch.distributions import Categorical, Distribution, Normal -from gfn.actions import GraphActionType from gfn.preprocessors import IdentityPreprocessor, Preprocessor from gfn.states import DiscreteStates, GraphStates, States from gfn.utils.distributions import ( @@ -484,14 +483,18 @@ def forward(self, states: GraphStates) -> TensorDict: Args: states: The input graph states. - Returns the . + Returns: + TensorDict containing: + - action_type: logits for action type selection (batch_shape, n_actions) + - features: parameters for node/edge features (batch_shape, feature_dim) + - edge_index: logits for edge connections (batch_shape, n_nodes, n_nodes) """ return self.module(states) def to_probability_distribution( self, states: GraphStates, - module_output: Dict[str, torch.Tensor], + module_output: TensorDict, temperature: float = 1.0, epsilon: float = 0.0, ) -> ComposedDistribution: diff --git a/src/gfn/states.py b/src/gfn/states.py index d7cc96f6..564116d8 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -8,7 +8,6 @@ import numpy as np import torch from tensordict import TensorDict -from torch_geometric.data import Batch, Data from gfn.actions import GraphActionType @@ -297,7 +296,10 @@ def sample(self, n_samples: int) -> States: @classmethod def stack(cls, states: List[States]): """Given a list of states, stacks them along a new dimension (0).""" - state_example = states[0] # We assume all elems of `states` are the same. + state_example = states[0] + assert all( + state.batch_shape == state_example.batch_shape for state in states + ), "All states must have the same batch_shape" stacked_states = state_example.from_batch_shape((0, 0)) # Empty. stacked_states.tensor = torch.stack([s.tensor for s in states], dim=0) @@ -519,6 +521,18 @@ class GraphStates(States): sf: ClassVar[TensorDict] def __init__(self, tensor: TensorDict): + REQUIRED_KEYS = { + "node_feature", + "edge_feature", + "edge_index", + "batch_ptr", + "batch_shape", + } + if not all(key in tensor for key in REQUIRED_KEYS): + raise ValueError( + f"TensorDict must contain all required keys: {REQUIRED_KEYS}" + ) + self.tensor = tensor self.node_features_dim = tensor["node_feature"].shape[-1] self.edge_features_dim = tensor["edge_feature"].shape[-1] @@ -601,18 +615,20 @@ def make_random_states_tensor(cls, batch_shape: int | Tuple) -> TensorDict: num_edges = np.random.randint(num_nodes * (num_nodes - 1) // 2) node_features_dim = cls.s0["node_feature"].shape[-1] edge_features_dim = cls.s0["edge_feature"].shape[-1] + device = cls.s0.device return TensorDict( { "node_feature": torch.rand( - np.prod(batch_shape) * num_nodes, node_features_dim + np.prod(batch_shape) * num_nodes, node_features_dim, device=device ), "edge_feature": torch.rand( - np.prod(batch_shape) * num_edges, edge_features_dim + np.prod(batch_shape) * num_edges, edge_features_dim, device=device ), "edge_index": torch.randint( - num_nodes, size=(np.prod(batch_shape) * num_edges, 2) + num_nodes, size=(np.prod(batch_shape) * num_edges, 2), device=device ), - "batch_ptr": torch.arange(np.prod(batch_shape) + 1) * num_nodes, + "batch_ptr": torch.arange(np.prod(batch_shape) + 1, device=device) + * num_nodes, "batch_shape": batch_shape, } ) @@ -716,7 +732,7 @@ def __setitem__(self, index: int | Sequence[int], graph: GraphStates): if new_nodes.shape[1] != self.node_features_dim: raise ValueError( - f"Node features must have dimension {node_feature_dim}" + f"Node features must have dimension {self.node_features_dim}" ) # Number of new nodes to add @@ -768,7 +784,7 @@ def __setitem__(self, index: int | Sequence[int], graph: GraphStates): self.tensor["batch_ptr"][graph_idx + 1 :] += shift @property - def device(self) -> torch.device: + def device(self) -> torch.device | None: return self.tensor.device def to(self, device: torch.device) -> GraphStates: diff --git a/testing/test_environments.py b/testing/test_environments.py index 45768c2e..7106c2cf 100644 --- a/testing/test_environments.py +++ b/testing/test_environments.py @@ -326,7 +326,9 @@ def test_graph_env(): BATCH_SIZE = 3 NUM_NODES = 5 - env = GraphBuilding(feature_dim=FEATURE_DIM) + env = GraphBuilding( + feature_dim=FEATURE_DIM, state_evaluator=lambda s: torch.zeros(s.batch_shape) + ) states = env.reset(batch_shape=BATCH_SIZE) assert states.batch_shape == (BATCH_SIZE,) action_cls = env.make_actions_class() diff --git a/testing/test_samplers_and_trajectories.py b/testing/test_samplers_and_trajectories.py index d23dbd61..de8d9ab0 100644 --- a/testing/test_samplers_and_trajectories.py +++ b/testing/test_samplers_and_trajectories.py @@ -268,7 +268,9 @@ def forward(self, states: GraphStates) -> TensorDict: def test_graph_building(): torch.manual_seed(7) feature_dim = 8 - env = GraphBuilding(feature_dim=feature_dim) + env = GraphBuilding( + feature_dim=feature_dim, state_evaluator=lambda s: torch.zeros(s.batch_shape) + ) module = GraphActionNet(feature_dim) pf_estimator = GraphActionPolicyEstimator(module=module) From 6482834f3bdefd873caf7459280b5fb96b12d5b1 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Sun, 12 Jan 2025 23:20:50 +0100 Subject: [PATCH 032/102] fix test --- src/gfn/gym/graph_building.py | 2 +- src/gfn/states.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/gfn/gym/graph_building.py b/src/gfn/gym/graph_building.py index c3b58a5b..fd894e3a 100644 --- a/src/gfn/gym/graph_building.py +++ b/src/gfn/gym/graph_building.py @@ -42,7 +42,7 @@ def __init__( { "node_feature": torch.ones((1, feature_dim), dtype=torch.float32) * float("inf"), - "edge_feature": torch.ones((1, feature_dim), dtype=torch.float32) + "edge_feature": torch.ones((0, feature_dim), dtype=torch.float32) * float("inf"), "edge_index": torch.zeros((0, 2), dtype=torch.long), }, diff --git a/src/gfn/states.py b/src/gfn/states.py index 564116d8..8d7db2f1 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -596,16 +596,17 @@ def make_sink_states_tensor(cls, batch_shape: int | Tuple) -> TensorDict: raise NotImplementedError("Sink state is not defined") batch_shape = batch_shape if isinstance(batch_shape, Tuple) else (batch_shape,) - return TensorDict( + out = TensorDict( { "node_feature": cls.sf["node_feature"].repeat(np.prod(batch_shape), 1), "edge_feature": cls.sf["edge_feature"].repeat(np.prod(batch_shape), 1), "edge_index": cls.sf["edge_index"].repeat(np.prod(batch_shape), 1), - "batch_ptr": torch.arange(np.prod(batch_shape) + 1) + "batch_ptr": torch.arange(np.prod(batch_shape) + 1, device=cls.sf.device) * cls.sf["node_feature"].shape[0], "batch_shape": batch_shape, } ) + return out @classmethod def make_random_states_tensor(cls, batch_shape: int | Tuple) -> TensorDict: From 6db601d0a4c55ff50496788430d208423ae131c8 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Tue, 14 Jan 2025 00:05:49 +0100 Subject: [PATCH 033/102] fix test --- testing/test_environments.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/testing/test_environments.py b/testing/test_environments.py index 7106c2cf..a2c913cb 100644 --- a/testing/test_environments.py +++ b/testing/test_environments.py @@ -393,8 +393,8 @@ def test_graph_env(): states = env.step(states, actions) for i in range(NUM_NODES - 1): - node_is = states.tensor["batch_ptr"][:-1] + i - node_js = states.tensor["batch_ptr"][:-1] + i + 1 + node_is = torch.full((BATCH_SIZE,), i) + node_js = torch.full((BATCH_SIZE,), i + 1) actions = action_cls( TensorDict( { @@ -437,7 +437,7 @@ def test_graph_env(): { "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_EDGE), "features": states.tensor["edge_feature"][edge_idx], - "edge_index": states.tensor["edge_index"][edge_idx], + "edge_index": states.tensor["edge_index"][edge_idx] - states.tensor["batch_ptr"][:-1][:, None], }, batch_size=BATCH_SIZE, ) From c7f8243994ee783b1c5d3eb834c382e0b730843b Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Tue, 14 Jan 2025 00:07:22 +0100 Subject: [PATCH 034/102] fix pre-commit --- src/gfn/states.py | 4 +++- testing/test_environments.py | 3 ++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/gfn/states.py b/src/gfn/states.py index 8d7db2f1..d8bb86b7 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -601,7 +601,9 @@ def make_sink_states_tensor(cls, batch_shape: int | Tuple) -> TensorDict: "node_feature": cls.sf["node_feature"].repeat(np.prod(batch_shape), 1), "edge_feature": cls.sf["edge_feature"].repeat(np.prod(batch_shape), 1), "edge_index": cls.sf["edge_index"].repeat(np.prod(batch_shape), 1), - "batch_ptr": torch.arange(np.prod(batch_shape) + 1, device=cls.sf.device) + "batch_ptr": torch.arange( + np.prod(batch_shape) + 1, device=cls.sf.device + ) * cls.sf["node_feature"].shape[0], "batch_shape": batch_shape, } diff --git a/testing/test_environments.py b/testing/test_environments.py index a2c913cb..835b6687 100644 --- a/testing/test_environments.py +++ b/testing/test_environments.py @@ -437,7 +437,8 @@ def test_graph_env(): { "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_EDGE), "features": states.tensor["edge_feature"][edge_idx], - "edge_index": states.tensor["edge_index"][edge_idx] - states.tensor["batch_ptr"][:-1][:, None], + "edge_index": states.tensor["edge_index"][edge_idx] + - states.tensor["batch_ptr"][:-1][:, None], }, batch_size=BATCH_SIZE, ) From 78b729a84bf103e10a31f0f1df04b24255340e19 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Tue, 14 Jan 2025 00:25:24 +0100 Subject: [PATCH 035/102] fix merging issues --- src/gfn/containers/trajectories.py | 2 +- src/gfn/samplers.py | 6 +++--- testing/test_samplers_and_trajectories.py | 6 ++++++ 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/src/gfn/containers/trajectories.py b/src/gfn/containers/trajectories.py index d010f503..9d7ccd48 100644 --- a/src/gfn/containers/trajectories.py +++ b/src/gfn/containers/trajectories.py @@ -104,7 +104,7 @@ def __init__( assert ( log_probs.shape == (self.max_length, self.n_trajectories) and log_probs.dtype == torch.float - ) + ), f"log_probs.shape={log_probs.shape}, self.max_length={self.max_length}, self.n_trajectories={self.n_trajectories}" else: log_probs = torch.full(size=(0, 0), fill_value=0, dtype=torch.float) self.log_probs: torch.Tensor = log_probs diff --git a/src/gfn/samplers.py b/src/gfn/samplers.py index 0d222d8d..820bd2f0 100644 --- a/src/gfn/samplers.py +++ b/src/gfn/samplers.py @@ -207,7 +207,6 @@ def sample_trajectories( all_estimator_outputs.append(estimator_outputs_padded) actions[~dones] = valid_actions - trajectories_actions.append(actions) if save_logprobs: # When off_policy, actions_log_probs are None. log_probs[~dones] = actions_log_probs @@ -247,7 +246,9 @@ def sample_trajectories( trajectories_states.append(deepcopy(states)) trajectories_states = env.States.stack(trajectories_states) - trajectories_actions = env.Actions.stack(trajectories_actions)[1:] # Drop dummy action + trajectories_actions = env.Actions.stack(trajectories_actions)[ + 1: + ] # Drop dummy action trajectories_logprobs = ( torch.stack(trajectories_logprobs, dim=0)[1:] # Drop dummy logprob if save_logprobs @@ -257,7 +258,6 @@ def sample_trajectories( # TODO: use torch.nested.nested_tensor(dtype, device, requires_grad). if save_estimator_outputs: all_estimator_outputs = torch.stack(all_estimator_outputs, dim=0) - trajectories = Trajectories( env=env, states=trajectories_states, diff --git a/testing/test_samplers_and_trajectories.py b/testing/test_samplers_and_trajectories.py index 4ffb3c4f..3066fe4f 100644 --- a/testing/test_samplers_and_trajectories.py +++ b/testing/test_samplers_and_trajectories.py @@ -1,10 +1,16 @@ from typing import Literal, Tuple import pytest +import torch +from tensordict import TensorDict +from torch import nn +from torch_geometric.nn import GCNConv +from gfn.actions import GraphActionType from gfn.containers import Trajectories from gfn.containers.replay_buffer import ReplayBuffer from gfn.gym import Box, DiscreteEBM, HyperGrid +from gfn.gym.graph_building import GraphBuilding from gfn.gym.helpers.box_utils import BoxPBEstimator, BoxPBMLP, BoxPFEstimator, BoxPFMLP from gfn.modules import DiscretePolicyEstimator, GFNModule, GraphActionPolicyEstimator from gfn.samplers import LocalSearchSampler, Sampler From 38dd2b0aec74f47a61e9949c5537d3ce08e78afe Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Tue, 14 Jan 2025 00:27:38 +0100 Subject: [PATCH 036/102] fix toml --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 0b1958c7..f2ae2eb7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,7 +60,7 @@ dev = [ "sphinx", "tox", "flake8", - "torch_geometric>=2.6.0", + "torch_geometric", ] scripts = ["tqdm", "wandb", "scikit-learn", "scipy", "matplotlib"] From 12c49b7687369bcbd67377ff99b77faa9e15a492 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Tue, 14 Jan 2025 00:31:48 +0100 Subject: [PATCH 037/102] add dep & address issue --- pyproject.toml | 1 + src/gfn/states.py | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index f2ae2eb7..477bfd72 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -82,6 +82,7 @@ all = [ "tox", "tqdm", "wandb", + "torch_geometric", ] [tool.poetry.urls] diff --git a/src/gfn/states.py b/src/gfn/states.py index 133d1a87..2157651a 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -809,7 +809,6 @@ def extend(self, other: GraphStates): self.tensor["edge_feature"] = torch.cat( [self.tensor["edge_feature"], other.tensor["edge_feature"]], dim=0 ) - # TODO: fix indices self.tensor["edge_index"] = torch.cat( [ self.tensor["edge_index"], From fe237ed1acaf83cf9179b25b3aa924e23e083094 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Tue, 14 Jan 2025 00:36:50 +0100 Subject: [PATCH 038/102] fix toml --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 477bfd72..585e31da 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,7 @@ wandb = { version = "*", optional = true } scikit-learn = {version = "*", optional = true } scipy = { version = "*", optional = true } matplotlib = { version = "*", optional = true } +torch_geometric = { version = "*", optional = true } [tool.poetry.extras] dev = [ From 9bbc48db178dbbace8365f0f30acbaba5f1c35a6 Mon Sep 17 00:00:00 2001 From: Omar Younis <42100908+younik@users.noreply.github.com> Date: Tue, 14 Jan 2025 00:43:24 +0100 Subject: [PATCH 039/102] fix pyproject.toml --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 477bfd72..585e31da 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,7 @@ wandb = { version = "*", optional = true } scikit-learn = {version = "*", optional = true } scipy = { version = "*", optional = true } matplotlib = { version = "*", optional = true } +torch_geometric = { version = "*", optional = true } [tool.poetry.extras] dev = [ From 5e4fc4e9236b184312ba678e3f7562f1ce4635d0 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Wed, 15 Jan 2025 00:28:16 +0100 Subject: [PATCH 040/102] address comments --- src/gfn/gflownet/flow_matching.py | 8 ++++---- src/gfn/modules.py | 4 +++- src/gfn/utils/distributions.py | 12 ++++++------ tutorials/examples/test_graph_ring.py | 1 + 4 files changed, 14 insertions(+), 11 deletions(-) diff --git a/src/gfn/gflownet/flow_matching.py b/src/gfn/gflownet/flow_matching.py index 8347b835..606c0a8a 100644 --- a/src/gfn/gflownet/flow_matching.py +++ b/src/gfn/gflownet/flow_matching.py @@ -50,10 +50,10 @@ def sample_trajectories( **policy_kwargs: Any, ) -> Trajectories: """Sample trajectory with optional kwargs controling the policy.""" - # if not env.is_discrete: - # raise NotImplementedError( - # "Flow Matching GFlowNet only supports discrete environments for now." - # ) + if not env.is_discrete: + raise NotImplementedError( + "Flow Matching GFlowNet only supports discrete environments for now." + ) sampler = Sampler(estimator=self.logF) trajectories = sampler.sample_trajectories( env, diff --git a/src/gfn/modules.py b/src/gfn/modules.py index c11d4758..4f6b9f7b 100644 --- a/src/gfn/modules.py +++ b/src/gfn/modules.py @@ -527,6 +527,8 @@ def to_probability_distribution( dists["action_type"] = CategoricalActionType(probs=action_type_probs) edge_index_logits = module_output["edge_index"] + B, N, N = edge_index_logits.shape + edge_index_logits = edge_index_logits.reshape(B, N * N) if states.tensor["node_feature"].shape[0] > 1 and torch.any( edge_index_logits != -float("inf") ): @@ -537,7 +539,7 @@ def to_probability_distribution( edge_index_probs = ( 1 - epsilon ) * edge_index_probs + epsilon * uniform_dist_probs - dists["edge_index"] = CategoricalIndexes(probs=edge_index_probs) + dists["edge_index"] = CategoricalIndexes(probs=edge_index_probs, n=N) dists["features"] = Normal(module_output["features"], temperature) return ComposedDistribution(dists=dists) diff --git a/src/gfn/utils/distributions.py b/src/gfn/utils/distributions.py index 250a275b..86b47c55 100644 --- a/src/gfn/utils/distributions.py +++ b/src/gfn/utils/distributions.py @@ -43,7 +43,7 @@ def log_prob(self, sample: torch.Tensor) -> torch.Tensor: return super().log_prob(sample.squeeze(-1)) -class ComposedDistribution(Distribution): # TODO: CompositeDistribution in TensorDict +class ComposedDistribution(Distribution): # TODO: remove in favor of CompositeDistribution in TensorDict """A mixture distribution.""" def __init__(self, dists: Dict[str, Distribution]): @@ -69,16 +69,16 @@ def log_prob(self, sample: Dict[str, torch.Tensor]) -> torch.Tensor: class CategoricalIndexes(Categorical): """Samples indexes from a categorical distribution.""" - def __init__(self, probs: torch.Tensor): + def __init__(self, probs: torch.Tensor, n: int): """Initializes the distribution. Args: probs: The probabilities of the categorical distribution. + n: The number of nodes in the graph. """ - self.n = probs.shape[-1] - batch_size = probs.shape[0] - assert probs.shape == (batch_size, self.n, self.n) - super().__init__(probs.reshape(batch_size, self.n * self.n)) + self.n = n + assert probs.shape == (probs.shape[0], self.n * self.n) + super().__init__(probs) def sample(self, sample_shape=torch.Size()) -> torch.Tensor: samples = super().sample(sample_shape) diff --git a/tutorials/examples/test_graph_ring.py b/tutorials/examples/test_graph_ring.py index b6d401f7..a33ae61d 100644 --- a/tutorials/examples/test_graph_ring.py +++ b/tutorials/examples/test_graph_ring.py @@ -85,6 +85,7 @@ def __init__(self, n_nodes: int = 10): self.n_nodes = n_nodes self.n_actions = 1 + n_nodes * n_nodes super().__init__(feature_dim=1, state_evaluator=state_evaluator) + self.is_discrete = True # actions here are discrete, needed for FlowMatching def make_actions_class(self) -> type[Actions]: env = self From 705b4ccbf6d3f089e58cff7b1ac3d8b2dcb8faf3 Mon Sep 17 00:00:00 2001 From: "Omar G. Younis" Date: Wed, 15 Jan 2025 17:37:50 +0100 Subject: [PATCH 041/102] add tests for action --- src/gfn/actions.py | 29 ++++++++--- testing/test_actions.py | 109 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 130 insertions(+), 8 deletions(-) create mode 100644 testing/test_actions.py diff --git a/src/gfn/actions.py b/src/gfn/actions.py index e76388d9..88b4e635 100644 --- a/src/gfn/actions.py +++ b/src/gfn/actions.py @@ -176,6 +176,7 @@ class GraphActionType(enum.IntEnum): ADD_NODE = 0 ADD_EDGE = 1 EXIT = 2 + DUMMY = 3 class GraphActions(Actions): @@ -209,7 +210,7 @@ def __init__(self, tensor: TensorDict): self.batch_shape = tensor["action_type"].shape features = tensor.get("features", None) if features is None: - assert torch.all(tensor["action_type"] == GraphActionType.EXIT) + assert torch.all(torch.logical_or(tensor["action_type"] == GraphActionType.EXIT, tensor["action_type"] == GraphActionType.DUMMY)) features = torch.zeros((*self.batch_shape, self.features_dim)) edge_index = tensor.get("edge_index", None) if edge_index is None: @@ -269,6 +270,11 @@ def compare(self, other: GraphActions) -> torch.Tensor: def is_exit(self) -> torch.Tensor: """Returns a boolean tensor of shape `batch_shape` indicating whether the actions are exit actions.""" return self.action_type == GraphActionType.EXIT + + @property + def is_dummy(self) -> torch.Tensor: + """Returns a boolean tensor of shape `batch_shape` indicating whether the actions are dummy actions.""" + return self.action_type == GraphActionType.DUMMY @property def action_type(self) -> torch.Tensor: @@ -287,12 +293,12 @@ def edge_index(self) -> torch.Tensor: @classmethod def make_dummy_actions(cls, batch_shape: tuple[int]) -> GraphActions: - """Creates an Actions object of dummy actions with the given batch shape.""" + """Creates a GraphActions object of dummy actions with the given batch shape.""" return cls( TensorDict( { "action_type": torch.full( - batch_shape, fill_value=GraphActionType.EXIT + batch_shape, fill_value=GraphActionType.DUMMY ), }, batch_size=batch_shape, @@ -300,9 +306,16 @@ def make_dummy_actions(cls, batch_shape: tuple[int]) -> GraphActions: ) @classmethod - def stack(cls, actions_list: list[GraphActions]) -> GraphActions: - """Stacks a list of GraphActions objects into a single GraphActions object.""" - actions_tensor = torch.stack( - [actions.tensor for actions in actions_list], dim=0 + def make_exit_actions(cls, batch_shape: tuple[int]) -> Actions: + """Creates an GraphActions object of exit actions with the given batch shape.""" + return cls( + TensorDict( + { + "action_type": torch.full( + batch_shape, fill_value=GraphActionType.EXIT + ), + }, + batch_size=batch_shape, + ) ) - return cls(actions_tensor) + diff --git a/testing/test_actions.py b/testing/test_actions.py new file mode 100644 index 00000000..4dcf05f4 --- /dev/null +++ b/testing/test_actions.py @@ -0,0 +1,109 @@ +from copy import deepcopy +from gfn.actions import Actions, GraphActions +import pytest +import torch +from tensordict import TensorDict + + +class ContinuousActions(Actions): + action_shape = (10,) + dummy_action = torch.zeros(10) + exit_action = torch.ones(10) + +class GraphActions(GraphActions): + features_dim = 10 + + +@pytest.fixture +def continuous_action(): + return ContinuousActions( + tensor=torch.arange(0, 10) + ) + +@pytest.fixture +def graph_action(): + return GraphActions( + tensor=TensorDict( + { + "action_type": torch.zeros((1,), dtype=torch.float32), + "features": torch.zeros((1, 10), dtype=torch.float32), + }, + device="cpu", + ) + ) + + +def test_continuous_action(continuous_action): + BATCH = 5 + + exit_actions = continuous_action.make_exit_actions((BATCH,)) + assert torch.all(exit_actions.tensor == continuous_action.exit_action.repeat(BATCH, 1)) + assert torch.all(exit_actions.is_exit == torch.ones(BATCH, dtype=torch.bool)) + assert torch.all(exit_actions.is_dummy == torch.zeros(BATCH, dtype=torch.bool)) + + dummy_actions = continuous_action.make_dummy_actions((BATCH,)) + assert torch.all(dummy_actions.tensor == continuous_action.dummy_action.repeat(BATCH, 1)) + assert torch.all(dummy_actions.is_dummy == torch.ones(BATCH, dtype=torch.bool)) + assert torch.all(dummy_actions.is_exit == torch.zeros(BATCH, dtype=torch.bool)) + + # Test stack + stacked_actions = continuous_action.stack([exit_actions, dummy_actions]) + assert stacked_actions.batch_shape == (2, BATCH) + assert torch.all(stacked_actions.tensor == torch.stack([exit_actions.tensor, dummy_actions.tensor], dim=0)) + is_exit_stacked = torch.stack([exit_actions.is_exit, dummy_actions.is_exit], dim=0) + assert torch.all(stacked_actions.is_exit == is_exit_stacked) + assert stacked_actions[0, 1].is_exit + stacked_actions[0, 1] = stacked_actions[1, 1] + is_exit_stacked[0, 1] = False + assert torch.all(stacked_actions.is_exit == is_exit_stacked) + + # Test extend + extended_actions = deepcopy(exit_actions) + extended_actions.extend(dummy_actions) + assert extended_actions.batch_shape == (BATCH * 2,) + assert torch.all(extended_actions.tensor == torch.cat([exit_actions.tensor, dummy_actions.tensor], dim=0)) + is_exit_extended = torch.cat([exit_actions.is_exit, dummy_actions.is_exit], dim=0) + assert torch.all(extended_actions.is_exit == is_exit_extended) + assert extended_actions[0].is_exit and extended_actions[BATCH].is_dummy + extended_actions[0] = extended_actions[BATCH] + is_exit_extended[0] = False + assert torch.all(extended_actions.is_exit == is_exit_extended) + +def test_graph_action(graph_action): + BATCH = 5 + + exit_actions = graph_action.make_exit_actions((BATCH,)) + assert torch.all(exit_actions.is_exit == torch.ones(BATCH, dtype=torch.bool)) + assert torch.all(exit_actions.is_dummy == torch.zeros(BATCH, dtype=torch.bool)) + dummy_actions = graph_action.make_dummy_actions((BATCH,)) + assert torch.all(dummy_actions.is_dummy == torch.ones(BATCH, dtype=torch.bool)) + assert torch.all(dummy_actions.is_exit == torch.zeros(BATCH, dtype=torch.bool)) + + # Test stack + stacked_actions = graph_action.stack([exit_actions, dummy_actions]) + assert stacked_actions.batch_shape == (2, BATCH) + manually_stacked_tensor = torch.stack([exit_actions.tensor, dummy_actions.tensor], dim=0) + assert torch.all(stacked_actions.tensor["action_type"] == manually_stacked_tensor["action_type"]) + assert torch.all(stacked_actions.tensor["features"] == manually_stacked_tensor["features"]) + assert torch.all(stacked_actions.tensor["edge_index"] == manually_stacked_tensor["edge_index"]) + is_exit_stacked = torch.stack([exit_actions.is_exit, dummy_actions.is_exit], dim=0) + assert torch.all(stacked_actions.is_exit == is_exit_stacked) + assert stacked_actions[0, 1].is_exit + stacked_actions[0, 1] = stacked_actions[1, 1] + is_exit_stacked[0, 1] = False + assert torch.all(stacked_actions.is_exit == is_exit_stacked) + + # Test extend + extended_actions = deepcopy(exit_actions) + extended_actions.extend(dummy_actions) + assert extended_actions.batch_shape == (BATCH * 2,) + manually_extended_tensor = torch.cat([exit_actions.tensor, dummy_actions.tensor], dim=0) + assert torch.all(extended_actions.tensor["action_type"] == manually_extended_tensor["action_type"]) + assert torch.all(extended_actions.tensor["features"] == manually_extended_tensor["features"]) + assert torch.all(extended_actions.tensor["edge_index"] == manually_extended_tensor["edge_index"]) + is_exit_extended = torch.cat([exit_actions.is_exit, dummy_actions.is_exit], dim=0) + assert torch.all(extended_actions.is_exit == is_exit_extended) + assert extended_actions[0].is_exit and extended_actions[BATCH].is_dummy + extended_actions[0] = extended_actions[BATCH] + is_exit_extended[0] = False + assert torch.all(extended_actions.is_exit == is_exit_extended) \ No newline at end of file From d76533037266e72fb631f4543ffd260a31a3fa62 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Thu, 16 Jan 2025 00:08:50 +0100 Subject: [PATCH 042/102] fix test after added dummy action --- testing/test_samplers_and_trajectories.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/testing/test_samplers_and_trajectories.py b/testing/test_samplers_and_trajectories.py index 3066fe4f..9b2d84c3 100644 --- a/testing/test_samplers_and_trajectories.py +++ b/testing/test_samplers_and_trajectories.py @@ -313,7 +313,7 @@ class GraphActionNet(nn.Module): def __init__(self, feature_dim: int): super().__init__() self.feature_dim = feature_dim - self.action_type_conv = GCNConv(feature_dim, len(GraphActionType)) + self.action_type_conv = GCNConv(feature_dim, 3) self.features_conv = GCNConv(feature_dim, feature_dim) self.edge_index_conv = GCNConv(feature_dim, 8) @@ -322,7 +322,7 @@ def forward(self, states: GraphStates) -> TensorDict: edge_index = states.tensor["edge_index"].T if states.tensor["node_feature"].shape[0] == 0: - action_type = torch.zeros((len(states), len(GraphActionType))) + action_type = torch.zeros((len(states), 3)) action_type[:, GraphActionType.ADD_NODE] = 1 features = torch.zeros((len(states), self.feature_dim)) else: From 4ee6987f41678dd4ffe188693ec16e554f894f9e Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Thu, 16 Jan 2025 00:16:45 +0100 Subject: [PATCH 043/102] add GraphPreprocessor --- src/gfn/actions.py | 10 +++-- src/gfn/modules.py | 11 +++--- src/gfn/preprocessors.py | 11 +++++- src/gfn/utils/distributions.py | 4 +- testing/test_actions.py | 72 ++++++++++++++++++++++++---------- 5 files changed, 76 insertions(+), 32 deletions(-) diff --git a/src/gfn/actions.py b/src/gfn/actions.py index 88b4e635..7d223311 100644 --- a/src/gfn/actions.py +++ b/src/gfn/actions.py @@ -210,7 +210,12 @@ def __init__(self, tensor: TensorDict): self.batch_shape = tensor["action_type"].shape features = tensor.get("features", None) if features is None: - assert torch.all(torch.logical_or(tensor["action_type"] == GraphActionType.EXIT, tensor["action_type"] == GraphActionType.DUMMY)) + assert torch.all( + torch.logical_or( + tensor["action_type"] == GraphActionType.EXIT, + tensor["action_type"] == GraphActionType.DUMMY, + ) + ) features = torch.zeros((*self.batch_shape, self.features_dim)) edge_index = tensor.get("edge_index", None) if edge_index is None: @@ -270,7 +275,7 @@ def compare(self, other: GraphActions) -> torch.Tensor: def is_exit(self) -> torch.Tensor: """Returns a boolean tensor of shape `batch_shape` indicating whether the actions are exit actions.""" return self.action_type == GraphActionType.EXIT - + @property def is_dummy(self) -> torch.Tensor: """Returns a boolean tensor of shape `batch_shape` indicating whether the actions are dummy actions.""" @@ -318,4 +323,3 @@ def make_exit_actions(cls, batch_shape: tuple[int]) -> Actions: batch_size=batch_shape, ) ) - diff --git a/src/gfn/modules.py b/src/gfn/modules.py index 4f6b9f7b..0937e017 100644 --- a/src/gfn/modules.py +++ b/src/gfn/modules.py @@ -6,7 +6,7 @@ from tensordict import TensorDict from torch.distributions import Categorical, Distribution, Normal -from gfn.preprocessors import IdentityPreprocessor, Preprocessor +from gfn.preprocessors import GraphPreprocessor, IdentityPreprocessor, Preprocessor from gfn.states import DiscreteStates, GraphStates, States from gfn.utils.distributions import ( CategoricalActionType, @@ -464,7 +464,7 @@ class GraphActionPolicyEstimator(GFNModule): def __init__( self, module: nn.Module, - # preprocessor: Preprocessor | None = None, + preprocessor: Preprocessor | None = None, is_backward: bool = False, ): """Initializes a estimator for P_F for graph environments. @@ -472,10 +472,9 @@ def __init__( Args: is_backward: if False, then this is a forward policy, else backward policy. """ - # super().__init__(module, preprocessor, is_backward) - nn.Module.__init__(self) - self.module = module - self.is_backward = is_backward + if preprocessor is None: + preprocessor = GraphPreprocessor() + super().__init__(module, preprocessor, is_backward) def forward(self, states: GraphStates) -> TensorDict: """Forward pass of the module. diff --git a/src/gfn/preprocessors.py b/src/gfn/preprocessors.py index dfa3e2b1..599d794a 100644 --- a/src/gfn/preprocessors.py +++ b/src/gfn/preprocessors.py @@ -2,8 +2,9 @@ from typing import Callable import torch +from tensordict import TensorDict -from gfn.states import States +from gfn.states import GraphStates, States class Preprocessor(ABC): @@ -72,3 +73,11 @@ def preprocess(self, states) -> torch.Tensor: Returns the unique indices of the states as a tensor of shape `batch_shape`. """ return self.get_states_indices(states).long().unsqueeze(-1) + + +class GraphPreprocessor(Preprocessor): + def __init__(self) -> None: + super().__init__(-1) + + def preprocess(self, states: GraphStates) -> TensorDict: + return states.tensor diff --git a/src/gfn/utils/distributions.py b/src/gfn/utils/distributions.py index 86b47c55..c2ff97ec 100644 --- a/src/gfn/utils/distributions.py +++ b/src/gfn/utils/distributions.py @@ -43,7 +43,9 @@ def log_prob(self, sample: torch.Tensor) -> torch.Tensor: return super().log_prob(sample.squeeze(-1)) -class ComposedDistribution(Distribution): # TODO: remove in favor of CompositeDistribution in TensorDict +class ComposedDistribution( + Distribution +): # TODO: remove in favor of CompositeDistribution in TensorDict """A mixture distribution.""" def __init__(self, dists: Dict[str, Distribution]): diff --git a/testing/test_actions.py b/testing/test_actions.py index 4dcf05f4..2058d0d8 100644 --- a/testing/test_actions.py +++ b/testing/test_actions.py @@ -1,24 +1,26 @@ from copy import deepcopy -from gfn.actions import Actions, GraphActions + import pytest import torch from tensordict import TensorDict +from gfn.actions import Actions, GraphActions + class ContinuousActions(Actions): action_shape = (10,) dummy_action = torch.zeros(10) exit_action = torch.ones(10) + class GraphActions(GraphActions): features_dim = 10 @pytest.fixture def continuous_action(): - return ContinuousActions( - tensor=torch.arange(0, 10) - ) + return ContinuousActions(tensor=torch.arange(0, 10)) + @pytest.fixture def graph_action(): @@ -37,19 +39,26 @@ def test_continuous_action(continuous_action): BATCH = 5 exit_actions = continuous_action.make_exit_actions((BATCH,)) - assert torch.all(exit_actions.tensor == continuous_action.exit_action.repeat(BATCH, 1)) + assert torch.all( + exit_actions.tensor == continuous_action.exit_action.repeat(BATCH, 1) + ) assert torch.all(exit_actions.is_exit == torch.ones(BATCH, dtype=torch.bool)) assert torch.all(exit_actions.is_dummy == torch.zeros(BATCH, dtype=torch.bool)) dummy_actions = continuous_action.make_dummy_actions((BATCH,)) - assert torch.all(dummy_actions.tensor == continuous_action.dummy_action.repeat(BATCH, 1)) + assert torch.all( + dummy_actions.tensor == continuous_action.dummy_action.repeat(BATCH, 1) + ) assert torch.all(dummy_actions.is_dummy == torch.ones(BATCH, dtype=torch.bool)) assert torch.all(dummy_actions.is_exit == torch.zeros(BATCH, dtype=torch.bool)) # Test stack stacked_actions = continuous_action.stack([exit_actions, dummy_actions]) - assert stacked_actions.batch_shape == (2, BATCH) - assert torch.all(stacked_actions.tensor == torch.stack([exit_actions.tensor, dummy_actions.tensor], dim=0)) + assert stacked_actions.batch_shape == (2, BATCH) + assert torch.all( + stacked_actions.tensor + == torch.stack([exit_actions.tensor, dummy_actions.tensor], dim=0) + ) is_exit_stacked = torch.stack([exit_actions.is_exit, dummy_actions.is_exit], dim=0) assert torch.all(stacked_actions.is_exit == is_exit_stacked) assert stacked_actions[0, 1].is_exit @@ -59,9 +68,12 @@ def test_continuous_action(continuous_action): # Test extend extended_actions = deepcopy(exit_actions) - extended_actions.extend(dummy_actions) + extended_actions.extend(dummy_actions) assert extended_actions.batch_shape == (BATCH * 2,) - assert torch.all(extended_actions.tensor == torch.cat([exit_actions.tensor, dummy_actions.tensor], dim=0)) + assert torch.all( + extended_actions.tensor + == torch.cat([exit_actions.tensor, dummy_actions.tensor], dim=0) + ) is_exit_extended = torch.cat([exit_actions.is_exit, dummy_actions.is_exit], dim=0) assert torch.all(extended_actions.is_exit == is_exit_extended) assert extended_actions[0].is_exit and extended_actions[BATCH].is_dummy @@ -69,6 +81,7 @@ def test_continuous_action(continuous_action): is_exit_extended[0] = False assert torch.all(extended_actions.is_exit == is_exit_extended) + def test_graph_action(graph_action): BATCH = 5 @@ -81,11 +94,19 @@ def test_graph_action(graph_action): # Test stack stacked_actions = graph_action.stack([exit_actions, dummy_actions]) - assert stacked_actions.batch_shape == (2, BATCH) - manually_stacked_tensor = torch.stack([exit_actions.tensor, dummy_actions.tensor], dim=0) - assert torch.all(stacked_actions.tensor["action_type"] == manually_stacked_tensor["action_type"]) - assert torch.all(stacked_actions.tensor["features"] == manually_stacked_tensor["features"]) - assert torch.all(stacked_actions.tensor["edge_index"] == manually_stacked_tensor["edge_index"]) + assert stacked_actions.batch_shape == (2, BATCH) + manually_stacked_tensor = torch.stack( + [exit_actions.tensor, dummy_actions.tensor], dim=0 + ) + assert torch.all( + stacked_actions.tensor["action_type"] == manually_stacked_tensor["action_type"] + ) + assert torch.all( + stacked_actions.tensor["features"] == manually_stacked_tensor["features"] + ) + assert torch.all( + stacked_actions.tensor["edge_index"] == manually_stacked_tensor["edge_index"] + ) is_exit_stacked = torch.stack([exit_actions.is_exit, dummy_actions.is_exit], dim=0) assert torch.all(stacked_actions.is_exit == is_exit_stacked) assert stacked_actions[0, 1].is_exit @@ -95,15 +116,24 @@ def test_graph_action(graph_action): # Test extend extended_actions = deepcopy(exit_actions) - extended_actions.extend(dummy_actions) + extended_actions.extend(dummy_actions) assert extended_actions.batch_shape == (BATCH * 2,) - manually_extended_tensor = torch.cat([exit_actions.tensor, dummy_actions.tensor], dim=0) - assert torch.all(extended_actions.tensor["action_type"] == manually_extended_tensor["action_type"]) - assert torch.all(extended_actions.tensor["features"] == manually_extended_tensor["features"]) - assert torch.all(extended_actions.tensor["edge_index"] == manually_extended_tensor["edge_index"]) + manually_extended_tensor = torch.cat( + [exit_actions.tensor, dummy_actions.tensor], dim=0 + ) + assert torch.all( + extended_actions.tensor["action_type"] + == manually_extended_tensor["action_type"] + ) + assert torch.all( + extended_actions.tensor["features"] == manually_extended_tensor["features"] + ) + assert torch.all( + extended_actions.tensor["edge_index"] == manually_extended_tensor["edge_index"] + ) is_exit_extended = torch.cat([exit_actions.is_exit, dummy_actions.is_exit], dim=0) assert torch.all(extended_actions.is_exit == is_exit_extended) assert extended_actions[0].is_exit and extended_actions[BATCH].is_dummy extended_actions[0] = extended_actions[BATCH] is_exit_extended[0] = False - assert torch.all(extended_actions.is_exit == is_exit_extended) \ No newline at end of file + assert torch.all(extended_actions.is_exit == is_exit_extended) From fe9713cdbc1749e63543b7afa603ecbbfaa455fe Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Thu, 16 Jan 2025 00:20:11 +0100 Subject: [PATCH 044/102] added TODO --- src/gfn/preprocessors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gfn/preprocessors.py b/src/gfn/preprocessors.py index 599d794a..f37e87f9 100644 --- a/src/gfn/preprocessors.py +++ b/src/gfn/preprocessors.py @@ -77,7 +77,7 @@ def preprocess(self, states) -> torch.Tensor: class GraphPreprocessor(Preprocessor): def __init__(self) -> None: - super().__init__(-1) + super().__init__(-1) # TODO: review output_dim API def preprocess(self, states: GraphStates) -> TensorDict: return states.tensor From 1425eb69148b8d85082f7c536d64b6ff9f2cce20 Mon Sep 17 00:00:00 2001 From: "Omar G. Younis" Date: Sun, 19 Jan 2025 16:27:36 +0100 Subject: [PATCH 045/102] add complete masks --- src/gfn/gym/graph_building.py | 10 +-- src/gfn/modules.py | 16 ++--- src/gfn/states.py | 80 ++++++++++++++++------- testing/test_environments.py | 4 +- testing/test_samplers_and_trajectories.py | 10 ++- 5 files changed, 73 insertions(+), 47 deletions(-) diff --git a/src/gfn/gym/graph_building.py b/src/gfn/gym/graph_building.py index fd894e3a..be88b324 100644 --- a/src/gfn/gym/graph_building.py +++ b/src/gfn/gym/graph_building.py @@ -124,11 +124,8 @@ def backward_step(self, states: GraphStates, actions: GraphActions) -> TensorDic state_tensor["node_feature"] = state_tensor["node_feature"][~is_equal] elif action_type == GraphActionType.ADD_EDGE: assert actions.edge_index is not None - global_edge_index = ( - actions.edge_index + state_tensor["batch_ptr"][:-1][:, None] - ) is_equal = torch.all( - state_tensor["edge_index"] == global_edge_index[:, None], dim=-1 + state_tensor["edge_index"] == actions.edge_index[:, None], dim=-1 ) is_equal = torch.any(is_equal, dim=0) state_tensor["edge_feature"] = state_tensor["edge_feature"][~is_equal] @@ -168,11 +165,8 @@ def is_action_valid( add_edge_actions.edge_index > add_edge_states["node_feature"].shape[0] ): return False - global_edge_index = ( - add_edge_actions.edge_index + add_edge_states["batch_ptr"][:-1][:, None] - ) equal_edges_per_batch = torch.all( - add_edge_states["edge_index"] == global_edge_index[:, None], dim=-1 + add_edge_states["edge_index"] == add_edge_actions.edge_index[:, None], dim=-1 ).sum(dim=-1) if backward: diff --git a/src/gfn/modules.py b/src/gfn/modules.py index 0937e017..20aee99a 100644 --- a/src/gfn/modules.py +++ b/src/gfn/modules.py @@ -512,12 +512,12 @@ def to_probability_distribution( dists = {} action_type_logits = module_output["action_type"] - action_type_masks = ( + masks = ( states.backward_masks if self.is_backward else states.forward_masks ) - action_type_logits[~action_type_masks] = -float("inf") + action_type_logits[~masks["action_type"]] = -float("inf") action_type_probs = torch.softmax(action_type_logits / temperature, dim=-1) - uniform_dist_probs = action_type_masks.float() / action_type_masks.sum( + uniform_dist_probs = masks["action_type"].float() / masks["action_type"].sum( dim=-1, keepdim=True ) action_type_probs = ( @@ -526,11 +526,10 @@ def to_probability_distribution( dists["action_type"] = CategoricalActionType(probs=action_type_probs) edge_index_logits = module_output["edge_index"] - B, N, N = edge_index_logits.shape - edge_index_logits = edge_index_logits.reshape(B, N * N) - if states.tensor["node_feature"].shape[0] > 1 and torch.any( - edge_index_logits != -float("inf") - ): + edge_index_logits[~masks["edge_index"]] = -float("inf") + if torch.any(edge_index_logits != -float("inf")): + B, N, N = edge_index_logits.shape + edge_index_logits = edge_index_logits.reshape(B, N * N) edge_index_probs = torch.softmax(edge_index_logits / temperature, dim=-1) uniform_dist_probs = ( torch.ones_like(edge_index_probs) / edge_index_probs.shape[-1] @@ -538,6 +537,7 @@ def to_probability_distribution( edge_index_probs = ( 1 - epsilon ) * edge_index_probs + epsilon * uniform_dist_probs + edge_index_probs[torch.isnan(edge_index_probs)] = 1 dists["edge_index"] = CategoricalIndexes(probs=edge_index_probs, n=N) dists["features"] = Normal(module_output["features"], temperature) diff --git a/src/gfn/states.py b/src/gfn/states.py index 2157651a..061f4ca9 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -536,26 +536,8 @@ def __init__(self, tensor: TensorDict): self.tensor = tensor self.node_features_dim = tensor["node_feature"].shape[-1] self.edge_features_dim = tensor["edge_feature"].shape[-1] - self._log_rewards: Optional[float] = None - # TODO logic repeated from env.is_valid_action - not_empty = self.tensor["batch_ptr"][:-1] + 1 < self.tensor["batch_ptr"][1:] - self.forward_masks = torch.ones( - (np.prod(self.batch_shape), 3), dtype=torch.bool - ) - self.forward_masks[..., GraphActionType.ADD_EDGE] = not_empty - self.forward_masks[..., GraphActionType.EXIT] = not_empty - self.forward_masks = self.forward_masks.view(*self.batch_shape, 3) - self.backward_masks = torch.ones( - (np.prod(self.batch_shape), 3), dtype=torch.bool - ) - self.backward_masks[..., GraphActionType.ADD_NODE] = not_empty - self.backward_masks[ - ..., GraphActionType.ADD_EDGE - ] = not_empty # TODO: check at least one edge is present - self.backward_masks[..., GraphActionType.EXIT] = not_empty - self.backward_masks = self.backward_masks.view(*self.batch_shape, 3) @property def batch_shape(self) -> tuple: @@ -880,11 +862,61 @@ def stack(cls, states: List[GraphStates]): assert state.batch_shape == state_batch_shape stacked_states.extend(state) - stacked_states.forward_masks = torch.stack( - [s.forward_masks for s in states], dim=0 - ) - stacked_states.backward_masks = torch.stack( - [s.backward_masks for s in states], dim=0 - ) stacked_states.tensor["batch_shape"] = (len(states),) + state_batch_shape return stacked_states + + @property + def forward_masks(self) -> TensorDict: + n_nodes = self.tensor["batch_ptr"][1:] - self.tensor["batch_ptr"][:-1] + ei_mask_shape = (len(self.tensor["node_feature"]), len(self.tensor["node_feature"])) + forward_masks = TensorDict({ + "action_type": torch.ones(self.batch_shape + (3,), dtype=torch.bool), + "features": torch.ones(self.batch_shape + (self.node_features_dim,), dtype=torch.bool), + "edge_index": torch.zeros(self.batch_shape + ei_mask_shape, dtype=torch.bool), + }) # TODO: edge_index mask is very memory consuming... + forward_masks["action_type"][..., GraphActionType.ADD_EDGE] = n_nodes > 1 + forward_masks["action_type"][..., GraphActionType.EXIT] = n_nodes >= 1 + + arange = torch.arange(len(self)).view(self.batch_shape) + arange_nodes = torch.arange(len(self.tensor["node_feature"]))[None, :] + same_graph_mask = (arange_nodes >= self.tensor["batch_ptr"][:-1, None]) & (arange_nodes < self.tensor["batch_ptr"][1:, None]) + ei1 = self.tensor["edge_index"][..., 0] + ei2 = self.tensor["edge_index"][..., 1] + for _ in range(len(self.batch_shape)): + ei1, ei2, = ei1.unsqueeze(0), ei2.unsqueeze(0) + + # First allow nodes in the same graph to connect, then disable nodes with existing edges + forward_masks["edge_index"][same_graph_mask[:, :, None] & same_graph_mask[:, None, :]] = True + torch.diagonal(forward_masks["edge_index"], dim1=-2, dim2=-1).fill_(False) + forward_masks["edge_index"][arange[..., None], ei1, ei2] = False + forward_masks["action_type"][..., GraphActionType.ADD_EDGE] &= torch.any(forward_masks["edge_index"], dim=(-1, -2)) + return forward_masks + + @property + def backward_masks(self) -> TensorDict: + n_nodes = self.tensor["batch_ptr"][1:] - self.tensor["batch_ptr"][:-1] + n_edges = torch.count_nonzero( + (self.tensor["edge_index"][None, :, 0] >= self.tensor["batch_ptr"][:-1, None]) & + (self.tensor["edge_index"][None, :, 0] < self.tensor["batch_ptr"][1:, None]) & + (self.tensor["edge_index"][None, :, 1] >= self.tensor["batch_ptr"][:-1, None]) & + (self.tensor["edge_index"][None, :, 1] < self.tensor["batch_ptr"][1:, None]), + dim=-1 + ) + ei_mask_shape = (len(self.tensor["node_feature"]), len(self.tensor["node_feature"])) + backward_masks = TensorDict({ + "action_type": torch.ones(self.batch_shape + (3,), dtype=torch.bool), + "features": torch.ones(self.batch_shape + (self.node_features_dim,), dtype=torch.bool), + "edge_index": torch.zeros(self.batch_shape + ei_mask_shape, dtype=torch.bool), + }) # TODO: edge_index mask is very memory consuming... + backward_masks["action_type"][..., GraphActionType.ADD_NODE] = n_nodes >= 1 + backward_masks["action_type"][..., GraphActionType.ADD_EDGE] = n_edges + backward_masks["action_type"][..., GraphActionType.EXIT] = n_nodes >= 1 + + # Allow only existing edges + arange = torch.arange(len(self)).view(self.batch_shape) + ei1 = self.tensor["edge_index"][..., 0] + ei2 = self.tensor["edge_index"][..., 1] + for _ in range(len(self.batch_shape)): + ei1, ei2, = ei1.unsqueeze(0), ei2.unsqueeze(0) + backward_masks["edge_index"][arange[..., None], ei1, ei2] = False + return backward_masks \ No newline at end of file diff --git a/testing/test_environments.py b/testing/test_environments.py index 835b6687..79a690ef 100644 --- a/testing/test_environments.py +++ b/testing/test_environments.py @@ -416,6 +416,9 @@ def test_graph_env(): batch_size=BATCH_SIZE, ) ) + + states.forward_masks + states.backward_masks sf_states = env.step(states, actions) sf_states = env.States(sf_states) assert torch.all(sf_states.is_sink_state) @@ -438,7 +441,6 @@ def test_graph_env(): "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_EDGE), "features": states.tensor["edge_feature"][edge_idx], "edge_index": states.tensor["edge_index"][edge_idx] - - states.tensor["batch_ptr"][:-1][:, None], }, batch_size=BATCH_SIZE, ) diff --git a/testing/test_samplers_and_trajectories.py b/testing/test_samplers_and_trajectories.py index 9b2d84c3..efcd6237 100644 --- a/testing/test_samplers_and_trajectories.py +++ b/testing/test_samplers_and_trajectories.py @@ -334,22 +334,20 @@ def forward(self, states: GraphStates) -> TensorDict: features = features.reshape(len(states), -1, features.shape[-1]).mean(dim=1) edge_index = self.edge_index_conv(node_feature, edge_index) - edge_index = edge_index.reshape(*states.batch_shape, -1, 8) - edge_index = torch.einsum("bnf,bmf->bnm", edge_index, edge_index) - torch.diagonal(edge_index, dim1=-2, dim2=-1).fill_(float("-inf")) - + edge_index = torch.einsum("nf,mf->nm", edge_index, edge_index) + edge_index = edge_index[None].repeat(len(states), 1, 1) + return TensorDict( { "action_type": action_type, "features": features, - "edge_index": edge_index, + "edge_index": edge_index.reshape(states.batch_shape + edge_index.shape[1:]), }, batch_size=states.batch_shape, ) def test_graph_building(): - torch.manual_seed(7) feature_dim = 8 env = GraphBuilding( feature_dim=feature_dim, state_evaluator=lambda s: torch.zeros(s.batch_shape) From 36c42ecc384e378078d56ba6fcd42e0c1bc4b2bb Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Sun, 19 Jan 2025 16:28:48 +0100 Subject: [PATCH 046/102] pre-commit hook --- src/gfn/gym/graph_building.py | 3 +- src/gfn/modules.py | 4 +- src/gfn/states.py | 95 +++++++++++++++++------ testing/test_environments.py | 2 +- testing/test_samplers_and_trajectories.py | 6 +- 5 files changed, 78 insertions(+), 32 deletions(-) diff --git a/src/gfn/gym/graph_building.py b/src/gfn/gym/graph_building.py index be88b324..51c4ed05 100644 --- a/src/gfn/gym/graph_building.py +++ b/src/gfn/gym/graph_building.py @@ -166,7 +166,8 @@ def is_action_valid( ): return False equal_edges_per_batch = torch.all( - add_edge_states["edge_index"] == add_edge_actions.edge_index[:, None], dim=-1 + add_edge_states["edge_index"] == add_edge_actions.edge_index[:, None], + dim=-1, ).sum(dim=-1) if backward: diff --git a/src/gfn/modules.py b/src/gfn/modules.py index 20aee99a..491db872 100644 --- a/src/gfn/modules.py +++ b/src/gfn/modules.py @@ -512,9 +512,7 @@ def to_probability_distribution( dists = {} action_type_logits = module_output["action_type"] - masks = ( - states.backward_masks if self.is_backward else states.forward_masks - ) + masks = states.backward_masks if self.is_backward else states.forward_masks action_type_logits[~masks["action_type"]] = -float("inf") action_type_probs = torch.softmax(action_type_logits / temperature, dim=-1) uniform_dist_probs = masks["action_type"].float() / masks["action_type"].sum( diff --git a/src/gfn/states.py b/src/gfn/states.py index 061f4ca9..38dde67e 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -538,7 +538,6 @@ def __init__(self, tensor: TensorDict): self.edge_features_dim = tensor["edge_feature"].shape[-1] self._log_rewards: Optional[float] = None - @property def batch_shape(self) -> tuple: return tuple(self.tensor["batch_shape"].tolist()) @@ -864,50 +863,91 @@ def stack(cls, states: List[GraphStates]): stacked_states.tensor["batch_shape"] = (len(states),) + state_batch_shape return stacked_states - + @property def forward_masks(self) -> TensorDict: n_nodes = self.tensor["batch_ptr"][1:] - self.tensor["batch_ptr"][:-1] - ei_mask_shape = (len(self.tensor["node_feature"]), len(self.tensor["node_feature"])) - forward_masks = TensorDict({ - "action_type": torch.ones(self.batch_shape + (3,), dtype=torch.bool), - "features": torch.ones(self.batch_shape + (self.node_features_dim,), dtype=torch.bool), - "edge_index": torch.zeros(self.batch_shape + ei_mask_shape, dtype=torch.bool), - }) # TODO: edge_index mask is very memory consuming... + ei_mask_shape = ( + len(self.tensor["node_feature"]), + len(self.tensor["node_feature"]), + ) + forward_masks = TensorDict( + { + "action_type": torch.ones(self.batch_shape + (3,), dtype=torch.bool), + "features": torch.ones( + self.batch_shape + (self.node_features_dim,), dtype=torch.bool + ), + "edge_index": torch.zeros( + self.batch_shape + ei_mask_shape, dtype=torch.bool + ), + } + ) # TODO: edge_index mask is very memory consuming... forward_masks["action_type"][..., GraphActionType.ADD_EDGE] = n_nodes > 1 forward_masks["action_type"][..., GraphActionType.EXIT] = n_nodes >= 1 arange = torch.arange(len(self)).view(self.batch_shape) arange_nodes = torch.arange(len(self.tensor["node_feature"]))[None, :] - same_graph_mask = (arange_nodes >= self.tensor["batch_ptr"][:-1, None]) & (arange_nodes < self.tensor["batch_ptr"][1:, None]) + same_graph_mask = (arange_nodes >= self.tensor["batch_ptr"][:-1, None]) & ( + arange_nodes < self.tensor["batch_ptr"][1:, None] + ) ei1 = self.tensor["edge_index"][..., 0] ei2 = self.tensor["edge_index"][..., 1] for _ in range(len(self.batch_shape)): - ei1, ei2, = ei1.unsqueeze(0), ei2.unsqueeze(0) + ( + ei1, + ei2, + ) = ei1.unsqueeze( + 0 + ), ei2.unsqueeze(0) # First allow nodes in the same graph to connect, then disable nodes with existing edges - forward_masks["edge_index"][same_graph_mask[:, :, None] & same_graph_mask[:, None, :]] = True + forward_masks["edge_index"][ + same_graph_mask[:, :, None] & same_graph_mask[:, None, :] + ] = True torch.diagonal(forward_masks["edge_index"], dim1=-2, dim2=-1).fill_(False) forward_masks["edge_index"][arange[..., None], ei1, ei2] = False - forward_masks["action_type"][..., GraphActionType.ADD_EDGE] &= torch.any(forward_masks["edge_index"], dim=(-1, -2)) + forward_masks["action_type"][..., GraphActionType.ADD_EDGE] &= torch.any( + forward_masks["edge_index"], dim=(-1, -2) + ) return forward_masks @property def backward_masks(self) -> TensorDict: n_nodes = self.tensor["batch_ptr"][1:] - self.tensor["batch_ptr"][:-1] n_edges = torch.count_nonzero( - (self.tensor["edge_index"][None, :, 0] >= self.tensor["batch_ptr"][:-1, None]) & - (self.tensor["edge_index"][None, :, 0] < self.tensor["batch_ptr"][1:, None]) & - (self.tensor["edge_index"][None, :, 1] >= self.tensor["batch_ptr"][:-1, None]) & - (self.tensor["edge_index"][None, :, 1] < self.tensor["batch_ptr"][1:, None]), - dim=-1 + ( + self.tensor["edge_index"][None, :, 0] + >= self.tensor["batch_ptr"][:-1, None] + ) + & ( + self.tensor["edge_index"][None, :, 0] + < self.tensor["batch_ptr"][1:, None] + ) + & ( + self.tensor["edge_index"][None, :, 1] + >= self.tensor["batch_ptr"][:-1, None] + ) + & ( + self.tensor["edge_index"][None, :, 1] + < self.tensor["batch_ptr"][1:, None] + ), + dim=-1, + ) + ei_mask_shape = ( + len(self.tensor["node_feature"]), + len(self.tensor["node_feature"]), ) - ei_mask_shape = (len(self.tensor["node_feature"]), len(self.tensor["node_feature"])) - backward_masks = TensorDict({ - "action_type": torch.ones(self.batch_shape + (3,), dtype=torch.bool), - "features": torch.ones(self.batch_shape + (self.node_features_dim,), dtype=torch.bool), - "edge_index": torch.zeros(self.batch_shape + ei_mask_shape, dtype=torch.bool), - }) # TODO: edge_index mask is very memory consuming... + backward_masks = TensorDict( + { + "action_type": torch.ones(self.batch_shape + (3,), dtype=torch.bool), + "features": torch.ones( + self.batch_shape + (self.node_features_dim,), dtype=torch.bool + ), + "edge_index": torch.zeros( + self.batch_shape + ei_mask_shape, dtype=torch.bool + ), + } + ) # TODO: edge_index mask is very memory consuming... backward_masks["action_type"][..., GraphActionType.ADD_NODE] = n_nodes >= 1 backward_masks["action_type"][..., GraphActionType.ADD_EDGE] = n_edges backward_masks["action_type"][..., GraphActionType.EXIT] = n_nodes >= 1 @@ -917,6 +957,11 @@ def backward_masks(self) -> TensorDict: ei1 = self.tensor["edge_index"][..., 0] ei2 = self.tensor["edge_index"][..., 1] for _ in range(len(self.batch_shape)): - ei1, ei2, = ei1.unsqueeze(0), ei2.unsqueeze(0) + ( + ei1, + ei2, + ) = ei1.unsqueeze( + 0 + ), ei2.unsqueeze(0) backward_masks["edge_index"][arange[..., None], ei1, ei2] = False - return backward_masks \ No newline at end of file + return backward_masks diff --git a/testing/test_environments.py b/testing/test_environments.py index 79a690ef..44af115f 100644 --- a/testing/test_environments.py +++ b/testing/test_environments.py @@ -440,7 +440,7 @@ def test_graph_env(): { "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_EDGE), "features": states.tensor["edge_feature"][edge_idx], - "edge_index": states.tensor["edge_index"][edge_idx] + "edge_index": states.tensor["edge_index"][edge_idx], }, batch_size=BATCH_SIZE, ) diff --git a/testing/test_samplers_and_trajectories.py b/testing/test_samplers_and_trajectories.py index efcd6237..2800a8eb 100644 --- a/testing/test_samplers_and_trajectories.py +++ b/testing/test_samplers_and_trajectories.py @@ -336,12 +336,14 @@ def forward(self, states: GraphStates) -> TensorDict: edge_index = self.edge_index_conv(node_feature, edge_index) edge_index = torch.einsum("nf,mf->nm", edge_index, edge_index) edge_index = edge_index[None].repeat(len(states), 1, 1) - + return TensorDict( { "action_type": action_type, "features": features, - "edge_index": edge_index.reshape(states.batch_shape + edge_index.shape[1:]), + "edge_index": edge_index.reshape( + states.batch_shape + edge_index.shape[1:] + ), }, batch_size=states.batch_shape, ) From 5747e97b52d18c67fa0cf495e47a1513bd16acf9 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Mon, 20 Jan 2025 00:31:41 +0100 Subject: [PATCH 047/102] adress comments --- pyproject.toml | 2 +- src/gfn/env.py | 1 + src/gfn/gym/graph_building.py | 33 +++++++++++++++------------------ 3 files changed, 17 insertions(+), 19 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 585e31da..d299aa8c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,7 +45,7 @@ wandb = { version = "*", optional = true } scikit-learn = {version = "*", optional = true } scipy = { version = "*", optional = true } matplotlib = { version = "*", optional = true } -torch_geometric = { version = "*", optional = true } +torch_geometric = { version = ">=2.6.1", optional = true } [tool.poetry.extras] dev = [ diff --git a/src/gfn/env.py b/src/gfn/env.py index 9b0b3549..8f5931c4 100644 --- a/src/gfn/env.py +++ b/src/gfn/env.py @@ -261,6 +261,7 @@ def _step( "Some actions are not valid in the given states. See `is_action_valid`." ) + # Set to the sink state when the action is exit. new_sink_states_idx = actions.is_exit sf_tensor = self.States.make_sink_states_tensor((new_sink_states_idx.sum(),)) new_states[new_sink_states_idx] = self.States(sf_tensor) diff --git a/src/gfn/gym/graph_building.py b/src/gfn/gym/graph_building.py index 51c4ed05..ad3f6d90 100644 --- a/src/gfn/gym/graph_building.py +++ b/src/gfn/gym/graph_building.py @@ -1,4 +1,3 @@ -from copy import deepcopy from typing import Callable, Literal, Tuple import torch @@ -68,9 +67,8 @@ def step(self, states: GraphStates, actions: GraphActions) -> TensorDict: """ if not self.is_action_valid(states, actions): raise NonValidActionsError("Invalid action.") - state_tensor = deepcopy(states.tensor) if len(actions) == 0: - return state_tensor + return states.tensor action_type = actions.action_type[0] assert torch.all(actions.action_type == action_type) @@ -81,21 +79,21 @@ def step(self, states: GraphStates, actions: GraphActions) -> TensorDict: batch_indices = torch.arange(len(states))[ actions.action_type == GraphActionType.ADD_NODE ] - state_tensor = self._add_node(state_tensor, batch_indices, actions.features) + states.tensor = self._add_node(states.tensor, batch_indices, actions.features) if action_type == GraphActionType.ADD_EDGE: - state_tensor["edge_feature"] = torch.cat( - [state_tensor["edge_feature"], actions.features], dim=0 + states.tensor["edge_feature"] = torch.cat( + [states.tensor["edge_feature"], actions.features], dim=0 ) - state_tensor["edge_index"] = torch.cat( + states.tensor["edge_index"] = torch.cat( [ - state_tensor["edge_index"], - actions.edge_index + state_tensor["batch_ptr"][:-1][:, None], + states.tensor["edge_index"], + actions.edge_index + states.tensor["batch_ptr"][:-1][:, None], ], dim=0, ) - return state_tensor + return states.tensor def backward_step(self, states: GraphStates, actions: GraphActions) -> TensorDict: """Backward step function for the GraphBuilding environment. @@ -108,30 +106,29 @@ def backward_step(self, states: GraphStates, actions: GraphActions) -> TensorDic """ if not self.is_action_valid(states, actions, backward=True): raise NonValidActionsError("Invalid action.") - state_tensor = deepcopy(states.tensor) if len(actions) == 0: - return state_tensor + return states.tensor action_type = actions.action_type[0] assert torch.all(actions.action_type == action_type) if action_type == GraphActionType.ADD_NODE: is_equal = torch.any( torch.all( - state_tensor["node_feature"][:, None] == actions.features, dim=-1 + states.tensor["node_feature"][:, None] == actions.features, dim=-1 ), dim=-1, ) - state_tensor["node_feature"] = state_tensor["node_feature"][~is_equal] + states.tensor["node_feature"] = states.tensor["node_feature"][~is_equal] elif action_type == GraphActionType.ADD_EDGE: assert actions.edge_index is not None is_equal = torch.all( - state_tensor["edge_index"] == actions.edge_index[:, None], dim=-1 + states.tensor["edge_index"] == actions.edge_index[:, None], dim=-1 ) is_equal = torch.any(is_equal, dim=0) - state_tensor["edge_feature"] = state_tensor["edge_feature"][~is_equal] - state_tensor["edge_index"] = state_tensor["edge_index"][~is_equal] + states.tensor["edge_feature"] = states.tensor["edge_feature"][~is_equal] + states.tensor["edge_index"] = states.tensor["edge_index"][~is_equal] - return state_tensor + return states.tensor def is_action_valid( self, states: GraphStates, actions: GraphActions, backward: bool = False From e9f9951f3facfcc5416fd3a3c5777cf952be4a8a Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Mon, 20 Jan 2025 00:33:00 +0100 Subject: [PATCH 048/102] pre-commit --- src/gfn/gym/graph_building.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/gfn/gym/graph_building.py b/src/gfn/gym/graph_building.py index ad3f6d90..6256d41f 100644 --- a/src/gfn/gym/graph_building.py +++ b/src/gfn/gym/graph_building.py @@ -79,7 +79,9 @@ def step(self, states: GraphStates, actions: GraphActions) -> TensorDict: batch_indices = torch.arange(len(states))[ actions.action_type == GraphActionType.ADD_NODE ] - states.tensor = self._add_node(states.tensor, batch_indices, actions.features) + states.tensor = self._add_node( + states.tensor, batch_indices, actions.features + ) if action_type == GraphActionType.ADD_EDGE: states.tensor["edge_feature"] = torch.cat( From 406cfcaf29a615b81291ac26b41a22758536e7af Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Mon, 20 Jan 2025 22:38:05 +0100 Subject: [PATCH 049/102] address comments --- src/gfn/modules.py | 6 +++--- src/gfn/utils/distributions.py | 6 ++++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/gfn/modules.py b/src/gfn/modules.py index 491db872..286a9730 100644 --- a/src/gfn/modules.py +++ b/src/gfn/modules.py @@ -11,7 +11,7 @@ from gfn.utils.distributions import ( CategoricalActionType, CategoricalIndexes, - ComposedDistribution, + CompositeDistribution, UnsqueezedCategorical, ) @@ -496,7 +496,7 @@ def to_probability_distribution( module_output: TensorDict, temperature: float = 1.0, epsilon: float = 0.0, - ) -> ComposedDistribution: + ) -> CompositeDistribution: """Returns a probability distribution given a batch of states and module output. We handle off-policyness using these kwargs. @@ -539,4 +539,4 @@ def to_probability_distribution( dists["edge_index"] = CategoricalIndexes(probs=edge_index_probs, n=N) dists["features"] = Normal(module_output["features"], temperature) - return ComposedDistribution(dists=dists) + return CompositeDistribution(dists=dists) diff --git a/src/gfn/utils/distributions.py b/src/gfn/utils/distributions.py index c2ff97ec..715e73eb 100644 --- a/src/gfn/utils/distributions.py +++ b/src/gfn/utils/distributions.py @@ -43,9 +43,9 @@ def log_prob(self, sample: torch.Tensor) -> torch.Tensor: return super().log_prob(sample.squeeze(-1)) -class ComposedDistribution( +class CompositeDistribution( Distribution -): # TODO: remove in favor of CompositeDistribution in TensorDict +): # TODO: may use CompositeDistribution in TensorDict """A mixture distribution.""" def __init__(self, dists: Dict[str, Distribution]): @@ -65,6 +65,8 @@ def log_prob(self, sample: Dict[str, torch.Tensor]) -> torch.Tensor: v.log_prob(sample[k]).reshape(sample[k].shape[0], -1).sum(dim=-1) for k, v in self.dists.items() ] + # Note: this returns the sum of the log_probs over all the components + # as it is a uniform mixture distribution. return sum(log_probs) From 08e519bfd72c39c45f2ae1f9af2e04314067244a Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Fri, 24 Jan 2025 23:43:34 +0100 Subject: [PATCH 050/102] fix ring example --- src/gfn/gym/graph_building.py | 14 +++++++------- testing/test_environments.py | 2 +- .../{test_graph_ring.py => train_graph_ring.py} | 0 3 files changed, 8 insertions(+), 8 deletions(-) rename tutorials/examples/{test_graph_ring.py => train_graph_ring.py} (100%) diff --git a/src/gfn/gym/graph_building.py b/src/gfn/gym/graph_building.py index 6256d41f..ef5d5645 100644 --- a/src/gfn/gym/graph_building.py +++ b/src/gfn/gym/graph_building.py @@ -123,8 +123,9 @@ def backward_step(self, states: GraphStates, actions: GraphActions) -> TensorDic states.tensor["node_feature"] = states.tensor["node_feature"][~is_equal] elif action_type == GraphActionType.ADD_EDGE: assert actions.edge_index is not None + remove_edge_index = actions.edge_index + states.tensor["batch_ptr"][:-1][:, None] is_equal = torch.all( - states.tensor["edge_index"] == actions.edge_index[:, None], dim=-1 + states.tensor["edge_index"] == remove_edge_index[:, None], dim=-1 ) is_equal = torch.any(is_equal, dim=0) states.tensor["edge_feature"] = states.tensor["edge_feature"][~is_equal] @@ -153,19 +154,18 @@ def is_action_valid( add_edge_out = True else: add_edge_states = states[add_edge_mask].tensor - add_edge_actions = actions[add_edge_mask] - if torch.any( - add_edge_actions.edge_index[:, 0] == add_edge_actions.edge_index[:, 1] - ): + add_edge_actions = actions[add_edge_mask].edge_index + add_edge_states["batch_ptr"][:-1][:, None] + if torch.any(add_edge_actions[:, 0] == add_edge_actions[:, 1]): return False if add_edge_states["node_feature"].shape[0] == 0: return False if torch.any( - add_edge_actions.edge_index > add_edge_states["node_feature"].shape[0] + add_edge_actions > add_edge_states["node_feature"].shape[0] ): return False + equal_edges_per_batch = torch.all( - add_edge_states["edge_index"] == add_edge_actions.edge_index[:, None], + add_edge_states["edge_index"] == add_edge_actions[:, None], dim=-1, ).sum(dim=-1) diff --git a/testing/test_environments.py b/testing/test_environments.py index 44af115f..117abd1f 100644 --- a/testing/test_environments.py +++ b/testing/test_environments.py @@ -440,7 +440,7 @@ def test_graph_env(): { "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_EDGE), "features": states.tensor["edge_feature"][edge_idx], - "edge_index": states.tensor["edge_index"][edge_idx], + "edge_index": states.tensor["edge_index"][edge_idx] - states.tensor["batch_ptr"][:-1][:, None], }, batch_size=BATCH_SIZE, ) diff --git a/tutorials/examples/test_graph_ring.py b/tutorials/examples/train_graph_ring.py similarity index 100% rename from tutorials/examples/test_graph_ring.py rename to tutorials/examples/train_graph_ring.py From 17e07ad30320df847a5078912f0344c4ccd39e5d Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Wed, 29 Jan 2025 16:19:53 +0100 Subject: [PATCH 051/102] make edge_index global --- src/gfn/gym/graph_building.py | 13 ++++++------- testing/test_environments.py | 2 +- tutorials/examples/train_graph_ring.py | 8 ++++---- 3 files changed, 11 insertions(+), 12 deletions(-) diff --git a/src/gfn/gym/graph_building.py b/src/gfn/gym/graph_building.py index ef5d5645..952ea16f 100644 --- a/src/gfn/gym/graph_building.py +++ b/src/gfn/gym/graph_building.py @@ -90,7 +90,7 @@ def step(self, states: GraphStates, actions: GraphActions) -> TensorDict: states.tensor["edge_index"] = torch.cat( [ states.tensor["edge_index"], - actions.edge_index + states.tensor["batch_ptr"][:-1][:, None], + actions.edge_index, ], dim=0, ) @@ -123,9 +123,8 @@ def backward_step(self, states: GraphStates, actions: GraphActions) -> TensorDic states.tensor["node_feature"] = states.tensor["node_feature"][~is_equal] elif action_type == GraphActionType.ADD_EDGE: assert actions.edge_index is not None - remove_edge_index = actions.edge_index + states.tensor["batch_ptr"][:-1][:, None] is_equal = torch.all( - states.tensor["edge_index"] == remove_edge_index[:, None], dim=-1 + states.tensor["edge_index"] == actions.edge_index[:, None], dim=-1 ) is_equal = torch.any(is_equal, dim=0) states.tensor["edge_feature"] = states.tensor["edge_feature"][~is_equal] @@ -153,8 +152,8 @@ def is_action_valid( if not torch.any(add_edge_mask): add_edge_out = True else: - add_edge_states = states[add_edge_mask].tensor - add_edge_actions = actions[add_edge_mask].edge_index + add_edge_states["batch_ptr"][:-1][:, None] + add_edge_states = states.tensor + add_edge_actions = actions[add_edge_mask].edge_index if torch.any(add_edge_actions[:, 0] == add_edge_actions[:, 1]): return False if add_edge_states["node_feature"].shape[0] == 0: @@ -216,8 +215,8 @@ def _add_node( ] ) - # Update edge indices - # Increment indices for edges after the current graph + # Increment indices for edges after the graph at graph_idx, + # i.e. the edges that point to nodes after end_ptr edge_mask_0 = modified_dict["edge_index"][:, 0] >= end_ptr edge_mask_1 = modified_dict["edge_index"][:, 1] >= end_ptr modified_dict["edge_index"][edge_mask_0, 0] += shift diff --git a/testing/test_environments.py b/testing/test_environments.py index 117abd1f..44af115f 100644 --- a/testing/test_environments.py +++ b/testing/test_environments.py @@ -440,7 +440,7 @@ def test_graph_env(): { "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_EDGE), "features": states.tensor["edge_feature"][edge_idx], - "edge_index": states.tensor["edge_index"][edge_idx] - states.tensor["batch_ptr"][:-1][:, None], + "edge_index": states.tensor["edge_index"][edge_idx], }, batch_size=BATCH_SIZE, ) diff --git a/tutorials/examples/train_graph_ring.py b/tutorials/examples/train_graph_ring.py index a33ae61d..58d3659b 100644 --- a/tutorials/examples/train_graph_ring.py +++ b/tutorials/examples/train_graph_ring.py @@ -167,14 +167,14 @@ def backward_masks(self, value: torch.Tensor): return RingStates def _step(self, states: GraphStates, actions: Actions) -> GraphStates: - actions = self.convert_actions(actions) + actions = self.convert_actions(states, actions) return super()._step(states, actions) def _backward_step(self, states: GraphStates, actions: Actions) -> GraphStates: - actions = self.convert_actions(actions) + actions = self.convert_actions(states, actions) return super()._backward_step(states, actions) - def convert_actions(self, actions: Actions) -> GraphActions: + def convert_actions(self, states: GraphStates, actions: Actions) -> GraphActions: action_tensor = actions.tensor.squeeze(-1) action_type = torch.where( action_tensor == 0, GraphActionType.EXIT, GraphActionType.ADD_EDGE @@ -188,7 +188,7 @@ def convert_actions(self, actions: Actions) -> GraphActions: { "action_type": action_type, "features": torch.ones(action_tensor.shape + (1,)), - "edge_index": edge_index, + "edge_index": edge_index + states.tensor["batch_ptr"][:-1][:, None], }, batch_size=action_tensor.shape, ) From 46e36982a3b77fd78175ff281281ff6476e1f6d2 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Wed, 29 Jan 2025 16:24:36 +0100 Subject: [PATCH 052/102] make edge_index global --- src/gfn/gym/graph_building.py | 21 ++++++++------------- testing/test_environments.py | 3 +-- tutorials/examples/test_graph_ring.py | 8 ++++---- 3 files changed, 13 insertions(+), 19 deletions(-) diff --git a/src/gfn/gym/graph_building.py b/src/gfn/gym/graph_building.py index fd894e3a..6a6d5d02 100644 --- a/src/gfn/gym/graph_building.py +++ b/src/gfn/gym/graph_building.py @@ -89,8 +89,8 @@ def step(self, states: GraphStates, actions: GraphActions) -> TensorDict: ) state_tensor["edge_index"] = torch.cat( [ - state_tensor["edge_index"], - actions.edge_index + state_tensor["batch_ptr"][:-1][:, None], + states.tensor["edge_index"], + actions.edge_index, ], dim=0, ) @@ -124,11 +124,8 @@ def backward_step(self, states: GraphStates, actions: GraphActions) -> TensorDic state_tensor["node_feature"] = state_tensor["node_feature"][~is_equal] elif action_type == GraphActionType.ADD_EDGE: assert actions.edge_index is not None - global_edge_index = ( - actions.edge_index + state_tensor["batch_ptr"][:-1][:, None] - ) is_equal = torch.all( - state_tensor["edge_index"] == global_edge_index[:, None], dim=-1 + states.tensor["edge_index"] == actions.edge_index[:, None], dim=-1 ) is_equal = torch.any(is_equal, dim=0) state_tensor["edge_feature"] = state_tensor["edge_feature"][~is_equal] @@ -156,11 +153,9 @@ def is_action_valid( if not torch.any(add_edge_mask): add_edge_out = True else: - add_edge_states = states[add_edge_mask].tensor - add_edge_actions = actions[add_edge_mask] - if torch.any( - add_edge_actions.edge_index[:, 0] == add_edge_actions.edge_index[:, 1] - ): + add_edge_states = states.tensor + add_edge_actions = actions[add_edge_mask].edge_index + if torch.any(add_edge_actions[:, 0] == add_edge_actions[:, 1]): return False if add_edge_states["node_feature"].shape[0] == 0: return False @@ -222,8 +217,8 @@ def _add_node( ] ) - # Update edge indices - # Increment indices for edges after the current graph + # Increment indices for edges after the graph at graph_idx, + # i.e. the edges that point to nodes after end_ptr edge_mask_0 = modified_dict["edge_index"][:, 0] >= end_ptr edge_mask_1 = modified_dict["edge_index"][:, 1] >= end_ptr modified_dict["edge_index"][edge_mask_0, 0] += shift diff --git a/testing/test_environments.py b/testing/test_environments.py index 835b6687..821c6b31 100644 --- a/testing/test_environments.py +++ b/testing/test_environments.py @@ -437,8 +437,7 @@ def test_graph_env(): { "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_EDGE), "features": states.tensor["edge_feature"][edge_idx], - "edge_index": states.tensor["edge_index"][edge_idx] - - states.tensor["batch_ptr"][:-1][:, None], + "edge_index": states.tensor["edge_index"][edge_idx], }, batch_size=BATCH_SIZE, ) diff --git a/tutorials/examples/test_graph_ring.py b/tutorials/examples/test_graph_ring.py index b6d401f7..76ce5437 100644 --- a/tutorials/examples/test_graph_ring.py +++ b/tutorials/examples/test_graph_ring.py @@ -166,14 +166,14 @@ def backward_masks(self, value: torch.Tensor): return RingStates def _step(self, states: GraphStates, actions: Actions) -> GraphStates: - actions = self.convert_actions(actions) + actions = self.convert_actions(states, actions) return super()._step(states, actions) def _backward_step(self, states: GraphStates, actions: Actions) -> GraphStates: - actions = self.convert_actions(actions) + actions = self.convert_actions(states, actions) return super()._backward_step(states, actions) - def convert_actions(self, actions: Actions) -> GraphActions: + def convert_actions(self, states: GraphStates, actions: Actions) -> GraphActions: action_tensor = actions.tensor.squeeze(-1) action_type = torch.where( action_tensor == 0, GraphActionType.EXIT, GraphActionType.ADD_EDGE @@ -187,7 +187,7 @@ def convert_actions(self, actions: Actions) -> GraphActions: { "action_type": action_type, "features": torch.ones(action_tensor.shape + (1,)), - "edge_index": edge_index, + "edge_index": edge_index + states.tensor["batch_ptr"][:-1][:, None], }, batch_size=action_tensor.shape, ) From e6d909be9011293cfde94ba91abd90b93f364ef3 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Thu, 30 Jan 2025 00:35:48 +0100 Subject: [PATCH 053/102] fix test_env --- testing/test_environments.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/testing/test_environments.py b/testing/test_environments.py index 44af115f..49f839a0 100644 --- a/testing/test_environments.py +++ b/testing/test_environments.py @@ -393,8 +393,8 @@ def test_graph_env(): states = env.step(states, actions) for i in range(NUM_NODES - 1): - node_is = torch.full((BATCH_SIZE,), i) - node_js = torch.full((BATCH_SIZE,), i + 1) + node_is = torch.arange(BATCH_SIZE) * NUM_NODES + i + node_js = torch.arange(BATCH_SIZE) * NUM_NODES + i + 1 actions = action_cls( TensorDict( { From da66adb967820b33aafea009c3c96d8c755c5306 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Wed, 5 Feb 2025 00:13:34 +0100 Subject: [PATCH 054/102] add global edge + pair programming session Co-authored-by: Joseph Viviano --- src/gfn/gflownet/flow_matching.py | 13 +- src/gfn/gym/graph_building.py | 32 ++--- src/gfn/modules.py | 2 +- src/gfn/states.py | 116 +++++++++++------ src/gfn/utils/distributions.py | 12 +- testing/test_environments.py | 10 +- testing/test_samplers_and_trajectories.py | 10 +- tutorials/examples/train_graph_ring.py | 148 ++++++++++++++++------ 8 files changed, 220 insertions(+), 123 deletions(-) diff --git a/src/gfn/gflownet/flow_matching.py b/src/gfn/gflownet/flow_matching.py index 606c0a8a..9b44d5e0 100644 --- a/src/gfn/gflownet/flow_matching.py +++ b/src/gfn/gflownet/flow_matching.py @@ -33,10 +33,10 @@ class FMGFlowNet(GFlowNet[Tuple[DiscreteStates, DiscreteStates]]): def __init__(self, logF: DiscretePolicyEstimator, alpha: float = 1.0): super().__init__() - # assert isinstance( # TODO: need a more flexible type check. - # logF, - # DiscretePolicyEstimator | ConditionalDiscretePolicyEstimator, - # ), "logF must be a DiscretePolicyEstimator or ConditionalDiscretePolicyEstimator" + assert isinstance( + logF, + DiscretePolicyEstimator | ConditionalDiscretePolicyEstimator, + ), "logF must be a DiscretePolicyEstimator or ConditionalDiscretePolicyEstimator" self.logF = logF self.alpha = alpha @@ -169,10 +169,13 @@ def reward_matching_loss( else: with no_conditioning_exception_handler("logF", self.logF): log_edge_flows = self.logF(terminating_states) - + # Handle the boundary condition (for all x, F(X->S_f) = R(x)). terminating_log_edge_flows = log_edge_flows[:, -1] log_rewards = terminating_states.log_rewards + + print(terminating_log_edge_flows - log_rewards) + return (terminating_log_edge_flows - log_rewards).pow(2).mean() def loss( diff --git a/src/gfn/gym/graph_building.py b/src/gfn/gym/graph_building.py index 952ea16f..1277d727 100644 --- a/src/gfn/gym/graph_building.py +++ b/src/gfn/gym/graph_building.py @@ -157,19 +157,17 @@ def is_action_valid( if torch.any(add_edge_actions[:, 0] == add_edge_actions[:, 1]): return False if add_edge_states["node_feature"].shape[0] == 0: - return False - if torch.any( - add_edge_actions > add_edge_states["node_feature"].shape[0] - ): + return False + node_exists = torch.isin(add_edge_actions, add_edge_states["node_index"]) + if not torch.all(node_exists): return False equal_edges_per_batch = torch.all( add_edge_states["edge_index"] == add_edge_actions[:, None], dim=-1, ).sum(dim=-1) - if backward: - add_edge_out = torch.all(equal_edges_per_batch == 1) + add_edge_out = torch.all(equal_edges_per_batch != 0) else: add_edge_out = torch.all(equal_edges_per_batch == 0) @@ -194,17 +192,15 @@ def _add_node( for graph_idx, new_nodes in zip(batch_indices, nodes_to_add): tensor_dict["batch_ptr"][graph_idx] end_ptr = tensor_dict["batch_ptr"][graph_idx + 1] - - if new_nodes.ndim == 1: - new_nodes = new_nodes.unsqueeze(0) + new_nodes = torch.atleast_2d(new_nodes) if new_nodes.shape[1] != node_feature_dim: raise ValueError( f"Node features must have dimension {node_feature_dim}" ) # Update batch pointers for subsequent graphs - shift = new_nodes.shape[0] - modified_dict["batch_ptr"][graph_idx + 1 :] += shift + num_new_nodes = new_nodes.shape[0] + modified_dict["batch_ptr"][graph_idx + 1 :] += num_new_nodes # Expand node features modified_dict["node_feature"] = torch.cat( @@ -214,13 +210,13 @@ def _add_node( modified_dict["node_feature"][end_ptr:], ] ) - - # Increment indices for edges after the graph at graph_idx, - # i.e. the edges that point to nodes after end_ptr - edge_mask_0 = modified_dict["edge_index"][:, 0] >= end_ptr - edge_mask_1 = modified_dict["edge_index"][:, 1] >= end_ptr - modified_dict["edge_index"][edge_mask_0, 0] += shift - modified_dict["edge_index"][edge_mask_1, 1] += shift + modified_dict["node_index"] = torch.cat( + [ + modified_dict["node_index"][:end_ptr], + GraphStates.unique_node_indices(num_new_nodes), + modified_dict["node_index"][end_ptr:], + ] + ) return modified_dict diff --git a/src/gfn/modules.py b/src/gfn/modules.py index 286a9730..a2ef72d6 100644 --- a/src/gfn/modules.py +++ b/src/gfn/modules.py @@ -536,7 +536,7 @@ def to_probability_distribution( 1 - epsilon ) * edge_index_probs + epsilon * uniform_dist_probs edge_index_probs[torch.isnan(edge_index_probs)] = 1 - dists["edge_index"] = CategoricalIndexes(probs=edge_index_probs, n=N) + dists["edge_index"] = CategoricalIndexes(probs=edge_index_probs, node_indexes=states.tensor["node_index"]) dists["features"] = Normal(module_output["features"], temperature) return CompositeDistribution(dists=dists) diff --git a/src/gfn/states.py b/src/gfn/states.py index 38dde67e..ded01374 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -520,9 +520,12 @@ class GraphStates(States): s0: ClassVar[TensorDict] sf: ClassVar[TensorDict] + _next_node_index = 0 + def __init__(self, tensor: TensorDict): REQUIRED_KEYS = { "node_feature", + "node_index", "edge_feature", "edge_index", "batch_ptr", @@ -533,6 +536,7 @@ def __init__(self, tensor: TensorDict): f"TensorDict must contain all required keys: {REQUIRED_KEYS}" ) + assert tensor["node_index"].unique().numel() == len(tensor["node_index"]) self.tensor = tensor self.node_features_dim = tensor["node_feature"].shape[-1] self.edge_features_dim = tensor["edge_feature"].shape[-1] @@ -559,10 +563,12 @@ def from_batch_shape( @classmethod def make_initial_states_tensor(cls, batch_shape: int | Tuple) -> TensorDict: batch_shape = batch_shape if isinstance(batch_shape, Tuple) else (batch_shape,) - + nodes = cls.s0["node_feature"].repeat(np.prod(batch_shape), 1) + return TensorDict( { - "node_feature": cls.s0["node_feature"].repeat(np.prod(batch_shape), 1), + "node_feature": nodes, + "node_index": GraphStates.unique_node_indices(nodes.shape[0]), "edge_feature": cls.s0["edge_feature"].repeat(np.prod(batch_shape), 1), "edge_index": cls.s0["edge_index"].repeat(np.prod(batch_shape), 1), "batch_ptr": torch.arange(np.prod(batch_shape) + 1) @@ -577,9 +583,11 @@ def make_sink_states_tensor(cls, batch_shape: int | Tuple) -> TensorDict: raise NotImplementedError("Sink state is not defined") batch_shape = batch_shape if isinstance(batch_shape, Tuple) else (batch_shape,) + nodes = cls.sf["node_feature"].repeat(np.prod(batch_shape), 1) out = TensorDict( { - "node_feature": cls.sf["node_feature"].repeat(np.prod(batch_shape), 1), + "node_feature": nodes, + "node_index": GraphStates.unique_node_indices(nodes.shape[0]), "edge_feature": cls.sf["edge_feature"].repeat(np.prod(batch_shape), 1), "edge_index": cls.sf["edge_index"].repeat(np.prod(batch_shape), 1), "batch_ptr": torch.arange( @@ -605,6 +613,7 @@ def make_random_states_tensor(cls, batch_shape: int | Tuple) -> TensorDict: "node_feature": torch.rand( np.prod(batch_shape) * num_nodes, node_features_dim, device=device ), + "node_index": GraphStates.unique_node_indices(np.prod(batch_shape) * num_nodes), "edge_feature": torch.rand( np.prod(batch_shape) * num_edges, edge_features_dim, device=device ), @@ -635,36 +644,34 @@ def __getitem__( if torch.any(index >= len(self.tensor["batch_ptr"]) - 1): raise ValueError("Graph index out of bounds") + # TODO: explain batch_ptr and node_index semantics start_ptrs = self.tensor["batch_ptr"][:-1][index] end_ptrs = self.tensor["batch_ptr"][1:][index] node_features = [torch.empty(0, self.node_features_dim)] + node_indices = [torch.empty(0, dtype=torch.long)] edge_features = [torch.empty(0, self.edge_features_dim)] edge_indices = [torch.empty(0, 2, dtype=torch.long)] batch_ptr = [0] for start, end in zip(start_ptrs, end_ptrs): - graph_nodes = self.tensor["node_feature"][start:end] - node_features.append(graph_nodes) + node_features.append(self.tensor["node_feature"][start:end]) + node_indices.append(self.tensor["node_index"][start:end]) + batch_ptr.append(batch_ptr[-1] + end - start) # Find edges for this graph - edge_mask = (self.tensor["edge_index"][:, 0] >= start) & ( - self.tensor["edge_index"][:, 0] < end - ) - graph_edges = self.tensor["edge_feature"][edge_mask] - edge_features.append(graph_edges) - - # Adjust edge indices to be local to this graph - graph_edge_index = self.tensor["edge_index"][edge_mask] - graph_edge_index[:, 0] -= start - batch_ptr[-1] - graph_edge_index[:, 1] -= start - batch_ptr[-1] - edge_indices.append(graph_edge_index) - batch_ptr.append(batch_ptr[-1] + len(graph_nodes)) + if self.tensor["node_index"].numel() > 0: + edge_mask = (self.tensor["edge_index"][:, 0] >= self.tensor["node_index"][start]) & ( + self.tensor["edge_index"][:, 0] <= self.tensor["node_index"][end - 1] + ) + edge_features.append(self.tensor["edge_feature"][edge_mask]) + edge_indices.append(self.tensor["edge_index"][edge_mask]) out = self.__class__( TensorDict( { "node_feature": torch.cat(node_features), + "node_index": torch.cat(node_indices), "edge_feature": torch.cat(edge_features), "edge_index": torch.cat(edge_indices), "batch_ptr": torch.tensor(batch_ptr), @@ -676,6 +683,7 @@ def __getitem__( if self._log_rewards is not None: out._log_rewards = self._log_rewards[index] + assert out.tensor["node_index"].unique().numel() == len(out.tensor["node_index"]) return out def __setitem__(self, index: int | Sequence[int], graph: GraphStates): @@ -709,19 +717,13 @@ def __setitem__(self, index: int | Sequence[int], graph: GraphStates): new_nodes = source_tensor_dict["node_feature"][ source_start_ptr:source_end_ptr ] - - # Ensure new nodes have correct feature dimension - if new_nodes.ndim == 1: - new_nodes = new_nodes.unsqueeze(0) + new_nodes = torch.atleast_2d(new_nodes) if new_nodes.shape[1] != self.node_features_dim: raise ValueError( f"Node features must have dimension {self.node_features_dim}" ) - # Number of new nodes to add - shift = new_nodes.shape[0] - (end_ptr - start_ptr) - # Concatenate node features self.tensor["node_feature"] = torch.cat( [ @@ -735,24 +737,21 @@ def __setitem__(self, index: int | Sequence[int], graph: GraphStates): ] ) - # Update edge indices for subsequent graphs - edge_mask = self.tensor["edge_index"] >= end_ptr - assert torch.all(edge_mask[..., 0] == edge_mask[..., 1]) - edge_mask = torch.all(edge_mask, dim=-1) - self.tensor["edge_index"][edge_mask] += shift - edge_mask |= torch.all(self.tensor["edge_index"] < start_ptr, dim=-1) + edge_mask = torch.empty(0, dtype=torch.bool) + if self.tensor["edge_index"].numel() > 0: + edge_mask = torch.all(self.tensor["edge_index"] > self.tensor["node_index"][end_ptr - 1], dim=-1) + edge_mask |= torch.all(self.tensor["edge_index"] < self.tensor["node_index"][start_ptr], dim=-1) + edge_to_add_mask = torch.all( - source_tensor_dict["edge_index"] >= source_start_ptr, dim=-1 + source_tensor_dict["edge_index"] >= source_tensor_dict["node_index"][source_start_ptr], dim=-1 ) edge_to_add_mask &= torch.all( - source_tensor_dict["edge_index"] < source_end_ptr, dim=-1 + source_tensor_dict["edge_index"] <= source_tensor_dict["node_index"][source_end_ptr - 1], dim=-1 ) self.tensor["edge_index"] = torch.cat( [ self.tensor["edge_index"][edge_mask], source_tensor_dict["edge_index"][edge_to_add_mask] - - source_start_ptr - + start_ptr, ], dim=0, ) @@ -764,8 +763,18 @@ def __setitem__(self, index: int | Sequence[int], graph: GraphStates): dim=0, ) + self.tensor["node_index"] = torch.cat( + [ + self.tensor["node_index"][:start_ptr], + source_tensor_dict["node_index"][source_start_ptr:source_end_ptr], + self.tensor["node_index"][end_ptr:], + ] + ) # Update batch pointers + shift = new_nodes.shape[0] - (end_ptr - start_ptr) self.tensor["batch_ptr"][graph_idx + 1 :] += shift + + assert self.tensor["node_index"].unique().numel() == len(self.tensor["node_index"]) @property def device(self) -> torch.device | None: @@ -787,13 +796,30 @@ def extend(self, other: GraphStates): self.tensor["node_feature"] = torch.cat( [self.tensor["node_feature"], other.tensor["node_feature"]], dim=0 ) + + # find if there are common node indices + other_node_index = other.tensor["node_index"] + other_edge_index = other.tensor["edge_index"] + common_node_indices = torch.any(self.tensor["node_index"][:, None] == other_node_index[None, :], dim=0) + if torch.any(common_node_indices): + new_indices = self.unique_node_indices(torch.sum(common_node_indices)) + # find edge_index which contains other_node_index[common_node_indices] + edge_mask = other_edge_index[:, None] == other_node_index[common_node_indices, None] + repeat_indices = new_indices[None, :, None].repeat(edge_mask.shape[0], 1, 2) + other_edge_index[torch.any(edge_mask, dim=1)] = repeat_indices[edge_mask] + + other_node_index[common_node_indices] = new_indices + + self.tensor["node_index"] = torch.cat( + [self.tensor["node_index"], other_node_index], dim=0 + ) self.tensor["edge_feature"] = torch.cat( [self.tensor["edge_feature"], other.tensor["edge_feature"]], dim=0 ) self.tensor["edge_index"] = torch.cat( [ self.tensor["edge_index"], - other.tensor["edge_index"] + self.tensor["batch_ptr"][-1], + other.tensor["edge_index"] ], dim=0, ) @@ -830,11 +856,11 @@ def _compare(self, other: TensorDict) -> torch.Tensor: self.tensor["node_feature"][start:end] == other["node_feature"] ) edge_mask = torch.all( - (self.tensor["edge_index"] >= start) - & (self.tensor["edge_index"] < end), + (self.tensor["edge_index"] >= self.tensor["node_index"][start]) + & (self.tensor["edge_index"] <= self.tensor["node_index"][end - 1]), dim=-1, ) - edge_index = self.tensor["edge_index"][edge_mask] - start + edge_index = self.tensor["edge_index"][edge_mask] out[i] &= len(edge_index) == len(other["edge_index"]) and torch.all( edge_index == other["edge_index"] ) @@ -862,6 +888,7 @@ def stack(cls, states: List[GraphStates]): stacked_states.extend(state) stacked_states.tensor["batch_shape"] = (len(states),) + state_batch_shape + assert stacked_states.tensor["node_index"].unique().numel() == len(stacked_states.tensor["node_index"]) return stacked_states @property @@ -890,8 +917,11 @@ def forward_masks(self) -> TensorDict: same_graph_mask = (arange_nodes >= self.tensor["batch_ptr"][:-1, None]) & ( arange_nodes < self.tensor["batch_ptr"][1:, None] ) - ei1 = self.tensor["edge_index"][..., 0] - ei2 = self.tensor["edge_index"][..., 1] + edge_index = torch.where( + self.tensor["edge_index"][..., None] == self.tensor["node_index"] + )[2].reshape(self.tensor["edge_index"].shape) + ei1 = edge_index[..., 0] + ei2 = edge_index[..., 1] for _ in range(len(self.batch_shape)): ( ei1, @@ -965,3 +995,9 @@ def backward_masks(self) -> TensorDict: ), ei2.unsqueeze(0) backward_masks["edge_index"][arange[..., None], ei1, ei2] = False return backward_masks + + @classmethod + def unique_node_indices(cls, num_new_nodes: int) -> torch.Tensor: + indices = torch.arange(cls._next_node_index, cls._next_node_index + num_new_nodes) + cls._next_node_index += num_new_nodes + return indices \ No newline at end of file diff --git a/src/gfn/utils/distributions.py b/src/gfn/utils/distributions.py index 715e73eb..cfb9b96b 100644 --- a/src/gfn/utils/distributions.py +++ b/src/gfn/utils/distributions.py @@ -73,24 +73,26 @@ def log_prob(self, sample: Dict[str, torch.Tensor]) -> torch.Tensor: class CategoricalIndexes(Categorical): """Samples indexes from a categorical distribution.""" - def __init__(self, probs: torch.Tensor, n: int): + def __init__(self, probs: torch.Tensor, node_indexes: torch.Tensor): """Initializes the distribution. Args: probs: The probabilities of the categorical distribution. n: The number of nodes in the graph. """ - self.n = n - assert probs.shape == (probs.shape[0], self.n * self.n) + self.node_indexes = node_indexes + assert probs.shape == (probs.shape[0], node_indexes.shape[0] * node_indexes.shape[0]) super().__init__(probs) def sample(self, sample_shape=torch.Size()) -> torch.Tensor: samples = super().sample(sample_shape) - out = torch.stack([samples // self.n, samples % self.n], dim=-1) + out = torch.stack([samples // self.node_indexes.shape[0], samples % self.node_indexes.shape[0]], dim=-1) + out = self.node_indexes.index_select(0, out.flatten()).reshape(*out.shape) return out def log_prob(self, value): - value = value[..., 0] * self.n + value[..., 1] + value = value[..., 0] * self.node_indexes.shape[0] + value[..., 1] + value = torch.bucketize(value, self.node_indexes) return super().log_prob(value) diff --git a/testing/test_environments.py b/testing/test_environments.py index 49f839a0..29681ff6 100644 --- a/testing/test_environments.py +++ b/testing/test_environments.py @@ -422,15 +422,7 @@ def test_graph_env(): sf_states = env.step(states, actions) sf_states = env.States(sf_states) assert torch.all(sf_states.is_sink_state) - env.reward(states) - - # with pytest.raises(NonValidActionsError): - # node_idx = torch.arange(0, BATCH_SIZE) - # actions = action_cls( - # GraphActionType.ADD_NODE, - # states.data.x[node_idxs], - # ) - # states = env.backward_step(states, actions) + env.reward(sf_states) num_edges_per_batch = len(states.tensor["edge_feature"]) // BATCH_SIZE for i in reversed(range(num_edges_per_batch)): diff --git a/testing/test_samplers_and_trajectories.py b/testing/test_samplers_and_trajectories.py index 2800a8eb..b5f4294d 100644 --- a/testing/test_samplers_and_trajectories.py +++ b/testing/test_samplers_and_trajectories.py @@ -319,21 +319,23 @@ def __init__(self, feature_dim: int): def forward(self, states: GraphStates) -> TensorDict: node_feature = states.tensor["node_feature"].reshape(-1, self.feature_dim) - edge_index = states.tensor["edge_index"].T + edge_index = torch.where( + states.tensor["edge_index"][..., None] == states.tensor["node_index"] + )[2].reshape(states.tensor["edge_index"].shape) if states.tensor["node_feature"].shape[0] == 0: action_type = torch.zeros((len(states), 3)) action_type[:, GraphActionType.ADD_NODE] = 1 features = torch.zeros((len(states), self.feature_dim)) else: - action_type = self.action_type_conv(node_feature, edge_index) + action_type = self.action_type_conv(node_feature, edge_index.T) action_type = action_type.reshape( len(states), -1, action_type.shape[-1] ).mean(dim=1) - features = self.features_conv(node_feature, edge_index) + features = self.features_conv(node_feature, edge_index.T) features = features.reshape(len(states), -1, features.shape[-1]).mean(dim=1) - edge_index = self.edge_index_conv(node_feature, edge_index) + edge_index = self.edge_index_conv(node_feature, edge_index.T) edge_index = torch.einsum("nf,mf->nm", edge_index, edge_index) edge_index = edge_index[None].repeat(len(states), 1, 1) diff --git a/tutorials/examples/train_graph_ring.py b/tutorials/examples/train_graph_ring.py index 58d3659b..e245289d 100644 --- a/tutorials/examples/train_graph_ring.py +++ b/tutorials/examples/train_graph_ring.py @@ -8,6 +8,7 @@ import torch from tensordict import TensorDict from torch import nn +import torch.nn.functional as F from torch_geometric.nn import GCNConv from gfn.actions import Actions, GraphActions, GraphActionType @@ -22,32 +23,68 @@ def state_evaluator(states: GraphStates) -> torch.Tensor: eps = 1e-6 if states.tensor["edge_index"].shape[0] == 0: return torch.full(states.batch_shape, eps) - if states.tensor["edge_index"].shape[0] != states.tensor["node_feature"].shape[0]: - return torch.full(states.batch_shape, eps) - out = torch.zeros(len(states)) + out = torch.full((len(states),), eps) # Default reward. + for i in range(len(states)): start, end = states.tensor["batch_ptr"][i], states.tensor["batch_ptr"][i + 1] + nodes_index_range = states.tensor["node_index"][start:end] edge_index_mask = torch.all( - states.tensor["edge_index"] >= start, dim=-1 - ) & torch.all(states.tensor["edge_index"] < end, dim=-1) - edge_index = states.tensor["edge_index"][edge_index_mask] - arange = torch.arange(start, end) - # TODO: not correct, accepts multiple rings - if torch.all(torch.sort(edge_index[:, 0])[0] == arange) and torch.all( - torch.sort(edge_index[:, 1])[0] == arange - ): - out[i] = 1 - else: - out[i] = eps + states.tensor["edge_index"] >= nodes_index_range[0], dim=-1 + ) & torch.all(states.tensor["edge_index"] <= nodes_index_range[-1], dim=-1) + masked_edge_index = states.tensor["edge_index"][edge_index_mask] - nodes_index_range[0] + + n_nodes = nodes_index_range.shape[0] + adj_matrix = torch.zeros(n_nodes, n_nodes) + adj_matrix[masked_edge_index[:, 0], masked_edge_index[:, 1]] = 1 + + # # Matrix must be symmetric (undirected graph). + # if not torch.all(adj_matrix == adj_matrix.T): + # continue + # Each vertex must have exactly degree 2 (sum of each row = 2). + if not torch.all(adj_matrix.sum(axis=1) == 1): + continue + + # Connectivity check: Start at vertex 0 and follow edges, keep track + # of visited edges, visit all edges once, end at vertex 0. + visited = [] + current = 0 + while current not in visited: + visited.append(current) + + def set_diff(tensor1, tensor2): + mask = ~torch.isin(tensor1, tensor2) + return tensor1[mask] + + # Find an unvisited neighbor + neighbors = torch.where(adj_matrix[current] == 1)[0] + valid_neighbours = set_diff(neighbors, torch.tensor(visited)) + + # Visit the fir + if len(valid_neighbours) == 1: + current = valid_neighbours[0] + elif len(valid_neighbours) == 0: + break + else: + break # TODO: This actually should never happen, should be caught on line 45. + + # Check if we visited all vertices and the last vertex connects back to start. + if len(visited) == n_nodes and adj_matrix[current][0] == 1: + out[i] = 1.0 + return out.view(*states.batch_shape) class RingPolicyEstimator(nn.Module): - def __init__(self, n_nodes: int, edge_hidden_dim: int = 128): + def __init__( + self, + n_nodes: int, + action_hidden_dim: int = 16, + edge_hidden_dim: int = 16, + ): super().__init__() - self.action_type_conv = GCNConv(1, 1) - self.edge_index_conv = GCNConv(1, edge_hidden_dim) + self.action_type_conv = GCNConv(n_nodes, action_hidden_dim) + self.edge_index_conv = GCNConv(n_nodes, edge_hidden_dim) self.n_nodes = n_nodes self.edge_hidden_dim = edge_hidden_dim @@ -58,21 +95,27 @@ def _group_sum(self, tensor: torch.Tensor, batch_ptr: torch.Tensor) -> torch.Ten device=tensor.device, ) cumsum[1:] = torch.cumsum(tensor, dim=0) + + # Subtract the end val from each batch idx fom the start val of each batch idx. return cumsum[batch_ptr[1:]] - cumsum[batch_ptr[:-1]] def forward(self, states_tensor: TensorDict) -> torch.Tensor: - node_feature = states_tensor["node_feature"].reshape(-1, 1) - edge_index = states_tensor["edge_index"].T - batch_ptr = states_tensor["batch_ptr"] + node_feature, batch_ptr = states_tensor["node_feature"], states_tensor["batch_ptr"] - action_type = self.action_type_conv(node_feature, edge_index) - action_type = self._group_sum(action_type, batch_ptr) + edge_index = torch.where( + states_tensor["edge_index"][..., None] == states_tensor["node_index"] + )[2].reshape(states_tensor["edge_index"].shape) # (M, 2) - edge_index = self.edge_index_conv(node_feature, edge_index) + action_type = self.action_type_conv(node_feature, edge_index.T) + action_type = self._group_sum(torch.mean(action_type, dim=-1, keepdim=True), batch_ptr) + + edge_index = self.edge_index_conv(node_feature, edge_index.T) edge_index = edge_index.reshape( *states_tensor["batch_shape"], self.n_nodes, self.edge_hidden_dim ) edge_index = torch.einsum("bnf,bmf->bnm", edge_index, edge_index) + + # edge_actions = edge_index.reshape( *states_tensor["batch_shape"], self.n_nodes * self.n_nodes ) @@ -81,10 +124,10 @@ def forward(self, states_tensor: TensorDict) -> torch.Tensor: class RingGraphBuilding(GraphBuilding): - def __init__(self, n_nodes: int = 10): + def __init__(self, n_nodes: int): self.n_nodes = n_nodes self.n_actions = 1 + n_nodes * n_nodes - super().__init__(feature_dim=1, state_evaluator=state_evaluator) + super().__init__(feature_dim=n_nodes, state_evaluator=state_evaluator) self.is_discrete = True # actions here are discrete, needed for FlowMatching def make_actions_class(self) -> type[Actions]: @@ -105,7 +148,7 @@ def make_states_class(self) -> type[GraphStates]: class RingStates(GraphStates): s0 = TensorDict( { - "node_feature": torch.arange(env.n_nodes).unsqueeze(-1), + "node_feature": F.one_hot(torch.arange(env.n_nodes), num_classes=env.n_nodes).float(), "edge_feature": torch.ones((0, 1)), "edge_index": torch.ones((0, 2), dtype=torch.long), }, @@ -113,7 +156,7 @@ class RingStates(GraphStates): ) sf = TensorDict( { - "node_feature": torch.zeros((env.n_nodes, 1)), + "node_feature": torch.zeros((env.n_nodes, env.n_nodes)), "edge_feature": torch.zeros((0, 1)), "edge_index": torch.zeros((0, 2), dtype=torch.long), }, @@ -134,7 +177,7 @@ def forward_masks(self): forward_masks = torch.ones(len(self), self.n_actions, dtype=torch.bool) forward_masks[:, 1 :: self.n_nodes + 1] = False for i in range(len(self)): - existing_edges = self[i].tensor["edge_index"] + existing_edges = self[i].tensor["edge_index"] - self.tensor['node_index'][self.tensor['batch_ptr'][i]] forward_masks[ i, 1 + existing_edges[:, 0] * self.n_nodes + existing_edges[:, 1], @@ -152,7 +195,7 @@ def backward_masks(self): len(self), self.n_actions, dtype=torch.bool ) for i in range(len(self)): - existing_edges = self[i].tensor["edge_index"] + existing_edges = self[i].tensor["edge_index"] - self.tensor['node_index'][self.tensor['batch_ptr'][i]] backward_masks[ i, 1 + existing_edges[:, 0] * self.n_nodes + existing_edges[:, 1], @@ -168,7 +211,8 @@ def backward_masks(self, value: torch.Tensor): def _step(self, states: GraphStates, actions: Actions) -> GraphStates: actions = self.convert_actions(states, actions) - return super()._step(states, actions) + out = super()._step(states, actions) + return out def _backward_step(self, states: GraphStates, actions: Actions) -> GraphStates: actions = self.convert_actions(states, actions) @@ -183,12 +227,13 @@ def convert_actions(self, states: GraphStates, actions: Actions) -> GraphActions edge_index_i1 = (action_tensor - 1) % (self.n_nodes) edge_index = torch.stack([edge_index_i0, edge_index_i1], dim=-1) + offset = states.tensor["node_index"][states.tensor["batch_ptr"][:-1]] return GraphActions( TensorDict( { "action_type": action_type, "features": torch.ones(action_tensor.shape + (1,)), - "edge_index": edge_index + states.tensor["batch_ptr"][:-1][:, None], + "edge_index": edge_index + offset[:, None], }, batch_size=action_tensor.shape, ) @@ -225,7 +270,11 @@ def render_states(states: GraphStates): plt.Circle((x, y), 0.5, facecolor="none", edgecolor="black") ) - for edge in state.tensor["edge_index"]: + edge_index = states[i].tensor["edge_index"] + edge_index = torch.where( + edge_index[..., None] == states[i].tensor["node_index"] + )[2].reshape(edge_index.shape) + for edge in edge_index: start_x, start_y = xs[edge[0]], ys[edge[0]] end_x, end_y = xs[edge[1]], ys[edge[1]] dx = end_x - start_x @@ -262,27 +311,43 @@ def render_states(states: GraphStates): if __name__ == "__main__": + # ring_state = GraphStates( + # TensorDict( + # { + # "node_feature": torch.tensor([[0], [1], [2]]), + # "node_index": torch.tensor([0, 1, 2]), + # "edge_feature": torch.ones((3, 1)), + # "edge_index": torch.tensor([[1, 0], [1, 2], [2, 0]]), + # "batch_ptr": torch.tensor([0, 3]), + # "batch_shape": torch.ones((1,), dtype=torch.long), + # }, + # batch_size=(), + # ) + # ) + # print(state_evaluator(ring_state)) + N_NODES = 3 - N_ITERATIONS = 1024 - torch.random.manual_seed(42) + N_ITERATIONS = 128 + torch.random.manual_seed(7) env = RingGraphBuilding(n_nodes=N_NODES) module = RingPolicyEstimator(env.n_nodes) - pf_estimator = DiscretePolicyEstimator( + logf_estimator = DiscretePolicyEstimator( module=module, n_actions=env.n_actions, preprocessor=GraphPreprocessor() ) - gflownet = FMGFlowNet(pf_estimator) + gflownet = FMGFlowNet(logf_estimator) optimizer = torch.optim.Adam(gflownet.parameters(), lr=1e-2) - scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( - optimizer, T_max=N_ITERATIONS, eta_min=1e-4 - ) + batch_size = 32 + # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + # optimizer, T_max=N_ITERATIONS, eta_min=1e-4 + # ) losses = [] t1 = time.time() for iteration in range(N_ITERATIONS): - trajectories = gflownet.sample_trajectories(env, n=64) + trajectories = gflownet.sample_trajectories(env, n=batch_size) samples = gflownet.to_training_samples(trajectories) optimizer.zero_grad() loss = gflownet.loss(env, samples) @@ -290,7 +355,8 @@ def render_states(states: GraphStates): loss.backward() optimizer.step() losses.append(loss.item()) - scheduler.step() + # scheduler.step() + t2 = time.time() print("Time:", t2 - t1) render_states(trajectories.last_states[:8]) From 713028130ff99381b13d305c516288a4b532781c Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Thu, 6 Feb 2025 00:25:47 +0100 Subject: [PATCH 055/102] pair programming session Co-authored-by: Joseph Viviano --- src/gfn/gflownet/flow_matching.py | 2 -- src/gfn/states.py | 13 ++++--- testing/test_states.py | 50 ++++++++++++++++++++++++++ tutorials/examples/train_graph_ring.py | 34 +++++++++--------- 4 files changed, 77 insertions(+), 22 deletions(-) create mode 100644 testing/test_states.py diff --git a/src/gfn/gflownet/flow_matching.py b/src/gfn/gflownet/flow_matching.py index 9b44d5e0..e5fd652d 100644 --- a/src/gfn/gflownet/flow_matching.py +++ b/src/gfn/gflownet/flow_matching.py @@ -173,8 +173,6 @@ def reward_matching_loss( # Handle the boundary condition (for all x, F(X->S_f) = R(x)). terminating_log_edge_flows = log_edge_flows[:, -1] log_rewards = terminating_states.log_rewards - - print(terminating_log_edge_flows - log_rewards) return (terminating_log_edge_flows - log_rewards).pow(2).mean() diff --git a/src/gfn/states.py b/src/gfn/states.py index ded01374..e38231d3 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -684,12 +684,14 @@ def __getitem__( out._log_rewards = self._log_rewards[index] assert out.tensor["node_index"].unique().numel() == len(out.tensor["node_index"]) + return out def __setitem__(self, index: int | Sequence[int], graph: GraphStates): """ Set particular states of the Batch """ + # This is to convert index to type int (linear indexing). tensor_idx = torch.arange(len(self)).view(*self.batch_shape) index = tensor_idx[index].flatten() @@ -802,12 +804,15 @@ def extend(self, other: GraphStates): other_edge_index = other.tensor["edge_index"] common_node_indices = torch.any(self.tensor["node_index"][:, None] == other_node_index[None, :], dim=0) if torch.any(common_node_indices): + # This renumbers nodes across batch indices such that all nodes have + # a unique ID. new_indices = self.unique_node_indices(torch.sum(common_node_indices)) - # find edge_index which contains other_node_index[common_node_indices] - edge_mask = other_edge_index[:, None] == other_node_index[common_node_indices, None] - repeat_indices = new_indices[None, :, None].repeat(edge_mask.shape[0], 1, 2) - other_edge_index[torch.any(edge_mask, dim=1)] = repeat_indices[edge_mask] + # find edge_index which contains other_node_index[common_node_indices]. this is + # because all new edges must be to new nodes (unique). + edge_mask = other_edge_index[:, :, None] == other_node_index[None, common_node_indices] + repeat_indices = new_indices[None, None].repeat(edge_mask.shape[0], 2, 1) + other_edge_index[torch.any(edge_mask, dim=-1)] = repeat_indices[edge_mask] other_node_index[common_node_indices] = new_indices self.tensor["node_index"] = torch.cat( diff --git a/testing/test_states.py b/testing/test_states.py new file mode 100644 index 00000000..921971eb --- /dev/null +++ b/testing/test_states.py @@ -0,0 +1,50 @@ +from gfn.states import GraphStates +from tensordict import TensorDict +import torch + +def make_graph_states(n_graphs, n_nodes, n_edges): + batch_ptr = torch.cat([torch.zeros(1), n_nodes.cumsum(0)]).int() + node_feature = torch.randn(batch_ptr[-1].item(), 10) + node_index = torch.arange(0, batch_ptr[-1].item()) + + edge_features = torch.randn(n_edges.sum(), 10) + edge_index = [] + for i, (start, end) in enumerate(zip(batch_ptr[:-1], batch_ptr[1:])): + edge_index.append(torch.randint(start, end, (n_edges[i], 2))) + edge_index = torch.cat(edge_index) + + return GraphStates(TensorDict({ + "node_feature": node_feature, + "edge_feature": edge_features, + "edge_index": edge_index, + "node_index": node_index, + "batch_ptr": batch_ptr, + "batch_shape": torch.tensor([n_graphs]), + })) + + +def test_get_set(): + n_graphs = 10 + n_nodes = torch.randint(1, 10, (n_graphs,)) + n_edges = torch.randint(1, 10, (n_graphs,)) + graphs = make_graph_states(10, n_nodes, n_edges) + assert not graphs[0]._compare(graphs[9].tensor) + last_graph = graphs[9] + graphs = graphs[:-1] + graphs[0] = last_graph + assert graphs[0]._compare(last_graph.tensor) + + +def test_stack(): + GraphStates.s0 = make_graph_states(1, torch.tensor([1]), torch.tensor([0])).tensor + n_graphs = 10 + n_nodes = torch.randint(1, 10, (n_graphs,)) + n_edges = torch.randint(1, 10, (n_graphs,)) + graphs = make_graph_states(10, n_nodes, n_edges) + stacked_graphs = GraphStates.stack([graphs[0], graphs[1]]) + assert stacked_graphs.batch_shape == (2, 1) + assert stacked_graphs[0]._compare(graphs[0].tensor) + assert stacked_graphs[1]._compare(graphs[1].tensor) + +if __name__ == "__main__": + test_get_set() \ No newline at end of file diff --git a/tutorials/examples/train_graph_ring.py b/tutorials/examples/train_graph_ring.py index e245289d..39bbd170 100644 --- a/tutorials/examples/train_graph_ring.py +++ b/tutorials/examples/train_graph_ring.py @@ -115,7 +115,6 @@ def forward(self, states_tensor: TensorDict) -> torch.Tensor: ) edge_index = torch.einsum("bnf,bmf->bnm", edge_index, edge_index) - # edge_actions = edge_index.reshape( *states_tensor["batch_shape"], self.n_nodes * self.n_nodes ) @@ -178,6 +177,7 @@ def forward_masks(self): forward_masks[:, 1 :: self.n_nodes + 1] = False for i in range(len(self)): existing_edges = self[i].tensor["edge_index"] - self.tensor['node_index'][self.tensor['batch_ptr'][i]] + assert torch.all(existing_edges >= 0) # TODO: convert to test. forward_masks[ i, 1 + existing_edges[:, 0] * self.n_nodes + existing_edges[:, 1], @@ -223,8 +223,13 @@ def convert_actions(self, states: GraphStates, actions: Actions) -> GraphActions action_type = torch.where( action_tensor == 0, GraphActionType.EXIT, GraphActionType.ADD_EDGE ) - edge_index_i0 = (action_tensor - 1) // (self.n_nodes) - edge_index_i1 = (action_tensor - 1) % (self.n_nodes) + + # TODO: refactor. + # action = [exit, src_node_idx, dest_node_idx] <- 2n + 1 + # action = [node] <- n^2 + 1 + + edge_index_i0 = (action_tensor - 1) // (self.n_nodes) # column + edge_index_i1 = (action_tensor - 1) % (self.n_nodes) # row edge_index = torch.stack([edge_index_i0, edge_index_i1], dim=-1) offset = states.tensor["node_index"][states.tensor["batch_ptr"][:-1]] @@ -314,11 +319,11 @@ def render_states(states: GraphStates): # ring_state = GraphStates( # TensorDict( # { - # "node_feature": torch.tensor([[0], [1], [2]]), - # "node_index": torch.tensor([0, 1, 2]), - # "edge_feature": torch.ones((3, 1)), - # "edge_index": torch.tensor([[1, 0], [1, 2], [2, 0]]), - # "batch_ptr": torch.tensor([0, 3]), + # "node_feature": torch.tensor([[10], [11]]), + # "node_index": torch.tensor([10, 11]), + # "edge_feature": torch.ones((2, 1)), + # "edge_index": torch.tensor([[10, 11], [11, 10]]), + # "batch_ptr": torch.tensor([0, 2]), # "batch_shape": torch.ones((1,), dtype=torch.long), # }, # batch_size=(), @@ -326,8 +331,8 @@ def render_states(states: GraphStates): # ) # print(state_evaluator(ring_state)) - N_NODES = 3 - N_ITERATIONS = 128 + N_NODES = 2 + N_ITERATIONS = 4096 torch.random.manual_seed(7) env = RingGraphBuilding(n_nodes=N_NODES) module = RingPolicyEstimator(env.n_nodes) @@ -337,17 +342,15 @@ def render_states(states: GraphStates): ) gflownet = FMGFlowNet(logf_estimator) - optimizer = torch.optim.Adam(gflownet.parameters(), lr=1e-2) - batch_size = 32 - # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( - # optimizer, T_max=N_ITERATIONS, eta_min=1e-4 - # ) + optimizer = torch.optim.Adam(gflownet.parameters(), lr=1e-3) + batch_size = 4 losses = [] t1 = time.time() for iteration in range(N_ITERATIONS): trajectories = gflownet.sample_trajectories(env, n=batch_size) + print(torch.count_nonzero(state_evaluator(trajectories.last_states) > 0.1)) samples = gflownet.to_training_samples(trajectories) optimizer.zero_grad() loss = gflownet.loss(env, samples) @@ -355,7 +358,6 @@ def render_states(states: GraphStates): loss.backward() optimizer.step() losses.append(loss.item()) - # scheduler.step() t2 = time.time() print("Time:", t2 - t1) From 1fa01df708a89a725293f72e2cdd90fbf6041927 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Thu, 6 Feb 2025 21:45:43 +0100 Subject: [PATCH 056/102] fix node_index --- src/gfn/states.py | 2 +- tutorials/examples/train_graph_ring.py | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/gfn/states.py b/src/gfn/states.py index e38231d3..25e5acbf 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -806,7 +806,7 @@ def extend(self, other: GraphStates): if torch.any(common_node_indices): # This renumbers nodes across batch indices such that all nodes have # a unique ID. - new_indices = self.unique_node_indices(torch.sum(common_node_indices)) + new_indices = GraphStates.unique_node_indices(torch.sum(common_node_indices)) # find edge_index which contains other_node_index[common_node_indices]. this is # because all new edges must be to new nodes (unique). diff --git a/tutorials/examples/train_graph_ring.py b/tutorials/examples/train_graph_ring.py index 39bbd170..26eb59b2 100644 --- a/tutorials/examples/train_graph_ring.py +++ b/tutorials/examples/train_graph_ring.py @@ -211,8 +211,7 @@ def backward_masks(self, value: torch.Tensor): def _step(self, states: GraphStates, actions: Actions) -> GraphStates: actions = self.convert_actions(states, actions) - out = super()._step(states, actions) - return out + return super()._step(states, actions) def _backward_step(self, states: GraphStates, actions: Actions) -> GraphStates: actions = self.convert_actions(states, actions) @@ -331,8 +330,8 @@ def render_states(states: GraphStates): # ) # print(state_evaluator(ring_state)) - N_NODES = 2 - N_ITERATIONS = 4096 + N_NODES = 3 + N_ITERATIONS = 1024 torch.random.manual_seed(7) env = RingGraphBuilding(n_nodes=N_NODES) module = RingPolicyEstimator(env.n_nodes) @@ -343,14 +342,15 @@ def render_states(states: GraphStates): gflownet = FMGFlowNet(logf_estimator) optimizer = torch.optim.Adam(gflownet.parameters(), lr=1e-3) - batch_size = 4 + batch_size = 64 losses = [] t1 = time.time() for iteration in range(N_ITERATIONS): trajectories = gflownet.sample_trajectories(env, n=batch_size) - print(torch.count_nonzero(state_evaluator(trajectories.last_states) > 0.1)) + rews = state_evaluator(trajectories.last_states) + print(f"Percentage of rings sampled {torch.mean(rews > 1.1, dtype=torch.float) * 100:.0f}%") samples = gflownet.to_training_samples(trajectories) optimizer.zero_grad() loss = gflownet.loss(env, samples) From 48ceafd01a7e1689f0db5eb2035f01333baee450 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Mon, 10 Feb 2025 13:53:43 +0100 Subject: [PATCH 057/102] add comments --- src/gfn/gflownet/flow_matching.py | 4 +- tutorials/examples/train_graph_ring.py | 156 +++++++++++++------------ 2 files changed, 82 insertions(+), 78 deletions(-) diff --git a/src/gfn/gflownet/flow_matching.py b/src/gfn/gflownet/flow_matching.py index e5fd652d..0540cb36 100644 --- a/src/gfn/gflownet/flow_matching.py +++ b/src/gfn/gflownet/flow_matching.py @@ -1,4 +1,4 @@ -from typing import Any, Tuple, Union +from typing import Any, Optional, Tuple, Union import torch @@ -157,7 +157,7 @@ def reward_matching_loss( self, env: Env, terminating_states: DiscreteStates, - conditioning: torch.Tensor, + conditioning: Optional[torch.Tensor], ) -> torch.Tensor: """Calculates the reward matching loss from the terminating states.""" del env # Unused diff --git a/tutorials/examples/train_graph_ring.py b/tutorials/examples/train_graph_ring.py index 26eb59b2..1304ecae 100644 --- a/tutorials/examples/train_graph_ring.py +++ b/tutorials/examples/train_graph_ring.py @@ -9,7 +9,6 @@ from tensordict import TensorDict from torch import nn import torch.nn.functional as F -from torch_geometric.nn import GCNConv from gfn.actions import Actions, GraphActions, GraphActionType from gfn.gflownet.flow_matching import FMGFlowNet @@ -20,6 +19,16 @@ def state_evaluator(states: GraphStates) -> torch.Tensor: + """Compute the reward of a graph. + + Specifically, the reward is 1 if the graph is a ring, 1e-6 otherwise. + + Args: + states: A batch of graphs. + + Returns: + A tensor of rewards. + """ eps = 1e-6 if states.tensor["edge_index"].shape[0] == 0: return torch.full(states.batch_shape, eps) @@ -38,15 +47,9 @@ def state_evaluator(states: GraphStates) -> torch.Tensor: adj_matrix = torch.zeros(n_nodes, n_nodes) adj_matrix[masked_edge_index[:, 0], masked_edge_index[:, 1]] = 1 - # # Matrix must be symmetric (undirected graph). - # if not torch.all(adj_matrix == adj_matrix.T): - # continue - # Each vertex must have exactly degree 2 (sum of each row = 2). if not torch.all(adj_matrix.sum(axis=1) == 1): continue - - # Connectivity check: Start at vertex 0 and follow edges, keep track - # of visited edges, visit all edges once, end at vertex 0. + visited = [] current = 0 while current not in visited: @@ -76,53 +79,46 @@ def set_diff(tensor1, tensor2): class RingPolicyEstimator(nn.Module): - def __init__( - self, - n_nodes: int, - action_hidden_dim: int = 16, - edge_hidden_dim: int = 16, - ): + """Simple module which outputs a fixed logits for the actions, depending on the number of edges. + + Args: + n_nodes: The number of nodes in the graph. + """ + def __init__(self, n_nodes: int): super().__init__() - self.action_type_conv = GCNConv(n_nodes, action_hidden_dim) - self.edge_index_conv = GCNConv(n_nodes, edge_hidden_dim) self.n_nodes = n_nodes - self.edge_hidden_dim = edge_hidden_dim - - def _group_sum(self, tensor: torch.Tensor, batch_ptr: torch.Tensor) -> torch.Tensor: - cumsum = torch.zeros( - (len(tensor) + 1, *tensor.shape[1:]), - dtype=tensor.dtype, - device=tensor.device, - ) - cumsum[1:] = torch.cumsum(tensor, dim=0) - - # Subtract the end val from each batch idx fom the start val of each batch idx. - return cumsum[batch_ptr[1:]] - cumsum[batch_ptr[:-1]] + self.params = nn.Parameter(nn.init.xavier_normal_(torch.empty(self.n_nodes + 1, self.n_nodes * self.n_nodes + 1))) + with torch.no_grad(): + self.params[-1] = torch.zeros(self.n_nodes * self.n_nodes + 1) + self.params[-1][-1] = 1.0 def forward(self, states_tensor: TensorDict) -> torch.Tensor: - node_feature, batch_ptr = states_tensor["node_feature"], states_tensor["batch_ptr"] - - edge_index = torch.where( - states_tensor["edge_index"][..., None] == states_tensor["node_index"] - )[2].reshape(states_tensor["edge_index"].shape) # (M, 2) - - action_type = self.action_type_conv(node_feature, edge_index.T) - action_type = self._group_sum(torch.mean(action_type, dim=-1, keepdim=True), batch_ptr) - - edge_index = self.edge_index_conv(node_feature, edge_index.T) - edge_index = edge_index.reshape( - *states_tensor["batch_shape"], self.n_nodes, self.edge_hidden_dim - ) - edge_index = torch.einsum("bnf,bmf->bnm", edge_index, edge_index) - - edge_actions = edge_index.reshape( - *states_tensor["batch_shape"], self.n_nodes * self.n_nodes - ) + batch_size = states_tensor["batch_shape"][0] + if batch_size == 0: + return torch.zeros(0, self.n_nodes * self.n_nodes + 1, requires_grad=True) + + first_idx = states_tensor["node_index"][states_tensor["batch_ptr"][:-1]] + last_idx = states_tensor["node_index"][states_tensor["batch_ptr"][1:] - 1] + n_edges = torch.logical_and( + states_tensor["edge_index"] >= first_idx[:, None, None], + states_tensor["edge_index"] <= last_idx[:, None, None], + ).all(dim=-1).sum(dim=-1) - return torch.cat([action_type, edge_actions], dim=-1) + n_edges = torch.clamp(n_edges, 0, self.n_nodes) + return self.params[n_edges] class RingGraphBuilding(GraphBuilding): + """Override the GraphBuilding class to create have discrete actions. + + Specifically, at initialization, we have n nodes. + The policy can only add edges between existing nodes or use the exit action. + The action space is thus discrete and of size n^2 + 1, where the last action is the exit action, + and the first n^2 actions are the possible edges. + + Args: + n_nodes: The number of nodes in the graph. + """ def __init__(self, n_nodes: int): self.n_nodes = n_nodes self.n_actions = 1 + n_nodes * n_nodes @@ -135,9 +131,7 @@ def make_actions_class(self) -> type[Actions]: class RingActions(Actions): action_shape = (1,) dummy_action = torch.tensor([env.n_actions]) - exit_action = torch.zeros( - 1, - ) + exit_action = torch.tensor([env.n_actions - 1]) return RingActions @@ -174,13 +168,13 @@ def __init__(self, tensor: TensorDict): @property def forward_masks(self): forward_masks = torch.ones(len(self), self.n_actions, dtype=torch.bool) - forward_masks[:, 1 :: self.n_nodes + 1] = False + forward_masks[:, ::self.n_nodes + 1] = False for i in range(len(self)): existing_edges = self[i].tensor["edge_index"] - self.tensor['node_index'][self.tensor['batch_ptr'][i]] assert torch.all(existing_edges >= 0) # TODO: convert to test. forward_masks[ i, - 1 + existing_edges[:, 0] * self.n_nodes + existing_edges[:, 1], + existing_edges[:, 0] * self.n_nodes + existing_edges[:, 1], ] = False return forward_masks.view(*self.batch_shape, self.n_actions) @@ -198,7 +192,7 @@ def backward_masks(self): existing_edges = self[i].tensor["edge_index"] - self.tensor['node_index'][self.tensor['batch_ptr'][i]] backward_masks[ i, - 1 + existing_edges[:, 0] * self.n_nodes + existing_edges[:, 1], + existing_edges[:, 0] * self.n_nodes + existing_edges[:, 1], ] = True return backward_masks.view(*self.batch_shape, self.n_actions) @@ -220,15 +214,15 @@ def _backward_step(self, states: GraphStates, actions: Actions) -> GraphStates: def convert_actions(self, states: GraphStates, actions: Actions) -> GraphActions: action_tensor = actions.tensor.squeeze(-1) action_type = torch.where( - action_tensor == 0, GraphActionType.EXIT, GraphActionType.ADD_EDGE + action_tensor == self.n_actions - 1, GraphActionType.EXIT, GraphActionType.ADD_EDGE ) # TODO: refactor. # action = [exit, src_node_idx, dest_node_idx] <- 2n + 1 # action = [node] <- n^2 + 1 - edge_index_i0 = (action_tensor - 1) // (self.n_nodes) # column - edge_index_i1 = (action_tensor - 1) % (self.n_nodes) # row + edge_index_i0 = action_tensor // (self.n_nodes) # column + edge_index_i1 = action_tensor % (self.n_nodes) # row edge_index = torch.stack([edge_index_i0, edge_index_i1], dim=-1) offset = states.tensor["node_index"][states.tensor["batch_ptr"][:-1]] @@ -245,6 +239,7 @@ def convert_actions(self, states: GraphStates, actions: Actions) -> GraphActions class GraphPreprocessor(Preprocessor): + """Extract the tensor from the states.""" def __init__(self, feature_dim: int = 1): super().__init__(output_dim=feature_dim) @@ -256,6 +251,11 @@ def __call__(self, states: GraphStates) -> torch.Tensor: def render_states(states: GraphStates): + """Render the states as a matplotlib plot. + + Args: + states: A batch of graphs. + """ rewards = state_evaluator(states) fig, ax = plt.subplots(2, 4, figsize=(15, 7)) for i in range(8): @@ -314,24 +314,29 @@ def render_states(states: GraphStates): plt.show() -if __name__ == "__main__": - # ring_state = GraphStates( - # TensorDict( - # { - # "node_feature": torch.tensor([[10], [11]]), - # "node_index": torch.tensor([10, 11]), - # "edge_feature": torch.ones((2, 1)), - # "edge_index": torch.tensor([[10, 11], [11, 10]]), - # "batch_ptr": torch.tensor([0, 2]), - # "batch_shape": torch.ones((1,), dtype=torch.long), - # }, - # batch_size=(), - # ) - # ) - # print(state_evaluator(ring_state)) +def test(): + module = RingPolicyEstimator(4) + + states = GraphStates( + TensorDict({ + "node_feature": torch.eye(4), + "edge_feature": torch.zeros((0, 1)), + "edge_index": torch.zeros((0, 2), dtype=torch.long), + "node_index": torch.arange(4), + "batch_ptr": torch.tensor([0, 4]), + "batch_shape": torch.tensor([1]), + }) + ) + out = module(states.tensor) + loss = torch.sum(out) + loss.backward() + print("Params:", [p for p in module.params]) + print("Gradients:", [p.grad for p in module.params]) + +if __name__ == "__main__": N_NODES = 3 - N_ITERATIONS = 1024 + N_ITERATIONS = 8192 torch.random.manual_seed(7) env = RingGraphBuilding(n_nodes=N_NODES) module = RingPolicyEstimator(env.n_nodes) @@ -341,8 +346,8 @@ def render_states(states: GraphStates): ) gflownet = FMGFlowNet(logf_estimator) - optimizer = torch.optim.Adam(gflownet.parameters(), lr=1e-3) - batch_size = 64 + optimizer = torch.optim.RMSprop(gflownet.parameters(), lr=0.1, momentum=0.9) + batch_size = 8 losses = [] @@ -350,11 +355,10 @@ def render_states(states: GraphStates): for iteration in range(N_ITERATIONS): trajectories = gflownet.sample_trajectories(env, n=batch_size) rews = state_evaluator(trajectories.last_states) - print(f"Percentage of rings sampled {torch.mean(rews > 1.1, dtype=torch.float) * 100:.0f}%") samples = gflownet.to_training_samples(trajectories) optimizer.zero_grad() loss = gflownet.loss(env, samples) - print("Iteration", iteration, "Loss:", loss.item()) + print("Iteration", iteration, "Loss:", loss.item(), f"rings: {torch.mean(rews > 0.1, dtype=torch.float) * 100:.0f}%") loss.backward() optimizer.step() losses.append(loss.item()) From 22717321255c257b82da0686b464b2bdbaa9d960 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Mon, 10 Feb 2025 16:55:20 +0100 Subject: [PATCH 058/102] fix linter --- src/gfn/actions.py | 4 +- src/gfn/gflownet/flow_matching.py | 4 +- src/gfn/gym/graph_building.py | 10 +- src/gfn/modules.py | 4 +- src/gfn/samplers.py | 28 ++--- src/gfn/states.py | 132 +++++++++++++--------- src/gfn/utils/distributions.py | 13 ++- testing/test_actions.py | 4 +- testing/test_samplers_and_trajectories.py | 82 -------------- testing/test_states.py | 24 ++-- tutorials/examples/train_graph_ring.py | 113 ++++++++++++------ 11 files changed, 211 insertions(+), 207 deletions(-) diff --git a/src/gfn/actions.py b/src/gfn/actions.py index 24159538..b85691a7 100644 --- a/src/gfn/actions.py +++ b/src/gfn/actions.py @@ -237,13 +237,13 @@ def __repr__(self): return f"""GraphAction object with {self.batch_shape} actions.""" @property - def device(self) -> torch.device: + def device(self) -> torch.device | None: """Returns the device of the features tensor.""" return self.tensor.device def __len__(self) -> int: """Returns the number of actions in the batch.""" - return prod(self.batch_shape) + return int(prod(self.batch_shape)) def __getitem__(self, index: int | Sequence[int] | Sequence[bool]) -> GraphActions: """Get particular actions of the batch.""" diff --git a/src/gfn/gflownet/flow_matching.py b/src/gfn/gflownet/flow_matching.py index 1571f383..d879916d 100644 --- a/src/gfn/gflownet/flow_matching.py +++ b/src/gfn/gflownet/flow_matching.py @@ -169,11 +169,11 @@ def reward_matching_loss( else: with no_conditioning_exception_handler("logF", self.logF): log_edge_flows = self.logF(terminating_states) - + # Handle the boundary condition (for all x, F(X->S_f) = R(x)). terminating_log_edge_flows = log_edge_flows[:, -1] log_rewards = terminating_states.log_rewards - + return (terminating_log_edge_flows - log_rewards).pow(2).mean() def loss( diff --git a/src/gfn/gym/graph_building.py b/src/gfn/gym/graph_building.py index 1277d727..fbde436b 100644 --- a/src/gfn/gym/graph_building.py +++ b/src/gfn/gym/graph_building.py @@ -56,6 +56,12 @@ def __init__( device_str=device_str, ) + def reset(self, batch_shape: Tuple | int) -> GraphStates: + """Reset the environment to a new batch of graphs.""" + states = super().reset(batch_shape) + assert isinstance(states, GraphStates) + return states + def step(self, states: GraphStates, actions: GraphActions) -> TensorDict: """Step function for the GraphBuilding environment. @@ -157,11 +163,11 @@ def is_action_valid( if torch.any(add_edge_actions[:, 0] == add_edge_actions[:, 1]): return False if add_edge_states["node_feature"].shape[0] == 0: - return False + return False node_exists = torch.isin(add_edge_actions, add_edge_states["node_index"]) if not torch.all(node_exists): return False - + equal_edges_per_batch = torch.all( add_edge_states["edge_index"] == add_edge_actions[:, None], dim=-1, diff --git a/src/gfn/modules.py b/src/gfn/modules.py index a2ef72d6..5ac7b77f 100644 --- a/src/gfn/modules.py +++ b/src/gfn/modules.py @@ -536,7 +536,9 @@ def to_probability_distribution( 1 - epsilon ) * edge_index_probs + epsilon * uniform_dist_probs edge_index_probs[torch.isnan(edge_index_probs)] = 1 - dists["edge_index"] = CategoricalIndexes(probs=edge_index_probs, node_indexes=states.tensor["node_index"]) + dists["edge_index"] = CategoricalIndexes( + probs=edge_index_probs, node_indexes=states.tensor["node_index"] + ) dists["features"] = Normal(module_output["features"], temperature) return CompositeDistribution(dists=dists) diff --git a/src/gfn/samplers.py b/src/gfn/samplers.py index 800c5bef..73997dae 100644 --- a/src/gfn/samplers.py +++ b/src/gfn/samplers.py @@ -245,32 +245,22 @@ def sample_trajectories( dones = dones | new_dones trajectories_states.append(deepcopy(states)) - trajectories_states = env.States.stack(trajectories_states) - trajectories_actions = env.Actions.stack(trajectories_actions)[ - 1: # Drop dummy action - ] # pyright: ignore - trajectories_logprobs = ( - torch.stack(trajectories_logprobs, dim=0)[1:] # Drop dummy logprob - if save_logprobs - else None - ) # pyright: ignore - - # TODO: use torch.nested.nested_tensor(dtype, device, requires_grad). - if save_estimator_outputs: - all_estimator_outputs = torch.stack(all_estimator_outputs, dim=0) - # TODO: do not ignore the next ignores trajectories = Trajectories( env=env, - states=trajectories_states, # pyright: ignore + states=env.States.stack(trajectories_states), conditioning=conditioning, - actions=trajectories_actions, # pyright: ignore + actions=env.Actions.stack(trajectories_actions[1:]), when_is_done=trajectories_dones, is_backward=self.estimator.is_backward, log_rewards=trajectories_log_rewards, - log_probs=trajectories_logprobs, # pyright: ignore + log_probs=( + torch.stack(trajectories_logprobs, dim=0)[1:] if save_logprobs else None + ), estimator_outputs=( - all_estimator_outputs if save_estimator_outputs else None - ), # pyright: ignore + torch.stack(all_estimator_outputs, dim=0) + if save_estimator_outputs + else None + ), ) return trajectories diff --git a/src/gfn/states.py b/src/gfn/states.py index 7e4150d4..b143539d 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -282,7 +282,7 @@ def sample(self, n_samples: int) -> States: return self[torch.randperm(len(self))[:n_samples]] @classmethod - def stack(cls, states: List[States]): + def stack(cls, states: Sequence[States]) -> States: """Given a list of states, stacks them along a new dimension (0).""" state_example = states[0] assert all( @@ -292,18 +292,12 @@ def stack(cls, states: List[States]): stacked_states = state_example.from_batch_shape((0, 0)) # Empty. stacked_states.tensor = torch.stack([s.tensor for s in states], dim=0) if state_example._log_rewards: - stacked_states._log_rewards = torch.stack( - [s._log_rewards for s in states], dim=0 - ) - - # We are dealing with a list of DiscretrStates instances. - if hasattr(state_example, "forward_masks"): - stacked_states.forward_masks = torch.stack( - [s.forward_masks for s in states], dim=0 - ) - stacked_states.backward_masks = torch.stack( - [s.backward_masks for s in states], dim=0 - ) + log_rewards = [] + for s in states: + if s._log_rewards is None: + raise ValueError("Some states have no log rewards.") + log_rewards.append(s._log_rewards) + stacked_states._log_rewards = torch.stack(log_rewards, dim=0) # Adds the trajectory dimension. stacked_states.batch_shape = ( @@ -498,6 +492,15 @@ def init_forward_masks(self, set_ones: bool = True): else: self.forward_masks = torch.zeros(shape).bool() + @classmethod + def stack(cls, states: List[DiscreteStates]) -> DiscreteStates: + """Stacks a list of DiscreteStates objects along a new dimension (0).""" + out = super().stack(states) + assert isinstance(out, DiscreteStates) + out.forward_masks = torch.stack([s.forward_masks for s in states], dim=0) + out.backward_masks = torch.stack([s.backward_masks for s in states], dim=0) + return out + class GraphStates(States): """ @@ -529,7 +532,7 @@ def __init__(self, tensor: TensorDict): self.tensor = tensor self.node_features_dim = tensor["node_feature"].shape[-1] self.edge_features_dim = tensor["edge_feature"].shape[-1] - self._log_rewards: Optional[float] = None + self._log_rewards: Optional[torch.Tensor] = None @property def batch_shape(self) -> tuple: @@ -553,7 +556,7 @@ def from_batch_shape( def make_initial_states_tensor(cls, batch_shape: int | Tuple) -> TensorDict: batch_shape = batch_shape if isinstance(batch_shape, Tuple) else (batch_shape,) nodes = cls.s0["node_feature"].repeat(np.prod(batch_shape), 1) - + return TensorDict( { "node_feature": nodes, @@ -602,7 +605,9 @@ def make_random_states_tensor(cls, batch_shape: int | Tuple) -> TensorDict: "node_feature": torch.rand( np.prod(batch_shape) * num_nodes, node_features_dim, device=device ), - "node_index": GraphStates.unique_node_indices(np.prod(batch_shape) * num_nodes), + "node_index": GraphStates.unique_node_indices( + np.prod(batch_shape) * num_nodes + ), "edge_feature": torch.rand( np.prod(batch_shape) * num_edges, edge_features_dim, device=device ), @@ -628,14 +633,14 @@ def __getitem__( self, index: int | Sequence[int] | slice | torch.Tensor ) -> GraphStates: tensor_idx = torch.arange(len(self)).view(*self.batch_shape) - index = tensor_idx[index].flatten() + idx = tensor_idx[index].flatten() - if torch.any(index >= len(self.tensor["batch_ptr"]) - 1): + if torch.any(idx >= len(self.tensor["batch_ptr"]) - 1): raise ValueError("Graph index out of bounds") # TODO: explain batch_ptr and node_index semantics - start_ptrs = self.tensor["batch_ptr"][:-1][index] - end_ptrs = self.tensor["batch_ptr"][1:][index] + start_ptrs = self.tensor["batch_ptr"][:-1][idx] + end_ptrs = self.tensor["batch_ptr"][1:][idx] node_features = [torch.empty(0, self.node_features_dim)] node_indices = [torch.empty(0, dtype=torch.long)] @@ -650,8 +655,11 @@ def __getitem__( # Find edges for this graph if self.tensor["node_index"].numel() > 0: - edge_mask = (self.tensor["edge_index"][:, 0] >= self.tensor["node_index"][start]) & ( - self.tensor["edge_index"][:, 0] <= self.tensor["node_index"][end - 1] + edge_mask = ( + self.tensor["edge_index"][:, 0] >= self.tensor["node_index"][start] + ) & ( + self.tensor["edge_index"][:, 0] + <= self.tensor["node_index"][end - 1] ) edge_features.append(self.tensor["edge_feature"][edge_mask]) edge_indices.append(self.tensor["edge_index"][edge_mask]) @@ -664,15 +672,17 @@ def __getitem__( "edge_feature": torch.cat(edge_features), "edge_index": torch.cat(edge_indices), "batch_ptr": torch.tensor(batch_ptr), - "batch_shape": (len(index),), + "batch_shape": (len(idx),), } ) ) if self._log_rewards is not None: - out._log_rewards = self._log_rewards[index] + out._log_rewards = self._log_rewards[idx] - assert out.tensor["node_index"].unique().numel() == len(out.tensor["node_index"]) + assert out.tensor["node_index"].unique().numel() == len( + out.tensor["node_index"] + ) return out @@ -681,11 +691,11 @@ def __setitem__(self, index: int | Sequence[int], graph: GraphStates): Set particular states of the Batch """ # This is to convert index to type int (linear indexing). - tensor_idx = torch.arange(len(self)).view(*self.batch_shape) - index = tensor_idx[index].flatten() + idx = torch.arange(len(self)).view(*self.batch_shape) + idx = idx[index].flatten() # Validate indices - if torch.any(index >= len(self.tensor["batch_ptr"]) - 1): + if torch.any(idx >= len(self.tensor["batch_ptr"]) - 1): raise ValueError("Target graph index out of bounds") # Source graph details @@ -693,12 +703,12 @@ def __setitem__(self, index: int | Sequence[int], graph: GraphStates): source_num_graphs = torch.prod(source_tensor_dict["batch_shape"]) # Validate source and target indices match - if len(index) != source_num_graphs: + if len(idx) != source_num_graphs: raise ValueError( "Number of source graphs must match number of target indices" ) - for i, graph_idx in enumerate(index): + for i, graph_idx in enumerate(idx): # Get start and end pointers for the current graph start_ptr = self.tensor["batch_ptr"][graph_idx] end_ptr = self.tensor["batch_ptr"][graph_idx + 1] @@ -730,19 +740,29 @@ def __setitem__(self, index: int | Sequence[int], graph: GraphStates): edge_mask = torch.empty(0, dtype=torch.bool) if self.tensor["edge_index"].numel() > 0: - edge_mask = torch.all(self.tensor["edge_index"] > self.tensor["node_index"][end_ptr - 1], dim=-1) - edge_mask |= torch.all(self.tensor["edge_index"] < self.tensor["node_index"][start_ptr], dim=-1) - + edge_mask = torch.all( + self.tensor["edge_index"] > self.tensor["node_index"][end_ptr - 1], + dim=-1, + ) + edge_mask |= torch.all( + self.tensor["edge_index"] < self.tensor["node_index"][start_ptr], + dim=-1, + ) + edge_to_add_mask = torch.all( - source_tensor_dict["edge_index"] >= source_tensor_dict["node_index"][source_start_ptr], dim=-1 + source_tensor_dict["edge_index"] + >= source_tensor_dict["node_index"][source_start_ptr], + dim=-1, ) edge_to_add_mask &= torch.all( - source_tensor_dict["edge_index"] <= source_tensor_dict["node_index"][source_end_ptr - 1], dim=-1 + source_tensor_dict["edge_index"] + <= source_tensor_dict["node_index"][source_end_ptr - 1], + dim=-1, ) self.tensor["edge_index"] = torch.cat( [ self.tensor["edge_index"][edge_mask], - source_tensor_dict["edge_index"][edge_to_add_mask] + source_tensor_dict["edge_index"][edge_to_add_mask], ], dim=0, ) @@ -764,8 +784,10 @@ def __setitem__(self, index: int | Sequence[int], graph: GraphStates): # Update batch pointers shift = new_nodes.shape[0] - (end_ptr - start_ptr) self.tensor["batch_ptr"][graph_idx + 1 :] += shift - - assert self.tensor["node_index"].unique().numel() == len(self.tensor["node_index"]) + + assert self.tensor["node_index"].unique().numel() == len( + self.tensor["node_index"] + ) @property def device(self) -> torch.device | None: @@ -791,15 +813,22 @@ def extend(self, other: GraphStates): # find if there are common node indices other_node_index = other.tensor["node_index"] other_edge_index = other.tensor["edge_index"] - common_node_indices = torch.any(self.tensor["node_index"][:, None] == other_node_index[None, :], dim=0) + common_node_indices = torch.any( + self.tensor["node_index"][:, None] == other_node_index[None, :], dim=0 + ) if torch.any(common_node_indices): - # This renumbers nodes across batch indices such that all nodes have + # This renumbers nodes across batch indices such that all nodes have # a unique ID. - new_indices = GraphStates.unique_node_indices(torch.sum(common_node_indices)) - - # find edge_index which contains other_node_index[common_node_indices]. this is + new_indices = GraphStates.unique_node_indices( + torch.sum(common_node_indices) + ) + + # find edge_index which contains other_node_index[common_node_indices]. this is # because all new edges must be to new nodes (unique). - edge_mask = other_edge_index[:, :, None] == other_node_index[None, common_node_indices] + edge_mask = ( + other_edge_index[:, :, None] + == other_node_index[None, common_node_indices] + ) repeat_indices = new_indices[None, None].repeat(edge_mask.shape[0], 2, 1) other_edge_index[torch.any(edge_mask, dim=-1)] = repeat_indices[edge_mask] other_node_index[common_node_indices] = new_indices @@ -811,10 +840,7 @@ def extend(self, other: GraphStates): [self.tensor["edge_feature"], other.tensor["edge_feature"]], dim=0 ) self.tensor["edge_index"] = torch.cat( - [ - self.tensor["edge_index"], - other.tensor["edge_index"] - ], + [self.tensor["edge_index"], other.tensor["edge_index"]], dim=0, ) self.tensor["batch_ptr"] = torch.cat( @@ -882,7 +908,9 @@ def stack(cls, states: List[GraphStates]): stacked_states.extend(state) stacked_states.tensor["batch_shape"] = (len(states),) + state_batch_shape - assert stacked_states.tensor["node_index"].unique().numel() == len(stacked_states.tensor["node_index"]) + assert stacked_states.tensor["node_index"].unique().numel() == len( + stacked_states.tensor["node_index"] + ) return stacked_states @property @@ -992,6 +1020,8 @@ def backward_masks(self) -> TensorDict: @classmethod def unique_node_indices(cls, num_new_nodes: int) -> torch.Tensor: - indices = torch.arange(cls._next_node_index, cls._next_node_index + num_new_nodes) + indices = torch.arange( + cls._next_node_index, cls._next_node_index + num_new_nodes + ) cls._next_node_index += num_new_nodes - return indices \ No newline at end of file + return indices diff --git a/src/gfn/utils/distributions.py b/src/gfn/utils/distributions.py index cfb9b96b..48db3f58 100644 --- a/src/gfn/utils/distributions.py +++ b/src/gfn/utils/distributions.py @@ -81,12 +81,21 @@ def __init__(self, probs: torch.Tensor, node_indexes: torch.Tensor): n: The number of nodes in the graph. """ self.node_indexes = node_indexes - assert probs.shape == (probs.shape[0], node_indexes.shape[0] * node_indexes.shape[0]) + assert probs.shape == ( + probs.shape[0], + node_indexes.shape[0] * node_indexes.shape[0], + ) super().__init__(probs) def sample(self, sample_shape=torch.Size()) -> torch.Tensor: samples = super().sample(sample_shape) - out = torch.stack([samples // self.node_indexes.shape[0], samples % self.node_indexes.shape[0]], dim=-1) + out = torch.stack( + [ + samples // self.node_indexes.shape[0], + samples % self.node_indexes.shape[0], + ], + dim=-1, + ) out = self.node_indexes.index_select(0, out.flatten()).reshape(*out.shape) return out diff --git a/testing/test_actions.py b/testing/test_actions.py index 2058d0d8..03e26a94 100644 --- a/testing/test_actions.py +++ b/testing/test_actions.py @@ -13,7 +13,7 @@ class ContinuousActions(Actions): exit_action = torch.ones(10) -class GraphActions(GraphActions): +class TestGraphActions(GraphActions): features_dim = 10 @@ -24,7 +24,7 @@ def continuous_action(): @pytest.fixture def graph_action(): - return GraphActions( + return TestGraphActions( tensor=TensorDict( { "action_type": torch.zeros((1,), dtype=torch.float32), diff --git a/testing/test_samplers_and_trajectories.py b/testing/test_samplers_and_trajectories.py index 2630706e..44288e94 100644 --- a/testing/test_samplers_and_trajectories.py +++ b/testing/test_samplers_and_trajectories.py @@ -352,88 +352,6 @@ def test_replay_buffer( except Exception as e: raise ValueError(f"Error while testing {env_name}") from e -@pytest.mark.parametrize("env_name", ["HyperGrid", "DiscreteEBM"]) -def test_reverse_backward_trajectories(env_name: str): - """ - Ensures that the vectorized `Trajectories.reverse_backward_trajectories` - matches the for-loop approach by toggling `debug=True`. - - Note that `Trajectories.reverse_backward_trajectories` is not compatible with - environment with continuous states (e.g., Box). - """ - _, backward_trajectories, *_ = trajectory_sampling_with_return( - env_name, - preprocessor_name="Identity", - delta=0.1, - n_components=1, - n_components_s0=1, - ) - try: - _ = Trajectories.reverse_backward_trajectories( - backward_trajectories, debug=True # <--- TRIGGER THE COMPARISON - ) - except Exception as e: - raise ValueError( - f"Error while testing Trajectories.reverse_backward_trajectories in {env_name}" - ) from e - - -@pytest.mark.parametrize("env_name", ["HyperGrid", "DiscreteEBM"]) -def test_local_search_for_loop_equivalence(env_name): - """ - Ensures that the vectorized `LocalSearchSampler.local_search` matches - the for-loop approach by toggling `debug=True`. - - Note that this is not supported for environment with continuous state - space (e.g., Box), since `Trajectories.reverse_backward_trajectories` - is not compatible with continuous states. - """ - # Build environment - if env_name == "HyperGrid": - env = HyperGrid(ndim=2, height=5, preprocessor_name="KHot") - elif env_name == "DiscreteEBM": - env = DiscreteEBM(ndim=5) - else: - raise ValueError("Unknown environment name") - - # Build pf & pb - pf_module = MLP(env.preprocessor.output_dim, env.n_actions) - pb_module = MLP(env.preprocessor.output_dim, env.n_actions - 1) - pf_estimator = DiscretePolicyEstimator( - module=pf_module, - n_actions=env.n_actions, - is_backward=False, - preprocessor=env.preprocessor, - ) - pb_estimator = DiscretePolicyEstimator( - module=pb_module, - n_actions=env.n_actions, - is_backward=True, - preprocessor=env.preprocessor, - ) - - sampler = LocalSearchSampler(pf_estimator=pf_estimator, pb_estimator=pb_estimator) - - # Initial forward-sampler call - trajectories = sampler.sample_trajectories(env, n=3, save_logprobs=True) - - # Now run local_search in debug mode so that for-loop logic is compared - # to the vectorized logic. - # If there’s any mismatch, local_search() will raise AssertionError - try: - new_trajectories, is_updated = sampler.local_search( - env, - trajectories, - save_logprobs=True, - back_ratio=0.5, - use_metropolis_hastings=True, - debug=True, # <--- TRIGGER THE COMPARISON - ) - except Exception as e: - raise ValueError( - f"Error while testing LocalSearchSampler.local_search in {env_name}" - ) from e - # ------ GRAPH TESTS ------ diff --git a/testing/test_states.py b/testing/test_states.py index 921971eb..2961ab97 100644 --- a/testing/test_states.py +++ b/testing/test_states.py @@ -2,6 +2,7 @@ from tensordict import TensorDict import torch + def make_graph_states(n_graphs, n_nodes, n_edges): batch_ptr = torch.cat([torch.zeros(1), n_nodes.cumsum(0)]).int() node_feature = torch.randn(batch_ptr[-1].item(), 10) @@ -13,14 +14,18 @@ def make_graph_states(n_graphs, n_nodes, n_edges): edge_index.append(torch.randint(start, end, (n_edges[i], 2))) edge_index = torch.cat(edge_index) - return GraphStates(TensorDict({ - "node_feature": node_feature, - "edge_feature": edge_features, - "edge_index": edge_index, - "node_index": node_index, - "batch_ptr": batch_ptr, - "batch_shape": torch.tensor([n_graphs]), - })) + return GraphStates( + TensorDict( + { + "node_feature": node_feature, + "edge_feature": edge_features, + "edge_index": edge_index, + "node_index": node_index, + "batch_ptr": batch_ptr, + "batch_shape": torch.tensor([n_graphs]), + } + ) + ) def test_get_set(): @@ -46,5 +51,6 @@ def test_stack(): assert stacked_graphs[0]._compare(graphs[0].tensor) assert stacked_graphs[1]._compare(graphs[1].tensor) + if __name__ == "__main__": - test_get_set() \ No newline at end of file + test_get_set() diff --git a/tutorials/examples/train_graph_ring.py b/tutorials/examples/train_graph_ring.py index 1304ecae..58f3a09f 100644 --- a/tutorials/examples/train_graph_ring.py +++ b/tutorials/examples/train_graph_ring.py @@ -4,6 +4,7 @@ import time from typing import Optional +from matplotlib import patches import matplotlib.pyplot as plt import torch from tensordict import TensorDict @@ -20,7 +21,7 @@ def state_evaluator(states: GraphStates) -> torch.Tensor: """Compute the reward of a graph. - + Specifically, the reward is 1 if the graph is a ring, 1e-6 otherwise. Args: @@ -41,7 +42,9 @@ def state_evaluator(states: GraphStates) -> torch.Tensor: edge_index_mask = torch.all( states.tensor["edge_index"] >= nodes_index_range[0], dim=-1 ) & torch.all(states.tensor["edge_index"] <= nodes_index_range[-1], dim=-1) - masked_edge_index = states.tensor["edge_index"][edge_index_mask] - nodes_index_range[0] + masked_edge_index = ( + states.tensor["edge_index"][edge_index_mask] - nodes_index_range[0] + ) n_nodes = nodes_index_range.shape[0] adj_matrix = torch.zeros(n_nodes, n_nodes) @@ -70,7 +73,7 @@ def set_diff(tensor1, tensor2): break else: break # TODO: This actually should never happen, should be caught on line 45. - + # Check if we visited all vertices and the last vertex connects back to start. if len(visited) == n_nodes and adj_matrix[current][0] == 1: out[i] = 1.0 @@ -80,14 +83,19 @@ def set_diff(tensor1, tensor2): class RingPolicyEstimator(nn.Module): """Simple module which outputs a fixed logits for the actions, depending on the number of edges. - + Args: n_nodes: The number of nodes in the graph. """ + def __init__(self, n_nodes: int): super().__init__() self.n_nodes = n_nodes - self.params = nn.Parameter(nn.init.xavier_normal_(torch.empty(self.n_nodes + 1, self.n_nodes * self.n_nodes + 1))) + self.params = nn.Parameter( + nn.init.xavier_normal_( + torch.empty(self.n_nodes + 1, self.n_nodes * self.n_nodes + 1) + ) + ) with torch.no_grad(): self.params[-1] = torch.zeros(self.n_nodes * self.n_nodes + 1) self.params[-1][-1] = 1.0 @@ -96,13 +104,17 @@ def forward(self, states_tensor: TensorDict) -> torch.Tensor: batch_size = states_tensor["batch_shape"][0] if batch_size == 0: return torch.zeros(0, self.n_nodes * self.n_nodes + 1, requires_grad=True) - + first_idx = states_tensor["node_index"][states_tensor["batch_ptr"][:-1]] last_idx = states_tensor["node_index"][states_tensor["batch_ptr"][1:] - 1] - n_edges = torch.logical_and( - states_tensor["edge_index"] >= first_idx[:, None, None], - states_tensor["edge_index"] <= last_idx[:, None, None], - ).all(dim=-1).sum(dim=-1) + n_edges = ( + torch.logical_and( + states_tensor["edge_index"] >= first_idx[:, None, None], + states_tensor["edge_index"] <= last_idx[:, None, None], + ) + .all(dim=-1) + .sum(dim=-1) + ) n_edges = torch.clamp(n_edges, 0, self.n_nodes) return self.params[n_edges] @@ -110,7 +122,7 @@ def forward(self, states_tensor: TensorDict) -> torch.Tensor: class RingGraphBuilding(GraphBuilding): """Override the GraphBuilding class to create have discrete actions. - + Specifically, at initialization, we have n nodes. The policy can only add edges between existing nodes or use the exit action. The action space is thus discrete and of size n^2 + 1, where the last action is the exit action, @@ -119,6 +131,7 @@ class RingGraphBuilding(GraphBuilding): Args: n_nodes: The number of nodes in the graph. """ + def __init__(self, n_nodes: int): self.n_nodes = n_nodes self.n_actions = 1 + n_nodes * n_nodes @@ -141,7 +154,9 @@ def make_states_class(self) -> type[GraphStates]: class RingStates(GraphStates): s0 = TensorDict( { - "node_feature": F.one_hot(torch.arange(env.n_nodes), num_classes=env.n_nodes).float(), + "node_feature": F.one_hot( + torch.arange(env.n_nodes), num_classes=env.n_nodes + ).float(), "edge_feature": torch.ones((0, 1)), "edge_index": torch.ones((0, 2), dtype=torch.long), }, @@ -168,9 +183,12 @@ def __init__(self, tensor: TensorDict): @property def forward_masks(self): forward_masks = torch.ones(len(self), self.n_actions, dtype=torch.bool) - forward_masks[:, ::self.n_nodes + 1] = False + forward_masks[:, :: self.n_nodes + 1] = False for i in range(len(self)): - existing_edges = self[i].tensor["edge_index"] - self.tensor['node_index'][self.tensor['batch_ptr'][i]] + existing_edges = ( + self[i].tensor["edge_index"] + - self.tensor["node_index"][self.tensor["batch_ptr"][i]] + ) assert torch.all(existing_edges >= 0) # TODO: convert to test. forward_masks[ i, @@ -189,7 +207,10 @@ def backward_masks(self): len(self), self.n_actions, dtype=torch.bool ) for i in range(len(self)): - existing_edges = self[i].tensor["edge_index"] - self.tensor['node_index'][self.tensor['batch_ptr'][i]] + existing_edges = ( + self[i].tensor["edge_index"] + - self.tensor["node_index"][self.tensor["batch_ptr"][i]] + ) backward_masks[ i, existing_edges[:, 0] * self.n_nodes + existing_edges[:, 1], @@ -205,24 +226,30 @@ def backward_masks(self, value: torch.Tensor): def _step(self, states: GraphStates, actions: Actions) -> GraphStates: actions = self.convert_actions(states, actions) - return super()._step(states, actions) + new_states = super()._step(states, actions) + assert isinstance(new_states, GraphStates) + return new_states def _backward_step(self, states: GraphStates, actions: Actions) -> GraphStates: actions = self.convert_actions(states, actions) - return super()._backward_step(states, actions) + new_states = super()._backward_step(states, actions) + assert isinstance(new_states, GraphStates) + return new_states def convert_actions(self, states: GraphStates, actions: Actions) -> GraphActions: action_tensor = actions.tensor.squeeze(-1) action_type = torch.where( - action_tensor == self.n_actions - 1, GraphActionType.EXIT, GraphActionType.ADD_EDGE + action_tensor == self.n_actions - 1, + GraphActionType.EXIT, + GraphActionType.ADD_EDGE, ) # TODO: refactor. # action = [exit, src_node_idx, dest_node_idx] <- 2n + 1 # action = [node] <- n^2 + 1 - + edge_index_i0 = action_tensor // (self.n_nodes) # column - edge_index_i1 = action_tensor % (self.n_nodes) # row + edge_index_i1 = action_tensor % (self.n_nodes) # row edge_index = torch.stack([edge_index_i0, edge_index_i1], dim=-1) offset = states.tensor["node_index"][states.tensor["batch_ptr"][:-1]] @@ -240,6 +267,7 @@ def convert_actions(self, states: GraphStates, actions: Actions) -> GraphActions class GraphPreprocessor(Preprocessor): """Extract the tensor from the states.""" + def __init__(self, feature_dim: int = 1): super().__init__(output_dim=feature_dim) @@ -252,7 +280,7 @@ def __call__(self, states: GraphStates) -> torch.Tensor: def render_states(states: GraphStates): """Render the states as a matplotlib plot. - + Args: states: A batch of graphs. """ @@ -271,7 +299,7 @@ def render_states(states: GraphStates): xs.append(x) ys.append(y) current_ax.add_patch( - plt.Circle((x, y), 0.5, facecolor="none", edgecolor="black") + patches.Circle((x, y), 0.5, facecolor="none", edgecolor="black") ) edge_index = states[i].tensor["edge_index"] @@ -318,14 +346,16 @@ def test(): module = RingPolicyEstimator(4) states = GraphStates( - TensorDict({ - "node_feature": torch.eye(4), - "edge_feature": torch.zeros((0, 1)), - "edge_index": torch.zeros((0, 2), dtype=torch.long), - "node_index": torch.arange(4), - "batch_ptr": torch.tensor([0, 4]), - "batch_shape": torch.tensor([1]), - }) + TensorDict( + { + "node_feature": torch.eye(4), + "edge_feature": torch.zeros((0, 1)), + "edge_index": torch.zeros((0, 2), dtype=torch.long), + "node_index": torch.arange(4), + "batch_ptr": torch.tensor([0, 4]), + "batch_shape": torch.tensor([1]), + } + ) ) out = module(states.tensor) @@ -334,6 +364,7 @@ def test(): print("Params:", [p for p in module.params]) print("Gradients:", [p.grad for p in module.params]) + if __name__ == "__main__": N_NODES = 3 N_ITERATIONS = 8192 @@ -353,16 +384,28 @@ def test(): t1 = time.time() for iteration in range(N_ITERATIONS): - trajectories = gflownet.sample_trajectories(env, n=batch_size) - rews = state_evaluator(trajectories.last_states) + trajectories = gflownet.sample_trajectories( + env, n=batch_size # pyright: ignore + ) + last_states = trajectories.last_states + assert isinstance(last_states, GraphStates) + rews = state_evaluator(last_states) samples = gflownet.to_training_samples(trajectories) optimizer.zero_grad() - loss = gflownet.loss(env, samples) - print("Iteration", iteration, "Loss:", loss.item(), f"rings: {torch.mean(rews > 0.1, dtype=torch.float) * 100:.0f}%") + loss = gflownet.loss(env, samples) # pyright: ignore + print( + "Iteration", + iteration, + "Loss:", + loss.item(), + f"rings: {torch.mean(rews > 0.1, dtype=torch.float) * 100:.0f}%", + ) loss.backward() optimizer.step() losses.append(loss.item()) t2 = time.time() print("Time:", t2 - t1) - render_states(trajectories.last_states[:8]) + last_states = trajectories.last_states[:8] + assert isinstance(last_states, GraphStates) + render_states(last_states) From 4be8f9bdfaf05eadbf080b9824ef79d0db3ee1ca Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Tue, 11 Feb 2025 16:57:33 +0100 Subject: [PATCH 059/102] add two tested RingPolicyEstimator --- tutorials/examples/train_graph_ring.py | 100 ++++++++++++++++++------- 1 file changed, 74 insertions(+), 26 deletions(-) diff --git a/tutorials/examples/train_graph_ring.py b/tutorials/examples/train_graph_ring.py index 58f3a09f..1c09f56d 100644 --- a/tutorials/examples/train_graph_ring.py +++ b/tutorials/examples/train_graph_ring.py @@ -10,6 +10,7 @@ from tensordict import TensorDict from torch import nn import torch.nn.functional as F +from torch_geometric.nn import GCNConv from gfn.actions import Actions, GraphActions, GraphActionType from gfn.gflownet.flow_matching import FMGFlowNet @@ -81,6 +82,44 @@ def set_diff(tensor1, tensor2): return out.view(*states.batch_shape) +# class RingPolicyEstimator(nn.Module): +# """Simple module which outputs a fixed logits for the actions, depending on the number of edges. + +# Args: +# n_nodes: The number of nodes in the graph. +# """ + +# def __init__(self, n_nodes: int): +# super().__init__() +# self.n_nodes = n_nodes +# self.params = nn.Parameter( +# nn.init.xavier_normal_( +# torch.empty(self.n_nodes + 1, self.n_nodes * self.n_nodes + 1) +# ) +# ) +# with torch.no_grad(): +# self.params[-1] = torch.zeros(self.n_nodes * self.n_nodes + 1) +# self.params[-1][-1] = 1.0 + +# def forward(self, states_tensor: TensorDict) -> torch.Tensor: +# batch_size = states_tensor["batch_shape"][0] +# if batch_size == 0: +# return torch.zeros(0, self.n_nodes * self.n_nodes + 1, requires_grad=True) + +# first_idx = states_tensor["node_index"][states_tensor["batch_ptr"][:-1]] +# last_idx = states_tensor["node_index"][states_tensor["batch_ptr"][1:] - 1] +# n_edges = ( +# torch.logical_and( +# states_tensor["edge_index"] >= first_idx[:, None, None], +# states_tensor["edge_index"] <= last_idx[:, None, None], +# ) +# .all(dim=-1) +# .sum(dim=-1) +# ) + +# n_edges = torch.clamp(n_edges, 0, self.n_nodes) +# return self.params[n_edges] + class RingPolicyEstimator(nn.Module): """Simple module which outputs a fixed logits for the actions, depending on the number of edges. @@ -91,33 +130,42 @@ class RingPolicyEstimator(nn.Module): def __init__(self, n_nodes: int): super().__init__() self.n_nodes = n_nodes - self.params = nn.Parameter( - nn.init.xavier_normal_( - torch.empty(self.n_nodes + 1, self.n_nodes * self.n_nodes + 1) - ) + self.action_type_conv = GCNConv(n_nodes, 4) + self.edge_index_conv = GCNConv(n_nodes, 16) + self.n_nodes = n_nodes + self.edge_hidden_dim = 16 + + def _group_sum(self, tensor: torch.Tensor, batch_ptr: torch.Tensor) -> torch.Tensor: + cumsum = torch.zeros( + (len(tensor) + 1, *tensor.shape[1:]), + dtype=tensor.dtype, + device=tensor.device, ) - with torch.no_grad(): - self.params[-1] = torch.zeros(self.n_nodes * self.n_nodes + 1) - self.params[-1][-1] = 1.0 + cumsum[1:] = torch.cumsum(tensor, dim=0) + + # Subtract the end val from each batch idx fom the start val of each batch idx. + return cumsum[batch_ptr[1:]] - cumsum[batch_ptr[:-1]] def forward(self, states_tensor: TensorDict) -> torch.Tensor: - batch_size = states_tensor["batch_shape"][0] - if batch_size == 0: - return torch.zeros(0, self.n_nodes * self.n_nodes + 1, requires_grad=True) - - first_idx = states_tensor["node_index"][states_tensor["batch_ptr"][:-1]] - last_idx = states_tensor["node_index"][states_tensor["batch_ptr"][1:] - 1] - n_edges = ( - torch.logical_and( - states_tensor["edge_index"] >= first_idx[:, None, None], - states_tensor["edge_index"] <= last_idx[:, None, None], - ) - .all(dim=-1) - .sum(dim=-1) + node_feature, batch_ptr = states_tensor["node_feature"], states_tensor["batch_ptr"] + + edge_index = torch.where( + states_tensor["edge_index"][..., None] == states_tensor["node_index"] + )[2].reshape(states_tensor["edge_index"].shape) # (M, 2) + + action_type = self.action_type_conv(node_feature, edge_index.T) + action_type = self._group_sum(torch.mean(action_type, dim=-1, keepdim=True), batch_ptr) + + edge_index = self.edge_index_conv(node_feature, edge_index.T) + edge_index = edge_index.reshape( + *states_tensor["batch_shape"], self.n_nodes, self.edge_hidden_dim ) + edge_index = torch.einsum("bnf,bmf->bnm", edge_index, edge_index) - n_edges = torch.clamp(n_edges, 0, self.n_nodes) - return self.params[n_edges] + edge_actions = edge_index.reshape( + *states_tensor["batch_shape"], self.n_nodes * self.n_nodes + ) + return torch.cat([edge_actions, action_type], dim=-1) class RingGraphBuilding(GraphBuilding): @@ -366,8 +414,8 @@ def test(): if __name__ == "__main__": - N_NODES = 3 - N_ITERATIONS = 8192 + N_NODES = 2 + N_ITERATIONS = 2048 torch.random.manual_seed(7) env = RingGraphBuilding(n_nodes=N_NODES) module = RingPolicyEstimator(env.n_nodes) @@ -376,8 +424,8 @@ def test(): module=module, n_actions=env.n_actions, preprocessor=GraphPreprocessor() ) - gflownet = FMGFlowNet(logf_estimator) - optimizer = torch.optim.RMSprop(gflownet.parameters(), lr=0.1, momentum=0.9) + gflownet = FMGFlowNet(logf_estimator, alpha=1) + optimizer = torch.optim.Adam(gflownet.parameters(), lr=0.001) batch_size = 8 losses = [] From e7465b89a0d729af6e9c1ee0378b92cbd6e7cf53 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Tue, 11 Feb 2025 23:58:33 +0100 Subject: [PATCH 060/102] push tentatives --- tutorials/examples/train_graph_ring.py | 112 +++++++++++++++++++++++-- 1 file changed, 103 insertions(+), 9 deletions(-) diff --git a/tutorials/examples/train_graph_ring.py b/tutorials/examples/train_graph_ring.py index 1c09f56d..1750470d 100644 --- a/tutorials/examples/train_graph_ring.py +++ b/tutorials/examples/train_graph_ring.py @@ -10,7 +10,7 @@ from tensordict import TensorDict from torch import nn import torch.nn.functional as F -from torch_geometric.nn import GCNConv +from torch_geometric.nn import GINEConv, GINConv, GCNConv from gfn.actions import Actions, GraphActions, GraphActionType from gfn.gflownet.flow_matching import FMGFlowNet @@ -130,8 +130,10 @@ class RingPolicyEstimator(nn.Module): def __init__(self, n_nodes: int): super().__init__() self.n_nodes = n_nodes - self.action_type_conv = GCNConv(n_nodes, 4) - self.edge_index_conv = GCNConv(n_nodes, 16) + embedding_dim = 16 + self.embedding = nn.Embedding(n_nodes, embedding_dim) + self.action_type_conv = GINConv(nn.Linear(embedding_dim, embedding_dim)) + self.edge_index_conv = GINConv(nn.Linear(embedding_dim, embedding_dim)) self.n_nodes = n_nodes self.edge_hidden_dim = 16 @@ -148,10 +150,12 @@ def _group_sum(self, tensor: torch.Tensor, batch_ptr: torch.Tensor) -> torch.Ten def forward(self, states_tensor: TensorDict) -> torch.Tensor: node_feature, batch_ptr = states_tensor["node_feature"], states_tensor["batch_ptr"] + node_feature = self.embedding(node_feature.squeeze().int()) edge_index = torch.where( states_tensor["edge_index"][..., None] == states_tensor["node_index"] )[2].reshape(states_tensor["edge_index"].shape) # (M, 2) + #edge_attrs = states_tensor["edge_feature"] action_type = self.action_type_conv(node_feature, edge_index.T) action_type = self._group_sum(torch.mean(action_type, dim=-1, keepdim=True), batch_ptr) @@ -168,6 +172,98 @@ def forward(self, states_tensor: TensorDict) -> torch.Tensor: return torch.cat([edge_actions, action_type], dim=-1) +# class RingPolicyEstimator(nn.Module): +# """Module that implements a GCN-based policy estimator for graph-structured data. +# Outputs action logits for both node-level and edge-level actions. + +# Args: +# node_feature_dim: Dimension of input node features +# n_nodes: Number of nodes in the graph +# hidden_dim: Hidden dimension for edge features +# """ +# def __init__(self, node_feature_dim: int, n_nodes: int, hidden_dim: int = 16): +# super().__init__() +# self.n_nodes = n_nodes +# self.edge_hidden_dim = hidden_dim + +# self.feature_embedding = nn.Embedding(n_nodes, node_feature_dim) +# self.action_type_convs = nn.ModuleList([ +# GCNConv(node_feature_dim, hidden_dim), +# GCNConv(hidden_dim, hidden_dim), +# ]) +# self.edge_index_convs = nn.ModuleList([ +# GCNConv(node_feature_dim, hidden_dim), +# GCNConv(hidden_dim, hidden_dim), +# ]) + +# def _group_sum(self, tensor: torch.Tensor, batch_ptr: torch.Tensor) -> torch.Tensor: +# cumsum = torch.zeros( +# (len(tensor) + 1, *tensor.shape[1:]), +# dtype=tensor.dtype, +# device=tensor.device, +# ) +# cumsum[1:] = torch.cumsum(tensor, dim=0) +# return cumsum[batch_ptr[1:]] - cumsum[batch_ptr[:-1]] + +# def forward(self, states_tensor: TensorDict) -> torch.Tensor: +# """ +# Forward pass of the GCN policy estimator. + +# Args: +# states_tensor: TensorDict containing: +# - node_feature: Node features tensor (N, feature_dim) +# - batch_ptr: Batch pointers for batched graphs +# - edge_index: Edge indices tensor (M, 2) +# - batch_shape: Shape of the batch + +# Returns: +# torch.Tensor: Concatenated logits for edge actions and node actions +# """ +# node_feature = self.feature_embedding(states_tensor["node_feature"].squeeze().int()) +# node_feature = F.relu(node_feature) # Add non-linearity +# batch_ptr = states_tensor["batch_ptr"] +# edge_index = torch.where( +# states_tensor["edge_index"][..., None] == states_tensor["node_index"] +# )[2].reshape(states_tensor["edge_index"].shape).T + +# # Node-level actions through GCN +# action_type = node_feature +# for conv in self.action_type_convs: +# action_type = conv(action_type, edge_index) +# action_type = F.relu(action_type) + +# # Average over feature dimension and sum over batch +# action_type = self._group_sum( +# torch.mean(action_type, dim=-1, keepdim=True), +# batch_ptr +# ) + +# # Edge-level actions through GCN +# edge_embeddings = node_feature +# for conv in self.edge_index_convs: +# edge_embeddings = conv(edge_embeddings, edge_index) +# edge_embeddings = F.relu(edge_embeddings) # Add non-linearity + +# # Reshape to proper batch dimensions +# edge_embeddings = edge_embeddings.reshape( +# -1, # Flatten batch dimension +# self.n_nodes, +# self.edge_hidden_dim +# ) + +# # Compute pairwise interactions between nodes +# edge_scores = torch.einsum('bnf,bmf->bnm', edge_embeddings, edge_embeddings) + +# # Reshape to final output format +# batch_size = len(batch_ptr) - 1 +# edge_actions = edge_scores.reshape( +# batch_size, +# self.n_nodes**2 +# ) + +# # Concatenate edge and node action logits +# return torch.cat([edge_actions, action_type], dim=-1) + class RingGraphBuilding(GraphBuilding): """Override the GraphBuilding class to create have discrete actions. @@ -202,9 +298,7 @@ def make_states_class(self) -> type[GraphStates]: class RingStates(GraphStates): s0 = TensorDict( { - "node_feature": F.one_hot( - torch.arange(env.n_nodes), num_classes=env.n_nodes - ).float(), + "node_feature": torch.arange(env.n_nodes)[:, None], "edge_feature": torch.ones((0, 1)), "edge_index": torch.ones((0, 2), dtype=torch.long), }, @@ -212,7 +306,7 @@ class RingStates(GraphStates): ) sf = TensorDict( { - "node_feature": torch.zeros((env.n_nodes, env.n_nodes)), + "node_feature": -torch.ones(env.n_nodes)[:, None], "edge_feature": torch.zeros((0, 1)), "edge_index": torch.zeros((0, 2), dtype=torch.long), }, @@ -414,7 +508,7 @@ def test(): if __name__ == "__main__": - N_NODES = 2 + N_NODES = 3 N_ITERATIONS = 2048 torch.random.manual_seed(7) env = RingGraphBuilding(n_nodes=N_NODES) @@ -426,7 +520,7 @@ def test(): gflownet = FMGFlowNet(logf_estimator, alpha=1) optimizer = torch.optim.Adam(gflownet.parameters(), lr=0.001) - batch_size = 8 + batch_size = 256 losses = [] From 38041e895dd5b44e7667c73b207d344bd43d58b6 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Wed, 12 Feb 2025 19:53:19 +0100 Subject: [PATCH 061/102] trying TBGFN --- tutorials/examples/train_graph_ring.py | 172 ++++--------------------- 1 file changed, 23 insertions(+), 149 deletions(-) diff --git a/tutorials/examples/train_graph_ring.py b/tutorials/examples/train_graph_ring.py index 1750470d..f99e6db6 100644 --- a/tutorials/examples/train_graph_ring.py +++ b/tutorials/examples/train_graph_ring.py @@ -14,6 +14,7 @@ from gfn.actions import Actions, GraphActions, GraphActionType from gfn.gflownet.flow_matching import FMGFlowNet +from gfn.gflownet.trajectory_balance import TBGFlowNet from gfn.gym import GraphBuilding from gfn.modules import DiscretePolicyEstimator from gfn.preprocessors import Preprocessor @@ -81,45 +82,6 @@ def set_diff(tensor1, tensor2): return out.view(*states.batch_shape) - -# class RingPolicyEstimator(nn.Module): -# """Simple module which outputs a fixed logits for the actions, depending on the number of edges. - -# Args: -# n_nodes: The number of nodes in the graph. -# """ - -# def __init__(self, n_nodes: int): -# super().__init__() -# self.n_nodes = n_nodes -# self.params = nn.Parameter( -# nn.init.xavier_normal_( -# torch.empty(self.n_nodes + 1, self.n_nodes * self.n_nodes + 1) -# ) -# ) -# with torch.no_grad(): -# self.params[-1] = torch.zeros(self.n_nodes * self.n_nodes + 1) -# self.params[-1][-1] = 1.0 - -# def forward(self, states_tensor: TensorDict) -> torch.Tensor: -# batch_size = states_tensor["batch_shape"][0] -# if batch_size == 0: -# return torch.zeros(0, self.n_nodes * self.n_nodes + 1, requires_grad=True) - -# first_idx = states_tensor["node_index"][states_tensor["batch_ptr"][:-1]] -# last_idx = states_tensor["node_index"][states_tensor["batch_ptr"][1:] - 1] -# n_edges = ( -# torch.logical_and( -# states_tensor["edge_index"] >= first_idx[:, None, None], -# states_tensor["edge_index"] <= last_idx[:, None, None], -# ) -# .all(dim=-1) -# .sum(dim=-1) -# ) - -# n_edges = torch.clamp(n_edges, 0, self.n_nodes) -# return self.params[n_edges] - class RingPolicyEstimator(nn.Module): """Simple module which outputs a fixed logits for the actions, depending on the number of edges. @@ -127,7 +89,7 @@ class RingPolicyEstimator(nn.Module): n_nodes: The number of nodes in the graph. """ - def __init__(self, n_nodes: int): + def __init__(self, n_nodes: int, is_backward: bool = False): super().__init__() self.n_nodes = n_nodes embedding_dim = 16 @@ -136,8 +98,9 @@ def __init__(self, n_nodes: int): self.edge_index_conv = GINConv(nn.Linear(embedding_dim, embedding_dim)) self.n_nodes = n_nodes self.edge_hidden_dim = 16 + self.is_backward = is_backward - def _group_sum(self, tensor: torch.Tensor, batch_ptr: torch.Tensor) -> torch.Tensor: + def _group_mean(self, tensor: torch.Tensor, batch_ptr: torch.Tensor) -> torch.Tensor: cumsum = torch.zeros( (len(tensor) + 1, *tensor.shape[1:]), dtype=tensor.dtype, @@ -146,7 +109,8 @@ def _group_sum(self, tensor: torch.Tensor, batch_ptr: torch.Tensor) -> torch.Ten cumsum[1:] = torch.cumsum(tensor, dim=0) # Subtract the end val from each batch idx fom the start val of each batch idx. - return cumsum[batch_ptr[1:]] - cumsum[batch_ptr[:-1]] + size = batch_ptr[1:] - batch_ptr[:-1] + return (cumsum[batch_ptr[1:]] - cumsum[batch_ptr[:-1]]) / size[:, None] def forward(self, states_tensor: TensorDict) -> torch.Tensor: node_feature, batch_ptr = states_tensor["node_feature"], states_tensor["batch_ptr"] @@ -158,7 +122,7 @@ def forward(self, states_tensor: TensorDict) -> torch.Tensor: #edge_attrs = states_tensor["edge_feature"] action_type = self.action_type_conv(node_feature, edge_index.T) - action_type = self._group_sum(torch.mean(action_type, dim=-1, keepdim=True), batch_ptr) + action_type = self._group_mean(torch.mean(action_type, dim=-1, keepdim=True), batch_ptr) edge_index = self.edge_index_conv(node_feature, edge_index.T) edge_index = edge_index.reshape( @@ -169,100 +133,11 @@ def forward(self, states_tensor: TensorDict) -> torch.Tensor: edge_actions = edge_index.reshape( *states_tensor["batch_shape"], self.n_nodes * self.n_nodes ) - return torch.cat([edge_actions, action_type], dim=-1) - - -# class RingPolicyEstimator(nn.Module): -# """Module that implements a GCN-based policy estimator for graph-structured data. -# Outputs action logits for both node-level and edge-level actions. - -# Args: -# node_feature_dim: Dimension of input node features -# n_nodes: Number of nodes in the graph -# hidden_dim: Hidden dimension for edge features -# """ -# def __init__(self, node_feature_dim: int, n_nodes: int, hidden_dim: int = 16): -# super().__init__() -# self.n_nodes = n_nodes -# self.edge_hidden_dim = hidden_dim - -# self.feature_embedding = nn.Embedding(n_nodes, node_feature_dim) -# self.action_type_convs = nn.ModuleList([ -# GCNConv(node_feature_dim, hidden_dim), -# GCNConv(hidden_dim, hidden_dim), -# ]) -# self.edge_index_convs = nn.ModuleList([ -# GCNConv(node_feature_dim, hidden_dim), -# GCNConv(hidden_dim, hidden_dim), -# ]) - -# def _group_sum(self, tensor: torch.Tensor, batch_ptr: torch.Tensor) -> torch.Tensor: -# cumsum = torch.zeros( -# (len(tensor) + 1, *tensor.shape[1:]), -# dtype=tensor.dtype, -# device=tensor.device, -# ) -# cumsum[1:] = torch.cumsum(tensor, dim=0) -# return cumsum[batch_ptr[1:]] - cumsum[batch_ptr[:-1]] - -# def forward(self, states_tensor: TensorDict) -> torch.Tensor: -# """ -# Forward pass of the GCN policy estimator. - -# Args: -# states_tensor: TensorDict containing: -# - node_feature: Node features tensor (N, feature_dim) -# - batch_ptr: Batch pointers for batched graphs -# - edge_index: Edge indices tensor (M, 2) -# - batch_shape: Shape of the batch - -# Returns: -# torch.Tensor: Concatenated logits for edge actions and node actions -# """ -# node_feature = self.feature_embedding(states_tensor["node_feature"].squeeze().int()) -# node_feature = F.relu(node_feature) # Add non-linearity -# batch_ptr = states_tensor["batch_ptr"] -# edge_index = torch.where( -# states_tensor["edge_index"][..., None] == states_tensor["node_index"] -# )[2].reshape(states_tensor["edge_index"].shape).T - -# # Node-level actions through GCN -# action_type = node_feature -# for conv in self.action_type_convs: -# action_type = conv(action_type, edge_index) -# action_type = F.relu(action_type) - -# # Average over feature dimension and sum over batch -# action_type = self._group_sum( -# torch.mean(action_type, dim=-1, keepdim=True), -# batch_ptr -# ) - -# # Edge-level actions through GCN -# edge_embeddings = node_feature -# for conv in self.edge_index_convs: -# edge_embeddings = conv(edge_embeddings, edge_index) -# edge_embeddings = F.relu(edge_embeddings) # Add non-linearity - -# # Reshape to proper batch dimensions -# edge_embeddings = edge_embeddings.reshape( -# -1, # Flatten batch dimension -# self.n_nodes, -# self.edge_hidden_dim -# ) - -# # Compute pairwise interactions between nodes -# edge_scores = torch.einsum('bnf,bmf->bnm', edge_embeddings, edge_embeddings) - -# # Reshape to final output format -# batch_size = len(batch_ptr) - 1 -# edge_actions = edge_scores.reshape( -# batch_size, -# self.n_nodes**2 -# ) - -# # Concatenate edge and node action logits -# return torch.cat([edge_actions, action_type], dim=-1) + if self.is_backward: + return edge_actions + else: + return torch.cat([edge_actions, action_type], dim=-1) + class RingGraphBuilding(GraphBuilding): """Override the GraphBuilding class to create have discrete actions. @@ -346,7 +221,7 @@ def forward_masks(self, value: torch.Tensor): @property def backward_masks(self): backward_masks = torch.zeros( - len(self), self.n_actions, dtype=torch.bool + len(self), self.n_actions - 1, dtype=torch.bool ) for i in range(len(self)): existing_edges = ( @@ -358,7 +233,7 @@ def backward_masks(self): existing_edges[:, 0] * self.n_nodes + existing_edges[:, 1], ] = True - return backward_masks.view(*self.batch_shape, self.n_actions) + return backward_masks.view(*self.batch_shape, self.n_actions - 1) @backward_masks.setter def backward_masks(self, value: torch.Tensor): @@ -509,18 +384,17 @@ def test(): if __name__ == "__main__": N_NODES = 3 - N_ITERATIONS = 2048 + N_ITERATIONS = 4096 torch.random.manual_seed(7) env = RingGraphBuilding(n_nodes=N_NODES) - module = RingPolicyEstimator(env.n_nodes) - - logf_estimator = DiscretePolicyEstimator( - module=module, n_actions=env.n_actions, preprocessor=GraphPreprocessor() - ) - - gflownet = FMGFlowNet(logf_estimator, alpha=1) - optimizer = torch.optim.Adam(gflownet.parameters(), lr=0.001) - batch_size = 256 + module_pf = RingPolicyEstimator(env.n_nodes) + module_pb = RingPolicyEstimator(env.n_nodes, is_backward=True) + pf = DiscretePolicyEstimator(module=module_pf, n_actions=env.n_actions, preprocessor=GraphPreprocessor()) + pb = DiscretePolicyEstimator(module=module_pb, n_actions=env.n_actions, preprocessor=GraphPreprocessor(), is_backward=True) + + gflownet = TBGFlowNet(pf, pb) + optimizer = torch.optim.Adam(gflownet.parameters(), lr=0.005) + batch_size = 128 losses = [] From 2f4b3f9c594f47033dcc595b7041ff5be6ff1ee0 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Thu, 13 Feb 2025 21:27:04 -0500 Subject: [PATCH 062/102] increased capacity of RingPolicyEstimator --- tutorials/examples/train_graph_ring.py | 72 ++++++++++++++++++-------- 1 file changed, 49 insertions(+), 23 deletions(-) diff --git a/tutorials/examples/train_graph_ring.py b/tutorials/examples/train_graph_ring.py index f99e6db6..6a2045cf 100644 --- a/tutorials/examples/train_graph_ring.py +++ b/tutorials/examples/train_graph_ring.py @@ -9,11 +9,9 @@ import torch from tensordict import TensorDict from torch import nn -import torch.nn.functional as F -from torch_geometric.nn import GINEConv, GINConv, GCNConv +from torch_geometric.nn import GINConv from gfn.actions import Actions, GraphActions, GraphActionType -from gfn.gflownet.flow_matching import FMGFlowNet from gfn.gflownet.trajectory_balance import TBGFlowNet from gfn.gym import GraphBuilding from gfn.modules import DiscretePolicyEstimator @@ -78,10 +76,19 @@ def set_diff(tensor1, tensor2): # Check if we visited all vertices and the last vertex connects back to start. if len(visited) == n_nodes and adj_matrix[current][0] == 1: - out[i] = 1.0 + out[i] = 100.0 # 1.0 return out.view(*states.batch_shape) + +def create_mlp(in_channels, hidden_channels, out_channels): + return nn.Sequential( + nn.Linear(in_channels, hidden_channels), + nn.ReLU(), + nn.Linear(hidden_channels, out_channels), + ) + + class RingPolicyEstimator(nn.Module): """Simple module which outputs a fixed logits for the actions, depending on the number of edges. @@ -91,16 +98,20 @@ class RingPolicyEstimator(nn.Module): def __init__(self, n_nodes: int, is_backward: bool = False): super().__init__() - self.n_nodes = n_nodes - embedding_dim = 16 - self.embedding = nn.Embedding(n_nodes, embedding_dim) - self.action_type_conv = GINConv(nn.Linear(embedding_dim, embedding_dim)) - self.edge_index_conv = GINConv(nn.Linear(embedding_dim, embedding_dim)) - self.n_nodes = n_nodes - self.edge_hidden_dim = 16 + embedding_dim = 32 + self.edge_hidden_dim = 32 self.is_backward = is_backward + self.n_nodes = n_nodes - def _group_mean(self, tensor: torch.Tensor, batch_ptr: torch.Tensor) -> torch.Tensor: + self.embedding = nn.Embedding(n_nodes, embedding_dim) + mlp_action = create_mlp(embedding_dim, embedding_dim, embedding_dim) + self.action_type_conv = GINConv(mlp_action) + mlp_edge = create_mlp(embedding_dim, embedding_dim, embedding_dim) + self.edge_index_conv = GINConv(mlp_edge) + + def _group_mean( + self, tensor: torch.Tensor, batch_ptr: torch.Tensor + ) -> torch.Tensor: cumsum = torch.zeros( (len(tensor) + 1, *tensor.shape[1:]), dtype=tensor.dtype, @@ -108,21 +119,28 @@ def _group_mean(self, tensor: torch.Tensor, batch_ptr: torch.Tensor) -> torch.Te ) cumsum[1:] = torch.cumsum(tensor, dim=0) - # Subtract the end val from each batch idx fom the start val of each batch idx. + # Subtract the end val from each batch idx fom the start val of each batch idx. size = batch_ptr[1:] - batch_ptr[:-1] return (cumsum[batch_ptr[1:]] - cumsum[batch_ptr[:-1]]) / size[:, None] def forward(self, states_tensor: TensorDict) -> torch.Tensor: - node_feature, batch_ptr = states_tensor["node_feature"], states_tensor["batch_ptr"] + node_feature, batch_ptr = ( + states_tensor["node_feature"], + states_tensor["batch_ptr"], + ) node_feature = self.embedding(node_feature.squeeze().int()) edge_index = torch.where( states_tensor["edge_index"][..., None] == states_tensor["node_index"] - )[2].reshape(states_tensor["edge_index"].shape) # (M, 2) - #edge_attrs = states_tensor["edge_feature"] + )[2].reshape( + states_tensor["edge_index"].shape + ) # (M, 2) + # edge_attrs = states_tensor["edge_feature"] action_type = self.action_type_conv(node_feature, edge_index.T) - action_type = self._group_mean(torch.mean(action_type, dim=-1, keepdim=True), batch_ptr) + action_type = self._group_mean( + torch.mean(action_type, dim=-1, keepdim=True), batch_ptr + ) edge_index = self.edge_index_conv(node_feature, edge_index.T) edge_index = edge_index.reshape( @@ -385,23 +403,31 @@ def test(): if __name__ == "__main__": N_NODES = 3 N_ITERATIONS = 4096 + LR = 0.005 + BATCH_SIZE = 128 + torch.random.manual_seed(7) env = RingGraphBuilding(n_nodes=N_NODES) module_pf = RingPolicyEstimator(env.n_nodes) module_pb = RingPolicyEstimator(env.n_nodes, is_backward=True) - pf = DiscretePolicyEstimator(module=module_pf, n_actions=env.n_actions, preprocessor=GraphPreprocessor()) - pb = DiscretePolicyEstimator(module=module_pb, n_actions=env.n_actions, preprocessor=GraphPreprocessor(), is_backward=True) - + pf = DiscretePolicyEstimator( + module=module_pf, n_actions=env.n_actions, preprocessor=GraphPreprocessor() + ) + pb = DiscretePolicyEstimator( + module=module_pb, + n_actions=env.n_actions, + preprocessor=GraphPreprocessor(), + is_backward=True, + ) gflownet = TBGFlowNet(pf, pb) - optimizer = torch.optim.Adam(gflownet.parameters(), lr=0.005) - batch_size = 128 + optimizer = torch.optim.Adam(gflownet.parameters(), lr=LR) losses = [] t1 = time.time() for iteration in range(N_ITERATIONS): trajectories = gflownet.sample_trajectories( - env, n=batch_size # pyright: ignore + env, n=BATCH_SIZE # pyright: ignore ) last_states = trajectories.last_states assert isinstance(last_states, GraphStates) From d0313d1869f0ce707607d3a633b5e09ef9e5a10b Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Tue, 18 Feb 2025 00:15:42 +0100 Subject: [PATCH 063/102] make undirected graph --- tutorials/examples/train_graph_ring.py | 162 ++++++++++--------------- 1 file changed, 66 insertions(+), 96 deletions(-) diff --git a/tutorials/examples/train_graph_ring.py b/tutorials/examples/train_graph_ring.py index f99e6db6..6f1f4e14 100644 --- a/tutorials/examples/train_graph_ring.py +++ b/tutorials/examples/train_graph_ring.py @@ -24,7 +24,7 @@ def state_evaluator(states: GraphStates) -> torch.Tensor: """Compute the reward of a graph. - Specifically, the reward is 1 if the graph is a ring, 1e-6 otherwise. + Specifically, the reward is 1 if the graph is an undirected ring, 1e-6 otherwise. Args: states: A batch of graphs. @@ -32,7 +32,7 @@ def state_evaluator(states: GraphStates) -> torch.Tensor: Returns: A tensor of rewards. """ - eps = 1e-6 + eps = 1e-4 if states.tensor["edge_index"].shape[0] == 0: return torch.full(states.batch_shape, eps) @@ -49,36 +49,44 @@ def state_evaluator(states: GraphStates) -> torch.Tensor: ) n_nodes = nodes_index_range.shape[0] - adj_matrix = torch.zeros(n_nodes, n_nodes) - adj_matrix[masked_edge_index[:, 0], masked_edge_index[:, 1]] = 1 - - if not torch.all(adj_matrix.sum(axis=1) == 1): + if n_nodes == 0: continue - visited = [] - current = 0 - while current not in visited: - visited.append(current) + # Construct a symmetric adjacency matrix for the undirected graph. + adj_matrix = torch.zeros(n_nodes, n_nodes) + if masked_edge_index.shape[0] > 0: + adj_matrix[masked_edge_index[:, 0], masked_edge_index[:, 1]] = 1 + adj_matrix[masked_edge_index[:, 1], masked_edge_index[:, 0]] = 1 - def set_diff(tensor1, tensor2): - mask = ~torch.isin(tensor1, tensor2) - return tensor1[mask] + # In an undirected ring, every vertex should have degree 2. + if not torch.all(adj_matrix.sum(dim=1) == 2): + continue - # Find an unvisited neighbor - neighbors = torch.where(adj_matrix[current] == 1)[0] - valid_neighbours = set_diff(neighbors, torch.tensor(visited)) + # Traverse the cycle starting from vertex 0. + start_vertex = 0 + visited = [start_vertex] + neighbors = torch.where(adj_matrix[start_vertex] == 1)[0] + if neighbors.numel() == 0: + continue + # Arbitrarily choose one neighbor to begin the traversal. + current = neighbors[0].item() + prev = start_vertex - # Visit the fir - if len(valid_neighbours) == 1: - current = valid_neighbours[0] - elif len(valid_neighbours) == 0: + while True: + if current == start_vertex: + break + visited.append(current) + current_neighbors = torch.where(adj_matrix[current] == 1)[0] + # Exclude the neighbor we just came from. + current_neighbors_list = [n.item() for n in current_neighbors] + possible = [n for n in current_neighbors_list if n != prev] + if len(possible) != 1: break - else: - break # TODO: This actually should never happen, should be caught on line 45. + next_node = possible[0] + prev, current = current, next_node - # Check if we visited all vertices and the last vertex connects back to start. - if len(visited) == n_nodes and adj_matrix[current][0] == 1: - out[i] = 1.0 + if current == start_vertex and len(visited) == n_nodes: + out[i] = 10.0 return out.view(*states.batch_shape) @@ -92,12 +100,10 @@ class RingPolicyEstimator(nn.Module): def __init__(self, n_nodes: int, is_backward: bool = False): super().__init__() self.n_nodes = n_nodes - embedding_dim = 16 - self.embedding = nn.Embedding(n_nodes, embedding_dim) - self.action_type_conv = GINConv(nn.Linear(embedding_dim, embedding_dim)) - self.edge_index_conv = GINConv(nn.Linear(embedding_dim, embedding_dim)) - self.n_nodes = n_nodes - self.edge_hidden_dim = 16 + self.embedding_dim = 32 + self.embedding = nn.Embedding(n_nodes, self.embedding_dim) + self.action_type_conv = GINConv(nn.Linear(self.embedding_dim, self.embedding_dim)) + self.edge_index_conv = GINConv(nn.Linear(self.embedding_dim, self.embedding_dim)) self.is_backward = is_backward def _group_mean(self, tensor: torch.Tensor, batch_ptr: torch.Tensor) -> torch.Tensor: @@ -115,6 +121,7 @@ def _group_mean(self, tensor: torch.Tensor, batch_ptr: torch.Tensor) -> torch.Te def forward(self, states_tensor: TensorDict) -> torch.Tensor: node_feature, batch_ptr = states_tensor["node_feature"], states_tensor["batch_ptr"] node_feature = self.embedding(node_feature.squeeze().int()) + batch_size = int(torch.prod(states_tensor["batch_shape"])) edge_index = torch.where( states_tensor["edge_index"][..., None] == states_tensor["node_index"] @@ -126,12 +133,15 @@ def forward(self, states_tensor: TensorDict) -> torch.Tensor: edge_index = self.edge_index_conv(node_feature, edge_index.T) edge_index = edge_index.reshape( - *states_tensor["batch_shape"], self.n_nodes, self.edge_hidden_dim + batch_size, self.n_nodes, self.embedding_dim ) edge_index = torch.einsum("bnf,bmf->bnm", edge_index, edge_index) - edge_actions = edge_index.reshape( - *states_tensor["batch_shape"], self.n_nodes * self.n_nodes + i0, i1 = torch.triu_indices(self.n_nodes, self.n_nodes, offset=1) + batch_arange = torch.arange(batch_size) + edge_actions = edge_index[batch_arange[:, None, None], i0, i1] + edge_actions = edge_actions.reshape( + *states_tensor["batch_shape"], self.n_nodes * (self.n_nodes - 1) // 2 ) if self.is_backward: return edge_actions @@ -153,7 +163,7 @@ class RingGraphBuilding(GraphBuilding): def __init__(self, n_nodes: int): self.n_nodes = n_nodes - self.n_actions = 1 + n_nodes * n_nodes + self.n_actions = 1 + n_nodes * (n_nodes - 1) // 2 super().__init__(feature_dim=n_nodes, state_evaluator=state_evaluator) self.is_discrete = True # actions here are discrete, needed for FlowMatching @@ -200,17 +210,15 @@ def __init__(self, tensor: TensorDict): @property def forward_masks(self): forward_masks = torch.ones(len(self), self.n_actions, dtype=torch.bool) - forward_masks[:, :: self.n_nodes + 1] = False for i in range(len(self)): existing_edges = ( self[i].tensor["edge_index"] - self.tensor["node_index"][self.tensor["batch_ptr"][i]] ) assert torch.all(existing_edges >= 0) # TODO: convert to test. - forward_masks[ - i, - existing_edges[:, 0] * self.n_nodes + existing_edges[:, 1], - ] = False + edge = existing_edges[:, 0] * (2 * self.n_nodes - existing_edges[:, 0] - 1) // 2 + edge += (existing_edges[:, 1] - existing_edges[:, 0] - 1) + forward_masks[i, edge] = False return forward_masks.view(*self.batch_shape, self.n_actions) @@ -228,10 +236,9 @@ def backward_masks(self): self[i].tensor["edge_index"] - self.tensor["node_index"][self.tensor["batch_ptr"][i]] ) - backward_masks[ - i, - existing_edges[:, 0] * self.n_nodes + existing_edges[:, 1], - ] = True + edge = existing_edges[:, 0] * (2 * self.n_nodes - existing_edges[:, 0] - 1) // 2 + edge += (existing_edges[:, 1] - existing_edges[:, 0] - 1) + backward_masks[i, edge,] = True return backward_masks.view(*self.batch_shape, self.n_actions - 1) @@ -254,32 +261,30 @@ def _backward_step(self, states: GraphStates, actions: Actions) -> GraphStates: return new_states def convert_actions(self, states: GraphStates, actions: Actions) -> GraphActions: - action_tensor = actions.tensor.squeeze(-1) + action_tensor = actions.tensor.squeeze(-1).clone() action_type = torch.where( action_tensor == self.n_actions - 1, GraphActionType.EXIT, GraphActionType.ADD_EDGE, ) + action_type[action_tensor == self.n_actions] = GraphActionType.DUMMY - # TODO: refactor. - # action = [exit, src_node_idx, dest_node_idx] <- 2n + 1 - # action = [node] <- n^2 + 1 + ei0, ei1 = torch.triu_indices(self.n_nodes, self.n_nodes, offset=1) + action_tensor[action_tensor >= (self.n_actions - 1)] = 0 + ei0, ei1 = ei0[action_tensor], ei1[action_tensor] - edge_index_i0 = action_tensor // (self.n_nodes) # column - edge_index_i1 = action_tensor % (self.n_nodes) # row - - edge_index = torch.stack([edge_index_i0, edge_index_i1], dim=-1) offset = states.tensor["node_index"][states.tensor["batch_ptr"][:-1]] - return GraphActions( + out = GraphActions( TensorDict( { "action_type": action_type, "features": torch.ones(action_tensor.shape + (1,)), - "edge_index": edge_index + offset[:, None], + "edge_index": torch.stack([ei0, ei1], dim=-1) + offset[:, None], }, batch_size=action_tensor.shape, ) ) + return out class GraphPreprocessor(Preprocessor): @@ -332,22 +337,11 @@ def render_states(states: GraphStates): dx, dy = dx / length, dy / length circle_radius = 0.5 - head_thickness = 0.2 - start_x += dx * (circle_radius) - start_y += dy * (circle_radius) - end_x -= dx * (circle_radius + head_thickness) - end_y -= dy * (circle_radius + head_thickness) - - current_ax.arrow( - start_x, - start_y, - end_x - start_x, - end_y - start_y, - head_width=head_thickness, - head_length=head_thickness, - fc="black", - ec="black", - ) + start_x += dx * circle_radius + start_y += dy * circle_radius + end_x -= dx * circle_radius + end_y -= dy * circle_radius + current_ax.plot([start_x, end_x], [start_y, end_y], color="black") current_ax.set_title(f"State {i}, $r={rewards[i]:.2f}$") current_ax.set_xlim(-(radius + 1), radius + 1) @@ -358,33 +352,9 @@ def render_states(states: GraphStates): plt.show() - -def test(): - module = RingPolicyEstimator(4) - - states = GraphStates( - TensorDict( - { - "node_feature": torch.eye(4), - "edge_feature": torch.zeros((0, 1)), - "edge_index": torch.zeros((0, 2), dtype=torch.long), - "node_index": torch.arange(4), - "batch_ptr": torch.tensor([0, 4]), - "batch_shape": torch.tensor([1]), - } - ) - ) - - out = module(states.tensor) - loss = torch.sum(out) - loss.backward() - print("Params:", [p for p in module.params]) - print("Gradients:", [p.grad for p in module.params]) - - if __name__ == "__main__": - N_NODES = 3 - N_ITERATIONS = 4096 + N_NODES = 4 + N_ITERATIONS = 256 torch.random.manual_seed(7) env = RingGraphBuilding(n_nodes=N_NODES) module_pf = RingPolicyEstimator(env.n_nodes) From 9d994dac6373c12a7f91bdff7dc4c2a343e27757 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Tue, 18 Feb 2025 11:06:20 -0500 Subject: [PATCH 064/102] merge over --- src/gfn/states.py | 11 +-- tutorials/examples/train_graph_ring.py | 99 ++++++++++++++++++++------ 2 files changed, 82 insertions(+), 28 deletions(-) diff --git a/src/gfn/states.py b/src/gfn/states.py index b143539d..386ed5d6 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -942,15 +942,10 @@ def forward_masks(self) -> TensorDict: edge_index = torch.where( self.tensor["edge_index"][..., None] == self.tensor["node_index"] )[2].reshape(self.tensor["edge_index"].shape) - ei1 = edge_index[..., 0] - ei2 = edge_index[..., 1] + i, j = edge_index[..., 0], edge_index[..., 1] + for _ in range(len(self.batch_shape)): - ( - ei1, - ei2, - ) = ei1.unsqueeze( - 0 - ), ei2.unsqueeze(0) + (i, j) = ei1.unsqueeze(0), ei2.unsqueeze(0) # First allow nodes in the same graph to connect, then disable nodes with existing edges forward_masks["edge_index"][ diff --git a/tutorials/examples/train_graph_ring.py b/tutorials/examples/train_graph_ring.py index 6a2045cf..5bae1f4f 100644 --- a/tutorials/examples/train_graph_ring.py +++ b/tutorials/examples/train_graph_ring.py @@ -10,6 +10,7 @@ from tensordict import TensorDict from torch import nn from torch_geometric.nn import GINConv +import torch.nn.functional as F from gfn.actions import Actions, GraphActions, GraphActionType from gfn.gflownet.trajectory_balance import TBGFlowNet @@ -81,13 +82,37 @@ def set_diff(tensor1, tensor2): return out.view(*states.batch_shape) -def create_mlp(in_channels, hidden_channels, out_channels): - return nn.Sequential( - nn.Linear(in_channels, hidden_channels), - nn.ReLU(), - nn.Linear(hidden_channels, out_channels), - ) - +def create_mlp(in_channels, hidden_channels, out_channels, num_layers=1): + """ + Create a Multi-Layer Perceptron with configurable number of layers. + + Args: + in_channels (int): Number of input features + hidden_channels (int): Number of hidden features per layer + out_channels (int): Number of output features + num_layers (int): Number of hidden layers (default: 2) + + Returns: + nn.Sequential: MLP model + """ + layers = [] + + # Input layer + layers.append(nn.Linear(in_channels, hidden_channels)) + layers.append(nn.LayerNorm(hidden_channels)) + layers.append(nn.ReLU()) + + # Hidden layers + for _ in range(num_layers - 1): + layers.append(nn.Linear(hidden_channels, hidden_channels)) + layers.append(nn.LayerNorm(hidden_channels)) + layers.append(nn.ReLU()) + + # Output layer + layers.append(nn.Linear(hidden_channels, out_channels)) + + return nn.Sequential(*layers) + class RingPolicyEstimator(nn.Module): """Simple module which outputs a fixed logits for the actions, depending on the number of edges. @@ -96,18 +121,33 @@ class RingPolicyEstimator(nn.Module): n_nodes: The number of nodes in the graph. """ - def __init__(self, n_nodes: int, is_backward: bool = False): + def __init__(self, n_nodes: int, num_conv_layers: int = 1, is_backward: bool = False): super().__init__() - embedding_dim = 32 - self.edge_hidden_dim = 32 + embedding_dim = 64 + self.edge_hidden_dim = 64 self.is_backward = is_backward self.n_nodes = n_nodes + self.num_conv_layers = num_conv_layers + + # Node embedding layer self.embedding = nn.Embedding(n_nodes, embedding_dim) - mlp_action = create_mlp(embedding_dim, embedding_dim, embedding_dim) - self.action_type_conv = GINConv(mlp_action) - mlp_edge = create_mlp(embedding_dim, embedding_dim, embedding_dim) - self.edge_index_conv = GINConv(mlp_edge) + + # Multiple action type convolution layers + self.action_type_convs = nn.ModuleList() + for _ in range(num_conv_layers): + mlp = create_mlp(embedding_dim, embedding_dim, embedding_dim) + self.action_type_convs.append(GINConv(mlp)) + + # Multiple edge index convolution layers + self.edge_index_convs = nn.ModuleList() + for _ in range(num_conv_layers): + mlp = create_mlp(embedding_dim, embedding_dim, embedding_dim) + self.edge_index_convs.append(GINConv(mlp)) + + # Layer normalization for stability + self.action_norm = nn.LayerNorm(embedding_dim) + self.edge_norm = nn.LayerNorm(embedding_dim) def _group_mean( self, tensor: torch.Tensor, batch_ptr: torch.Tensor @@ -128,7 +168,7 @@ def forward(self, states_tensor: TensorDict) -> torch.Tensor: states_tensor["node_feature"], states_tensor["batch_ptr"], ) - node_feature = self.embedding(node_feature.squeeze().int()) + x = self.embedding(node_feature.squeeze().int()) edge_index = torch.where( states_tensor["edge_index"][..., None] == states_tensor["node_index"] @@ -136,21 +176,36 @@ def forward(self, states_tensor: TensorDict) -> torch.Tensor: states_tensor["edge_index"].shape ) # (M, 2) # edge_attrs = states_tensor["edge_feature"] + action_type = x + + # Multiple action type convolutions with residual connections. + for conv in self.action_type_convs: + action_type_new = conv(action_type, edge_index.T) + action_type = action_type + action_type_new # Residual connection + action_type = self.action_norm(action_type) + action_type = F.relu(action_type) - action_type = self.action_type_conv(node_feature, edge_index.T) action_type = self._group_mean( torch.mean(action_type, dim=-1, keepdim=True), batch_ptr ) - edge_index = self.edge_index_conv(node_feature, edge_index.T) - edge_index = edge_index.reshape( + # Multiple edge index convolutions with residual connections + edge_feature = x + for conv in self.edge_index_convs: + edge_feature_new = conv(edge_feature, edge_index.T) + edge_feature = edge_feature + edge_feature_new # Residual connection + edge_feature = self.edge_norm(edge_feature) + edge_feature = F.relu(edge_feature) + + edge_feature = edge_feature.reshape( *states_tensor["batch_shape"], self.n_nodes, self.edge_hidden_dim ) - edge_index = torch.einsum("bnf,bmf->bnm", edge_index, edge_index) + edge_index = torch.einsum("bnf,bmf->bnm", edge_feature, edge_feature) edge_actions = edge_index.reshape( *states_tensor["batch_shape"], self.n_nodes * self.n_nodes ) + if self.is_backward: return edge_actions else: @@ -217,8 +272,12 @@ def __init__(self, tensor: TensorDict): @property def forward_masks(self): + # Allow all actions. forward_masks = torch.ones(len(self), self.n_actions, dtype=torch.bool) - forward_masks[:, :: self.n_nodes + 1] = False + + forward_masks[:, :: self.n_nodes + 1] = False # Remove self-loops. + + # Remove existing edges.s for i in range(len(self)): existing_edges = ( self[i].tensor["edge_index"] From 0625f5c56dccec014a5ea34c81cf0fd523957d39 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Tue, 18 Feb 2025 15:43:22 -0500 Subject: [PATCH 065/102] allows for configurable policy capacity, and the code simultaneously handles directed and undirected versions --- tutorials/examples/train_graph_ring.py | 206 +++++++++++++++++++------ 1 file changed, 160 insertions(+), 46 deletions(-) diff --git a/tutorials/examples/train_graph_ring.py b/tutorials/examples/train_graph_ring.py index 1285de50..cacdada6 100644 --- a/tutorials/examples/train_graph_ring.py +++ b/tutorials/examples/train_graph_ring.py @@ -9,7 +9,7 @@ import torch from tensordict import TensorDict from torch import nn -from torch_geometric.nn import GINConv +from torch_geometric.nn import GINConv, GCNConv import torch.nn.functional as F from gfn.actions import Actions, GraphActions, GraphActionType @@ -20,7 +20,69 @@ from gfn.states import GraphStates -def state_evaluator(states: GraphStates) -> torch.Tensor: +def directed_reward(states: GraphStates) -> torch.Tensor: + """Compute the reward of a graph. + + Specifically, the reward is 1 if the graph is a ring, 1e-6 otherwise. + + Args: + states: A batch of graphs. + + Returns: + A tensor of rewards. + """ + eps = 1e-6 + if states.tensor["edge_index"].shape[0] == 0: + return torch.full(states.batch_shape, eps) + + out = torch.full((len(states),), eps) # Default reward. + + for i in range(len(states)): + start, end = states.tensor["batch_ptr"][i], states.tensor["batch_ptr"][i + 1] + nodes_index_range = states.tensor["node_index"][start:end] + edge_index_mask = torch.all( + states.tensor["edge_index"] >= nodes_index_range[0], dim=-1 + ) & torch.all(states.tensor["edge_index"] <= nodes_index_range[-1], dim=-1) + masked_edge_index = ( + states.tensor["edge_index"][edge_index_mask] - nodes_index_range[0] + ) + + n_nodes = nodes_index_range.shape[0] + adj_matrix = torch.zeros(n_nodes, n_nodes) + adj_matrix[masked_edge_index[:, 0], masked_edge_index[:, 1]] = 1 + + if not torch.all(adj_matrix.sum(axis=1) == 1): + continue + + visited = [] + current = 0 + while current not in visited: + visited.append(current) + + def set_diff(tensor1, tensor2): + mask = ~torch.isin(tensor1, tensor2) + return tensor1[mask] + + # Find an unvisited neighbor + neighbors = torch.where(adj_matrix[current] == 1)[0] + valid_neighbours = set_diff(neighbors, torch.tensor(visited)) + + # Visit the fir + if len(valid_neighbours) == 1: + current = valid_neighbours[0] + elif len(valid_neighbours) == 0: + break + else: + break # TODO: This actually should never happen, should be caught on line 45. + + # Check if we visited all vertices and the last vertex connects back to start. + if len(visited) == n_nodes and adj_matrix[current][0] == 1: + out[i] = 1.0 + + return out.view(*states.batch_shape) + + +def undirected_reward(states: GraphStates) -> torch.Tensor: """Compute the reward of a graph. Specifically, the reward is 1 if the graph is an undirected ring, 1e-6 otherwise. @@ -122,41 +184,75 @@ def create_mlp(in_channels, hidden_channels, out_channels, num_layers=1): return nn.Sequential(*layers) -class RingPolicyEstimator(nn.Module): +class RingPolicyModule(nn.Module): """Simple module which outputs a fixed logits for the actions, depending on the number of edges. Args: n_nodes: The number of nodes in the graph. """ - def __init__(self, n_nodes: int, num_conv_layers: int = 1, is_backward: bool = False): + def __init__(self, n_nodes: int, directed: bool, num_conv_layers: int = 1, is_backward: bool = False): super().__init__() - self.edge_hidden_dim = 64 - self.embedding_dim = 32 + self.hidden_dim = self.embedding_dim = 64 self.is_backward = is_backward self.n_nodes = n_nodes self.num_conv_layers = num_conv_layers - # Node embedding layer - self.embedding = nn.Embedding(n_nodes, embedding_dim) - - # Multiple action type convolution layers - self.action_type_convs = nn.ModuleList() - for _ in range(num_conv_layers): - mlp = create_mlp(embedding_dim, embedding_dim, embedding_dim) - self.action_type_convs.append(GINConv(mlp)) - - # Multiple edge index convolution layers - self.edge_index_convs = nn.ModuleList() - for _ in range(num_conv_layers): - mlp = create_mlp(embedding_dim, embedding_dim, embedding_dim) - self.edge_index_convs.append(GINConv(mlp)) + # Node embedding layer. + self.embedding = nn.Embedding(n_nodes, self.embedding_dim) + self.action_conv_blks = nn.ModuleList() + self.edge_conv_blks = nn.ModuleList() + + if directed: + # Multiple action type convolution layers. + for i in range(num_conv_layers): + # mlp = create_mlp(self.embedding_dim, self.embedding_dim, self.embedding_dim) + self.action_conv_blks.extend([ + GCNConv( + self.embedding_dim if i == 0 else self.hidden_dim, + self.hidden_dim, + ), + nn.Linear(self.hidden_dim, self.hidden_dim), + nn.ReLU(), + nn.Linear(self.hidden_dim, self.hidden_dim), + ]) + + # Multiple edge index convolution layers. + for i in range(num_conv_layers): + # mlp = create_mlp(self.embedding_dim, self.embedding_dim, self.embedding_dim) + self.edge_conv_blks.extend([ + GCNConv( + self.embedding_dim if i == 0 else self.hidden_dim, + self.hidden_dim, + ), + nn.Linear(self.hidden_dim, self.hidden_dim), + nn.ReLU(), + nn.Linear(self.hidden_dim, self.hidden_dim), + ]) + else: # Undirected case. + # Multiple action type convolution layers. + for _ in range(num_conv_layers): + self.action_conv_blks.extend([ + GINConv(create_mlp(self.embedding_dim, self.hidden_dim, self.hidden_dim)), + nn.Linear(self.hidden_dim, self.hidden_dim), + nn.ReLU(), + nn.Linear(self.hidden_dim, self.hidden_dim) + ]) + + # Multiple edge index convolution layers. + for _ in range(num_conv_layers): + self.edge_conv_blks.extend([ + GINConv(create_mlp(self.embedding_dim, self.hidden_dim, self.hidden_dim)), + nn.Linear(self.hidden_dim, self.hidden_dim), + nn.ReLU(), + nn.Linear(self.hidden_dim, self.hidden_dim) + ]) # Layer normalization for stability - self.action_norm = nn.LayerNorm(embedding_dim) - self.edge_norm = nn.LayerNorm(embedding_dim) + self.action_norm = nn.LayerNorm(self.hidden_dim) + self.edge_norm = nn.LayerNorm(self.hidden_dim) def _group_mean( self, tensor: torch.Tensor, batch_ptr: torch.Tensor @@ -186,33 +282,49 @@ def forward(self, states_tensor: TensorDict) -> torch.Tensor: states_tensor["edge_index"].shape ) # (M, 2) # edge_attrs = states_tensor["edge_feature"] - action_type = x # Multiple action type convolutions with residual connections. - for conv in self.action_type_convs: - action_type_new = conv(action_type, edge_index.T) - action_type = action_type + action_type_new # Residual connection + action_type = x + for i in range(0, len(self.action_conv_blks), 4): + + # GIN/GCN conv. + action_type_new = self.action_conv_blks[i](action_type, edge_index.T) + # First linear. + action_type_new = self.action_conv_blks[i + 1](action_type_new) + # ReLU. + action_type_new = self.action_conv_blks[i + 2](action_type_new) + # Second linear. + action_type_new = self.action_conv_blks[i + 3](action_type_new) + # Residual connection with original input. + action_type = action_type_new + action_type action_type = self.action_norm(action_type) - action_type = F.relu(action_type) action_type = self._group_mean( torch.mean(action_type, dim=-1, keepdim=True), batch_ptr ) - # Multiple edge index convolutions with residual connections + # Multiple action type convolutions with residual connections. edge_feature = x - for conv in self.edge_index_convs: - edge_feature_new = conv(edge_feature, edge_index.T) - edge_feature = edge_feature + edge_feature_new # Residual connection + for i in range(0, len(self.edge_conv_blks), 4): + + # GIN/GCN conv. + edge_feature_new = self.edge_conv_blks[i](edge_feature, edge_index.T) + # First linear. + edge_feature_new = self.edge_conv_blks[i + 1](edge_feature_new) + # ReLU. + edge_feature_new = self.edge_conv_blks[i + 2](edge_feature_new) + # Second linear. + edge_feature_new = self.edge_conv_blks[i + 3](edge_feature_new) + # Residual connection with original input. + edge_feature = edge_feature_new + edge_feature edge_feature = self.edge_norm(edge_feature) - edge_feature = F.relu(edge_feature) - edge_feature = edge_feature.reshape( - *states_tensor["batch_shape"], self.n_nodes, self.edge_hidden_dim + # edge_feature = self._group_mean( + # torch.mean(edge_feature, dim=-1, keepdim=True), batch_ptr + # ) - edge_index = self.edge_index_conv(node_feature, edge_index.T) - edge_index = edge_index.reshape( - batch_size, self.n_nodes, self.embedding_dim + edge_feature = edge_feature.reshape( + *states_tensor["batch_shape"], self.n_nodes, self.hidden_dim ) edge_index = torch.einsum("bnf,bmf->bnm", edge_feature, edge_feature) @@ -241,7 +353,7 @@ class RingGraphBuilding(GraphBuilding): n_nodes: The number of nodes in the graph. """ - def __init__(self, n_nodes: int): + def __init__(self, n_nodes: int, state_evaluator: callable): self.n_nodes = n_nodes self.n_actions = 1 + n_nodes * (n_nodes - 1) // 2 super().__init__(feature_dim=n_nodes, state_evaluator=state_evaluator) @@ -384,7 +496,7 @@ def __call__(self, states: GraphStates) -> torch.Tensor: return self.preprocess(states) -def render_states(states: GraphStates): +def render_states(states: GraphStates, state_evaluator: callable): """Render the states as a matplotlib plot. Args: @@ -437,15 +549,17 @@ def render_states(states: GraphStates): plt.show() if __name__ == "__main__": - N_NODES = 4 - N_ITERATIONS = 128 - LR = 0.005 + N_NODES = 3 + N_ITERATIONS = 1000 + LR = 0.05 BATCH_SIZE = 128 + DIRECTED = True + state_evaluator = undirected_reward if not DIRECTED else directed_reward torch.random.manual_seed(7) - env = RingGraphBuilding(n_nodes=N_NODES) - module_pf = RingPolicyEstimator(env.n_nodes) - module_pb = RingPolicyEstimator(env.n_nodes, is_backward=True) + env = RingGraphBuilding(n_nodes=N_NODES, state_evaluator=state_evaluator) + module_pf = RingPolicyModule(env.n_nodes, DIRECTED) + module_pb = RingPolicyModule(env.n_nodes, DIRECTED, is_backward=True) pf = DiscretePolicyEstimator( module=module_pf, n_actions=env.n_actions, preprocessor=GraphPreprocessor() ) @@ -486,4 +600,4 @@ def render_states(states: GraphStates): print("Time:", t2 - t1) last_states = trajectories.last_states[:8] assert isinstance(last_states, GraphStates) - render_states(last_states) + render_states(last_states, state_evaluator) From dae3299b659ce97deeecc51dbd1f63db786f1afb Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Tue, 18 Feb 2025 15:43:58 -0500 Subject: [PATCH 066/102] spelling --- src/gfn/samplers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gfn/samplers.py b/src/gfn/samplers.py index 73997dae..c20423f4 100644 --- a/src/gfn/samplers.py +++ b/src/gfn/samplers.py @@ -220,7 +220,7 @@ def sample_trajectories( new_states = env._step(states, actions) sink_states_mask = new_states.is_sink_state - # Increment the step, determine which trajectories are finisihed, and eval + # Increment the step, determine which trajectories are finished, and eval # rewards. step += 1 From 671917b7ebcd7e2e8709304930dcb32d555b8dfb Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Tue, 18 Feb 2025 15:44:19 -0500 Subject: [PATCH 067/102] allows for configurable policy capacity, and the code simultaneously handles directed and undirected versions --- tutorials/examples/train_graph_ring.py | 108 ++++++++++++++++--------- 1 file changed, 70 insertions(+), 38 deletions(-) diff --git a/tutorials/examples/train_graph_ring.py b/tutorials/examples/train_graph_ring.py index cacdada6..aca60f20 100644 --- a/tutorials/examples/train_graph_ring.py +++ b/tutorials/examples/train_graph_ring.py @@ -10,7 +10,6 @@ from tensordict import TensorDict from torch import nn from torch_geometric.nn import GINConv, GCNConv -import torch.nn.functional as F from gfn.actions import Actions, GraphActions, GraphActionType from gfn.gflownet.trajectory_balance import TBGFlowNet @@ -191,7 +190,13 @@ class RingPolicyModule(nn.Module): n_nodes: The number of nodes in the graph. """ - def __init__(self, n_nodes: int, directed: bool, num_conv_layers: int = 1, is_backward: bool = False): + def __init__( + self, + n_nodes: int, + directed: bool, + num_conv_layers: int = 1, + is_backward: bool = False, + ): super().__init__() self.hidden_dim = self.embedding_dim = 64 @@ -199,7 +204,6 @@ def __init__(self, n_nodes: int, directed: bool, num_conv_layers: int = 1, is_ba self.n_nodes = n_nodes self.num_conv_layers = num_conv_layers - # Node embedding layer. self.embedding = nn.Embedding(n_nodes, self.embedding_dim) self.action_conv_blks = nn.ModuleList() @@ -209,46 +213,62 @@ def __init__(self, n_nodes: int, directed: bool, num_conv_layers: int = 1, is_ba # Multiple action type convolution layers. for i in range(num_conv_layers): # mlp = create_mlp(self.embedding_dim, self.embedding_dim, self.embedding_dim) - self.action_conv_blks.extend([ - GCNConv( - self.embedding_dim if i == 0 else self.hidden_dim, - self.hidden_dim, - ), - nn.Linear(self.hidden_dim, self.hidden_dim), - nn.ReLU(), - nn.Linear(self.hidden_dim, self.hidden_dim), - ]) + self.action_conv_blks.extend( + [ + GCNConv( + self.embedding_dim if i == 0 else self.hidden_dim, + self.hidden_dim, + ), + nn.Linear(self.hidden_dim, self.hidden_dim), + nn.ReLU(), + nn.Linear(self.hidden_dim, self.hidden_dim), + ] + ) # Multiple edge index convolution layers. for i in range(num_conv_layers): # mlp = create_mlp(self.embedding_dim, self.embedding_dim, self.embedding_dim) - self.edge_conv_blks.extend([ - GCNConv( - self.embedding_dim if i == 0 else self.hidden_dim, - self.hidden_dim, - ), - nn.Linear(self.hidden_dim, self.hidden_dim), - nn.ReLU(), - nn.Linear(self.hidden_dim, self.hidden_dim), - ]) + self.edge_conv_blks.extend( + [ + GCNConv( + self.embedding_dim if i == 0 else self.hidden_dim, + self.hidden_dim, + ), + nn.Linear(self.hidden_dim, self.hidden_dim), + nn.ReLU(), + nn.Linear(self.hidden_dim, self.hidden_dim), + ] + ) else: # Undirected case. # Multiple action type convolution layers. for _ in range(num_conv_layers): - self.action_conv_blks.extend([ - GINConv(create_mlp(self.embedding_dim, self.hidden_dim, self.hidden_dim)), - nn.Linear(self.hidden_dim, self.hidden_dim), - nn.ReLU(), - nn.Linear(self.hidden_dim, self.hidden_dim) - ]) + self.action_conv_blks.extend( + [ + GINConv( + create_mlp( + self.embedding_dim, self.hidden_dim, self.hidden_dim + ) + ), + nn.Linear(self.hidden_dim, self.hidden_dim), + nn.ReLU(), + nn.Linear(self.hidden_dim, self.hidden_dim), + ] + ) # Multiple edge index convolution layers. for _ in range(num_conv_layers): - self.edge_conv_blks.extend([ - GINConv(create_mlp(self.embedding_dim, self.hidden_dim, self.hidden_dim)), - nn.Linear(self.hidden_dim, self.hidden_dim), - nn.ReLU(), - nn.Linear(self.hidden_dim, self.hidden_dim) - ]) + self.edge_conv_blks.extend( + [ + GINConv( + create_mlp( + self.embedding_dim, self.hidden_dim, self.hidden_dim + ) + ), + nn.Linear(self.hidden_dim, self.hidden_dim), + nn.ReLU(), + nn.Linear(self.hidden_dim, self.hidden_dim), + ] + ) # Layer normalization for stability self.action_norm = nn.LayerNorm(self.hidden_dim) @@ -412,8 +432,12 @@ def forward_masks(self): - self.tensor["node_index"][self.tensor["batch_ptr"][i]] ) assert torch.all(existing_edges >= 0) # TODO: convert to test. - edge = existing_edges[:, 0] * (2 * self.n_nodes - existing_edges[:, 0] - 1) // 2 - edge += (existing_edges[:, 1] - existing_edges[:, 0] - 1) + edge = ( + existing_edges[:, 0] + * (2 * self.n_nodes - existing_edges[:, 0] - 1) + // 2 + ) + edge += existing_edges[:, 1] - existing_edges[:, 0] - 1 forward_masks[i, edge] = False return forward_masks.view(*self.batch_shape, self.n_actions) @@ -432,9 +456,16 @@ def backward_masks(self): self[i].tensor["edge_index"] - self.tensor["node_index"][self.tensor["batch_ptr"][i]] ) - edge = existing_edges[:, 0] * (2 * self.n_nodes - existing_edges[:, 0] - 1) // 2 - edge += (existing_edges[:, 1] - existing_edges[:, 0] - 1) - backward_masks[i, edge,] = True + edge = ( + existing_edges[:, 0] + * (2 * self.n_nodes - existing_edges[:, 0] - 1) + // 2 + ) + edge += existing_edges[:, 1] - existing_edges[:, 0] - 1 + backward_masks[ + i, + edge, + ] = True return backward_masks.view(*self.batch_shape, self.n_actions - 1) @@ -548,6 +579,7 @@ def render_states(states: GraphStates, state_evaluator: callable): plt.show() + if __name__ == "__main__": N_NODES = 3 N_ITERATIONS = 1000 From 623f50e30c1975a9f03ee9f772bcd3fc7c97ac3f Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Wed, 19 Feb 2025 08:22:33 -0500 Subject: [PATCH 068/102] directed and undirected graphs now co-implemented -- but the model will never sample the exit action. --- tutorials/examples/train_graph_ring.py | 169 ++++++++++++++++++++----- 1 file changed, 134 insertions(+), 35 deletions(-) diff --git a/tutorials/examples/train_graph_ring.py b/tutorials/examples/train_graph_ring.py index aca60f20..034f91a2 100644 --- a/tutorials/examples/train_graph_ring.py +++ b/tutorials/examples/train_graph_ring.py @@ -199,7 +199,7 @@ def __init__( ): super().__init__() self.hidden_dim = self.embedding_dim = 64 - + self.is_directed = directed self.is_backward = is_backward self.n_nodes = n_nodes self.num_conv_layers = num_conv_layers @@ -346,14 +346,28 @@ def forward(self, states_tensor: TensorDict) -> torch.Tensor: edge_feature = edge_feature.reshape( *states_tensor["batch_shape"], self.n_nodes, self.hidden_dim ) + + # This is n_nodes ** 2, for each graph. edge_index = torch.einsum("bnf,bmf->bnm", edge_feature, edge_feature) - i0, i1 = torch.triu_indices(self.n_nodes, self.n_nodes, offset=1) + # Undirected. + if self.is_directed: + i_up, j_up = torch.triu_indices(self.n_nodes, self.n_nodes, offset=1) # Upper triangle. + i_lo, j_lo = torch.tril_indices(self.n_nodes, self.n_nodes, offset=-1) # Lower triangle. + + # Combine them + i0 = torch.cat([i_up, i_lo]) + i1 = torch.cat([j_up, j_lo]) + out_size = self.n_nodes ** 2 - self.n_nodes + + else: + i0, i1 = torch.triu_indices(self.n_nodes, self.n_nodes, offset=1) + out_size = (self.n_nodes ** 2 - self.n_nodes) // 2 + + # Grab the needed elems from the adjacency matrix and reshape. batch_arange = torch.arange(batch_size) edge_actions = edge_index[batch_arange[:, None, None], i0, i1] - edge_actions = edge_actions.reshape( - *states_tensor["batch_shape"], self.n_nodes * (self.n_nodes - 1) // 2 - ) + edge_actions = edge_actions.reshape(*states_tensor["batch_shape"], out_size) if self.is_backward: return edge_actions @@ -373,11 +387,17 @@ class RingGraphBuilding(GraphBuilding): n_nodes: The number of nodes in the graph. """ - def __init__(self, n_nodes: int, state_evaluator: callable): + def __init__(self, n_nodes: int, state_evaluator: callable, directed: bool): self.n_nodes = n_nodes - self.n_actions = 1 + n_nodes * (n_nodes - 1) // 2 + if directed: + # all off-diagonal edges + exit. + self.n_actions = (n_nodes ** 2 - n_nodes) + 1 + else: + # bottom triangle + exit. + self.n_actions = ((n_nodes ** 2 - n_nodes) // 2) + 1 super().__init__(feature_dim=n_nodes, state_evaluator=state_evaluator) self.is_discrete = True # actions here are discrete, needed for FlowMatching + self.is_directed = directed def make_actions_class(self) -> type[Actions]: env = self @@ -423,22 +443,48 @@ def __init__(self, tensor: TensorDict): def forward_masks(self): # Allow all actions. forward_masks = torch.ones(len(self), self.n_actions, dtype=torch.bool) - forward_masks[:, :: self.n_nodes + 1] = False # Remove self-loops. - # Remove existing edges.s + if env.is_directed: + i_up, j_up = torch.triu_indices(self.n_nodes, self.n_nodes, offset=1) # Upper triangle. + i_lo, j_lo = torch.tril_indices(self.n_nodes, self.n_nodes, offset=-1) # Lower triangle. + + # Combine them + ei0 = torch.cat([i_up, i_lo]) + ei1 = torch.cat([j_up, j_lo]) + else: + ei0, ei1 = torch.triu_indices(self.n_nodes, self.n_nodes, offset=1) + + # # Adds -1 "edge" representing exit, -2 "edge" representing dummy. + # ei0 = torch.cat((ei0, torch.IntTensor([-1, -2])), dim=0) + # ei1 = torch.cat((ei1, torch.IntTensor([-1, -2])), dim=0) + + # Indexes either the second last element (exit) or la + # action_tensor[action_tensor >= (self.n_actions - 1)] = 0 + # ei0, ei1 = ei0[action_tensor], ei1[action_tensor] + + # Remove existing edges. for i in range(len(self)): existing_edges = ( self[i].tensor["edge_index"] - self.tensor["node_index"][self.tensor["batch_ptr"][i]] ) assert torch.all(existing_edges >= 0) # TODO: convert to test. - edge = ( - existing_edges[:, 0] - * (2 * self.n_nodes - existing_edges[:, 0] - 1) - // 2 - ) - edge += existing_edges[:, 1] - existing_edges[:, 0] - 1 - forward_masks[i, edge] = False + + if len(existing_edges) == 0: + edge_idx = torch.zeros(0, dtype=torch.bool) + else: + edge_idx = torch.logical_and( + existing_edges[:, 0] == ei0.unsqueeze(-1), + existing_edges[:, 1] == ei1.unsqueeze(-1), + ) + + # Collapse across the edge dimension. + if len(edge_idx.shape) == 2: + edge_idx = edge_idx.sum(1).bool() + + # Adds an unmasked exit action. + edge_idx = torch.cat((edge_idx, torch.BoolTensor([False]))) + forward_masks[i, edge_idx] = False # Disallow the addition of this edge. return forward_masks.view(*self.batch_shape, self.n_actions) @@ -448,24 +494,39 @@ def forward_masks(self, value: torch.Tensor): @property def backward_masks(self): + # Disallow all actions. backward_masks = torch.zeros( len(self), self.n_actions - 1, dtype=torch.bool ) + for i in range(len(self)): existing_edges = ( self[i].tensor["edge_index"] - self.tensor["node_index"][self.tensor["batch_ptr"][i]] ) - edge = ( - existing_edges[:, 0] - * (2 * self.n_nodes - existing_edges[:, 0] - 1) - // 2 - ) - edge += existing_edges[:, 1] - existing_edges[:, 0] - 1 - backward_masks[ - i, - edge, - ] = True + + if env.is_directed: + i_up, j_up = torch.triu_indices(self.n_nodes, self.n_nodes, offset=1) # Upper triangle. + i_lo, j_lo = torch.tril_indices(self.n_nodes, self.n_nodes, offset=-1) # Lower triangle. + + # Combine them + ei0 = torch.cat([i_up, i_lo]) + ei1 = torch.cat([j_up, j_lo]) + else: + ei0, ei1 = torch.triu_indices(self.n_nodes, self.n_nodes, offset=1) + + if len(existing_edges) == 0: + edge_idx = torch.zeros(0, dtype=torch.bool) + else: + edge_idx = torch.logical_and( + existing_edges[:, 0] == ei0.unsqueeze(-1), + existing_edges[:, 1] == ei1.unsqueeze(-1), + ) + # Collapse across the edge dimension. + if len(edge_idx.shape) == 2: + edge_idx = edge_idx.sum(1).bool() + + backward_masks[i, edge_idx] = True # Allow the removal of this edge. return backward_masks.view(*self.batch_shape, self.n_actions - 1) @@ -488,6 +549,7 @@ def _backward_step(self, states: GraphStates, actions: Actions) -> GraphStates: return new_states def convert_actions(self, states: GraphStates, actions: Actions) -> GraphActions: + """Converts the action from discrete space to graph action space.""" action_tensor = actions.tensor.squeeze(-1).clone() action_type = torch.where( action_tensor == self.n_actions - 1, @@ -496,8 +558,27 @@ def convert_actions(self, states: GraphStates, actions: Actions) -> GraphActions ) action_type[action_tensor == self.n_actions] = GraphActionType.DUMMY - ei0, ei1 = torch.triu_indices(self.n_nodes, self.n_nodes, offset=1) - action_tensor[action_tensor >= (self.n_actions - 1)] = 0 + # TODO: factor out into utility function. + if self.is_directed: + i_up, j_up = torch.triu_indices(self.n_nodes, self.n_nodes, offset=1) # Upper triangle. + i_lo, j_lo = torch.tril_indices(self.n_nodes, self.n_nodes, offset=-1) # Lower triangle. + + # Combine them + ei0 = torch.cat([i_up, i_lo]) + ei1 = torch.cat([j_up, j_lo]) + + # Potentially problematic (returns [0,0,0,1,1] instead of above which returns [1,1,1,0,0]). + # ei0 = (action_tensor) // (self.n_nodes) + # ei1 = (action_tensor) % (self.n_nodes) + else: + ei0, ei1 = torch.triu_indices(self.n_nodes, self.n_nodes, offset=1) + + # Adds -1 "edge" representing exit, -2 "edge" representing dummy. + ei0 = torch.cat((ei0, torch.IntTensor([-1, -2])), dim=0) + ei1 = torch.cat((ei1, torch.IntTensor([-1, -2])), dim=0) + + # Indexes either the second last element (exit) or la + # action_tensor[action_tensor >= (self.n_actions - 1)] = 0 ei0, ei1 = ei0[action_tensor], ei1[action_tensor] offset = states.tensor["node_index"][states.tensor["batch_ptr"][:-1]] @@ -527,7 +608,7 @@ def __call__(self, states: GraphStates) -> torch.Tensor: return self.preprocess(states) -def render_states(states: GraphStates, state_evaluator: callable): +def render_states(states: GraphStates, state_evaluator: callable, directed: bool): """Render the states as a matplotlib plot. Args: @@ -555,6 +636,7 @@ def render_states(states: GraphStates, state_evaluator: callable): edge_index = torch.where( edge_index[..., None] == states[i].tensor["node_index"] )[2].reshape(edge_index.shape) + for edge in edge_index: start_x, start_y = xs[edge[0]], ys[edge[0]] end_x, end_y = xs[edge[1]], ys[edge[1]] @@ -564,11 +646,28 @@ def render_states(states: GraphStates, state_evaluator: callable): dx, dy = dx / length, dy / length circle_radius = 0.5 + head_thickness = 0.2 + start_x += dx * circle_radius start_y += dy * circle_radius - end_x -= dx * circle_radius - end_y -= dy * circle_radius - current_ax.plot([start_x, end_x], [start_y, end_y], color="black") + if directed: + end_x -= dx * circle_radius + end_y -= dy * circle_radius + current_ax.arrow( + start_x, + start_y, + end_x - start_x, + end_y - start_y, + head_width=head_thickness, + head_length=head_thickness, + fc="black", + ec="black", + ) + + else: + end_x -= dx * (circle_radius + head_thickness) + end_y -= dy * (circle_radius + head_thickness) + current_ax.plot([start_x, end_x], [start_y, end_y], color="black") current_ax.set_title(f"State {i}, $r={rewards[i]:.2f}$") current_ax.set_xlim(-(radius + 1), radius + 1) @@ -582,14 +681,14 @@ def render_states(states: GraphStates, state_evaluator: callable): if __name__ == "__main__": N_NODES = 3 - N_ITERATIONS = 1000 + N_ITERATIONS = 500 LR = 0.05 BATCH_SIZE = 128 DIRECTED = True state_evaluator = undirected_reward if not DIRECTED else directed_reward torch.random.manual_seed(7) - env = RingGraphBuilding(n_nodes=N_NODES, state_evaluator=state_evaluator) + env = RingGraphBuilding(n_nodes=N_NODES, state_evaluator=state_evaluator, directed=DIRECTED) module_pf = RingPolicyModule(env.n_nodes, DIRECTED) module_pb = RingPolicyModule(env.n_nodes, DIRECTED, is_backward=True) pf = DiscretePolicyEstimator( @@ -632,4 +731,4 @@ def render_states(states: GraphStates, state_evaluator: callable): print("Time:", t2 - t1) last_states = trajectories.last_states[:8] assert isinstance(last_states, GraphStates) - render_states(last_states, state_evaluator) + render_states(last_states, state_evaluator, DIRECTED) From 74d9b136d80ee51850f81142a26c148858a37942 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Wed, 19 Feb 2025 08:22:47 -0500 Subject: [PATCH 069/102] black --- tutorials/examples/train_graph_ring.py | 56 ++++++++++++++++++-------- 1 file changed, 40 insertions(+), 16 deletions(-) diff --git a/tutorials/examples/train_graph_ring.py b/tutorials/examples/train_graph_ring.py index 034f91a2..3438e057 100644 --- a/tutorials/examples/train_graph_ring.py +++ b/tutorials/examples/train_graph_ring.py @@ -352,17 +352,21 @@ def forward(self, states_tensor: TensorDict) -> torch.Tensor: # Undirected. if self.is_directed: - i_up, j_up = torch.triu_indices(self.n_nodes, self.n_nodes, offset=1) # Upper triangle. - i_lo, j_lo = torch.tril_indices(self.n_nodes, self.n_nodes, offset=-1) # Lower triangle. + i_up, j_up = torch.triu_indices( + self.n_nodes, self.n_nodes, offset=1 + ) # Upper triangle. + i_lo, j_lo = torch.tril_indices( + self.n_nodes, self.n_nodes, offset=-1 + ) # Lower triangle. # Combine them i0 = torch.cat([i_up, i_lo]) i1 = torch.cat([j_up, j_lo]) - out_size = self.n_nodes ** 2 - self.n_nodes + out_size = self.n_nodes**2 - self.n_nodes else: i0, i1 = torch.triu_indices(self.n_nodes, self.n_nodes, offset=1) - out_size = (self.n_nodes ** 2 - self.n_nodes) // 2 + out_size = (self.n_nodes**2 - self.n_nodes) // 2 # Grab the needed elems from the adjacency matrix and reshape. batch_arange = torch.arange(batch_size) @@ -391,10 +395,10 @@ def __init__(self, n_nodes: int, state_evaluator: callable, directed: bool): self.n_nodes = n_nodes if directed: # all off-diagonal edges + exit. - self.n_actions = (n_nodes ** 2 - n_nodes) + 1 + self.n_actions = (n_nodes**2 - n_nodes) + 1 else: # bottom triangle + exit. - self.n_actions = ((n_nodes ** 2 - n_nodes) // 2) + 1 + self.n_actions = ((n_nodes**2 - n_nodes) // 2) + 1 super().__init__(feature_dim=n_nodes, state_evaluator=state_evaluator) self.is_discrete = True # actions here are discrete, needed for FlowMatching self.is_directed = directed @@ -445,8 +449,12 @@ def forward_masks(self): forward_masks = torch.ones(len(self), self.n_actions, dtype=torch.bool) if env.is_directed: - i_up, j_up = torch.triu_indices(self.n_nodes, self.n_nodes, offset=1) # Upper triangle. - i_lo, j_lo = torch.tril_indices(self.n_nodes, self.n_nodes, offset=-1) # Lower triangle. + i_up, j_up = torch.triu_indices( + self.n_nodes, self.n_nodes, offset=1 + ) # Upper triangle. + i_lo, j_lo = torch.tril_indices( + self.n_nodes, self.n_nodes, offset=-1 + ) # Lower triangle. # Combine them ei0 = torch.cat([i_up, i_lo]) @@ -484,7 +492,9 @@ def forward_masks(self): # Adds an unmasked exit action. edge_idx = torch.cat((edge_idx, torch.BoolTensor([False]))) - forward_masks[i, edge_idx] = False # Disallow the addition of this edge. + forward_masks[i, edge_idx] = ( + False # Disallow the addition of this edge. + ) return forward_masks.view(*self.batch_shape, self.n_actions) @@ -506,14 +516,20 @@ def backward_masks(self): ) if env.is_directed: - i_up, j_up = torch.triu_indices(self.n_nodes, self.n_nodes, offset=1) # Upper triangle. - i_lo, j_lo = torch.tril_indices(self.n_nodes, self.n_nodes, offset=-1) # Lower triangle. + i_up, j_up = torch.triu_indices( + self.n_nodes, self.n_nodes, offset=1 + ) # Upper triangle. + i_lo, j_lo = torch.tril_indices( + self.n_nodes, self.n_nodes, offset=-1 + ) # Lower triangle. # Combine them ei0 = torch.cat([i_up, i_lo]) ei1 = torch.cat([j_up, j_lo]) else: - ei0, ei1 = torch.triu_indices(self.n_nodes, self.n_nodes, offset=1) + ei0, ei1 = torch.triu_indices( + self.n_nodes, self.n_nodes, offset=1 + ) if len(existing_edges) == 0: edge_idx = torch.zeros(0, dtype=torch.bool) @@ -526,7 +542,9 @@ def backward_masks(self): if len(edge_idx.shape) == 2: edge_idx = edge_idx.sum(1).bool() - backward_masks[i, edge_idx] = True # Allow the removal of this edge. + backward_masks[i, edge_idx] = ( + True # Allow the removal of this edge. + ) return backward_masks.view(*self.batch_shape, self.n_actions - 1) @@ -560,8 +578,12 @@ def convert_actions(self, states: GraphStates, actions: Actions) -> GraphActions # TODO: factor out into utility function. if self.is_directed: - i_up, j_up = torch.triu_indices(self.n_nodes, self.n_nodes, offset=1) # Upper triangle. - i_lo, j_lo = torch.tril_indices(self.n_nodes, self.n_nodes, offset=-1) # Lower triangle. + i_up, j_up = torch.triu_indices( + self.n_nodes, self.n_nodes, offset=1 + ) # Upper triangle. + i_lo, j_lo = torch.tril_indices( + self.n_nodes, self.n_nodes, offset=-1 + ) # Lower triangle. # Combine them ei0 = torch.cat([i_up, i_lo]) @@ -688,7 +710,9 @@ def render_states(states: GraphStates, state_evaluator: callable, directed: bool state_evaluator = undirected_reward if not DIRECTED else directed_reward torch.random.manual_seed(7) - env = RingGraphBuilding(n_nodes=N_NODES, state_evaluator=state_evaluator, directed=DIRECTED) + env = RingGraphBuilding( + n_nodes=N_NODES, state_evaluator=state_evaluator, directed=DIRECTED + ) module_pf = RingPolicyModule(env.n_nodes, DIRECTED) module_pb = RingPolicyModule(env.n_nodes, DIRECTED, is_backward=True) pf = DiscretePolicyEstimator( From 9b7ce68d053e6902b398b42a0b0a4139f47da597 Mon Sep 17 00:00:00 2001 From: Salem Lahlou Date: Wed, 19 Feb 2025 17:29:20 +0400 Subject: [PATCH 070/102] change docstring and fix argument --- tutorials/examples/train_graph_ring.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/tutorials/examples/train_graph_ring.py b/tutorials/examples/train_graph_ring.py index aca60f20..60c2f20b 100644 --- a/tutorials/examples/train_graph_ring.py +++ b/tutorials/examples/train_graph_ring.py @@ -1,4 +1,14 @@ -"""Write ane xamples where we want to create graphs that are rings.""" +"""Train a GFlowNet to generate ring graphs. + +This example demonstrates training a GFlowNet to generate graphs that are rings - where each vertex +has exactly two neighbors and the edges form a single cycle containing all vertices. The environment +supports both directed and undirected ring generation. + +Key components: +- RingGraphBuilding: Environment for building ring graphs +- RingPolicyModule: GNN-based policy network for predicting actions +- directed_reward/undirected_reward: Reward functions for validating ring structures +""" import math import time @@ -50,7 +60,7 @@ def directed_reward(states: GraphStates) -> torch.Tensor: adj_matrix = torch.zeros(n_nodes, n_nodes) adj_matrix[masked_edge_index[:, 0], masked_edge_index[:, 1]] = 1 - if not torch.all(adj_matrix.sum(axis=1) == 1): + if not torch.all(adj_matrix.sum(dim=1) == 1): continue visited = [] From 704cf6615f8a9226d6d155f3fcf02d1598cd85e7 Mon Sep 17 00:00:00 2001 From: Salem Lahlou Date: Wed, 19 Feb 2025 17:29:42 +0400 Subject: [PATCH 071/102] add the possibility of layernorm in MLP --- src/gfn/utils/modules.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/src/gfn/utils/modules.py b/src/gfn/utils/modules.py index d8106783..d13028d7 100644 --- a/src/gfn/utils/modules.py +++ b/src/gfn/utils/modules.py @@ -17,6 +17,7 @@ def __init__( n_hidden_layers: Optional[int] = 2, activation_fn: Optional[Literal["relu", "tanh", "elu"]] = "relu", trunk: Optional[nn.Module] = None, + add_layer_norm: bool = False, ): """Instantiates a MLP instance. @@ -28,6 +29,8 @@ def __init__( activation_fn: Activation function. trunk: If provided, this module will be used as the trunk of the network (i.e. all layers except last layer). + add_layer_norm: If True, add a LayerNorm after each linear layer. + (incompatible with `trunk` argument) """ super().__init__() self._input_dim = input_dim @@ -44,9 +47,18 @@ def __init__( activation = nn.ReLU elif activation_fn == "tanh": activation = nn.Tanh - arch = [nn.Linear(input_dim, hidden_dim), activation()] + if add_layer_norm: + arch = [ + nn.Linear(input_dim, hidden_dim), + nn.LayerNorm(hidden_dim), + activation(), + ] + else: + arch = [nn.Linear(input_dim, hidden_dim), activation()] for _ in range(n_hidden_layers - 1): arch.append(nn.Linear(hidden_dim, hidden_dim)) + if add_layer_norm: + arch.append(nn.LayerNorm(hidden_dim)) arch.append(activation()) self.trunk = nn.Sequential(*arch) self.trunk.hidden_dim = hidden_dim From 7bbd72dc7ffc188e56f6e20e2fc3d06693f01457 Mon Sep 17 00:00:00 2001 From: Salem Lahlou Date: Wed, 19 Feb 2025 17:31:04 +0400 Subject: [PATCH 072/102] remove create_mlp function --- tutorials/examples/train_graph_ring.py | 55 +++++++------------------- 1 file changed, 15 insertions(+), 40 deletions(-) diff --git a/tutorials/examples/train_graph_ring.py b/tutorials/examples/train_graph_ring.py index 60c2f20b..c54f5101 100644 --- a/tutorials/examples/train_graph_ring.py +++ b/tutorials/examples/train_graph_ring.py @@ -27,6 +27,7 @@ from gfn.modules import DiscretePolicyEstimator from gfn.preprocessors import Preprocessor from gfn.states import GraphStates +from gfn.utils.modules import MLP def directed_reward(states: GraphStates) -> torch.Tensor: @@ -161,38 +162,6 @@ def undirected_reward(states: GraphStates) -> torch.Tensor: return out.view(*states.batch_shape) -def create_mlp(in_channels, hidden_channels, out_channels, num_layers=1): - """ - Create a Multi-Layer Perceptron with configurable number of layers. - - Args: - in_channels (int): Number of input features - hidden_channels (int): Number of hidden features per layer - out_channels (int): Number of output features - num_layers (int): Number of hidden layers (default: 2) - - Returns: - nn.Sequential: MLP model - """ - layers = [] - - # Input layer - layers.append(nn.Linear(in_channels, hidden_channels)) - layers.append(nn.LayerNorm(hidden_channels)) - layers.append(nn.ReLU()) - - # Hidden layers - for _ in range(num_layers - 1): - layers.append(nn.Linear(hidden_channels, hidden_channels)) - layers.append(nn.LayerNorm(hidden_channels)) - layers.append(nn.ReLU()) - - # Output layer - layers.append(nn.Linear(hidden_channels, out_channels)) - - return nn.Sequential(*layers) - - class RingPolicyModule(nn.Module): """Simple module which outputs a fixed logits for the actions, depending on the number of edges. @@ -222,7 +191,6 @@ def __init__( if directed: # Multiple action type convolution layers. for i in range(num_conv_layers): - # mlp = create_mlp(self.embedding_dim, self.embedding_dim, self.embedding_dim) self.action_conv_blks.extend( [ GCNConv( @@ -237,7 +205,6 @@ def __init__( # Multiple edge index convolution layers. for i in range(num_conv_layers): - # mlp = create_mlp(self.embedding_dim, self.embedding_dim, self.embedding_dim) self.edge_conv_blks.extend( [ GCNConv( @@ -255,9 +222,13 @@ def __init__( self.action_conv_blks.extend( [ GINConv( - create_mlp( - self.embedding_dim, self.hidden_dim, self.hidden_dim - ) + MLP( + input_dim=self.embedding_dim, + output_dim=self.hidden_dim, + hidden_dim=self.hidden_dim, + n_hidden_layers=1, + add_layer_norm=True, + ), ), nn.Linear(self.hidden_dim, self.hidden_dim), nn.ReLU(), @@ -270,9 +241,13 @@ def __init__( self.edge_conv_blks.extend( [ GINConv( - create_mlp( - self.embedding_dim, self.hidden_dim, self.hidden_dim - ) + MLP( + input_dim=self.embedding_dim, + output_dim=self.hidden_dim, + hidden_dim=self.hidden_dim, + n_hidden_layers=1, + add_layer_norm=True, + ), ), nn.Linear(self.hidden_dim, self.hidden_dim), nn.ReLU(), From 7bdf5b5946c0fb1ea9d65d81b8f4967e15abcd52 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Wed, 19 Feb 2025 11:46:17 -0500 Subject: [PATCH 073/102] new architecture --- src/gfn/samplers.py | 1 + tutorials/examples/train_graph_ring.py | 117 ++++++++++++------------- 2 files changed, 59 insertions(+), 59 deletions(-) diff --git a/src/gfn/samplers.py b/src/gfn/samplers.py index c20423f4..5a7e9066 100644 --- a/src/gfn/samplers.py +++ b/src/gfn/samplers.py @@ -75,6 +75,7 @@ def sample_actions( with no_conditioning_exception_handler("estimator", self.estimator): estimator_output = self.estimator(states) + print("estimator_output={}".format(estimator_output[-1])) dist = self.estimator.to_probability_distribution( states, estimator_output, **policy_kwargs ) diff --git a/tutorials/examples/train_graph_ring.py b/tutorials/examples/train_graph_ring.py index 3438e057..210c2f70 100644 --- a/tutorials/examples/train_graph_ring.py +++ b/tutorials/examples/train_graph_ring.py @@ -206,29 +206,30 @@ def __init__( # Node embedding layer. self.embedding = nn.Embedding(n_nodes, self.embedding_dim) - self.action_conv_blks = nn.ModuleList() - self.edge_conv_blks = nn.ModuleList() + # self.action_conv_blks = nn.ModuleList() + self.conv_blks = nn.ModuleList() + self.exit_mlp = create_mlp(self.hidden_dim, self.hidden_dim, 1) if directed: # Multiple action type convolution layers. - for i in range(num_conv_layers): - # mlp = create_mlp(self.embedding_dim, self.embedding_dim, self.embedding_dim) - self.action_conv_blks.extend( - [ - GCNConv( - self.embedding_dim if i == 0 else self.hidden_dim, - self.hidden_dim, - ), - nn.Linear(self.hidden_dim, self.hidden_dim), - nn.ReLU(), - nn.Linear(self.hidden_dim, self.hidden_dim), - ] - ) + # for i in range(num_conv_layers): + # mlp = create_mlp(self.embedding_dim, self.embedding_dim, self.embedding_dim) + # self.action_conv_blks.extend( + # [ + # GCNConv( + # self.embedding_dim if i == 0 else self.hidden_dim, + # self.hidden_dim, + # ), + # nn.Linear(self.hidden_dim, self.hidden_dim), + # nn.ReLU(), + # nn.Linear(self.hidden_dim, self.hidden_dim), + # ] + # ) # Multiple edge index convolution layers. for i in range(num_conv_layers): # mlp = create_mlp(self.embedding_dim, self.embedding_dim, self.embedding_dim) - self.edge_conv_blks.extend( + self.conv_blks.extend( [ GCNConv( self.embedding_dim if i == 0 else self.hidden_dim, @@ -241,23 +242,23 @@ def __init__( ) else: # Undirected case. # Multiple action type convolution layers. - for _ in range(num_conv_layers): - self.action_conv_blks.extend( - [ - GINConv( - create_mlp( - self.embedding_dim, self.hidden_dim, self.hidden_dim - ) - ), - nn.Linear(self.hidden_dim, self.hidden_dim), - nn.ReLU(), - nn.Linear(self.hidden_dim, self.hidden_dim), - ] - ) + # for _ in range(num_conv_layers): + # self.action_conv_blks.extend( + # [ + # GINConv( + # create_mlp( + # self.embedding_dim, self.hidden_dim, self.hidden_dim + # ) + # ), + # nn.Linear(self.hidden_dim, self.hidden_dim), + # nn.ReLU(), + # nn.Linear(self.hidden_dim, self.hidden_dim), + # ] + # ) # Multiple edge index convolution layers. for _ in range(num_conv_layers): - self.edge_conv_blks.extend( + self.conv_blks.extend( [ GINConv( create_mlp( @@ -271,7 +272,7 @@ def __init__( ) # Layer normalization for stability - self.action_norm = nn.LayerNorm(self.hidden_dim) + # self.action_norm = nn.LayerNorm(self.hidden_dim) self.edge_norm = nn.LayerNorm(self.hidden_dim) def _group_mean( @@ -289,11 +290,10 @@ def _group_mean( return (cumsum[batch_ptr[1:]] - cumsum[batch_ptr[:-1]]) / size[:, None] def forward(self, states_tensor: TensorDict) -> torch.Tensor: - node_feature, batch_ptr = ( + node_features, batch_ptr = ( states_tensor["node_feature"], states_tensor["batch_ptr"], ) - x = self.embedding(node_feature.squeeze().int()) batch_size = int(torch.prod(states_tensor["batch_shape"])) edge_index = torch.where( @@ -304,45 +304,44 @@ def forward(self, states_tensor: TensorDict) -> torch.Tensor: # edge_attrs = states_tensor["edge_feature"] # Multiple action type convolutions with residual connections. - action_type = x - for i in range(0, len(self.action_conv_blks), 4): - - # GIN/GCN conv. - action_type_new = self.action_conv_blks[i](action_type, edge_index.T) - # First linear. - action_type_new = self.action_conv_blks[i + 1](action_type_new) - # ReLU. - action_type_new = self.action_conv_blks[i + 2](action_type_new) - # Second linear. - action_type_new = self.action_conv_blks[i + 3](action_type_new) - # Residual connection with original input. - action_type = action_type_new + action_type - action_type = self.action_norm(action_type) - - action_type = self._group_mean( - torch.mean(action_type, dim=-1, keepdim=True), batch_ptr - ) + # for i in range(0, len(self.action_conv_blks), 4): + + # # GIN/GCN conv. + # action_type_new = self.action_conv_blks[i](action_type, edge_index.T) + # # First linear. + # action_type_new = self.action_conv_blks[i + 1](action_type_new) + # # ReLU. + # action_type_new = self.action_conv_blks[i + 2](action_type_new) + # # Second linear. + # action_type_new = self.action_conv_blks[i + 3](action_type_new) + # # Residual connection with original input. + # action_type = action_type_new + action_type + # action_type = self.action_norm(action_type) # Multiple action type convolutions with residual connections. - edge_feature = x - for i in range(0, len(self.edge_conv_blks), 4): + node_features = self.embedding(node_features.squeeze().int()) + for i in range(0, len(self.conv_blks), 4): # GIN/GCN conv. - edge_feature_new = self.edge_conv_blks[i](edge_feature, edge_index.T) + node_feature_new = self.conv_blks[i](node_features, edge_index.T) # First linear. - edge_feature_new = self.edge_conv_blks[i + 1](edge_feature_new) + node_feature_new = self.conv_blks[i + 1](node_feature_new) # ReLU. - edge_feature_new = self.edge_conv_blks[i + 2](edge_feature_new) + node_feature_new = self.conv_blks[i + 2](node_feature_new) # Second linear. - edge_feature_new = self.edge_conv_blks[i + 3](edge_feature_new) + node_feature_new = self.conv_blks[i + 3](node_feature_new) # Residual connection with original input. - edge_feature = edge_feature_new + edge_feature + edge_feature = node_feature_new + node_features edge_feature = self.edge_norm(edge_feature) # edge_feature = self._group_mean( # torch.mean(edge_feature, dim=-1, keepdim=True), batch_ptr # ) + # TODO: MLP from here to exit_action. + edge_feature_means = self._group_mean(edge_feature, batch_ptr) + exit_action = self.exit_mlp(edge_feature_means) + edge_feature = edge_feature.reshape( *states_tensor["batch_shape"], self.n_nodes, self.hidden_dim ) @@ -376,7 +375,7 @@ def forward(self, states_tensor: TensorDict) -> torch.Tensor: if self.is_backward: return edge_actions else: - return torch.cat([edge_actions, action_type], dim=-1) + return torch.cat([edge_actions, exit_action], dim=-1) class RingGraphBuilding(GraphBuilding): From af26f6e34f23701c9b7eafd252c23837a3c9543f Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Wed, 19 Feb 2025 11:52:00 -0500 Subject: [PATCH 074/102] MLP changes --- tutorials/examples/train_graph_ring.py | 34 +++++++++++++------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/tutorials/examples/train_graph_ring.py b/tutorials/examples/train_graph_ring.py index 0e725bb0..024b55df 100644 --- a/tutorials/examples/train_graph_ring.py +++ b/tutorials/examples/train_graph_ring.py @@ -240,23 +240,23 @@ def __init__( # nn.Linear(self.hidden_dim, self.hidden_dim), # ] # ) - for _ in range(num_conv_layers): - self.conv_blks.extend( - [ - GINConv( - MLP( - input_dim=self.embedding_dim, - output_dim=self.hidden_dim, - hidden_dim=self.hidden_dim, - n_hidden_layers=1, - add_layer_norm=True, - ), - ), - nn.Linear(self.hidden_dim, self.hidden_dim), - nn.ReLU(), - nn.Linear(self.hidden_dim, self.hidden_dim), - ] - ) + # for _ in range(num_conv_layers): + # self.conv_blks.extend( + # [ + # GINConv( + # MLP( + # input_dim=self.embedding_dim, + # output_dim=self.hidden_dim, + # hidden_dim=self.hidden_dim, + # n_hidden_layers=1, + # add_layer_norm=True, + # ), + # ), + # nn.Linear(self.hidden_dim, self.hidden_dim), + # nn.ReLU(), + # nn.Linear(self.hidden_dim, self.hidden_dim), + # ] + # ) # Multiple edge index convolution layers. for _ in range(num_conv_layers): From a6ffc728ec7d6493295b8701b95915df3370078e Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Thu, 20 Feb 2025 00:38:26 +0100 Subject: [PATCH 075/102] fix magnitude issue --- tutorials/examples/train_graph_ring.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tutorials/examples/train_graph_ring.py b/tutorials/examples/train_graph_ring.py index 024b55df..b5a51381 100644 --- a/tutorials/examples/train_graph_ring.py +++ b/tutorials/examples/train_graph_ring.py @@ -354,6 +354,7 @@ def forward(self, states_tensor: TensorDict) -> torch.Tensor: # This is n_nodes ** 2, for each graph. edge_index = torch.einsum("bnf,bmf->bnm", edge_feature, edge_feature) + edge_index = edge_index / torch.sqrt(torch.tensor(self.hidden_dim)) # Undirected. if self.is_directed: From 8ba06a17d46e3e0abecf1594c05b4360cfa25ee7 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Wed, 19 Feb 2025 20:01:24 -0500 Subject: [PATCH 076/102] changed docs --- tutorials/examples/train_graph_ring.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tutorials/examples/train_graph_ring.py b/tutorials/examples/train_graph_ring.py index 024b55df..053a3fdc 100644 --- a/tutorials/examples/train_graph_ring.py +++ b/tutorials/examples/train_graph_ring.py @@ -344,7 +344,7 @@ def forward(self, states_tensor: TensorDict) -> torch.Tensor: # torch.mean(edge_feature, dim=-1, keepdim=True), batch_ptr # ) - # TODO: MLP from here to exit_action. + # This MLP computes the exit action. edge_feature_means = self._group_mean(edge_feature, batch_ptr) exit_action = self.exit_mlp(edge_feature_means) From a2d44d53b6519cda3c52b020a4956c14a064b6ea Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Thu, 20 Feb 2025 00:58:22 -0500 Subject: [PATCH 077/102] some bugfixes in extend of trajectories --- src/gfn/containers/trajectories.py | 2 +- src/gfn/states.py | 21 ++++++++++++++------- tutorials/examples/train_graph_ring.py | 22 ++++++++++++++++------ 3 files changed, 31 insertions(+), 14 deletions(-) diff --git a/src/gfn/containers/trajectories.py b/src/gfn/containers/trajectories.py index 2e7730d1..fe8a662b 100644 --- a/src/gfn/containers/trajectories.py +++ b/src/gfn/containers/trajectories.py @@ -256,7 +256,7 @@ def extend(self, other: Trajectories) -> None: # TODO: The replay buffer is storing `dones` - this wastes a lot of space. self.actions.extend(other.actions) - self.states.extend(other.states) + self.states.extend(other.states) # n_trajectories comes from this. self.when_is_done = torch.cat((self.when_is_done, other.when_is_done), dim=0) # For log_probs, we first need to make the first dimensions of self.log_probs diff --git a/src/gfn/states.py b/src/gfn/states.py index 386ed5d6..dc82b330 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -534,6 +534,7 @@ def __init__(self, tensor: TensorDict): self.edge_features_dim = tensor["edge_feature"].shape[-1] self._log_rewards: Optional[torch.Tensor] = None + # TODO: self.tensor["batch_shape"] is set wrong. @property def batch_shape(self) -> tuple: return tuple(self.tensor["batch_shape"].tolist()) @@ -633,6 +634,7 @@ def __getitem__( self, index: int | Sequence[int] | slice | torch.Tensor ) -> GraphStates: tensor_idx = torch.arange(len(self)).view(*self.batch_shape) + new_shape = tensor_idx[index].shape idx = tensor_idx[index].flatten() if torch.any(idx >= len(self.tensor["batch_ptr"]) - 1): @@ -672,7 +674,7 @@ def __getitem__( "edge_feature": torch.cat(edge_features), "edge_index": torch.cat(edge_indices), "batch_ptr": torch.tensor(batch_ptr), - "batch_shape": (len(idx),), + "batch_shape": (tuple(new_shape)), # TODO: this shouldn't change from len(2) to len(1). } ) ) @@ -850,12 +852,17 @@ def extend(self, other: GraphStates): ], dim=0, ) - assert torch.all( - self.tensor["batch_shape"][1:] == other.tensor["batch_shape"][1:] - ) - self.tensor["batch_shape"] = ( - self.tensor["batch_shape"][0] + other.tensor["batch_shape"][0], - ) + self.batch_shape[1:] + + #self.tensor["batch_shape"] = self.tensor["batch_shape"] + other.tensor["batch_shape"] + # If self.tensor is a placeholder and all batch_dims are 0, this check won't pass. + if not torch.all(self.tensor["batch_shape"] == 0): + assert torch.all( + self.tensor["batch_shape"][1:] == other.tensor["batch_shape"][1:] + ) + # self.tensor["batch_shape"] = ( + # self.tensor["batch_shape"][0] + other.tensor["batch_shape"][0], + # ) + self.batch_shape[1:] + self.tensor["batch_shape"] = self.tensor["batch_shape"] + other.tensor["batch_shape"] @property def log_rewards(self) -> torch.Tensor: diff --git a/tutorials/examples/train_graph_ring.py b/tutorials/examples/train_graph_ring.py index 3042a473..d9106e60 100644 --- a/tutorials/examples/train_graph_ring.py +++ b/tutorials/examples/train_graph_ring.py @@ -28,6 +28,7 @@ from gfn.preprocessors import Preprocessor from gfn.states import GraphStates from gfn.utils.modules import MLP +from gfn.containers import PrioritizedReplayBuffer, ReplayBuffer def directed_reward(states: GraphStates) -> torch.Tensor: @@ -708,11 +709,11 @@ def render_states(states: GraphStates, state_evaluator: callable, directed: bool if __name__ == "__main__": - N_NODES = 3 + N_NODES = 6 N_ITERATIONS = 500 LR = 0.05 BATCH_SIZE = 128 - DIRECTED = True + DIRECTED = False state_evaluator = undirected_reward if not DIRECTED else directed_reward torch.random.manual_seed(7) @@ -733,6 +734,10 @@ def render_states(states: GraphStates, state_evaluator: callable, directed: bool gflownet = TBGFlowNet(pf, pb) optimizer = torch.optim.Adam(gflownet.parameters(), lr=LR) + replay_buffer = ReplayBuffer( + env, objects_type="trajectories", capacity=1000, + ) + losses = [] t1 = time.time() @@ -742,16 +747,21 @@ def render_states(states: GraphStates, state_evaluator: callable, directed: bool ) last_states = trajectories.last_states assert isinstance(last_states, GraphStates) - rews = state_evaluator(last_states) - samples = gflownet.to_training_samples(trajectories) + rewards = state_evaluator(last_states) + training_samples = gflownet.to_training_samples(trajectories) + + with torch.no_grad(): + replay_buffer.add(training_samples) + training_objects = replay_buffer.sample(n_trajectories=BATCH_SIZE) + optimizer.zero_grad() - loss = gflownet.loss(env, samples) # pyright: ignore + loss = gflownet.loss(env, training_objects) # pyright: ignore print( "Iteration", iteration, "Loss:", loss.item(), - f"rings: {torch.mean(rews > 0.1, dtype=torch.float) * 100:.0f}%", + f"rings: {torch.mean(rewards > 0.1, dtype=torch.float) * 100:.0f}%", ) loss.backward() optimizer.step() From bf92366e48e516f63dae144b6a7e682ba80e122a Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Thu, 20 Feb 2025 00:58:44 -0500 Subject: [PATCH 078/102] black --- tutorials/examples/train_graph_ring.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tutorials/examples/train_graph_ring.py b/tutorials/examples/train_graph_ring.py index d9106e60..fd6b759b 100644 --- a/tutorials/examples/train_graph_ring.py +++ b/tutorials/examples/train_graph_ring.py @@ -28,7 +28,7 @@ from gfn.preprocessors import Preprocessor from gfn.states import GraphStates from gfn.utils.modules import MLP -from gfn.containers import PrioritizedReplayBuffer, ReplayBuffer +from gfn.containers import ReplayBuffer def directed_reward(states: GraphStates) -> torch.Tensor: @@ -735,7 +735,9 @@ def render_states(states: GraphStates, state_evaluator: callable, directed: bool optimizer = torch.optim.Adam(gflownet.parameters(), lr=LR) replay_buffer = ReplayBuffer( - env, objects_type="trajectories", capacity=1000, + env, + objects_type="trajectories", + capacity=1000, ) losses = [] From cfd474040450c04d19bc695dd4673d52c6e0e452 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Thu, 20 Feb 2025 00:59:07 -0500 Subject: [PATCH 079/102] black --- src/gfn/states.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/gfn/states.py b/src/gfn/states.py index dc82b330..3d84b2d4 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -674,7 +674,9 @@ def __getitem__( "edge_feature": torch.cat(edge_features), "edge_index": torch.cat(edge_indices), "batch_ptr": torch.tensor(batch_ptr), - "batch_shape": (tuple(new_shape)), # TODO: this shouldn't change from len(2) to len(1). + "batch_shape": ( + tuple(new_shape) + ), # TODO: this shouldn't change from len(2) to len(1). } ) ) @@ -853,7 +855,7 @@ def extend(self, other: GraphStates): dim=0, ) - #self.tensor["batch_shape"] = self.tensor["batch_shape"] + other.tensor["batch_shape"] + # self.tensor["batch_shape"] = self.tensor["batch_shape"] + other.tensor["batch_shape"] # If self.tensor is a placeholder and all batch_dims are 0, this check won't pass. if not torch.all(self.tensor["batch_shape"] == 0): assert torch.all( @@ -862,7 +864,9 @@ def extend(self, other: GraphStates): # self.tensor["batch_shape"] = ( # self.tensor["batch_shape"][0] + other.tensor["batch_shape"][0], # ) + self.batch_shape[1:] - self.tensor["batch_shape"] = self.tensor["batch_shape"] + other.tensor["batch_shape"] + self.tensor["batch_shape"] = ( + self.tensor["batch_shape"] + other.tensor["batch_shape"] + ) @property def log_rewards(self) -> torch.Tensor: From ab40005bb1caf5ac73b7f3b15bc7fbd25707ffe6 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Thu, 20 Feb 2025 00:59:29 -0500 Subject: [PATCH 080/102] removed print --- src/gfn/samplers.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/gfn/samplers.py b/src/gfn/samplers.py index 5a7e9066..c20423f4 100644 --- a/src/gfn/samplers.py +++ b/src/gfn/samplers.py @@ -75,7 +75,6 @@ def sample_actions( with no_conditioning_exception_handler("estimator", self.estimator): estimator_output = self.estimator(states) - print("estimator_output={}".format(estimator_output[-1])) dist = self.estimator.to_probability_distribution( states, estimator_output, **policy_kwargs ) From 8c8304982a73ecc84e5ff4fb5c2d1d0f8b5948a8 Mon Sep 17 00:00:00 2001 From: Salem Lahlou Date: Thu, 20 Feb 2025 19:28:26 +0400 Subject: [PATCH 081/102] change types in env and states to allow for graph states --- src/gfn/env.py | 43 +++++++---------- src/gfn/gym/graph_building.py | 1 + src/gfn/states.py | 88 ++++++++++++++++++++++++----------- 3 files changed, 79 insertions(+), 53 deletions(-) diff --git a/src/gfn/env.py b/src/gfn/env.py index 21d401ff..cdc84ca4 100644 --- a/src/gfn/env.py +++ b/src/gfn/env.py @@ -13,7 +13,7 @@ NonValidActionsError = type("NonValidActionsError", (ValueError,), {}) -def get_device(device_str, default_device): +def get_device(device_str, default_device) -> torch.device: return torch.device(device_str) if device_str is not None else default_device @@ -23,12 +23,12 @@ class Env(ABC): def __init__( self, - s0: torch.Tensor, + s0: torch.Tensor | TensorDict, state_shape: Tuple, action_shape: Tuple, dummy_action: torch.Tensor, exit_action: torch.Tensor, - sf: Optional[torch.Tensor] = None, + sf: Optional[torch.Tensor | TensorDict] = None, device_str: Optional[str] = None, preprocessor: Optional[Preprocessor] = None, ): @@ -55,7 +55,7 @@ def __init__( assert s0.shape == state_shape if sf is None: sf = torch.full(s0.shape, -float("inf")).to(self.device) - self.sf: torch.Tensor = sf + self.sf = sf assert self.sf.shape == state_shape self.state_shape = state_shape self.action_shape = action_shape @@ -263,7 +263,9 @@ def _step( # Set to the sink state when the action is exit. new_sink_states_idx = actions.is_exit - sf_tensor = self.States.make_sink_states_tensor((new_sink_states_idx.sum(),)) + sf_tensor = self.States.make_sink_states_tensor( + (int(new_sink_states_idx.sum().item()),) + ) new_states[new_sink_states_idx] = self.States(sf_tensor) new_sink_states_idx = ~valid_states_idx | new_sink_states_idx assert new_sink_states_idx.shape == states.batch_shape @@ -361,6 +363,9 @@ class DiscreteEnv(Env, ABC): via mask tensors, that are directly attached to `States` objects. """ + s0: torch.Tensor # this tells the type checker that s0 is a torch.Tensor + sf: torch.Tensor # this tells the type checker that sf is a torch.Tensor + def __init__( self, n_actions: int, @@ -510,28 +515,14 @@ def _step(self, states: DiscreteStates, actions: Actions) -> States: return new_states def get_states_indices(self, states: DiscreteStates) -> torch.Tensor: - """Returns the indices of the states in the environment. - - Args: - states: The batch of states. - - Returns: - torch.Tensor: Tensor of shape "batch_shape" containing the indices of the states. - """ - return NotImplementedError( + """Returns the indices of the states in the environment.""" + raise NotImplementedError( "The environment does not support enumeration of states" ) def get_terminating_states_indices(self, states: DiscreteStates) -> torch.Tensor: - """Returns the indices of the terminating states in the environment. - - Args: - states: The batch of states. - - Returns: - torch.Tensor: Tensor of shape "batch_shape" containing the indices of the terminating states. - """ - return NotImplementedError( + """Returns the indices of the terminating states in the environment.""" + raise NotImplementedError( "The environment does not support enumeration of states" ) @@ -580,10 +571,12 @@ def terminating_states(self) -> DiscreteStates: class GraphEnv(Env): """Base class for graph-based environments.""" + sf: TensorDict # this tells the type checker that sf is a TensorDict + def __init__( self, s0: TensorDict, - sf: Optional[TensorDict] = None, + sf: TensorDict, device_str: Optional[str] = None, preprocessor: Optional[Preprocessor] = None, ): @@ -591,7 +584,7 @@ def __init__( Args: s0: The initial graph state. - sf: The final graph state. + sf: The sink graph state. device_str: 'cpu' or 'cuda'. Defaults to None, in which case the device is inferred from s0. preprocessor: a Preprocessor object that converts raw graph states to a tensor diff --git a/src/gfn/gym/graph_building.py b/src/gfn/gym/graph_building.py index fbde436b..59126468 100644 --- a/src/gfn/gym/graph_building.py +++ b/src/gfn/gym/graph_building.py @@ -49,6 +49,7 @@ def __init__( ) self.state_evaluator = state_evaluator + self.feature_dim = feature_dim super().__init__( s0=s0, diff --git a/src/gfn/states.py b/src/gfn/states.py index 386ed5d6..c9665356 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -49,8 +49,10 @@ class States(ABC): """ state_shape: ClassVar[tuple[int, ...]] # Shape of one state - s0: ClassVar[torch.Tensor] # Source state of the DAG - sf: ClassVar[torch.Tensor] # Dummy state, used to pad a batch of states + s0: ClassVar[torch.Tensor | TensorDict] # Source state of the DAG + sf: ClassVar[ + torch.Tensor | TensorDict + ] # Dummy state, used to pad a batch of states make_random_states_tensor: Callable = lambda x: (_ for _ in ()).throw( NotImplementedError( "The environment does not support initialization of random states." @@ -108,14 +110,24 @@ def make_initial_states_tensor(cls, batch_shape: tuple[int, ...]) -> torch.Tenso """Makes a tensor with a `batch_shape` of states consisting of $s_0`$s.""" state_ndim = len(cls.state_shape) assert cls.s0 is not None and state_ndim is not None - return cls.s0.repeat(*batch_shape, *((1,) * state_ndim)) + if isinstance(cls.s0, torch.Tensor): + return cls.s0.repeat(*batch_shape, *((1,) * state_ndim)) + else: + raise NotImplementedError( + "make_initial_states_tensor is not implemented by default for TensorDicts" + ) @classmethod def make_sink_states_tensor(cls, batch_shape: tuple[int, ...]) -> torch.Tensor: """Makes a tensor with a `batch_shape` of states consisting of $s_f$s.""" state_ndim = len(cls.state_shape) assert cls.sf is not None and state_ndim is not None - return cls.sf.repeat(*batch_shape, *((1,) * state_ndim)) + if isinstance(cls.sf, torch.Tensor): + return cls.sf.repeat(*batch_shape, *((1,) * state_ndim)) + else: + raise NotImplementedError( + "make_sink_states_tensor is not implemented by default for TensorDicts" + ) def __len__(self): return prod(self.batch_shape) @@ -139,7 +151,9 @@ def __getitem__( return out def __setitem__( - self, index: int | Sequence[int] | Sequence[bool], states: States + self, + index: int | slice | tuple | Sequence[int] | Sequence[bool] | torch.Tensor, + states: States, ) -> None: """Set particular states of the batch.""" self.tensor[index] = states.tensor @@ -209,7 +223,7 @@ def extend_with_sf(self, required_first_dim: int) -> None: Args: required_first_dim: The size of the first batch dimension post-expansion. """ - if len(self.batch_shape) == 2: + if len(self.batch_shape) == 2 and isinstance(self.__class__.sf, torch.Tensor): if self.batch_shape[0] >= required_first_dim: return self.tensor = torch.cat( @@ -224,7 +238,7 @@ def extend_with_sf(self, required_first_dim: int) -> None: self.batch_shape = (required_first_dim, self.batch_shape[1]) else: raise ValueError( - f"extend_with_sf is not implemented for batch shapes {self.batch_shape}" + f"extend_with_sf is not implemented for graph states nor for batch shapes {self.batch_shape}" ) def compare(self, other: torch.Tensor) -> torch.Tensor: @@ -248,22 +262,32 @@ def compare(self, other: torch.Tensor) -> torch.Tensor: @property def is_initial_state(self) -> torch.Tensor: """Returns a tensor of shape `batch_shape` that is True for states that are $s_0$ of the DAG.""" - source_states_tensor = self.__class__.s0.repeat( - *self.batch_shape, *((1,) * len(self.__class__.state_shape)) - ) + if isinstance(self.__class__.s0, torch.Tensor): + source_states_tensor = self.__class__.s0.repeat( + *self.batch_shape, *((1,) * len(self.__class__.state_shape)) + ) + else: + raise NotImplementedError( + "is_initial_state is not implemented by default for TensorDicts" + ) return self.compare(source_states_tensor) @property def is_sink_state(self) -> torch.Tensor: """Returns a tensor of shape `batch_shape` that is True for states that are $s_f$ of the DAG.""" # TODO: self.__class__.sf == self.tensor -- or something similar? - sink_states = self.__class__.sf.repeat( - *self.batch_shape, *((1,) * len(self.__class__.state_shape)) - ).to(self.tensor.device) + if isinstance(self.__class__.sf, torch.Tensor): + sink_states = self.__class__.sf.repeat( + *self.batch_shape, *((1,) * len(self.__class__.state_shape)) + ).to(self.tensor.device) + else: + raise NotImplementedError( + "is_sink_state is not implemented by default for TensorDicts" + ) return self.compare(sink_states) @property - def log_rewards(self) -> torch.Tensor: + def log_rewards(self) -> torch.Tensor | None: """Returns the log rewards of the states as tensor of shape `batch_shape`.""" return self._log_rewards @@ -563,9 +587,11 @@ def make_initial_states_tensor(cls, batch_shape: int | Tuple) -> TensorDict: "node_index": GraphStates.unique_node_indices(nodes.shape[0]), "edge_feature": cls.s0["edge_feature"].repeat(np.prod(batch_shape), 1), "edge_index": cls.s0["edge_index"].repeat(np.prod(batch_shape), 1), - "batch_ptr": torch.arange(np.prod(batch_shape) + 1) + "batch_ptr": torch.arange( + int(np.prod(batch_shape)) + 1, device=cls.s0.device + ) * cls.s0["node_feature"].shape[0], - "batch_shape": batch_shape, + "batch_shape": torch.tensor(batch_shape, device=cls.s0.device), } ) @@ -583,10 +609,10 @@ def make_sink_states_tensor(cls, batch_shape: int | Tuple) -> TensorDict: "edge_feature": cls.sf["edge_feature"].repeat(np.prod(batch_shape), 1), "edge_index": cls.sf["edge_index"].repeat(np.prod(batch_shape), 1), "batch_ptr": torch.arange( - np.prod(batch_shape) + 1, device=cls.sf.device + int(np.prod(batch_shape)) + 1, device=cls.sf.device ) * cls.sf["node_feature"].shape[0], - "batch_shape": batch_shape, + "batch_shape": torch.tensor(batch_shape, device=cls.sf.device), } ) return out @@ -603,20 +629,26 @@ def make_random_states_tensor(cls, batch_shape: int | Tuple) -> TensorDict: return TensorDict( { "node_feature": torch.rand( - np.prod(batch_shape) * num_nodes, node_features_dim, device=device + int(np.prod(batch_shape)) * num_nodes, + node_features_dim, + device=device, ), "node_index": GraphStates.unique_node_indices( - np.prod(batch_shape) * num_nodes + int(np.prod(batch_shape)) * num_nodes ), "edge_feature": torch.rand( - np.prod(batch_shape) * num_edges, edge_features_dim, device=device + int(np.prod(batch_shape)) * num_edges, + edge_features_dim, + device=device, ), "edge_index": torch.randint( - num_nodes, size=(np.prod(batch_shape) * num_edges, 2), device=device + num_nodes, + size=(int(np.prod(batch_shape)) * num_edges, 2), + device=device, ), - "batch_ptr": torch.arange(np.prod(batch_shape) + 1, device=device) + "batch_ptr": torch.arange(int(np.prod(batch_shape)) + 1, device=device) * num_nodes, - "batch_shape": batch_shape, + "batch_shape": torch.tensor(batch_shape), } ) @@ -671,8 +703,8 @@ def __getitem__( "node_index": torch.cat(node_indices), "edge_feature": torch.cat(edge_features), "edge_index": torch.cat(edge_indices), - "batch_ptr": torch.tensor(batch_ptr), - "batch_shape": (len(idx),), + "batch_ptr": torch.tensor(batch_ptr, device=self.tensor.device), + "batch_shape": torch.tensor(len(idx), device=self.tensor.device), } ) ) @@ -820,7 +852,7 @@ def extend(self, other: GraphStates): # This renumbers nodes across batch indices such that all nodes have # a unique ID. new_indices = GraphStates.unique_node_indices( - torch.sum(common_node_indices) + int(torch.sum(common_node_indices).item()) ) # find edge_index which contains other_node_index[common_node_indices]. this is @@ -858,7 +890,7 @@ def extend(self, other: GraphStates): ) + self.batch_shape[1:] @property - def log_rewards(self) -> torch.Tensor: + def log_rewards(self) -> torch.Tensor | None: return self._log_rewards @log_rewards.setter From e2ac8be836b7a316548b6edb245e279671233b6c Mon Sep 17 00:00:00 2001 From: Salem Lahlou Date: Thu, 20 Feb 2025 20:48:22 +0400 Subject: [PATCH 082/102] make batch_shape a property in general states, to make explicit the type (tuple), and to allow for inheritance from graphstates --- src/gfn/states.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/gfn/states.py b/src/gfn/states.py index f8956770..21102c84 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -69,11 +69,19 @@ def __init__(self, tensor: torch.Tensor): assert tensor.shape[-len(self.state_shape) :] == self.state_shape self.tensor = tensor - self.batch_shape = tuple(self.tensor.shape)[: -len(self.state_shape)] + self._batch_shape = tuple(self.tensor.shape)[: -len(self.state_shape)] self._log_rewards = ( None # Useful attribute if we want to store the log-reward of the states ) + @property + def batch_shape(self) -> tuple[int, ...]: + return self._batch_shape + + @batch_shape.setter + def batch_shape(self, batch_shape: tuple[int, ...]) -> None: + self._batch_shape = batch_shape + @classmethod def from_batch_shape( cls, batch_shape: tuple[int, ...], random: bool = False, sink: bool = False From 1543e49863b75317a91b4f8169580eade74b8ad4 Mon Sep 17 00:00:00 2001 From: Salem Lahlou Date: Thu, 20 Feb 2025 20:49:36 +0400 Subject: [PATCH 083/102] fix batch_shape in GraphStates --- src/gfn/states.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/gfn/states.py b/src/gfn/states.py index 21102c84..67637eb2 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -44,7 +44,7 @@ class States(ABC): Attributes: tensor: Tensor representing a batch of states. - batch_shape: Sizes of the batch dimensions. + _batch_shape: Sizes of the batch dimensions. _log_rewards: Stores the log rewards of each state. """ @@ -566,9 +566,8 @@ def __init__(self, tensor: TensorDict): self.edge_features_dim = tensor["edge_feature"].shape[-1] self._log_rewards: Optional[torch.Tensor] = None - # TODO: self.tensor["batch_shape"] is set wrong. @property - def batch_shape(self) -> tuple: + def batch_shape(self) -> tuple[int, ...]: return tuple(self.tensor["batch_shape"].tolist()) @classmethod @@ -714,9 +713,7 @@ def __getitem__( "edge_feature": torch.cat(edge_features), "edge_index": torch.cat(edge_indices), "batch_ptr": torch.tensor(batch_ptr, device=self.tensor.device), - "batch_shape": torch.tensor( - len(idx), device=self.tensor.device - ), # TODO: this shouldn't change from len(2) to len(1). + "batch_shape": torch.tensor(new_shape, device=self.tensor.device), } ) ) From 5dd5abaa4cc63e48208a9a32e14ba8d69660df02 Mon Sep 17 00:00:00 2001 From: Salem Lahlou Date: Thu, 20 Feb 2025 20:50:05 +0400 Subject: [PATCH 084/102] fix batch_shape in GraphRing --- tutorials/examples/train_graph_ring.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tutorials/examples/train_graph_ring.py b/tutorials/examples/train_graph_ring.py index fd6b759b..f0dc9209 100644 --- a/tutorials/examples/train_graph_ring.py +++ b/tutorials/examples/train_graph_ring.py @@ -350,7 +350,7 @@ def forward(self, states_tensor: TensorDict) -> torch.Tensor: exit_action = self.exit_mlp(edge_feature_means) edge_feature = edge_feature.reshape( - *states_tensor["batch_shape"], self.n_nodes, self.hidden_dim + states_tensor["batch_shape"].item(), self.n_nodes, self.hidden_dim ) # This is n_nodes ** 2, for each graph. @@ -378,7 +378,9 @@ def forward(self, states_tensor: TensorDict) -> torch.Tensor: # Grab the needed elems from the adjacency matrix and reshape. batch_arange = torch.arange(batch_size) edge_actions = edge_index[batch_arange[:, None, None], i0, i1] - edge_actions = edge_actions.reshape(*states_tensor["batch_shape"], out_size) + edge_actions = edge_actions.reshape( + states_tensor["batch_shape"].item(), out_size + ) if self.is_backward: return edge_actions From cda28e3c9966f6c0cc861c65a4a4594518d9b6b4 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Thu, 20 Feb 2025 14:58:13 -0500 Subject: [PATCH 085/102] removed unused code from forward_masks --- src/gfn/states.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/gfn/states.py b/src/gfn/states.py index 67637eb2..42cee279 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -987,13 +987,13 @@ def forward_masks(self) -> TensorDict: same_graph_mask = (arange_nodes >= self.tensor["batch_ptr"][:-1, None]) & ( arange_nodes < self.tensor["batch_ptr"][1:, None] ) - edge_index = torch.where( - self.tensor["edge_index"][..., None] == self.tensor["node_index"] - )[2].reshape(self.tensor["edge_index"].shape) - i, j = edge_index[..., 0], edge_index[..., 1] + # edge_index = torch.where( + # self.tensor["edge_index"][..., None] == self.tensor["node_index"] + # )[2].reshape(self.tensor["edge_index"].shape) + # i, j = edge_index[..., 0], edge_index[..., 1] - for _ in range(len(self.batch_shape)): - (i, j) = ei1.unsqueeze(0), ei2.unsqueeze(0) + # for _ in range(len(self.batch_shape)): + # (i, j) = ei1.unsqueeze(0), ei2.unsqueeze(0) # First allow nodes in the same graph to connect, then disable nodes with existing edges forward_masks["edge_index"][ From 9ba30041bfcc907de57c445aa68a8e4f0b23ba1c Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Thu, 20 Feb 2025 23:09:28 +0100 Subject: [PATCH 086/102] fix extends --- src/gfn/states.py | 106 +++++++++++++++++-------- tutorials/examples/train_graph_ring.py | 4 +- 2 files changed, 77 insertions(+), 33 deletions(-) diff --git a/src/gfn/states.py b/src/gfn/states.py index 42cee279..c943b0d3 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -850,7 +850,7 @@ def extend(self, other: GraphStates): self.tensor["node_feature"] = torch.cat( [self.tensor["node_feature"], other.tensor["node_feature"]], dim=0 ) - + # find if there are common node indices other_node_index = other.tensor["node_index"] other_edge_index = other.tensor["edge_index"] @@ -873,37 +873,81 @@ def extend(self, other: GraphStates): repeat_indices = new_indices[None, None].repeat(edge_mask.shape[0], 2, 1) other_edge_index[torch.any(edge_mask, dim=-1)] = repeat_indices[edge_mask] other_node_index[common_node_indices] = new_indices - - self.tensor["node_index"] = torch.cat( - [self.tensor["node_index"], other_node_index], dim=0 - ) - self.tensor["edge_feature"] = torch.cat( - [self.tensor["edge_feature"], other.tensor["edge_feature"]], dim=0 - ) - self.tensor["edge_index"] = torch.cat( - [self.tensor["edge_index"], other.tensor["edge_index"]], - dim=0, - ) - self.tensor["batch_ptr"] = torch.cat( - [ - self.tensor["batch_ptr"], - other.tensor["batch_ptr"][1:] + self.tensor["batch_ptr"][-1], - ], - dim=0, - ) - - # self.tensor["batch_shape"] = self.tensor["batch_shape"] + other.tensor["batch_shape"] - # If self.tensor is a placeholder and all batch_dims are 0, this check won't pass. - if not torch.all(self.tensor["batch_shape"] == 0): - assert torch.all( - self.tensor["batch_shape"][1:] == other.tensor["batch_shape"][1:] + if torch.prod(self.tensor["batch_shape"]) == 0: + # if self is empty, just copy other + self.tensor["batch_shape"] = other.tensor["batch_shape"] + self.tensor["node_index"] = other.tensor["node_index"] + self.tensor["edge_feature"] = other.tensor["edge_feature"] + self.tensor["edge_index"] = other.tensor["edge_index"] + self.tensor["batch_ptr"] = other.tensor["batch_ptr"] + + elif len(self.tensor["batch_shape"]) == 1: + self.tensor["node_index"] = torch.cat( + [self.tensor["node_index"], other_node_index], dim=0 ) - # self.tensor["batch_shape"] = ( - # self.tensor["batch_shape"][0] + other.tensor["batch_shape"][0], - # ) + self.batch_shape[1:] - self.tensor["batch_shape"] = ( - self.tensor["batch_shape"] + other.tensor["batch_shape"] - ) + self.tensor["edge_feature"] = torch.cat( + [self.tensor["edge_feature"], other.tensor["edge_feature"]], dim=0 + ) + self.tensor["edge_index"] = torch.cat( + [self.tensor["edge_index"], other.tensor["edge_index"]], + dim=0, + ) + self.tensor["batch_ptr"] = torch.cat( + [ + self.tensor["batch_ptr"], + other.tensor["batch_ptr"][1:] + self.tensor["batch_ptr"][-1], + ], + dim=0, + ) + self.tensor["batch_shape"] = ( + self.tensor["batch_shape"][0] + other.tensor["batch_shape"][0], + ) + self.batch_shape[1:] + else: + # Here we handle the case where the batch shape is (T, B) + # and we want to concatenate along the batch dimension B. + assert len(self.tensor["batch_shape"]) == 2 + max_len = max(self.tensor["batch_shape"][0], other.tensor["batch_shape"][0]) + + node_features = [] + node_indices = [] + edge_features = [] + edge_indices = [] + batch_ptr = [torch.tensor([0], device=self.tensor.device)] + for i in range(max_len): + # Following the logic of Base class, we want to extend with sink states + if i >= self.tensor["batch_shape"][0]: + self_i = self.make_sink_states_tensor(self.tensor["batch_shape"][1:]) + else: + self_i = self[i].tensor + if i >= other.tensor["batch_shape"][0]: + other_i = other.make_sink_states_tensor(other.tensor["batch_shape"][1:]) + else: + other_i = other[i].tensor + + node_features.append(self_i["node_feature"]) + node_indices.append(self_i["node_index"]) + edge_features.append(self_i["edge_feature"]) + edge_indices.append(self_i["edge_index"]) + batch_ptr.append(self_i["batch_ptr"][1:] + batch_ptr[-1][-1]) + + node_features.append(other_i["node_feature"]) + node_indices.append(other_i["node_index"]) + edge_features.append(other_i["edge_feature"]) + edge_indices.append(other_i["edge_index"]) + batch_ptr.append(other_i["batch_ptr"][1:] + batch_ptr[-1][-1]) + + self.tensor["node_feature"] = torch.cat(node_features, dim=0) + self.tensor["node_index"] = torch.cat(node_indices, dim=0) + self.tensor["edge_feature"] = torch.cat(edge_features, dim=0) + self.tensor["edge_index"] = torch.cat(edge_indices, dim=0) + self.tensor["batch_ptr"] = torch.cat(batch_ptr, dim=0) + + self.tensor["batch_shape"] = ( + max_len, + self.tensor["batch_shape"][1] + other.tensor["batch_shape"][1], + ) + + assert torch.prod(self.tensor["batch_shape"]) == len(self.tensor["batch_ptr"]) - 1 @property def log_rewards(self) -> torch.Tensor | None: diff --git a/tutorials/examples/train_graph_ring.py b/tutorials/examples/train_graph_ring.py index f0dc9209..de6fe937 100644 --- a/tutorials/examples/train_graph_ring.py +++ b/tutorials/examples/train_graph_ring.py @@ -350,7 +350,7 @@ def forward(self, states_tensor: TensorDict) -> torch.Tensor: exit_action = self.exit_mlp(edge_feature_means) edge_feature = edge_feature.reshape( - states_tensor["batch_shape"].item(), self.n_nodes, self.hidden_dim + *states_tensor["batch_shape"], self.n_nodes, self.hidden_dim ) # This is n_nodes ** 2, for each graph. @@ -379,7 +379,7 @@ def forward(self, states_tensor: TensorDict) -> torch.Tensor: batch_arange = torch.arange(batch_size) edge_actions = edge_index[batch_arange[:, None, None], i0, i1] edge_actions = edge_actions.reshape( - states_tensor["batch_shape"].item(), out_size + *states_tensor["batch_shape"], out_size ) if self.is_backward: From 3515c5406be27042ae7c396e86c8d289796cebce Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Fri, 21 Feb 2025 19:27:11 -0500 Subject: [PATCH 087/102] normal replay buffer can also be prioritized --- src/gfn/containers/__init__.py | 2 +- src/gfn/containers/replay_buffer.py | 87 +++++++++++++++-------------- 2 files changed, 46 insertions(+), 43 deletions(-) diff --git a/src/gfn/containers/__init__.py b/src/gfn/containers/__init__.py index a133fe4a..1b64a853 100644 --- a/src/gfn/containers/__init__.py +++ b/src/gfn/containers/__init__.py @@ -1,3 +1,3 @@ -from .replay_buffer import PrioritizedReplayBuffer, ReplayBuffer +from .replay_buffer import NormBasedDiversePrioritizedReplayBuffer, ReplayBuffer from .trajectories import Trajectories from .transitions import Transitions diff --git a/src/gfn/containers/replay_buffer.py b/src/gfn/containers/replay_buffer.py index bb12c0db..9df46014 100644 --- a/src/gfn/containers/replay_buffer.py +++ b/src/gfn/containers/replay_buffer.py @@ -22,6 +22,7 @@ class ReplayBuffer: training_objects: the buffer of objects used for training. terminating_states: a States class representation of $s_f$. objects_type: the type of buffer (transitions, trajectories, or states). + prioritized: whether the buffer is prioritized by log_reward or not. """ def __init__( @@ -29,6 +30,7 @@ def __init__( env: Env, objects_type: Literal["transitions", "trajectories", "states"], capacity: int = 1000, + prioritized: bool = False, ): """Instantiates a replay buffer. Args: @@ -53,6 +55,7 @@ def __init__( raise ValueError(f"Unknown objects_type: {objects_type}") self._is_full = False + self.prioritized = prioritized def __repr__(self): return f"ReplayBuffer(capacity={self.capacity}, containing {len(self)} {self.objects_type})" @@ -60,27 +63,53 @@ def __repr__(self): def __len__(self): return len(self.training_objects) - def add(self, training_objects: Transitions | Trajectories | tuple[States]): + def _add_objs( + self, + training_objects: Transitions | Trajectories | tuple[States], + ): """Adds a training object to the buffer.""" terminating_states = None if isinstance(training_objects, tuple): assert self.objects_type == "states" and self.terminating_states is not None training_objects, terminating_states = training_objects # pyright: ignore - + to_add = len(training_objects) - self._is_full |= len(self) + to_add >= self.capacity + # Adds the objects to the buffer. self.training_objects.extend(training_objects) # pyright: ignore + + # Sort elements by logreward, capping the size at the defined capacity. + if self.prioritized: + + if ( + self.training_objects.log_rewards is None + or training_objects.log_rewards is None # pyright: ignore + ): + raise ValueError("log_rewards must be defined for prioritized replay.") + + # Ascending sort. + ix = torch.argsort(self.training_objects.log_rewards) # pyright: ignore + self.training_objects = self.training_objects[ix] # pyright: ignore self.training_objects = self.training_objects[ - -self.capacity : + -self.capacity : # Ascending sort, so we retain the final elements. ] # pyright: ignore + # Add the terminating states to the buffer. if self.terminating_states is not None: assert terminating_states is not None - self.terminating_states.extend(terminating_states) # pyright: ignore + self.terminating_states.extend(terminating_states) + + # Sort terminating states by logreward as well. + if self.prioritized: + self.terminating_states = self.terminating_states[ix] + self.terminating_states = self.terminating_states[-self.capacity :] + def add(self, training_objects: Transitions | Trajectories | tuple[States]): + """Adds a training object to the buffer.""" + self._add_objs(training_objects) + def sample(self, n_trajectories: int) -> Transitions | Trajectories | tuple[States]: """Samples `n_trajectories` training objects from the buffer.""" if self.terminating_states is not None: @@ -113,8 +142,8 @@ def load(self, directory: str): ) -class PrioritizedReplayBuffer(ReplayBuffer): - """A replay buffer of trajectories or transitions. +class NormBasedDiversePrioritizedReplayBuffer(ReplayBuffer): + """A replay buffer of trajectories or transitions with diverse trajectories. Attributes: env: the Environment instance. @@ -152,53 +181,27 @@ def __init__( super().__init__(env, objects_type, capacity) self.cutoff_distance = cutoff_distance self.p_norm_distance = p_norm_distance + self._prioritized = True - def _add_objs( - self, - training_objects: Transitions | Trajectories | tuple[States], - terminating_states: States | None = None, - ): - """Adds a training object to the buffer.""" - # Adds the objects to the buffer. - self.training_objects.extend(training_objects) # pyright: ignore - - # Sort elements by logreward, capping the size at the defined capacity. - ix = torch.argsort(self.training_objects.log_rewards) # pyright: ignore - self.training_objects = self.training_objects[ix] # pyright: ignore - self.training_objects = self.training_objects[ - -self.capacity : - ] # pyright: ignore - - # Add the terminating states to the buffer. - if self.terminating_states is not None: - assert terminating_states is not None - self.terminating_states.extend(terminating_states) - - # Sort terminating states by logreward as well. - self.terminating_states = self.terminating_states[ix] - self.terminating_states = self.terminating_states[-self.capacity :] + @property + def prioritized(self) -> bool: + return self._prioritized def add(self, training_objects: Transitions | Trajectories | tuple[States]): """Adds a training object to the buffer.""" - terminating_states = None - if isinstance(training_objects, tuple): - assert self.objects_type == "states" and self.terminating_states is not None - training_objects, terminating_states = training_objects # pyright: ignore - to_add = len(training_objects) self._is_full |= len(self) + to_add >= self.capacity # The buffer isn't full yet. if len(self.training_objects) < self.capacity: - self._add_objs(training_objects, terminating_states) + self._add_objs(training_objects) # Our buffer is full and we will prioritize diverse, high reward additions. else: - if ( - self.training_objects.log_rewards is None - or training_objects.log_rewards is None # pyright: ignore - ): - raise ValueError("log_rewards must be defined for prioritized replay.") + terminating_states = None + if isinstance(training_objects, tuple): + assert self.objects_type == "states" and self.terminating_states is not None + training_objects, terminating_states = training_objects # pyright: ignore # Sort the incoming elements by their logrewards. ix = torch.argsort( From b40118f0a2619cde1c96457ec7ce51fd65b0e628 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Fri, 21 Feb 2025 20:21:57 -0500 Subject: [PATCH 088/102] directed GCN working on small graphs --- tutorials/examples/train_graph_ring.py | 257 +++++++++---------------- 1 file changed, 96 insertions(+), 161 deletions(-) diff --git a/tutorials/examples/train_graph_ring.py b/tutorials/examples/train_graph_ring.py index de6fe937..dd0913a9 100644 --- a/tutorials/examples/train_graph_ring.py +++ b/tutorials/examples/train_graph_ring.py @@ -19,7 +19,7 @@ import torch from tensordict import TensorDict from torch import nn -from torch_geometric.nn import GINConv, GCNConv +from torch_geometric.nn import GINConv, GCNConv, DirGNNConv from gfn.actions import Actions, GraphActions, GraphActionType from gfn.gflownet.trajectory_balance import TBGFlowNet @@ -31,6 +31,9 @@ from gfn.containers import ReplayBuffer +REW_VAL = 100.0 +EPS_VAL = 1e-6 + def directed_reward(states: GraphStates) -> torch.Tensor: """Compute the reward of a graph. @@ -42,11 +45,10 @@ def directed_reward(states: GraphStates) -> torch.Tensor: Returns: A tensor of rewards. """ - eps = 1e-6 if states.tensor["edge_index"].shape[0] == 0: - return torch.full(states.batch_shape, eps) + return torch.full(states.batch_shape, EPS_VAL) - out = torch.full((len(states),), eps) # Default reward. + out = torch.full((len(states),), EPS_VAL) # Default reward. for i in range(len(states)): start, end = states.tensor["batch_ptr"][i], states.tensor["batch_ptr"][i + 1] @@ -65,8 +67,7 @@ def directed_reward(states: GraphStates) -> torch.Tensor: if not torch.all(adj_matrix.sum(dim=1) == 1): continue - visited = [] - current = 0 + visited, current = [], 0 while current not in visited: visited.append(current) @@ -78,7 +79,7 @@ def set_diff(tensor1, tensor2): neighbors = torch.where(adj_matrix[current] == 1)[0] valid_neighbours = set_diff(neighbors, torch.tensor(visited)) - # Visit the fir + # Visit the first valid neighbor. if len(valid_neighbours) == 1: current = valid_neighbours[0] elif len(valid_neighbours) == 0: @@ -88,7 +89,7 @@ def set_diff(tensor1, tensor2): # Check if we visited all vertices and the last vertex connects back to start. if len(visited) == n_nodes and adj_matrix[current][0] == 1: - out[i] = 1.0 + out[i] = REW_VAL return out.view(*states.batch_shape) @@ -104,11 +105,10 @@ def undirected_reward(states: GraphStates) -> torch.Tensor: Returns: A tensor of rewards. """ - eps = 1e-4 if states.tensor["edge_index"].shape[0] == 0: - return torch.full(states.batch_shape, eps) + return torch.full(states.batch_shape, EPS_VAL) - out = torch.full((len(states),), eps) # Default reward. + out = torch.full((len(states),), EPS_VAL) # Default reward. for i in range(len(states)): start, end = states.tensor["batch_ptr"][i], states.tensor["batch_ptr"][i + 1] @@ -158,7 +158,7 @@ def undirected_reward(states: GraphStates) -> torch.Tensor: prev, current = current, next_node if current == start_vertex and len(visited) == n_nodes: - out[i] = 100.0 + out[i] = REW_VAL return out.view(*states.batch_shape) @@ -175,10 +175,11 @@ def __init__( n_nodes: int, directed: bool, num_conv_layers: int = 1, + embedding_dim: int = 128, is_backward: bool = False, ): super().__init__() - self.hidden_dim = self.embedding_dim = 64 + self.hidden_dim = self.embedding_dim = embedding_dim self.is_directed = directed self.is_backward = is_backward self.n_nodes = n_nodes @@ -186,7 +187,6 @@ def __init__( # Node embedding layer. self.embedding = nn.Embedding(n_nodes, self.embedding_dim) - # self.action_conv_blks = nn.ModuleList() self.conv_blks = nn.ModuleList() self.exit_mlp = MLP( input_dim=self.hidden_dim, @@ -197,69 +197,27 @@ def __init__( ) if directed: - # Multiple action type convolution layers. - # for i in range(num_conv_layers): - # mlp = create_mlp(self.embedding_dim, self.embedding_dim, self.embedding_dim) - # self.action_conv_blks.extend( - # [ - # GCNConv( - # self.embedding_dim if i == 0 else self.hidden_dim, - # self.hidden_dim, - # ), - # nn.Linear(self.hidden_dim, self.hidden_dim), - # nn.ReLU(), - # nn.Linear(self.hidden_dim, self.hidden_dim), - # ] - # ) - - # Multiple edge index convolution layers. for i in range(num_conv_layers): - # mlp = create_mlp(self.embedding_dim, self.embedding_dim, self.embedding_dim) self.conv_blks.extend( [ - GCNConv( - self.embedding_dim if i == 0 else self.hidden_dim, - self.hidden_dim, + DirGNNConv( + GCNConv( + self.embedding_dim if i == 0 else self.hidden_dim, + self.hidden_dim, + ), + alpha=0.5, + root_weight=True, ), - nn.Linear(self.hidden_dim, self.hidden_dim), - nn.ReLU(), - nn.Linear(self.hidden_dim, self.hidden_dim), - ] - ) + # Process in/out components separately + nn.ModuleList([ + nn.Sequential( + nn.Linear(self.hidden_dim//2, self.hidden_dim//2), + nn.ReLU(), + nn.Linear(self.hidden_dim//2, self.hidden_dim//2) + ) for _ in range(2) # One for in-features, one for out-features. + ]) + ]) else: # Undirected case. - # Multiple action type convolution layers. - # for _ in range(num_conv_layers): - # self.action_conv_blks.extend( - # [ - # GINConv( - # create_mlp( - # self.embedding_dim, self.hidden_dim, self.hidden_dim - # ) - # ), - # nn.Linear(self.hidden_dim, self.hidden_dim), - # nn.ReLU(), - # nn.Linear(self.hidden_dim, self.hidden_dim), - # ] - # ) - # for _ in range(num_conv_layers): - # self.conv_blks.extend( - # [ - # GINConv( - # MLP( - # input_dim=self.embedding_dim, - # output_dim=self.hidden_dim, - # hidden_dim=self.hidden_dim, - # n_hidden_layers=1, - # add_layer_norm=True, - # ), - # ), - # nn.Linear(self.hidden_dim, self.hidden_dim), - # nn.ReLU(), - # nn.Linear(self.hidden_dim, self.hidden_dim), - # ] - # ) - - # Multiple edge index convolution layers. for _ in range(num_conv_layers): self.conv_blks.extend( [ @@ -272,15 +230,15 @@ def __init__( add_layer_norm=True, ), ), - nn.Linear(self.hidden_dim, self.hidden_dim), - nn.ReLU(), - nn.Linear(self.hidden_dim, self.hidden_dim), + nn.Sequential( + nn.Linear(self.hidden_dim, self.hidden_dim), + nn.ReLU(), + nn.Linear(self.hidden_dim, self.hidden_dim), + ), ] ) - # Layer normalization for stability - # self.action_norm = nn.LayerNorm(self.hidden_dim) - self.edge_norm = nn.LayerNorm(self.hidden_dim) + self.norm = nn.LayerNorm(self.hidden_dim) def _group_mean( self, tensor: torch.Tensor, batch_ptr: torch.Tensor @@ -308,79 +266,57 @@ def forward(self, states_tensor: TensorDict) -> torch.Tensor: )[2].reshape( states_tensor["edge_index"].shape ) # (M, 2) - # edge_attrs = states_tensor["edge_feature"] # Multiple action type convolutions with residual connections. - # for i in range(0, len(self.action_conv_blks), 4): - - # # GIN/GCN conv. - # action_type_new = self.action_conv_blks[i](action_type, edge_index.T) - # # First linear. - # action_type_new = self.action_conv_blks[i + 1](action_type_new) - # # ReLU. - # action_type_new = self.action_conv_blks[i + 2](action_type_new) - # # Second linear. - # action_type_new = self.action_conv_blks[i + 3](action_type_new) - # # Residual connection with original input. - # action_type = action_type_new + action_type - # action_type = self.action_norm(action_type) + x = self.embedding(node_features.squeeze().int()) + for i in range(0, len(self.conv_blks), 2): + x_new = self.conv_blks[i](x, edge_index.T) # GIN/GCN conv. + if self.is_directed: + x_in, x_out = torch.chunk(x_new, 2, dim=-1) + # Process each component separately through its own MLP + x_in = self.conv_blks[i + 1][0](x_in) # First MLP in ModuleList. + x_out = self.conv_blks[i + 1][1](x_out) # Second MLP in ModuleList. + x_new = torch.cat([x_in, x_out], dim=-1) + else: + x_new = self.conv_blks[i + 1](x_new) # Linear -> ReLU -> Linear. - # Multiple action type convolutions with residual connections. - node_features = self.embedding(node_features.squeeze().int()) - for i in range(0, len(self.conv_blks), 4): - - # GIN/GCN conv. - node_feature_new = self.conv_blks[i](node_features, edge_index.T) - # First linear. - node_feature_new = self.conv_blks[i + 1](node_feature_new) - # ReLU. - node_feature_new = self.conv_blks[i + 2](node_feature_new) - # Second linear. - node_feature_new = self.conv_blks[i + 3](node_feature_new) - # Residual connection with original input. - edge_feature = node_feature_new + node_features - edge_feature = self.edge_norm(edge_feature) - - # edge_feature = self._group_mean( - # torch.mean(edge_feature, dim=-1, keepdim=True), batch_ptr - # ) + x = x_new + x if i > 0 else x_new # Residual connection. + x = self.norm(x) # Layernorm. # This MLP computes the exit action. - edge_feature_means = self._group_mean(edge_feature, batch_ptr) - exit_action = self.exit_mlp(edge_feature_means) + node_feature_means = self._group_mean(x, batch_ptr) + exit_action = self.exit_mlp(node_feature_means) - edge_feature = edge_feature.reshape( - *states_tensor["batch_shape"], self.n_nodes, self.hidden_dim - ) - - # This is n_nodes ** 2, for each graph. - edge_index = torch.einsum("bnf,bmf->bnm", edge_feature, edge_feature) - edge_index = edge_index / torch.sqrt(torch.tensor(self.hidden_dim)) + x = x.reshape(*states_tensor["batch_shape"], self.n_nodes, self.hidden_dim) # Undirected. if self.is_directed: - i_up, j_up = torch.triu_indices( - self.n_nodes, self.n_nodes, offset=1 - ) # Upper triangle. - i_lo, j_lo = torch.tril_indices( - self.n_nodes, self.n_nodes, offset=-1 - ) # Lower triangle. + feature_dim = self.hidden_dim // 2 + source_features = x[..., :feature_dim] + target_features = x[..., feature_dim:] - # Combine them + # Dot product between source and target features (asymmetric). + edgewise_dot_prod = torch.einsum("bnf,bmf->bnm", source_features, target_features) + edgewise_dot_prod = edgewise_dot_prod / torch.sqrt(torch.tensor(feature_dim)) + + i_up, j_up = torch.triu_indices(self.n_nodes, self.n_nodes, offset=1) + i_lo, j_lo = torch.tril_indices(self.n_nodes, self.n_nodes, offset=-1) + + # Combine them. i0 = torch.cat([i_up, i_lo]) i1 = torch.cat([j_up, j_lo]) out_size = self.n_nodes**2 - self.n_nodes else: + # Dot product between all node features (symmetric). + edgewise_dot_prod = torch.einsum("bnf,bmf->bnm", x, x) + edgewise_dot_prod = edgewise_dot_prod / torch.sqrt(torch.tensor(self.hidden_dim)) i0, i1 = torch.triu_indices(self.n_nodes, self.n_nodes, offset=1) out_size = (self.n_nodes**2 - self.n_nodes) // 2 # Grab the needed elems from the adjacency matrix and reshape. - batch_arange = torch.arange(batch_size) - edge_actions = edge_index[batch_arange[:, None, None], i0, i1] - edge_actions = edge_actions.reshape( - *states_tensor["batch_shape"], out_size - ) + edge_actions = x[torch.arange(batch_size)[:, None, None], i0, i1] + edge_actions = edge_actions.reshape(*states_tensor["batch_shape"], out_size) if self.is_backward: return edge_actions @@ -471,14 +407,6 @@ def forward_masks(self): else: ei0, ei1 = torch.triu_indices(self.n_nodes, self.n_nodes, offset=1) - # # Adds -1 "edge" representing exit, -2 "edge" representing dummy. - # ei0 = torch.cat((ei0, torch.IntTensor([-1, -2])), dim=0) - # ei1 = torch.cat((ei1, torch.IntTensor([-1, -2])), dim=0) - - # Indexes either the second last element (exit) or la - # action_tensor[action_tensor >= (self.n_actions - 1)] = 0 - # ei0, ei1 = ei0[action_tensor], ei1[action_tensor] - # Remove existing edges. for i in range(len(self)): existing_edges = ( @@ -598,9 +526,6 @@ def convert_actions(self, states: GraphStates, actions: Actions) -> GraphActions ei0 = torch.cat([i_up, i_lo]) ei1 = torch.cat([j_up, j_lo]) - # Potentially problematic (returns [0,0,0,1,1] instead of above which returns [1,1,1,0,0]). - # ei0 = (action_tensor) // (self.n_nodes) - # ei1 = (action_tensor) % (self.n_nodes) else: ei0, ei1 = torch.triu_indices(self.n_nodes, self.n_nodes, offset=1) @@ -711,11 +636,12 @@ def render_states(states: GraphStates, state_evaluator: callable, directed: bool if __name__ == "__main__": - N_NODES = 6 - N_ITERATIONS = 500 - LR = 0.05 + N_NODES = 4 + N_ITERATIONS = 1000 + LR = 0.001 BATCH_SIZE = 128 - DIRECTED = False + DIRECTED = True + USE_BUFFER = False state_evaluator = undirected_reward if not DIRECTED else directed_reward torch.random.manual_seed(7) @@ -739,7 +665,8 @@ def render_states(states: GraphStates, state_evaluator: callable, directed: bool replay_buffer = ReplayBuffer( env, objects_type="trajectories", - capacity=1000, + capacity=BATCH_SIZE, + prioritized=True, ) losses = [] @@ -747,25 +674,31 @@ def render_states(states: GraphStates, state_evaluator: callable, directed: bool t1 = time.time() for iteration in range(N_ITERATIONS): trajectories = gflownet.sample_trajectories( - env, n=BATCH_SIZE # pyright: ignore + env, n=BATCH_SIZE, save_logprobs=True, # pyright: ignore ) - last_states = trajectories.last_states + training_samples = gflownet.to_training_samples(trajectories) + + # Collect rewards for reporting. + if isinstance(training_samples, tuple): + last_states = training_samples[1] + else: + last_states = training_samples.last_states assert isinstance(last_states, GraphStates) rewards = state_evaluator(last_states) - training_samples = gflownet.to_training_samples(trajectories) - with torch.no_grad(): - replay_buffer.add(training_samples) - training_objects = replay_buffer.sample(n_trajectories=BATCH_SIZE) + if USE_BUFFER: + with torch.no_grad(): + replay_buffer.add(training_samples) + if iteration > 20: + training_samples = training_samples[:BATCH_SIZE // 2] + buffer_samples = replay_buffer.sample(n_trajectories=BATCH_SIZE // 2) + training_samples.extend(buffer_samples) optimizer.zero_grad() - loss = gflownet.loss(env, training_objects) # pyright: ignore - print( - "Iteration", - iteration, - "Loss:", - loss.item(), - f"rings: {torch.mean(rewards > 0.1, dtype=torch.float) * 100:.0f}%", + loss = gflownet.loss(env, training_samples) # pyright: ignore + pct_rings = torch.mean(rewards > 0.1, dtype=torch.float) * 100 + print("Iteration {} - Loss: {:.02f}, rings: {:.0f}%".format( + iteration, loss.item(), pct_rings) ) loss.backward() optimizer.step() @@ -773,6 +706,8 @@ def render_states(states: GraphStates, state_evaluator: callable, directed: bool t2 = time.time() print("Time:", t2 - t1) + + # This comes from the gflownet, not the buffer. last_states = trajectories.last_states[:8] assert isinstance(last_states, GraphStates) render_states(last_states, state_evaluator, DIRECTED) From de092a0ace65ba47e6cf0988eeed924627f42ecc Mon Sep 17 00:00:00 2001 From: "Omar G. Younis" Date: Mon, 24 Feb 2025 15:29:25 +0100 Subject: [PATCH 089/102] fix extend --- src/gfn/states.py | 81 +++++++------ testing/test_states.py | 250 ++++++++++++++++++++++++++++++++--------- 2 files changed, 246 insertions(+), 85 deletions(-) diff --git a/src/gfn/states.py b/src/gfn/states.py index c943b0d3..53e5c984 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -847,41 +847,33 @@ def clone(self) -> GraphStates: def extend(self, other: GraphStates): """Concatenates to another GraphStates object along the batch dimension""" - self.tensor["node_feature"] = torch.cat( - [self.tensor["node_feature"], other.tensor["node_feature"]], dim=0 - ) - # find if there are common node indices - other_node_index = other.tensor["node_index"] - other_edge_index = other.tensor["edge_index"] - common_node_indices = torch.any( - self.tensor["node_index"][:, None] == other_node_index[None, :], dim=0 - ) - if torch.any(common_node_indices): - # This renumbers nodes across batch indices such that all nodes have - # a unique ID. - new_indices = GraphStates.unique_node_indices( - int(torch.sum(common_node_indices).item()) - ) + other_node_index = other.tensor["node_index"].clone() # Clone to avoid modifying original + other_edge_index = other.tensor["edge_index"].clone() # Clone to avoid modifying original + + # Always generate new indices for the other state to ensure uniqueness + new_indices = GraphStates.unique_node_indices(len(other_node_index)) + + # Update edge indices to match new node indices + for i, old_idx in enumerate(other_node_index): + other_edge_index[other_edge_index == old_idx] = new_indices[i] + + # Update node indices + other_node_index = new_indices - # find edge_index which contains other_node_index[common_node_indices]. this is - # because all new edges must be to new nodes (unique). - edge_mask = ( - other_edge_index[:, :, None] - == other_node_index[None, common_node_indices] - ) - repeat_indices = new_indices[None, None].repeat(edge_mask.shape[0], 2, 1) - other_edge_index[torch.any(edge_mask, dim=-1)] = repeat_indices[edge_mask] - other_node_index[common_node_indices] = new_indices if torch.prod(self.tensor["batch_shape"]) == 0: # if self is empty, just copy other + self.tensor["node_feature"] = other.tensor["node_feature"] self.tensor["batch_shape"] = other.tensor["batch_shape"] - self.tensor["node_index"] = other.tensor["node_index"] + self.tensor["node_index"] = other_node_index self.tensor["edge_feature"] = other.tensor["edge_feature"] - self.tensor["edge_index"] = other.tensor["edge_index"] + self.tensor["edge_index"] = other_edge_index self.tensor["batch_ptr"] = other.tensor["batch_ptr"] elif len(self.tensor["batch_shape"]) == 1: + self.tensor["node_feature"] = torch.cat( + [self.tensor["node_feature"], other.tensor["node_feature"]], dim=0 + ) self.tensor["node_index"] = torch.cat( [self.tensor["node_index"], other_node_index], dim=0 ) @@ -889,7 +881,7 @@ def extend(self, other: GraphStates): [self.tensor["edge_feature"], other.tensor["edge_feature"]], dim=0 ) self.tensor["edge_index"] = torch.cat( - [self.tensor["edge_index"], other.tensor["edge_index"]], + [self.tensor["edge_index"], other_edge_index], dim=0, ) self.tensor["batch_ptr"] = torch.cat( @@ -912,7 +904,10 @@ def extend(self, other: GraphStates): node_indices = [] edge_features = [] edge_indices = [] - batch_ptr = [torch.tensor([0], device=self.tensor.device)] + # Get device from one of the tensors + device = self.tensor["node_feature"].device + batch_ptr = [torch.tensor([0], device=device)] + for i in range(max_len): # Following the logic of Base class, we want to extend with sink states if i >= self.tensor["batch_shape"][0]: @@ -923,17 +918,33 @@ def extend(self, other: GraphStates): other_i = other.make_sink_states_tensor(other.tensor["batch_shape"][1:]) else: other_i = other[i].tensor + + # Generate new unique indices for both self_i and other_i + new_self_indices = GraphStates.unique_node_indices(len(self_i["node_index"])) + new_other_indices = GraphStates.unique_node_indices(len(other_i["node_index"])) + + # Update self_i edge indices + self_edge_index = self_i["edge_index"].clone() + for old_idx, new_idx in zip(self_i["node_index"], new_self_indices): + mask = (self_edge_index == old_idx) + self_edge_index[mask] = new_idx + + # Update other_i edge indices + other_edge_index = other_i["edge_index"].clone() + for old_idx, new_idx in zip(other_i["node_index"], new_other_indices): + mask = (other_edge_index == old_idx) + other_edge_index[mask] = new_idx node_features.append(self_i["node_feature"]) - node_indices.append(self_i["node_index"]) + node_indices.append(new_self_indices) # Use new indices edge_features.append(self_i["edge_feature"]) - edge_indices.append(self_i["edge_index"]) + edge_indices.append(self_edge_index) # Use updated edge indices batch_ptr.append(self_i["batch_ptr"][1:] + batch_ptr[-1][-1]) node_features.append(other_i["node_feature"]) - node_indices.append(other_i["node_index"]) + node_indices.append(new_other_indices) # Use new indices edge_features.append(other_i["edge_feature"]) - edge_indices.append(other_i["edge_index"]) + edge_indices.append(other_edge_index) # Use updated edge indices batch_ptr.append(other_i["batch_ptr"][1:] + batch_ptr[-1][-1]) self.tensor["node_feature"] = torch.cat(node_features, dim=0) @@ -947,7 +958,10 @@ def extend(self, other: GraphStates): self.tensor["batch_shape"][1] + other.tensor["batch_shape"][1], ) - assert torch.prod(self.tensor["batch_shape"]) == len(self.tensor["batch_ptr"]) - 1 + assert self.tensor["node_index"].unique().numel() == len( + self.tensor["node_index"] + ) + assert torch.prod(torch.tensor(self.tensor["batch_shape"])) == len(self.tensor["batch_ptr"]) - 1 @property def log_rewards(self) -> torch.Tensor | None: @@ -995,6 +1009,7 @@ def stack(cls, states: List[GraphStates]): """Given a list of states, stacks them along a new dimension (0).""" stacked_states = cls.from_batch_shape(0) state_batch_shape = states[0].batch_shape + assert len(state_batch_shape) == 1 for state in states: assert state.batch_shape == state_batch_shape stacked_states.extend(state) diff --git a/testing/test_states.py b/testing/test_states.py index 2961ab97..1182e6cd 100644 --- a/testing/test_states.py +++ b/testing/test_states.py @@ -1,56 +1,202 @@ +import pytest +import torch from gfn.states import GraphStates from tensordict import TensorDict -import torch -def make_graph_states(n_graphs, n_nodes, n_edges): - batch_ptr = torch.cat([torch.zeros(1), n_nodes.cumsum(0)]).int() - node_feature = torch.randn(batch_ptr[-1].item(), 10) - node_index = torch.arange(0, batch_ptr[-1].item()) - - edge_features = torch.randn(n_edges.sum(), 10) - edge_index = [] - for i, (start, end) in enumerate(zip(batch_ptr[:-1], batch_ptr[1:])): - edge_index.append(torch.randint(start, end, (n_edges[i], 2))) - edge_index = torch.cat(edge_index) - - return GraphStates( - TensorDict( - { - "node_feature": node_feature, - "edge_feature": edge_features, - "edge_index": edge_index, - "node_index": node_index, - "batch_ptr": batch_ptr, - "batch_shape": torch.tensor([n_graphs]), - } - ) - ) - - -def test_get_set(): - n_graphs = 10 - n_nodes = torch.randint(1, 10, (n_graphs,)) - n_edges = torch.randint(1, 10, (n_graphs,)) - graphs = make_graph_states(10, n_nodes, n_edges) - assert not graphs[0]._compare(graphs[9].tensor) - last_graph = graphs[9] - graphs = graphs[:-1] - graphs[0] = last_graph - assert graphs[0]._compare(last_graph.tensor) - - -def test_stack(): - GraphStates.s0 = make_graph_states(1, torch.tensor([1]), torch.tensor([0])).tensor - n_graphs = 10 - n_nodes = torch.randint(1, 10, (n_graphs,)) - n_edges = torch.randint(1, 10, (n_graphs,)) - graphs = make_graph_states(10, n_nodes, n_edges) - stacked_graphs = GraphStates.stack([graphs[0], graphs[1]]) - assert stacked_graphs.batch_shape == (2, 1) - assert stacked_graphs[0]._compare(graphs[0].tensor) - assert stacked_graphs[1]._compare(graphs[1].tensor) - - -if __name__ == "__main__": - test_get_set() +class MyGraphStates(GraphStates): + s0 = TensorDict({ + "node_feature": torch.tensor([[1.0], [2.0]]), + "node_index": torch.tensor([0, 1]), + "edge_feature": torch.tensor([[0.5]]), + "edge_index": torch.tensor([[0, 1]]), + }) + sf = TensorDict({ + "node_feature": torch.tensor([[3.0], [4.0]]), + "node_index": torch.tensor([2, 3]), + "edge_feature": torch.tensor([[0.7]]), + "edge_index": torch.tensor([[2, 3]]), + }) + +@pytest.fixture +def simple_graph_state(): + """Creates a simple graph state with 2 nodes and 1 edge""" + tensor = TensorDict({ + "node_feature": torch.tensor([[1.0], [2.0]]), + "node_index": torch.tensor([0, 1]), + "edge_feature": torch.tensor([[0.5]]), + "edge_index": torch.tensor([[0, 1]]), + "batch_ptr": torch.tensor([0, 2]), + "batch_shape": torch.tensor([1]) + }) + return MyGraphStates(tensor) + +@pytest.fixture +def empty_graph_state(): + """Creates an empty graph state""" + tensor = TensorDict({ + "node_feature": torch.tensor([]), + "node_index": torch.tensor([]), + "edge_feature": torch.tensor([]), + "edge_index": torch.tensor([]).reshape(0, 2), + "batch_ptr": torch.tensor([0]), + "batch_shape": torch.tensor([0]) + }) + return MyGraphStates(tensor) + +def test_extend_empty_state(empty_graph_state, simple_graph_state): + """Test extending an empty state with a non-empty state""" + empty_graph_state.extend(simple_graph_state) + + assert torch.equal(empty_graph_state.tensor["batch_shape"], simple_graph_state.tensor["batch_shape"]) + assert torch.equal(empty_graph_state.tensor["node_feature"], simple_graph_state.tensor["node_feature"]) + assert torch.equal(empty_graph_state.tensor["edge_index"], simple_graph_state.tensor["edge_index"]) + assert torch.equal(empty_graph_state.tensor["edge_feature"], simple_graph_state.tensor["edge_feature"]) + assert torch.equal(empty_graph_state.tensor["batch_ptr"], simple_graph_state.tensor["batch_ptr"]) + +def test_extend_1d_batch(simple_graph_state): + """Test extending two 1D batch states""" + other_state = simple_graph_state.clone() + + # The node indices should be different after extend + original_node_indices = simple_graph_state.tensor["node_index"].clone() + + simple_graph_state.extend(other_state) + + assert simple_graph_state.tensor["batch_shape"][0] == 2 + assert len(simple_graph_state.tensor["node_feature"]) == 4 + assert len(simple_graph_state.tensor["edge_feature"]) == 2 + + # Check that node indices were properly updated (should be unique) + new_node_indices = simple_graph_state.tensor["node_index"] + assert len(torch.unique(new_node_indices)) == len(new_node_indices) + assert not torch.equal(new_node_indices[:2], new_node_indices[2:]) + +def test_extend_2d_batch(): + """Test extending two 2D batch states""" + # Create 2D batch states (T=2, B=1) + tensor1 = TensorDict({ + "node_feature": torch.tensor([[1.0], [2.0], [3.0], [4.0]]), + "node_index": torch.tensor([0, 1, 2, 3]), + "edge_feature": torch.tensor([[0.5], [0.6]]), + "edge_index": torch.tensor([[0, 1], [2, 3]]), + "batch_ptr": torch.tensor([0, 2, 4]), + "batch_shape": torch.tensor([2, 1]) + }) + state1 = MyGraphStates(tensor1) + + # Create another state with different time length (T=3, B=1) + tensor2 = TensorDict({ + "node_feature": torch.tensor([[5.0], [6.0], [7.0], [8.0], [9.0], [10.0]]), + "node_index": torch.tensor([4, 5, 6, 7, 8, 9]), + "edge_feature": torch.tensor([[0.7], [0.8], [0.9]]), + "edge_index": torch.tensor([[4, 5], [6, 7], [8, 9]]), + "batch_ptr": torch.tensor([0, 2, 4, 6]), + "batch_shape": torch.tensor([3, 1]) + }) + state2 = MyGraphStates(tensor2) + + state1.extend(state2) + + # Check final shape should be (max_len=3, B=2) + assert torch.equal(state1.tensor["batch_shape"], torch.tensor([3, 2])) + + # Check that we have the correct number of nodes and edges + # Each time step has 2 nodes and 1 edge, so for 3 time steps and 2 batches: + expected_nodes = 3 * 2 * 2 # T * nodes_per_timestep * B + expected_edges = 3 * 1 * 2 # T * edges_per_timestep * B + assert len(state1.tensor["node_feature"]) == expected_nodes + assert len(state1.tensor["edge_feature"]) == expected_edges + +def test_extend_with_common_indices(simple_graph_state): + """Test extending states with common node indices""" + # Create a state with overlapping node indices + tensor = TensorDict({ + "node_feature": torch.tensor([[3.0], [4.0]]), + "node_index": torch.tensor([1, 2]), # Note: index 1 overlaps + "edge_feature": torch.tensor([[0.7]]), + "edge_index": torch.tensor([[1, 2]]), + "batch_ptr": torch.tensor([0, 2]), + "batch_shape": torch.tensor([1]) + }) + other_state = MyGraphStates(tensor) + + simple_graph_state.extend(other_state) + + # Check that node indices are unique after extend + assert len(torch.unique(simple_graph_state.tensor["node_index"])) == 4 + + # Check that edge indices were properly updated + edge_indices = simple_graph_state.tensor["edge_index"] + assert torch.all(edge_indices >= 0) + +def test_stack_1d_batch(): + """Test stacking multiple 1D batch states""" + # Create first state + tensor1 = TensorDict({ + "node_feature": torch.tensor([[1.0], [2.0]]), + "node_index": torch.tensor([0, 1]), + "edge_feature": torch.tensor([[0.5]]), + "edge_index": torch.tensor([[0, 1]]), + "batch_ptr": torch.tensor([0, 2]), + "batch_shape": torch.tensor([1]) + }) + state1 = MyGraphStates(tensor1) + + # Create second state with different values + tensor2 = TensorDict({ + "node_feature": torch.tensor([[3.0], [4.0]]), + "node_index": torch.tensor([0, 5]), + "edge_feature": torch.tensor([[0.7]]), + "edge_index": torch.tensor([[2, 3]]), + "batch_ptr": torch.tensor([0, 2]), + "batch_shape": torch.tensor([1]) + }) + state2 = MyGraphStates(tensor2) + + # Stack the states + stacked = MyGraphStates.stack([state1, state2]) + + # Check the batch shape is correct (2, 1) + assert torch.equal(stacked.tensor["batch_shape"], torch.tensor([2, 1])) + + # Check that node features are preserved and ordered correctly + assert torch.equal(stacked.tensor["node_feature"], + torch.tensor([[1.0], [2.0], [3.0], [4.0]])) + + # Check that edge features are preserved and ordered correctly + assert torch.equal(stacked.tensor["edge_feature"], + torch.tensor([[0.5], [0.7]])) + + # Check that node indices are unique + assert len(torch.unique(stacked.tensor["node_index"])) == 4 + + # Check batch pointers are correct + assert torch.equal(stacked.tensor["batch_ptr"], torch.tensor([0, 2, 4])) + +def test_stack_empty_states(): + """Test stacking empty states""" + # Create empty state + tensor = TensorDict({ + "node_feature": torch.tensor([]), + "node_index": torch.tensor([]), + "edge_feature": torch.tensor([]), + "edge_index": torch.tensor([]).reshape(0, 2), + "batch_ptr": torch.tensor([0]), + "batch_shape": torch.tensor([0]) + }) + empty_state = MyGraphStates(tensor) + + # Stack multiple empty states + stacked = MyGraphStates.stack([empty_state, empty_state]) + + # Check the batch shape is correct (2, 0) + assert torch.equal(stacked.tensor["batch_shape"], torch.tensor([2, 0])) + + # Check that tensors are empty + assert stacked.tensor["node_feature"].numel() == 0 + assert stacked.tensor["edge_feature"].numel() == 0 + assert stacked.tensor["edge_index"].numel() == 0 + + # Check batch pointers are correct + assert torch.equal(stacked.tensor["batch_ptr"], torch.tensor([0])) From 39f91d3776327f07e834b64ba4511cd38f363354 Mon Sep 17 00:00:00 2001 From: Salem Lahlou Date: Tue, 25 Feb 2025 11:11:20 +0400 Subject: [PATCH 090/102] add adjacency matrix module --- tutorials/examples/train_graph_ring.py | 180 ++++++++++++++++++++++--- 1 file changed, 160 insertions(+), 20 deletions(-) diff --git a/tutorials/examples/train_graph_ring.py b/tutorials/examples/train_graph_ring.py index dd0913a9..259cd8b4 100644 --- a/tutorials/examples/train_graph_ring.py +++ b/tutorials/examples/train_graph_ring.py @@ -34,6 +34,7 @@ REW_VAL = 100.0 EPS_VAL = 1e-6 + def directed_reward(states: GraphStates) -> torch.Tensor: """Compute the reward of a graph. @@ -202,21 +203,31 @@ def __init__( [ DirGNNConv( GCNConv( - self.embedding_dim if i == 0 else self.hidden_dim, + self.embedding_dim if i == 0 else self.hidden_dim, self.hidden_dim, ), alpha=0.5, root_weight=True, ), # Process in/out components separately - nn.ModuleList([ - nn.Sequential( - nn.Linear(self.hidden_dim//2, self.hidden_dim//2), - nn.ReLU(), - nn.Linear(self.hidden_dim//2, self.hidden_dim//2) - ) for _ in range(2) # One for in-features, one for out-features. - ]) - ]) + nn.ModuleList( + [ + nn.Sequential( + nn.Linear( + self.hidden_dim // 2, self.hidden_dim // 2 + ), + nn.ReLU(), + nn.Linear( + self.hidden_dim // 2, self.hidden_dim // 2 + ), + ) + for _ in range( + 2 + ) # One for in-features, one for out-features. + ] + ), + ] + ) else: # Undirected case. for _ in range(num_conv_layers): self.conv_blks.extend( @@ -296,8 +307,12 @@ def forward(self, states_tensor: TensorDict) -> torch.Tensor: target_features = x[..., feature_dim:] # Dot product between source and target features (asymmetric). - edgewise_dot_prod = torch.einsum("bnf,bmf->bnm", source_features, target_features) - edgewise_dot_prod = edgewise_dot_prod / torch.sqrt(torch.tensor(feature_dim)) + edgewise_dot_prod = torch.einsum( + "bnf,bmf->bnm", source_features, target_features + ) + edgewise_dot_prod = edgewise_dot_prod / torch.sqrt( + torch.tensor(feature_dim) + ) i_up, j_up = torch.triu_indices(self.n_nodes, self.n_nodes, offset=1) i_lo, j_lo = torch.tril_indices(self.n_nodes, self.n_nodes, offset=-1) @@ -310,7 +325,9 @@ def forward(self, states_tensor: TensorDict) -> torch.Tensor: else: # Dot product between all node features (symmetric). edgewise_dot_prod = torch.einsum("bnf,bmf->bnm", x, x) - edgewise_dot_prod = edgewise_dot_prod / torch.sqrt(torch.tensor(self.hidden_dim)) + edgewise_dot_prod = edgewise_dot_prod / torch.sqrt( + torch.tensor(self.hidden_dim) + ) i0, i1 = torch.triu_indices(self.n_nodes, self.n_nodes, offset=1) out_size = (self.n_nodes**2 - self.n_nodes) // 2 @@ -635,6 +652,115 @@ def render_states(states: GraphStates, state_evaluator: callable, directed: bool plt.show() +class AdjacencyPolicyModule(nn.Module): + """Simple MLP that processes flattened adjacency matrices instead of using GNN. + + Args: + n_nodes: The number of nodes in the graph. + directed: Whether the graph is directed. + embedding_dim: Dimension of embeddings. + is_backward: Whether this is a backward policy. + """ + + def __init__( + self, + n_nodes: int, + directed: bool, + embedding_dim: int = 128, + is_backward: bool = False, + ): + super().__init__() + self.n_nodes = n_nodes + self.is_directed = directed + self.is_backward = is_backward + self.hidden_dim = embedding_dim + + # MLP for processing the flattened adjacency matrix + self.mlp = MLP( + input_dim=n_nodes * n_nodes, # Flattened adjacency matrix + output_dim=embedding_dim, + hidden_dim=embedding_dim, + n_hidden_layers=2, + add_layer_norm=True, + ) + + # Exit action MLP + self.exit_mlp = MLP( + input_dim=embedding_dim, + output_dim=1, + hidden_dim=embedding_dim, + n_hidden_layers=1, + add_layer_norm=True, + ) + + # Edge prediction MLP + if directed: + # For directed graphs: all off-diagonal elements + out_size = n_nodes**2 - n_nodes + else: + # For undirected graphs: upper triangle without diagonal + out_size = (n_nodes**2 - n_nodes) // 2 + + self.edge_mlp = MLP( + input_dim=embedding_dim, + output_dim=out_size, + hidden_dim=embedding_dim, + n_hidden_layers=1, + add_layer_norm=True, + ) + + def forward(self, states_tensor: TensorDict) -> torch.Tensor: + # Convert the graph to adjacency matrix + batch_size = int(torch.prod(torch.tensor(states_tensor["batch_shape"]))) + adj_matrices = torch.zeros( + (batch_size, self.n_nodes, self.n_nodes), + device=states_tensor["node_feature"].device, + ) + + # Fill the adjacency matrices from edge indices + for i in range(batch_size): + start, end = ( + states_tensor["batch_ptr"][i], + states_tensor["batch_ptr"][i + 1], + ) + nodes_index_range = states_tensor["node_index"][start:end] + + # Skip if no edges + if states_tensor["edge_index"].shape[0] == 0: + continue + + # Find edges that belong to this graph + edge_index_mask = torch.all( + states_tensor["edge_index"] >= nodes_index_range[0], dim=-1 + ) & torch.all(states_tensor["edge_index"] <= nodes_index_range[-1], dim=-1) + + if torch.any(edge_index_mask): + # Get the edge indices relative to this graph's node indices + masked_edge_index = ( + states_tensor["edge_index"][edge_index_mask] - nodes_index_range[0] + ) + # Fill the adjacency matrix + if len(masked_edge_index) > 0: + adj_matrices[ + i, masked_edge_index[:, 0], masked_edge_index[:, 1] + ] = 1 + + # Flatten the adjacency matrices for the MLP + adj_matrices_flat = adj_matrices.view(batch_size, -1) + + # Process with MLP + embedding = self.mlp(adj_matrices_flat) + + # Generate edge and exit actions + edge_actions = self.edge_mlp(embedding) + exit_action = self.exit_mlp(embedding) + + if self.is_backward: + return edge_actions + else: + return torch.cat([edge_actions, exit_action], dim=-1) + + if __name__ == "__main__": N_NODES = 4 N_ITERATIONS = 1000 @@ -642,14 +768,22 @@ def render_states(states: GraphStates, state_evaluator: callable, directed: bool BATCH_SIZE = 128 DIRECTED = True USE_BUFFER = False + USE_GNN = False # Set to False to use MLP with adjacency matrices instead of GNN state_evaluator = undirected_reward if not DIRECTED else directed_reward torch.random.manual_seed(7) env = RingGraphBuilding( n_nodes=N_NODES, state_evaluator=state_evaluator, directed=DIRECTED ) - module_pf = RingPolicyModule(env.n_nodes, DIRECTED) - module_pb = RingPolicyModule(env.n_nodes, DIRECTED, is_backward=True) + + # Choose model type based on USE_GNN flag + if USE_GNN: + module_pf = RingPolicyModule(env.n_nodes, DIRECTED) + module_pb = RingPolicyModule(env.n_nodes, DIRECTED, is_backward=True) + else: + module_pf = AdjacencyPolicyModule(env.n_nodes, DIRECTED) + module_pb = AdjacencyPolicyModule(env.n_nodes, DIRECTED, is_backward=True) + pf = DiscretePolicyEstimator( module=module_pf, n_actions=env.n_actions, preprocessor=GraphPreprocessor() ) @@ -674,7 +808,9 @@ def render_states(states: GraphStates, state_evaluator: callable, directed: bool t1 = time.time() for iteration in range(N_ITERATIONS): trajectories = gflownet.sample_trajectories( - env, n=BATCH_SIZE, save_logprobs=True, # pyright: ignore + env, + n=BATCH_SIZE, + save_logprobs=True, # pyright: ignore ) training_samples = gflownet.to_training_samples(trajectories) @@ -690,15 +826,19 @@ def render_states(states: GraphStates, state_evaluator: callable, directed: bool with torch.no_grad(): replay_buffer.add(training_samples) if iteration > 20: - training_samples = training_samples[:BATCH_SIZE // 2] - buffer_samples = replay_buffer.sample(n_trajectories=BATCH_SIZE // 2) + training_samples = training_samples[: BATCH_SIZE // 2] + buffer_samples = replay_buffer.sample( + n_trajectories=BATCH_SIZE // 2 + ) training_samples.extend(buffer_samples) optimizer.zero_grad() loss = gflownet.loss(env, training_samples) # pyright: ignore pct_rings = torch.mean(rewards > 0.1, dtype=torch.float) * 100 - print("Iteration {} - Loss: {:.02f}, rings: {:.0f}%".format( - iteration, loss.item(), pct_rings) + print( + "Iteration {} - Loss: {:.02f}, rings: {:.0f}%".format( + iteration, loss.item(), pct_rings + ) ) loss.backward() optimizer.step() @@ -706,7 +846,7 @@ def render_states(states: GraphStates, state_evaluator: callable, directed: bool t2 = time.time() print("Time:", t2 - t1) - + # This comes from the gflownet, not the buffer. last_states = trajectories.last_states[:8] assert isinstance(last_states, GraphStates) From 80b2a9e95852accc13798e61a99939f533ce7702 Mon Sep 17 00:00:00 2001 From: Salem Lahlou Date: Tue, 25 Feb 2025 11:33:37 +0400 Subject: [PATCH 091/102] Simplify ring graph reward calculation logic --- tutorials/examples/train_graph_ring.py | 36 ++++++++++++-------------- 1 file changed, 17 insertions(+), 19 deletions(-) diff --git a/tutorials/examples/train_graph_ring.py b/tutorials/examples/train_graph_ring.py index 259cd8b4..ea001991 100644 --- a/tutorials/examples/train_graph_ring.py +++ b/tutorials/examples/train_graph_ring.py @@ -39,6 +39,7 @@ def directed_reward(states: GraphStates) -> torch.Tensor: """Compute the reward of a graph. Specifically, the reward is 1 if the graph is a ring, 1e-6 otherwise. + A ring is a directed cycle where each node has exactly one outgoing and one incoming edge. Args: states: A batch of graphs. @@ -65,32 +66,29 @@ def directed_reward(states: GraphStates) -> torch.Tensor: adj_matrix = torch.zeros(n_nodes, n_nodes) adj_matrix[masked_edge_index[:, 0], masked_edge_index[:, 1]] = 1 + # Check if each node has exactly one outgoing edge (row sum = 1) if not torch.all(adj_matrix.sum(dim=1) == 1): continue - visited, current = [], 0 + # Check that each node has exactly one incoming edge (column sum = 1) + if not torch.all(adj_matrix.sum(dim=0) == 1): + continue + + # Starting from node 0, follow edges and see if we visit all nodes + # and return to the start + visited = [] + current = 0 # Start from node 0 + while current not in visited: visited.append(current) - def set_diff(tensor1, tensor2): - mask = ~torch.isin(tensor1, tensor2) - return tensor1[mask] - - # Find an unvisited neighbor - neighbors = torch.where(adj_matrix[current] == 1)[0] - valid_neighbours = set_diff(neighbors, torch.tensor(visited)) + # Get the outgoing neighbor + current = torch.where(adj_matrix[current] == 1)[0].item() - # Visit the first valid neighbor. - if len(valid_neighbours) == 1: - current = valid_neighbours[0] - elif len(valid_neighbours) == 0: + # If we've visited all nodes and returned to 0, it's a valid ring + if len(visited) == n_nodes and current == 0: + out[i] = REW_VAL break - else: - break # TODO: This actually should never happen, should be caught on line 45. - - # Check if we visited all vertices and the last vertex connects back to start. - if len(visited) == n_nodes and adj_matrix[current][0] == 1: - out[i] = REW_VAL return out.view(*states.batch_shape) @@ -762,7 +760,7 @@ def forward(self, states_tensor: TensorDict) -> torch.Tensor: if __name__ == "__main__": - N_NODES = 4 + N_NODES = 3 N_ITERATIONS = 1000 LR = 0.001 BATCH_SIZE = 128 From c24297d0ad806e310ffd34eb785016d18c8a776f Mon Sep 17 00:00:00 2001 From: "Omar G. Younis" Date: Tue, 25 Feb 2025 10:51:25 +0100 Subject: [PATCH 092/102] fix forward mask --- src/gfn/states.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/gfn/states.py b/src/gfn/states.py index 53e5c984..f8ee0e2b 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -1046,20 +1046,20 @@ def forward_masks(self) -> TensorDict: same_graph_mask = (arange_nodes >= self.tensor["batch_ptr"][:-1, None]) & ( arange_nodes < self.tensor["batch_ptr"][1:, None] ) - # edge_index = torch.where( - # self.tensor["edge_index"][..., None] == self.tensor["node_index"] - # )[2].reshape(self.tensor["edge_index"].shape) - # i, j = edge_index[..., 0], edge_index[..., 1] + edge_index = torch.where( + self.tensor["edge_index"][..., None] == self.tensor["node_index"] + )[2].reshape(self.tensor["edge_index"].shape) + i, j = edge_index[..., 0], edge_index[..., 1] - # for _ in range(len(self.batch_shape)): - # (i, j) = ei1.unsqueeze(0), ei2.unsqueeze(0) + for _ in range(len(self.batch_shape)): + (i, j) = i.unsqueeze(0), j.unsqueeze(0) # First allow nodes in the same graph to connect, then disable nodes with existing edges forward_masks["edge_index"][ same_graph_mask[:, :, None] & same_graph_mask[:, None, :] ] = True torch.diagonal(forward_masks["edge_index"], dim1=-2, dim2=-1).fill_(False) - forward_masks["edge_index"][arange[..., None], ei1, ei2] = False + forward_masks["edge_index"][arange[..., None], i, j] = False forward_masks["action_type"][..., GraphActionType.ADD_EDGE] &= torch.any( forward_masks["edge_index"], dim=(-1, -2) ) From 3c390647a619d40566a79fb6504e582e3cbe5162 Mon Sep 17 00:00:00 2001 From: Salem Lahlou Date: Tue, 25 Feb 2025 19:31:17 +0400 Subject: [PATCH 093/102] add docstrings --- tutorials/examples/train_graph_ring.py | 233 ++++++++++++++++++++++--- 1 file changed, 207 insertions(+), 26 deletions(-) diff --git a/tutorials/examples/train_graph_ring.py b/tutorials/examples/train_graph_ring.py index ea001991..507b776a 100644 --- a/tutorials/examples/train_graph_ring.py +++ b/tutorials/examples/train_graph_ring.py @@ -36,16 +36,23 @@ def directed_reward(states: GraphStates) -> torch.Tensor: - """Compute the reward of a graph. + """Compute reward for directed ring graphs. - Specifically, the reward is 1 if the graph is a ring, 1e-6 otherwise. - A ring is a directed cycle where each node has exactly one outgoing and one incoming edge. + This function evaluates if a graph forms a valid directed ring (directed cycle). + A valid directed ring must satisfy these conditions: + 1. Each node must have exactly one outgoing edge (row sum = 1 in adjacency matrix) + 2. Each node must have exactly one incoming edge (column sum = 1 in adjacency matrix) + 3. Following the edges must form a single cycle that includes all nodes + + The reward is binary: + - REW_VAL (100.0) for valid directed rings + - EPS_VAL (1e-6) for invalid structures Args: - states: A batch of graphs. + states: A batch of graph states to evaluate Returns: - A tensor of rewards. + A tensor of rewards with the same batch shape as states """ if states.tensor["edge_index"].shape[0] == 0: return torch.full(states.batch_shape, EPS_VAL) @@ -94,15 +101,27 @@ def directed_reward(states: GraphStates) -> torch.Tensor: def undirected_reward(states: GraphStates) -> torch.Tensor: - """Compute the reward of a graph. + """Compute reward for undirected ring graphs. + + This function evaluates if a graph forms a valid undirected ring (cycle). + A valid undirected ring must satisfy these conditions: + 1. Each node must have exactly two neighbors (degree = 2) + 2. The graph must form a single connected cycle including all nodes + + The reward is binary: + - REW_VAL (100.0) for valid undirected rings + - EPS_VAL (1e-6) for invalid structures - Specifically, the reward is 1 if the graph is an undirected ring, 1e-6 otherwise. + The algorithm: + 1. Checks that all nodes have degree 2 + 2. Performs a traversal starting from node 0, following edges + 3. Checks if the traversal visits all nodes and returns to start Args: - states: A batch of graphs. + states: A batch of graph states to evaluate Returns: - A tensor of rewards. + A tensor of rewards with the same batch shape as states """ if states.tensor["edge_index"].shape[0] == 0: return torch.full(states.batch_shape, EPS_VAL) @@ -330,7 +349,9 @@ def forward(self, states_tensor: TensorDict) -> torch.Tensor: out_size = (self.n_nodes**2 - self.n_nodes) // 2 # Grab the needed elems from the adjacency matrix and reshape. - edge_actions = x[torch.arange(batch_size)[:, None, None], i0, i1] + edge_actions = edgewise_dot_prod[ + torch.arange(batch_size)[:, None, None], i0, i1 + ] edge_actions = edge_actions.reshape(*states_tensor["batch_shape"], out_size) if self.is_backward: @@ -340,15 +361,24 @@ def forward(self, states_tensor: TensorDict) -> torch.Tensor: class RingGraphBuilding(GraphBuilding): - """Override the GraphBuilding class to create have discrete actions. + """Environment for building ring graphs with discrete action space. - Specifically, at initialization, we have n nodes. - The policy can only add edges between existing nodes or use the exit action. - The action space is thus discrete and of size n^2 + 1, where the last action is the exit action, - and the first n^2 actions are the possible edges. + This environment is specialized for creating ring graphs where each node has + exactly two neighbors and the edges form a single cycle. The environment supports + both directed and undirected graphs. + + In each state, the policy can: + 1. Add an edge between existing nodes + 2. Use the exit action to terminate graph building + + The action space is discrete, with size: + - For directed graphs: n_nodes^2 - n_nodes + 1 (all possible directed edges + exit) + - For undirected graphs: (n_nodes^2 - n_nodes)/2 + 1 (upper triangle + exit) Args: n_nodes: The number of nodes in the graph. + state_evaluator: A function that evaluates a state and returns a reward. + directed: Whether the graph should be directed. """ def __init__(self, n_nodes: int, state_evaluator: callable, directed: bool): @@ -367,6 +397,14 @@ def make_actions_class(self) -> type[Actions]: env = self class RingActions(Actions): + """Actions for building ring graphs. + + Actions are represented as discrete indices where: + - 0 to n_actions-2: Adding an edge between specific nodes + - n_actions-1: EXIT action to terminate the trajectory + - n_actions: DUMMY action (used for padding) + """ + action_shape = (1,) dummy_action = torch.tensor([env.n_actions]) exit_action = torch.tensor([env.n_actions - 1]) @@ -377,6 +415,25 @@ def make_states_class(self) -> type[GraphStates]: env = self class RingStates(GraphStates): + """Represents the state of a ring graph building process. + + This class extends GraphStates to specifically handle ring graph states. + Each state represents a graph with a fixed number of nodes where edges + are being added incrementally to form a ring structure. + + The state representation consists of: + - node_feature: Node IDs for each node in the graph (shape: [n_nodes, 1]) + - edge_feature: Features for each edge (shape: [n_edges, 1]) + - edge_index: Indices representing the source and target nodes for each edge (shape: [n_edges, 2]) + + Special states: + - s0: Initial state with n_nodes and no edges + - sf: Terminal state (used as a placeholder) + + The class provides masks for both forward and backward actions to determine + which actions are valid from the current state. + """ + s0 = TensorDict( { "node_feature": torch.arange(env.n_nodes)[:, None], @@ -405,6 +462,20 @@ def __init__(self, tensor: TensorDict): @property def forward_masks(self): + """Compute masks for valid forward actions from the current state. + + A forward action is valid if: + 1. The edge doesn't already exist in the graph + 2. The edge connects two distinct nodes + + For directed graphs, all possible src->dst edges are considered. + For undirected graphs, only the upper triangular portion of the adjacency matrix is used. + + The last action is always the EXIT action, which is always valid. + + Returns: + Tensor: Boolean mask of shape [batch_size, n_actions] where True indicates valid actions + """ # Allow all actions. forward_masks = torch.ones(len(self), self.n_actions, dtype=torch.bool) @@ -456,6 +527,19 @@ def forward_masks(self, value: torch.Tensor): @property def backward_masks(self): + """Compute masks for valid backward actions from the current state. + + A backward action is valid if: + 1. The edge exists in the current graph (i.e., can be removed) + + For directed graphs, all existing edges are considered for removal. + For undirected graphs, only the upper triangular edges are considered. + + The EXIT action is not included in backward masks. + + Returns: + Tensor: Boolean mask of shape [batch_size, n_actions-1] where True indicates valid actions + """ # Disallow all actions. backward_masks = torch.zeros( len(self), self.n_actions - 1, dtype=torch.bool @@ -507,19 +591,51 @@ def backward_masks(self, value: torch.Tensor): return RingStates def _step(self, states: GraphStates, actions: Actions) -> GraphStates: + """Take a step in the environment by applying actions to states. + + Args: + states: Current states batch + actions: Actions to apply + + Returns: + New states after applying the actions + """ actions = self.convert_actions(states, actions) new_states = super()._step(states, actions) assert isinstance(new_states, GraphStates) return new_states def _backward_step(self, states: GraphStates, actions: Actions) -> GraphStates: + """Take a backward step in the environment. + + Args: + states: Current states batch + actions: Actions to apply in reverse + + Returns: + New states after applying the backward actions + """ actions = self.convert_actions(states, actions) new_states = super()._backward_step(states, actions) assert isinstance(new_states, GraphStates) return new_states def convert_actions(self, states: GraphStates, actions: Actions) -> GraphActions: - """Converts the action from discrete space to graph action space.""" + """Converts the action from discrete space to graph action space. + + This method maps discrete action indices to specific graph operations: + - GraphActionType.ADD_EDGE: Add an edge between specific nodes + - GraphActionType.EXIT: Terminate trajectory + - GraphActionType.DUMMY: No-op action (for padding) + + Args: + states: Current states batch + actions: Discrete actions to convert + + Returns: + Equivalent actions in the GraphActions format + """ + # TODO: factor out into utility function. action_tensor = actions.tensor.squeeze(-1).clone() action_type = torch.where( action_tensor == self.n_actions - 1, @@ -528,6 +644,7 @@ def convert_actions(self, states: GraphStates, actions: Actions) -> GraphActions ) action_type[action_tensor == self.n_actions] = GraphActionType.DUMMY + # Convert action indices to source-target node pairs # TODO: factor out into utility function. if self.is_directed: i_up, j_up = torch.triu_indices( @@ -567,7 +684,15 @@ def convert_actions(self, states: GraphStates, actions: Actions) -> GraphActions class GraphPreprocessor(Preprocessor): - """Extract the tensor from the states.""" + """Preprocessor for graph states to extract the tensor representation. + + This simple preprocessor extracts the TensorDict from GraphStates to make + it compatible with the policy networks. It doesn't perform any complex + transformations, just ensuring the tensors are accessible in the right format. + + Args: + feature_dim: The dimension of features in the graph (default: 1) + """ def __init__(self, feature_dim: int = 1): super().__init__(output_dim=feature_dim) @@ -575,15 +700,26 @@ def __init__(self, feature_dim: int = 1): def preprocess(self, states: GraphStates) -> TensorDict: return states.tensor - def __call__(self, states: GraphStates) -> torch.Tensor: + def __call__(self, states: GraphStates) -> TensorDict: return self.preprocess(states) def render_states(states: GraphStates, state_evaluator: callable, directed: bool): - """Render the states as a matplotlib plot. + """Visualize a batch of graph states as ring structures. + + This function creates a matplotlib visualization of graph states, rendering them + as circular layouts with nodes positioned evenly around a circle. For directed + graphs, edges are shown as arrows; for undirected graphs, edges are shown as lines. + + The visualization includes: + - Circular positioning of nodes + - Drawing edges between connected nodes + - Displaying the reward value for each graph Args: - states: A batch of graphs. + states: A batch of graphs to visualize + state_evaluator: Function to compute rewards for each graph + directed: Whether to render directed or undirected edges """ rewards = state_evaluator(states) fig, ax = plt.subplots(2, 4, figsize=(15, 7)) @@ -651,13 +787,25 @@ def render_states(states: GraphStates, state_evaluator: callable, directed: bool class AdjacencyPolicyModule(nn.Module): - """Simple MLP that processes flattened adjacency matrices instead of using GNN. + """Policy network that processes flattened adjacency matrices to predict graph actions. + + Unlike the GNN-based RingPolicyModule, this module uses standard MLPs to process + the entire adjacency matrix as a flattened vector. This approach: + + 1. Can directly process global graph structure without message passing + 2. May be more effective for small graphs where global patterns are important + 3. Does not require complex graph neural network operations + + The module architecture consists of: + - An MLP to process the flattened adjacency matrix into an embedding + - An edge MLP that predicts logits for each possible edge action + - An exit MLP that predicts a logit for the exit action Args: - n_nodes: The number of nodes in the graph. - directed: Whether the graph is directed. - embedding_dim: Dimension of embeddings. - is_backward: Whether this is a backward policy. + n_nodes: Number of nodes in the graph + directed: Whether the graph is directed or undirected + embedding_dim: Dimension of internal embeddings (default: 128) + is_backward: Whether this is a backward policy (default: False) """ def __init__( @@ -708,6 +856,19 @@ def __init__( ) def forward(self, states_tensor: TensorDict) -> torch.Tensor: + """Forward pass to compute action logits from graph states. + + Process: + 1. Convert the graph representation to adjacency matrices + 2. Process the flattened adjacency matrices through the main MLP + 3. Predict logits for edge actions and exit action + + Args: + states_tensor: A TensorDict containing graph state information + + Returns: + A tensor of logits for all possible actions + """ # Convert the graph to adjacency matrix batch_size = int(torch.prod(torch.tensor(states_tensor["batch_shape"]))) adj_matrices = torch.zeros( @@ -760,7 +921,27 @@ def forward(self, states_tensor: TensorDict) -> torch.Tensor: if __name__ == "__main__": - N_NODES = 3 + """ + Main execution for training a GFlowNet to generate ring graphs. + + This script demonstrates the complete workflow of training a GFlowNet + to generate valid ring structures in both directed and undirected settings. + + Configurable parameters: + - N_NODES: Number of nodes in the graph (default: 5) + - N_ITERATIONS: Number of training iterations (default: 1000) + - LR: Learning rate for optimizer (default: 0.001) + - BATCH_SIZE: Batch size for training (default: 128) + - DIRECTED: Whether to generate directed rings (default: True) + - USE_BUFFER: Whether to use a replay buffer (default: False) + - USE_GNN: Whether to use GNN-based policy (True) or MLP-based policy (False) + + The script performs the following steps: + 1. Initialize the environment and policy networks + 2. Train the GFlowNet using trajectory balance + 3. Visualize sample generated graphs + """ + N_NODES = 4 N_ITERATIONS = 1000 LR = 0.001 BATCH_SIZE = 128 From 0a1f8e1b28e0cd40c6972275e299e09368079f14 Mon Sep 17 00:00:00 2001 From: "Omar G. Younis" Date: Wed, 26 Feb 2025 15:26:18 +0100 Subject: [PATCH 094/102] switch from torch_geoemtric to TensorDict --- src/gfn/env.py | 16 +- src/gfn/gym/graph_building.py | 287 ++++--- src/gfn/states.py | 982 +++++++++++----------- src/gfn/utils/training.py | 2 +- testing/test_environments.py | 24 +- testing/test_samplers_and_trajectories.py | 13 +- testing/test_states.py | 470 +++++++---- 7 files changed, 989 insertions(+), 805 deletions(-) diff --git a/src/gfn/env.py b/src/gfn/env.py index cdc84ca4..8c1992b3 100644 --- a/src/gfn/env.py +++ b/src/gfn/env.py @@ -2,7 +2,7 @@ from typing import Optional, Tuple, Union import torch -from tensordict import TensorDict +from torch_geometric.data import Batch, Data from gfn.actions import Actions, GraphActions from gfn.preprocessors import IdentityPreprocessor, Preprocessor @@ -23,12 +23,12 @@ class Env(ABC): def __init__( self, - s0: torch.Tensor | TensorDict, + s0: torch.Tensor | Data, state_shape: Tuple, action_shape: Tuple, dummy_action: torch.Tensor, exit_action: torch.Tensor, - sf: Optional[torch.Tensor | TensorDict] = None, + sf: Optional[torch.Tensor | Data] = None, device_str: Optional[str] = None, preprocessor: Optional[Preprocessor] = None, ): @@ -275,7 +275,7 @@ def _step( new_not_done_states_tensor = self.step(not_done_states, not_done_actions) - if not isinstance(new_not_done_states_tensor, (torch.Tensor, TensorDict)): + if not isinstance(new_not_done_states_tensor, (torch.Tensor, Batch)): raise Exception( "User implemented env.step function *must* return a torch.Tensor!" ) @@ -571,12 +571,12 @@ def terminating_states(self) -> DiscreteStates: class GraphEnv(Env): """Base class for graph-based environments.""" - sf: TensorDict # this tells the type checker that sf is a TensorDict + sf: Data # this tells the type checker that sf is a Data def __init__( self, - s0: TensorDict, - sf: TensorDict, + s0: Data, + sf: Data, device_str: Optional[str] = None, preprocessor: Optional[Preprocessor] = None, ): @@ -592,7 +592,7 @@ def __init__( the IdentityPreprocessor is used. """ self.s0 = s0.to(device_str) - self.features_dim = s0["node_feature"].shape[-1] + self.features_dim = s0.x.shape[-1] self.sf = sf self.States = self.make_states_class() diff --git a/src/gfn/gym/graph_building.py b/src/gfn/gym/graph_building.py index 59126468..1006f351 100644 --- a/src/gfn/gym/graph_building.py +++ b/src/gfn/gym/graph_building.py @@ -1,7 +1,7 @@ from typing import Callable, Literal, Tuple import torch -from tensordict import TensorDict +from torch_geometric.data import Data, Batch from gfn.actions import GraphActions, GraphActionType from gfn.env import GraphEnv, NonValidActionsError @@ -29,22 +29,16 @@ def __init__( state_evaluator: Callable[[GraphStates], torch.Tensor], device_str: Literal["cpu", "cuda"] = "cpu", ): - s0 = TensorDict( - { - "node_feature": torch.zeros((0, feature_dim), dtype=torch.float32), - "edge_feature": torch.zeros((0, feature_dim), dtype=torch.float32), - "edge_index": torch.zeros((0, 2), dtype=torch.long), - }, + s0 = Data( + x=torch.zeros((0, feature_dim), dtype=torch.float32), + edge_attr=torch.zeros((0, feature_dim), dtype=torch.float32), + edge_index=torch.zeros((2, 0), dtype=torch.long), device=device_str, ) - sf = TensorDict( - { - "node_feature": torch.ones((1, feature_dim), dtype=torch.float32) - * float("inf"), - "edge_feature": torch.ones((0, feature_dim), dtype=torch.float32) - * float("inf"), - "edge_index": torch.zeros((0, 2), dtype=torch.long), - }, + sf = Data( + x=torch.ones((1, feature_dim), dtype=torch.float32) * float("inf"), + edge_attr=torch.ones((0, feature_dim), dtype=torch.float32) * float("inf"), + edge_index=torch.zeros((2, 0), dtype=torch.long), device=device_str, ) @@ -63,7 +57,7 @@ def reset(self, batch_shape: Tuple | int) -> GraphStates: assert isinstance(states, GraphStates) return states - def step(self, states: GraphStates, actions: GraphActions) -> TensorDict: + def step(self, states: GraphStates, actions: GraphActions) -> Batch: """Step function for the GraphBuilding environment. Args: @@ -74,8 +68,6 @@ def step(self, states: GraphStates, actions: GraphActions) -> TensorDict: """ if not self.is_action_valid(states, actions): raise NonValidActionsError("Invalid action.") - if len(actions) == 0: - return states.tensor action_type = actions.action_type[0] assert torch.all(actions.action_type == action_type) @@ -91,20 +83,34 @@ def step(self, states: GraphStates, actions: GraphActions) -> TensorDict: ) if action_type == GraphActionType.ADD_EDGE: - states.tensor["edge_feature"] = torch.cat( - [states.tensor["edge_feature"], actions.features], dim=0 - ) - states.tensor["edge_index"] = torch.cat( - [ - states.tensor["edge_index"], - actions.edge_index, - ], - dim=0, - ) + # Get the data list from the batch + data_list = states.tensor.to_data_list() + + # Add edges to each graph + for i, (src, dst) in enumerate(actions.edge_index): + # Get the graph to modify + graph = data_list[i] + + # Add the new edge + graph.edge_index = torch.cat([ + graph.edge_index, + torch.tensor([[src], [dst]], device=graph.edge_index.device) + ], dim=1) + + # Add the edge feature + graph.edge_attr = torch.cat([ + graph.edge_attr, + actions.features[i].unsqueeze(0) + ], dim=0) + + # Create a new batch from the updated data list + new_tensor = Batch.from_data_list(data_list) + new_tensor.batch_shape = states.tensor.batch_shape + states.tensor = new_tensor return states.tensor - def backward_step(self, states: GraphStates, actions: GraphActions) -> TensorDict: + def backward_step(self, states: GraphStates, actions: GraphActions) -> Batch: """Backward step function for the GraphBuilding environment. Args: @@ -120,112 +126,161 @@ def backward_step(self, states: GraphStates, actions: GraphActions) -> TensorDic action_type = actions.action_type[0] assert torch.all(actions.action_type == action_type) + + # Get the data list from the batch + data_list = states.tensor.to_data_list() + if action_type == GraphActionType.ADD_NODE: - is_equal = torch.any( - torch.all( - states.tensor["node_feature"][:, None] == actions.features, dim=-1 - ), - dim=-1, - ) - states.tensor["node_feature"] = states.tensor["node_feature"][~is_equal] - elif action_type == GraphActionType.ADD_EDGE: - assert actions.edge_index is not None - is_equal = torch.all( - states.tensor["edge_index"] == actions.edge_index[:, None], dim=-1 - ) - is_equal = torch.any(is_equal, dim=0) - states.tensor["edge_feature"] = states.tensor["edge_feature"][~is_equal] - states.tensor["edge_index"] = states.tensor["edge_index"][~is_equal] + # Remove nodes with matching features + for i, features in enumerate(actions.features): + graph = data_list[i] + + # Find nodes with matching features + is_equal = torch.all(graph.x == features.unsqueeze(0), dim=1) + + if torch.any(is_equal): + # Remove the first matching node + node_idx = torch.where(is_equal)[0][0].item() + + # Remove the node + mask = torch.ones(graph.num_nodes, dtype=torch.bool, device=graph.x.device) + mask[node_idx] = False + + # Update node features + graph.x = graph.x[mask] - return states.tensor + elif action_type == GraphActionType.ADD_EDGE: + # Remove edges with matching indices + for i, (src, dst) in enumerate(actions.edge_index): + graph = data_list[i] + + # Find the edge to remove + edge_mask = ~((graph.edge_index[0] == src) & (graph.edge_index[1] == dst)) + + # Remove the edge + graph.edge_index = graph.edge_index[:, edge_mask] + graph.edge_attr = graph.edge_attr[edge_mask] + + # Create a new batch from the updated data list + new_batch = Batch.from_data_list(data_list) + + # Preserve the batch shape + new_batch.batch_shape = torch.tensor(states.batch_shape, device=states.tensor.x.device) + + return new_batch def is_action_valid( self, states: GraphStates, actions: GraphActions, backward: bool = False ) -> bool: - add_node_mask = actions.action_type == GraphActionType.ADD_NODE - if not torch.any(add_node_mask): - add_node_out = True - else: - node_feature = states[add_node_mask].tensor["node_feature"] - equal_nodes_per_batch = torch.all( - node_feature == actions[add_node_mask].features[:, None], dim=-1 - ).sum(dim=-1) - if backward: # TODO: check if no edge is connected? - add_node_out = torch.all(equal_nodes_per_batch == 1) + """Check if actions are valid for the given states. + + Args: + states: Current graph states. + actions: Actions to validate. + backward: Whether this is a backward step. + + Returns: + True if all actions are valid, False otherwise. + """ + # Get the data list from the batch + data_list = states.tensor.to_data_list() + + for i, features in enumerate(actions.features[actions.action_type == GraphActionType.ADD_NODE]): + graph = data_list[i] + + # Check if a node with these features already exists + equal_nodes = torch.all(graph.x == features.unsqueeze(0), dim=1) + + if backward: + # For backward actions, we need at least one matching node + if not torch.any(equal_nodes): + return False else: - add_node_out = torch.all(equal_nodes_per_batch == 0) - - add_edge_mask = actions.action_type == GraphActionType.ADD_EDGE - if not torch.any(add_edge_mask): - add_edge_out = True - else: - add_edge_states = states.tensor - add_edge_actions = actions[add_edge_mask].edge_index - if torch.any(add_edge_actions[:, 0] == add_edge_actions[:, 1]): - return False - if add_edge_states["node_feature"].shape[0] == 0: - return False - node_exists = torch.isin(add_edge_actions, add_edge_states["node_index"]) - if not torch.all(node_exists): + # For forward actions, we should not have any matching nodes + if torch.any(equal_nodes): + return False + + for i, (src, dst) in enumerate(actions.edge_index[actions.action_type == GraphActionType.ADD_EDGE]): + graph = data_list[i] + + # Check if src and dst are valid node indices + if src >= graph.num_nodes or dst >= graph.num_nodes or src == dst: return False - - equal_edges_per_batch = torch.all( - add_edge_states["edge_index"] == add_edge_actions[:, None], - dim=-1, - ).sum(dim=-1) + + # Check if the edge already exists + edge_exists = torch.any((graph.edge_index[0] == src) & (graph.edge_index[1] == dst)) + if backward: - add_edge_out = torch.all(equal_edges_per_batch != 0) + # For backward actions, the edge must exist + if not edge_exists: + return False else: - add_edge_out = torch.all(equal_edges_per_batch == 0) + # For forward actions, the edge must not exist + if edge_exists: + return False - return bool(add_node_out) and bool(add_edge_out) + return True def _add_node( self, - tensor_dict: TensorDict, + tensor: Batch, batch_indices: torch.Tensor, nodes_to_add: torch.Tensor, - ) -> TensorDict: + ) -> Batch: + """Add nodes to graphs in a batch. + + Args: + tensor_dict: The current batch of graphs. + batch_indices: Indices of graphs to add nodes to. + nodes_to_add: Features of nodes to add. + + Returns: + Updated batch of graphs. + """ if isinstance(batch_indices, list): batch_indices = torch.tensor(batch_indices) if len(batch_indices) != len(nodes_to_add): raise ValueError( "Number of batch indices must match number of node feature lists" - ) - - modified_dict = tensor_dict.clone() - node_feature_dim = modified_dict["node_feature"].shape[1] - + ) + # Get the data list from the batch + data_list = tensor.to_data_list() + + # Add nodes to the specified graphs for graph_idx, new_nodes in zip(batch_indices, nodes_to_add): - tensor_dict["batch_ptr"][graph_idx] - end_ptr = tensor_dict["batch_ptr"][graph_idx + 1] + # Get the graph to modify + graph = data_list[graph_idx] + + # Ensure new_nodes is 2D new_nodes = torch.atleast_2d(new_nodes) - if new_nodes.shape[1] != node_feature_dim: - raise ValueError( - f"Node features must have dimension {node_feature_dim}" - ) - - # Update batch pointers for subsequent graphs + + # Check feature dimension + if new_nodes.shape[1] != graph.x.shape[1]: + raise ValueError(f"Node features must have dimension {graph.x.shape[1]}") + + # Generate unique indices for new nodes num_new_nodes = new_nodes.shape[0] - modified_dict["batch_ptr"][graph_idx + 1 :] += num_new_nodes - - # Expand node features - modified_dict["node_feature"] = torch.cat( - [ - modified_dict["node_feature"][:end_ptr], - new_nodes, - modified_dict["node_feature"][end_ptr:], - ] - ) - modified_dict["node_index"] = torch.cat( - [ - modified_dict["node_index"][:end_ptr], - GraphStates.unique_node_indices(num_new_nodes), - modified_dict["node_index"][end_ptr:], - ] - ) - - return modified_dict + new_indices = GraphStates.unique_node_indices(num_new_nodes) + + # Add new nodes to the graph + graph.x = torch.cat([graph.x, new_nodes], dim=0) + + # Add node indices if they exist + if hasattr(graph, "node_index"): + graph.node_index = torch.cat([graph.node_index, new_indices], dim=0) + else: + # If node_index doesn't exist, create it + graph.node_index = torch.cat([ + torch.arange(graph.num_nodes - num_new_nodes, device=graph.x.device), + new_indices + ], dim=0) + + # Create a new batch from the updated data list + new_batch = Batch.from_data_list(data_list) + + # Preserve the batch shape + new_batch.batch_shape = tensor.batch_shape + return new_batch def reward(self, final_states: GraphStates) -> torch.Tensor: """The environment's reward given a state. @@ -239,16 +294,6 @@ def reward(self, final_states: GraphStates) -> torch.Tensor: """ return self.state_evaluator(final_states) - @property - def log_partition(self) -> float: - "Returns the logarithm of the partition function." - raise NotImplementedError - - @property - def true_dist_pmf(self) -> torch.Tensor: - "Returns a one-dimensional tensor representing the true distribution." - raise NotImplementedError - def make_random_states_tensor(self, batch_shape: Tuple) -> GraphStates: """Generates random states tensor of shape (*batch_shape, feature_dim).""" return self.States.from_batch_shape(batch_shape) diff --git a/src/gfn/states.py b/src/gfn/states.py index f8ee0e2b..78296a6c 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -7,7 +7,7 @@ import numpy as np import torch -from tensordict import TensorDict +from torch_geometric.data import Data, Batch from gfn.actions import GraphActionType @@ -48,11 +48,10 @@ class States(ABC): _log_rewards: Stores the log rewards of each state. """ - state_shape: ClassVar[tuple[int, ...]] # Shape of one state - s0: ClassVar[torch.Tensor | TensorDict] # Source state of the DAG - sf: ClassVar[ - torch.Tensor | TensorDict - ] # Dummy state, used to pad a batch of states + state_shape: ClassVar[tuple[int, ...]] + s0: ClassVar[torch.Tensor] + sf: ClassVar[torch.Tensor] + make_random_states_tensor: Callable = lambda x: (_ for _ in ()).throw( NotImplementedError( "The environment does not support initialization of random states." @@ -541,39 +540,42 @@ class GraphStates(States): graph objects as states. """ - s0: ClassVar[TensorDict] - sf: ClassVar[TensorDict] + s0: ClassVar[Data] + sf: ClassVar[Data] _next_node_index = 0 - def __init__(self, tensor: TensorDict): - REQUIRED_KEYS = { - "node_feature", - "node_index", - "edge_feature", - "edge_index", - "batch_ptr", - "batch_shape", - } - if not all(key in tensor for key in REQUIRED_KEYS): - raise ValueError( - f"TensorDict must contain all required keys: {REQUIRED_KEYS}" - ) - - assert tensor["node_index"].unique().numel() == len(tensor["node_index"]) + def __init__(self, tensor: Batch): + """Initialize the GraphStates with a PyG Batch object. + + Args: + tensor: A PyG Batch object representing a batch of graphs. + """ self.tensor = tensor - self.node_features_dim = tensor["node_feature"].shape[-1] - self.edge_features_dim = tensor["edge_feature"].shape[-1] + if not hasattr(self.tensor, 'batch_shape'): + self.tensor.batch_shape = self.tensor.batch_size self._log_rewards: Optional[torch.Tensor] = None @property def batch_shape(self) -> tuple[int, ...]: - return tuple(self.tensor["batch_shape"].tolist()) + """Returns the batch shape as a tuple.""" + return tuple(self.tensor.batch_shape.tolist()) + @classmethod def from_batch_shape( cls, batch_shape: int | Tuple, random: bool = False, sink: bool = False ) -> GraphStates: + """Create a GraphStates object with the given batch shape. + + Args: + batch_shape: Shape of the batch dimensions. + random: Initialize states randomly. + sink: States initialized with s_f (the sink state). + + Returns: + A GraphStates object with the specified batch shape. + """ if random and sink: raise ValueError("Only one of `random` and `sink` should be True.") if random: @@ -585,545 +587,537 @@ def from_batch_shape( return cls(tensor) @classmethod - def make_initial_states_tensor(cls, batch_shape: int | Tuple) -> TensorDict: + def make_initial_states_tensor(cls, batch_shape: int | Tuple) -> Batch: + """Makes a batch of graphs consisting of s0 states. + + Args: + batch_shape: Shape of the batch dimensions. + + Returns: + A PyG Batch object containing copies of the initial state. + """ batch_shape = batch_shape if isinstance(batch_shape, Tuple) else (batch_shape,) - nodes = cls.s0["node_feature"].repeat(np.prod(batch_shape), 1) - - return TensorDict( - { - "node_feature": nodes, - "node_index": GraphStates.unique_node_indices(nodes.shape[0]), - "edge_feature": cls.s0["edge_feature"].repeat(np.prod(batch_shape), 1), - "edge_index": cls.s0["edge_index"].repeat(np.prod(batch_shape), 1), - "batch_ptr": torch.arange( - int(np.prod(batch_shape)) + 1, device=cls.s0.device - ) - * cls.s0["node_feature"].shape[0], - "batch_shape": torch.tensor(batch_shape, device=cls.s0.device), - } - ) + num_graphs = int(np.prod(batch_shape)) + + # Create a list of Data objects by copying s0 + data_list = [cls.s0.clone() for _ in range(num_graphs)] + + # Create a batch from the list + batch = Batch.from_data_list(data_list) + + # Store the batch shape for later reference + batch.batch_shape = torch.tensor(batch_shape, device=cls.s0.x.device) + + return batch @classmethod - def make_sink_states_tensor(cls, batch_shape: int | Tuple) -> TensorDict: + def make_sink_states_tensor(cls, batch_shape: int | Tuple) -> Batch: + """Makes a batch of graphs consisting of sf states. + + Args: + batch_shape: Shape of the batch dimensions. + + Returns: + A PyG Batch object containing copies of the sink state. + """ if cls.sf is None: raise NotImplementedError("Sink state is not defined") batch_shape = batch_shape if isinstance(batch_shape, Tuple) else (batch_shape,) - nodes = cls.sf["node_feature"].repeat(np.prod(batch_shape), 1) - out = TensorDict( - { - "node_feature": nodes, - "node_index": GraphStates.unique_node_indices(nodes.shape[0]), - "edge_feature": cls.sf["edge_feature"].repeat(np.prod(batch_shape), 1), - "edge_index": cls.sf["edge_index"].repeat(np.prod(batch_shape), 1), - "batch_ptr": torch.arange( - int(np.prod(batch_shape)) + 1, device=cls.sf.device - ) - * cls.sf["node_feature"].shape[0], - "batch_shape": torch.tensor(batch_shape, device=cls.sf.device), - } - ) - return out + num_graphs = int(np.prod(batch_shape)) + + # Create a list of Data objects by copying sf + data_list = [cls.sf.clone() for _ in range(num_graphs)] + if len(data_list) == 0: # If batch_shape is 0, create a single empty graph + data_list = [Data( + x=torch.zeros(0, cls.sf.x.size(1)), + edge_index=torch.zeros(2, 0, dtype=torch.long), + edge_attr=torch.zeros(0, cls.sf.edge_attr.size(1)) + )] + + # Create a batch from the list + batch = Batch.from_data_list(data_list) + + # Store the batch shape for later reference + batch.batch_shape = torch.tensor(batch_shape, device=cls.sf.x.device) + + return batch @classmethod - def make_random_states_tensor(cls, batch_shape: int | Tuple) -> TensorDict: + def make_random_states_tensor(cls, batch_shape: int | Tuple) -> Batch: + """Makes a batch of random graph states. + + Args: + batch_shape: Shape of the batch dimensions. + + Returns: + A PyG Batch object containing random graph states. + """ batch_shape = batch_shape if isinstance(batch_shape, Tuple) else (batch_shape,) - - num_nodes = np.random.randint(10) - num_edges = np.random.randint(num_nodes * (num_nodes - 1) // 2) - node_features_dim = cls.s0["node_feature"].shape[-1] - edge_features_dim = cls.s0["edge_feature"].shape[-1] - device = cls.s0.device - return TensorDict( - { - "node_feature": torch.rand( - int(np.prod(batch_shape)) * num_nodes, - node_features_dim, - device=device, - ), - "node_index": GraphStates.unique_node_indices( - int(np.prod(batch_shape)) * num_nodes - ), - "edge_feature": torch.rand( - int(np.prod(batch_shape)) * num_edges, - edge_features_dim, - device=device, - ), - "edge_index": torch.randint( - num_nodes, - size=(int(np.prod(batch_shape)) * num_edges, 2), - device=device, - ), - "batch_ptr": torch.arange(int(np.prod(batch_shape)) + 1, device=device) - * num_nodes, - "batch_shape": torch.tensor(batch_shape), - } - ) + num_graphs = int(np.prod(batch_shape)) + device = cls.s0.x.device + + data_list = [] + for _ in range(num_graphs): + # Create a random graph with random number of nodes + num_nodes = np.random.randint(1, 10) + + # Create random node features + x = torch.rand(num_nodes, cls.s0.x.size(1), device=device) + + # Create random edges (not all possible edges to keep it sparse) + num_edges = np.random.randint(0, num_nodes * (num_nodes - 1) // 2 + 1) + if num_edges > 0 and num_nodes > 1: + # Generate random source and target nodes + edge_index = torch.zeros(2, num_edges, dtype=torch.long, device=device) + for i in range(num_edges): + src, dst = np.random.choice(num_nodes, 2, replace=False) + edge_index[0, i] = src + edge_index[1, i] = dst + + # Create random edge features + edge_attr = torch.rand(num_edges, cls.s0.edge_attr.size(1), device=device) + + data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr) + else: + # No edges + data = Data(x=x, edge_index=torch.zeros(2, 0, dtype=torch.long, device=device), + edge_attr=torch.zeros(0, cls.s0.edge_attr.size(1), device=device)) + + data_list.append(data) + + # Create a batch from the list + batch = Batch.from_data_list(data_list) + + # Store the batch shape for later reference + batch.batch_shape = torch.tensor(batch_shape, device=device) + + return batch def __len__(self) -> int: + """Returns the number of graphs in the batch.""" return int(np.prod(self.batch_shape)) def __repr__(self): + """Returns a string representation of the GraphStates object.""" return ( - f"{self.__class__.__name__} object of batch shape {self.tensor['batch_shape']} and " - f"node feature dim {self.node_features_dim} and edge feature dim {self.edge_features_dim}" + f"{self.__class__.__name__} object of batch shape {self.batch_shape} and " + f"node feature dim {self.tensor.x.size(1)} and edge feature dim {self.tensor.edge_attr.size(1)}" ) def __getitem__( self, index: int | Sequence[int] | slice | torch.Tensor ) -> GraphStates: + """Get a subset of the GraphStates. + + Args: + index: Index or indices to select. + + Returns: + A new GraphStates object containing the selected graphs. + """ + # Convert the index to a list of indices tensor_idx = torch.arange(len(self)).view(*self.batch_shape) - new_shape = tensor_idx[index].shape - idx = tensor_idx[index].flatten() - - if torch.any(idx >= len(self.tensor["batch_ptr"]) - 1): - raise ValueError("Graph index out of bounds") - - # TODO: explain batch_ptr and node_index semantics - start_ptrs = self.tensor["batch_ptr"][:-1][idx] - end_ptrs = self.tensor["batch_ptr"][1:][idx] - - node_features = [torch.empty(0, self.node_features_dim)] - node_indices = [torch.empty(0, dtype=torch.long)] - edge_features = [torch.empty(0, self.edge_features_dim)] - edge_indices = [torch.empty(0, 2, dtype=torch.long)] - batch_ptr = [0] - - for start, end in zip(start_ptrs, end_ptrs): - node_features.append(self.tensor["node_feature"][start:end]) - node_indices.append(self.tensor["node_index"][start:end]) - batch_ptr.append(batch_ptr[-1] + end - start) - - # Find edges for this graph - if self.tensor["node_index"].numel() > 0: - edge_mask = ( - self.tensor["edge_index"][:, 0] >= self.tensor["node_index"][start] - ) & ( - self.tensor["edge_index"][:, 0] - <= self.tensor["node_index"][end - 1] - ) - edge_features.append(self.tensor["edge_feature"][edge_mask]) - edge_indices.append(self.tensor["edge_index"][edge_mask]) - - out = self.__class__( - TensorDict( - { - "node_feature": torch.cat(node_features), - "node_index": torch.cat(node_indices), - "edge_feature": torch.cat(edge_features), - "edge_index": torch.cat(edge_indices), - "batch_ptr": torch.tensor(batch_ptr, device=self.tensor.device), - "batch_shape": torch.tensor(new_shape, device=self.tensor.device), - } - ) - ) - + if isinstance(index, int): + new_shape = (1,) + else: + new_shape = tensor_idx[index].shape + indices = tensor_idx[index].flatten().tolist() + + # Get the selected graphs from the batch + selected_graphs = self.tensor.index_select(indices) + if len(selected_graphs) == 0: + assert np.prod(new_shape) == 0 + selected_graphs = [Data( + x=torch.zeros(0, self.tensor.x.size(1)), + edge_index=torch.zeros(2, 0, dtype=torch.long), + edge_attr=torch.zeros(0, self.tensor.edge_attr.size(1)) + )] + + # Create a new batch from the selected graphs + new_batch = Batch.from_data_list(selected_graphs) + new_batch.batch_shape = torch.tensor(new_shape, device=self.tensor.batch_shape.device) + + # Create a new GraphStates object + out = self.__class__(new_batch) + + # Copy log rewards if they exist if self._log_rewards is not None: - out._log_rewards = self._log_rewards[idx] - - assert out.tensor["node_index"].unique().numel() == len( - out.tensor["node_index"] - ) - + out._log_rewards = self._log_rewards[indices] + return out def __setitem__(self, index: int | Sequence[int], graph: GraphStates): + """Set a subset of the GraphStates. + + Args: + index: Index or indices to set. + graph: GraphStates object containing the new graphs. """ - Set particular states of the Batch - """ - # This is to convert index to type int (linear indexing). - idx = torch.arange(len(self)).view(*self.batch_shape) - idx = idx[index].flatten() - - # Validate indices - if torch.any(idx >= len(self.tensor["batch_ptr"]) - 1): - raise ValueError("Target graph index out of bounds") - - # Source graph details - source_tensor_dict = graph.tensor - source_num_graphs = torch.prod(source_tensor_dict["batch_shape"]) - - # Validate source and target indices match - if len(idx) != source_num_graphs: - raise ValueError( - "Number of source graphs must match number of target indices" - ) - - for i, graph_idx in enumerate(idx): - # Get start and end pointers for the current graph - start_ptr = self.tensor["batch_ptr"][graph_idx] - end_ptr = self.tensor["batch_ptr"][graph_idx + 1] - source_start_ptr = source_tensor_dict["batch_ptr"][i] - source_end_ptr = source_tensor_dict["batch_ptr"][i + 1] - - new_nodes = source_tensor_dict["node_feature"][ - source_start_ptr:source_end_ptr - ] - new_nodes = torch.atleast_2d(new_nodes) - - if new_nodes.shape[1] != self.node_features_dim: - raise ValueError( - f"Node features must have dimension {self.node_features_dim}" - ) - - # Concatenate node features - self.tensor["node_feature"] = torch.cat( - [ - self.tensor["node_feature"][ - :start_ptr - ], # Nodes before the current graph - new_nodes, # New nodes to add - self.tensor["node_feature"][ - end_ptr: - ], # Nodes after the current graph - ] - ) - - edge_mask = torch.empty(0, dtype=torch.bool) - if self.tensor["edge_index"].numel() > 0: - edge_mask = torch.all( - self.tensor["edge_index"] > self.tensor["node_index"][end_ptr - 1], - dim=-1, - ) - edge_mask |= torch.all( - self.tensor["edge_index"] < self.tensor["node_index"][start_ptr], - dim=-1, - ) - - edge_to_add_mask = torch.all( - source_tensor_dict["edge_index"] - >= source_tensor_dict["node_index"][source_start_ptr], - dim=-1, - ) - edge_to_add_mask &= torch.all( - source_tensor_dict["edge_index"] - <= source_tensor_dict["node_index"][source_end_ptr - 1], - dim=-1, - ) - self.tensor["edge_index"] = torch.cat( - [ - self.tensor["edge_index"][edge_mask], - source_tensor_dict["edge_index"][edge_to_add_mask], - ], - dim=0, - ) - self.tensor["edge_feature"] = torch.cat( - [ - self.tensor["edge_feature"][edge_mask], - source_tensor_dict["edge_feature"][edge_to_add_mask], - ], - dim=0, - ) - - self.tensor["node_index"] = torch.cat( - [ - self.tensor["node_index"][:start_ptr], - source_tensor_dict["node_index"][source_start_ptr:source_end_ptr], - self.tensor["node_index"][end_ptr:], - ] - ) - # Update batch pointers - shift = new_nodes.shape[0] - (end_ptr - start_ptr) - self.tensor["batch_ptr"][graph_idx + 1 :] += shift - - assert self.tensor["node_index"].unique().numel() == len( - self.tensor["node_index"] - ) + # Convert the index to a list of indices + batch_shape = self.batch_shape + if isinstance(index, int): + indices = [index] + else: + tensor_idx = torch.arange(len(self)).view(*batch_shape) + indices = tensor_idx[index].flatten().tolist() + + # Get the data list from the current batch + data_list = self.tensor.to_data_list() + + # Get the data list from the new graphs + new_data_list = graph.tensor.to_data_list() + + # Replace the selected graphs + for i, idx in enumerate(indices): + if i < len(new_data_list): + data_list[idx] = new_data_list[i] + + # Create a new batch from the updated data list + self.tensor = Batch.from_data_list(data_list) + + # Preserve the batch shape + self.tensor.batch_shape = torch.tensor(batch_shape, device=self.tensor.x.device) @property - def device(self) -> torch.device | None: - return self.tensor.device + def device(self) -> torch.device: + """Returns the device of the tensor.""" + return self.tensor.x.device def to(self, device: torch.device) -> GraphStates: - """ - Moves and/or casts the graph states to the specified device + """Moves the GraphStates to the specified device. + + Args: + device: The device to move to. + + Returns: + The GraphStates object on the specified device. """ self.tensor = self.tensor.to(device) + if self._log_rewards is not None: + self._log_rewards = self._log_rewards.to(device) return self def clone(self) -> GraphStates: - """Returns a *detached* clone of the current instance using deepcopy.""" - return deepcopy(self) + """Returns a detached clone of the current instance. + + Returns: + A new GraphStates object with the same data. + """ + # Create a deep copy of the batch + data_list = [data.clone() for data in self.tensor.to_data_list()] + new_batch = Batch.from_data_list(data_list) + new_batch.batch_shape = self.tensor.batch_shape.clone() + + # Create a new GraphStates object + out = self.__class__(new_batch) + + # Copy log rewards if they exist + if self._log_rewards is not None: + out._log_rewards = self._log_rewards.clone() + + return out def extend(self, other: GraphStates): - """Concatenates to another GraphStates object along the batch dimension""" - # find if there are common node indices - other_node_index = other.tensor["node_index"].clone() # Clone to avoid modifying original - other_edge_index = other.tensor["edge_index"].clone() # Clone to avoid modifying original + """Concatenates to another GraphStates object along the batch dimension. - # Always generate new indices for the other state to ensure uniqueness - new_indices = GraphStates.unique_node_indices(len(other_node_index)) + Args: + other: GraphStates object to concatenate with. + """ + if len(self) == 0: + # If self is empty, just copy other + self.tensor = other.tensor.clone() + if other._log_rewards is not None: + self._log_rewards = other._log_rewards.clone() + return - # Update edge indices to match new node indices - for i, old_idx in enumerate(other_node_index): - other_edge_index[other_edge_index == old_idx] = new_indices[i] + # Get the data lists + self_data_list = self.tensor.to_data_list() + other_data_list = other.tensor.to_data_list() - # Update node indices - other_node_index = new_indices - - if torch.prod(self.tensor["batch_shape"]) == 0: - # if self is empty, just copy other - self.tensor["node_feature"] = other.tensor["node_feature"] - self.tensor["batch_shape"] = other.tensor["batch_shape"] - self.tensor["node_index"] = other_node_index - self.tensor["edge_feature"] = other.tensor["edge_feature"] - self.tensor["edge_index"] = other_edge_index - self.tensor["batch_ptr"] = other.tensor["batch_ptr"] - - elif len(self.tensor["batch_shape"]) == 1: - self.tensor["node_feature"] = torch.cat( - [self.tensor["node_feature"], other.tensor["node_feature"]], dim=0 - ) - self.tensor["node_index"] = torch.cat( - [self.tensor["node_index"], other_node_index], dim=0 - ) - self.tensor["edge_feature"] = torch.cat( - [self.tensor["edge_feature"], other.tensor["edge_feature"]], dim=0 - ) - self.tensor["edge_index"] = torch.cat( - [self.tensor["edge_index"], other_edge_index], - dim=0, - ) - self.tensor["batch_ptr"] = torch.cat( - [ - self.tensor["batch_ptr"], - other.tensor["batch_ptr"][1:] + self.tensor["batch_ptr"][-1], - ], - dim=0, - ) - self.tensor["batch_shape"] = ( - self.tensor["batch_shape"][0] + other.tensor["batch_shape"][0], - ) + self.batch_shape[1:] - else: - # Here we handle the case where the batch shape is (T, B) - # and we want to concatenate along the batch dimension B. - assert len(self.tensor["batch_shape"]) == 2 - max_len = max(self.tensor["batch_shape"][0], other.tensor["batch_shape"][0]) - - node_features = [] - node_indices = [] - edge_features = [] - edge_indices = [] - # Get device from one of the tensors - device = self.tensor["node_feature"].device - batch_ptr = [torch.tensor([0], device=device)] + # Update the batch shape + if len(self.batch_shape) == 1: + # Create a new batch + new_batch_shape = (self.batch_shape[0] + other.batch_shape[0],) + self.tensor = Batch.from_data_list(self_data_list + other_data_list) + self.tensor.batch_shape = torch.tensor(new_batch_shape, device=self.tensor.x.device) + else: + # Handle the case where batch_shape is (T, B) + # and we want to concatenate along the B dimension + assert len(self.batch_shape) == 2 + max_len = max(self.batch_shape[0], other.batch_shape[0]) - for i in range(max_len): - # Following the logic of Base class, we want to extend with sink states - if i >= self.tensor["batch_shape"][0]: - self_i = self.make_sink_states_tensor(self.tensor["batch_shape"][1:]) - else: - self_i = self[i].tensor - if i >= other.tensor["batch_shape"][0]: - other_i = other.make_sink_states_tensor(other.tensor["batch_shape"][1:]) - else: - other_i = other[i].tensor - - # Generate new unique indices for both self_i and other_i - new_self_indices = GraphStates.unique_node_indices(len(self_i["node_index"])) - new_other_indices = GraphStates.unique_node_indices(len(other_i["node_index"])) - - # Update self_i edge indices - self_edge_index = self_i["edge_index"].clone() - for old_idx, new_idx in zip(self_i["node_index"], new_self_indices): - mask = (self_edge_index == old_idx) - self_edge_index[mask] = new_idx - - # Update other_i edge indices - other_edge_index = other_i["edge_index"].clone() - for old_idx, new_idx in zip(other_i["node_index"], new_other_indices): - mask = (other_edge_index == old_idx) - other_edge_index[mask] = new_idx - - node_features.append(self_i["node_feature"]) - node_indices.append(new_self_indices) # Use new indices - edge_features.append(self_i["edge_feature"]) - edge_indices.append(self_edge_index) # Use updated edge indices - batch_ptr.append(self_i["batch_ptr"][1:] + batch_ptr[-1][-1]) - - node_features.append(other_i["node_feature"]) - node_indices.append(new_other_indices) # Use new indices - edge_features.append(other_i["edge_feature"]) - edge_indices.append(other_edge_index) # Use updated edge indices - batch_ptr.append(other_i["batch_ptr"][1:] + batch_ptr[-1][-1]) - - self.tensor["node_feature"] = torch.cat(node_features, dim=0) - self.tensor["node_index"] = torch.cat(node_indices, dim=0) - self.tensor["edge_feature"] = torch.cat(edge_features, dim=0) - self.tensor["edge_index"] = torch.cat(edge_indices, dim=0) - self.tensor["batch_ptr"] = torch.cat(batch_ptr, dim=0) - - self.tensor["batch_shape"] = ( - max_len, - self.tensor["batch_shape"][1] + other.tensor["batch_shape"][1], - ) - - assert self.tensor["node_index"].unique().numel() == len( - self.tensor["node_index"] - ) - assert torch.prod(torch.tensor(self.tensor["batch_shape"])) == len(self.tensor["batch_ptr"]) - 1 + # We need to extend both batches to the same length T + if self.batch_shape[0] < max_len: + self_extension = self.make_sink_states_tensor( + (max_len - self.batch_shape[0], self.batch_shape[1]) + ) + self_data_list = self_data_list + self_extension.to_data_list() + + if other.batch_shape[0] < max_len: + other_extension = other.make_sink_states_tensor( + (max_len - other.batch_shape[0], other.batch_shape[1]) + ) + other_data_list = other_data_list + other_extension.to_data_list() + + # Now both have the same length T, we can concatenate along B + batch_shape = (max_len, self.batch_shape[1] + other.batch_shape[1]) + self.tensor = Batch.from_data_list(self_data_list + other_data_list) + self.tensor.batch_shape = torch.tensor(batch_shape, device=self.tensor.x.device) + + # Combine log rewards if they exist + if self._log_rewards is not None and other._log_rewards is not None: + self._log_rewards = torch.cat([self._log_rewards, other._log_rewards], dim=0) + elif other._log_rewards is not None: + self._log_rewards = other._log_rewards.clone() @property def log_rewards(self) -> torch.Tensor | None: + """Returns the log rewards of the states.""" return self._log_rewards @log_rewards.setter def log_rewards(self, log_rewards: torch.Tensor) -> None: + """Sets the log rewards of the states. + + Args: + log_rewards: Tensor of shape `batch_shape` representing the log rewards. + """ + assert log_rewards.shape == self.batch_shape self._log_rewards = log_rewards - def _compare(self, other: TensorDict) -> torch.Tensor: - out = torch.zeros(len(self.tensor["batch_ptr"]) - 1, dtype=torch.bool) - for i in range(len(self.tensor["batch_ptr"]) - 1): - start, end = self.tensor["batch_ptr"][i], self.tensor["batch_ptr"][i + 1] - if end - start != len(other["node_feature"]): - out[i] = False - else: - out[i] = torch.all( - self.tensor["node_feature"][start:end] == other["node_feature"] - ) - edge_mask = torch.all( - (self.tensor["edge_index"] >= self.tensor["node_index"][start]) - & (self.tensor["edge_index"] <= self.tensor["node_index"][end - 1]), - dim=-1, - ) - edge_index = self.tensor["edge_index"][edge_mask] - out[i] &= len(edge_index) == len(other["edge_index"]) and torch.all( - edge_index == other["edge_index"] - ) - edge_feature = self.tensor["edge_feature"][edge_mask] - out[i] &= len(edge_feature) == len(other["edge_feature"]) and torch.all( - edge_feature == other["edge_feature"] - ) + def _compare(self, other: Data) -> torch.Tensor: + """Compares the current batch of graphs with another graph. + + Args: + other: A PyG Data object to compare with. + + Returns: + A boolean tensor indicating which graphs in the batch are equal to other. + """ + out = torch.zeros(len(self), dtype=torch.bool, device=self.device) + + # Get the data list from the batch + data_list = self.tensor.to_data_list() + + for i, data in enumerate(data_list): + # Check if the number of nodes is the same + if data.num_nodes != other.num_nodes: + continue + + # Check if node features are the same + if not torch.all(data.x == other.x): + continue + + # Check if the number of edges is the same + if data.edge_index.size(1) != other.edge_index.size(1): + continue + + # Check if edge indices are the same (this is more complex due to potential reordering) + # We'll use a simple heuristic: sort edges and compare + data_edges = data.edge_index.t().tolist() + other_edges = other.edge_index.t().tolist() + data_edges.sort() + other_edges.sort() + if data_edges != other_edges: + continue + + # Check if edge attributes are the same (after sorting) + data_edge_attr = data.edge_attr[torch.argsort(data.edge_index[0] * data.num_nodes + data.edge_index[1])] + other_edge_attr = other.edge_attr[torch.argsort(other.edge_index[0] * other.num_nodes + other.edge_index[1])] + if not torch.all(data_edge_attr == other_edge_attr): + continue + + # If all checks pass, the graphs are equal + out[i] = True + return out.view(self.batch_shape) @property def is_sink_state(self) -> torch.Tensor: + """Returns a tensor that is True for states that are sf.""" return self._compare(self.sf) @property def is_initial_state(self) -> torch.Tensor: + """Returns a tensor that is True for states that are s0.""" return self._compare(self.s0) @classmethod - def stack(cls, states: List[GraphStates]): - """Given a list of states, stacks them along a new dimension (0).""" - stacked_states = cls.from_batch_shape(0) + def stack(cls, states: List[GraphStates]) -> GraphStates: + """Given a list of states, stacks them along a new dimension (0). + + Args: + states: List of GraphStates objects to stack. + + Returns: + A new GraphStates object with the stacked states. + """ + # Check that all states have the same batch shape state_batch_shape = states[0].batch_shape - assert len(state_batch_shape) == 1 - for state in states: - assert state.batch_shape == state_batch_shape - stacked_states.extend(state) - - stacked_states.tensor["batch_shape"] = (len(states),) + state_batch_shape - assert stacked_states.tensor["node_index"].unique().numel() == len( - stacked_states.tensor["node_index"] + assert all(state.batch_shape == state_batch_shape for state in states) + + # Get all data lists + all_data_lists = [state.tensor.to_data_list() for state in states] + + # Flatten the list of lists + flat_data_list = [data for data_list in all_data_lists for data in data_list] + + # Create a new batch + batch = Batch.from_data_list(flat_data_list) + + # Set the batch shape + batch.batch_shape = torch.tensor( + (len(states),) + state_batch_shape, + device=states[0].device ) - return stacked_states + + # Create a new GraphStates object + out = cls(batch) + + # Stack log rewards if they exist + if all(state._log_rewards is not None for state in states): + out._log_rewards = torch.stack([state._log_rewards for state in states], dim=0) + + return out @property - def forward_masks(self) -> TensorDict: - n_nodes = self.tensor["batch_ptr"][1:] - self.tensor["batch_ptr"][:-1] - ei_mask_shape = ( - len(self.tensor["node_feature"]), - len(self.tensor["node_feature"]), - ) - forward_masks = TensorDict( - { - "action_type": torch.ones(self.batch_shape + (3,), dtype=torch.bool), - "features": torch.ones( - self.batch_shape + (self.node_features_dim,), dtype=torch.bool - ), - "edge_index": torch.zeros( - self.batch_shape + ei_mask_shape, dtype=torch.bool - ), - } - ) # TODO: edge_index mask is very memory consuming... - forward_masks["action_type"][..., GraphActionType.ADD_EDGE] = n_nodes > 1 - forward_masks["action_type"][..., GraphActionType.EXIT] = n_nodes >= 1 - - arange = torch.arange(len(self)).view(self.batch_shape) - arange_nodes = torch.arange(len(self.tensor["node_feature"]))[None, :] - same_graph_mask = (arange_nodes >= self.tensor["batch_ptr"][:-1, None]) & ( - arange_nodes < self.tensor["batch_ptr"][1:, None] - ) - edge_index = torch.where( - self.tensor["edge_index"][..., None] == self.tensor["node_index"] - )[2].reshape(self.tensor["edge_index"].shape) - i, j = edge_index[..., 0], edge_index[..., 1] - - for _ in range(len(self.batch_shape)): - (i, j) = i.unsqueeze(0), j.unsqueeze(0) - - # First allow nodes in the same graph to connect, then disable nodes with existing edges - forward_masks["edge_index"][ - same_graph_mask[:, :, None] & same_graph_mask[:, None, :] - ] = True - torch.diagonal(forward_masks["edge_index"], dim1=-2, dim2=-1).fill_(False) - forward_masks["edge_index"][arange[..., None], i, j] = False - forward_masks["action_type"][..., GraphActionType.ADD_EDGE] &= torch.any( - forward_masks["edge_index"], dim=(-1, -2) + def forward_masks(self) -> dict: + """Returns masks denoting allowed forward actions. + + Returns: + A dictionary containing masks for different action types. + """ + # Get the data list from the batch + data_list = self.tensor.to_data_list() + N = self.tensor.x.size(0) + + # Initialize masks + action_type_mask = torch.ones(self.batch_shape + (3,), dtype=torch.bool, device=self.device) + features_mask = torch.ones( + self.batch_shape + (self.tensor.x.size(1),), + dtype=torch.bool, + device=self.device ) - return forward_masks + edge_index_masks = torch.ones((len(data_list), N, N), dtype=torch.bool, device=self.device) + + + # For each graph in the batch + for i, data in enumerate(data_list): + # Flatten the batch index + flat_idx = i + + # ADD_NODE is always allowed + action_type_mask[flat_idx, GraphActionType.ADD_NODE] = True + + # ADD_EDGE is allowed only if there are at least 2 nodes + action_type_mask[flat_idx, GraphActionType.ADD_EDGE] = data.num_nodes > 1 + + # EXIT is always allowed + action_type_mask[flat_idx, GraphActionType.EXIT] = True + + # Create edge_index mask as a dense representation (NxN matrix) + start_n = 0 + for i, data in enumerate(data_list): + # For each graph, create a dense mask for potential edges + n = data.num_nodes + + edge_mask = torch.ones((n, n), dtype=torch.bool, device=self.device) + # Remove self-loops by setting diagonal to False + edge_mask.fill_diagonal_(False) + + # Exclude existing edges + if data.edge_index.size(1) > 0: + for j in range(data.edge_index.size(1)): + src, dst = data.edge_index[0, j], data.edge_index[1, j] + edge_mask[src, dst] = False + + edge_index_masks[i, start_n:(start_n + n), start_n:(start_n + n)] = edge_mask + start_n += n + + # Update ADD_EDGE mask based on whether there are valid edges to add + action_type_mask[flat_idx, GraphActionType.ADD_EDGE] &= edge_mask.any() + + return { + "action_type": action_type_mask, + "features": features_mask, + "edge_index": edge_index_masks + } @property - def backward_masks(self) -> TensorDict: - n_nodes = self.tensor["batch_ptr"][1:] - self.tensor["batch_ptr"][:-1] - n_edges = torch.count_nonzero( - ( - self.tensor["edge_index"][None, :, 0] - >= self.tensor["batch_ptr"][:-1, None] - ) - & ( - self.tensor["edge_index"][None, :, 0] - < self.tensor["batch_ptr"][1:, None] - ) - & ( - self.tensor["edge_index"][None, :, 1] - >= self.tensor["batch_ptr"][:-1, None] - ) - & ( - self.tensor["edge_index"][None, :, 1] - < self.tensor["batch_ptr"][1:, None] - ), - dim=-1, - ) - ei_mask_shape = ( - len(self.tensor["node_feature"]), - len(self.tensor["node_feature"]), + def backward_masks(self) -> dict: + """Returns masks denoting allowed backward actions. + + Returns: + A dictionary containing masks for different action types. + """ + # Get the data list from the batch + data_list = self.tensor.to_data_list() + N = self.tensor.x.size(0) + + # Initialize masks + action_type_mask = torch.ones(self.batch_shape + (3,), dtype=torch.bool, device=self.device) + features_mask = torch.ones( + self.batch_shape + (self.tensor.x.size(1),), + dtype=torch.bool, + device=self.device ) - backward_masks = TensorDict( - { - "action_type": torch.ones(self.batch_shape + (3,), dtype=torch.bool), - "features": torch.ones( - self.batch_shape + (self.node_features_dim,), dtype=torch.bool - ), - "edge_index": torch.zeros( - self.batch_shape + ei_mask_shape, dtype=torch.bool - ), - } - ) # TODO: edge_index mask is very memory consuming... - backward_masks["action_type"][..., GraphActionType.ADD_NODE] = n_nodes >= 1 - backward_masks["action_type"][..., GraphActionType.ADD_EDGE] = n_edges - backward_masks["action_type"][..., GraphActionType.EXIT] = n_nodes >= 1 - - # Allow only existing edges - arange = torch.arange(len(self)).view(self.batch_shape) - ei1 = self.tensor["edge_index"][..., 0] - ei2 = self.tensor["edge_index"][..., 1] - for _ in range(len(self.batch_shape)): - ( - ei1, - ei2, - ) = ei1.unsqueeze( - 0 - ), ei2.unsqueeze(0) - backward_masks["edge_index"][arange[..., None], ei1, ei2] = False - return backward_masks + edge_index_masks = torch.zeros((len(data_list), N, N), dtype=torch.bool, device=self.device) + + # For each graph in the batch + for i, data in enumerate(data_list): + # Flatten the batch index + flat_idx = i + + # ADD_NODE is allowed if there's at least one node (can remove a node) + action_type_mask[flat_idx, GraphActionType.ADD_NODE] = data.num_nodes >= 1 + + # ADD_EDGE is allowed if there's at least one edge (can remove an edge) + action_type_mask[flat_idx, GraphActionType.ADD_EDGE] = data.edge_index.size(1) > 0 + + # EXIT is allowed if there's at least one node + action_type_mask[flat_idx, GraphActionType.EXIT] = data.num_nodes >= 1 + + # Create edge_index mask for backward actions (existing edges that can be removed) + start_n = 0 + for i, data in enumerate(data_list): + # For backward actions, we can only remove existing edges + n = data.num_nodes + edge_mask = torch.zeros((n, n), dtype=torch.bool, device=self.device) + + # Include only existing edges + if data.edge_index.size(1) > 0: + for j in range(data.edge_index.size(1)): + src, dst = data.edge_index[0, j].item(), data.edge_index[1, j].item() + edge_mask[src, dst] = True + + edge_index_masks[i, start_n:(start_n + n), start_n:(start_n + n)] = edge_mask + start_n += n + + return { + "action_type": action_type_mask, + "features": features_mask, + "edge_index": edge_index_masks + } - @classmethod - def unique_node_indices(cls, num_new_nodes: int) -> torch.Tensor: + @staticmethod + def unique_node_indices(num_nodes: int) -> torch.Tensor: + """Generate unique node indices for new nodes. + + Args: + num_nodes: Number of new nodes to generate indices for. + + Returns: + A tensor of unique node indices. + """ + # Generate sequential indices starting from the next available index indices = torch.arange( - cls._next_node_index, cls._next_node_index + num_new_nodes + GraphStates._next_node_index, + GraphStates._next_node_index + num_nodes, + dtype=torch.long ) - cls._next_node_index += num_new_nodes + + # Update the next available index + GraphStates._next_node_index += num_nodes + return indices diff --git a/src/gfn/utils/training.py b/src/gfn/utils/training.py index 36afb450..d42b61dc 100644 --- a/src/gfn/utils/training.py +++ b/src/gfn/utils/training.py @@ -132,7 +132,7 @@ def states_actions_tns_to_traj( # stack is a class method, so actions[0] is just to access a class instance and is not particularly relevant actions = actions[0].stack(actions) log_rewards = env.log_reward(states[-2]) - states = states[0].stack_states(states) + states = states[0].stack(states) when_is_done = torch.tensor([len(states_tns) - 1]) log_probs = None diff --git a/testing/test_environments.py b/testing/test_environments.py index f01d566d..22137a3d 100644 --- a/testing/test_environments.py +++ b/testing/test_environments.py @@ -379,17 +379,17 @@ def test_graph_env(): states = env.step(states, actions) states = env.States(states) - assert states.tensor["node_feature"].shape == (BATCH_SIZE * NUM_NODES, FEATURE_DIM) + assert states.tensor.x.shape == (BATCH_SIZE * NUM_NODES, FEATURE_DIM) with pytest.raises(NonValidActionsError): first_node_mask = ( - torch.arange(len(states.tensor["node_feature"])) // BATCH_SIZE == 0 + torch.arange(len(states.tensor.x)) // BATCH_SIZE == 0 ) actions = action_cls( TensorDict( { "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_NODE), - "features": states.tensor["node_feature"][first_node_mask], + "features": states.tensor.x[first_node_mask], }, batch_size=BATCH_SIZE, ) @@ -411,14 +411,14 @@ def test_graph_env(): states = env.step(states, actions) for i in range(NUM_NODES - 1): - node_is = torch.arange(BATCH_SIZE) * NUM_NODES + i - node_js = torch.arange(BATCH_SIZE) * NUM_NODES + i + 1 + # node_is = torch.arange(BATCH_SIZE) * NUM_NODES + i + # node_js = torch.arange(BATCH_SIZE) * NUM_NODES + i + 1 actions = action_cls( TensorDict( { "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_EDGE), "features": torch.rand((BATCH_SIZE, FEATURE_DIM)), - "edge_index": torch.stack([node_is, node_js], dim=1), + "edge_index": torch.tensor([[i, i + 1]] * BATCH_SIZE), }, batch_size=BATCH_SIZE, ) @@ -442,15 +442,15 @@ def test_graph_env(): assert torch.all(sf_states.is_sink_state) env.reward(sf_states) - num_edges_per_batch = len(states.tensor["edge_feature"]) // BATCH_SIZE + num_edges_per_batch = len(states.tensor.edge_attr) // BATCH_SIZE for i in reversed(range(num_edges_per_batch)): - edge_idx = torch.arange(i * BATCH_SIZE, (i + 1) * BATCH_SIZE) + edge_idx = torch.arange(i, (i + 1) * BATCH_SIZE, i + 1) actions = action_cls( TensorDict( { "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_EDGE), - "features": states.tensor["edge_feature"][edge_idx], - "edge_index": states.tensor["edge_index"][edge_idx], + "features": states.tensor.edge_attr[edge_idx], + "edge_index": states.tensor.edge_index[:, edge_idx].T - states.tensor.ptr[:-1, None], }, batch_size=BATCH_SIZE, ) @@ -479,7 +479,7 @@ def test_graph_env(): TensorDict( { "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_NODE), - "features": states.tensor["node_feature"][edge_idx], + "features": states.tensor.x[edge_idx], }, batch_size=BATCH_SIZE, ) @@ -487,7 +487,7 @@ def test_graph_env(): states = env.backward_step(states, actions) states = env.States(states) - assert states.tensor["node_feature"].shape == (0, FEATURE_DIM) + assert states.tensor.x.shape == (0, FEATURE_DIM) with pytest.raises(NonValidActionsError): actions = action_cls( diff --git a/testing/test_samplers_and_trajectories.py b/testing/test_samplers_and_trajectories.py index c5b6e685..2cb1cbb5 100644 --- a/testing/test_samplers_and_trajectories.py +++ b/testing/test_samplers_and_trajectories.py @@ -376,24 +376,21 @@ def __init__(self, feature_dim: int): self.edge_index_conv = GCNConv(feature_dim, 8) def forward(self, states: GraphStates) -> TensorDict: - node_feature = states.tensor["node_feature"].reshape(-1, self.feature_dim) - edge_index = torch.where( - states.tensor["edge_index"][..., None] == states.tensor["node_index"] - )[2].reshape(states.tensor["edge_index"].shape) + node_feature = states.tensor.x.reshape(-1, self.feature_dim) - if states.tensor["node_feature"].shape[0] == 0: + if states.tensor.x.shape[0] == 0: action_type = torch.zeros((len(states), 3)) action_type[:, GraphActionType.ADD_NODE] = 1 features = torch.zeros((len(states), self.feature_dim)) else: - action_type = self.action_type_conv(node_feature, edge_index.T) + action_type = self.action_type_conv(node_feature, states.tensor.edge_index) action_type = action_type.reshape( len(states), -1, action_type.shape[-1] ).mean(dim=1) - features = self.features_conv(node_feature, edge_index.T) + features = self.features_conv(node_feature, states.tensor.edge_index) features = features.reshape(len(states), -1, features.shape[-1]).mean(dim=1) - edge_index = self.edge_index_conv(node_feature, edge_index.T) + edge_index = self.edge_index_conv(node_feature, states.tensor.edge_index) edge_index = torch.einsum("nf,mf->nm", edge_index, edge_index) edge_index = edge_index[None].repeat(len(states), 1, 1) diff --git a/testing/test_states.py b/testing/test_states.py index 1182e6cd..f03f6373 100644 --- a/testing/test_states.py +++ b/testing/test_states.py @@ -1,202 +1,350 @@ import pytest import torch +from torch_geometric.data import Data, Batch from gfn.states import GraphStates -from tensordict import TensorDict +from gfn.actions import GraphActionType class MyGraphStates(GraphStates): - s0 = TensorDict({ - "node_feature": torch.tensor([[1.0], [2.0]]), - "node_index": torch.tensor([0, 1]), - "edge_feature": torch.tensor([[0.5]]), - "edge_index": torch.tensor([[0, 1]]), - }) - sf = TensorDict({ - "node_feature": torch.tensor([[3.0], [4.0]]), - "node_index": torch.tensor([2, 3]), - "edge_feature": torch.tensor([[0.7]]), - "edge_index": torch.tensor([[2, 3]]), - }) + # Initial state: a graph with 2 nodes and 1 edge + s0 = Data( + x=torch.tensor([[1.0], [2.0]]), + edge_index=torch.tensor([[0], [1]]), + edge_attr=torch.tensor([[0.5]]) + ) + + # Sink state: a graph with 2 nodes and 1 edge (different from s0) + sf = Data( + x=torch.tensor([[3.0], [4.0]]), + edge_index=torch.tensor([[0], [1]]), + edge_attr=torch.tensor([[0.7]]) + ) + @pytest.fixture def simple_graph_state(): """Creates a simple graph state with 2 nodes and 1 edge""" - tensor = TensorDict({ - "node_feature": torch.tensor([[1.0], [2.0]]), - "node_index": torch.tensor([0, 1]), - "edge_feature": torch.tensor([[0.5]]), - "edge_index": torch.tensor([[0, 1]]), - "batch_ptr": torch.tensor([0, 2]), - "batch_shape": torch.tensor([1]) - }) - return MyGraphStates(tensor) + data = Data( + x=torch.tensor([[1.0], [2.0]]), + edge_index=torch.tensor([[0], [1]]), + edge_attr=torch.tensor([[0.5]]) + ) + batch = Batch.from_data_list([data]) + batch.batch_shape = torch.tensor([1]) + return MyGraphStates(batch) + @pytest.fixture def empty_graph_state(): - """Creates an empty graph state""" - tensor = TensorDict({ - "node_feature": torch.tensor([]), - "node_index": torch.tensor([]), - "edge_feature": torch.tensor([]), - "edge_index": torch.tensor([]).reshape(0, 2), - "batch_ptr": torch.tensor([0]), - "batch_shape": torch.tensor([0]) - }) - return MyGraphStates(tensor) + """Creates an empty GraphStates object""" + # Create an empty batch + batch = Batch() + batch.x = torch.zeros((0, 1)) + batch.edge_index = torch.zeros((2, 0), dtype=torch.long) + batch.edge_attr = torch.zeros((0, 1)) + batch.batch = torch.zeros((0,), dtype=torch.long) + batch.batch_shape = torch.tensor([0]) + return MyGraphStates(batch) + def test_extend_empty_state(empty_graph_state, simple_graph_state): """Test extending an empty state with a non-empty state""" empty_graph_state.extend(simple_graph_state) - assert torch.equal(empty_graph_state.tensor["batch_shape"], simple_graph_state.tensor["batch_shape"]) - assert torch.equal(empty_graph_state.tensor["node_feature"], simple_graph_state.tensor["node_feature"]) - assert torch.equal(empty_graph_state.tensor["edge_index"], simple_graph_state.tensor["edge_index"]) - assert torch.equal(empty_graph_state.tensor["edge_feature"], simple_graph_state.tensor["edge_feature"]) - assert torch.equal(empty_graph_state.tensor["batch_ptr"], simple_graph_state.tensor["batch_ptr"]) + # Check that the empty state now has the same content as the simple state + assert torch.equal(empty_graph_state.tensor.batch_shape, simple_graph_state.tensor.batch_shape) + assert torch.equal(empty_graph_state.tensor.x, simple_graph_state.tensor.x) + assert torch.equal(empty_graph_state.tensor.edge_index, simple_graph_state.tensor.edge_index) + assert torch.equal(empty_graph_state.tensor.edge_attr, simple_graph_state.tensor.edge_attr) + assert torch.equal(empty_graph_state.tensor.batch, simple_graph_state.tensor.batch) + def test_extend_1d_batch(simple_graph_state): """Test extending two 1D batch states""" other_state = simple_graph_state.clone() - # The node indices should be different after extend - original_node_indices = simple_graph_state.tensor["node_index"].clone() + # Store original number of nodes and edges + original_num_nodes = simple_graph_state.tensor.num_nodes + original_num_edges = simple_graph_state.tensor.num_edges simple_graph_state.extend(other_state) - assert simple_graph_state.tensor["batch_shape"][0] == 2 - assert len(simple_graph_state.tensor["node_feature"]) == 4 - assert len(simple_graph_state.tensor["edge_feature"]) == 2 + # Check batch shape is updated + assert simple_graph_state.tensor.batch_shape[0] == 2 + + # Check number of nodes and edges doubled + assert simple_graph_state.tensor.num_nodes == 2 * original_num_nodes + assert simple_graph_state.tensor.num_edges == 2 * original_num_edges - # Check that node indices were properly updated (should be unique) - new_node_indices = simple_graph_state.tensor["node_index"] - assert len(torch.unique(new_node_indices)) == len(new_node_indices) - assert not torch.equal(new_node_indices[:2], new_node_indices[2:]) + # Check that batch indices are properly updated + batch_indices = simple_graph_state.tensor.batch + assert torch.equal(batch_indices[:original_num_nodes], torch.zeros(original_num_nodes, dtype=torch.long)) + assert torch.equal(batch_indices[original_num_nodes:], torch.ones(original_num_nodes, dtype=torch.long)) + def test_extend_2d_batch(): """Test extending two 2D batch states""" - # Create 2D batch states (T=2, B=1) - tensor1 = TensorDict({ - "node_feature": torch.tensor([[1.0], [2.0], [3.0], [4.0]]), - "node_index": torch.tensor([0, 1, 2, 3]), - "edge_feature": torch.tensor([[0.5], [0.6]]), - "edge_index": torch.tensor([[0, 1], [2, 3]]), - "batch_ptr": torch.tensor([0, 2, 4]), - "batch_shape": torch.tensor([2, 1]) - }) - state1 = MyGraphStates(tensor1) - - # Create another state with different time length (T=3, B=1) - tensor2 = TensorDict({ - "node_feature": torch.tensor([[5.0], [6.0], [7.0], [8.0], [9.0], [10.0]]), - "node_index": torch.tensor([4, 5, 6, 7, 8, 9]), - "edge_feature": torch.tensor([[0.7], [0.8], [0.9]]), - "edge_index": torch.tensor([[4, 5], [6, 7], [8, 9]]), - "batch_ptr": torch.tensor([0, 2, 4, 6]), - "batch_shape": torch.tensor([3, 1]) - }) - state2 = MyGraphStates(tensor2) + # Create first state (T=2, B=1) + data1 = Data( + x=torch.tensor([[1.0], [2.0]]), + edge_index=torch.tensor([[0], [1]]), + edge_attr=torch.tensor([[0.5]]) + ) + data2 = Data( + x=torch.tensor([[3.0], [4.0]]), + edge_index=torch.tensor([[0], [1]]), + edge_attr=torch.tensor([[0.6]]) + ) + batch1 = Batch.from_data_list([data1, data2]) + batch1.batch_shape = torch.tensor([2, 1]) + state1 = MyGraphStates(batch1) + + # Create second state (T=3, B=1) + data3 = Data( + x=torch.tensor([[5.0], [6.0]]), + edge_index=torch.tensor([[0], [1]]), + edge_attr=torch.tensor([[0.7]]) + ) + data4 = Data( + x=torch.tensor([[7.0], [8.0]]), + edge_index=torch.tensor([[0], [1]]), + edge_attr=torch.tensor([[0.8]]) + ) + data5 = Data( + x=torch.tensor([[9.0], [10.0]]), + edge_index=torch.tensor([[0], [1]]), + edge_attr=torch.tensor([[0.9]]) + ) + batch2 = Batch.from_data_list([data3, data4, data5]) + batch2.batch_shape = torch.tensor([3, 1]) + state2 = MyGraphStates(batch2) + # Extend state1 with state2 state1.extend(state2) # Check final shape should be (max_len=3, B=2) - assert torch.equal(state1.tensor["batch_shape"], torch.tensor([3, 2])) + assert torch.equal(state1.tensor.batch_shape, torch.tensor([3, 2])) # Check that we have the correct number of nodes and edges - # Each time step has 2 nodes and 1 edge, so for 3 time steps and 2 batches: - expected_nodes = 3 * 2 * 2 # T * nodes_per_timestep * B - expected_edges = 3 * 1 * 2 # T * edges_per_timestep * B - assert len(state1.tensor["node_feature"]) == expected_nodes - assert len(state1.tensor["edge_feature"]) == expected_edges - -def test_extend_with_common_indices(simple_graph_state): - """Test extending states with common node indices""" - # Create a state with overlapping node indices - tensor = TensorDict({ - "node_feature": torch.tensor([[3.0], [4.0]]), - "node_index": torch.tensor([1, 2]), # Note: index 1 overlaps - "edge_feature": torch.tensor([[0.7]]), - "edge_index": torch.tensor([[1, 2]]), - "batch_ptr": torch.tensor([0, 2]), - "batch_shape": torch.tensor([1]) - }) - other_state = MyGraphStates(tensor) + # Each graph has 2 nodes and 1 edge + # For 3 time steps and 2 batches, we should have: + expected_nodes = 3 * 2 * 2 # T * nodes_per_graph * B + expected_edges = 3 * 1 * 2 # T * edges_per_graph * B + + # The actual count might be higher due to padding with sink states + assert state1.tensor.num_nodes >= expected_nodes + assert state1.tensor.num_edges >= expected_edges - simple_graph_state.extend(other_state) - # Check that node indices are unique after extend - assert len(torch.unique(simple_graph_state.tensor["node_index"])) == 4 - - # Check that edge indices were properly updated - edge_indices = simple_graph_state.tensor["edge_index"] - assert torch.all(edge_indices >= 0) - -def test_stack_1d_batch(): - """Test stacking multiple 1D batch states""" - # Create first state - tensor1 = TensorDict({ - "node_feature": torch.tensor([[1.0], [2.0]]), - "node_index": torch.tensor([0, 1]), - "edge_feature": torch.tensor([[0.5]]), - "edge_index": torch.tensor([[0, 1]]), - "batch_ptr": torch.tensor([0, 2]), - "batch_shape": torch.tensor([1]) - }) - state1 = MyGraphStates(tensor1) - - # Create second state with different values - tensor2 = TensorDict({ - "node_feature": torch.tensor([[3.0], [4.0]]), - "node_index": torch.tensor([0, 5]), - "edge_feature": torch.tensor([[0.7]]), - "edge_index": torch.tensor([[2, 3]]), - "batch_ptr": torch.tensor([0, 2]), - "batch_shape": torch.tensor([1]) - }) - state2 = MyGraphStates(tensor2) +def test_getitem(): + """Test indexing into GraphStates""" + # Create a batch with 3 graphs + data1 = Data( + x=torch.tensor([[1.0], [2.0]]), + edge_index=torch.tensor([[0], [1]]), + edge_attr=torch.tensor([[0.5]]) + ) + data2 = Data( + x=torch.tensor([[3.0], [4.0]]), + edge_index=torch.tensor([[0], [1]]), + edge_attr=torch.tensor([[0.6]]) + ) + data3 = Data( + x=torch.tensor([[5.0], [6.0]]), + edge_index=torch.tensor([[0], [1]]), + edge_attr=torch.tensor([[0.7]]) + ) + batch = Batch.from_data_list([data1, data2, data3]) + batch.batch_shape = torch.tensor([3]) + states = MyGraphStates(batch) + + # Get a single graph + single_state = states[1] + assert single_state.tensor.batch_shape[0] == 1 + assert single_state.tensor.num_nodes == 2 + assert torch.allclose(single_state.tensor.x, torch.tensor([[3.0], [4.0]])) + + # Get multiple graphs + multi_state = states[[0, 2]] + assert multi_state.tensor.batch_shape[0] == 2 + assert multi_state.tensor.num_nodes == 4 + + # Check first graph in selection + first_graph = multi_state.tensor.get_example(0) + assert torch.allclose(first_graph.x, torch.tensor([[1.0], [2.0]])) + + # Check second graph in selection + second_graph = multi_state.tensor.get_example(1) + assert torch.allclose(second_graph.x, torch.tensor([[5.0], [6.0]])) + + +def test_clone(simple_graph_state): + """Test cloning a GraphStates object""" + cloned = simple_graph_state.clone() + + # Check that the clone has the same content + assert torch.equal(cloned.tensor.batch_shape, simple_graph_state.tensor.batch_shape) + assert torch.equal(cloned.tensor.x, simple_graph_state.tensor.x) + assert torch.equal(cloned.tensor.edge_index, simple_graph_state.tensor.edge_index) + assert torch.equal(cloned.tensor.edge_attr, simple_graph_state.tensor.edge_attr) + + # Modify the clone and check that the original is unchanged + cloned.tensor.x[0, 0] = 99.0 + assert cloned.tensor.x[0, 0] == 99.0 + assert simple_graph_state.tensor.x[0, 0] == 1.0 + + +def test_is_initial_state(): + """Test is_initial_state property""" + # Create a batch with s0 and a different graph + s0 = MyGraphStates.s0.clone() + different = Data( + x=torch.tensor([[5.0], [6.0]]), + edge_index=torch.tensor([[0], [1]]), + edge_attr=torch.tensor([[0.9]]) + ) + batch = Batch.from_data_list([s0, different]) + batch.batch_shape = torch.tensor([2]) + states = MyGraphStates(batch) + + # Check is_initial_state + is_initial = states.is_initial_state + assert is_initial[0] == True + assert is_initial[1] == False + + +def test_is_sink_state(): + """Test is_sink_state property""" + # Create a batch with sf and a different graph + sf = MyGraphStates.sf.clone() + different = Data( + x=torch.tensor([[5.0], [6.0]]), + edge_index=torch.tensor([[0], [1]]), + edge_attr=torch.tensor([[0.9]]) + ) + batch = Batch.from_data_list([sf, different]) + batch.batch_shape = torch.tensor([2]) + states = MyGraphStates(batch) + + # Check is_sink_state + is_sink = states.is_sink_state + assert is_sink[0] == True + assert is_sink[1] == False + + +def test_from_batch_shape(): + """Test creating states from batch shape""" + # Create states with initial state + states = MyGraphStates.from_batch_shape((3,)) + assert states.tensor.batch_shape[0] == 3 + assert states.tensor.num_nodes == 6 # 3 graphs * 2 nodes per graph + + # Check all graphs are s0 + is_initial = states.is_initial_state + assert torch.all(is_initial) + + # Create states with sink state + sink_states = MyGraphStates.from_batch_shape((2,), sink=True) + assert sink_states.tensor.batch_shape[0] == 2 + assert sink_states.tensor.num_nodes == 4 # 2 graphs * 2 nodes per graph + + # Check all graphs are sf + is_sink = sink_states.is_sink_state + assert torch.all(is_sink) + + +def test_forward_masks(): + """Test forward_masks property""" + # Create a graph with 2 nodes and 1 edge + data = Data( + x=torch.tensor([[1.0], [2.0]]), + edge_index=torch.tensor([[0], [1]]), + edge_attr=torch.tensor([[0.5]]) + ) + batch = Batch.from_data_list([data]) + batch.batch_shape = torch.tensor([1]) + states = MyGraphStates(batch) + + # Get forward masks + masks = states.forward_masks + + # Check action type mask + assert masks["action_type"].shape == (1, 3) + assert masks["action_type"][0, GraphActionType.ADD_NODE] == True # Can add node + assert masks["action_type"][0, GraphActionType.ADD_EDGE] == True # Can add edge (2 nodes) + assert masks["action_type"][0, GraphActionType.EXIT] == True # Can exit + + # Check features mask + assert masks["features"].shape == (1, 1) # 1 feature dimension + assert masks["features"][0, 0] == True # All features allowed + + # Check edge_index masks + assert len(masks["edge_index"]) == 1 # 1 graph + assert torch.all(masks["edge_index"][0] == torch.tensor([[False, False], [True, False]])) + +def test_backward_masks(): + """Test backward_masks property""" + # Create a graph with 2 nodes and 1 edge + data = Data( + x=torch.tensor([[1.0], [2.0]]), + edge_index=torch.tensor([[0], [1]]), + edge_attr=torch.tensor([[0.5]]) + ) + batch = Batch.from_data_list([data]) + batch.batch_shape = torch.tensor([1]) + states = MyGraphStates(batch) + + # Get backward masks + masks = states.backward_masks + + # Check action type mask + assert masks["action_type"].shape == (1, 3) + assert masks["action_type"][0, GraphActionType.ADD_NODE] == True # Can remove node + assert masks["action_type"][0, GraphActionType.ADD_EDGE] == True # Can remove edge + assert masks["action_type"][0, GraphActionType.EXIT] == True # Can exit + + # Check features mask + assert masks["features"].shape == (1, 1) # 1 feature dimension + assert masks["features"][0, 0] == True # All features allowed + + # Check edge_index masks + assert len(masks["edge_index"]) == 1 # 1 graph + assert torch.all(masks["edge_index"][0] == torch.tensor([[False, True], [False, False]])) + + +def test_stack(): + """Test stacking GraphStates objects""" + # Create two states + data1 = Data( + x=torch.tensor([[1.0], [2.0]]), + edge_index=torch.tensor([[0], [1]]), + edge_attr=torch.tensor([[0.5]]) + ) + batch1 = Batch.from_data_list([data1]) + batch1.batch_shape = torch.tensor([1]) + state1 = MyGraphStates(batch1) + + data2 = Data( + x=torch.tensor([[3.0], [4.0]]), + edge_index=torch.tensor([[0], [1]]), + edge_attr=torch.tensor([[0.7]]) + ) + batch2 = Batch.from_data_list([data2]) + batch2.batch_shape = torch.tensor([1]) + state2 = MyGraphStates(batch2) + # Stack the states stacked = MyGraphStates.stack([state1, state2]) - - # Check the batch shape is correct (2, 1) - assert torch.equal(stacked.tensor["batch_shape"], torch.tensor([2, 1])) - - # Check that node features are preserved and ordered correctly - assert torch.equal(stacked.tensor["node_feature"], - torch.tensor([[1.0], [2.0], [3.0], [4.0]])) - - # Check that edge features are preserved and ordered correctly - assert torch.equal(stacked.tensor["edge_feature"], - torch.tensor([[0.5], [0.7]])) - - # Check that node indices are unique - assert len(torch.unique(stacked.tensor["node_index"])) == 4 - - # Check batch pointers are correct - assert torch.equal(stacked.tensor["batch_ptr"], torch.tensor([0, 2, 4])) - -def test_stack_empty_states(): - """Test stacking empty states""" - # Create empty state - tensor = TensorDict({ - "node_feature": torch.tensor([]), - "node_index": torch.tensor([]), - "edge_feature": torch.tensor([]), - "edge_index": torch.tensor([]).reshape(0, 2), - "batch_ptr": torch.tensor([0]), - "batch_shape": torch.tensor([0]) - }) - empty_state = MyGraphStates(tensor) - - # Stack multiple empty states - stacked = MyGraphStates.stack([empty_state, empty_state]) - - # Check the batch shape is correct (2, 0) - assert torch.equal(stacked.tensor["batch_shape"], torch.tensor([2, 0])) - - # Check that tensors are empty - assert stacked.tensor["node_feature"].numel() == 0 - assert stacked.tensor["edge_feature"].numel() == 0 - assert stacked.tensor["edge_index"].numel() == 0 - - # Check batch pointers are correct - assert torch.equal(stacked.tensor["batch_ptr"], torch.tensor([0])) + + # Check the batch shape + assert torch.equal(stacked.tensor.batch_shape, torch.tensor([2, 1])) + + # Check the number of nodes and edges + assert stacked.tensor.num_nodes == 4 # 2 states * 2 nodes + assert stacked.tensor.num_edges == 2 # 2 states * 1 edge + + # Check the batch indices + batch_indices = stacked.tensor.batch + assert torch.equal(batch_indices[:2], torch.zeros(2, dtype=torch.long)) + assert torch.equal(batch_indices[2:], torch.ones(2, dtype=torch.long)) From 382be452bbbd0dff648fadfec4191e1139310157 Mon Sep 17 00:00:00 2001 From: "Omar G. Younis" Date: Thu, 27 Feb 2025 15:45:50 +0100 Subject: [PATCH 095/102] update train_graph_ring --- src/gfn/states.py | 14 +++ tutorials/examples/train_graph_ring.py | 152 ++++++++----------------- 2 files changed, 62 insertions(+), 104 deletions(-) diff --git a/src/gfn/states.py b/src/gfn/states.py index 78296a6c..c99854f9 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -601,6 +601,13 @@ def make_initial_states_tensor(cls, batch_shape: int | Tuple) -> Batch: # Create a list of Data objects by copying s0 data_list = [cls.s0.clone() for _ in range(num_graphs)] + + if len(data_list) == 0: # If batch_shape is 0, create a single empty graph + data_list = [Data( + x=torch.zeros(0, cls.s0.x.size(1)), + edge_index=torch.zeros(2, 0, dtype=torch.long), + edge_attr=torch.zeros(0, cls.s0.edge_attr.size(1)) + )] # Create a batch from the list batch = Batch.from_data_list(data_list) @@ -686,6 +693,13 @@ def make_random_states_tensor(cls, batch_shape: int | Tuple) -> Batch: data_list.append(data) + if len(data_list) == 0: # If batch_shape is 0, create a single empty graph + data_list = [Data( + x=torch.zeros(0, cls.s0.x.size(1)), + edge_index=torch.zeros(2, 0, dtype=torch.long), + edge_attr=torch.zeros(0, cls.s0.edge_attr.size(1)) + )] + # Create a batch from the list batch = Batch.from_data_list(data_list) diff --git a/tutorials/examples/train_graph_ring.py b/tutorials/examples/train_graph_ring.py index ea001991..e74aa012 100644 --- a/tutorials/examples/train_graph_ring.py +++ b/tutorials/examples/train_graph_ring.py @@ -19,6 +19,7 @@ import torch from tensordict import TensorDict from torch import nn +from torch_geometric.data import Data from torch_geometric.nn import GINConv, GCNConv, DirGNNConv from gfn.actions import Actions, GraphActions, GraphActionType @@ -47,24 +48,15 @@ def directed_reward(states: GraphStates) -> torch.Tensor: Returns: A tensor of rewards. """ - if states.tensor["edge_index"].shape[0] == 0: + if states.tensor.edge_index.numel() == 0: return torch.full(states.batch_shape, EPS_VAL) out = torch.full((len(states),), EPS_VAL) # Default reward. for i in range(len(states)): - start, end = states.tensor["batch_ptr"][i], states.tensor["batch_ptr"][i + 1] - nodes_index_range = states.tensor["node_index"][start:end] - edge_index_mask = torch.all( - states.tensor["edge_index"] >= nodes_index_range[0], dim=-1 - ) & torch.all(states.tensor["edge_index"] <= nodes_index_range[-1], dim=-1) - masked_edge_index = ( - states.tensor["edge_index"][edge_index_mask] - nodes_index_range[0] - ) - - n_nodes = nodes_index_range.shape[0] - adj_matrix = torch.zeros(n_nodes, n_nodes) - adj_matrix[masked_edge_index[:, 0], masked_edge_index[:, 1]] = 1 + graph = states[i] + adj_matrix = torch.zeros(graph.tensor.num_nodes, graph.tensor.num_nodes) + adj_matrix[graph.tensor.edge_index[0], graph.tensor.edge_index[1]] = 1 # Check if each node has exactly one outgoing edge (row sum = 1) if not torch.all(adj_matrix.sum(dim=1) == 1): @@ -86,7 +78,7 @@ def directed_reward(states: GraphStates) -> torch.Tensor: current = torch.where(adj_matrix[current] == 1)[0].item() # If we've visited all nodes and returned to 0, it's a valid ring - if len(visited) == n_nodes and current == 0: + if len(visited) == graph.tensor.num_nodes and current == 0: out[i] = REW_VAL break @@ -104,30 +96,18 @@ def undirected_reward(states: GraphStates) -> torch.Tensor: Returns: A tensor of rewards. """ - if states.tensor["edge_index"].shape[0] == 0: + if states.tensor.edge_index.numel() == 0: return torch.full(states.batch_shape, EPS_VAL) out = torch.full((len(states),), EPS_VAL) # Default reward. for i in range(len(states)): - start, end = states.tensor["batch_ptr"][i], states.tensor["batch_ptr"][i + 1] - nodes_index_range = states.tensor["node_index"][start:end] - edge_index_mask = torch.all( - states.tensor["edge_index"] >= nodes_index_range[0], dim=-1 - ) & torch.all(states.tensor["edge_index"] <= nodes_index_range[-1], dim=-1) - masked_edge_index = ( - states.tensor["edge_index"][edge_index_mask] - nodes_index_range[0] - ) - - n_nodes = nodes_index_range.shape[0] - if n_nodes == 0: + graph = states[i] + if graph.tensor.num_nodes == 0: continue - - # Construct a symmetric adjacency matrix for the undirected graph. - adj_matrix = torch.zeros(n_nodes, n_nodes) - if masked_edge_index.shape[0] > 0: - adj_matrix[masked_edge_index[:, 0], masked_edge_index[:, 1]] = 1 - adj_matrix[masked_edge_index[:, 1], masked_edge_index[:, 0]] = 1 + adj_matrix = torch.zeros(graph.tensor.num_nodes, graph.tensor.num_nodes) + adj_matrix[graph.tensor.edge_index[0], graph.tensor.edge_index[1]] = 1 + adj_matrix[graph.tensor.edge_index[1], graph.tensor.edge_index[0]] = 1 # In an undirected ring, every vertex should have degree 2. if not torch.all(adj_matrix.sum(dim=1) == 2): @@ -156,7 +136,7 @@ def undirected_reward(states: GraphStates) -> torch.Tensor: next_node = possible[0] prev, current = current, next_node - if current == start_vertex and len(visited) == n_nodes: + if current == start_vertex and len(visited) == graph.tensor.num_nodes: out[i] = REW_VAL return out.view(*states.batch_shape) @@ -265,21 +245,15 @@ def _group_mean( def forward(self, states_tensor: TensorDict) -> torch.Tensor: node_features, batch_ptr = ( - states_tensor["node_feature"], - states_tensor["batch_ptr"], + states_tensor.x, + states_tensor.ptr, ) - batch_size = int(torch.prod(states_tensor["batch_shape"])) - - edge_index = torch.where( - states_tensor["edge_index"][..., None] == states_tensor["node_index"] - )[2].reshape( - states_tensor["edge_index"].shape - ) # (M, 2) + batch_size = int(torch.prod(states_tensor.batch_shape)) # Multiple action type convolutions with residual connections. x = self.embedding(node_features.squeeze().int()) for i in range(0, len(self.conv_blks), 2): - x_new = self.conv_blks[i](x, edge_index.T) # GIN/GCN conv. + x_new = self.conv_blks[i](x, states_tensor.edge_index) # GIN/GCN conv. if self.is_directed: x_in, x_out = torch.chunk(x_new, 2, dim=-1) # Process each component separately through its own MLP @@ -296,7 +270,7 @@ def forward(self, states_tensor: TensorDict) -> torch.Tensor: node_feature_means = self._group_mean(x, batch_ptr) exit_action = self.exit_mlp(node_feature_means) - x = x.reshape(*states_tensor["batch_shape"], self.n_nodes, self.hidden_dim) + x = x.reshape(*states_tensor.batch_shape, self.n_nodes, self.hidden_dim) # Undirected. if self.is_directed: @@ -331,7 +305,7 @@ def forward(self, states_tensor: TensorDict) -> torch.Tensor: # Grab the needed elems from the adjacency matrix and reshape. edge_actions = x[torch.arange(batch_size)[:, None, None], i0, i1] - edge_actions = edge_actions.reshape(*states_tensor["batch_shape"], out_size) + edge_actions = edge_actions.reshape(*states_tensor.batch_shape, out_size) if self.is_backward: return edge_actions @@ -377,27 +351,21 @@ def make_states_class(self) -> type[GraphStates]: env = self class RingStates(GraphStates): - s0 = TensorDict( - { - "node_feature": torch.arange(env.n_nodes)[:, None], - "edge_feature": torch.ones((0, 1)), - "edge_index": torch.ones((0, 2), dtype=torch.long), - }, - batch_size=(), + s0 = Data( + x=torch.arange(env.n_nodes)[:, None], + edge_attr=torch.ones((0, 1)), + edge_index=torch.ones((2, 0), dtype=torch.long), ) - sf = TensorDict( - { - "node_feature": -torch.ones(env.n_nodes)[:, None], - "edge_feature": torch.zeros((0, 1)), - "edge_index": torch.zeros((0, 2), dtype=torch.long), - }, - batch_size=(), + sf = Data( + x=-torch.ones(env.n_nodes)[:, None], + edge_attr=torch.zeros((0, 1)), + edge_index=torch.zeros((2, 0), dtype=torch.long), ) def __init__(self, tensor: TensorDict): self.tensor = tensor - self.node_features_dim = tensor["node_feature"].shape[-1] - self.edge_features_dim = tensor["edge_feature"].shape[-1] + self.node_features_dim = tensor.x.shape[-1] + self.edge_features_dim = tensor.edge_attr.shape[-1] self._log_rewards: Optional[float] = None self.n_nodes = env.n_nodes @@ -424,23 +392,20 @@ def forward_masks(self): # Remove existing edges. for i in range(len(self)): - existing_edges = ( - self[i].tensor["edge_index"] - - self.tensor["node_index"][self.tensor["batch_ptr"][i]] - ) + existing_edges = self[i].tensor.edge_index assert torch.all(existing_edges >= 0) # TODO: convert to test. - if len(existing_edges) == 0: + if existing_edges.numel() == 0: edge_idx = torch.zeros(0, dtype=torch.bool) else: edge_idx = torch.logical_and( - existing_edges[:, 0] == ei0.unsqueeze(-1), - existing_edges[:, 1] == ei1.unsqueeze(-1), + existing_edges[0] == ei0, + existing_edges[1] == ei1, ) # Collapse across the edge dimension. if len(edge_idx.shape) == 2: - edge_idx = edge_idx.sum(1).bool() + edge_idx = edge_idx.sum(0).bool() # Adds an unmasked exit action. edge_idx = torch.cat((edge_idx, torch.BoolTensor([False]))) @@ -463,8 +428,8 @@ def backward_masks(self): for i in range(len(self)): existing_edges = ( - self[i].tensor["edge_index"] - - self.tensor["node_index"][self.tensor["batch_ptr"][i]] + self[i].tensor.edge_index + - self.tensor.node_index[self.tensor.batch_ptr[i]] ) if env.is_directed: @@ -552,13 +517,12 @@ def convert_actions(self, states: GraphStates, actions: Actions) -> GraphActions # action_tensor[action_tensor >= (self.n_actions - 1)] = 0 ei0, ei1 = ei0[action_tensor], ei1[action_tensor] - offset = states.tensor["node_index"][states.tensor["batch_ptr"][:-1]] out = GraphActions( TensorDict( { "action_type": action_type, "features": torch.ones(action_tensor.shape + (1,)), - "edge_index": torch.stack([ei0, ei1], dim=-1) + offset[:, None], + "edge_index": torch.stack([ei0, ei1], dim=-1), }, batch_size=action_tensor.shape, ) @@ -590,7 +554,7 @@ def render_states(states: GraphStates, state_evaluator: callable, directed: bool for i in range(8): current_ax = ax[i // 4, i % 4] state = states[i] - n_circles = state.tensor["node_feature"].shape[0] + n_circles = state.tensor.x.shape[0] radius = 5 xs, ys = [], [] for j in range(n_circles): @@ -603,7 +567,7 @@ def render_states(states: GraphStates, state_evaluator: callable, directed: bool patches.Circle((x, y), 0.5, facecolor="none", edgecolor="black") ) - edge_index = states[i].tensor["edge_index"] + edge_index = states[i].tensor.edge_index edge_index = torch.where( edge_index[..., None] == states[i].tensor["node_index"] )[2].reshape(edge_index.shape) @@ -709,39 +673,19 @@ def __init__( def forward(self, states_tensor: TensorDict) -> torch.Tensor: # Convert the graph to adjacency matrix - batch_size = int(torch.prod(torch.tensor(states_tensor["batch_shape"]))) + batch_size = int(states_tensor.batch_size) adj_matrices = torch.zeros( (batch_size, self.n_nodes, self.n_nodes), - device=states_tensor["node_feature"].device, + device=states_tensor.x.device, ) # Fill the adjacency matrices from edge indices - for i in range(batch_size): - start, end = ( - states_tensor["batch_ptr"][i], - states_tensor["batch_ptr"][i + 1], - ) - nodes_index_range = states_tensor["node_index"][start:end] - - # Skip if no edges - if states_tensor["edge_index"].shape[0] == 0: - continue - - # Find edges that belong to this graph - edge_index_mask = torch.all( - states_tensor["edge_index"] >= nodes_index_range[0], dim=-1 - ) & torch.all(states_tensor["edge_index"] <= nodes_index_range[-1], dim=-1) - - if torch.any(edge_index_mask): - # Get the edge indices relative to this graph's node indices - masked_edge_index = ( - states_tensor["edge_index"][edge_index_mask] - nodes_index_range[0] - ) - # Fill the adjacency matrix - if len(masked_edge_index) > 0: - adj_matrices[ - i, masked_edge_index[:, 0], masked_edge_index[:, 1] - ] = 1 + if states_tensor.edge_index.numel() > 0: + adj_matrices[ + torch.arange(batch_size)[:, None, None], + states_tensor.edge_index[0] - states_tensor.ptr[:-1], + states_tensor.edge_index[1] - states_tensor.ptr[:-1], + ] = 1 # Flatten the adjacency matrices for the MLP adj_matrices_flat = adj_matrices.view(batch_size, -1) @@ -766,7 +710,7 @@ def forward(self, states_tensor: TensorDict) -> torch.Tensor: BATCH_SIZE = 128 DIRECTED = True USE_BUFFER = False - USE_GNN = False # Set to False to use MLP with adjacency matrices instead of GNN + USE_GNN = True # Set to False to use MLP with adjacency matrices instead of GNN state_evaluator = undirected_reward if not DIRECTED else directed_reward torch.random.manual_seed(7) From 939bdab9fcbc956aa9fb350422704c6dd49478e5 Mon Sep 17 00:00:00 2001 From: "Omar G. Younis" Date: Thu, 27 Feb 2025 17:05:12 +0100 Subject: [PATCH 096/102] fix all problems --- src/gfn/gym/graph_building.py | 64 +++++++++++++------------- src/gfn/modules.py | 2 +- src/gfn/utils/distributions.py | 13 ++---- tutorials/examples/train_graph_ring.py | 27 +++++------ 4 files changed, 50 insertions(+), 56 deletions(-) diff --git a/src/gfn/gym/graph_building.py b/src/gfn/gym/graph_building.py index 1006f351..2ba32b31 100644 --- a/src/gfn/gym/graph_building.py +++ b/src/gfn/gym/graph_building.py @@ -68,9 +68,11 @@ def step(self, states: GraphStates, actions: GraphActions) -> Batch: """ if not self.is_action_valid(states, actions): raise NonValidActionsError("Invalid action.") + if len(actions) == 0: + return states.tensor action_type = actions.action_type[0] - assert torch.all(actions.action_type == action_type) + assert torch.all(actions.action_type == action_type) # TODO: allow different action types if action_type == GraphActionType.EXIT: return self.States.make_sink_states_tensor(states.batch_shape) @@ -185,39 +187,39 @@ def is_action_valid( # Get the data list from the batch data_list = states.tensor.to_data_list() - for i, features in enumerate(actions.features[actions.action_type == GraphActionType.ADD_NODE]): + for i in range(len(actions)): graph = data_list[i] - - # Check if a node with these features already exists - equal_nodes = torch.all(graph.x == features.unsqueeze(0), dim=1) - - if backward: - # For backward actions, we need at least one matching node - if not torch.any(equal_nodes): - return False - else: - # For forward actions, we should not have any matching nodes - if torch.any(equal_nodes): - return False + if actions.action_type[i] == GraphActionType.ADD_NODE: + # Check if a node with these features already exists + equal_nodes = torch.all(graph.x == actions.features[i].unsqueeze(0), dim=1) + + if backward: + # For backward actions, we need at least one matching node + if not torch.any(equal_nodes): + return False + else: + # For forward actions, we should not have any matching nodes + if torch.any(equal_nodes): + return False - for i, (src, dst) in enumerate(actions.edge_index[actions.action_type == GraphActionType.ADD_EDGE]): - graph = data_list[i] - - # Check if src and dst are valid node indices - if src >= graph.num_nodes or dst >= graph.num_nodes or src == dst: - return False - - # Check if the edge already exists - edge_exists = torch.any((graph.edge_index[0] == src) & (graph.edge_index[1] == dst)) - - if backward: - # For backward actions, the edge must exist - if not edge_exists: - return False - else: - # For forward actions, the edge must not exist - if edge_exists: + elif actions.action_type[i] == GraphActionType.ADD_EDGE: + src, dst = actions.edge_index[i] + + # Check if src and dst are valid node indices + if src >= graph.num_nodes or dst >= graph.num_nodes or src == dst: return False + + # Check if the edge already exists + edge_exists = torch.any((graph.edge_index[0] == src) & (graph.edge_index[1] == dst)) + + if backward: + # For backward actions, the edge must exist + if not edge_exists: + return False + else: + # For forward actions, the edge must not exist + if edge_exists: + return False return True diff --git a/src/gfn/modules.py b/src/gfn/modules.py index 5ac7b77f..342b4632 100644 --- a/src/gfn/modules.py +++ b/src/gfn/modules.py @@ -537,7 +537,7 @@ def to_probability_distribution( ) * edge_index_probs + epsilon * uniform_dist_probs edge_index_probs[torch.isnan(edge_index_probs)] = 1 dists["edge_index"] = CategoricalIndexes( - probs=edge_index_probs, node_indexes=states.tensor["node_index"] + probs=edge_index_probs, n_nodes=states.tensor.num_nodes ) dists["features"] = Normal(module_output["features"], temperature) diff --git a/src/gfn/utils/distributions.py b/src/gfn/utils/distributions.py index 48db3f58..3ea029cf 100644 --- a/src/gfn/utils/distributions.py +++ b/src/gfn/utils/distributions.py @@ -73,17 +73,16 @@ def log_prob(self, sample: Dict[str, torch.Tensor]) -> torch.Tensor: class CategoricalIndexes(Categorical): """Samples indexes from a categorical distribution.""" - def __init__(self, probs: torch.Tensor, node_indexes: torch.Tensor): + def __init__(self, probs: torch.Tensor, n_nodes: int): """Initializes the distribution. Args: probs: The probabilities of the categorical distribution. n: The number of nodes in the graph. """ - self.node_indexes = node_indexes assert probs.shape == ( probs.shape[0], - node_indexes.shape[0] * node_indexes.shape[0], + n_nodes * n_nodes, ) super().__init__(probs) @@ -91,17 +90,15 @@ def sample(self, sample_shape=torch.Size()) -> torch.Tensor: samples = super().sample(sample_shape) out = torch.stack( [ - samples // self.node_indexes.shape[0], - samples % self.node_indexes.shape[0], + samples // self.n_nodes, + samples % self.n_nodes, ], dim=-1, ) - out = self.node_indexes.index_select(0, out.flatten()).reshape(*out.shape) return out def log_prob(self, value): - value = value[..., 0] * self.node_indexes.shape[0] + value[..., 1] - value = torch.bucketize(value, self.node_indexes) + value = value[..., 0] * self.n_nodes + value[..., 1] return super().log_prob(value) diff --git a/tutorials/examples/train_graph_ring.py b/tutorials/examples/train_graph_ring.py index e74aa012..7e0dfbed 100644 --- a/tutorials/examples/train_graph_ring.py +++ b/tutorials/examples/train_graph_ring.py @@ -19,7 +19,7 @@ import torch from tensordict import TensorDict from torch import nn -from torch_geometric.data import Data +from torch_geometric.data import Batch, Data from torch_geometric.nn import GINConv, GCNConv, DirGNNConv from gfn.actions import Actions, GraphActions, GraphActionType @@ -399,8 +399,8 @@ def forward_masks(self): edge_idx = torch.zeros(0, dtype=torch.bool) else: edge_idx = torch.logical_and( - existing_edges[0] == ei0, - existing_edges[1] == ei1, + existing_edges[0][..., None] == ei0[None], + existing_edges[1][..., None] == ei1[None], ) # Collapse across the edge dimension. @@ -427,10 +427,7 @@ def backward_masks(self): ) for i in range(len(self)): - existing_edges = ( - self[i].tensor.edge_index - - self.tensor.node_index[self.tensor.batch_ptr[i]] - ) + existing_edges = self[i].tensor.edge_index if env.is_directed: i_up, j_up = torch.triu_indices( @@ -452,12 +449,12 @@ def backward_masks(self): edge_idx = torch.zeros(0, dtype=torch.bool) else: edge_idx = torch.logical_and( - existing_edges[:, 0] == ei0.unsqueeze(-1), - existing_edges[:, 1] == ei1.unsqueeze(-1), + existing_edges[0][..., None] == ei0[None], + existing_edges[1][..., None] == ei1[None], ) # Collapse across the edge dimension. if len(edge_idx.shape) == 2: - edge_idx = edge_idx.sum(1).bool() + edge_idx = edge_idx.sum(0).bool() backward_masks[i, edge_idx] = ( True # Allow the removal of this edge. @@ -671,7 +668,7 @@ def __init__( add_layer_norm=True, ) - def forward(self, states_tensor: TensorDict) -> torch.Tensor: + def forward(self, states_tensor: Batch) -> torch.Tensor: # Convert the graph to adjacency matrix batch_size = int(states_tensor.batch_size) adj_matrices = torch.zeros( @@ -681,11 +678,9 @@ def forward(self, states_tensor: TensorDict) -> torch.Tensor: # Fill the adjacency matrices from edge indices if states_tensor.edge_index.numel() > 0: - adj_matrices[ - torch.arange(batch_size)[:, None, None], - states_tensor.edge_index[0] - states_tensor.ptr[:-1], - states_tensor.edge_index[1] - states_tensor.ptr[:-1], - ] = 1 + for i in range(batch_size): + eis = states_tensor[i].edge_index + adj_matrices[i, eis[0], eis[1]] = 1 # Flatten the adjacency matrices for the MLP adj_matrices_flat = adj_matrices.view(batch_size, -1) From 0f57485031b07a71a176ee4ac3d4bc1aad10714f Mon Sep 17 00:00:00 2001 From: "Omar G. Younis" Date: Thu, 27 Feb 2025 17:13:26 +0100 Subject: [PATCH 097/102] remove node_index --- src/gfn/states.py | 24 ------------------------ tutorials/examples/train_graph_ring.py | 7 ++----- 2 files changed, 2 insertions(+), 29 deletions(-) diff --git a/src/gfn/states.py b/src/gfn/states.py index c99854f9..42a4043f 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -543,8 +543,6 @@ class GraphStates(States): s0: ClassVar[Data] sf: ClassVar[Data] - _next_node_index = 0 - def __init__(self, tensor: Batch): """Initialize the GraphStates with a PyG Batch object. @@ -1113,25 +1111,3 @@ def backward_masks(self) -> dict: "features": features_mask, "edge_index": edge_index_masks } - - @staticmethod - def unique_node_indices(num_nodes: int) -> torch.Tensor: - """Generate unique node indices for new nodes. - - Args: - num_nodes: Number of new nodes to generate indices for. - - Returns: - A tensor of unique node indices. - """ - # Generate sequential indices starting from the next available index - indices = torch.arange( - GraphStates._next_node_index, - GraphStates._next_node_index + num_nodes, - dtype=torch.long - ) - - # Update the next available index - GraphStates._next_node_index += num_nodes - - return indices diff --git a/tutorials/examples/train_graph_ring.py b/tutorials/examples/train_graph_ring.py index 7e0dfbed..e348bc8c 100644 --- a/tutorials/examples/train_graph_ring.py +++ b/tutorials/examples/train_graph_ring.py @@ -551,7 +551,7 @@ def render_states(states: GraphStates, state_evaluator: callable, directed: bool for i in range(8): current_ax = ax[i // 4, i % 4] state = states[i] - n_circles = state.tensor.x.shape[0] + n_circles = state.tensor.num_nodes radius = 5 xs, ys = [], [] for j in range(n_circles): @@ -565,11 +565,8 @@ def render_states(states: GraphStates, state_evaluator: callable, directed: bool ) edge_index = states[i].tensor.edge_index - edge_index = torch.where( - edge_index[..., None] == states[i].tensor["node_index"] - )[2].reshape(edge_index.shape) - for edge in edge_index: + for edge in edge_index.T: start_x, start_y = xs[edge[0]], ys[edge[0]] end_x, end_y = xs[edge[1]], ys[edge[1]] dx = end_x - start_x From 911157c9ee1aa2345d601ae2223726a058ef4329 Mon Sep 17 00:00:00 2001 From: "Omar G. Younis" Date: Thu, 27 Feb 2025 17:15:50 +0100 Subject: [PATCH 098/102] move dep --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 32a3e1ec..cbd533a4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ numpy = ">=1.21.2" python = "^3.10" torch = ">=1.9.0" tensordict = ">=0.6.1" +torch_geometric = ">=2.6.1" # dev dependencies. black = { version = "24.3", optional = true } @@ -64,7 +65,6 @@ dev = [ "sphinx", "tox", "flake8", - "torch_geometric", ] scripts = ["tqdm", "wandb", "scikit-learn", "scipy", "matplotlib"] From 8806cda07b4015ac59e6e0bbc0358c6fb8fe5864 Mon Sep 17 00:00:00 2001 From: hyeok9855 Date: Sat, 1 Mar 2025 02:53:12 +0900 Subject: [PATCH 099/102] fix pyproject.toml --- pyproject.toml | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index cbd533a4..ec4dd579 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ classifiers = [ einops = ">=0.6.1" numpy = ">=1.21.2" python = "^3.10" -torch = ">=1.9.0" +torch = ">=2.6.0" tensordict = ">=0.6.1" torch_geometric = ">=2.6.1" @@ -49,7 +49,6 @@ wandb = { version = "*", optional = true } scikit-learn = {version = "*", optional = true } scipy = { version = "*", optional = true } matplotlib = { version = "*", optional = true } -torch_geometric = { version = ">=2.6.1", optional = true } [tool.poetry.extras] dev = [ @@ -86,7 +85,6 @@ all = [ "tox", "tqdm", "wandb", - "torch_geometric", ] [tool.poetry.urls] From ce287eb6024c550c3edfcd4c73054f66067559f3 Mon Sep 17 00:00:00 2001 From: hyeok9855 Date: Sat, 1 Mar 2025 03:03:44 +0900 Subject: [PATCH 100/102] change state type hinting from Tensordict to torch_geometric Data --- src/gfn/preprocessors.py | 4 ++-- src/gfn/states.py | 8 ++++---- tutorials/examples/train_graph_ring.py | 12 ++++++------ 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/gfn/preprocessors.py b/src/gfn/preprocessors.py index f37e87f9..bee77249 100644 --- a/src/gfn/preprocessors.py +++ b/src/gfn/preprocessors.py @@ -2,7 +2,7 @@ from typing import Callable import torch -from tensordict import TensorDict +from torch_geometric.data import Batch from gfn.states import GraphStates, States @@ -79,5 +79,5 @@ class GraphPreprocessor(Preprocessor): def __init__(self) -> None: super().__init__(-1) # TODO: review output_dim API - def preprocess(self, states: GraphStates) -> TensorDict: + def preprocess(self, states: GraphStates) -> Batch: return states.tensor diff --git a/src/gfn/states.py b/src/gfn/states.py index 42a4043f..3ac4988f 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -121,7 +121,7 @@ def make_initial_states_tensor(cls, batch_shape: tuple[int, ...]) -> torch.Tenso return cls.s0.repeat(*batch_shape, *((1,) * state_ndim)) else: raise NotImplementedError( - "make_initial_states_tensor is not implemented by default for TensorDicts" + f"make_initial_states_tensor is not implemented by default for {cls.__name__}" ) @classmethod @@ -133,7 +133,7 @@ def make_sink_states_tensor(cls, batch_shape: tuple[int, ...]) -> torch.Tensor: return cls.sf.repeat(*batch_shape, *((1,) * state_ndim)) else: raise NotImplementedError( - "make_sink_states_tensor is not implemented by default for TensorDicts" + f"make_sink_states_tensor is not implemented by default for {cls.__name__}" ) def __len__(self): @@ -275,7 +275,7 @@ def is_initial_state(self) -> torch.Tensor: ) else: raise NotImplementedError( - "is_initial_state is not implemented by default for TensorDicts" + f"is_initial_state is not implemented by default for {self.__class__.__name__}" ) return self.compare(source_states_tensor) @@ -289,7 +289,7 @@ def is_sink_state(self) -> torch.Tensor: ).to(self.tensor.device) else: raise NotImplementedError( - "is_sink_state is not implemented by default for TensorDicts" + f"is_sink_state is not implemented by default for {self.__class__.__name__}" ) return self.compare(sink_states) diff --git a/tutorials/examples/train_graph_ring.py b/tutorials/examples/train_graph_ring.py index 0292af54..7ba96010 100644 --- a/tutorials/examples/train_graph_ring.py +++ b/tutorials/examples/train_graph_ring.py @@ -262,7 +262,7 @@ def _group_mean( size = batch_ptr[1:] - batch_ptr[:-1] return (cumsum[batch_ptr[1:]] - cumsum[batch_ptr[:-1]]) / size[:, None] - def forward(self, states_tensor: TensorDict) -> torch.Tensor: + def forward(self, states_tensor: Batch) -> torch.Tensor: node_features, batch_ptr = ( states_tensor.x, states_tensor.ptr, @@ -418,7 +418,7 @@ class RingStates(GraphStates): edge_index=torch.zeros((2, 0), dtype=torch.long), ) - def __init__(self, tensor: TensorDict): + def __init__(self, tensor: Batch): self.tensor = tensor self.node_features_dim = tensor.x.shape[-1] self.edge_features_dim = tensor.edge_attr.shape[-1] @@ -646,7 +646,7 @@ def convert_actions(self, states: GraphStates, actions: Actions) -> GraphActions class GraphPreprocessor(Preprocessor): """Preprocessor for graph states to extract the tensor representation. - This simple preprocessor extracts the TensorDict from GraphStates to make + This simple preprocessor extracts the torch_geometric Batch from GraphStates to make it compatible with the policy networks. It doesn't perform any complex transformations, just ensuring the tensors are accessible in the right format. @@ -657,10 +657,10 @@ class GraphPreprocessor(Preprocessor): def __init__(self, feature_dim: int = 1): super().__init__(output_dim=feature_dim) - def preprocess(self, states: GraphStates) -> TensorDict: + def preprocess(self, states: GraphStates) -> Batch: return states.tensor - def __call__(self, states: GraphStates) -> TensorDict: + def __call__(self, states: GraphStates) -> Batch: return self.preprocess(states) @@ -821,7 +821,7 @@ def forward(self, states_tensor: Batch) -> torch.Tensor: 3. Predict logits for edge actions and exit action Args: - states_tensor: A TensorDict containing graph state information + states_tensor: A torch_geometric Batch containing graph state information Returns: A tensor of logits for all possible actions From 7607945d4f7eb06a6d0031822d4a7199f47dddf6 Mon Sep 17 00:00:00 2001 From: hyeok9855 Date: Sat, 1 Mar 2025 03:05:09 +0900 Subject: [PATCH 101/102] add settings that achieve 95% in the ring generation (directed) with four nodes --- tutorials/examples/train_graph_ring.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tutorials/examples/train_graph_ring.py b/tutorials/examples/train_graph_ring.py index 7ba96010..b7003f7b 100644 --- a/tutorials/examples/train_graph_ring.py +++ b/tutorials/examples/train_graph_ring.py @@ -877,9 +877,9 @@ def forward(self, states_tensor: Batch) -> torch.Tensor: 3. Visualize sample generated graphs """ N_NODES = 4 - N_ITERATIONS = 1000 + N_ITERATIONS = 200 LR = 0.001 - BATCH_SIZE = 128 + BATCH_SIZE = 1024 DIRECTED = True USE_BUFFER = False USE_GNN = True # Set to False to use MLP with adjacency matrices instead of GNN @@ -909,6 +909,7 @@ def forward(self, states_tensor: Batch) -> torch.Tensor: ) gflownet = TBGFlowNet(pf, pb) optimizer = torch.optim.Adam(gflownet.parameters(), lr=LR) + scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.1) replay_buffer = ReplayBuffer( env, @@ -925,6 +926,7 @@ def forward(self, states_tensor: Batch) -> torch.Tensor: env, n=BATCH_SIZE, save_logprobs=True, # pyright: ignore + epsilon=0.2 * (1 - iteration / N_ITERATIONS), ) training_samples = gflownet.to_training_samples(trajectories) @@ -956,6 +958,7 @@ def forward(self, states_tensor: Batch) -> torch.Tensor: ) loss.backward() optimizer.step() + scheduler.step() losses.append(loss.item()) t2 = time.time() From e918ac7c840b8476778d4c1caeed7de55688cf03 Mon Sep 17 00:00:00 2001 From: hyeok9855 Date: Mon, 3 Mar 2025 23:50:04 +0900 Subject: [PATCH 102/102] rename test_state to test_graph_states --- testing/{test_states.py => test_graph_states.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename testing/{test_states.py => test_graph_states.py} (100%) diff --git a/testing/test_states.py b/testing/test_graph_states.py similarity index 100% rename from testing/test_states.py rename to testing/test_graph_states.py