Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 22 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,33 @@ Play around with it and raise Github issues if anything fails
conda env create --file=processing_env.yml
```

For kilosorting you will also need:
1. Install kilosort and the GUI, run `python -m pip install kilosort[gui]`. If you're on a zsh server, you may need to use `python -m pip install "kilosort[gui]"`
2. You can also just install the minimal version of kilosort with python -m pip install kilosort.
3. Next, if the CPU version of pytorch was installed (will happen on Windows), remove it with `pip uninstall torch`
4. Then install the GPU version of pytorch `conda install pytorch pytorch-cuda=11.8 -c pytorch -c nvidia`
For kilosorting, install Kilosort in a **separate** conda env called `kilosort` (not in `bnd`).
The `bnd` pipeline will invoke it via `conda run -n kilosort ...`.

1. Create and activate the env:
```shell
conda create -n kilosort python=3.10 pip
conda activate kilosort
```
2. Install Kilosort (see https://github.com/MouseLand/Kilosort):
```shell
python -m pip install "kilosort[gui]"
```
Or minimal:
```shell
python -m pip install kilosort
```
3. Install GPU-enabled PyTorch (example):
```shell
conda install pytorch pytorch-cuda=11.8 -c pytorch -c nvidia
```

If your env is not named `kilosort`, set `BND_KILOSORT_ENV` to the env name before running `bnd`.

If you installed the base environment and want to update later on:
```shell
conda env update --file=processing_env.yml
```
And then do the kilosort step
4. Create your configuration file:
```shell
bnd init # Provide the path to local and remote data storage
Expand Down
185 changes: 166 additions & 19 deletions bnd/pipeline/kilosort.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,167 @@
import os
import json
import shutil
import subprocess
import tempfile
import textwrap
from configparser import ConfigParser
from pathlib import Path

import torch
from kilosort import run_kilosort
from kilosort.utils import PROBE_DIR, download_probes

from ..logger import set_logging
from ..config import Config, _load_config
from ..config import find_file

logger = set_logging(__name__)


_KILOSORT_RUNNER_CODE = textwrap.dedent(
"""
import json
import sys
from pathlib import Path

params_path = Path(sys.argv[1])
params = json.loads(params_path.read_text())

from kilosort import run_kilosort
from kilosort.utils import PROBE_DIR, download_probes

probe_name = params["probe_name"]

if not PROBE_DIR.exists():
download_probes()

if not any(PROBE_DIR.glob(probe_name)):
download_probes()

run_kilosort(
settings=params["settings"],
probe_name=probe_name,
data_dir=params["data_dir"],
results_dir=params["results_dir"],
save_preprocessed_copy=params.get("save_preprocessed_copy", False),
verbose_console=params.get("verbose_console", False),
)
"""
).strip()


def _get_kilosort_env_name() -> str:
return (
os.environ.get("BND_KILOSORT_ENV")
or os.environ.get("KILOSORT_CONDA_ENV")
or "kilosort"
)


def _find_conda_runner() -> str:
conda_exe = os.environ.get("CONDA_EXE")
if conda_exe and Path(conda_exe).exists():
return conda_exe

for candidate in ("conda", "mamba", "micromamba"):
resolved = shutil.which(candidate)
if resolved:
return resolved

raise FileNotFoundError(
"Could not find a conda runner executable (tried CONDA_EXE, conda, mamba, micromamba)."
)


def _run_in_conda_env(
env_name: str, args: list[str], *, capture_output: bool = False
) -> subprocess.CompletedProcess:
runner = _find_conda_runner()
cmd = [runner, "run", "-n", env_name, *args]

# Workaround for WSL/DrvFs temp-file oddities (e.g., ftruncate -> ENOENT) when
# TEMP/TMP point to `/mnt/c/...`. Force a sane temp dir for subprocesses.
env = os.environ.copy()
if Path("/tmp").exists():
env["TMPDIR"] = "/tmp"
env["TEMP"] = "/tmp"
env["TMP"] = "/tmp"
try:
return subprocess.run(
cmd,
check=True,
capture_output=capture_output,
text=capture_output,
env=env,
)
except subprocess.CalledProcessError as e:
raise RuntimeError(
f"Command failed in conda env '{env_name}': {cmd} (exit code {e.returncode})."
) from e


def _check_kilosort_cuda(env_name: str) -> tuple[bool, str | None]:
code = textwrap.dedent(
"""
import torch

if torch.cuda.is_available():
print("CUDA_AVAILABLE")
print(torch.cuda.get_device_name(0))
else:
print("CUDA_NOT_AVAILABLE")
"""
).strip()

proc = _run_in_conda_env(env_name, ["python", "-c", code], capture_output=True)
lines = [line.strip() for line in (proc.stdout or "").splitlines() if line.strip()]
if not lines:
return False, None

if lines[0] == "CUDA_AVAILABLE":
return True, lines[1] if len(lines) > 1 else None

return False, None


def _run_kilosort_in_env(
*,
env_name: str,
settings: dict,
probe_name: str,
data_dir: Path,
results_dir: Path,
) -> None:
payload = dict(
settings=settings,
probe_name=probe_name,
data_dir=str(data_dir),
results_dir=str(results_dir),
save_preprocessed_copy=False,
verbose_console=False,
)

tmp_path: Path | None = None
try:
with tempfile.NamedTemporaryFile("w", suffix=".json", delete=False) as tmp:
json.dump(payload, tmp)
tmp_path = Path(tmp.name)

_run_in_conda_env(
env_name, ["python", "-c", _KILOSORT_RUNNER_CODE, str(tmp_path)]
)

except Exception as e:
raise RuntimeError(
f"Failed to run Kilosort in the separate conda env '{env_name}'. "
"Make sure it exists and has the `kilosort` package installed. "
"You can override the env name via BND_KILOSORT_ENV."
) from e

finally:
if tmp_path:
try:
tmp_path.unlink(missing_ok=True)
except Exception:
pass


def read_metadata(filepath: Path) -> dict:
"""Parse a section-less INI file (eg NPx metadata file) and return a dictionary of key-value pairs."""
with open(filepath, "r") as f:
Expand Down Expand Up @@ -120,27 +269,19 @@ def run_kilosort_on_stream(
)
ksort_output_path.mkdir(parents=True, exist_ok=True)

if not PROBE_DIR.exists():
logger.info("Probe directory not found, downloading probes")
download_probes()

if any(PROBE_DIR.glob(f"{probe_name}")):
# Sometimes the gateway can throw an error so just double check.
download_probes()

# Check if the metadata file is complete
# when SpikeGLX crashes, metadata misses some values.
_fix_session_ap_metadata(meta_file_path)
# Find out which probe type we have
probe_name = _read_probe_type(meta_file_path)

_ = run_kilosort(
env_name = _get_kilosort_env_name()
_run_kilosort_in_env(
env_name=env_name,
settings=sorter_params,
probe_name=probe_name,
data_dir=probe_folder_path,
results_dir=ksort_output_path,
save_preprocessed_copy=False,
verbose_console=False,
)
return

Expand Down Expand Up @@ -210,11 +351,17 @@ def run_kilosort_on_session(session_path: Path) -> None:

else:
ephys_recording_folders = config.get_subdirectories_from_pattern(session_path, "*_g?")
# Check kilosort is installed in environment
if torch.cuda.is_available():
logger.info(f"CUDA is available. GPU device: {torch.cuda.get_device_name(0)}")
env_name = _get_kilosort_env_name()
cuda_available, device_name = _check_kilosort_cuda(env_name)
if cuda_available:
if device_name:
logger.info(f"CUDA is available in '{env_name}'. GPU device: {device_name}")
else:
logger.info(f"CUDA is available in '{env_name}'.")
else:
logger.warning("CUDA is not available. GPU computations will not be enabled.")
logger.warning(
f"CUDA is not available in '{env_name}'. GPU computations will not be enabled."
)
if len(ephys_recording_folders) > 1:
raise ValueError(
"It seems you are trying to run kilosort without GPU. Look at the README on instrucstions of how to do this. "
Expand Down
Loading