diff --git a/src/scanpy/preprocessing/_simple.py b/src/scanpy/preprocessing/_simple.py index d0e2dac04c..bd584d3b73 100644 --- a/src/scanpy/preprocessing/_simple.py +++ b/src/scanpy/preprocessing/_simple.py @@ -674,6 +674,7 @@ def regress_out( layer: str | None = None, n_jobs: int | None = None, copy: bool = False, + add_intercept: bool = False, ) -> AnnData | None: """Regress out (mostly) unwanted sources of variation. @@ -686,14 +687,16 @@ def regress_out( adata The annotated data matrix. keys - Keys for observation annotation on which to regress on. + Keys for observation annotation on which to regress. layer - If provided, which element of layers to regress on. + If provided, which element of layers to use in regression. n_jobs Number of jobs for parallel computation. `None` means using :attr:`scanpy.settings.n_jobs`. copy Determines whether a copy of `adata` is returned. + add_intercept + If True, regress_out will add intercept back to residuals in order to transform results back into gene-count space. Returns ------- @@ -787,7 +790,10 @@ def regress_out( # TODO: figure out how to test that this doesn't oversubscribe resources res = Parallel(n_jobs=n_jobs)( delayed(_regress_out_chunk)( - data_chunk, regres, variable_is_categorical=variable_is_categorical + data_chunk, + regres, + variable_is_categorical=variable_is_categorical, + add_intercept=add_intercept, ) for data_chunk, regres in zip(chunk_list, regressors_chunk, strict=False) ) @@ -806,6 +812,7 @@ def _regress_out_chunk( regressors: pd.DataFrame | NDArray[np.floating], *, variable_is_categorical: bool, + add_intercept: bool, ) -> NDArray[np.floating]: import statsmodels.api as sm import statsmodels.tools.sm_exceptions as sme @@ -822,6 +829,8 @@ def _regress_out_chunk( regres = np.c_[np.ones(regressors.shape[0]), regressors[:, col_index]] else: regres = regressors + if add_intercept: # add constant to regres to get intercept in results + regres = sm.add_constant(regres) try: with warnings.catch_warnings(): @@ -830,6 +839,8 @@ def _regress_out_chunk( data_chunk[:, col_index], regres, family=sm.families.Gaussian() ).fit() new_column = result.resid_response + if add_intercept: # calculate result as resid + intercept + new_column += result.params.iloc[0] except (sme.PerfectSeparationError, sme.PerfectSeparationWarning): logg.warning("Encountered perfect separation, setting to 0 as in R.") new_column = np.zeros(data_chunk.shape[0])