diff --git a/CHANGELOG.md b/CHANGELOG.md index 2c78d30..4b0fcda 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,44 @@ All notable changes to this project will be documented in this file. +## [1.0.0] - 2025-09-04 + +### Bug Fixes + +- Change `info` to `core` for mid-level lydata columns +- Use bug-fixed lydata `.ly.enhance()` method. + +### Documentation + +- Update documentation for `integrate` and `evidence` commands. +- Improve `data collect` description. + +### Features + +- Add `integrate` command for thermodynamic integration. Thanks [@noemibuehrer]! +- Add command spawning websever for interactive data collection. + +### Miscellaneous Tasks + +- Add missing links to changelog. +- Add CITATION.cff. + +### Testing + +- Update tests for new lydata. + +### Build + +- Add uvicorn, fastapi to deps. +- Require at least lydata 0.4.0. + +### Change + +- Make compatible with new lyDATA version. +- Centralize inverse temperature schedule generation. +- Store selected log-level globally. +- Disable properties in collector. + ## [1.0.0rc3] - 2025-07-22 ### Documentation @@ -356,7 +394,7 @@ over diagnosis times cannot be converted to a `DistributionConfig`. BREAKING CHANGES: `generate` command is better configurable - (**config**) Merge sample/sampling configs. - Use lydata's `ModalityConfig`.\ - Since the [lydata](https://github.com/rmnldwg/lydata) package is + Since the [lydata](https://github.com/lycosystem/lydata) package is evolving quickly, I added it as a dependency and moved the first bit of code over there. - Enable use of lydata to load patient data. @@ -871,6 +909,8 @@ returns `None` instead. Fixes [#11] ## [0.5.3] - 2022-08-22 +[1.0.0]: https://github.com/lycosystem/lyscripts/compare/1.0.0rc3...1.0.0 +[1.0.0rc3]: https://github.com/lycosystem/lyscripts/compare/1.0.0rc2...1.0.0rc3 [1.0.0rc2]: https://github.com/lycosystem/lyscripts/compare/1.0.0rc1...1.0.0rc2 [1.0.0rc1]: https://github.com/lycosystem/lyscripts/compare/1.0.0.a7...1.0.0rc1 [1.0.0.a7]: https://github.com/lycosystem/lyscripts/compare/1.0.0.a6...1.0.0.a7 @@ -931,8 +971,10 @@ returns `None` instead. Fixes [#11] [#70]: https://github.com/lycosystem/lyscripts/issues/70 [#72]: https://github.com/lycosystem/lyscripts/issues/72 [#74]: https://github.com/lycosystem/lyscripts/issues/74 +[#75]: https://github.com/lycosystem/lyscripts/issues/75 [`emcee`]: https://emcee.readthedocs.io/en/stable/ [`rich`]: https://rich.readthedocs.io/en/latest/ [`rich_argparse`]: https://github.com/hamdanal/rich_argparse [LyProX]: https://lyprox.org +[@noemibuehrer]: https://github.com/noemibuehrer diff --git a/CITATION.cff b/CITATION.cff new file mode 100755 index 0000000..70404a2 --- /dev/null +++ b/CITATION.cff @@ -0,0 +1,25 @@ +# This CITATION.cff file was generated with cffinit. +# Visit https://bit.ly/cffinit to generate yours today! + +cff-version: 1.2.0 +title: lyscripts +message: >- + If you use this software, please cite it using the + metadata from this file. +type: software +authors: + - given-names: Roman + family-names: Ludwig + orcid: 'https://orcid.org/0000-0001-9434-328X' + affiliation: University Hospital Zurich +repository-code: 'https://github.com/lycosystem/lyscripts' +url: 'https://lyscripts.readthedocs.io' +abstract: >- + Scripts for reproducible research on lymphatic tumor + progression in head and neck cancer. +keywords: + - cancer + - metastasis + - lymphatic system + - head and neck +license: MIT diff --git a/README.md b/README.md index e757f19..e64323b 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,7 @@ ## What are these `lyscripts`? -This package provides convenient scripts for performing inference and learning regarding the lymphatic spread of head & neck cancer. Essentially, it provides a *command line interface* (CLI) to the [lymph](https://github.com/lycosystem/lymph) library and the [lydata](https://github.com/rmnldwg/lydata) repository that stores lymphatic progression data. +This package provides convenient scripts for performing inference and learning regarding the lymphatic spread of head & neck cancer. Essentially, it provides a *command line interface* (CLI) to the [lymph](https://github.com/lycosystem/lymph) library and the [lydata](https://github.com/lycosystem/lydata) repository that stores lymphatic progression data. We are making these "convenience" scripts public, because doing so is one necessary requirement to making our research easily and fully reproducible. There exists another repository, [lynference](https://github.com/lycosystem/lynference), where we stored the pipelines that produced our published results in a persistent way. @@ -41,7 +41,11 @@ Simply run lyscripts --help ``` -in your terminal and let the output guide you through the functions of the program. +in your terminal to display the help text for the main command. It will list all subcommands that are avialable, which you can then also call with `lyscripts --help` to get more information on its use and the available arguments. + +For example, one subcommand is `lyscripts data collect`, which will launch a small web server that allows a user to enter patient records on lymphatic involvement in head and neck cancer one row at a time and construct a standardized CSV file from it. + + You can also refer to the [documentation] for a written-down version of all these help texts and even more context on how and why to use the provided commands. diff --git a/docs/source/data/collect.rst b/docs/source/data/collect.rst new file mode 100644 index 0000000..8c8b08b --- /dev/null +++ b/docs/source/data/collect.rst @@ -0,0 +1,13 @@ +.. currentmodule:: lyscripts.data.collect + +Collect lyDATA Tables Interactively +=================================== + +.. automodule:: lyscripts.data.collect + :members: + :show-inheritance: + +Command Help +------------ + +.. program-output:: lyscripts data collect --help diff --git a/docs/source/data/init.rst b/docs/source/data/init.rst index e53ee60..014bb85 100644 --- a/docs/source/data/init.rst +++ b/docs/source/data/init.rst @@ -19,6 +19,7 @@ Submodules .. toctree:: :maxdepth: 1 + collect lyproxify join split diff --git a/docs/source/index.rst b/docs/source/index.rst index 1f36252..647e4e1 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -8,7 +8,7 @@ Introduction .. include:: ../../README.md - :end-line: 30 + :end-before: :parser: myst_parser.sphinx_ diff --git a/pyproject.toml b/pyproject.toml index f3f409b..1f095be 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,8 +50,10 @@ dependencies = [ "pydantic", "pydantic-settings >= 2.7.0, != 2.9.1, != 2.9.0", "numpydantic", - "lydata >= 0.3.3", "loguru", + "fastapi", + "uvicorn", + "lydata >= 0.4.0", ] dynamic = ["version"] @@ -81,9 +83,18 @@ dev = [ [project.scripts] lyscripts = "lyscripts:main" +[tool.setuptools] +include-package-data = true + [tool.setuptools.packages.find] where = ["src"] +[tool.setuptools.package-data] +"lyscripts" = [ + "src/lyscripts/data/collect/collector.js", + "src/lyscripts/data/collect/index.html", +] + [tool.setuptools_scm] write_to = "src/lyscripts/_version.py" local_scheme = "no-local-version" @@ -191,10 +202,7 @@ skip_tags = "v0.1.0-beta.1" ignore_tags = "" # sort the tags topologically topo_order = false -# sort the commits inside sections by oldest/newest order -sort_commits = "oldest" - -[tool.uv.sources] -lydata = { path = "../lydata-package", editable = true } # limit the number of commits included in the changelog. # limit_commits = 42 +# sort the commits inside sections by oldest/newest order +sort_commits = "oldest" diff --git a/schemas/ly.json b/schemas/ly.json index d14eb3f..561127c 100644 --- a/schemas/ly.json +++ b/schemas/ly.json @@ -308,18 +308,46 @@ "type": "string" }, "repo_name": { - "default": "rmnldwg/lydata", + "anyOf": [ + { + "minLength": 1, + "type": "string" + }, + { + "type": "null" + } + ], + "default": "lycosystem/lydata", "description": "GitHub `repository/owner`.", - "minLength": 1, - "title": "Repo Name", - "type": "string" + "title": "Repo Name" }, "ref": { + "anyOf": [ + { + "minLength": 1, + "type": "string" + }, + { + "type": "null" + } + ], "default": "main", "description": "Branch/tag/commit of the repo.", - "minLength": 1, - "title": "Ref", - "type": "string" + "title": "Ref" + }, + "local_dataset_dir": { + "anyOf": [ + { + "format": "directory-path", + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Path to directory containing all the dataset subdirectories. So, e.g. if `path_on_disk` is `~/datasets` and the dataset is `2023-clb-multisite`, then the CSV file is expected to be at `~/datasets/2023-clb-multisite/data.csv`.", + "title": "Local Dataset Dir" } }, "required": [ @@ -501,6 +529,19 @@ "title": "Relative Thresh", "type": "number" }, + "burnin_steps": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Number of burn-in steps to take. If None, burn-in runs until convergence.", + "title": "Burnin Steps" + }, "num_steps": { "anyOf": [ { @@ -623,6 +664,52 @@ ], "title": "ScenarioConfig", "type": "object" + }, + "ScheduleConfig": { + "description": "Configuration for generating a schedule of inverse temperatures.", + "properties": { + "method": { + "default": "power", + "description": "Method to generate the inverse temperature schedule.", + "enum": [ + "geometric", + "linear", + "power" + ], + "title": "Method", + "type": "string" + }, + "num": { + "default": 32, + "description": "Number of inverse temperatures in the schedule.", + "title": "Num", + "type": "integer" + }, + "power": { + "default": 4.0, + "description": "If a power schedule is chosen, use this as power.", + "title": "Power", + "type": "number" + }, + "values": { + "anyOf": [ + { + "items": { + "type": "number" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "default": null, + "description": "List of inverse temperatures to use instead of generating a schedule. If a list is provided, the other parameters are ignored.", + "title": "Values" + } + }, + "title": "ScheduleConfig", + "type": "object" } }, "description": "Settings for generating a JSON schema for lyscripts configuration files.", @@ -685,6 +772,10 @@ }, "title": "Scenarios", "type": "array" + }, + "schedule": { + "$ref": "#/$defs/ScheduleConfig", + "default": null } }, "required": [ diff --git a/src/lyscripts/__init__.py b/src/lyscripts/__init__.py index f6b7de8..e9dbcb1 100644 --- a/src/lyscripts/__init__.py +++ b/src/lyscripts/__init__.py @@ -17,7 +17,7 @@ CliSubCommand, ) -from lyscripts import compute, data, sample, schedule # noqa: F401 +from lyscripts import compute, data, integrate, sample, schedule # noqa: F401 from lyscripts._version import version from lyscripts.cli import assemble_main, configure_logging from lyscripts.utils import console @@ -51,6 +51,7 @@ class LyscriptsCLI(BaseSettings): sample: CliSubCommand[sample.SampleCLI] compute: CliSubCommand[compute.ComputeCLI] schedule: CliSubCommand[schedule.ScheduleCLI] + integrate: CliSubCommand[integrate.IntegrateCLI] def __init__(self, **kwargs): """Add logging configuration to the lyscripts CLI.""" diff --git a/src/lyscripts/cli.py b/src/lyscripts/cli.py index 26870c7..b6448d0 100644 --- a/src/lyscripts/cli.py +++ b/src/lyscripts/cli.py @@ -8,6 +8,8 @@ .. _loguru: https://loguru.readthedocs.io/en/stable """ +import inspect +import logging from collections.abc import Callable from typing import Literal @@ -17,6 +19,8 @@ from rich.logging import RichHandler from rich_argparse import ArgumentDefaultsRichHelpFormatter +_current_log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO" + def assemble_main( settings_cls: type[BaseSettings], @@ -80,11 +84,40 @@ def configure_logging( """ logger.enable("lyscripts") logger.enable("lydata") - log_level = somewhat_safely_get_loglevel(argv=argv) + global _current_log_level + _current_log_level = somewhat_safely_get_loglevel(argv=argv) logger.remove() handler = RichHandler(console=console) logger.add( sink=handler, - level=log_level, + level=_current_log_level, format="{message}", ) + + +class InterceptHandler(logging.Handler): + """Intercept logging messages and redirect them to Loguru.""" + + def emit(self, record: logging.LogRecord) -> None: + """Intercept the log record and redirect it to Loguru.""" + # Get corresponding Loguru level if it exists. + try: + level: str | int = logger.level(record.levelname).name + except ValueError: + level = record.levelno + + # Find caller from where originated the logged message. + frame, depth = inspect.currentframe(), 0 + while frame: + filename = frame.f_code.co_filename + is_logging = filename == logging.__file__ + is_frozen = "importlib" in filename and "_bootstrap" in filename + if depth > 0 and not (is_logging or is_frozen): + break + frame = frame.f_back + depth += 1 + + logger.opt(depth=depth, exception=record.exc_info).log( + level, + record.getMessage(), + ) diff --git a/src/lyscripts/compute/__init__.py b/src/lyscripts/compute/__init__.py index 84f0e1a..c2bbdff 100644 --- a/src/lyscripts/compute/__init__.py +++ b/src/lyscripts/compute/__init__.py @@ -5,16 +5,17 @@ from pydantic_settings import BaseSettings, CliApp, CliSubCommand -from lyscripts.compute import posteriors, prevalences, priors, risks +from lyscripts.compute import posteriors, prevalences, priors, risks, evidence class ComputeCLI(BaseSettings): - """Compute priors, posteriors, risks, and prevalences from model samples.""" + """Compute priors, posteriors, risks, prevalences and model evidence from model samples.""" priors: CliSubCommand[priors.PriorsCLI] posteriors: CliSubCommand[posteriors.PosteriorsCLI] risks: CliSubCommand[risks.RisksCLI] prevalences: CliSubCommand[prevalences.PrevalencesCLI] + evidence: CliSubCommand[evidence.EvidenceCLI] def cli_cmd(self) -> None: """Start the ``compute`` subcommand.""" diff --git a/src/lyscripts/compute/evidence.py b/src/lyscripts/compute/evidence.py new file mode 100644 index 0000000..1cd51e2 --- /dev/null +++ b/src/lyscripts/compute/evidence.py @@ -0,0 +1,204 @@ +"""Compute the model evidence from MCMC samples. + +Given the samples drawn during thermodynamic integration and their respective log +likelihoods, compute the model log evidence and the Bayesian Information Criterion. +""" + +from __future__ import annotations + +import json +from pathlib import Path + +import emcee +import h5py +import numpy as np +import pandas as pd +from loguru import logger +from pydantic import Field +from scipy.integrate import trapezoid + +from lyscripts.cli import assemble_main +from lyscripts.configs import ( + BaseCLI, + DataConfig, + SamplingConfig, + ScheduleConfig, +) + +RNG = np.random.default_rng() + + +def comp_bic(log_probs: np.ndarray, num_params: int, num_data: int) -> float: + r"""Compute the negative one half of the Bayesian Information Criterion (BIC). + + The BIC is defined as [^1] + $$ BIC = k \\ln{n} - 2 \\ln{\\hat{L}} $$ + where $k$ is the number of parameters ``num_params``, $n$ the number of datapoints + ``num_data`` and $\\hat{L}$ the maximum likelihood estimate of the ``log_prob``. + It is constructed such that the following is an + approximation of the model evidence: + $$ p(D \\mid m) \\approx \\exp{\\left( - BIC / 2 \\right)} $$ + which is why this function returns the negative one half of it. + + [^1]: https://en.wikipedia.org/wiki/Bayesian_information_criterion + """ + return np.max(log_probs) - num_params * np.log(num_data) / 2.0 + + +def compute_evidence( + temp_schedule: np.ndarray, + log_probs: np.ndarray, + num: int = 1000, +) -> tuple[float, float]: + """Compute the evidence and its standard deviation. + + Given a ``temp_schedule`` of inverse temperatures and corresponding sets of + ``log_probs``, draw ``num`` "paths" of log-probabilities and compute the evidence + for each using trapezoidal integration. + + The evidence is then the mean of those ``num`` integrations, while the error is + their standard deviation. + """ + integrals = np.zeros(shape=num) + for i in range(num): + rand_idx = RNG.choice(log_probs.shape[1], size=log_probs.shape[0]) + drawn_accuracy = log_probs[np.arange(log_probs.shape[0]), rand_idx].copy() + integrals[i] = trapezoid(y=drawn_accuracy, x=temp_schedule) + return np.mean(integrals), np.std(integrals) + + +def compute_ti_results( + settings: EvidenceCLI, + temp_schedule: np.ndarray, + metrics: dict, + ndim: int, + h5_file: Path, +) -> tuple[np.ndarray, np.ndarray]: + """Compute the results in case of a thermodynamic integration run.""" + num_temps = len(temp_schedule) + + if num_temps != len(h5_file["ti"]): + raise RuntimeError( + f"Parameters suggest temp schedule of length {num_temps}, " + f"but stored are {len(h5_file['ti'])}", + ) + + nwalker = ndim * settings.sampling.walkers_per_dim + nsteps = settings.sampling.num_steps + ti_log_probs = np.zeros(shape=(num_temps, nsteps * nwalker)) + + for i, run in enumerate(h5_file["ti"]): + reader = emcee.backends.HDFBackend( + settings.sampling.storage_file, + name=f"ti/{run}", + read_only=True, + ) + ti_log_probs[i] = reader.get_blobs(flat=True)["log_prob"] + + evidence, evidence_std = compute_evidence(temp_schedule, ti_log_probs) + metrics["evidence"] = evidence + metrics["evidence_std"] = evidence_std + + return temp_schedule, ti_log_probs + + +class EvidenceCLI(BaseCLI): + """Compute model evidence from thermodynamic integration samples.""" + + data: DataConfig + sampling: SamplingConfig + schedule: ScheduleConfig = Field( + description="Configuration for generating inverse temperature schedule.", + ) + plots: Path = Field( + default="./plots", + description="Directory for storing plots.", + ) + metrics: Path = Field( + default="./metrics.json", + description="Path to metrics file.", + ) + + def cli_cmd(self) -> None: + """Start the ``evidence`` subcommand. + + Given the MCMC samples from thermodynamic integration provided by the + ``sampling`` argument and the corresponding inverse temperature schedule, + specified in the ``schedule`` argument, the model evidence is computed using + the functions :py:func:`compute_ti_results` and :py:func`compute_evidence`. + Further the BIC is evaluated. + """ + data = self.data.load() + + metrics = {} + + temp_schedule = self.schedule.get_schedule() + + with h5py.File(self.sampling.storage_file, mode="r") as h5_file: + # Get ndim from the HDF5 backend + backend = emcee.backends.HDFBackend( + self.sampling.storage_file, + read_only=True, + name=self.sampling.dataset, + ) + ndim = backend.shape[1] + logger.info(f"Inferred {ndim} parameters from stored samples") + + # if TI has been performed, compute the evidence + if "ti" in h5_file: + temp_schedule, ti_log_probs = compute_ti_results( + settings=self, + temp_schedule=temp_schedule, + metrics=metrics, + ndim=ndim, + h5_file=h5_file, + ) + + logger.info( + "Computed results of thermodynamic integration with " + f"{len(temp_schedule)} steps", + ) + + # store inverse temperatures and log-probs in CSV file + self.plots.parent.mkdir(parents=True, exist_ok=True) + + beta_vs_accuracy = pd.DataFrame( + np.array( + [ + temp_schedule, + np.mean(ti_log_probs, axis=1), + np.std(ti_log_probs, axis=1), + ], + ).T, + columns=["β", "accuracy", "std"], + ) + beta_vs_accuracy.to_csv(self.plots, index=False) + logger.info(f"Plotted β vs accuracy at {self.plots}") + + # use blobs, because also for TI, this is the unscaled log-prob + final_log_probs = backend.get_blobs()["log_prob"] + logger.info( + f"Opened samples from emcee backend from {self.sampling.storage_file}", + ) + + # store metrics in JSON file + self.metrics.parent.mkdir(parents=True, exist_ok=True) + self.metrics.touch(exist_ok=True) + + metrics["BIC"] = comp_bic( + log_probs=final_log_probs, + num_params=ndim, + num_data=len(data), + ) + metrics["max_llh"] = np.max(final_log_probs) + metrics["mean_llh"] = np.mean(final_log_probs) + + with open(self.metrics, mode="w", encoding="utf-8") as metrics_file: + json.dump(metrics, metrics_file) + + logger.info(f"Wrote out metrics to {self.metrics}") + + +if __name__ == "__main__": + main = assemble_main(settings_cls=EvidenceCLI, prog_name="compute evidence") + main() diff --git a/src/lyscripts/compute/prevalences.py b/src/lyscripts/compute/prevalences.py index 8771eb9..89b84c9 100644 --- a/src/lyscripts/compute/prevalences.py +++ b/src/lyscripts/compute/prevalences.py @@ -13,7 +13,8 @@ import pandas as pd from loguru import logger from lydata import C, Q -from lydata.accessor import NoneQ, QueryPortion +from lydata.accessor import QueryPortion +from lydata.querier import NoneQ from lydata.utils import is_old from lymph import models from pydantic import Field @@ -140,7 +141,7 @@ def observe_prevalence( QueryPortion(match=np.int64(7), total=np.int64(79)) """ mapping = mapping or DataConfig.model_fields["mapping"].default_factory() - key = ("tumor", "1", "t_stage") if is_old(data) else ("tumor", "info", "t_stage") + key = ("tumor", "1", "t_stage") if is_old(data) else ("tumor", "core", "t_stage") data[key] = data.ly.t_stage.map(mapping) has_t_stage = C("t_stage").isin(scenario_config.t_stages) diff --git a/src/lyscripts/configs.py b/src/lyscripts/configs.py index f76726e..ebae236 100644 --- a/src/lyscripts/configs.py +++ b/src/lyscripts/configs.py @@ -432,6 +432,12 @@ class SamplingConfig(BaseModel): default=0.05, description="Relative threshold for convergence.", ) + burnin_steps: int | None = Field( + default=None, + description=( + "Number of burn-in steps to take. If None, burn-in runs until convergence." + ), + ) num_steps: int | None = Field( default=100, description=("Number of steps to take in the MCMC sampling."), @@ -461,6 +467,75 @@ def load(self, thin: int = 1) -> np.ndarray: ) +def geometric_schedule(num: int, *_a) -> np.ndarray: + """Create a geometric sequence of ``num`` numbers from 0 to 1.""" + log_seq = np.logspace(0.0, 1.0, num) + shifted_seq = log_seq - 1.0 + return shifted_seq / 9.0 + + +def linear_schedule(num: int, *_a) -> np.ndarray: + """Create a linear sequence of ``num`` numbers from 0 to 1. + + Equivalent to the :py:func:`power_schedule` with ``power=1``. + """ + return np.linspace(0.0, 1.0, num) + + +def power_schedule(num: int, power: float, *_a) -> np.ndarray: + """Create a power sequence of ``num`` numbers from 0 to 1. + + This is essentially a :py:func:`linear_schedule` of ``num`` numbers from 0 to 1, + but each number is raised to the power of ``power``. + """ + lin_seq = np.linspace(0.0, 1.0, num) + return lin_seq**power + + +SCHEDULES = { + "geometric": geometric_schedule, + "linear": linear_schedule, + "power": power_schedule, +} + + +class ScheduleConfig(BaseModel): + """Configuration for generating a schedule of inverse temperatures.""" + + method: Literal["geometric", "linear", "power"] = Field( + default="power", + description="Method to generate the inverse temperature schedule.", + ) + num: int = Field( + default=32, + description="Number of inverse temperatures in the schedule.", + ) + power: float = Field( + default=4.0, + description="If a power schedule is chosen, use this as power.", + ) + values: list[float] | None = Field( + default=None, + description=( + "List of inverse temperatures to use instead of generating a schedule. " + "If a list is provided, the other parameters are ignored." + ), + ) + + def get_schedule(self) -> np.ndarray: + """Get the inverse temperature schedule as a numpy array.""" + if self.values is not None: + logger.debug("Using provided inverse temperature values.") + schedule = np.array(self.values) + else: + logger.debug(f"Generating inverse temperature schedule with {self.method}.") + func = SCHEDULES[self.method] + schedule = func(self.num, self.power) + + logger.info(f"Generated inverse temperature schedule: {schedule}") + return schedule + + def map_to_optional_bool(value: Any) -> Any: """Try to convert the options in the `PatternType` to a boolean value.""" if value in [True, "involved", 1]: diff --git a/src/lyscripts/data/__init__.py b/src/lyscripts/data/__init__.py index cd1fb8d..29ffe69 100644 --- a/src/lyscripts/data/__init__.py +++ b/src/lyscripts/data/__init__.py @@ -7,16 +7,23 @@ the installed datasets provided by the `lydata`_ package and directly from the associated `GitHub repository`_. +Another cool feature is the built-in mini web application that allows collecting nodal +involvement data interactively and in the same standardized format as we have published +in the past, both on `LyProX`_ and in our `GitHub repository`_. It can be launched by +running `lyscripts data collect` in the terminal. See the docs for the +:py:mod:`lyscripts.data.collect` submodule on more information. + .. _Make: https://www.gnu.org/software/make/ .. _DVC: https://dvc.org .. _LyProX: https://lyprox.org .. _lydata: https://lydata.readthedocs.io -.. _GitHub repository: https://github.com/rmnldwg/lydata +.. _GitHub repository: https://github.com/lycosystem/lydata """ from pydantic_settings import BaseSettings, CliApp, CliSubCommand from lyscripts.data import ( # noqa: F401 + collect, enhance, fetch, generate, @@ -32,6 +39,7 @@ class DataCLI(BaseSettings): """Work with lymphatic progression data through this CLI.""" + collect: CliSubCommand[collect.CollectorCLI] lyproxify: CliSubCommand[lyproxify.LyproxifyCLI] join: CliSubCommand[join.JoinCLI] split: CliSubCommand[split.SplitCLI] diff --git a/src/lyscripts/data/collect/__init__.py b/src/lyscripts/data/collect/__init__.py new file mode 100644 index 0000000..425c728 --- /dev/null +++ b/src/lyscripts/data/collect/__init__.py @@ -0,0 +1,150 @@ +"""Submodule to collect data interactively using a simple web interface. + +With the simple command + +.. code-block:: bash + + lyscripts data collect + +One can start a very basic web server that serves an interactive UI at +``http://localhost:8000/``. There, one can enter patient, tumor, and lymphatic +involvement data one by one. When completed, the "submit" button will parse, validate, +and convert the data to serve a downloadable CSV file. + +The resulting CSV file is in the correct format to be used in `LyProX`_ and for +inference using our `lymph-model`_ library. + +.. _LyProX: https://lyprox.org +.. _lymph-model: https://lymph-model.readthedocs.io +""" + +import io +import logging +from pathlib import Path +from typing import Any + +import lydata +import lydata.validator +import pandas as pd +from fastapi import FastAPI, HTTPException +from fastapi.responses import StreamingResponse +from loguru import logger +from pydantic import Field, RootModel +from starlette.responses import FileResponse, HTMLResponse + +from lyscripts.cli import InterceptHandler, _current_log_level +from lyscripts.configs import BaseCLI + +app = FastAPI( + title="lyDATA Collector", + description=( + "A simple web interface to collect data for the lyDATA datasets. " + "This is a prototype and not intended for production use." + ), + version=lydata.__version__, +) + +BASE_DIR = Path(__file__).parent +modalities = lydata.schema.get_default_modalities() +RecordModel = lydata.schema.create_full_record_model(modalities, model_name="Record") +ROOT_MODEL = RootModel[list[RecordModel]] + + +@app.get("/") +def serve_index_html() -> HTMLResponse: + """Serve the ``index.html`` file at the URL's root.""" + with open(BASE_DIR / "index.html") as file: + content = file.read() + return HTMLResponse(content=content) + + +@app.get("/schema") +def serve_schema() -> dict[str, Any]: + """Serve the JSON schema for the patient and tumor records.""" + return ROOT_MODEL.model_json_schema() + + +@app.get("/collector.js") +def serve_collector_js() -> FileResponse: + """Serve the ``collector.js`` file under ``"http://{host}:{port}/collector.js"``. + + This frontend JavaScript file loads the `JSON-Editor`_ library and initializes it + using the schema returned by the :py:func:`serve_schema` function. + + .. _JSON-Editor: https://github.com/json-editor/json-editor/ + """ + return FileResponse(BASE_DIR / "collector.js") + + +@app.post("/submit") +async def process(data: RootModel) -> StreamingResponse: + """Process the submitted data to a DataFrame. + + `FastAPI`_ will automatically parse the received JSON data into the list of + instances of he pydantic type defined by the + :py:func:`lydata.schema.create_full_record_model` function. + + From this list, we create a pandas DataFrame and return it as a downloadable CSV + file. + + .. _FastAPI: https://fastapi.tiangolo.com/ + """ + logger.info(f"Received data: {data.root}") + + if len(data.root) == 0: + logger.warning("No records provided in the data.") + raise HTTPException( + status_code=400, + detail="No records provided in the data.", + ) + + flattened_records = [] + + for record in data.root: + flattened_record = lydata.validator.flatten(record) + logger.debug(f"Flattened record: {flattened_record}") + flattened_records.append(flattened_record) + + df = pd.DataFrame(flattened_records) + df.columns = pd.MultiIndex.from_tuples(flattened_record.keys()) + logger.info(df.patient.core.head()) + + buffer = io.StringIO() + df.to_csv(buffer, index=False) + buffer.seek(0) + logger.success("Data prepared for download") + return StreamingResponse( + buffer, + media_type="text/csv", + headers={"Content-Disposition": "attachment; filename=lydata_records.csv"}, + ) + + +class CollectorCLI(BaseCLI): + """Serve a FastAPI web app for collecting involvement patterns as CSV files.""" + + hostname: str = Field( + default="localhost", + description="Hostname to run the FastAPI app on.", + ) + port: int = Field( + default=8000, + description="Port to run the FastAPI app on.", + ) + + def cli_cmd(self) -> None: + """Run the FastAPI app.""" + logger.debug(self.model_dump_json(indent=2)) + import uvicorn + + # Intercept standard logging and redirect it to Loguru + logging.basicConfig(handlers=[InterceptHandler()], level=0, force=True) + logger.enable("lydata") + + uvicorn.run( + app, + host=self.hostname, + port=self.port, + log_level=_current_log_level.lower(), + log_config=None, + ) diff --git a/src/lyscripts/data/collect/collector.js b/src/lyscripts/data/collect/collector.js new file mode 100644 index 0000000..3d8d5d9 --- /dev/null +++ b/src/lyscripts/data/collect/collector.js @@ -0,0 +1,151 @@ +/** + * Client-side helper functions for collecting user input through JSONEditor, + * validating it against a fetched JSON Schema, submitting the validated data + * to the backend, and presenting a downloadable CSV returned by the server. + * + * NOTE: Functionality is intentionally unchanged; only readability and + * documentation have been improved. + */ + +/** + * Ensure an alert element (used to display validation errors) exists. + * Creates and appends it if missing. + * + * @returns {HTMLDivElement} The existing or newly created alert element. + */ +function ensureAlertExists() { + let alertElement = document.querySelector('.alert'); + if (!alertElement) { + alertElement = document.createElement('div'); + } + alertElement.className = 'alert alert-danger'; + const editorHolder = document.getElementById('editor_holder'); + editorHolder.appendChild(alertElement); + return alertElement; +} + +/** + * Remove an existing validation alert if present. + */ +function ensureAlertRemoved() { + const existingAlert = document.querySelector('.alert'); + if (existingAlert) { + console.log('Clearing existing alert'); + existingAlert.remove(); + } +} + +/** + * Remove an existing download button (if it exists) to avoid duplicates. + */ +function ensureDownloadButtonRemoved() { + const existingButton = document.getElementById('download_link'); + if (existingButton) { + console.log('Clearing existing download button'); + existingButton.remove(); + } +} + +/** + * Create (or replace) a download button for a CSV blob returned by the server. + * + * @param {Blob} blob - The CSV data blob to make downloadable. + */ +function createDownloadButton(blob) { + ensureDownloadButtonRemoved(); + + const url = window.URL.createObjectURL(blob); + const downloadLink = document.createElement('a'); + downloadLink.id = 'download_link'; + downloadLink.href = url; + downloadLink.textContent = 'Download CSV'; + downloadLink.className = 'btn btn-success'; + downloadLink.download = 'lydata_records.csv'; + + document.getElementById('editor_holder').appendChild(downloadLink); + console.log('Download button created:', downloadLink); +} + +/** + * Send validated editor data to the backend for processing. Expects a CSV blob + * in response which is then exposed via a generated download button. + * + * @param {JSONEditor} editor - The JSONEditor instance from which to read data. + */ +async function sendEditorData(editor) { + const data = editor.getValue(); + console.log('Sending data:', data); + + try { + const response = await fetch('/submit', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify(data) + }); + + if (!response.ok) { + // Try to parse error details from JSON, fallback to text + let errorMsg = 'Unknown error'; + try { + const err = await response.json(); + errorMsg = err.detail || err.message || errorMsg; + } catch { + errorMsg = await response.text(); + } + throw new Error(errorMsg); + } + + const blob = await response.blob(); + console.log('Received processed data as blob:', blob); + createDownloadButton(blob); + } catch (error) { + ensureDownloadButtonRemoved(); + console.error('Error submitting data:', error); + const alert = ensureAlertExists(); + alert.textContent = 'Error submitting data: ' + error.message; + alert.classList.add('alert-danger'); + } +} + +/** + * Validate the editor content. If there are validation errors they are + * displayed in an alert; otherwise the data is submitted to the backend. + * + * @param {JSONEditor} editor - The JSONEditor instance to validate & submit. + */ +function processEditor(editor) { + const errors = editor.validate(); + + if (errors.length) { + console.error('Validation errors:', errors); + const alert = ensureAlertExists(); + alert.textContent = 'Validation errors: ' + errors.map(e => e.message).join(', '); + } else { + console.log('Data successfully validated'); + ensureAlertRemoved(); + sendEditorData(editor); + } +} + +// Fetch the JSON Schema to initialize the editor +fetch('/schema') + .then(response => response.json()) + .then(schema => { + const element = document.getElementById('editor_holder'); + const options = { + disable_edit_json: true, + theme: 'bootstrap5', + iconlib: 'bootstrap', + object_layout: 'grid', + disable_properties: true, + schema: schema + }; + const editor = new JSONEditor(element, options); + + // Bind the submit button to validation + submission flow + document.getElementById('submit').addEventListener('click', () => { + console.log('Submit button clicked'); + processEditor(editor); + }); + }) + .catch(error => console.error('Error loading schema:', error)); diff --git a/src/lyscripts/data/collect/index.html b/src/lyscripts/data/collect/index.html new file mode 100644 index 0000000..931dacf --- /dev/null +++ b/src/lyscripts/data/collect/index.html @@ -0,0 +1,26 @@ + + + + + + Basic JSON Editor Example + + + + + + + + +
+

LyDATA Collector

+ +
+ + + +
+ + + diff --git a/src/lyscripts/data/enhance.py b/src/lyscripts/data/enhance.py index 0662072..87b1e23 100644 --- a/src/lyscripts/data/enhance.py +++ b/src/lyscripts/data/enhance.py @@ -1,13 +1,15 @@ """Enhance the dataset by inferring additional columns from the data. -This is a command-line interface to the -:py:func:`~lydata.utils.infer_and_combine_levels` function. +This is a command-line interface to the methods +:py:meth:`~lydata.accessor.LyDataAccessor.combine` and +:py:meth:`~lydata.accessor.LyDataAccessor.augment` of the +:py:class:`~lydata.accessor.LyDataAccessor` class. """ from typing import Literal from loguru import logger -from lydata import infer_and_combine_levels +from lydata.accessor import LyDataFrame from lydata.utils import ModalityConfig from lyscripts.cli import assemble_main @@ -21,7 +23,6 @@ class EnhanceCLI(BaseCLI): input: DataConfig modalities: dict[str, ModalityConfig] | None = None method: Literal["max_llh", "rank"] = "max_llh" - sides: list[Literal["ipsi", "contra"]] = ["ipsi", "contra"] lnl_subdivisions: dict[str, list[str]] = { "I": ["a", "b"], "II": ["a", "b"], @@ -33,29 +34,18 @@ def cli_cmd(self) -> None: """Infer additional columns from the data and save the enhanced dataset. This basically provides a CLI to the - :py:func:`~lydata.utils.infer_and_combine_levels` function. See its docs for + :py:func:`~lydata.accessor.LyDataAccessor.augment` function. See its docs for more details on what exactly is happening here. """ logger.debug(self.model_dump_json(indent=2)) - data = self.input.load() - modality_names = list(self.modalities.keys()) if self.modalities else None - - infer_lvls_kwargs = { - "modalities": modality_names, - "sides": self.sides, - "subdivisions": self.lnl_subdivisions, - } - enhanced = infer_and_combine_levels( - dataset=data, - infer_superlevels_kwargs=infer_lvls_kwargs, - infer_sublevels_kwargs=infer_lvls_kwargs, - combine_kwargs={ - "modalities": self.modalities, - "method": self.method, - }, + data: LyDataFrame = self.input.load() + data = data.ly.enhance( + modalities=self.modalities, + method=self.method, + subdivisions=self.lnl_subdivisions, ) - save_table_to_csv(file_path=self.output_file, table=enhanced) + save_table_to_csv(file_path=self.output_file, table=data) if __name__ == "__main__": diff --git a/src/lyscripts/data/filter.py b/src/lyscripts/data/filter.py index bf64566..3e09fe8 100644 --- a/src/lyscripts/data/filter.py +++ b/src/lyscripts/data/filter.py @@ -1,7 +1,7 @@ """Filter a dataset according to some common criteria. This is essentially a command line interface to building a -:py:class:`query object ` and applying it to the dataset. +:py:class:`query object ` and applying it to the dataset. """ from pathlib import Path @@ -68,7 +68,7 @@ def model_post_init(self, __context): def cli_cmd(self): """Execute the ``filter`` command. - This command uses the :py:class:`~lydata.accessor.Q` objects of the `lydata`_ + This command uses the :py:class:`~lydata.querier.Q` objects of the `lydata`_ library to filter the dataset according to the given criteria. .. _lydata: https://lydata.readthedocs.io diff --git a/src/lyscripts/data/join.py b/src/lyscripts/data/join.py index 6ecd8df..d7eb62e 100644 --- a/src/lyscripts/data/join.py +++ b/src/lyscripts/data/join.py @@ -55,7 +55,7 @@ def cli_cmd(self) -> None: lyscripts data join --configs datasets.ly.yaml --output-file joined.csv .. _pydantic: https://docs.pydantic.dev/latest/ - .. _lydata Github repo: https://github.com/rmnldwg/lydata + .. _lydata Github repo: https://github.com/lycosystem/lydata """ joined = None diff --git a/src/lyscripts/data/lyproxify.py b/src/lyscripts/data/lyproxify.py index 3b14a63..d810824 100644 --- a/src/lyscripts/data/lyproxify.py +++ b/src/lyscripts/data/lyproxify.py @@ -226,7 +226,7 @@ def transform_to_lyprox( .. code-block:: python column_map = { - ("patient", "info", "age"): { + ("patient", "core", "age"): { "func": compute_age_from_raw, "kwargs": {"randomize": False}, "columns": ["birthday", "date of diagnosis"] @@ -237,7 +237,7 @@ def transform_to_lyprox( values of the columns ``"birthday"`` and ``"date of diagnosis"`` as positional arguments, and the keyword argument ``"randomize"`` is set to ``False``. The function then returns the patient's age, which is subsequently stored in the column - ``("patient", "info", "age")``. + ``("patient", "core", "age")``. Note that the ``column_map`` dictionary must have either a ``"default"`` key or ``"func"`` along with ``"columns"`` and ``"kwargs"``, depending on the function diff --git a/src/lyscripts/integrate.py b/src/lyscripts/integrate.py new file mode 100644 index 0000000..c213116 --- /dev/null +++ b/src/lyscripts/integrate.py @@ -0,0 +1,163 @@ +"""Perform thermodynamic integration to evaluate the model evidence. + +Using the functions provided by the `sample` module, this script implements +thermodynamic integration (TI) in order to compute the model evidence. +This is done by sampling the model parameters at different inverse temperatures +following a specified schedule. +""" + +from __future__ import annotations + +import os +from typing import Any + +import emcee +import h5py +import numpy as np +from loguru import logger +from lydata.utils import ModalityConfig +from pydantic import Field + +import lyscripts.sample as sample_module # Import the module to set its global MODEL +from lyscripts.cli import assemble_main +from lyscripts.configs import ( + BaseCLI, + DataConfig, + DistributionConfig, + GraphConfig, + ModelConfig, + SamplingConfig, + ScheduleConfig, + add_distributions, + add_modalities, + construct_model, +) +from lyscripts.utils import get_hdf5_backend + + +def init_ti_sampler( + settings: IntegrateCLI, + temp_idx: int, + ndim: int, + inv_temp: float, + pool: Any, +) -> emcee.EnsembleSampler: + """Initialize the ``emcee.EnsembleSampler`` for TI with the given ``settings''.""" + nwalkers = ndim * settings.sampling.walkers_per_dim + backend = get_hdf5_backend( + file_path=settings.sampling.storage_file, + dataset=f"ti/{temp_idx + 1:0>2d}", + nwalkers=nwalkers, + ndim=ndim, + ) + return emcee.EnsembleSampler( + nwalkers=nwalkers, + ndim=ndim, + log_prob_fn=sample_module.log_prob_fn, + kwargs={"inverse_temp": inv_temp}, + moves=[(emcee.moves.DEMove(), 0.8), (emcee.moves.DESnookerMove(), 0.2)], + backend=backend, + pool=pool, + blobs_dtype=[("log_prob", np.float64)], + parameter_names=list(MODEL.get_named_params().keys()), + ) + + +class IntegrateCLI(BaseCLI): + """Perform thermodynamic integration to compute the model evidence.""" + + graph: GraphConfig + model: ModelConfig = ModelConfig() + distributions: dict[str, DistributionConfig] = Field( + default={}, + description=( + "Mapping of model T-categories to predefined distributions over " + "diagnose times." + ), + ) + modalities: dict[str, ModalityConfig] = Field( + default={}, + description=( + "Maps names of diagnostic modalities to their specificity/sensitivity." + ), + ) + data: DataConfig + sampling: SamplingConfig + schedule: ScheduleConfig = Field( + description="Configuration for generating inverse temperature schedule.", + ) + + def cli_cmd(self) -> None: + """Start the ``integrate`` subcommand. + + The model construction and setup is done analogously to the + ``sample`` command. Afterwards, an :py:class:`emcee.EnsembleSampler` + is initialized (see :py:func:`init_sampler`) and :py:func:`run_sampling`, + implemented in the ``sample``module, is executed twice for each TI step: + once for the burn-in phase and once for the actual sampling phase. + Thereby, the log likelihood is scaled by the respective inverse + temperature of that step. All necessary settings for the sampling + are passed by the ``sampling``argument, except for the inverse + temperatures, which are provided by the ``schedule`` argument. + """ + # as recommended in https://emcee.readthedocs.io/en/stable/tutorials/parallel/# + os.environ["OMP_NUM_THREADS"] = "1" + + logger.debug(self.model_dump_json(indent=2)) + + # ugly, but necessary for pickling + global MODEL + MODEL = construct_model(self.model, self.graph) + MODEL = add_distributions(MODEL, self.distributions) + MODEL = add_modalities(MODEL, self.modalities) + MODEL.load_patient_data(**self.data.get_load_kwargs()) + ndim = MODEL.get_num_dims() + + # set MODEL in the sample module's namespace so log_prob_fn can access it + sample_module.MODEL = MODEL + + schedule = self.schedule.get_schedule() + + # emcee does not support numpy's new random number generator yet. + np.random.seed(self.sampling.seed) # noqa: NPY002 + + with sample_module.get_pool(self.sampling.cores) as pool: + for idx, inv_temp in enumerate(schedule): + sampler = init_ti_sampler( + settings=self, + temp_idx=idx, + ndim=ndim, + inv_temp=inv_temp, + pool=pool, + ) + + sample_module.run_sampling( + description=f"Burn-in phase: TI step {idx + 1}/{len(schedule)}", + sampler=sampler, + num_steps=self.sampling.burnin_steps, + check_interval=self.sampling.check_interval, + trust_factor=self.sampling.trust_factor, + relative_thresh=self.sampling.relative_thresh, + history_file=self.sampling.history_file, + ) + + sample_module.run_sampling( + description=f"Sampling phase: TI step {idx + 1}/{len(schedule)}", + sampler=sampler, + num_steps=self.sampling.num_steps, + reset_backend=True, + check_interval=self.sampling.num_steps, + thin_by=self.sampling.thin_by, + ) + # copy last sampling round over to a group in the HDF5 file called "mcmc" + with h5py.File(self.sampling.storage_file, mode="r+") as h5_file: + h5_file.copy( + f"ti/{len(schedule):0>2d}", + h5_file, + name=self.sampling.dataset, + ) + + +if __name__ == "__main__": + main = assemble_main(settings_cls=IntegrateCLI, prog_name="integrate") + main() diff --git a/src/lyscripts/sample.py b/src/lyscripts/sample.py index d154dfb..1fe7096 100644 --- a/src/lyscripts/sample.py +++ b/src/lyscripts/sample.py @@ -129,7 +129,10 @@ def log_prob_fn(theta: ParamsType, inverse_temp: float = 1.0) -> tuple[float, fl An inverse temperature ``inverse_temp`` can be provided for thermodynamic integration. """ - return inverse_temp * MODEL.likelihood(given_params=theta), inverse_temp + llh = MODEL.likelihood(given_params=theta) + if np.isinf(llh): # to prevent the case of 0 * inf = NaN + return -np.inf, -np.inf + return inverse_temp * llh, llh def ensure_initial_state(sampler: emcee.EnsembleSampler) -> np.ndarray: @@ -339,7 +342,7 @@ def init_sampler(settings: SampleCLI, ndim: int, pool: Any) -> emcee.EnsembleSam moves=[(emcee.moves.DEMove(), 0.8), (emcee.moves.DESnookerMove(), 0.2)], backend=backend, pool=pool, - blobs_dtype=[("inverse_temp", np.float64)], + blobs_dtype=[("log_prob", np.float64)], parameter_names=list(MODEL.get_named_params().keys()), ) @@ -402,6 +405,7 @@ def cli_cmd(self) -> None: run_sampling( description="Burn-in phase", sampler=sampler, + num_steps=self.sampling.burnin_steps, check_interval=self.sampling.check_interval, trust_factor=self.sampling.trust_factor, relative_thresh=self.sampling.relative_thresh, @@ -411,6 +415,7 @@ def cli_cmd(self) -> None: description="Sampling phase", sampler=sampler, num_steps=self.sampling.num_steps, + check_interval=self.sampling.num_steps, reset_backend=True, thin_by=self.sampling.thin_by, ) diff --git a/src/lyscripts/schedule.py b/src/lyscripts/schedule.py index 16b367f..99fa334 100644 --- a/src/lyscripts/schedule.py +++ b/src/lyscripts/schedule.py @@ -10,72 +10,20 @@ :math:`\beta_i^k` where :math:`k` could e.g. be 5. """ -from typing import Literal - -import numpy as np from loguru import logger -from pydantic import Field from lyscripts.cli import assemble_main -from lyscripts.configs import BaseCLI - - -def geometric_schedule(num: int, *_a) -> np.ndarray: - """Create a geometric sequence of ``num`` numbers from 0 to 1.""" - log_seq = np.logspace(0.0, 1.0, num) - shifted_seq = log_seq - 1.0 - return shifted_seq / 9.0 - - -def linear_schedule(num: int, *_a) -> np.ndarray: - """Create a linear sequence of ``num`` numbers from 0 to 1. - - Equivalent to the :py:func:`power_schedule` with ``power=1``. - """ - return np.linspace(0.0, 1.0, num) - +from lyscripts.configs import BaseCLI, ScheduleConfig -def power_schedule(num: int, power: float, *_a) -> np.ndarray: - """Create a power sequence of ``num`` numbers from 0 to 1. - This is essentially a :py:func:`linear_schedule` of ``num`` numbers from 0 to 1, - but each number is raised to the power of ``power``. - """ - lin_seq = np.linspace(0.0, 1.0, num) - return lin_seq**power - - -SCHEDULES = { - "geometric": geometric_schedule, - "linear": linear_schedule, - "power": power_schedule, -} - - -class ScheduleCLI(BaseCLI): +class ScheduleCLI(ScheduleConfig, BaseCLI): """Generate an inverse temperature schedule for thermodynamic integration.""" - method: Literal["geometric", "linear", "power"] = Field( - default="geometric", - description="Choose the method to distribute the inverse temperatures.", - ) - num: int = Field( - default=32, - description="Number of inverse temperatures in the schedule.", - ) - power: float = Field( - default=4, - description="If a power schedule is chosen, use this as power.", - ) - def cli_cmd(self) -> None: """Start the ``schedule`` command.""" logger.debug(self.model_dump_json(indent=2)) - func = SCHEDULES[self.method] - schedule = func(self.num, self.power) - - for inv_temp in schedule: + for inv_temp in self.get_schedule(): # print is necessary to allow piping the output print(inv_temp) # noqa: T201 diff --git a/src/lyscripts/schema.py b/src/lyscripts/schema.py index cab7c68..7470af3 100644 --- a/src/lyscripts/schema.py +++ b/src/lyscripts/schema.py @@ -53,6 +53,7 @@ class SchemaSettings(BaseModel): model: configs.ModelConfig = None sampling: configs.SamplingConfig = None scenarios: list[configs.ScenarioConfig] = [] + schedule: configs.ScheduleConfig = None def main() -> None: diff --git a/tests/compute/prevalences_test.py b/tests/compute/prevalences_test.py index 471b78d..3fe2151 100644 --- a/tests/compute/prevalences_test.py +++ b/tests/compute/prevalences_test.py @@ -2,7 +2,7 @@ import pandas as pd import pytest -from lydata import infer_and_combine_levels, load_datasets +from lydata import load_datasets from lyscripts.compute.prevalences import observe_prevalence from lyscripts.configs import DiagnosisConfig, ScenarioConfig @@ -24,7 +24,7 @@ def scenario_config() -> ScenarioConfig: def data() -> pd.DataFrame: """Load one of the lyDATA datasets.""" data = next(load_datasets(year=2021, institution="usz")) - return infer_and_combine_levels(data) + return data.ly.enhance() def test_observe_prevalence( @@ -37,5 +37,5 @@ def test_observe_prevalence( scenario_config=scenario_config, ) - assert portion.match == 67 + assert portion.match == 66 assert portion.total == 150