Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 4 additions & 10 deletions .github/workflows/quality.yml → .github/workflows/linting.yml
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
name: Code Quality
name: Linting

on: [push, pull_request]

jobs:
qualitycheck:
lint:
runs-on: ubuntu-latest

steps:
Expand Down Expand Up @@ -37,19 +37,13 @@ jobs:
echo "$PWD/.venv/bin" >> $GITHUB_PATH

- name: Install dependencies
run: uv pip install -e ".[dev]"
run: uv sync

- name: Run ruff (linter)
run: ruff check

- name: Run ruff (formatter)
run: ruff format --check

- name: Run isort
run: isort --check --profile black .

- name: Run mypy
run: mypy .

- name: Run tests
run: pytest -v
run: mypy .
42 changes: 42 additions & 0 deletions .github/workflows/testing.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
name: Tests

on: [pull_request]
jobs:
test:
runs-on: ubuntu-latest

steps:
- name: Checkout
uses: actions/checkout@v4

- name: Set up Python 3.13
uses: actions/setup-python@v4
with:
python-version: '3.13'

Comment on lines +13 to +16
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Update actions/setup-python to v5

The action version is outdated and may not run properly on newer GitHub Actions runners.

-    - name: Set up Python 3.13
-      uses: actions/setup-python@v4
-      with:
-        python-version: '3.13'
+    - name: Set up Python 3.13
+      uses: actions/setup-python@v5
+      with:
+        python-version: '3.13'
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
uses: actions/setup-python@v4
with:
python-version: '3.13'
- name: Set up Python 3.13
uses: actions/setup-python@v5
with:
python-version: '3.13'
🧰 Tools
🪛 actionlint (1.7.7)

13-13: the runner of "actions/setup-python@v4" action is too old to run on GitHub Actions. update the action's version to fix this issue

(action)

🤖 Prompt for AI Agents
In .github/workflows/testing.yml around lines 13 to 16, the setup-python action
is using version v4, which is outdated. Update the version from
actions/setup-python@v4 to actions/setup-python@v5 to ensure compatibility with
newer GitHub Actions runners and improve reliability.

- name: Install uv
run: |
curl -LsSf https://astral.sh/uv/install.sh | sh
echo "$HOME/.cargo/bin" >> $GITHUB_PATH

- name: Cache dependencies
uses: actions/cache@v4
with:
path: |
~/.cache/uv
~/.uv
.venv
key: ${{ runner.os }}-uv-${{ hashFiles('pyproject.toml') }}
restore-keys: |
${{ runner.os }}-uv-

- name: Create and activate virtual environment
run: |
uv venv
echo "$PWD/.venv/bin" >> $GITHUB_PATH

- name: Install dependencies
run: uv sync

- name: Run tests
run: pytest -v
11 changes: 3 additions & 8 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,10 @@ repos:
language: python
types: [python]

- id: isort
name: isort
entry: isort
language: python
types: [python]
args: [--profile=black]

- id: mypy
name: mypy
entry: mypy
language: system
types: [python]
types: [python]
exclude: ^(tests/)
args: ["--config-file=pyproject.toml"]
27 changes: 13 additions & 14 deletions data/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ class PaRoutesDataset:
def __init__(self, data_path: Path, filename: str, verbose: bool = True) -> None:
self.data_path = data_path
self.filename = filename
self.dataset = json.load(open(data_path.joinpath(filename), "r"))
with open(data_path.joinpath(filename)) as f:
self.dataset = json.load(f)

self.verbose = verbose

Expand Down Expand Up @@ -94,18 +95,16 @@ def prepare_final_dataset_v2(
if n_sms is not None and sm_count + 1 >= n_sms:
break
print(f"Created dataset with {len(products)} entries")
pickle.dump(
(products, starting_materials, path_strings, n_steps_list),
open(save_path, "wb"),
)
with open(save_path, "wb") as f:
pickle.dump((products, starting_materials, path_strings, n_steps_list), f)
return non_permuted_paths


# ------- Dataset Processing -------
print("--- Processing of the PaRoutes dataset begins!")
print("-- starting to canonicalize n1 and n5 stocks")
n1_stock = open(data_path / "n1-stock.txt").read().splitlines()
n5_stock = open(data_path / "n5-stock.txt").read().splitlines()
n1_stock = open(data_path / "n1-stock.txt").read().splitlines() # noqa: SIM115
n5_stock = open(data_path / "n5-stock.txt").read().splitlines() # noqa: SIM115

n1_stock_canon = [canonicalize_smiles(smi) for smi in n1_stock]
n5_stock_canon = [canonicalize_smiles(smi) for smi in n5_stock]
Expand All @@ -128,17 +127,17 @@ def prepare_final_dataset_v2(
n_perms=n_perms,
n_sms=n_sms,
)
pickle.dump(n1_path_set, open(save_path / f"n1_nperms={perm_suffix}_nsms={sm_suffix}_path_set.pkl", "wb"))
pickle.dump(n1_path_set, open(save_path / f"n1_nperms={perm_suffix}_nsms={sm_suffix}_path_set.pkl", "wb")) # noqa: SIM115

print("-- starting to process n5 Routes")
n5_routes_obj = PaRoutesDataset(data_path, "n5-routes.json")
n5_path_set = n5_routes_obj.prepare_final_dataset_v2(
save_path / f"n5_dataset_nperms={perm_suffix}_nsms={sm_suffix}.pkl", n_perms=n_perms, n_sms=n_sms
)
pickle.dump(n5_path_set, open(save_path / f"n5_nperms={perm_suffix}_nsms={sm_suffix}_path_set.pkl", "wb"))
pickle.dump(n5_path_set, open(save_path / f"n5_nperms={perm_suffix}_nsms={sm_suffix}_path_set.pkl", "wb")) # noqa: SIM115

n1_path_set = pickle.load(open(save_path / "n1_nperms=all_nsms=1_path_set.pkl", "rb"))
n5_path_set = pickle.load(open(save_path / "n5_nperms=all_nsms=1_path_set.pkl", "rb"))
n1_path_set = pickle.load(open(save_path / "n1_nperms=all_nsms=1_path_set.pkl", "rb")) # noqa: SIM115
n5_path_set = pickle.load(open(save_path / "n5_nperms=all_nsms=1_path_set.pkl", "rb")) # noqa: SIM115

print("-- starting to process All Routes")
all_routes_obj = PaRoutesDataset(data_path, "all_routes.json")
Expand Down Expand Up @@ -188,8 +187,8 @@ def prepare_final_dataset_v2(


def remove_sm_from_ds(load_path: Path, save_path: Path) -> None:
products, _, path_strings, n_steps_lists = pickle.load(open(load_path, "rb"))
pickle.dump((products, path_strings, n_steps_lists), open(save_path, "wb"))
products, _, path_strings, n_steps_lists = pickle.load(open(load_path, "rb")) # noqa: SIM115
pickle.dump((products, path_strings, n_steps_lists), open(save_path, "wb")) # noqa: SIM115


remove_sm_from_ds(
Expand Down Expand Up @@ -219,5 +218,5 @@ def remove_sm_from_ds(load_path: Path, save_path: Path) -> None:
train_ds_dict = convert_list_of_dicts_to_dict_of_lists(train_ds)
val_ds_dict = convert_list_of_dicts_to_dict_of_lists(val_ds)

save_dataset_sm(train_ds_dict, save_path / f"{train_fname.split('.')[0]}_train={1-val_frac}.pkl")
save_dataset_sm(train_ds_dict, save_path / f"{train_fname.split('.')[0]}_train={1 - val_frac}.pkl")
save_dataset_sm(val_ds_dict, save_path / f"{train_fname.split('.')[0]}_val={val_frac}.pkl")
64 changes: 40 additions & 24 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,13 @@ name = "directmultistep"
version = "1.1.2"
requires-python = ">=3.11"
dependencies = [
"numpy==1.26.4",
"pandas==1.5.3",
"pdoc3==0.11.5",
"plotly==5.24.1",
"lightning==2.2.5",
"pyyaml==6.0.1",
"rdkit==2023.9.3",
"torch==2.3.0",
"torchmetrics==1.6.0",
"tqdm==4.67.1",
"svgwrite==1.4.3",
"svglib==1.5.1",
"tomli>=2.2.1",
"numpy>=1.26.4",
"lightning>=2.2.5",
"pyyaml>=6.0.1",
"rdkit>=2023.9.3",
"torch>=2.3.0",
"tqdm>=4.67.1",

]
authors = [
{ name = "Anton Morgunov", email = "anton@ischemist.com" },
Expand All @@ -28,32 +22,44 @@ readme = "README.md"
Homepage = "https://github.com/batistagroup/DirectMultiStep"
Issues = "https://github.com/batistagroup/DirectMultiStep/issues"

[build-system]
requires = ["setuptools>=61.0", "wheel"]
build-backend = "setuptools.build_meta"

[tool.setuptools]
package-dir = { "" = "src" }

[tool.setuptools.package-data]
"directmultistep.model.default_configs" = ["*.yaml"]

[project.optional-dependencies]
[dependency-groups]
dev = [
"ipykernel>=6.29.5",
"nbformat>=5.10.4",
"rich>=13.9.4",
"kaleido==0.2.1",
"pre-commit==4.0.1",
"mkdocs==1.6.1",
"mkdocstrings-python==1.12.2",
"mkdocs-material==9.5.49",
"pre-commit>=4.0.1",
"mkdocs>=1.6.1",
"mkdocstrings-python>=1.12.2",
"mkdocs-material>=9.5.49",
"material-plausible-plugin>=0.3.0",
"pytest==8.3.4",
"ruff==0.4.7",
"mypy==1.13.0",
"isort==5.13.2",
"typing-extensions==4.12.2",
"mypy-extensions==1.0.0",
"pytest>=8.3.4",
"ruff>=0.12.5",
"mypy>=1.17.0",
"typing-extensions>=4.12.2",
"mypy-extensions>=1.0.0",
"types-pyyaml",
"types-tqdm",
"types-requests",
]

[project.optional-dependencies]
vis = [
"svgwrite>=1.4.3",
"svglib>=1.5.1",
"plotly>=5.24.1",
]

[tool.mypy]
strict = true
ignore_missing_imports = true
Expand All @@ -66,4 +72,14 @@ ignore_errors = true

[tool.ruff]
line-length = 120
lint.select = [
"E", # pycodestyle
"F", # Pyflakes
"UP", # pyupgrade
"B", # flake8-bugbear
"SIM", # flake8-simplify
"I", # isort
]
lint.ignore = ["E501"]


8 changes: 4 additions & 4 deletions scripts/solve_compounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,22 +42,22 @@

logger.info("Loading targets and stock compounds")
if target_name == "uspto_190":
with open(COMPOUND_PATH / "uspto_190.txt", "r") as f:
with open(COMPOUND_PATH / "uspto_190.txt") as f:
targets = f.read().splitlines()
elif target_name == "chembl":
with open(COMPOUND_PATH / "chembl_targets.json", "r") as f:
with open(COMPOUND_PATH / "chembl_targets.json") as f:
targets = json.load(f)
else:
logger.error(f"{target_name} is not a valid target name")
raise Exception("Not valid target_name")

# eMols is available at https://github.com/binghong-ml/retro_star
# make sure to canonicalize the SMILES strings before using them
with open(COMPOUND_PATH / "eMolecules.txt", "r") as f:
with open(COMPOUND_PATH / "eMolecules.txt") as f:
emol_stock_set = set(f.read().splitlines())
# buyables-stock is available at https://github.com/jihye-roh/higherlev_retro
# make sure to canonicalize the SMILES strings before using them
with open(COMPOUND_PATH / "buyables-stock.txt", "r") as f:
with open(COMPOUND_PATH / "buyables-stock.txt") as f:
buyables_stock_set = set(f.read().splitlines())

chunk_size = len(targets) // num_part
Expand Down
12 changes: 6 additions & 6 deletions src/directmultistep/analysis/paper/linear_vs_convergent.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,15 +330,15 @@ def create_comparative_bar_plots(

colors = style.colors_blue + style.colors_purple + style.colors_red

for cat, pos in zip(categories, positions):
for cat, pos in zip(categories, positions, strict=True):
x = list(results[0][cat].keys())
x.sort(key=lambda k: int(k.split()[-1]))

if k_vals is not None:
k_vals_str = [f"Top {k}" for k in k_vals]
x = [k for k in x if k in k_vals_str]

for i, (result, name) in enumerate(zip(results, trace_names)):
for i, (result, name) in enumerate(zip(results, trace_names, strict=True)):
y = [float(result[cat][k].strip('%')) for k in x]

fig.add_trace(
Expand Down Expand Up @@ -408,7 +408,7 @@ def create_accuracy_by_length_plot(
cset = style.publication_colors
colors = [cset["primary_blue"], cset["dark_blue"], cset["purple"], cset["dark_purple"]]

for i, (path, dataset, config) in enumerate(zip(result_paths, datasets, configs)):
for i, (path, dataset, config) in enumerate(zip(result_paths, datasets, configs, strict=True)):
paths_name = config.processed_paths_name

with open(path / paths_name, "rb") as f:
Expand Down Expand Up @@ -469,7 +469,7 @@ def create_accuracy_by_length_subplots(
cset = style.publication_colors
colors = [cset["primary_blue"], cset["dark_blue"], cset["purple"], cset["dark_purple"]]

for i, (path, dataset, config) in enumerate(zip(result_paths, datasets, configs)):
for i, (path, dataset, config) in enumerate(zip(result_paths, datasets, configs, strict=True)):
paths_name = config.processed_paths_name

with open(path / paths_name, "rb") as f:
Expand All @@ -482,7 +482,7 @@ def create_accuracy_by_length_subplots(
"non_convergent": (analyzer.convergent_idxs, 3),
}

for route_type, (ignore_ids, row) in route_types.items():
for _route_type, (ignore_ids, row) in route_types.items():
lengths, step_stats = RouteAnalyzer._calculate_accuracy_by_length_data(
predicted_routes, dataset, k_vals, ignore_ids=ignore_ids
)
Expand Down Expand Up @@ -552,7 +552,7 @@ def get_predictions_by_type(routes: PathsProcessedType) -> tuple[list[int], list
# Create subplot titles
def create_subtitle(stage: str, predictions: list[int]) -> str:
mean, median, mean_f, median_f = calculate_prediction_stats(predictions)
base = f"{stage}<br><span style=\"font-size:{FONT_SIZES['subplot_title']-4}px\">"
base = f'{stage}<br><span style="font-size:{FONT_SIZES["subplot_title"] - 4}px">'
stats = f"mean={mean:.1f}, median={median:.1f}"
if show_filtered_stats:
stats += f" (μ*={mean_f:.1f}, m*={median_f:.1f})"
Expand Down
Loading
Loading