-
Notifications
You must be signed in to change notification settings - Fork 4
Open
Labels
enhancementNew feature or requestNew feature or request
Description
Description of feature
When initialised the GraphAnnDataModule there is a graph_loader and spatial_node_loader. Currently, they load the entire dataset as batches. For self-supervised learning strategies it is necessary to mask some nodes during training. Because of imbalanced datasets the masking or sampling should consider the different e.i. cell type proportions. I made an initial draft for a spatial node loader that adds a .mask to the PyG data object.
def smallest_data_batch_length(self, data_list: List['BaseData']):
"""Returns the number of nodes in the smallest graph from the list of BaseData."""
lengths = [data.num_nodes for data in data_list]
return min(lengths)
def _spatial_node_loader(self,
data_list: List[BaseData],
shuffle: bool = False,
**kwargs) -> DataListLoader:
"""Adds a one-node mask to each Data object. TODO: load each graph multiple times with a different mask.
Args:
----
data: PyTorch geometric.Batch
shuffle (bool, optional): whether to shuffle the data. Defaults to False.
kwargs: arguments passed to the pyg.NeighborLoader
Returns
-------
NeighborLoader: the node dataloader
"""
smallest_length = self.smallest_data_batch_length(data_list)
num_nodes_to_mask = int(smallest_length * self.pct_mask_nodes)
if num_nodes_to_mask == 0: # must mask at least one node
num_nodes_to_mask = 1
for data in data_list:
if data.num_nodes < num_nodes_to_mask:
raise ValueError("Cannot sample more nodes than available in any graph.")
# Randomly select a ndoe to mask
mask_indices = random.sample(range(data.num_nodes), num_nodes_to_mask)
data.mask = torch.zeros(data.num_nodes, dtype=torch.bool)
data.mask[mask_indices] = True
return DataLoader(
dataset=data_list,
shuffle=shuffle,
batch_size=self.batch_size,
num_workers=self.num_workers,
#collate_fn=collate_fn,
**kwargs,
)
Some challenges:
- My current draft does not consider different e.i. cell type proportions which means that some cell types are never sampled and predicted. I can think of two solutions: 1. mask by considering stratification .obs or 2. remember the already sampled cells to next time mask cells that haven't been sampled yet. Personally, I would prefer the second option because then all information / cells are used for prediction or learning.
- Graphs are of different sizes. Currently I solve it by using the smallest graph as reference as getting random values for it. The problem is then that for some graphs it only includes 1 node while other graphs in the batch are of size >1k.
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
enhancementNew feature or requestNew feature or request