Skip to content

GraphAnnDataModule with custom split #57

@FrancescaDr

Description

@FrancescaDr

Description of feature

Currently the GraphAnnDataModule randomly splits the PyG Data object according to a pre-defined train, val and test size. It would be good to allow the user to 1) define the train, val and test size and 2) allow for group stratification.

Regarding the second, for example if the Data objects are from specific groups (saved in .obs) then the group specific Data objects should be equally distributed across train, val and test set.

At the moment, to get a custom split the used needs to first split the AnnData e.g. by adding a .obs['split'] assigning each node either to train, val or test. Then loading a Data object for each split and passing them separate to the GraphAnnDataModule. For example what I am doing at the moment is:

# add new .obs['split'] assigning each .obs['FOV'] stratified per .obs['condition'] to train, val or test dataset
split_adata(adata, split_obs = 'FOV', stratify_group = 'condition')
# load train, val and test data seperately according to split variable
train_datas = load_geome(adata, split = 'train')
val_datas = load_geome(adata, split = 'val')
test_datas = load_geome(adata, split = 'test')
# create data module using train, val and test datasets
dm = GraphAnnDataModule(datas = [train_datas, val_datas, test_datas])

To reduce the steps it would be nice to load the split directly through the GraphAnnDataModule with customised settings. Alternatively, one could indicate a .obs['split'] variable that should be considered to split the datas instead of a RandomNodeSplit.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions