diff --git a/marc_db/ingest.py b/marc_db/ingest.py index 7ffcda0..4c444cd 100644 --- a/marc_db/ingest.py +++ b/marc_db/ingest.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Dict, Iterable, Optional, Tuple, Callable +from typing import Callable, Dict, Iterable, Optional, Union import pandas as pd from sqlalchemy.orm import Session @@ -38,6 +38,12 @@ def _ensure_required_columns(df: pd.DataFrame, required: Iterable[str]): raise ValueError(f"Missing required column(s): {', '.join(missing)}") +def _load_dataframe(data: Optional[Union[pd.DataFrame, Path, str]]) -> Optional[pd.DataFrame]: + if data is None or isinstance(data, pd.DataFrame): + return data + return pd.read_csv(Path(data), sep="\t") + + def _ingest_isolates(df: pd.DataFrame, session: Session): isolate_cols = [ "SampleID", @@ -202,6 +208,12 @@ def ingest_from_tsvs( """ created_session = False + isolates = _load_dataframe(isolates) + assemblies = _load_dataframe(assemblies) + assembly_qcs = _load_dataframe(assembly_qcs) + taxonomic_assignments = _load_dataframe(taxonomic_assignments) + contaminants = _load_dataframe(contaminants) + antimicrobials = _load_dataframe(antimicrobials) if session is None: session = get_session() created_session = True diff --git a/tests/test_ingest.py b/tests/test_ingest.py index 015badd..4944f71 100644 --- a/tests/test_ingest.py +++ b/tests/test_ingest.py @@ -57,3 +57,19 @@ def test_conflicting_duplicate_rows(): session.close() engine.dispose() + + +def test_ingest_accepts_path_strings(): + engine = create_engine("sqlite:///:memory:") + Session = sessionmaker(bind=engine) + session = Session() + Base.metadata.create_all(engine) + + tsv_path = str(data_dir / "test_multi_aliquot.tsv") + ingest_from_tsvs(isolates=tsv_path, yes=True, session=session) + + assert len(get_isolates(session)) == 2 + assert len(get_aliquots(session)) == 5 + + session.close() + engine.dispose()