From 1a4fae7cc4c66d52b96cd4a60b5d01df00944aa3 Mon Sep 17 00:00:00 2001 From: Maren Buettner Date: Thu, 2 May 2024 09:38:03 -0700 Subject: [PATCH 1/4] :sparkles: add colorblocks to baseplot --- scanpy/plotting/_baseplot_class.py | 170 ++++++++++++++++++++++++++++- 1 file changed, 167 insertions(+), 3 deletions(-) diff --git a/scanpy/plotting/_baseplot_class.py b/scanpy/plotting/_baseplot_class.py index d3b3acff02..4b81cfc60c 100644 --- a/scanpy/plotting/_baseplot_class.py +++ b/scanpy/plotting/_baseplot_class.py @@ -9,8 +9,12 @@ from warnings import warn import numpy as np +import pandas as pd from matplotlib import gridspec from matplotlib import pyplot as plt +from matplotlib.colors import BoundaryNorm, ListedColormap +from matplotlib import colormaps + from .. import logging as logg from .._compat import old_positionals @@ -19,9 +23,10 @@ if TYPE_CHECKING: from anndata import AnnData - from matplotlib.axes import Axes + from matplotlib.axes import Axes from matplotlib.colors import Normalize + _VarNames = Union[str, Sequence[str]] doc_common_groupby_plot_args = """\ @@ -384,6 +389,103 @@ def add_totals( "color": color, } return self + + def add_colorblocks( + self, + show: bool | None = True, + sort: Literal["ascending", "descending"] = None, + size: float | None = 0.5, + color: ColorLike | Sequence[ColorLike] | None = None, + ) -> BasePlot: + r"""\ + Show colorblocks of `groupby` category. + + The annotation colorblocks is by default shown on the right side of the plot or on top + if the axes are swapped. + + + Parameters + ---------- + show + Boolean to turn on (True) or off (False) 'add_colorblocks' + sort + Set to either 'ascending' or 'descending' to reorder the categories + by category name + size + size of the annotation colorblocks. Corresponds to width when shown on + the right of the plot, or height when shown on top. The unit is the same + as in matplotlib (inches). + color + Color for the bar plots or list of colors for each of the bar plots. + By default, each bar plot uses the colors assigned in + `adata.uns[{groupby}_colors]`. + + + Returns + ------- + Returns `self` for method chaining. + + + Examples + -------- + >>> import scanpy as sc + >>> adata = sc.datasets.pbmc68k_reduced() + >>> markers = {'T-cell': 'CD3D', 'B-cell': 'CD79A', 'myeloid': 'CST3'} + >>> plot = sc.pl._baseplot_class.BasePlot(adata, markers, groupby='bulk_labels').add_colorblocks() + >>> plot.plot_group_extra['counts_df'] # doctest: +SKIP + bulk_labels + CD4+/CD25 T Reg 68 + CD4+/CD45RA+/CD25- Naive T 8 + CD4+/CD45RO+ Memory 19 + CD8+ Cytotoxic T 54 + CD8+/CD45RA+ Naive Cytotoxic 43 + CD14+ Monocyte 129 + CD19+ B 95 + CD34+ 13 + CD56+ NK 31 + Dendritic 240 + Name: count, dtype: int64 + """ + self.group_extra_size = size + + if not show: + # hide colorblocks + self.plot_group_extra = None + self.group_extra_size = 0 + return self + + _sort = True if sort is not None else False + _ascending = True if sort == "ascending" else False + #counts_df = self.obs_tidy.index.value_counts(sort=_sort, ascending=_ascending) + + # determine groupby label positions such that they appear + # centered next/below to the color code rectangle assigned to the category + value_sum = 0 + #ticks = [] # list of centered position of the labels + labels = [] + label2code = {} # dictionary of numerical values asigned to each label + for code, (label, value) in enumerate( + self.obs_tidy.index.value_counts(sort=_sort, ascending=_ascending).items() + ): + #ticks.append(value_sum + (value / 2)) + labels.append(label) + #value_sum += value + label2code[label] = code + + counts_df = pd.Series(label2code) + + if _sort: + self.categories_order = counts_df.index + + self.plot_group_extra = { + "kind": "group_colors", + "width": size, + "sort": sort, + "counts_df": counts_df, + "color": color, + } + return self + @old_positionals("cmap") def style(self, *, cmap: str | None = DEFAULT_COLORMAP) -> BasePlot: @@ -457,6 +559,62 @@ def get_axes(self) -> dict[str, Axes]: self.make_figure() return self.ax_dict + def _plot_colorblocks( + self, group_color_ax: Axes, orientation: Literal["top", "right"] + ): + """ + Makes the annotation plot for group labels + """ + + params = self.plot_group_extra + counts_df = params["counts_df"] + if self.categories_order is not None: + counts_df = counts_df.loc[self.categories_order] + if params["color"] is None: + if f"{self.groupby}_colors" in self.adata.uns: + color = ListedColormap(self.adata.uns[f"{self.groupby}_colors"], + f"{self.groupby}_cmap") + else: + color = plt.get_cmap("tab20") + else: + if params["color"] in list(colormaps): + color = plt.get_cmap(params["color"]) + else: #if it is a list of colors + color = ListedColormap(params["color"], f"{self.groupby}_cmap") + #color scaling + norm = BoundaryNorm(np.arange(color.N + 1) - 0.5, color.N) + + if orientation == "top": + group_color_ax.imshow( + counts_df.values.reshape(1,-1), + aspect="auto", + extent = [0, len(counts_df), 1, 0], + cmap=color, + norm=norm + ) + + elif orientation == "right": + group_color_ax.imshow( + counts_df.values.reshape(-1,1), + aspect="auto", + extent = [0, 1, len(counts_df), 0], + cmap=color, + norm=norm + ) + + # remove x/y ticks and labels + group_color_ax.tick_params(axis="x", bottom=False, labelbottom=False) + group_color_ax.tick_params(axis="y", bottom=False, labelbottom=False) + + # remove surrounding lines + group_color_ax.spines["right"].set_visible(False) + group_color_ax.spines["top"].set_visible(False) + group_color_ax.spines["left"].set_visible(False) + group_color_ax.spines["bottom"].set_visible(False) + + group_color_ax.grid(False) + group_color_ax.axis("off") + def _plot_totals( self, total_barplot_ax: Axes, orientation: Literal["top", "right"] ): @@ -723,11 +881,14 @@ def make_figure(self): # second row is for brackets (if needed), # third row is for mainplot and dendrogram/totals (legend goes in gs[0,1] # defined earlier) + wspace_adj = self.wspace if self.plot_group_extra is None and self.are_axes_swapped is False else 0.05 + hspace_adj = 0.05 if self.plot_group_extra is not None and self.are_axes_swapped is True else 0 + mainplot_gs = gridspec.GridSpecFromSubplotSpec( nrows=3, ncols=2, - wspace=self.wspace, - hspace=0.0, + wspace=wspace_adj, + hspace=hspace_adj, subplot_spec=gs[0, 0], width_ratios=width_ratios, height_ratios=height_ratios, @@ -761,6 +922,9 @@ def make_figure(self): ) if self.plot_group_extra["kind"] == "group_totals": self._plot_totals(group_extra_ax, group_extra_orientation) + if self.plot_group_extra["kind"] == "group_colors": + self._plot_colorblocks(group_extra_ax, group_extra_orientation) + return_ax_dict["group_extra_ax"] = group_extra_ax From 1bb912d119771f5e91587cc083febcb92dfb00ed Mon Sep 17 00:00:00 2001 From: Maren Buettner Date: Thu, 2 May 2024 09:43:03 -0700 Subject: [PATCH 2/4] :lipstick: Run precommit --- scanpy/plotting/_baseplot_class.py | 27 ++++++++++++--------------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/scanpy/plotting/_baseplot_class.py b/scanpy/plotting/_baseplot_class.py index 4b81cfc60c..5767771d74 100644 --- a/scanpy/plotting/_baseplot_class.py +++ b/scanpy/plotting/_baseplot_class.py @@ -416,7 +416,7 @@ def add_colorblocks( the right of the plot, or height when shown on top. The unit is the same as in matplotlib (inches). color - Color for the bar plots or list of colors for each of the bar plots. + Colormap or list of colors for each of the colorblocks. By default, each bar plot uses the colors assigned in `adata.uns[{groupby}_colors]`. @@ -433,18 +433,17 @@ def add_colorblocks( >>> markers = {'T-cell': 'CD3D', 'B-cell': 'CD79A', 'myeloid': 'CST3'} >>> plot = sc.pl._baseplot_class.BasePlot(adata, markers, groupby='bulk_labels').add_colorblocks() >>> plot.plot_group_extra['counts_df'] # doctest: +SKIP - bulk_labels - CD4+/CD25 T Reg 68 - CD4+/CD45RA+/CD25- Naive T 8 - CD4+/CD45RO+ Memory 19 - CD8+ Cytotoxic T 54 - CD8+/CD45RA+ Naive Cytotoxic 43 - CD14+ Monocyte 129 - CD19+ B 95 - CD34+ 13 - CD56+ NK 31 - Dendritic 240 - Name: count, dtype: int64 + CD4+/CD25 T Reg 0 + CD4+/CD45RA+/CD25- Naive T 1 + CD4+/CD45RO+ Memory 2 + CD8+ Cytotoxic T 3 + CD8+/CD45RA+ Naive Cytotoxic 4 + CD14+ Monocyte 5 + CD19+ B 6 + CD34+ 7 + CD56+ NK 8 + Dendritic 9 + dtype: int64 """ self.group_extra_size = size @@ -460,8 +459,6 @@ def add_colorblocks( # determine groupby label positions such that they appear # centered next/below to the color code rectangle assigned to the category - value_sum = 0 - #ticks = [] # list of centered position of the labels labels = [] label2code = {} # dictionary of numerical values asigned to each label for code, (label, value) in enumerate( From abeb58b98a339f722ed98c0f5113b476d40273a6 Mon Sep 17 00:00:00 2001 From: Maren Buettner Date: Thu, 2 May 2024 09:45:58 -0700 Subject: [PATCH 3/4] :lipstick: more checks --- scanpy/plotting/_baseplot_class.py | 51 ++++++++++++++++-------------- 1 file changed, 28 insertions(+), 23 deletions(-) diff --git a/scanpy/plotting/_baseplot_class.py b/scanpy/plotting/_baseplot_class.py index 5767771d74..58c395a4b7 100644 --- a/scanpy/plotting/_baseplot_class.py +++ b/scanpy/plotting/_baseplot_class.py @@ -10,11 +10,9 @@ import numpy as np import pandas as pd -from matplotlib import gridspec +from matplotlib import colormaps, gridspec from matplotlib import pyplot as plt from matplotlib.colors import BoundaryNorm, ListedColormap -from matplotlib import colormaps - from .. import logging as logg from .._compat import old_positionals @@ -23,7 +21,7 @@ if TYPE_CHECKING: from anndata import AnnData - from matplotlib.axes import Axes + from matplotlib.axes import Axes from matplotlib.colors import Normalize @@ -389,7 +387,7 @@ def add_totals( "color": color, } return self - + def add_colorblocks( self, show: bool | None = True, @@ -455,7 +453,7 @@ def add_colorblocks( _sort = True if sort is not None else False _ascending = True if sort == "ascending" else False - #counts_df = self.obs_tidy.index.value_counts(sort=_sort, ascending=_ascending) + # counts_df = self.obs_tidy.index.value_counts(sort=_sort, ascending=_ascending) # determine groupby label positions such that they appear # centered next/below to the color code rectangle assigned to the category @@ -464,9 +462,9 @@ def add_colorblocks( for code, (label, value) in enumerate( self.obs_tidy.index.value_counts(sort=_sort, ascending=_ascending).items() ): - #ticks.append(value_sum + (value / 2)) + # ticks.append(value_sum + (value / 2)) labels.append(label) - #value_sum += value + # value_sum += value label2code[label] = code counts_df = pd.Series(label2code) @@ -483,7 +481,6 @@ def add_colorblocks( } return self - @old_positionals("cmap") def style(self, *, cmap: str | None = DEFAULT_COLORMAP) -> BasePlot: """\ @@ -562,41 +559,42 @@ def _plot_colorblocks( """ Makes the annotation plot for group labels """ - + params = self.plot_group_extra counts_df = params["counts_df"] if self.categories_order is not None: counts_df = counts_df.loc[self.categories_order] if params["color"] is None: if f"{self.groupby}_colors" in self.adata.uns: - color = ListedColormap(self.adata.uns[f"{self.groupby}_colors"], - f"{self.groupby}_cmap") + color = ListedColormap( + self.adata.uns[f"{self.groupby}_colors"], f"{self.groupby}_cmap" + ) else: color = plt.get_cmap("tab20") else: if params["color"] in list(colormaps): color = plt.get_cmap(params["color"]) - else: #if it is a list of colors + else: # if it is a list of colors color = ListedColormap(params["color"], f"{self.groupby}_cmap") - #color scaling + # color scaling norm = BoundaryNorm(np.arange(color.N + 1) - 0.5, color.N) if orientation == "top": group_color_ax.imshow( - counts_df.values.reshape(1,-1), + counts_df.values.reshape(1, -1), aspect="auto", - extent = [0, len(counts_df), 1, 0], + extent=[0, len(counts_df), 1, 0], cmap=color, - norm=norm + norm=norm, ) elif orientation == "right": group_color_ax.imshow( - counts_df.values.reshape(-1,1), + counts_df.values.reshape(-1, 1), aspect="auto", - extent = [0, 1, len(counts_df), 0], + extent=[0, 1, len(counts_df), 0], cmap=color, - norm=norm + norm=norm, ) # remove x/y ticks and labels @@ -878,8 +876,16 @@ def make_figure(self): # second row is for brackets (if needed), # third row is for mainplot and dendrogram/totals (legend goes in gs[0,1] # defined earlier) - wspace_adj = self.wspace if self.plot_group_extra is None and self.are_axes_swapped is False else 0.05 - hspace_adj = 0.05 if self.plot_group_extra is not None and self.are_axes_swapped is True else 0 + wspace_adj = ( + self.wspace + if self.plot_group_extra is None and self.are_axes_swapped is False + else 0.05 + ) + hspace_adj = ( + 0.05 + if self.plot_group_extra is not None and self.are_axes_swapped is True + else 0 + ) mainplot_gs = gridspec.GridSpecFromSubplotSpec( nrows=3, @@ -921,7 +927,6 @@ def make_figure(self): self._plot_totals(group_extra_ax, group_extra_orientation) if self.plot_group_extra["kind"] == "group_colors": self._plot_colorblocks(group_extra_ax, group_extra_orientation) - return_ax_dict["group_extra_ax"] = group_extra_ax From efcf8df4ba98b38150e3a39a4b7bf4dd18df5a54 Mon Sep 17 00:00:00 2001 From: Maren Buettner Date: Thu, 2 May 2024 13:27:19 -0700 Subject: [PATCH 4/4] :fire: remove redundant code --- docs/release-notes/1.10.2.md | 1 + scanpy/plotting/_baseplot_class.py | 10 ---------- 2 files changed, 1 insertion(+), 10 deletions(-) diff --git a/docs/release-notes/1.10.2.md b/docs/release-notes/1.10.2.md index c0822c5200..1524ab8044 100644 --- a/docs/release-notes/1.10.2.md +++ b/docs/release-notes/1.10.2.md @@ -4,6 +4,7 @@ ``` * Add performance benchmarking {pr}`2977` {smaller}`R Shrestha`, {smaller}`P Angerer` +* Add annotation colorblock to Baseplot {pr}`3043` {smaller}`M Büttner` ```{rubric} Docs ``` diff --git a/scanpy/plotting/_baseplot_class.py b/scanpy/plotting/_baseplot_class.py index 58c395a4b7..96664092df 100644 --- a/scanpy/plotting/_baseplot_class.py +++ b/scanpy/plotting/_baseplot_class.py @@ -597,16 +597,6 @@ def _plot_colorblocks( norm=norm, ) - # remove x/y ticks and labels - group_color_ax.tick_params(axis="x", bottom=False, labelbottom=False) - group_color_ax.tick_params(axis="y", bottom=False, labelbottom=False) - - # remove surrounding lines - group_color_ax.spines["right"].set_visible(False) - group_color_ax.spines["top"].set_visible(False) - group_color_ax.spines["left"].set_visible(False) - group_color_ax.spines["bottom"].set_visible(False) - group_color_ax.grid(False) group_color_ax.axis("off")