Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 0 additions & 50 deletions .github/workflows/ci.yml

This file was deleted.

14 changes: 14 additions & 0 deletions .github/workflows/pre-commit.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
name: pre-commit

on:
pull_request:
push:
branches: [main]

jobs:
pre-commit:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v3
- uses: pre-commit/action@v3.0.1
22 changes: 22 additions & 0 deletions .github/workflows/unit.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
name: unit

on:
pull_request:
push:
branches: [main]


jobs:
unit:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4

- name: Install uv
uses: astral-sh/setup-uv@v5

- name: Install the project
run: uv sync --all-extras --dev

- name: Run tests
run: uv run pytest -svv --cov=cortex tests
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@
__pycache__
*.egg-info
docs/build
temp
.coverage
9 changes: 5 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@ repos:
rev: 0.6.0
hooks:
- id: nbstripout
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.2.1
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.9.7
hooks:
# Run the linter.
- id: ruff
name: ruff
args: [ --fix ]
# Run the formatter.
- id: ruff-format
20 changes: 4 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,11 @@ Rather than tack on auxiliary abstractions to a single input --> single task mod

Note: conda/mamba are no longer actively supported. We recommend using `uv` for package management.

Note: Support for Python versions > 3.10 is blocked by the current dependency on `pytorch-lightning==1.9.5`.

```bash
uv venv -n ~/.venv/cortex --python 3.10
uv sync
source ~/.venv/cortex/bin/activate
uv pip install pytorch-cortex
```


If you have a package version issue we provide pinned versions of all dependencies in `requirements.txt`.
To update the frozen dependencies run

```bash
uv pip freeze > requirements.txt
```


## Running

Use `cortex_train_model --config-name <CONFIG_NAME>` to train, e.g.:
Expand All @@ -57,14 +45,14 @@ Contributions are welcome!
### Install dev requirements and pre-commit hooks

```bash
python -m pip install -r requirements-dev.in
pre-commit install
uv sync --dev
uv run pre-commit install
```

### Testing

```bash
python -m pytest -v --cov-report term-missing --cov=./cortex ./tests
uv run pytest -v --cov-report term-missing --cov=./cortex ./tests
```

### Build and browse docs locally
Expand Down
Empty file removed README.rst
Empty file.
5 changes: 4 additions & 1 deletion cortex/cmdline/train_cortex_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ def execute(cfg):
ckpt_file = None
ckpt_cfg = None

if os.path.exists(ckpt_file) and cfg.save_ckpt:
ckpt_exists = os.path.exists(ckpt_file)

if ckpt_exists and cfg.save_ckpt:
msg = f"checkpoint already exists at {ckpt_file} and will be overwritten!"
warnings.warn(msg, UserWarning, stacklevel=2)

Expand All @@ -76,6 +78,7 @@ def execute(cfg):
model,
train_dataloaders=model.get_dataloader(split="train"),
val_dataloaders=model.get_dataloader(split="val"),
ckpt_path=ckpt_file if ckpt_exists else None,
)

# save model
Expand Down
4 changes: 1 addition & 3 deletions cortex/data/dataset/_data_frame_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,7 @@ def __init__(
remove_archive=True,
)
else:
raise ValueError(
f"Dataset not found at {path}. " "If `download` is `True`, the dataset will be downloaded."
)
raise ValueError(f"Dataset not found at {path}. If `download` is `True`, the dataset will be downloaded.")
self._data = self._read_data(path, dedup=dedup, train=train, random_seed=random_seed, **kwargs)

def _read_data(self, path: str, dedup: bool, train: bool, random_seed: int, **kwargs: Any) -> DataFrame:
Expand Down
8 changes: 7 additions & 1 deletion cortex/logging/_wandb_setup.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import uuid
from typing import MutableMapping

import wandb
Expand Down Expand Up @@ -30,7 +31,12 @@ def wandb_setup(cfg: DictConfig):
mode=cfg.wandb_mode,
group=cfg.exp_name,
)
cfg["job_name"] = wandb.run.name

if cfg["wandb_mode"] == "online":
cfg["job_name"] = wandb.run.name
else:
cfg["job_name"] = uuid.uuid4().hex[:8]

cfg["__version__"] = cortex.__version__
log_cfg = flatten_config(OmegaConf.to_container(cfg, resolve=True))
wandb.config.update(log_cfg)
Expand Down
2 changes: 1 addition & 1 deletion cortex/model/elemental/_ddp_standardize.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def forward(self, Y: Tensor, Yvar: Optional[Tensor] = None) -> tuple[Tensor, Opt
f"Y.shape[:-2]={Y.shape[:-2]}."
)
if Y.size(-1) != self._m:
raise RuntimeError(f"Wrong output dimension. Y.size(-1) is {Y.size(-1)}; expected " f"{self._m}.")
raise RuntimeError(f"Wrong output dimension. Y.size(-1) is {Y.size(-1)}; expected {self._m}.")
stdvs = Y.std(dim=-2, keepdim=True)
stdvs = stdvs.where(stdvs >= self._min_stdv, torch.full_like(stdvs, 1.0))
means = Y.mean(dim=-2, keepdim=True)
Expand Down
30 changes: 24 additions & 6 deletions cortex/model/tree/_seq_model_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
import pandas as pd
import torch
from botorch.models.transforms.outcome import OutcomeTransform
from lightning.pytorch.utilities.combined_loader import CombinedLoader
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning.trainer.supporters import CombinedLoader
from torch import nn

from cortex.model import online_weight_update_
Expand Down Expand Up @@ -47,6 +47,10 @@ def __init__(
self._eval_state_dict = None
self._w_avg_step_count = 1

# for accumulating step outputs in Lightning 2.x
self.training_step_outputs = []
self.validation_step_outputs = []

# decoupled multi-task training requires manual optimization
self.automatic_optimization = False
self.save_hyperparameters(
Expand Down Expand Up @@ -82,7 +86,6 @@ def get_dataloader(self, split="train"):
else:
raise ValueError(f"Invalid split {split}")

# change val to max_size when lightning upgraded to >1.9.5
mode = "min_size" if split == "train" else "max_size_cycle"
return CombinedLoader(loaders, mode=mode)

Expand Down Expand Up @@ -132,6 +135,10 @@ def training_step(self, batch: dict, batch_idx: int, dataloader_idx: Optional[in
step_metrics.update(
{f"{task_key}/train_batch_size": np.mean(batch_sizes) for task_key, batch_sizes in batch_size.items()}
)

# Append metrics to accumulate across steps (for Lightning 2.x)
self.training_step_outputs.append(step_metrics)

return step_metrics

def training_step_end(self, step_metrics):
Expand All @@ -157,8 +164,9 @@ def training_step_end(self, step_metrics):

return step_metrics

def training_epoch_end(self, step_metrics):
step_metrics = pd.DataFrame.from_records(step_metrics)
def on_train_epoch_end(self):
# In Lightning 2.x, we need to process the accumulated outputs manually
step_metrics = pd.DataFrame.from_records(self.training_step_outputs)
step_metrics = step_metrics.mean().to_dict()

task_keys = set()
Expand All @@ -172,6 +180,9 @@ def training_epoch_end(self, step_metrics):
del task_metrics[f"{t_key}/train_batch_size"]
self.log_dict(task_metrics, prog_bar=True, batch_size=batch_size)

# Clear the outputs list
self.training_step_outputs.clear()

def _weight_average_update(
self,
w_avg_step_count: int,
Expand Down Expand Up @@ -252,10 +263,14 @@ def validation_step(self, batch: dict, batch_idx: int, dataloader_idx: Optional[
)
step_metrics[f"{task_key}/val_batch_size"] = len(task_batch)

# Append metrics to accumulate across steps (for Lightning 2.x)
self.validation_step_outputs.append(step_metrics)

return step_metrics

def validation_epoch_end(self, step_metrics):
step_metrics = pd.DataFrame.from_records(step_metrics)
def on_validation_epoch_end(self):
# In Lightning 2.x, we need to process the accumulated outputs manually
step_metrics = pd.DataFrame.from_records(self.validation_step_outputs)
step_metrics = step_metrics.mean().to_dict()

task_keys = set()
Expand All @@ -269,6 +284,9 @@ def validation_epoch_end(self, step_metrics):
del task_metrics[f"{t_key}/val_batch_size"]
self.log_dict(task_metrics, prog_bar=True, logger=True, batch_size=batch_size)

# Clear the outputs list
self.validation_step_outputs.clear()

def finetune(
self,
cfg: DictConfig,
Expand Down
47 changes: 41 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,13 +1,36 @@
[project]
name = "pytorch-cortex"
dynamic = ["version", "readme", "dependencies", "optional-dependencies"]
dynamic = ["version"]
readme = "README.md"
description = "A modular architecture for deep learning systems."
authors = [{name = "Samuel Stanton", email = "stanton.samuel@gene.com"}]
classifiers=[
"Development Status :: 3 - Alpha",
"Intended Audience :: Science/Research",
"Programming Language :: Python :: 3",
]
requires-python = ">=3.10"
dependencies = [
"boto3>=1.37.1",
"botorch>=0.9.4",
"cachetools>=5.5.2",
"edlib>=1.3.9.post1",
"hydra-core>=1.2.0",
"lightning>=2.0",
"numpy>=2",
"omegaconf>=2.3.0",
"pandas>=2",
"pyarrow>=19.0.1",
"pytorch-warmup>=0.2.0",
"s3fs>=2025.3.2",
"tabulate>=0.9.0",
"torch>=2.5.1",
"torchvision",
"tqdm>=4.67.1",
"transformers>=4.24.0",
"universal-pathlib>=0.2.6",
"wandb>=0.19.9",
]

[project.scripts]
cortex_train_model = "cortex.cmdline.train_cortex_model:main"
Expand All @@ -19,11 +42,6 @@ cortex_generate_designs = "cortex.cmdline.generate_designs:main"
requires = ["setuptools>=45", "setuptools_scm[toml]>=6.2"]
build-backend = "setuptools.build_meta"

[tool.setuptools.dynamic]
dependencies = {file = ["requirements.in"]}
optional-dependencies = {dev = {file = ["requirements-dev.in"]}}
readme = {file = "README.rst"}

[tool.setuptools.packages.find]
include = ["cortex*"]

Expand All @@ -39,6 +57,11 @@ fallback_version = "0.0.0"

[tool.ruff]
line-length = 120
extend-exclude = [
"*.ipynb",
"**/torchinductor/**/*.py",
"notebooks",
]

[tool.ruff.lint]
select = [
Expand All @@ -60,3 +83,15 @@ ignore = [
"__init__.py" = [
"F401", # MODULE IMPORTED BUT UNUSED
]

[dependency-groups]
dev = [
"ipykernel>=6.29.5",
"ipython>=8.34.0",
"pre-commit>=4.2.0",
"pytest>=8.3.5",
"pytest-cov>=6.1.0",
"sphinx>=8.1.3",
"sphinx-autoapi>=3.6.0",
"sphinx-rtd-theme>=3.0.2",
]
7 changes: 0 additions & 7 deletions requirements-dev.in

This file was deleted.

Loading