Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 80 additions & 42 deletions scanpy/plotting/_anndata.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,7 +761,7 @@ def clustermap(
@doc_params(show_save_ax=doc_show_save_ax, common_plot_args=doc_common_plot_args)
def stacked_violin(adata, var_names, groupby=None, log=False, use_raw=None, num_categories=7,
figsize=None, dendrogram=False, gene_symbols=None,
var_group_positions=None, var_group_labels=None,
var_group_positions=None, var_group_labels=None, standard_scale=None,
var_group_rotation=None, layer=None, stripplot=False, jitter=False, size=1,
scale='width', order=None, swap_axes=False, show=None, save=None,
row_palette='muted', **kwds):
Expand Down Expand Up @@ -796,6 +796,9 @@ def stacked_violin(adata, var_names, groupby=None, log=False, use_raw=None, num_
should be a valid seaborn palette name or a valic matplotlib colormap
(see https://seaborn.pydata.org/generated/seaborn.color_palette.html). Alternatively,
a single color name or hex value can be passed. E.g. 'red' or '#cc33ff'
standard_scale : {{'var', 'obs'}}, optional (default: None)
Whether or not to standardize that dimension between 0 and 1, meaning for each variable or observation,
subtract the minimum and divide each by its maximum.
swap_axes: `bool`, optional (default: `False`)
By default, the x axis contains `var_names` (e.g. genes) and the y axis the `groupby` categories.
By setting `swap_axes` then x are the `groupby` categories and y the `var_names`. When swapping
Expand All @@ -822,6 +825,18 @@ def stacked_violin(adata, var_names, groupby=None, log=False, use_raw=None, num_
categories, obs_tidy = _prepare_dataframe(adata, var_names, groupby, use_raw, log, num_categories,
gene_symbols=gene_symbols, layer=layer)

if standard_scale == 'obs':
obs_tidy = obs_tidy.sub(obs_tidy.min(1), axis=0)
obs_tidy = obs_tidy.div(obs_tidy.max(1), axis=0).fillna(0)
elif standard_scale == 'var':
obs_tidy -= obs_tidy.min(0)
obs_tidy /= obs_tidy.max(0).fillna(0)
elif standard_scale is None:
pass
else:
logg.warn('Unknown type for standard_scale, ignored')


if 'color' in kwds:
row_palette = kwds['color']
# remove color from kwds in case is set to avoid an error caused by
Expand Down Expand Up @@ -1045,7 +1060,8 @@ def rename_cols_to_int(value):
@doc_params(show_save_ax=doc_show_save_ax, common_plot_args=doc_common_plot_args)
def heatmap(adata, var_names, groupby=None, use_raw=None, log=False, num_categories=7,
dendrogram=False, gene_symbols=None, var_group_positions=None, var_group_labels=None,
var_group_rotation=None, layer=None, swap_axes=False, show_gene_labels=None, show=None, save=None, figsize=None, **kwds):
var_group_rotation=None, layer=None, standard_scale=None, swap_axes=False,
show_gene_labels=None, show=None, save=None, figsize=None, **kwds):
"""\
Heatmap of the expression values of genes.

Expand All @@ -1058,6 +1074,9 @@ def heatmap(adata, var_names, groupby=None, use_raw=None, log=False, num_categor
Parameters
----------
{common_plot_args}
standard_scale : {{'var', 'obs'}}, optional (default: None)
Whether or not to standardize that dimension between 0 and 1, meaning for each variable or observation,
subtract the minimum and divide each by its maximum.
swap_axes: `bool`, optional (default: `False`)
By default, the x axis contains `var_names` (e.g. genes) and the y axis the `groupby`
categories (if any). By setting `swap_axes` then x are the `groupby` categories and y the `var_names`.
Expand All @@ -1080,26 +1099,21 @@ def heatmap(adata, var_names, groupby=None, use_raw=None, log=False, num_categor
if use_raw is None and adata.raw is not None: use_raw = True
if isinstance(var_names, str):
var_names = [var_names]
if not use_raw and layer is None:
# this most likely will used a scaled version of the data
# and thus is better to use a diverging scale
param_set = False
if 'vmin' not in kwds:
kwds['vmin'] = -3
param_set = True
if 'vmax' not in kwds:
kwds['vmax'] = 3
param_set = True
if 'cmap' not in kwds:
kwds['cmap'] = 'bwr'
param_set = True
if param_set:
logg.info('Divergent color map has been automatically set to plot non-raw data. Use '
'`vmin`, `vmax` and `cmap` to adjust the plot.')

categories, obs_tidy = _prepare_dataframe(adata, var_names, groupby, use_raw, log, num_categories,
gene_symbols=gene_symbols, layer=layer)

if standard_scale == 'obs':
obs_tidy = obs_tidy.sub(obs_tidy.min(1), axis=0)
obs_tidy = obs_tidy.div(obs_tidy.max(1), axis=0).fillna(0)
elif standard_scale == 'var':
obs_tidy -= obs_tidy.min(0)
obs_tidy /= obs_tidy.max(0).fillna(0)
elif standard_scale is None:
pass
else:
logg.warn('Unknown type for standard_scale, ignored')

if groupby is None or len(categories) <= 1:
categorical = False
# dendrogram can only be computed between groupby categories
Expand Down Expand Up @@ -1305,7 +1319,7 @@ def heatmap(adata, var_names, groupby=None, use_raw=None, log=False, num_categor
@doc_params(show_save_ax=doc_show_save_ax, common_plot_args=doc_common_plot_args)
def dotplot(adata, var_names, groupby=None, use_raw=None, log=False, num_categories=7,
color_map='Reds', dot_max=None, dot_min=None, figsize=None, dendrogram=False,
gene_symbols=None, var_group_positions=None,
gene_symbols=None, var_group_positions=None, standard_scale=None, smallest_dot=0.,
var_group_labels=None, var_group_rotation=None, layer=None, show=None, save=None, **kwds):
"""\
Makes a *dot plot* of the expression values of `var_names`.
Expand Down Expand Up @@ -1334,6 +1348,12 @@ def dotplot(adata, var_names, groupby=None, use_raw=None, log=False, num_categor
If none, the minimum dot size is set to 0. If given,
the value should be a number between 0 and 1. All fractions smaller than dot_min are clipped to
this value.
standard_scale : {{'var', 'group'}}, optional (default: None)
Whether or not to standardize that dimension between 0 and 1, meaning for each variable or group,
subtract the minimum and divide each by its maximum.
smallest_dot : `float` optional (default: 0.)
If none, the smallest dot has size 0. All expression levels with `dot_min` are potted with
`smallest_dot` dot size.

{show_save_ax}
**kwds : keyword arguments
Expand Down Expand Up @@ -1362,6 +1382,17 @@ def dotplot(adata, var_names, groupby=None, use_raw=None, log=False, num_categor
# 1. compute mean value
mean_obs = obs_tidy.groupby(level=0).mean()

if standard_scale == 'group':
mean_obs = mean_obs.sub(mean_obs.min(1), axis=0)
mean_obs = mean_obs.div(mean_obs.max(1), axis=0).fillna(0)
elif standard_scale == 'var':
mean_obs -= mean_obs.min(0)
mean_obs /= mean_obs.max(0).fillna(0)
elif standard_scale is None:
pass
else:
logg.warn('Unknown type for standard_scale, ignored')

# 2. compute fraction of cells having value >0
# transform obs_tidy into boolean matrix
obs_bool = obs_tidy.astype(bool)
Expand Down Expand Up @@ -1489,6 +1520,7 @@ def dotplot(adata, var_names, groupby=None, use_raw=None, log=False, num_categor
frac = ((frac - dot_min) / old_range)

size = (frac * 10) ** 2
size += smallest_dot
import matplotlib.colors

normalize = matplotlib.colors.Normalize(vmin=kwds.get('vmin'), vmax=kwds.get('vmax'))
Expand Down Expand Up @@ -1542,13 +1574,18 @@ def dotplot(adata, var_names, groupby=None, use_raw=None, log=False, num_categor
else:
fracs_values = fracs_legends
size = (fracs_values * 10) ** 2
size += smallest_dot
color = [cmap(normalize(value)) for value in np.repeat(max(mean_flat) * 0.7, len(size))]

# plot size bar
size_legend = fig.add_subplot(axs3[0])

size_legend.scatter(np.repeat(0, len(size)), range(len(size)), s=size, color=color)
size_legend.set_yticks(range(len(size)))
labels = ["{:.0%}".format(x) for x in fracs_legends]
if dot_max < 1:
labels[-1] = ">" + labels[-1]
size_legend.set_yticklabels(labels)
size_legend.set_yticklabels(["{:.0%}".format(x) for x in fracs_legends])

size_legend.tick_params(axis='y', left=False, labelleft=False, labelright=True)
Expand All @@ -1573,7 +1610,9 @@ def dotplot(adata, var_names, groupby=None, use_raw=None, log=False, num_categor
@doc_params(show_save_ax=doc_show_save_ax, common_plot_args=doc_common_plot_args)
def matrixplot(adata, var_names, groupby=None, use_raw=None, log=False, num_categories=7,
figsize=None, dendrogram=False, gene_symbols=None, var_group_positions=None, var_group_labels=None,
var_group_rotation=None, layer=None, swap_axes=False, show=None, save=None, **kwds):
var_group_rotation=None, layer=None, standard_scale=None, swap_axes=False, show=None,
save=None, **kwds):

"""\
Creates a heatmap of the mean expression values per cluster of each var_names
If groupby is not given, the matrixplot assumes that all data belongs to a single
Expand All @@ -1582,6 +1621,9 @@ def matrixplot(adata, var_names, groupby=None, use_raw=None, log=False, num_cate
Parameters
----------
{common_plot_args}
standard_scale : {{'var', 'group'}}, optional (default: None)
Whether or not to standardize that dimension between 0 and 1, meaning for each variable or group,
subtract the minimum and divide each by its maximum.
{show_save_ax}
**kwds : keyword arguments
Are passed to `matplotlib.pyplot.pcolor`.
Expand All @@ -1601,22 +1643,6 @@ def matrixplot(adata, var_names, groupby=None, use_raw=None, log=False, num_cate
if use_raw is None and adata.raw is not None: use_raw = True
if isinstance(var_names, str):
var_names = [var_names]
if use_raw is False:
# this most likely will used a scaled version of the data
# and thus is better to use a diverging scale
param_set = False
if 'vmin' not in kwds:
kwds['vmin'] = -3
param_set = True
if 'vmax' not in kwds:
kwds['vmax'] = 3
param_set = True
if 'cmap' not in kwds:
kwds['cmap'] = 'bwr'
param_set = True
if param_set:
logg.info('Divergent color map has been automatically set to plot non-raw data. Use '
'`vmin`, `vmax` and `cmap` to adjust the plot.')

categories, obs_tidy = _prepare_dataframe(adata, var_names, groupby, use_raw, log, num_categories,
gene_symbols=gene_symbols, layer=layer)
Expand All @@ -1626,6 +1652,17 @@ def matrixplot(adata, var_names, groupby=None, use_raw=None, log=False, num_cate

mean_obs = obs_tidy.groupby(level=0).mean()

if standard_scale == 'group':
mean_obs = mean_obs.sub(mean_obs.min(1), axis=0)
mean_obs = mean_obs.div(mean_obs.max(1), axis=0).fillna(0)
elif standard_scale == 'var':
mean_obs -= mean_obs.min(0)
mean_obs /= mean_obs.max(0).fillna(0)
elif standard_scale is None:
pass
else:
logg.warn('Unknown type for standard_scale, ignored')

if dendrogram:
dendro_data = _reorder_categories_after_dendrogram(adata, groupby, dendrogram,
var_names=var_names,
Expand Down Expand Up @@ -1850,7 +1887,7 @@ def tracksplot(adata, var_names, groupby, use_raw=None, log=False,
groupby_height = 0.24
num_rows = len(var_names) + 2 # +1 because of dendrogram on top and categories at bottom
if figsize is None:
width = 10
width = 12
track_height = 0.25
else:
width, height = figsize
Expand All @@ -1862,7 +1899,7 @@ def tracksplot(adata, var_names, groupby, use_raw=None, log=False,
obs_tidy = obs_tidy.T

fig = pl.figure(figsize=(width, height))
axs = gridspec.GridSpec(ncols=2, nrows=num_rows, wspace=0.3 / width,
axs = gridspec.GridSpec(ncols=2, nrows=num_rows, wspace=1.0 / width,
hspace=0, height_ratios=height_ratios,
width_ratios=[width, 0.14])
axs_list = []
Expand All @@ -1877,7 +1914,8 @@ def tracksplot(adata, var_names, groupby, use_raw=None, log=False,
axs_list.append(ax)
for cat_idx, category in enumerate(categories):
x_start, x_end = x_values[cat_idx]
ax.fill_between(range(x_start, x_end), 0, obs_tidy.iloc[idx, x_start:x_end], lw=0.1, color=groupby_colors[cat_idx])
ax.fill_between(range(x_start, x_end), 0, obs_tidy.iloc[idx, x_start:x_end], lw=0.1,
color=groupby_colors[cat_idx])

# remove the xticks labels except for the last processed plot.
# Because the plots share the x axis it is redundant and less compact to plot the
Expand All @@ -1894,9 +1932,9 @@ def tracksplot(adata, var_names, groupby, use_raw=None, log=False,
ymin, ymax = ax.get_ylim()
ymax = int(ymax)
ax.set_yticks([ymax])
tt = ax.set_yticklabels([str(ymax)], ha='right', va='top')

ax.tick_params(axis='y', labelsize='x-small', right=True, left=False, pad=-5,
tt = ax.set_yticklabels([str(ymax)], ha='left', va='top')
ax.spines['right'].set_position(('axes', 1.01))
ax.tick_params(axis='y', labelsize='x-small', right=True, left=False, length=2,
which='both', labelright=True, labelleft=False, direction='in')
ax.set_ylabel(var, rotation=0, fontsize='small', ha='right', va='bottom')
ax.yaxis.set_label_coords(-0.005, 0.1)
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified scanpy/tests/_images/master_heatmap.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified scanpy/tests/_images/master_heatmap2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified scanpy/tests/_images/master_heatmap_gene_symbols.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified scanpy/tests/_images/master_heatmap_swap_axes.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified scanpy/tests/_images/master_matrixplot.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified scanpy/tests/_images/master_matrixplot2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified scanpy/tests/_images/master_matrixplot_gene_symbols.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified scanpy/tests/_images/master_matrixplot_swap_axes.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified scanpy/tests/_images/master_ranked_genes_tracksplot.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified scanpy/tests/_images/master_tracksplot.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified scanpy/tests/_images/master_tracksplot_gene_symbols.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
31 changes: 30 additions & 1 deletion scanpy/tests/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,16 @@ def test_heatmap():
num_categories=4, figsize=(4.5, 5), show=False)
save_and_compare_images('master_heatmap2')

# test var/obs standardization and layer
adata.layers['test'] = -1 * adata.X.copy()
sc.pl.heatmap(adata, adata.var_names, 'cell_type', use_raw=False, dendrogram=True, show=False,
standard_scale='var', layer='test')
save_and_compare_images('master_heatmap_std_scale_var', tolerance=15)

sc.pl.heatmap(adata, adata.var_names, 'cell_type', use_raw=False, dendrogram=True, show=False,
standard_scale='obs')
save_and_compare_images('master_heatmap_std_scale_obs', tolerance=15)


def test_dotplot():
adata = sc.datasets.krumsiek11()
Expand All @@ -71,6 +81,15 @@ def test_dotplot():
figsize=(7, 2.5), dendrogram=True, show=False)
save_and_compare_images('master_dotplot3', tolerance=15)

# test var/group standardization smallest_dot
sc.pl.dotplot(adata, adata.var_names, 'cell_type', use_raw=False, dendrogram=True, show=False,
standard_scale='var', smallest_dot=40)
save_and_compare_images('master_dotplot_std_scale_var', tolerance=15)

sc.pl.dotplot(adata, adata.var_names, 'cell_type', use_raw=False, dendrogram=True, show=False,
standard_scale='group', smallest_dot=10)
save_and_compare_images('master_dotplot_std_scale_group', tolerance=15)


def test_matrixplot():
adata = sc.datasets.krumsiek11()
Expand All @@ -81,6 +100,16 @@ def test_matrixplot():
sc.pl.matrixplot(adata, adata.var_names, 'cell_type', use_raw=False, dendrogram=True, show=False, swap_axes=True)
save_and_compare_images('master_matrixplot_swap_axes', tolerance=15)

# test var/group standardization and layer
adata.layers['test'] = -1 * adata.X.copy()
sc.pl.matrixplot(adata, adata.var_names, 'cell_type', use_raw=False, dendrogram=True,
show=False, standard_scale='var', layer='test', cmap='Blues_r')
save_and_compare_images('master_matrixplot_std_scale_var', tolerance=15)

sc.pl.matrixplot(adata, adata.var_names, 'cell_type', use_raw=False, dendrogram=True,
show=False, standard_scale='group', swap_axes=True)
save_and_compare_images('master_matrixplot_std_scale_group', tolerance=15)

# test matrixplot numeric column and alternative cmap
adata.obs['Gata2'] = adata.X[:, 0]
sc.pl.matrixplot(adata, adata.var_names, 'Gata2', use_raw=False,
Expand Down Expand Up @@ -150,7 +179,7 @@ def test_rank_genes_groups():

# test ranked genes using heatmap (swap_axes=True show_gene_labels=False)
sc.pl.rank_genes_groups_heatmap(pbmc, n_genes=20, swap_axes=True, use_raw=False,
show_gene_labels=False, show=False)
show_gene_labels=False, show=False, vmin=-3, vmax=3, cmap='bwr')
save_and_compare_images('master_ranked_genes_heatmap_swap_axes', tolerance=tolerance)

# test ranked genes using stacked violin plots
Expand Down