Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions benchmarks/asv.conf.json
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@
// "psutil": [""]
"pooch": [""],
"scikit-image": [""], // https://github.com/conda-forge/scikit-misc-feedstock/pull/29
// "scikit-misc": [""],
"scikit-misc": [""],
},

// Combinations of libraries/python versions can be excluded/included
Expand All @@ -104,7 +104,7 @@
// - environment_type
// Environment type, as above.
// - sys_platform
// Platform, as in sys.platform. Possible values for the common
// Platform, as in sys.platform. Possible values for the commonπ
// cases: 'linux2', 'win32', 'cygwin', 'darwin'.
//
// "exclude": [
Expand Down
39 changes: 27 additions & 12 deletions benchmarks/benchmarks/preprocessing_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class PreprocessingSuite: # noqa: D101

def setup_cache(self) -> None:
"""Without this caching, asv was running several processes which meant the data was repeatedly downloaded."""
for dataset, layer in product(*self.params):
for dataset, layer in product(*self.params[:2]):
adata, _ = get_dataset(dataset, layer=layer)
adata.write_h5ad(f"{dataset}_{layer}.h5ad")

Expand All @@ -47,17 +47,6 @@ def time_pca(self, *_) -> None:
def peakmem_pca(self, *_) -> None:
sc.pp.pca(self.adata, svd_solver="arpack")

def time_highly_variable_genes(self, *_) -> None:
# the default flavor runs on log-transformed data
sc.pp.highly_variable_genes(
self.adata, min_mean=0.0125, max_mean=3, min_disp=0.5
)

def peakmem_highly_variable_genes(self, *_) -> None:
sc.pp.highly_variable_genes(
self.adata, min_mean=0.0125, max_mean=3, min_disp=0.5
)

# regress_out is very slow for this dataset
@skip_when(dataset={"pbmc3k"})
def time_regress_out(self, *_) -> None:
Expand All @@ -72,3 +61,29 @@ def time_scale(self, *_) -> None:

def peakmem_scale(self, *_) -> None:
sc.pp.scale(self.adata, max_value=10)


class HVGSuite: # noqa: D101
params = (["seurat_v3", "cell_ranger", "seurat"],)
param_names = ("flavor",)

def setup_cache(self) -> None:
"""Without this caching, asv was running several processes which meant the data was repeatedly downloaded."""
adata, _ = get_dataset("pbmc3k")
adata.write_h5ad("pbmc3k.h5ad")

def setup(self, flavor) -> None:
self.adata = ad.read_h5ad("pbmc3k.h5ad")
sc.pp.filter_genes(self.adata, min_cells=3)
self.flavor = flavor

def time_highly_variable_genes(self, *_) -> None:
# the default flavor runs on log-transformed data
sc.pp.highly_variable_genes(
self.adata, min_mean=0.0125, max_mean=3, min_disp=0.5, flavor=self.flavor
)

def peakmem_highly_variable_genes(self, *_) -> None:
sc.pp.highly_variable_genes(
self.adata, min_mean=0.0125, max_mean=3, min_disp=0.5, flavor=self.flavor
)
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,8 @@ filterwarnings = [
"ignore:.*'(parseAll)'.*'(parse_all)':DeprecationWarning",
# igraph vs leidenalg warning
"ignore:The `igraph` implementation of leiden clustering:UserWarning",
"ignore:Detected unsupported threading environment:UserWarning",
"ignore:Cannot cache compiled function",
]

[tool.coverage.run]
Expand Down
69 changes: 47 additions & 22 deletions src/scanpy/preprocessing/_highly_variable_genes.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from fast_array_utils import stats

from .. import logging as logg
from .._compat import CSBase, CSRBase, DaskArray, old_positionals, warn
from .._compat import CSBase, CSRBase, DaskArray, njit, old_positionals, warn
from .._settings import Verbosity, settings
from .._utils import (
check_nonnegative_integers,
Expand Down Expand Up @@ -92,31 +92,50 @@ def _(data_batch: CSBase, clip_val: np.ndarray) -> tuple[np.ndarray, np.ndarray]
return _sum_and_sum_squares_clipped(
batch_counts.indices,
batch_counts.data,
batch_counts.indptr,
n_rows=batch_counts.shape[0],
n_cols=batch_counts.shape[1],
clip_val=clip_val,
nnz=batch_counts.nnz,
)


# parallel=False needed for accuracy
@numba.njit(cache=True, parallel=False) # noqa: TID251
@njit
def _sum_and_sum_squares_clipped(
indices: NDArray[np.integer],
data: NDArray[np.floating],
indices: np.ndarray,
data: np.ndarray,
indptr: np.ndarray,
*,
n_rows: int,
n_cols: int,
clip_val: NDArray[np.float64],
nnz: int,
) -> tuple[NDArray[np.float64], NDArray[np.float64]]:
squared_batch_counts_sum = np.zeros(n_cols, dtype=np.float64)
batch_counts_sum = np.zeros(n_cols, dtype=np.float64)
for i in numba.prange(nnz):
idx = indices[i]
element = min(np.float64(data[i]), clip_val[idx])
squared_batch_counts_sum[idx] += element**2
batch_counts_sum[idx] += element
clip_val: np.ndarray,
):
n_threads = numba.get_num_threads()

return squared_batch_counts_sum, batch_counts_sum
# Each thread gets its own private buffer to avoid race conditions
sum_local = np.zeros((n_threads, n_cols), dtype=np.float64)
squared_local = np.zeros((n_threads, n_cols), dtype=np.float64)

# We parallelize over the rows of the sparse matrix
for tid in numba.prange(n_threads):
for r in range(tid, n_rows, n_threads):
for i in range(indptr[r], indptr[r + 1]):
col_idx = indices[i]
val = np.float64(data[i])
element = min(val, clip_val[col_idx])
# Use the thread's private buffer slice
sum_local[tid, col_idx] += element
squared_local[tid, col_idx] += element**2

# Reduction phase (merging the thread buffers)
final_sum = np.zeros(n_cols, dtype=np.float64)
final_squared = np.zeros(n_cols, dtype=np.float64)

for t in range(n_threads):
for c in range(n_cols):
final_sum[c] += sum_local[t, c]
final_squared[c] += squared_local[t, c]

return final_squared, final_sum


def _highly_variable_genes_seurat_v3( # noqa: PLR0912, PLR0915
Expand Down Expand Up @@ -176,13 +195,19 @@ def _highly_variable_genes_seurat_v3( # noqa: PLR0912, PLR0915
)

norm_gene_vars = []
for b in np.unique(batch_info):
unique_batches = np.unique(batch_info)
n_batches = len(unique_batches)

for b in unique_batches:
data_batch = data[batch_info == b]

mean, var = stats.mean_var(data_batch, axis=0, correction=1)
# These get computed anyway for loess
if isinstance(mean, DaskArray):
mean, var = mean.compute(), var.compute()
if n_batches > 1:
mean, var = stats.mean_var(data_batch, axis=0, correction=1)
# Compute Dask arrays since loess requires in-memory data
if isinstance(mean, DaskArray):
mean, var = mean.compute(), var.compute()
else:
mean, var = df["means"].to_numpy(), df["variances"].to_numpy()
not_const = var > 0
estimat_var = np.zeros(data.shape[1], dtype=np.float64)

Expand Down
Loading