Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions benchmarks/asv.conf.json
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@
// "psutil": [""]
"pooch": [""],
"scikit-image": [""], // https://github.com/conda-forge/scikit-misc-feedstock/pull/29
// "scikit-misc": [""],
"scikit-misc": [""],
},

// Combinations of libraries/python versions can be excluded/included
Expand All @@ -104,7 +104,7 @@
// - environment_type
// Environment type, as above.
// - sys_platform
// Platform, as in sys.platform. Possible values for the common
// Platform, as in sys.platform. Possible values for the commonπ
// cases: 'linux2', 'win32', 'cygwin', 'darwin'.
//
// "exclude": [
Expand Down
39 changes: 27 additions & 12 deletions benchmarks/benchmarks/preprocessing_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class PreprocessingSuite: # noqa: D101

def setup_cache(self) -> None:
"""Without this caching, asv was running several processes which meant the data was repeatedly downloaded."""
for dataset, layer in product(*self.params):
for dataset, layer in product(*self.params[:2]):
adata, _ = get_dataset(dataset, layer=layer)
adata.write_h5ad(f"{dataset}_{layer}.h5ad")

Expand All @@ -47,17 +47,6 @@ def time_pca(self, *_) -> None:
def peakmem_pca(self, *_) -> None:
sc.pp.pca(self.adata, svd_solver="arpack")

def time_highly_variable_genes(self, *_) -> None:
# the default flavor runs on log-transformed data
sc.pp.highly_variable_genes(
self.adata, min_mean=0.0125, max_mean=3, min_disp=0.5
)

def peakmem_highly_variable_genes(self, *_) -> None:
sc.pp.highly_variable_genes(
self.adata, min_mean=0.0125, max_mean=3, min_disp=0.5
)

# regress_out is very slow for this dataset
@skip_when(dataset={"pbmc3k"})
def time_regress_out(self, *_) -> None:
Expand All @@ -72,3 +61,29 @@ def time_scale(self, *_) -> None:

def peakmem_scale(self, *_) -> None:
sc.pp.scale(self.adata, max_value=10)


class HVGSuite: # noqa: D101
params = (["seurat_v3", "cell_ranger", "seurat"],)
param_names = ("flavor",)

def setup_cache(self) -> None:
"""Without this caching, asv was running several processes which meant the data was repeatedly downloaded."""
adata, _ = get_dataset("pbmc3k")
adata.write_h5ad("pbmc3k.h5ad")

def setup(self, flavor) -> None:
self.adata = ad.read_h5ad("pbmc3k.h5ad")
sc.pp.filter_genes(self.adata, min_cells=3)
self.flavor = flavor

def time_highly_variable_genes(self, *_) -> None:
# the default flavor runs on log-transformed data
sc.pp.highly_variable_genes(
self.adata, min_mean=0.0125, max_mean=3, min_disp=0.5, flavor=self.flavor
)

def peakmem_highly_variable_genes(self, *_) -> None:
sc.pp.highly_variable_genes(
self.adata, min_mean=0.0125, max_mean=3, min_disp=0.5, flavor=self.flavor
)
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,8 @@ filterwarnings = [
"ignore:.*'(parseAll)'.*'(parse_all)':DeprecationWarning",
# igraph vs leidenalg warning
"ignore:The `igraph` implementation of leiden clustering:UserWarning",
"ignore:Detected unsupported threading environment:UserWarning",
"ignore:Cannot cache compiled function",
]

[tool.coverage.run]
Expand Down
34 changes: 27 additions & 7 deletions src/scanpy/preprocessing/_highly_variable_genes.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from fast_array_utils import stats

from .. import logging as logg
from .._compat import CSBase, CSRBase, DaskArray, old_positionals, warn
from .._compat import CSBase, CSRBase, DaskArray, njit, old_positionals, warn
from .._settings import Verbosity, settings
from .._utils import (
check_nonnegative_integers,
Expand Down Expand Up @@ -98,8 +98,7 @@ def _(data_batch: CSBase, clip_val: np.ndarray) -> tuple[np.ndarray, np.ndarray]
)


# parallel=False needed for accuracy
@numba.njit(cache=True, parallel=False) # noqa: TID251
@njit
def _sum_and_sum_squares_clipped(
indices: NDArray[np.integer],
data: NDArray[np.floating],
Expand All @@ -108,13 +107,34 @@ def _sum_and_sum_squares_clipped(
clip_val: NDArray[np.float64],
nnz: int,
) -> tuple[NDArray[np.float64], NDArray[np.float64]]:
squared_batch_counts_sum = np.zeros(n_cols, dtype=np.float64)
batch_counts_sum = np.zeros(n_cols, dtype=np.float64)
"""
Parallel implementation using thread-local buffers to avoid race conditions.

Previous implementation used parallel=False due to race condition on shared arrays.
This version uses explicit thread-local reduction to restore both correctness
and parallelism.
"""
# Thread-local accumulators for parallel reduction
n_threads = numba.get_num_threads()
squared_local = np.zeros((n_threads, n_cols), dtype=np.float64)
sum_local = np.zeros((n_threads, n_cols), dtype=np.float64)

# Parallel accumulation into thread-local buffers (no race condition)
for i in numba.prange(nnz):
tid = numba.get_thread_id()
idx = indices[i]
element = min(np.float64(data[i]), clip_val[idx])
squared_batch_counts_sum[idx] += element**2
batch_counts_sum[idx] += element
squared_local[tid, idx] += element**2
sum_local[tid, idx] += element

# Reduction phase: combine thread-local results
squared_batch_counts_sum = np.zeros(n_cols, dtype=np.float64)
batch_counts_sum = np.zeros(n_cols, dtype=np.float64)

for t in range(n_threads):
for j in range(n_cols):
squared_batch_counts_sum[j] += squared_local[t, j]
batch_counts_sum[j] += sum_local[t, j]

return squared_batch_counts_sum, batch_counts_sum

Expand Down
Loading