Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
120 commits
Select commit Hold shift + click to select a range
9ae28b2
including Graphs as States for torchgfn
alip67 Nov 6, 2024
de6ab1c
add GraphEnv
younik Nov 7, 2024
24e23e8
add deps and reformat
younik Nov 7, 2024
1f7b220
add test, fix errors, add valid action check
younik Nov 8, 2024
63e4f1c
fix formatting
younik Nov 8, 2024
8034fb2
add GraphAction
younik Nov 14, 2024
d179671
fix batching mechanism
younik Nov 14, 2024
e018f4e
Merge branch 'GFNOrg:master' into graph-states
alip67 Nov 15, 2024
7ff96d5
add support for EXIT action
younik Nov 16, 2024
cf482da
Merge branch 'graph-states' of https://github.com/alip67/torchgfn int…
younik Nov 16, 2024
dacbbf7
add GraphActionPolicyEstimator
younik Nov 19, 2024
98ea448
Merge branch 'GFNOrg:master' into graph-states
alip67 Nov 19, 2024
e74e500
Sampler integration work
younik Nov 22, 2024
a862bb4
Merge branch 'graph-states' of https://github.com/alip67/torchgfn int…
younik Nov 22, 2024
5e64c84
use TensorDict
younik Nov 26, 2024
81f8b71
solve some errors
younik Nov 28, 2024
34781ef
use tensordict in actions
younik Nov 28, 2024
3e584f2
handle sf
younik Dec 2, 2024
d5e438f
remove Data
younik Dec 3, 2024
fba5d50
categorical action type
younik Dec 6, 2024
478bd14
change batching
younik Dec 10, 2024
dd80f28
fix stacking
younik Dec 11, 2024
616551c
fix graph stacking
younik Dec 11, 2024
77611d4
fix test graph env
younik Dec 12, 2024
5874ff6
add ring example
younik Dec 19, 2024
9d42332
remove check edge_features
younik Dec 20, 2024
2d44242
fix GraphStates set
younik Dec 20, 2024
173d4fb
remove debug
younik Dec 20, 2024
7265857
fix add_edge action
younik Dec 20, 2024
2b3208f
fix edge_index after get
younik Dec 20, 2024
b84246f
push updated code
younik Dec 22, 2024
fa0d22a
add rendering
younik Dec 27, 2024
27d192a
fix gradient propagation
younik Jan 6, 2025
5d99739
Merge remote-tracking branch 'origin/master' into graph-states
younik Jan 12, 2025
f4fc3ab
fix formatting
younik Jan 12, 2025
8f1c62c
address comments
younik Jan 12, 2025
6482834
fix test
younik Jan 12, 2025
6db601d
fix test
younik Jan 13, 2025
c7f8243
fix pre-commit
younik Jan 13, 2025
c3df427
Merge remote-tracking branch 'origin/master' into graph-states
younik Jan 13, 2025
78b729a
fix merging issues
younik Jan 13, 2025
38dd2b0
fix toml
younik Jan 13, 2025
12c49b7
add dep & address issue
younik Jan 13, 2025
fe237ed
fix toml
younik Jan 13, 2025
9bbc48d
fix pyproject.toml
younik Jan 13, 2025
5e4fc4e
address comments
younik Jan 14, 2025
705b4cc
add tests for action
younik Jan 15, 2025
d765330
fix test after added dummy action
younik Jan 15, 2025
4ee6987
add GraphPreprocessor
younik Jan 15, 2025
fe9713c
added TODO
younik Jan 15, 2025
1425eb6
add complete masks
younik Jan 19, 2025
36c42ec
pre-commit hook
younik Jan 19, 2025
5747e97
adress comments
younik Jan 19, 2025
e9f9951
pre-commit
younik Jan 19, 2025
406cfca
address comments
younik Jan 20, 2025
08e519b
fix ring example
younik Jan 24, 2025
17e07ad
make edge_index global
younik Jan 29, 2025
46e3698
make edge_index global
younik Jan 29, 2025
1c33d98
Merge remote-tracking branch 'origin/graph-states-fix' into graph-states
younik Jan 29, 2025
e6d909b
fix test_env
younik Jan 29, 2025
da66adb
add global edge + pair programming session
younik Feb 4, 2025
7130281
pair programming session
younik Feb 5, 2025
1fa01df
fix node_index
younik Feb 6, 2025
48ceafd
add comments
younik Feb 10, 2025
2ace3ca
Merge remote-tracking branch 'origin/master' into graph-states
younik Feb 10, 2025
2271732
fix linter
younik Feb 10, 2025
4be8f9b
add two tested RingPolicyEstimator
younik Feb 11, 2025
e7465b8
push tentatives
younik Feb 11, 2025
38041e8
trying TBGFN
younik Feb 12, 2025
2f4b3f9
increased capacity of RingPolicyEstimator
josephdviviano Feb 14, 2025
d0313d1
make undirected graph
younik Feb 17, 2025
f29432e
Merge remote-tracking branch 'alip67/graph-states' into graph-states
younik Feb 17, 2025
9d994da
merge over
josephdviviano Feb 18, 2025
6e1bbc2
merged
josephdviviano Feb 18, 2025
0625f5c
allows for configurable policy capacity, and the code simultaneously …
josephdviviano Feb 18, 2025
dae3299
spelling
josephdviviano Feb 18, 2025
671917b
allows for configurable policy capacity, and the code simultaneously …
josephdviviano Feb 18, 2025
623f50e
directed and undirected graphs now co-implemented -- but the model wi…
josephdviviano Feb 19, 2025
74d9b13
black
josephdviviano Feb 19, 2025
9b7ce68
change docstring and fix argument
saleml Feb 19, 2025
704cf66
add the possibility of layernorm in MLP
saleml Feb 19, 2025
7bbd72d
remove create_mlp function
saleml Feb 19, 2025
3e84ebd
Merge branch 'graph-states' of https://github.com/alip67/torchgfn int…
saleml Feb 19, 2025
7bdf5b5
new architecture
josephdviviano Feb 19, 2025
7e99d68
MLP changes
josephdviviano Feb 19, 2025
af26f6e
MLP changes
josephdviviano Feb 19, 2025
a6ffc72
fix magnitude issue
younik Feb 19, 2025
8ba06a1
changed docs
josephdviviano Feb 20, 2025
201553b
Merge branch 'graph-states' of https://github.com/alip67/torchgfn int…
josephdviviano Feb 20, 2025
a2d44d5
some bugfixes in extend of trajectories
josephdviviano Feb 20, 2025
bf92366
black
josephdviviano Feb 20, 2025
cfd4740
black
josephdviviano Feb 20, 2025
ab40005
removed print
josephdviviano Feb 20, 2025
8c83049
change types in env and states to allow for graph states
saleml Feb 20, 2025
7263ff2
merge recent chances
saleml Feb 20, 2025
e2ac8be
make batch_shape a property in general states, to make explicit the t…
saleml Feb 20, 2025
1543e49
fix batch_shape in GraphStates
saleml Feb 20, 2025
5dd5aba
fix batch_shape in GraphRing
saleml Feb 20, 2025
cda28e3
removed unused code from forward_masks
josephdviviano Feb 20, 2025
9ba3004
fix extends
younik Feb 20, 2025
3515c54
normal replay buffer can also be prioritized
josephdviviano Feb 22, 2025
b40118f
directed GCN working on small graphs
josephdviviano Feb 22, 2025
f7cf25e
Merge branch 'master' into pr/alip67/210
saleml Feb 23, 2025
799531f
Merge branch 'graph-states' of https://github.com/alip67/torchgfn int…
saleml Feb 23, 2025
de092a0
fix extend
younik Feb 24, 2025
39f91d3
add adjacency matrix module
saleml Feb 25, 2025
80b2a9e
Simplify ring graph reward calculation logic
saleml Feb 25, 2025
c24297d
fix forward mask
younik Feb 25, 2025
3c39064
add docstrings
saleml Feb 25, 2025
e43c937
Merge branch 'graph-states' of https://github.com/alip67/torchgfn int…
saleml Feb 25, 2025
0a1f8e1
switch from torch_geoemtric to TensorDict
younik Feb 26, 2025
382be45
update train_graph_ring
younik Feb 27, 2025
939bdab
fix all problems
younik Feb 27, 2025
0f57485
remove node_index
younik Feb 27, 2025
911157c
move dep
younik Feb 27, 2025
313562e
Merge branch 'graph-states' into graph-states
younik Feb 27, 2025
8806cda
fix pyproject.toml
hyeok9855 Feb 28, 2025
ce287eb
change state type hinting from Tensordict to torch_geometric Data
hyeok9855 Feb 28, 2025
7607945
add settings that achieve 95% in the ring generation (directed) with …
hyeok9855 Feb 28, 2025
e918ac7
rename test_state to test_graph_states
hyeok9855 Mar 3, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ classifiers = [
einops = ">=0.6.1"
numpy = ">=1.21.2"
python = "^3.10"
torch = ">=1.9.0"
torch = ">=2.6.0"
tensordict = ">=0.6.1"
torch_geometric = ">=2.6.1"

# dev dependencies.
black = { version = "24.3", optional = true }
Expand Down
155 changes: 155 additions & 0 deletions src/gfn/actions.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import annotations # This allows to use the class name in type hints

import enum
from abc import ABC
from math import prod
from typing import ClassVar, Sequence

import torch
from tensordict import TensorDict


class Actions(ABC):
Expand Down Expand Up @@ -170,3 +172,156 @@ def is_exit(self) -> torch.Tensor:
*self.batch_shape, *((1,) * len(self.__class__.action_shape))
)
return self.compare(exit_actions_tensor)


class GraphActionType(enum.IntEnum):
ADD_NODE = 0
ADD_EDGE = 1
EXIT = 2
DUMMY = 3


class GraphActions(Actions):
"""Actions for graph-based environments.

Each action is one of:
- ADD_NODE: Add a node with given features
- ADD_EDGE: Add an edge between two nodes with given features
- EXIT: Terminate the trajectory

Attributes:
features_dim: Dimension of node/edge features
tensor: TensorDict containing:
- action_type: Type of action (GraphActionType)
- features: Features for nodes/edges
- edge_index: Source/target nodes for edges
"""

features_dim: ClassVar[int]

def __init__(self, tensor: TensorDict):
"""Initializes a GraphAction object.

Args:
action: a GraphActionType indicating the type of action.
features: a tensor of shape (batch_shape, feature_shape) representing the features of the nodes or of the edges, depending on the action type.
In case of EXIT action, this can be None.
edge_index: an tensor of shape (batch_shape, 2) representing the edge to add.
This must defined if and only if the action type is GraphActionType.AddEdge.
"""
self.batch_shape = tensor["action_type"].shape
features = tensor.get("features", None)
if features is None:
assert torch.all(
torch.logical_or(
tensor["action_type"] == GraphActionType.EXIT,
tensor["action_type"] == GraphActionType.DUMMY,
)
)
features = torch.zeros((*self.batch_shape, self.features_dim))
edge_index = tensor.get("edge_index", None)
if edge_index is None:
assert torch.all(tensor["action_type"] != GraphActionType.ADD_EDGE)
edge_index = torch.zeros((*self.batch_shape, 2), dtype=torch.long)

self.tensor = TensorDict(
{
"action_type": tensor["action_type"],
"features": features,
"edge_index": edge_index,
},
batch_size=self.batch_shape,
)

def __repr__(self):
return f"""GraphAction object with {self.batch_shape} actions."""

@property
def device(self) -> torch.device | None:
"""Returns the device of the features tensor."""
return self.tensor.device

def __len__(self) -> int:
"""Returns the number of actions in the batch."""
return int(prod(self.batch_shape))

def __getitem__(self, index: int | Sequence[int] | Sequence[bool]) -> GraphActions:
"""Get particular actions of the batch."""
return GraphActions(self.tensor[index])

def __setitem__(
self, index: int | Sequence[int] | Sequence[bool], action: GraphActions
) -> None:
"""Set particular actions of the batch."""
self.tensor[index] = action.tensor

def compare(self, other: GraphActions) -> torch.Tensor:
"""Compares the actions to another GraphAction object.

Args:
other: GraphAction object to compare.

Returns: boolean tensor of shape batch_shape indicating whether the actions are equal.
"""
compare = torch.all(self.tensor == other.tensor, dim=-1)
return (
compare["action_type"]
& (compare["action_type"] == GraphActionType.EXIT | compare["features"])
& (
compare["action_type"]
!= GraphActionType.ADD_EDGE | compare["edge_index"]
)
)

@property
def is_exit(self) -> torch.Tensor:
"""Returns a boolean tensor of shape `batch_shape` indicating whether the actions are exit actions."""
return self.action_type == GraphActionType.EXIT

@property
def is_dummy(self) -> torch.Tensor:
"""Returns a boolean tensor of shape `batch_shape` indicating whether the actions are dummy actions."""
return self.action_type == GraphActionType.DUMMY

@property
def action_type(self) -> torch.Tensor:
"""Returns the action type tensor."""
return self.tensor["action_type"]

@property
def features(self) -> torch.Tensor:
"""Returns the features tensor."""
return self.tensor["features"]

@property
def edge_index(self) -> torch.Tensor:
"""Returns the edge index tensor."""
return self.tensor["edge_index"]

@classmethod
def make_dummy_actions(cls, batch_shape: tuple[int]) -> GraphActions:
"""Creates a GraphActions object of dummy actions with the given batch shape."""
return cls(
TensorDict(
{
"action_type": torch.full(
batch_shape, fill_value=GraphActionType.DUMMY
),
},
batch_size=batch_shape,
)
)

@classmethod
def make_exit_actions(cls, batch_shape: tuple[int]) -> Actions:
"""Creates an GraphActions object of exit actions with the given batch shape."""
return cls(
TensorDict(
{
"action_type": torch.full(
batch_shape, fill_value=GraphActionType.EXIT
),
},
batch_size=batch_shape,
)
)
2 changes: 1 addition & 1 deletion src/gfn/containers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .replay_buffer import PrioritizedReplayBuffer, ReplayBuffer
from .replay_buffer import NormBasedDiversePrioritizedReplayBuffer, ReplayBuffer
from .trajectories import Trajectories
from .transitions import Transitions
87 changes: 45 additions & 42 deletions src/gfn/containers/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,15 @@ class ReplayBuffer:
training_objects: the buffer of objects used for training.
terminating_states: a States class representation of $s_f$.
objects_type: the type of buffer (transitions, trajectories, or states).
prioritized: whether the buffer is prioritized by log_reward or not.
"""

def __init__(
self,
env: Env,
objects_type: Literal["transitions", "trajectories", "states"],
capacity: int = 1000,
prioritized: bool = False,
):
"""Instantiates a replay buffer.
Args:
Expand All @@ -53,34 +55,61 @@ def __init__(
raise ValueError(f"Unknown objects_type: {objects_type}")

self._is_full = False
self.prioritized = prioritized

def __repr__(self):
return f"ReplayBuffer(capacity={self.capacity}, containing {len(self)} {self.objects_type})"

def __len__(self):
return len(self.training_objects)

def add(self, training_objects: Transitions | Trajectories | tuple[States]):
def _add_objs(
self,
training_objects: Transitions | Trajectories | tuple[States],
):
"""Adds a training object to the buffer."""
terminating_states = None
if isinstance(training_objects, tuple):
assert self.objects_type == "states" and self.terminating_states is not None
training_objects, terminating_states = training_objects # pyright: ignore

to_add = len(training_objects)

self._is_full |= len(self) + to_add >= self.capacity

# Adds the objects to the buffer.
self.training_objects.extend(training_objects) # pyright: ignore

# Sort elements by logreward, capping the size at the defined capacity.
if self.prioritized:

if (
self.training_objects.log_rewards is None
or training_objects.log_rewards is None # pyright: ignore
):
raise ValueError("log_rewards must be defined for prioritized replay.")

# Ascending sort.
ix = torch.argsort(self.training_objects.log_rewards) # pyright: ignore
self.training_objects = self.training_objects[ix] # pyright: ignore
self.training_objects = self.training_objects[
-self.capacity :
-self.capacity : # Ascending sort, so we retain the final elements.
] # pyright: ignore

# Add the terminating states to the buffer.
if self.terminating_states is not None:
assert terminating_states is not None
self.terminating_states.extend(terminating_states) # pyright: ignore
self.terminating_states.extend(terminating_states)

# Sort terminating states by logreward as well.
if self.prioritized:
self.terminating_states = self.terminating_states[ix]

self.terminating_states = self.terminating_states[-self.capacity :]

def add(self, training_objects: Transitions | Trajectories | tuple[States]):
"""Adds a training object to the buffer."""
self._add_objs(training_objects)

def sample(self, n_trajectories: int) -> Transitions | Trajectories | tuple[States]:
"""Samples `n_trajectories` training objects from the buffer."""
if self.terminating_states is not None:
Expand Down Expand Up @@ -113,8 +142,8 @@ def load(self, directory: str):
)


class PrioritizedReplayBuffer(ReplayBuffer):
"""A replay buffer of trajectories or transitions.
class NormBasedDiversePrioritizedReplayBuffer(ReplayBuffer):
"""A replay buffer of trajectories or transitions with diverse trajectories.

Attributes:
env: the Environment instance.
Expand Down Expand Up @@ -152,53 +181,27 @@ def __init__(
super().__init__(env, objects_type, capacity)
self.cutoff_distance = cutoff_distance
self.p_norm_distance = p_norm_distance
self._prioritized = True

def _add_objs(
self,
training_objects: Transitions | Trajectories | tuple[States],
terminating_states: States | None = None,
):
"""Adds a training object to the buffer."""
# Adds the objects to the buffer.
self.training_objects.extend(training_objects) # pyright: ignore

# Sort elements by logreward, capping the size at the defined capacity.
ix = torch.argsort(self.training_objects.log_rewards) # pyright: ignore
self.training_objects = self.training_objects[ix] # pyright: ignore
self.training_objects = self.training_objects[
-self.capacity :
] # pyright: ignore

# Add the terminating states to the buffer.
if self.terminating_states is not None:
assert terminating_states is not None
self.terminating_states.extend(terminating_states)

# Sort terminating states by logreward as well.
self.terminating_states = self.terminating_states[ix]
self.terminating_states = self.terminating_states[-self.capacity :]
@property
def prioritized(self) -> bool:
return self._prioritized

def add(self, training_objects: Transitions | Trajectories | tuple[States]):
"""Adds a training object to the buffer."""
terminating_states = None
if isinstance(training_objects, tuple):
assert self.objects_type == "states" and self.terminating_states is not None
training_objects, terminating_states = training_objects # pyright: ignore

to_add = len(training_objects)
self._is_full |= len(self) + to_add >= self.capacity

# The buffer isn't full yet.
if len(self.training_objects) < self.capacity:
self._add_objs(training_objects, terminating_states)
self._add_objs(training_objects)

# Our buffer is full and we will prioritize diverse, high reward additions.
else:
if (
self.training_objects.log_rewards is None
or training_objects.log_rewards is None # pyright: ignore
):
raise ValueError("log_rewards must be defined for prioritized replay.")
terminating_states = None
if isinstance(training_objects, tuple):
assert self.objects_type == "states" and self.terminating_states is not None
training_objects, terminating_states = training_objects # pyright: ignore

# Sort the incoming elements by their logrewards.
ix = torch.argsort(
Expand Down
4 changes: 2 additions & 2 deletions src/gfn/containers/trajectories.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def __init__(
assert (
log_probs.shape == (self.max_length, self.n_trajectories)
and log_probs.dtype == torch.float
)
), f"log_probs.shape={log_probs.shape}, self.max_length={self.max_length}, self.n_trajectories={self.n_trajectories}"
else:
log_probs = torch.full(size=(0, 0), fill_value=0, dtype=torch.float)
self.log_probs: torch.Tensor = log_probs
Expand Down Expand Up @@ -256,7 +256,7 @@ def extend(self, other: Trajectories) -> None:

# TODO: The replay buffer is storing `dones` - this wastes a lot of space.
self.actions.extend(other.actions)
self.states.extend(other.states)
self.states.extend(other.states) # n_trajectories comes from this.
self.when_is_done = torch.cat((self.when_is_done, other.when_is_done), dim=0)

# For log_probs, we first need to make the first dimensions of self.log_probs
Expand Down
Loading
Loading