Conversation
…o graph-states
…o graph-states
switch from TensorDict to torch_geometric
|
This issue is resolved in the following commits! While adding more tests, I found that the a = torch.tensor([1, 2, 3])
# a.shape: torch.size([3])
# a[0].shape: torch.size([])
b = torch.tensor([[1, 2], [4, 5]])
# b.shape: torch.size([2, 2])
# b[0].shape: torch.size([2])
# b[1, 1].shape: torch.size([])
states = GraphState(...)
states.batch_shape = (3,)
# states[0].batch_shape = (1,) <--- rather than (,) or None
states_2d = GraphState(...)
states_2d.batch_shape = (2, 2)
# states_2d[0].batch_shape = ??? <--- What's the desired shape?
# states_2d[1, 1].batch_shape = ??? <--- What's the desired shape?My last commit assumes Any suggestions on this? |
|
Thanks @hyeok9855 - good catch. I think we should match the behaviour of indexing a tensor exactly, so that downstream functions can be agnostic (as much as possible) as to whether |
josephdviviano
left a comment
There was a problem hiding this comment.
I've checked this over and pushed a few minuscule changes. Otherwise LGTM.
hyeok9855
left a comment
There was a problem hiding this comment.
Approved!
I left a few concerns, which I think we should resolve in the following PRs immediately.
| assert isinstance( | ||
| self.env.s0, torch.Tensor | ||
| ), "not supported for Graph trajectories." | ||
| assert isinstance( | ||
| self.env.sf, torch.Tensor | ||
| ), "not supported for Graph trajectories." | ||
|
|
There was a problem hiding this comment.
TODO: To support GraphStates, maybe utilize States.__repr__?
| assert isinstance( | ||
| self.env.s0, torch.Tensor | ||
| ), "reverse_backward_trajectories not supported for Graph trajectories" |
There was a problem hiding this comment.
TODO: To support GraphStates and GraphAction, I think we should implement .transpose in each state or action class.
There was a problem hiding this comment.
Could you elaborate on this one please?
There was a problem hiding this comment.
I'm concerned that .transpose() on GraphStates might be a very slow operation (will essentially need a for loop if I understand correctly). We could do it, but it would be the kind of operation we shouldn't encourage... I'm not sure if it's a good idea for this reason.
| @property | ||
| def batch_shape(self) -> tuple[int, ...]: | ||
| return self._batch_shape | ||
|
|
||
| @batch_shape.setter | ||
| def batch_shape(self, batch_shape: tuple[int, ...]) -> None: | ||
| self._batch_shape = batch_shape |
There was a problem hiding this comment.
I think the .batch_shape should be equivalent to .tensor.shape or .tensor.batch_shape for graphs.
If so, why do we need this setter for the batch_shape?
My suggestion is:
@property
def batch_shape(self) -> tuple[int, ...]:
return tuple(self.tensor.shape)[: -len(self.state_shape)]There was a problem hiding this comment.
This is a good point. I'm not sure why we have setter for this. I suggest we investigate in another issue
saleml
left a comment
There was a problem hiding this comment.
Thank you for the great work.
A continuation of the PR here:
#210
Description:
Unlike the current States object that necessitates appending dummy states to batch trajectories of varying lengths, our approach aims to support Trajectories through a nested Batch object representation. The Data class in Torch Geometric represents the graph structure, while the Batch class, which encapsulates batching of Data objects and their efficient indexing, represents the GraphStates object.
The current implementation of Trajectory supports the indexing dimensions: (Num time steps, Num trajectories, State Size). By using a nested Batch of Batch object to represent state Trajectories, the indexing would inherently take the form (Num trajectories, Num timesteps, State size). This approach requires implementing logic within getitem() and setitem() to internally.