Skip to content

GraphAnnDataModule with masking option for spatial_node_loader #60

@FrancescaDr

Description

@FrancescaDr

Description of feature

When initialised the GraphAnnDataModule there is a graph_loader and spatial_node_loader. Currently, they load the entire dataset as batches. For self-supervised learning strategies it is necessary to mask some nodes during training. Because of imbalanced datasets the masking or sampling should consider the different e.i. cell type proportions. I made an initial draft for a spatial node loader that adds a .mask to the PyG data object.

def smallest_data_batch_length(self, data_list: List['BaseData']):
        """Returns the number of nodes in the smallest graph from the list of BaseData."""
        lengths = [data.num_nodes for data in data_list]
        return min(lengths)

    def _spatial_node_loader(self, 
                             data_list: List[BaseData], 
                             shuffle: bool = False, 
                             **kwargs) -> DataListLoader:
        """Adds a one-node mask to each Data object. TODO: load each graph multiple times with a different mask.

        Args:
        ----
        data: PyTorch geometric.Batch
        shuffle (bool, optional): whether to shuffle the data. Defaults to False.
        kwargs: arguments passed to the pyg.NeighborLoader

        Returns
        -------
            NeighborLoader: the node dataloader
        """
        smallest_length = self.smallest_data_batch_length(data_list)
        num_nodes_to_mask = int(smallest_length * self.pct_mask_nodes)
        if num_nodes_to_mask == 0: # must mask at least one node
            num_nodes_to_mask = 1
        
        for data in data_list:
            if data.num_nodes < num_nodes_to_mask:
                raise ValueError("Cannot sample more nodes than available in any graph.")

            # Randomly select a ndoe to mask
            mask_indices = random.sample(range(data.num_nodes), num_nodes_to_mask)
            data.mask = torch.zeros(data.num_nodes, dtype=torch.bool)
            data.mask[mask_indices] = True

        return DataLoader(
            dataset=data_list,
            shuffle=shuffle,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            #collate_fn=collate_fn,
            **kwargs,
        ) 

Some challenges:

  1. My current draft does not consider different e.i. cell type proportions which means that some cell types are never sampled and predicted. I can think of two solutions: 1. mask by considering stratification .obs or 2. remember the already sampled cells to next time mask cells that haven't been sampled yet. Personally, I would prefer the second option because then all information / cells are used for prediction or learning.
  2. Graphs are of different sizes. Currently I solve it by using the smallest graph as reference as getting random values for it. The problem is then that for some graphs it only includes 1 node while other graphs in the batch are of size >1k.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions