diff --git a/benchmarks/asv.conf.json b/benchmarks/asv.conf.json index d19b822178..32b9ca1a8f 100644 --- a/benchmarks/asv.conf.json +++ b/benchmarks/asv.conf.json @@ -83,7 +83,7 @@ // "psutil": [""] "pooch": [""], "scikit-image": [""], // https://github.com/conda-forge/scikit-misc-feedstock/pull/29 - // "scikit-misc": [""], + "scikit-misc": [""], }, // Combinations of libraries/python versions can be excluded/included @@ -104,7 +104,7 @@ // - environment_type // Environment type, as above. // - sys_platform - // Platform, as in sys.platform. Possible values for the common + // Platform, as in sys.platform. Possible values for the commonπ // cases: 'linux2', 'win32', 'cygwin', 'darwin'. // // "exclude": [ diff --git a/benchmarks/benchmarks/preprocessing_log.py b/benchmarks/benchmarks/preprocessing_log.py index 9633c8e208..318a8ce82c 100644 --- a/benchmarks/benchmarks/preprocessing_log.py +++ b/benchmarks/benchmarks/preprocessing_log.py @@ -34,7 +34,7 @@ class PreprocessingSuite: # noqa: D101 def setup_cache(self) -> None: """Without this caching, asv was running several processes which meant the data was repeatedly downloaded.""" - for dataset, layer in product(*self.params): + for dataset, layer in product(*self.params[:2]): adata, _ = get_dataset(dataset, layer=layer) adata.write_h5ad(f"{dataset}_{layer}.h5ad") @@ -47,17 +47,6 @@ def time_pca(self, *_) -> None: def peakmem_pca(self, *_) -> None: sc.pp.pca(self.adata, svd_solver="arpack") - def time_highly_variable_genes(self, *_) -> None: - # the default flavor runs on log-transformed data - sc.pp.highly_variable_genes( - self.adata, min_mean=0.0125, max_mean=3, min_disp=0.5 - ) - - def peakmem_highly_variable_genes(self, *_) -> None: - sc.pp.highly_variable_genes( - self.adata, min_mean=0.0125, max_mean=3, min_disp=0.5 - ) - # regress_out is very slow for this dataset @skip_when(dataset={"pbmc3k"}) def time_regress_out(self, *_) -> None: @@ -72,3 +61,29 @@ def time_scale(self, *_) -> None: def peakmem_scale(self, *_) -> None: sc.pp.scale(self.adata, max_value=10) + + +class HVGSuite: # noqa: D101 + params = (["seurat_v3", "cell_ranger", "seurat"],) + param_names = ("flavor",) + + def setup_cache(self) -> None: + """Without this caching, asv was running several processes which meant the data was repeatedly downloaded.""" + adata, _ = get_dataset("pbmc3k") + adata.write_h5ad("pbmc3k.h5ad") + + def setup(self, flavor) -> None: + self.adata = ad.read_h5ad("pbmc3k.h5ad") + sc.pp.filter_genes(self.adata, min_cells=3) + self.flavor = flavor + + def time_highly_variable_genes(self, *_) -> None: + # the default flavor runs on log-transformed data + sc.pp.highly_variable_genes( + self.adata, min_mean=0.0125, max_mean=3, min_disp=0.5, flavor=self.flavor + ) + + def peakmem_highly_variable_genes(self, *_) -> None: + sc.pp.highly_variable_genes( + self.adata, min_mean=0.0125, max_mean=3, min_disp=0.5, flavor=self.flavor + ) diff --git a/pyproject.toml b/pyproject.toml index c61961606c..16fcae67d1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -198,6 +198,8 @@ filterwarnings = [ "ignore:.*'(parseAll)'.*'(parse_all)':DeprecationWarning", # igraph vs leidenalg warning "ignore:The `igraph` implementation of leiden clustering:UserWarning", + "ignore:Detected unsupported threading environment:UserWarning", + "ignore:Cannot cache compiled function", ] [tool.coverage.run] diff --git a/src/scanpy/preprocessing/_highly_variable_genes.py b/src/scanpy/preprocessing/_highly_variable_genes.py index 26d5a7763e..845b1fd21e 100644 --- a/src/scanpy/preprocessing/_highly_variable_genes.py +++ b/src/scanpy/preprocessing/_highly_variable_genes.py @@ -13,7 +13,7 @@ from fast_array_utils import stats from .. import logging as logg -from .._compat import CSBase, CSRBase, DaskArray, old_positionals, warn +from .._compat import CSBase, CSRBase, DaskArray, njit, old_positionals, warn from .._settings import Verbosity, settings from .._utils import ( check_nonnegative_integers, @@ -98,8 +98,7 @@ def _(data_batch: CSBase, clip_val: np.ndarray) -> tuple[np.ndarray, np.ndarray] ) -# parallel=False needed for accuracy -@numba.njit(cache=True, parallel=False) # noqa: TID251 +@njit def _sum_and_sum_squares_clipped( indices: NDArray[np.integer], data: NDArray[np.floating], @@ -108,13 +107,34 @@ def _sum_and_sum_squares_clipped( clip_val: NDArray[np.float64], nnz: int, ) -> tuple[NDArray[np.float64], NDArray[np.float64]]: - squared_batch_counts_sum = np.zeros(n_cols, dtype=np.float64) - batch_counts_sum = np.zeros(n_cols, dtype=np.float64) + """ + Parallel implementation using thread-local buffers to avoid race conditions. + + Previous implementation used parallel=False due to race condition on shared arrays. + This version uses explicit thread-local reduction to restore both correctness + and parallelism. + """ + # Thread-local accumulators for parallel reduction + n_threads = numba.get_num_threads() + squared_local = np.zeros((n_threads, n_cols), dtype=np.float64) + sum_local = np.zeros((n_threads, n_cols), dtype=np.float64) + + # Parallel accumulation into thread-local buffers (no race condition) for i in numba.prange(nnz): + tid = numba.get_thread_id() idx = indices[i] element = min(np.float64(data[i]), clip_val[idx]) - squared_batch_counts_sum[idx] += element**2 - batch_counts_sum[idx] += element + squared_local[tid, idx] += element**2 + sum_local[tid, idx] += element + + # Reduction phase: combine thread-local results + squared_batch_counts_sum = np.zeros(n_cols, dtype=np.float64) + batch_counts_sum = np.zeros(n_cols, dtype=np.float64) + + for t in range(n_threads): + for j in range(n_cols): + squared_batch_counts_sum[j] += squared_local[t, j] + batch_counts_sum[j] += sum_local[t, j] return squared_batch_counts_sum, batch_counts_sum