Skip to content

Graph states#252

Merged
josephdviviano merged 171 commits intomasterfrom
graph-states
Mar 7, 2025
Merged

Graph states#252
josephdviviano merged 171 commits intomasterfrom
graph-states

Conversation

@josephdviviano
Copy link
Collaborator

A continuation of the PR here:

#210

Description:
Unlike the current States object that necessitates appending dummy states to batch trajectories of varying lengths, our approach aims to support Trajectories through a nested Batch object representation. The Data class in Torch Geometric represents the graph structure, while the Batch class, which encapsulates batching of Data objects and their efficient indexing, represents the GraphStates object.

The current implementation of Trajectory supports the indexing dimensions: (Num time steps, Num trajectories, State Size). By using a nested Batch of Batch object to represent state Trajectories, the indexing would inherently take the form (Num trajectories, Num timesteps, State size). This approach requires implementing logic within getitem() and setitem() to internally.

@hyeok9855
Copy link
Collaborator

hyeok9855 commented Mar 5, 2025

This issue is resolved in the following commits!


While adding more tests, I found that the GraphState's __getitem__ and __setitem__ have different behavior from torch.Tensor's. Specifically, ours keep dimensions even when indexing with an integer.
For example,

a = torch.tensor([1, 2, 3])
# a.shape: torch.size([3])
# a[0].shape: torch.size([])

b = torch.tensor([[1, 2], [4, 5]])
# b.shape: torch.size([2, 2])
# b[0].shape: torch.size([2])
# b[1, 1].shape: torch.size([])

states = GraphState(...)
states.batch_shape = (3,)
# states[0].batch_shape = (1,)   <--- rather than (,) or None

states_2d = GraphState(...)
states_2d.batch_shape = (2, 2)
# states_2d[0].batch_shape = ???   <--- What's the desired shape?
# states_2d[1, 1].batch_shape = ???   <--- What's the desired shape?

My last commit assumes states_2d[0].batch_shape = (1, 2) and states_2d[1, 1].batch_shape = (1,), but this doesn't seem to be an optimal design choice.

Any suggestions on this?

@josephdviviano
Copy link
Collaborator Author

Thanks @hyeok9855 - good catch. I think we should match the behaviour of indexing a tensor exactly, so that downstream functions can be agnostic (as much as possible) as to whether self.tensor is a torch.Tensor or a GeometricBatch / GeometricData instance.

Copy link
Collaborator Author

@josephdviviano josephdviviano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've checked this over and pushed a few minuscule changes. Otherwise LGTM.

Copy link
Collaborator

@hyeok9855 hyeok9855 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approved!

I left a few concerns, which I think we should resolve in the following PRs immediately.

Comment on lines +125 to +131
assert isinstance(
self.env.s0, torch.Tensor
), "not supported for Graph trajectories."
assert isinstance(
self.env.sf, torch.Tensor
), "not supported for Graph trajectories."

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: To support GraphStates, maybe utilize States.__repr__?

Comment on lines +500 to +502
assert isinstance(
self.env.s0, torch.Tensor
), "reverse_backward_trajectories not supported for Graph trajectories"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: To support GraphStates and GraphAction, I think we should implement .transpose in each state or action class.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you elaborate on this one please?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm concerned that .transpose() on GraphStates might be a very slow operation (will essentially need a for loop if I understand correctly). We could do it, but it would be the kind of operation we shouldn't encourage... I'm not sure if it's a good idea for this reason.

Comment on lines +78 to +84
@property
def batch_shape(self) -> tuple[int, ...]:
return self._batch_shape

@batch_shape.setter
def batch_shape(self, batch_shape: tuple[int, ...]) -> None:
self._batch_shape = batch_shape
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the .batch_shape should be equivalent to .tensor.shape or .tensor.batch_shape for graphs.
If so, why do we need this setter for the batch_shape?
My suggestion is:

@property
def batch_shape(self) -> tuple[int, ...]:
    return tuple(self.tensor.shape)[: -len(self.state_shape)]

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good point. I'm not sure why we have setter for this. I suggest we investigate in another issue

Copy link
Collaborator

@saleml saleml left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the great work.

@josephdviviano josephdviviano merged commit 8725410 into master Mar 7, 2025
2 checks passed
@josephdviviano josephdviviano deleted the graph-states branch March 11, 2025 15:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants