Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/release-notes/3042.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add annotation colorblock to Baseplot {pr}`3043` {smaller}`M Büttner`
159 changes: 156 additions & 3 deletions src/scanpy/plotting/_baseplot_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
from warnings import warn

import numpy as np
import pandas as pd
from matplotlib import colormaps, gridspec
from matplotlib import pyplot as plt
from matplotlib.colors import BoundaryNorm, ListedColormap

from .. import logging as logg
from .._compat import old_positionals
Expand All @@ -26,7 +28,6 @@
from collections.abc import Sequence
from typing import Literal, Self

import pandas as pd
from anndata import AnnData
from matplotlib.axes import Axes
from matplotlib.colors import Colormap, Normalize
Expand Down Expand Up @@ -409,6 +410,101 @@ def add_totals(
}
return self

def add_colorblocks(
self,
*,
show: bool | None = True,
sort: Literal["ascending", "descending"] = None,
size: float | None = 0.5,
color: ColorLike | Sequence[ColorLike] | None = None,
) -> BasePlot:
r"""Show colorblocks of `groupby` category.

The annotation colorblocks is by default shown on the right side of the plot or on top
if the axes are swapped.


Parameters
----------
show
Boolean to turn on (True) or off (False) 'add_colorblocks'
sort
Set to either 'ascending' or 'descending' to reorder the categories
by category name
size
size of the annotation colorblocks. Corresponds to width when shown on
the right of the plot, or height when shown on top. The unit is the same
as in matplotlib (inches).
color
Colormap or list of colors for each of the colorblocks.
By default, each bar plot uses the colors assigned in
`adata.uns[{groupby}_colors]`.


Returns
-------
Returns `self` for method chaining.


Examples
--------
>>> import scanpy as sc
>>> adata = sc.datasets.pbmc68k_reduced()
>>> markers = {"T-cell": "CD3D", "B-cell": "CD79A", "myeloid": "CST3"}
>>> plot = sc.pl._baseplot_class.BasePlot(
... adata, markers, groupby="bulk_labels"
... ).add_colorblocks()
>>> plot.plot_group_extra["counts_df"] # doctest: +SKIP
CD4+/CD25 T Reg 0
CD4+/CD45RA+/CD25- Naive T 1
CD4+/CD45RO+ Memory 2
CD8+ Cytotoxic T 3
CD8+/CD45RA+ Naive Cytotoxic 4
CD14+ Monocyte 5
CD19+ B 6
CD34+ 7
CD56+ NK 8
Dendritic 9
dtype: int64
"""
self.group_extra_size = size

if not show:
# hide colorblocks
self.plot_group_extra = None
self.group_extra_size = 0
return self

_sort = sort is not None
_ascending = sort == "ascending"
# counts_df = self.obs_tidy.index.value_counts(sort=_sort, ascending=_ascending)

# determine groupby label positions such that they appear
# centered next/below to the color code rectangle assigned to the category
labels = []
label2code = {} # dictionary of numerical values asigned to each label
for code, (label, value) in enumerate(
self.obs_tidy.index.value_counts(sort=_sort, ascending=_ascending).items()
):
# ticks.append(value_sum + (value / 2))
labels.append(label)
# value_sum += value
label2code[label] = code

counts_df = pd.Series(label2code)

if _sort:
self.categories_order = counts_df.index

self.plot_group_extra = {
"kind": "group_colors",
"width": size,
"sort": sort,
"counts_df": counts_df,
"color": color,
}
return self

@old_positionals("cmap")
def style(self, *, cmap: Colormap | str | None | Empty = _empty) -> Self:
r"""Set visual style parameters.
Expand Down Expand Up @@ -480,6 +576,50 @@ def get_axes(self) -> dict[str, Axes]:
self.make_figure()
return self.ax_dict

def _plot_colorblocks(
self, group_color_ax: Axes, orientation: Literal["top", "right"]
):
"""Make the annotation plot for group labels."""
params = self.plot_group_extra
counts_df = params["counts_df"]
if self.categories_order is not None:
counts_df = counts_df.loc[self.categories_order]
if params["color"] is None:
if f"{self.groupby}_colors" in self.adata.uns:
color = ListedColormap(
self.adata.uns[f"{self.groupby}_colors"], f"{self.groupby}_cmap"
)
else:
color = plt.get_cmap("tab20")
else:
if params["color"] in list(colormaps):
color = plt.get_cmap(params["color"])
else: # if it is a list of colors
color = ListedColormap(params["color"], f"{self.groupby}_cmap")
# color scaling
norm = BoundaryNorm(np.arange(color.N + 1) - 0.5, color.N)

if orientation == "top":
group_color_ax.imshow(
counts_df.values.reshape(1, -1),
aspect="auto",
extent=[0, len(counts_df), 1, 0],
cmap=color,
norm=norm,
)

elif orientation == "right":
group_color_ax.imshow(
counts_df.values.reshape(-1, 1),
aspect="auto",
extent=[0, 1, len(counts_df), 0],
cmap=color,
norm=norm,
)

group_color_ax.grid(visible=False)
group_color_ax.axis("off")

def _plot_totals(
self, total_barplot_ax: Axes, orientation: Literal["top", "right"]
):
Expand Down Expand Up @@ -740,11 +880,22 @@ def make_figure(self):
# second row is for brackets (if needed),
# third row is for mainplot and dendrogram/totals (legend goes in gs[0,1]
# defined earlier)
wspace_adj = (
self.wspace
if self.plot_group_extra is None and self.are_axes_swapped is False
else 0.05
)
hspace_adj = (
0.05
if self.plot_group_extra is not None and self.are_axes_swapped is True
else 0
)

mainplot_gs = gridspec.GridSpecFromSubplotSpec(
nrows=3,
ncols=2,
wspace=self.wspace,
hspace=0.0,
wspace=wspace_adj,
hspace=hspace_adj,
subplot_spec=gs[0, 0],
width_ratios=width_ratios,
height_ratios=height_ratios,
Expand Down Expand Up @@ -778,6 +929,8 @@ def make_figure(self):
)
if self.plot_group_extra["kind"] == "group_totals":
self._plot_totals(group_extra_ax, group_extra_orientation)
if self.plot_group_extra["kind"] == "group_colors":
self._plot_colorblocks(group_extra_ax, group_extra_orientation)

return_ax_dict["group_extra_ax"] = group_extra_ax

Expand Down
Loading