From 8d9dfef6f387dc0d3609b085477348ebca8202de Mon Sep 17 00:00:00 2001 From: Jonathan Dickinson Date: Sat, 7 Dec 2024 10:47:52 -0800 Subject: [PATCH 1/7] Add RLE plot to dds and utils --- pydeseq2/dds.py | 31 ++++++++++++++++++++++++ pydeseq2/utils.py | 61 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 92 insertions(+) diff --git a/pydeseq2/dds.py b/pydeseq2/dds.py index b421a504..4eda0010 100644 --- a/pydeseq2/dds.py +++ b/pydeseq2/dds.py @@ -21,6 +21,7 @@ from pydeseq2.preprocessing import deseq2_norm_transform from pydeseq2.utils import build_design_matrix from pydeseq2.utils import dispersion_trend +from pydeseq2.utils import make_rle_plot from pydeseq2.utils import make_scatter from pydeseq2.utils import mean_absolute_deviation from pydeseq2.utils import n_or_more_replicates @@ -1010,6 +1011,36 @@ def plot_dispersions( **kwargs, ) + def plot_rle( + self, + normalize: bool = False, + save_path: Optional[str] = None, + **kwargs, + ): + """Plot ratio of log expressions for each sample. + + Useful for visualizing sample to sample variation. + + Parameters + ---------- + normalize : bool, optional + Whether to normalize the counts before plotting. (default: ``False``). + + save_path : str or None + The path where to save the plot. If left None, the plot won't be saved + (default: ``None``). + + **kwargs + Keyword arguments for the scatter plot. + """ + make_rle_plot( + count_matrix=self.X, + normalize=normalize, + sample_ids=self.obsm["design_matrix"].index, + save_path=save_path, + **kwargs, + ) + def _fit_parametric_dispersion_trend(self, vst: bool = False): r"""Fit the dispersion curve according to a parametric model. diff --git a/pydeseq2/utils.py b/pydeseq2/utils.py index 95d93ac1..12268fc0 100644 --- a/pydeseq2/utils.py +++ b/pydeseq2/utils.py @@ -1608,3 +1608,64 @@ def lowess( delta = (1 - delta**2) ** 2 return yest + + +def make_rle_plot( + count_matrix: np.array, + sample_ids: np.array, + normalize: bool = False, + save_path: Optional[str] = None, + **kwargs, +) -> None: + """ + Create a ratio of log expression plot using matplotlib. + + Parameters + ---------- + count_matrix : np.ndarray + An mxn matrix of count data, where m is the number of samples (rows), + and n is the number of genes (columns). + + sample_ids : np.ndarray + An array of sample identifiers. + + normalize : bool + Whether to normalize the count matrix before plotting. (default: ``False``). + + save_path : str or None + The path where to save the plot. If left None, the plot won't be saved + (default: ``None``). + + **kwargs : + Additional keyword arguments passed to matplotlib's boxplot function. + """ + if normalize: + print("Plotting normalized RLE plot...") + geometric_mean = np.exp(np.mean(np.log(count_matrix + 1), axis=0)) + size_factors = np.median(count_matrix / geometric_mean, axis=1) + count_matrix = count_matrix / size_factors[:, np.newaxis] + + plt.rcParams.update({"font.size": 10}) + + fig, ax = plt.subplots(figsize=(15, 8), dpi=600) + + # Calculate median expression across samples + gene_medians = np.median(count_matrix, axis=0) + rle_values = np.log2(count_matrix / gene_medians) + + kwargs.setdefault("alpha", 0.5) + boxprops = {"facecolor": "lightgray", "alpha": kwargs.pop("alpha")} + + ax.boxplot(rle_values.T, patch_artist=True, boxprops=boxprops, **kwargs) + + ax.axhline(0, color="red", linestyle="--", linewidth=1, alpha=0.5, zorder=3) + ax.set_xlabel("Sample") + ax.set_ylabel("Relative Log Expression") + ax.set_xticks(np.arange(len(sample_ids))) + ax.set_xticklabels(sample_ids, rotation=90) + plt.tight_layout() + + if save_path: + plt.savefig(save_path, bbox_inches="tight") + else: + plt.show() From 7012b83fc137dc6be76be8640a3086a3ba6b5d2e Mon Sep 17 00:00:00 2001 From: Boris MUZELLEC Date: Wed, 11 Dec 2024 09:12:57 +0100 Subject: [PATCH 2/7] docs: fix typehint --- pydeseq2/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pydeseq2/utils.py b/pydeseq2/utils.py index 12268fc0..874e026d 100644 --- a/pydeseq2/utils.py +++ b/pydeseq2/utils.py @@ -1622,11 +1622,11 @@ def make_rle_plot( Parameters ---------- - count_matrix : np.ndarray + count_matrix : ndarray An mxn matrix of count data, where m is the number of samples (rows), and n is the number of genes (columns). - sample_ids : np.ndarray + sample_ids : ndarray An array of sample identifiers. normalize : bool From 9c60aee4523d6f669ba640fa550da4eacf18bc9f Mon Sep 17 00:00:00 2001 From: Boris MUZELLEC Date: Tue, 17 Dec 2024 09:53:49 +0100 Subject: [PATCH 3/7] test: add plot_rle test --- tests/test_pydeseq2.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/test_pydeseq2.py b/tests/test_pydeseq2.py index 18e57011..b4563238 100644 --- a/tests/test_pydeseq2.py +++ b/tests/test_pydeseq2.py @@ -875,3 +875,16 @@ def assert_res_almost_equal(py_res, r_res, tol=0.02): ).max() < tol assert (abs(r_res.pvalue - py_res.pvalue) / r_res.pvalue).max() < tol assert (abs(r_res.padj - py_res.padj) / r_res.padj).max() < tol + + +def test_plot_rle(train_counts, train_metadata): + """Test that the RLE plot is generated without error.""" + + dds = DeseqDataSet( + counts=train_counts, + metadata=train_metadata, + design_factors="condition", + ) + + dds.plot_rle(normalize=False) + dds.plot_rle(normalize=True) From 5141dea862cc5455d24aef0ee8737ce097625718 Mon Sep 17 00:00:00 2001 From: Boris MUZELLEC Date: Tue, 17 Dec 2024 09:54:04 +0100 Subject: [PATCH 4/7] refactor: use obs_names --- pydeseq2/dds.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pydeseq2/dds.py b/pydeseq2/dds.py index 4eda0010..a1fd390e 100644 --- a/pydeseq2/dds.py +++ b/pydeseq2/dds.py @@ -1036,7 +1036,7 @@ def plot_rle( make_rle_plot( count_matrix=self.X, normalize=normalize, - sample_ids=self.obsm["design_matrix"].index, + sample_ids=self.obs_names, save_path=save_path, **kwargs, ) From 6e797e2df14d559c74239f8df75be6a54cdbe1dd Mon Sep 17 00:00:00 2001 From: Boris MUZELLEC Date: Tue, 17 Dec 2024 09:56:42 +0100 Subject: [PATCH 5/7] refactor: fix type hints and remove print --- pydeseq2/dds.py | 4 ++-- pydeseq2/utils.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pydeseq2/dds.py b/pydeseq2/dds.py index a1fd390e..5f1bfae8 100644 --- a/pydeseq2/dds.py +++ b/pydeseq2/dds.py @@ -1014,7 +1014,7 @@ def plot_dispersions( def plot_rle( self, normalize: bool = False, - save_path: Optional[str] = None, + save_path: str | None = None, **kwargs, ): """Plot ratio of log expressions for each sample. @@ -1026,7 +1026,7 @@ def plot_rle( normalize : bool, optional Whether to normalize the counts before plotting. (default: ``False``). - save_path : str or None + save_path : str, optional The path where to save the plot. If left None, the plot won't be saved (default: ``None``). diff --git a/pydeseq2/utils.py b/pydeseq2/utils.py index 874e026d..8b9f4404 100644 --- a/pydeseq2/utils.py +++ b/pydeseq2/utils.py @@ -1614,7 +1614,7 @@ def make_rle_plot( count_matrix: np.array, sample_ids: np.array, normalize: bool = False, - save_path: Optional[str] = None, + save_path: str | None = None, **kwargs, ) -> None: """ @@ -1632,7 +1632,7 @@ def make_rle_plot( normalize : bool Whether to normalize the count matrix before plotting. (default: ``False``). - save_path : str or None + save_path : str, optional The path where to save the plot. If left None, the plot won't be saved (default: ``None``). From d563a3e1d6f374447c52c5beddc94deaab988bca Mon Sep 17 00:00:00 2001 From: Boris MUZELLEC Date: Tue, 17 Dec 2024 09:59:00 +0100 Subject: [PATCH 6/7] docs: add plots to docs --- docs/source/api/docstrings/pydeseq2.dds.DeseqDataSet.rst | 1 + docs/source/api/docstrings/pydeseq2.ds.DeseqStats.rst | 1 + pydeseq2/dds.py | 2 +- pydeseq2/utils.py | 1 - 4 files changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/source/api/docstrings/pydeseq2.dds.DeseqDataSet.rst b/docs/source/api/docstrings/pydeseq2.dds.DeseqDataSet.rst index e4388bba..376125d2 100644 --- a/docs/source/api/docstrings/pydeseq2.dds.DeseqDataSet.rst +++ b/docs/source/api/docstrings/pydeseq2.dds.DeseqDataSet.rst @@ -18,6 +18,7 @@ ~DeseqDataSet.fit_genewise_dispersions ~DeseqDataSet.fit_size_factors ~DeseqDataSet.plot_dispersions + ~DeseqDataSet.plot_rle ~DeseqDataSet.refit ~DeseqDataSet.vst diff --git a/docs/source/api/docstrings/pydeseq2.ds.DeseqStats.rst b/docs/source/api/docstrings/pydeseq2.ds.DeseqStats.rst index 22275d83..f55afe29 100644 --- a/docs/source/api/docstrings/pydeseq2.ds.DeseqStats.rst +++ b/docs/source/api/docstrings/pydeseq2.ds.DeseqStats.rst @@ -12,6 +12,7 @@ ~DeseqStats.lfc_shrink ~DeseqStats.run_wald_test ~DeseqStats.summary + ~DeseqStats.plot_MA diff --git a/pydeseq2/dds.py b/pydeseq2/dds.py index 5f1bfae8..eb759d2c 100644 --- a/pydeseq2/dds.py +++ b/pydeseq2/dds.py @@ -1017,7 +1017,7 @@ def plot_rle( save_path: str | None = None, **kwargs, ): - """Plot ratio of log expressions for each sample. + """Plot ratio of log expressions (RLE) for each sample. Useful for visualizing sample to sample variation. diff --git a/pydeseq2/utils.py b/pydeseq2/utils.py index 8b9f4404..f4e9f80d 100644 --- a/pydeseq2/utils.py +++ b/pydeseq2/utils.py @@ -1640,7 +1640,6 @@ def make_rle_plot( Additional keyword arguments passed to matplotlib's boxplot function. """ if normalize: - print("Plotting normalized RLE plot...") geometric_mean = np.exp(np.mean(np.log(count_matrix + 1), axis=0)) size_factors = np.median(count_matrix / geometric_mean, axis=1) count_matrix = count_matrix / size_factors[:, np.newaxis] From 6c93fed3271f7fe1703fe9eb78582ad858aac6d1 Mon Sep 17 00:00:00 2001 From: Boris MUZELLEC Date: Tue, 17 Dec 2024 10:05:22 +0100 Subject: [PATCH 7/7] fix: test --- tests/test_pydeseq2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_pydeseq2.py b/tests/test_pydeseq2.py index b4563238..2ba15821 100644 --- a/tests/test_pydeseq2.py +++ b/tests/test_pydeseq2.py @@ -883,7 +883,7 @@ def test_plot_rle(train_counts, train_metadata): dds = DeseqDataSet( counts=train_counts, metadata=train_metadata, - design_factors="condition", + design="~condition", ) dds.plot_rle(normalize=False)