Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions src/scanpy/preprocessing/_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,6 +674,7 @@ def regress_out(
layer: str | None = None,
n_jobs: int | None = None,
copy: bool = False,
add_intercept: bool = False,
) -> AnnData | None:
"""Regress out (mostly) unwanted sources of variation.

Expand All @@ -686,14 +687,16 @@ def regress_out(
adata
The annotated data matrix.
keys
Keys for observation annotation on which to regress on.
Keys for observation annotation on which to regress.
layer
If provided, which element of layers to regress on.
If provided, which element of layers to use in regression.
n_jobs
Number of jobs for parallel computation.
`None` means using :attr:`scanpy.settings.n_jobs`.
copy
Determines whether a copy of `adata` is returned.
add_intercept
If True, regress_out will add intercept back to residuals in order to transform results back into gene-count space.

Returns
-------
Expand Down Expand Up @@ -787,7 +790,10 @@ def regress_out(
# TODO: figure out how to test that this doesn't oversubscribe resources
res = Parallel(n_jobs=n_jobs)(
delayed(_regress_out_chunk)(
data_chunk, regres, variable_is_categorical=variable_is_categorical
data_chunk,
regres,
variable_is_categorical=variable_is_categorical,
add_intercept=add_intercept,
)
for data_chunk, regres in zip(chunk_list, regressors_chunk, strict=False)
)
Expand All @@ -806,6 +812,7 @@ def _regress_out_chunk(
regressors: pd.DataFrame | NDArray[np.floating],
*,
variable_is_categorical: bool,
add_intercept: bool,
) -> NDArray[np.floating]:
import statsmodels.api as sm
import statsmodels.tools.sm_exceptions as sme
Expand All @@ -822,6 +829,8 @@ def _regress_out_chunk(
regres = np.c_[np.ones(regressors.shape[0]), regressors[:, col_index]]
else:
regres = regressors
if add_intercept: # add constant to regres to get intercept in results
regres = sm.add_constant(regres)

try:
with warnings.catch_warnings():
Expand All @@ -830,6 +839,8 @@ def _regress_out_chunk(
data_chunk[:, col_index], regres, family=sm.families.Gaussian()
).fit()
new_column = result.resid_response
if add_intercept: # calculate result as resid + intercept
new_column += result.params.iloc[0]
except (sme.PerfectSeparationError, sme.PerfectSeparationWarning):
logg.warning("Encountered perfect separation, setting to 0 as in R.")
new_column = np.zeros(data_chunk.shape[0])
Expand Down
Loading