diff --git a/docs/release-notes/3042.feature.md b/docs/release-notes/3042.feature.md new file mode 100644 index 0000000000..6545501263 --- /dev/null +++ b/docs/release-notes/3042.feature.md @@ -0,0 +1 @@ +Add annotation colorblock to Baseplot {pr}`3043` {smaller}`M Büttner` diff --git a/src/scanpy/plotting/_baseplot_class.py b/src/scanpy/plotting/_baseplot_class.py index c829f36efb..7a6b9e12ac 100644 --- a/src/scanpy/plotting/_baseplot_class.py +++ b/src/scanpy/plotting/_baseplot_class.py @@ -7,8 +7,10 @@ from warnings import warn import numpy as np +import pandas as pd from matplotlib import colormaps, gridspec from matplotlib import pyplot as plt +from matplotlib.colors import BoundaryNorm, ListedColormap from .. import logging as logg from .._compat import old_positionals @@ -26,7 +28,6 @@ from collections.abc import Sequence from typing import Literal, Self - import pandas as pd from anndata import AnnData from matplotlib.axes import Axes from matplotlib.colors import Colormap, Normalize @@ -409,6 +410,101 @@ def add_totals( } return self + def add_colorblocks( + self, + *, + show: bool | None = True, + sort: Literal["ascending", "descending"] = None, + size: float | None = 0.5, + color: ColorLike | Sequence[ColorLike] | None = None, + ) -> BasePlot: + r"""Show colorblocks of `groupby` category. + + The annotation colorblocks is by default shown on the right side of the plot or on top + if the axes are swapped. + + + Parameters + ---------- + show + Boolean to turn on (True) or off (False) 'add_colorblocks' + sort + Set to either 'ascending' or 'descending' to reorder the categories + by category name + size + size of the annotation colorblocks. Corresponds to width when shown on + the right of the plot, or height when shown on top. The unit is the same + as in matplotlib (inches). + color + Colormap or list of colors for each of the colorblocks. + By default, each bar plot uses the colors assigned in + `adata.uns[{groupby}_colors]`. + + + Returns + ------- + Returns `self` for method chaining. + + + Examples + -------- + >>> import scanpy as sc + >>> adata = sc.datasets.pbmc68k_reduced() + >>> markers = {"T-cell": "CD3D", "B-cell": "CD79A", "myeloid": "CST3"} + >>> plot = sc.pl._baseplot_class.BasePlot( + ... adata, markers, groupby="bulk_labels" + ... ).add_colorblocks() + >>> plot.plot_group_extra["counts_df"] # doctest: +SKIP + CD4+/CD25 T Reg 0 + CD4+/CD45RA+/CD25- Naive T 1 + CD4+/CD45RO+ Memory 2 + CD8+ Cytotoxic T 3 + CD8+/CD45RA+ Naive Cytotoxic 4 + CD14+ Monocyte 5 + CD19+ B 6 + CD34+ 7 + CD56+ NK 8 + Dendritic 9 + dtype: int64 + """ + self.group_extra_size = size + + if not show: + # hide colorblocks + self.plot_group_extra = None + self.group_extra_size = 0 + return self + + _sort = sort is not None + _ascending = sort == "ascending" + # counts_df = self.obs_tidy.index.value_counts(sort=_sort, ascending=_ascending) + + # determine groupby label positions such that they appear + # centered next/below to the color code rectangle assigned to the category + labels = [] + label2code = {} # dictionary of numerical values asigned to each label + for code, (label, value) in enumerate( + self.obs_tidy.index.value_counts(sort=_sort, ascending=_ascending).items() + ): + # ticks.append(value_sum + (value / 2)) + labels.append(label) + # value_sum += value + label2code[label] = code + + counts_df = pd.Series(label2code) + + if _sort: + self.categories_order = counts_df.index + + self.plot_group_extra = { + "kind": "group_colors", + "width": size, + "sort": sort, + "counts_df": counts_df, + "color": color, + } + return self + @old_positionals("cmap") def style(self, *, cmap: Colormap | str | None | Empty = _empty) -> Self: r"""Set visual style parameters. @@ -480,6 +576,50 @@ def get_axes(self) -> dict[str, Axes]: self.make_figure() return self.ax_dict + def _plot_colorblocks( + self, group_color_ax: Axes, orientation: Literal["top", "right"] + ): + """Make the annotation plot for group labels.""" + params = self.plot_group_extra + counts_df = params["counts_df"] + if self.categories_order is not None: + counts_df = counts_df.loc[self.categories_order] + if params["color"] is None: + if f"{self.groupby}_colors" in self.adata.uns: + color = ListedColormap( + self.adata.uns[f"{self.groupby}_colors"], f"{self.groupby}_cmap" + ) + else: + color = plt.get_cmap("tab20") + else: + if params["color"] in list(colormaps): + color = plt.get_cmap(params["color"]) + else: # if it is a list of colors + color = ListedColormap(params["color"], f"{self.groupby}_cmap") + # color scaling + norm = BoundaryNorm(np.arange(color.N + 1) - 0.5, color.N) + + if orientation == "top": + group_color_ax.imshow( + counts_df.values.reshape(1, -1), + aspect="auto", + extent=[0, len(counts_df), 1, 0], + cmap=color, + norm=norm, + ) + + elif orientation == "right": + group_color_ax.imshow( + counts_df.values.reshape(-1, 1), + aspect="auto", + extent=[0, 1, len(counts_df), 0], + cmap=color, + norm=norm, + ) + + group_color_ax.grid(visible=False) + group_color_ax.axis("off") + def _plot_totals( self, total_barplot_ax: Axes, orientation: Literal["top", "right"] ): @@ -740,11 +880,22 @@ def make_figure(self): # second row is for brackets (if needed), # third row is for mainplot and dendrogram/totals (legend goes in gs[0,1] # defined earlier) + wspace_adj = ( + self.wspace + if self.plot_group_extra is None and self.are_axes_swapped is False + else 0.05 + ) + hspace_adj = ( + 0.05 + if self.plot_group_extra is not None and self.are_axes_swapped is True + else 0 + ) + mainplot_gs = gridspec.GridSpecFromSubplotSpec( nrows=3, ncols=2, - wspace=self.wspace, - hspace=0.0, + wspace=wspace_adj, + hspace=hspace_adj, subplot_spec=gs[0, 0], width_ratios=width_ratios, height_ratios=height_ratios, @@ -778,6 +929,8 @@ def make_figure(self): ) if self.plot_group_extra["kind"] == "group_totals": self._plot_totals(group_extra_ax, group_extra_orientation) + if self.plot_group_extra["kind"] == "group_colors": + self._plot_colorblocks(group_extra_ax, group_extra_orientation) return_ax_dict["group_extra_ax"] = group_extra_ax