Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
171 commits
Select commit Hold shift + click to select a range
9ae28b2
including Graphs as States for torchgfn
alip67 Nov 6, 2024
de6ab1c
add GraphEnv
younik Nov 7, 2024
24e23e8
add deps and reformat
younik Nov 7, 2024
1f7b220
add test, fix errors, add valid action check
younik Nov 8, 2024
63e4f1c
fix formatting
younik Nov 8, 2024
8034fb2
add GraphAction
younik Nov 14, 2024
d179671
fix batching mechanism
younik Nov 14, 2024
e018f4e
Merge branch 'GFNOrg:master' into graph-states
alip67 Nov 15, 2024
7ff96d5
add support for EXIT action
younik Nov 16, 2024
cf482da
Merge branch 'graph-states' of https://github.com/alip67/torchgfn int…
younik Nov 16, 2024
dacbbf7
add GraphActionPolicyEstimator
younik Nov 19, 2024
98ea448
Merge branch 'GFNOrg:master' into graph-states
alip67 Nov 19, 2024
e74e500
Sampler integration work
younik Nov 22, 2024
a862bb4
Merge branch 'graph-states' of https://github.com/alip67/torchgfn int…
younik Nov 22, 2024
5e64c84
use TensorDict
younik Nov 26, 2024
81f8b71
solve some errors
younik Nov 28, 2024
34781ef
use tensordict in actions
younik Nov 28, 2024
3e584f2
handle sf
younik Dec 2, 2024
d5e438f
remove Data
younik Dec 3, 2024
fba5d50
categorical action type
younik Dec 6, 2024
478bd14
change batching
younik Dec 10, 2024
dd80f28
fix stacking
younik Dec 11, 2024
616551c
fix graph stacking
younik Dec 11, 2024
77611d4
fix test graph env
younik Dec 12, 2024
5874ff6
add ring example
younik Dec 19, 2024
9d42332
remove check edge_features
younik Dec 20, 2024
2d44242
fix GraphStates set
younik Dec 20, 2024
173d4fb
remove debug
younik Dec 20, 2024
7265857
fix add_edge action
younik Dec 20, 2024
2b3208f
fix edge_index after get
younik Dec 20, 2024
b84246f
push updated code
younik Dec 22, 2024
fa0d22a
add rendering
younik Dec 27, 2024
27d192a
fix gradient propagation
younik Jan 6, 2025
5d99739
Merge remote-tracking branch 'origin/master' into graph-states
younik Jan 12, 2025
f4fc3ab
fix formatting
younik Jan 12, 2025
8f1c62c
address comments
younik Jan 12, 2025
6482834
fix test
younik Jan 12, 2025
6db601d
fix test
younik Jan 13, 2025
c7f8243
fix pre-commit
younik Jan 13, 2025
c3df427
Merge remote-tracking branch 'origin/master' into graph-states
younik Jan 13, 2025
78b729a
fix merging issues
younik Jan 13, 2025
38dd2b0
fix toml
younik Jan 13, 2025
12c49b7
add dep & address issue
younik Jan 13, 2025
fe237ed
fix toml
younik Jan 13, 2025
9bbc48d
fix pyproject.toml
younik Jan 13, 2025
5e4fc4e
address comments
younik Jan 14, 2025
705b4cc
add tests for action
younik Jan 15, 2025
d765330
fix test after added dummy action
younik Jan 15, 2025
4ee6987
add GraphPreprocessor
younik Jan 15, 2025
fe9713c
added TODO
younik Jan 15, 2025
1425eb6
add complete masks
younik Jan 19, 2025
36c42ec
pre-commit hook
younik Jan 19, 2025
5747e97
adress comments
younik Jan 19, 2025
e9f9951
pre-commit
younik Jan 19, 2025
406cfca
address comments
younik Jan 20, 2025
08e519b
fix ring example
younik Jan 24, 2025
17e07ad
make edge_index global
younik Jan 29, 2025
46e3698
make edge_index global
younik Jan 29, 2025
1c33d98
Merge remote-tracking branch 'origin/graph-states-fix' into graph-states
younik Jan 29, 2025
e6d909b
fix test_env
younik Jan 29, 2025
da66adb
add global edge + pair programming session
younik Feb 4, 2025
7130281
pair programming session
younik Feb 5, 2025
1fa01df
fix node_index
younik Feb 6, 2025
48ceafd
add comments
younik Feb 10, 2025
2ace3ca
Merge remote-tracking branch 'origin/master' into graph-states
younik Feb 10, 2025
2271732
fix linter
younik Feb 10, 2025
4be8f9b
add two tested RingPolicyEstimator
younik Feb 11, 2025
e7465b8
push tentatives
younik Feb 11, 2025
38041e8
trying TBGFN
younik Feb 12, 2025
2f4b3f9
increased capacity of RingPolicyEstimator
josephdviviano Feb 14, 2025
d0313d1
make undirected graph
younik Feb 17, 2025
f29432e
Merge remote-tracking branch 'alip67/graph-states' into graph-states
younik Feb 17, 2025
9d994da
merge over
josephdviviano Feb 18, 2025
6e1bbc2
merged
josephdviviano Feb 18, 2025
0625f5c
allows for configurable policy capacity, and the code simultaneously …
josephdviviano Feb 18, 2025
dae3299
spelling
josephdviviano Feb 18, 2025
671917b
allows for configurable policy capacity, and the code simultaneously …
josephdviviano Feb 18, 2025
623f50e
directed and undirected graphs now co-implemented -- but the model wi…
josephdviviano Feb 19, 2025
74d9b13
black
josephdviviano Feb 19, 2025
9b7ce68
change docstring and fix argument
saleml Feb 19, 2025
704cf66
add the possibility of layernorm in MLP
saleml Feb 19, 2025
7bbd72d
remove create_mlp function
saleml Feb 19, 2025
3e84ebd
Merge branch 'graph-states' of https://github.com/alip67/torchgfn int…
saleml Feb 19, 2025
7bdf5b5
new architecture
josephdviviano Feb 19, 2025
7e99d68
MLP changes
josephdviviano Feb 19, 2025
af26f6e
MLP changes
josephdviviano Feb 19, 2025
a6ffc72
fix magnitude issue
younik Feb 19, 2025
8ba06a1
changed docs
josephdviviano Feb 20, 2025
201553b
Merge branch 'graph-states' of https://github.com/alip67/torchgfn int…
josephdviviano Feb 20, 2025
a2d44d5
some bugfixes in extend of trajectories
josephdviviano Feb 20, 2025
bf92366
black
josephdviviano Feb 20, 2025
cfd4740
black
josephdviviano Feb 20, 2025
ab40005
removed print
josephdviviano Feb 20, 2025
8c83049
change types in env and states to allow for graph states
saleml Feb 20, 2025
7263ff2
merge recent chances
saleml Feb 20, 2025
e2ac8be
make batch_shape a property in general states, to make explicit the t…
saleml Feb 20, 2025
1543e49
fix batch_shape in GraphStates
saleml Feb 20, 2025
5dd5aba
fix batch_shape in GraphRing
saleml Feb 20, 2025
cda28e3
removed unused code from forward_masks
josephdviviano Feb 20, 2025
9ba3004
fix extends
younik Feb 20, 2025
3515c54
normal replay buffer can also be prioritized
josephdviviano Feb 22, 2025
b40118f
directed GCN working on small graphs
josephdviviano Feb 22, 2025
f7cf25e
Merge branch 'master' into pr/alip67/210
saleml Feb 23, 2025
799531f
Merge branch 'graph-states' of https://github.com/alip67/torchgfn int…
saleml Feb 23, 2025
de092a0
fix extend
younik Feb 24, 2025
39f91d3
add adjacency matrix module
saleml Feb 25, 2025
80b2a9e
Simplify ring graph reward calculation logic
saleml Feb 25, 2025
c24297d
fix forward mask
younik Feb 25, 2025
3c39064
add docstrings
saleml Feb 25, 2025
e43c937
Merge branch 'graph-states' of https://github.com/alip67/torchgfn int…
saleml Feb 25, 2025
4ae403f
tweaks
josephdviviano Feb 25, 2025
4a8da05
black
josephdviviano Feb 25, 2025
0fdaa3c
merged
josephdviviano Feb 25, 2025
0a1f8e1
switch from torch_geoemtric to TensorDict
younik Feb 26, 2025
382be45
update train_graph_ring
younik Feb 27, 2025
939bdab
fix all problems
younik Feb 27, 2025
0f57485
remove node_index
younik Feb 27, 2025
911157c
move dep
younik Feb 27, 2025
313562e
Merge branch 'graph-states' into graph-states
younik Feb 27, 2025
8806cda
fix pyproject.toml
hyeok9855 Feb 28, 2025
ce287eb
change state type hinting from Tensordict to torch_geometric Data
hyeok9855 Feb 28, 2025
7607945
add settings that achieve 95% in the ring generation (directed) with …
hyeok9855 Feb 28, 2025
e918ac7
rename test_state to test_graph_states
hyeok9855 Mar 3, 2025
13d5b6f
hyperparams
josephdviviano Mar 3, 2025
5f12f86
hyperparams
josephdviviano Mar 3, 2025
fb4445c
renamed Batch to GeometricBatch, Data to GeometricData
josephdviviano Mar 3, 2025
461fa2d
renamed Batch to GeometricBatch, Data to GeometricData
josephdviviano Mar 3, 2025
0af5747
docstring
josephdviviano Mar 4, 2025
e717d04
fixed imports
josephdviviano Mar 4, 2025
be148ae
merge
josephdviviano Mar 4, 2025
3372702
Merge pull request #250 from younik/graph-states
hyeok9855 Mar 4, 2025
3a3e41d
Merge branch 'graph-state-pyg' of github.com:GFNOrg/torchgfn into for…
josephdviviano Mar 4, 2025
8e496d8
batch_shape as tuple for all Graph things
hyeok9855 Mar 4, 2025
2b076e0
remove since it's redundant with
hyeok9855 Mar 4, 2025
eb7c9ca
apply black
hyeok9855 Mar 4, 2025
51518ab
apply black
hyeok9855 Mar 4, 2025
f90fa7f
resolve flake-8 issues
hyeok9855 Mar 4, 2025
67349b6
fix ring exp
hyeok9855 Mar 4, 2025
66ed460
a trivial change
josephdviviano Mar 4, 2025
2644a06
from_batch_shape inherited from parent class
josephdviviano Mar 4, 2025
3b44bb1
is_discrete is inherited from parent
josephdviviano Mar 4, 2025
3cf06f1
calling super instead
josephdviviano Mar 4, 2025
2e9b035
merge conflicts resolves
josephdviviano Mar 4, 2025
96394d7
black
josephdviviano Mar 4, 2025
e1fad74
black pyright and formatting
josephdviviano Mar 4, 2025
5105837
black pyright and formatting
josephdviviano Mar 4, 2025
93ebcc1
removed lines
josephdviviano Mar 4, 2025
7da0ab5
saved changes
josephdviviano Mar 4, 2025
f5b8c98
merged master
josephdviviano Mar 4, 2025
310e8d7
massive merge of pyright / black / testing errors -- one outstanding …
josephdviviano Mar 4, 2025
06f87d2
flake
josephdviviano Mar 5, 2025
b8a1aec
isort
josephdviviano Mar 5, 2025
efcb28b
isort
josephdviviano Mar 5, 2025
50b3cda
black
josephdviviano Mar 5, 2025
4abeadd
Merge pull request #253 from GFNOrg/graph-state-pyg
josephdviviano Mar 5, 2025
f3b1a17
add assertion for batch_shape in GraphStates
hyeok9855 Mar 5, 2025
2a307f1
use fixture in test_graph_states.py
hyeok9855 Mar 5, 2025
9c1a79b
add tests for and with 2d batch shape
hyeok9855 Mar 5, 2025
873981a
notes
josephdviviano Mar 5, 2025
282ebb8
remove redundancy and some minor refactorings
hyeok9855 Mar 6, 2025
afe2368
do not allow None batch_shape for env.reset
hyeok9855 Mar 6, 2025
9f6cf71
Fix the GraphStates indexing to match the torch.Tensor indexing
hyeok9855 Mar 6, 2025
1c60068
Add tests for GraphStates.log_rewards and make sure to use the setter…
hyeok9855 Mar 6, 2025
51fc0ba
Add tests for GraphStates.log_rewards (missed in the last commit)
hyeok9855 Mar 6, 2025
5f45d8f
added name back
josephdviviano Mar 6, 2025
d16aa61
cleanup of replay buffer
josephdviviano Mar 6, 2025
1c61f05
added not implemented methods
josephdviviano Mar 6, 2025
a6cd815
minor refactorings
hyeok9855 Mar 6, 2025
06cecfc
remove useless pyright ignore
hyeok9855 Mar 6, 2025
6c8b54e
fixed E704
josephdviviano Mar 6, 2025
cf14c8e
merge
josephdviviano Mar 6, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ einops = ">=0.6.1"
numpy = ">=1.21.2"
python = "^3.10"
torch = ">=2.6.0"
tensordict = ">=0.6.1"
torch_geometric = ">=2.6.1"

# dev dependencies.
black = { version = "24.3", optional = true }
Expand Down Expand Up @@ -99,17 +101,22 @@ include = '\.pyi?$'
extend-exclude = '''/(\.git|\.hg|\.mypy_cache|\.ipynb|\.tox|\.venv|build)/g'''

[tool.pyright]
pythonVersion = "3.10"
include = ["src/gfn", "tutorials/examples", "testing"] # Removed ** globstars
exclude = [
"**/node_modules",
"**/__pycache__",
"**/.*", # Exclude dot files and folders
]

strict = [

]
# This is required as the CI pre-commit does not dl the module (i.e. numpy)
# Therefore, we have to ignore missing imports
# Removed "strict": [], as it's redundant with typeCheckingMode

typeCheckingMode = "basic"
pythonVersion = "3.10"

# Removed enableTypeIgnoreComments, not available in pyproject.toml, and bad practice.

Expand All @@ -124,6 +131,9 @@ reportUntypedFunctionDecorator = "none"
reportMissingTypeStubs = false
reportUnboundVariable = "warning"
reportGeneralTypeIssues = "none"
reportAttributeAccessIssue = false

[tool.pytest.ini_options]
reportOptionalMemberAccess = "error"
reportArgumentType = "error" #This setting doesn't exist, removed.

Expand Down
146 changes: 143 additions & 3 deletions src/gfn/actions.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import annotations # This allows to use the class name in type hints

import enum
from abc import ABC
from math import prod
from typing import ClassVar, List, Sequence

import torch
from tensordict import TensorDict


class Actions(ABC):
Expand Down Expand Up @@ -134,7 +136,7 @@ def extend_with_dummy_actions(self, required_first_dim: int) -> None:
"extend_with_dummy_actions is only implemented for bi-dimensional actions."
)

def compare(self, other: torch.Tensor) -> torch.Tensor:
def _compare(self, other: torch.Tensor) -> torch.Tensor:
"""Compares the actions to a tensor of actions.

Args:
Expand All @@ -161,12 +163,150 @@ def is_dummy(self) -> torch.Tensor:
dummy_actions_tensor = self.__class__.dummy_action.repeat(
*self.batch_shape, *((1,) * len(self.__class__.action_shape))
)
return self.compare(dummy_actions_tensor)
return self._compare(dummy_actions_tensor)

@property
def is_exit(self) -> torch.Tensor:
"""Returns a boolean tensor of shape `batch_shape` indicating whether the actions are exit actions."""
exit_actions_tensor = self.__class__.exit_action.repeat(
*self.batch_shape, *((1,) * len(self.__class__.action_shape))
)
return self.compare(exit_actions_tensor)
return self._compare(exit_actions_tensor)


class GraphActionType(enum.IntEnum):
ADD_NODE = 0
ADD_EDGE = 1
EXIT = 2
DUMMY = 3


class GraphActions(Actions):
"""Actions for graph-based environments.

Each action is one of:
- ADD_NODE: Add a node with given features
- ADD_EDGE: Add an edge between two nodes with given features
- EXIT: Terminate the trajectory

Attributes:
features_dim: Dimension of node/edge features
tensor: TensorDict containing:
- action_type: Type of action (GraphActionType)
- features: Features for nodes/edges
- edge_index: Source/target nodes for edges
"""

features_dim: ClassVar[int]

def __init__(self, tensor: TensorDict):
"""Initializes a GraphAction object.

Args:
action: a GraphActionType indicating the type of action.
features: a tensor of shape (batch_shape, feature_shape) representing the features
of the nodes or of the edges, depending on the action type. In case of EXIT
action, this can be None.
edge_index: an tensor of shape (batch_shape, 2) representing the edge to add.
This must defined if and only if the action type is GraphActionType.AddEdge.
"""
self.batch_shape = tensor["action_type"].shape
features = tensor.get("features", None)
if features is None:
assert torch.all(
torch.logical_or(
tensor["action_type"] == GraphActionType.EXIT,
tensor["action_type"] == GraphActionType.DUMMY,
)
)
features = torch.zeros((*self.batch_shape, self.features_dim))
edge_index = tensor.get("edge_index", None)
if edge_index is None:
assert torch.all(tensor["action_type"] != GraphActionType.ADD_EDGE)
edge_index = torch.zeros((*self.batch_shape, 2), dtype=torch.long)

self.tensor = TensorDict(
{
"action_type": tensor["action_type"],
"features": features,
"edge_index": edge_index,
},
batch_size=self.batch_shape,
)

def __repr__(self):
return f"""GraphAction object with {self.batch_shape} actions."""

def _compare(self, other: GraphActions) -> torch.Tensor:
"""Compares the actions to another GraphAction object.

Args:
other: GraphAction object to compare.

Returns: boolean tensor of shape batch_shape indicating whether the actions are equal.
"""
action_compare = torch.all(
self.tensor["action_type"] == other.tensor["action_type"]
)
exit_compare = (
torch.all(self.tensor["features"] == other.tensor["features"])
| action_compare
== GraphActionType.EXIT
)
edge_compare = (action_compare != GraphActionType.ADD_EDGE) | (
torch.all(self.tensor["edge_index"] == other.tensor["edge_index"])
)
return action_compare & exit_compare & edge_compare

@property
def is_exit(self) -> torch.Tensor:
"""Returns a boolean tensor of shape `batch_shape` indicating whether the actions are exit actions."""
return self.action_type == GraphActionType.EXIT

@property
def is_dummy(self) -> torch.Tensor:
"""Returns a boolean tensor of shape `batch_shape` indicating whether the actions are dummy actions."""
return self.action_type == GraphActionType.DUMMY

@property
def action_type(self) -> torch.Tensor:
"""Returns the action type tensor."""
return self.tensor["action_type"]

@property
def features(self) -> torch.Tensor:
"""Returns the features tensor."""
return self.tensor["features"]

@property
def edge_index(self) -> torch.Tensor:
"""Returns the edge index tensor."""
return self.tensor["edge_index"]

@classmethod
def make_dummy_actions(cls, batch_shape: tuple[int]) -> GraphActions:
"""Creates a GraphActions object of dummy actions with the given batch shape."""
return cls(
TensorDict(
{
"action_type": torch.full(
batch_shape, fill_value=GraphActionType.DUMMY
),
},
batch_size=batch_shape,
)
)

@classmethod
def make_exit_actions(cls, batch_shape: tuple[int]) -> Actions:
"""Creates an GraphActions object of exit actions with the given batch shape."""
return cls(
TensorDict(
{
"action_type": torch.full(
batch_shape, fill_value=GraphActionType.EXIT
),
},
batch_size=batch_shape,
)
)
4 changes: 2 additions & 2 deletions src/gfn/containers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from .base import Container
from .replay_buffer import PrioritizedReplayBuffer, ReplayBuffer
from .replay_buffer import NormBasedDiversePrioritizedReplayBuffer, ReplayBuffer
from .state_pairs import StatePairs
from .trajectories import Trajectories
from .transitions import Transitions

__all__ = [
"PrioritizedReplayBuffer",
"NormBasedDiversePrioritizedReplayBuffer",
"ReplayBuffer",
"StatePairs",
"Trajectories",
Expand Down
Loading