-
Notifications
You must be signed in to change notification settings - Fork 4
Description
Description of feature
Currently the GraphAnnDataModule randomly splits the PyG Data object according to a pre-defined train, val and test size. It would be good to allow the user to 1) define the train, val and test size and 2) allow for group stratification.
Regarding the second, for example if the Data objects are from specific groups (saved in .obs) then the group specific Data objects should be equally distributed across train, val and test set.
At the moment, to get a custom split the used needs to first split the AnnData e.g. by adding a .obs['split'] assigning each node either to train, val or test. Then loading a Data object for each split and passing them separate to the GraphAnnDataModule. For example what I am doing at the moment is:
# add new .obs['split'] assigning each .obs['FOV'] stratified per .obs['condition'] to train, val or test dataset
split_adata(adata, split_obs = 'FOV', stratify_group = 'condition')
# load train, val and test data seperately according to split variable
train_datas = load_geome(adata, split = 'train')
val_datas = load_geome(adata, split = 'val')
test_datas = load_geome(adata, split = 'test')
# create data module using train, val and test datasets
dm = GraphAnnDataModule(datas = [train_datas, val_datas, test_datas])
To reduce the steps it would be nice to load the split directly through the GraphAnnDataModule with customised settings. Alternatively, one could indicate a .obs['split'] variable that should be considered to split the datas instead of a RandomNodeSplit.