diff --git a/README.md b/README.md index 5bd22a4..d598eea 100644 --- a/README.md +++ b/README.md @@ -24,17 +24,33 @@ Play around with it and raise Github issues if anything fails conda env create --file=processing_env.yml ``` - For kilosorting you will also need: - 1. Install kilosort and the GUI, run `python -m pip install kilosort[gui]`. If you're on a zsh server, you may need to use `python -m pip install "kilosort[gui]"` - 2. You can also just install the minimal version of kilosort with python -m pip install kilosort. - 3. Next, if the CPU version of pytorch was installed (will happen on Windows), remove it with `pip uninstall torch` - 4. Then install the GPU version of pytorch `conda install pytorch pytorch-cuda=11.8 -c pytorch -c nvidia` + For kilosorting, install Kilosort in a **separate** conda env called `kilosort` (not in `bnd`). + The `bnd` pipeline will invoke it via `conda run -n kilosort ...`. + + 1. Create and activate the env: + ```shell + conda create -n kilosort python=3.10 pip + conda activate kilosort + ``` + 2. Install Kilosort (see https://github.com/MouseLand/Kilosort): + ```shell + python -m pip install "kilosort[gui]" + ``` + Or minimal: + ```shell + python -m pip install kilosort + ``` + 3. Install GPU-enabled PyTorch (example): + ```shell + conda install pytorch pytorch-cuda=11.8 -c pytorch -c nvidia + ``` + + If your env is not named `kilosort`, set `BND_KILOSORT_ENV` to the env name before running `bnd`. If you installed the base environment and want to update later on: ```shell conda env update --file=processing_env.yml ``` - And then do the kilosort step 4. Create your configuration file: ```shell bnd init # Provide the path to local and remote data storage diff --git a/bnd/pipeline/kilosort.py b/bnd/pipeline/kilosort.py index b291779..5b23a23 100644 --- a/bnd/pipeline/kilosort.py +++ b/bnd/pipeline/kilosort.py @@ -1,11 +1,12 @@ import os +import json +import shutil +import subprocess +import tempfile +import textwrap from configparser import ConfigParser from pathlib import Path -import torch -from kilosort import run_kilosort -from kilosort.utils import PROBE_DIR, download_probes - from ..logger import set_logging from ..config import Config, _load_config from ..config import find_file @@ -13,6 +14,154 @@ logger = set_logging(__name__) +_KILOSORT_RUNNER_CODE = textwrap.dedent( + """ + import json + import sys + from pathlib import Path + + params_path = Path(sys.argv[1]) + params = json.loads(params_path.read_text()) + + from kilosort import run_kilosort + from kilosort.utils import PROBE_DIR, download_probes + + probe_name = params["probe_name"] + + if not PROBE_DIR.exists(): + download_probes() + + if not any(PROBE_DIR.glob(probe_name)): + download_probes() + + run_kilosort( + settings=params["settings"], + probe_name=probe_name, + data_dir=params["data_dir"], + results_dir=params["results_dir"], + save_preprocessed_copy=params.get("save_preprocessed_copy", False), + verbose_console=params.get("verbose_console", False), + ) +""" +).strip() + + +def _get_kilosort_env_name() -> str: + return ( + os.environ.get("BND_KILOSORT_ENV") + or os.environ.get("KILOSORT_CONDA_ENV") + or "kilosort" + ) + + +def _find_conda_runner() -> str: + conda_exe = os.environ.get("CONDA_EXE") + if conda_exe and Path(conda_exe).exists(): + return conda_exe + + for candidate in ("conda", "mamba", "micromamba"): + resolved = shutil.which(candidate) + if resolved: + return resolved + + raise FileNotFoundError( + "Could not find a conda runner executable (tried CONDA_EXE, conda, mamba, micromamba)." + ) + + +def _run_in_conda_env( + env_name: str, args: list[str], *, capture_output: bool = False +) -> subprocess.CompletedProcess: + runner = _find_conda_runner() + cmd = [runner, "run", "-n", env_name, *args] + + # Workaround for WSL/DrvFs temp-file oddities (e.g., ftruncate -> ENOENT) when + # TEMP/TMP point to `/mnt/c/...`. Force a sane temp dir for subprocesses. + env = os.environ.copy() + if Path("/tmp").exists(): + env["TMPDIR"] = "/tmp" + env["TEMP"] = "/tmp" + env["TMP"] = "/tmp" + try: + return subprocess.run( + cmd, + check=True, + capture_output=capture_output, + text=capture_output, + env=env, + ) + except subprocess.CalledProcessError as e: + raise RuntimeError( + f"Command failed in conda env '{env_name}': {cmd} (exit code {e.returncode})." + ) from e + + +def _check_kilosort_cuda(env_name: str) -> tuple[bool, str | None]: + code = textwrap.dedent( + """ + import torch + + if torch.cuda.is_available(): + print("CUDA_AVAILABLE") + print(torch.cuda.get_device_name(0)) + else: + print("CUDA_NOT_AVAILABLE") + """ + ).strip() + + proc = _run_in_conda_env(env_name, ["python", "-c", code], capture_output=True) + lines = [line.strip() for line in (proc.stdout or "").splitlines() if line.strip()] + if not lines: + return False, None + + if lines[0] == "CUDA_AVAILABLE": + return True, lines[1] if len(lines) > 1 else None + + return False, None + + +def _run_kilosort_in_env( + *, + env_name: str, + settings: dict, + probe_name: str, + data_dir: Path, + results_dir: Path, +) -> None: + payload = dict( + settings=settings, + probe_name=probe_name, + data_dir=str(data_dir), + results_dir=str(results_dir), + save_preprocessed_copy=False, + verbose_console=False, + ) + + tmp_path: Path | None = None + try: + with tempfile.NamedTemporaryFile("w", suffix=".json", delete=False) as tmp: + json.dump(payload, tmp) + tmp_path = Path(tmp.name) + + _run_in_conda_env( + env_name, ["python", "-c", _KILOSORT_RUNNER_CODE, str(tmp_path)] + ) + + except Exception as e: + raise RuntimeError( + f"Failed to run Kilosort in the separate conda env '{env_name}'. " + "Make sure it exists and has the `kilosort` package installed. " + "You can override the env name via BND_KILOSORT_ENV." + ) from e + + finally: + if tmp_path: + try: + tmp_path.unlink(missing_ok=True) + except Exception: + pass + + def read_metadata(filepath: Path) -> dict: """Parse a section-less INI file (eg NPx metadata file) and return a dictionary of key-value pairs.""" with open(filepath, "r") as f: @@ -120,27 +269,19 @@ def run_kilosort_on_stream( ) ksort_output_path.mkdir(parents=True, exist_ok=True) - if not PROBE_DIR.exists(): - logger.info("Probe directory not found, downloading probes") - download_probes() - - if any(PROBE_DIR.glob(f"{probe_name}")): - # Sometimes the gateway can throw an error so just double check. - download_probes() - # Check if the metadata file is complete # when SpikeGLX crashes, metadata misses some values. _fix_session_ap_metadata(meta_file_path) # Find out which probe type we have probe_name = _read_probe_type(meta_file_path) - _ = run_kilosort( + env_name = _get_kilosort_env_name() + _run_kilosort_in_env( + env_name=env_name, settings=sorter_params, probe_name=probe_name, data_dir=probe_folder_path, results_dir=ksort_output_path, - save_preprocessed_copy=False, - verbose_console=False, ) return @@ -210,11 +351,17 @@ def run_kilosort_on_session(session_path: Path) -> None: else: ephys_recording_folders = config.get_subdirectories_from_pattern(session_path, "*_g?") - # Check kilosort is installed in environment - if torch.cuda.is_available(): - logger.info(f"CUDA is available. GPU device: {torch.cuda.get_device_name(0)}") + env_name = _get_kilosort_env_name() + cuda_available, device_name = _check_kilosort_cuda(env_name) + if cuda_available: + if device_name: + logger.info(f"CUDA is available in '{env_name}'. GPU device: {device_name}") + else: + logger.info(f"CUDA is available in '{env_name}'.") else: - logger.warning("CUDA is not available. GPU computations will not be enabled.") + logger.warning( + f"CUDA is not available in '{env_name}'. GPU computations will not be enabled." + ) if len(ephys_recording_folders) > 1: raise ValueError( "It seems you are trying to run kilosort without GPU. Look at the README on instrucstions of how to do this. "