Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/cell_load/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@ class ExperimentConfig:
# Fewshot perturbation assignments (dataset.celltype -> {split: [perts]})
fewshot: dict[str, dict[str, list[str]]]

# Path to h5ad CSV summary file for gene/protein embedding mapping
h5ad_csv_path: str = ""

# Thresholds for raw count heuristics
RAW_COUNT_HEURISTIC_THRESHOLD: int = 35
EXPONENTIATED_UMIS_LIMIT: int = 5_000_000

@classmethod
def from_toml(cls, toml_path: str) -> "ExperimentConfig":
"""Load configuration from TOML file."""
Expand All @@ -36,6 +43,11 @@ def from_toml(cls, toml_path: str) -> "ExperimentConfig":
training=config.get("training", {}),
zeroshot=config.get("zeroshot", {}),
fewshot=config.get("fewshot", {}),
h5ad_csv_path=config.get("h5ad_csv_path", ""),
RAW_COUNT_HEURISTIC_THRESHOLD=config.get(
"RAW_COUNT_HEURISTIC_THRESHOLD", 1000
),
EXPONENTIATED_UMIS_LIMIT=config.get("EXPONENTIATED_UMIS_LIMIT", 1000000),
)

def get_all_datasets(self) -> Set[str]:
Expand Down
59 changes: 59 additions & 0 deletions src/cell_load/data_modules/cell_sentence_dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from torch.utils.data import DataLoader
from cell_load.dataset.cell_sentence_dataset import FilteredGenesCounts
from cell_load.dataset.cell_sentence_dataset import CellSentenceCollator


def create_dataloader(
cfg,
workers=1,
data_dir=None,
datasets=None,
shape_dict=None,
adata=None,
adata_name=None,
shuffle=False,
sentence_collator=None,
):
"""
Expected to be used for inference
Either datasets and shape_dict or adata and adata_name should be provided
"""
if datasets is None and adata is None:
raise ValueError(
"Either datasets and shape_dict or adata and adata_name should be provided"
)

if adata is not None:
shuffle = False

if data_dir:
cfg.model.data_dir = data_dir
# ? utils.get_dataset_cfg(cfg).data_dir = data_dir

dataset = FilteredGenesCounts(
cfg,
datasets=datasets,
shape_dict=shape_dict,
adata=adata,
adata_name=adata_name,
)
if sentence_collator is None:
sentence_collator = CellSentenceCollator(
cfg,
valid_gene_mask=dataset.valid_gene_index,
ds_emb_mapping_inference=dataset.ds_emb_map,
is_train=False,
)

# validation should not use cell augmentations
sentence_collator.training = False

dataloader = DataLoader(
dataset,
batch_size=cfg.model.batch_size,
shuffle=shuffle,
collate_fn=sentence_collator,
num_workers=workers,
persistent_workers=True,
)
return dataloader
Loading