diff --git a/scanpy/plotting/_anndata.py b/scanpy/plotting/_anndata.py index ef6ae9119b..688eae82a8 100755 --- a/scanpy/plotting/_anndata.py +++ b/scanpy/plotting/_anndata.py @@ -761,7 +761,7 @@ def clustermap( @doc_params(show_save_ax=doc_show_save_ax, common_plot_args=doc_common_plot_args) def stacked_violin(adata, var_names, groupby=None, log=False, use_raw=None, num_categories=7, figsize=None, dendrogram=False, gene_symbols=None, - var_group_positions=None, var_group_labels=None, + var_group_positions=None, var_group_labels=None, standard_scale=None, var_group_rotation=None, layer=None, stripplot=False, jitter=False, size=1, scale='width', order=None, swap_axes=False, show=None, save=None, row_palette='muted', **kwds): @@ -796,6 +796,9 @@ def stacked_violin(adata, var_names, groupby=None, log=False, use_raw=None, num_ should be a valid seaborn palette name or a valic matplotlib colormap (see https://seaborn.pydata.org/generated/seaborn.color_palette.html). Alternatively, a single color name or hex value can be passed. E.g. 'red' or '#cc33ff' + standard_scale : {{'var', 'obs'}}, optional (default: None) + Whether or not to standardize that dimension between 0 and 1, meaning for each variable or observation, + subtract the minimum and divide each by its maximum. swap_axes: `bool`, optional (default: `False`) By default, the x axis contains `var_names` (e.g. genes) and the y axis the `groupby` categories. By setting `swap_axes` then x are the `groupby` categories and y the `var_names`. When swapping @@ -822,6 +825,18 @@ def stacked_violin(adata, var_names, groupby=None, log=False, use_raw=None, num_ categories, obs_tidy = _prepare_dataframe(adata, var_names, groupby, use_raw, log, num_categories, gene_symbols=gene_symbols, layer=layer) + if standard_scale == 'obs': + obs_tidy = obs_tidy.sub(obs_tidy.min(1), axis=0) + obs_tidy = obs_tidy.div(obs_tidy.max(1), axis=0).fillna(0) + elif standard_scale == 'var': + obs_tidy -= obs_tidy.min(0) + obs_tidy /= obs_tidy.max(0).fillna(0) + elif standard_scale is None: + pass + else: + logg.warn('Unknown type for standard_scale, ignored') + + if 'color' in kwds: row_palette = kwds['color'] # remove color from kwds in case is set to avoid an error caused by @@ -1045,7 +1060,8 @@ def rename_cols_to_int(value): @doc_params(show_save_ax=doc_show_save_ax, common_plot_args=doc_common_plot_args) def heatmap(adata, var_names, groupby=None, use_raw=None, log=False, num_categories=7, dendrogram=False, gene_symbols=None, var_group_positions=None, var_group_labels=None, - var_group_rotation=None, layer=None, swap_axes=False, show_gene_labels=None, show=None, save=None, figsize=None, **kwds): + var_group_rotation=None, layer=None, standard_scale=None, swap_axes=False, + show_gene_labels=None, show=None, save=None, figsize=None, **kwds): """\ Heatmap of the expression values of genes. @@ -1058,6 +1074,9 @@ def heatmap(adata, var_names, groupby=None, use_raw=None, log=False, num_categor Parameters ---------- {common_plot_args} + standard_scale : {{'var', 'obs'}}, optional (default: None) + Whether or not to standardize that dimension between 0 and 1, meaning for each variable or observation, + subtract the minimum and divide each by its maximum. swap_axes: `bool`, optional (default: `False`) By default, the x axis contains `var_names` (e.g. genes) and the y axis the `groupby` categories (if any). By setting `swap_axes` then x are the `groupby` categories and y the `var_names`. @@ -1080,26 +1099,21 @@ def heatmap(adata, var_names, groupby=None, use_raw=None, log=False, num_categor if use_raw is None and adata.raw is not None: use_raw = True if isinstance(var_names, str): var_names = [var_names] - if not use_raw and layer is None: - # this most likely will used a scaled version of the data - # and thus is better to use a diverging scale - param_set = False - if 'vmin' not in kwds: - kwds['vmin'] = -3 - param_set = True - if 'vmax' not in kwds: - kwds['vmax'] = 3 - param_set = True - if 'cmap' not in kwds: - kwds['cmap'] = 'bwr' - param_set = True - if param_set: - logg.info('Divergent color map has been automatically set to plot non-raw data. Use ' - '`vmin`, `vmax` and `cmap` to adjust the plot.') categories, obs_tidy = _prepare_dataframe(adata, var_names, groupby, use_raw, log, num_categories, gene_symbols=gene_symbols, layer=layer) + if standard_scale == 'obs': + obs_tidy = obs_tidy.sub(obs_tidy.min(1), axis=0) + obs_tidy = obs_tidy.div(obs_tidy.max(1), axis=0).fillna(0) + elif standard_scale == 'var': + obs_tidy -= obs_tidy.min(0) + obs_tidy /= obs_tidy.max(0).fillna(0) + elif standard_scale is None: + pass + else: + logg.warn('Unknown type for standard_scale, ignored') + if groupby is None or len(categories) <= 1: categorical = False # dendrogram can only be computed between groupby categories @@ -1305,7 +1319,7 @@ def heatmap(adata, var_names, groupby=None, use_raw=None, log=False, num_categor @doc_params(show_save_ax=doc_show_save_ax, common_plot_args=doc_common_plot_args) def dotplot(adata, var_names, groupby=None, use_raw=None, log=False, num_categories=7, color_map='Reds', dot_max=None, dot_min=None, figsize=None, dendrogram=False, - gene_symbols=None, var_group_positions=None, + gene_symbols=None, var_group_positions=None, standard_scale=None, smallest_dot=0., var_group_labels=None, var_group_rotation=None, layer=None, show=None, save=None, **kwds): """\ Makes a *dot plot* of the expression values of `var_names`. @@ -1334,6 +1348,12 @@ def dotplot(adata, var_names, groupby=None, use_raw=None, log=False, num_categor If none, the minimum dot size is set to 0. If given, the value should be a number between 0 and 1. All fractions smaller than dot_min are clipped to this value. + standard_scale : {{'var', 'group'}}, optional (default: None) + Whether or not to standardize that dimension between 0 and 1, meaning for each variable or group, + subtract the minimum and divide each by its maximum. + smallest_dot : `float` optional (default: 0.) + If none, the smallest dot has size 0. All expression levels with `dot_min` are potted with + `smallest_dot` dot size. {show_save_ax} **kwds : keyword arguments @@ -1362,6 +1382,17 @@ def dotplot(adata, var_names, groupby=None, use_raw=None, log=False, num_categor # 1. compute mean value mean_obs = obs_tidy.groupby(level=0).mean() + if standard_scale == 'group': + mean_obs = mean_obs.sub(mean_obs.min(1), axis=0) + mean_obs = mean_obs.div(mean_obs.max(1), axis=0).fillna(0) + elif standard_scale == 'var': + mean_obs -= mean_obs.min(0) + mean_obs /= mean_obs.max(0).fillna(0) + elif standard_scale is None: + pass + else: + logg.warn('Unknown type for standard_scale, ignored') + # 2. compute fraction of cells having value >0 # transform obs_tidy into boolean matrix obs_bool = obs_tidy.astype(bool) @@ -1489,6 +1520,7 @@ def dotplot(adata, var_names, groupby=None, use_raw=None, log=False, num_categor frac = ((frac - dot_min) / old_range) size = (frac * 10) ** 2 + size += smallest_dot import matplotlib.colors normalize = matplotlib.colors.Normalize(vmin=kwds.get('vmin'), vmax=kwds.get('vmax')) @@ -1542,6 +1574,7 @@ def dotplot(adata, var_names, groupby=None, use_raw=None, log=False, num_categor else: fracs_values = fracs_legends size = (fracs_values * 10) ** 2 + size += smallest_dot color = [cmap(normalize(value)) for value in np.repeat(max(mean_flat) * 0.7, len(size))] # plot size bar @@ -1549,6 +1582,10 @@ def dotplot(adata, var_names, groupby=None, use_raw=None, log=False, num_categor size_legend.scatter(np.repeat(0, len(size)), range(len(size)), s=size, color=color) size_legend.set_yticks(range(len(size))) + labels = ["{:.0%}".format(x) for x in fracs_legends] + if dot_max < 1: + labels[-1] = ">" + labels[-1] + size_legend.set_yticklabels(labels) size_legend.set_yticklabels(["{:.0%}".format(x) for x in fracs_legends]) size_legend.tick_params(axis='y', left=False, labelleft=False, labelright=True) @@ -1573,7 +1610,9 @@ def dotplot(adata, var_names, groupby=None, use_raw=None, log=False, num_categor @doc_params(show_save_ax=doc_show_save_ax, common_plot_args=doc_common_plot_args) def matrixplot(adata, var_names, groupby=None, use_raw=None, log=False, num_categories=7, figsize=None, dendrogram=False, gene_symbols=None, var_group_positions=None, var_group_labels=None, - var_group_rotation=None, layer=None, swap_axes=False, show=None, save=None, **kwds): + var_group_rotation=None, layer=None, standard_scale=None, swap_axes=False, show=None, + save=None, **kwds): + """\ Creates a heatmap of the mean expression values per cluster of each var_names If groupby is not given, the matrixplot assumes that all data belongs to a single @@ -1582,6 +1621,9 @@ def matrixplot(adata, var_names, groupby=None, use_raw=None, log=False, num_cate Parameters ---------- {common_plot_args} + standard_scale : {{'var', 'group'}}, optional (default: None) + Whether or not to standardize that dimension between 0 and 1, meaning for each variable or group, + subtract the minimum and divide each by its maximum. {show_save_ax} **kwds : keyword arguments Are passed to `matplotlib.pyplot.pcolor`. @@ -1601,22 +1643,6 @@ def matrixplot(adata, var_names, groupby=None, use_raw=None, log=False, num_cate if use_raw is None and adata.raw is not None: use_raw = True if isinstance(var_names, str): var_names = [var_names] - if use_raw is False: - # this most likely will used a scaled version of the data - # and thus is better to use a diverging scale - param_set = False - if 'vmin' not in kwds: - kwds['vmin'] = -3 - param_set = True - if 'vmax' not in kwds: - kwds['vmax'] = 3 - param_set = True - if 'cmap' not in kwds: - kwds['cmap'] = 'bwr' - param_set = True - if param_set: - logg.info('Divergent color map has been automatically set to plot non-raw data. Use ' - '`vmin`, `vmax` and `cmap` to adjust the plot.') categories, obs_tidy = _prepare_dataframe(adata, var_names, groupby, use_raw, log, num_categories, gene_symbols=gene_symbols, layer=layer) @@ -1626,6 +1652,17 @@ def matrixplot(adata, var_names, groupby=None, use_raw=None, log=False, num_cate mean_obs = obs_tidy.groupby(level=0).mean() + if standard_scale == 'group': + mean_obs = mean_obs.sub(mean_obs.min(1), axis=0) + mean_obs = mean_obs.div(mean_obs.max(1), axis=0).fillna(0) + elif standard_scale == 'var': + mean_obs -= mean_obs.min(0) + mean_obs /= mean_obs.max(0).fillna(0) + elif standard_scale is None: + pass + else: + logg.warn('Unknown type for standard_scale, ignored') + if dendrogram: dendro_data = _reorder_categories_after_dendrogram(adata, groupby, dendrogram, var_names=var_names, @@ -1850,7 +1887,7 @@ def tracksplot(adata, var_names, groupby, use_raw=None, log=False, groupby_height = 0.24 num_rows = len(var_names) + 2 # +1 because of dendrogram on top and categories at bottom if figsize is None: - width = 10 + width = 12 track_height = 0.25 else: width, height = figsize @@ -1862,7 +1899,7 @@ def tracksplot(adata, var_names, groupby, use_raw=None, log=False, obs_tidy = obs_tidy.T fig = pl.figure(figsize=(width, height)) - axs = gridspec.GridSpec(ncols=2, nrows=num_rows, wspace=0.3 / width, + axs = gridspec.GridSpec(ncols=2, nrows=num_rows, wspace=1.0 / width, hspace=0, height_ratios=height_ratios, width_ratios=[width, 0.14]) axs_list = [] @@ -1877,7 +1914,8 @@ def tracksplot(adata, var_names, groupby, use_raw=None, log=False, axs_list.append(ax) for cat_idx, category in enumerate(categories): x_start, x_end = x_values[cat_idx] - ax.fill_between(range(x_start, x_end), 0, obs_tidy.iloc[idx, x_start:x_end], lw=0.1, color=groupby_colors[cat_idx]) + ax.fill_between(range(x_start, x_end), 0, obs_tidy.iloc[idx, x_start:x_end], lw=0.1, + color=groupby_colors[cat_idx]) # remove the xticks labels except for the last processed plot. # Because the plots share the x axis it is redundant and less compact to plot the @@ -1894,9 +1932,9 @@ def tracksplot(adata, var_names, groupby, use_raw=None, log=False, ymin, ymax = ax.get_ylim() ymax = int(ymax) ax.set_yticks([ymax]) - tt = ax.set_yticklabels([str(ymax)], ha='right', va='top') - - ax.tick_params(axis='y', labelsize='x-small', right=True, left=False, pad=-5, + tt = ax.set_yticklabels([str(ymax)], ha='left', va='top') + ax.spines['right'].set_position(('axes', 1.01)) + ax.tick_params(axis='y', labelsize='x-small', right=True, left=False, length=2, which='both', labelright=True, labelleft=False, direction='in') ax.set_ylabel(var, rotation=0, fontsize='small', ha='right', va='bottom') ax.yaxis.set_label_coords(-0.005, 0.1) diff --git a/scanpy/tests/_images/master_dotplot_std_scale_group.png b/scanpy/tests/_images/master_dotplot_std_scale_group.png new file mode 100644 index 0000000000..2a40a1fa33 Binary files /dev/null and b/scanpy/tests/_images/master_dotplot_std_scale_group.png differ diff --git a/scanpy/tests/_images/master_dotplot_std_scale_var.png b/scanpy/tests/_images/master_dotplot_std_scale_var.png new file mode 100644 index 0000000000..0f4d53c040 Binary files /dev/null and b/scanpy/tests/_images/master_dotplot_std_scale_var.png differ diff --git a/scanpy/tests/_images/master_heatmap.png b/scanpy/tests/_images/master_heatmap.png index 78c5c9fbf7..60c735164c 100644 Binary files a/scanpy/tests/_images/master_heatmap.png and b/scanpy/tests/_images/master_heatmap.png differ diff --git a/scanpy/tests/_images/master_heatmap2.png b/scanpy/tests/_images/master_heatmap2.png index 157677e366..71bc4ee97d 100644 Binary files a/scanpy/tests/_images/master_heatmap2.png and b/scanpy/tests/_images/master_heatmap2.png differ diff --git a/scanpy/tests/_images/master_heatmap_gene_symbols.png b/scanpy/tests/_images/master_heatmap_gene_symbols.png index 1f0fb3a266..72e9863e61 100644 Binary files a/scanpy/tests/_images/master_heatmap_gene_symbols.png and b/scanpy/tests/_images/master_heatmap_gene_symbols.png differ diff --git a/scanpy/tests/_images/master_heatmap_std_scale_obs.png b/scanpy/tests/_images/master_heatmap_std_scale_obs.png new file mode 100644 index 0000000000..c08270611c Binary files /dev/null and b/scanpy/tests/_images/master_heatmap_std_scale_obs.png differ diff --git a/scanpy/tests/_images/master_heatmap_std_scale_var.png b/scanpy/tests/_images/master_heatmap_std_scale_var.png new file mode 100644 index 0000000000..c64c298ba0 Binary files /dev/null and b/scanpy/tests/_images/master_heatmap_std_scale_var.png differ diff --git a/scanpy/tests/_images/master_heatmap_swap_axes.png b/scanpy/tests/_images/master_heatmap_swap_axes.png index 8d3713394c..a051526ae9 100644 Binary files a/scanpy/tests/_images/master_heatmap_swap_axes.png and b/scanpy/tests/_images/master_heatmap_swap_axes.png differ diff --git a/scanpy/tests/_images/master_matrixplot.png b/scanpy/tests/_images/master_matrixplot.png index e042fb105d..e6dbad60d0 100644 Binary files a/scanpy/tests/_images/master_matrixplot.png and b/scanpy/tests/_images/master_matrixplot.png differ diff --git a/scanpy/tests/_images/master_matrixplot2.png b/scanpy/tests/_images/master_matrixplot2.png index ecf7a39570..a4aefbfc32 100644 Binary files a/scanpy/tests/_images/master_matrixplot2.png and b/scanpy/tests/_images/master_matrixplot2.png differ diff --git a/scanpy/tests/_images/master_matrixplot_gene_symbols.png b/scanpy/tests/_images/master_matrixplot_gene_symbols.png index 3205f8c8d7..6f04ca21c5 100644 Binary files a/scanpy/tests/_images/master_matrixplot_gene_symbols.png and b/scanpy/tests/_images/master_matrixplot_gene_symbols.png differ diff --git a/scanpy/tests/_images/master_matrixplot_std_scale_group.png b/scanpy/tests/_images/master_matrixplot_std_scale_group.png new file mode 100644 index 0000000000..b7e29e47e6 Binary files /dev/null and b/scanpy/tests/_images/master_matrixplot_std_scale_group.png differ diff --git a/scanpy/tests/_images/master_matrixplot_std_scale_var.png b/scanpy/tests/_images/master_matrixplot_std_scale_var.png new file mode 100644 index 0000000000..4dbebadf30 Binary files /dev/null and b/scanpy/tests/_images/master_matrixplot_std_scale_var.png differ diff --git a/scanpy/tests/_images/master_matrixplot_swap_axes.png b/scanpy/tests/_images/master_matrixplot_swap_axes.png index a1139af08f..70bd223f2e 100644 Binary files a/scanpy/tests/_images/master_matrixplot_swap_axes.png and b/scanpy/tests/_images/master_matrixplot_swap_axes.png differ diff --git a/scanpy/tests/_images/master_ranked_genes_tracksplot.png b/scanpy/tests/_images/master_ranked_genes_tracksplot.png index 8f2fe934c5..20ea4f65d0 100644 Binary files a/scanpy/tests/_images/master_ranked_genes_tracksplot.png and b/scanpy/tests/_images/master_ranked_genes_tracksplot.png differ diff --git a/scanpy/tests/_images/master_tracksplot.png b/scanpy/tests/_images/master_tracksplot.png index f6e002d43b..c43f7e6c03 100644 Binary files a/scanpy/tests/_images/master_tracksplot.png and b/scanpy/tests/_images/master_tracksplot.png differ diff --git a/scanpy/tests/_images/master_tracksplot_gene_symbols.png b/scanpy/tests/_images/master_tracksplot_gene_symbols.png index f2a6e102d9..622a88dc24 100644 Binary files a/scanpy/tests/_images/master_tracksplot_gene_symbols.png and b/scanpy/tests/_images/master_tracksplot_gene_symbols.png differ diff --git a/scanpy/tests/test_plotting.py b/scanpy/tests/test_plotting.py index d19464d9d1..4c1296c15d 100644 --- a/scanpy/tests/test_plotting.py +++ b/scanpy/tests/test_plotting.py @@ -48,6 +48,16 @@ def test_heatmap(): num_categories=4, figsize=(4.5, 5), show=False) save_and_compare_images('master_heatmap2') + # test var/obs standardization and layer + adata.layers['test'] = -1 * adata.X.copy() + sc.pl.heatmap(adata, adata.var_names, 'cell_type', use_raw=False, dendrogram=True, show=False, + standard_scale='var', layer='test') + save_and_compare_images('master_heatmap_std_scale_var', tolerance=15) + + sc.pl.heatmap(adata, adata.var_names, 'cell_type', use_raw=False, dendrogram=True, show=False, + standard_scale='obs') + save_and_compare_images('master_heatmap_std_scale_obs', tolerance=15) + def test_dotplot(): adata = sc.datasets.krumsiek11() @@ -71,6 +81,15 @@ def test_dotplot(): figsize=(7, 2.5), dendrogram=True, show=False) save_and_compare_images('master_dotplot3', tolerance=15) + # test var/group standardization smallest_dot + sc.pl.dotplot(adata, adata.var_names, 'cell_type', use_raw=False, dendrogram=True, show=False, + standard_scale='var', smallest_dot=40) + save_and_compare_images('master_dotplot_std_scale_var', tolerance=15) + + sc.pl.dotplot(adata, adata.var_names, 'cell_type', use_raw=False, dendrogram=True, show=False, + standard_scale='group', smallest_dot=10) + save_and_compare_images('master_dotplot_std_scale_group', tolerance=15) + def test_matrixplot(): adata = sc.datasets.krumsiek11() @@ -81,6 +100,16 @@ def test_matrixplot(): sc.pl.matrixplot(adata, adata.var_names, 'cell_type', use_raw=False, dendrogram=True, show=False, swap_axes=True) save_and_compare_images('master_matrixplot_swap_axes', tolerance=15) + # test var/group standardization and layer + adata.layers['test'] = -1 * adata.X.copy() + sc.pl.matrixplot(adata, adata.var_names, 'cell_type', use_raw=False, dendrogram=True, + show=False, standard_scale='var', layer='test', cmap='Blues_r') + save_and_compare_images('master_matrixplot_std_scale_var', tolerance=15) + + sc.pl.matrixplot(adata, adata.var_names, 'cell_type', use_raw=False, dendrogram=True, + show=False, standard_scale='group', swap_axes=True) + save_and_compare_images('master_matrixplot_std_scale_group', tolerance=15) + # test matrixplot numeric column and alternative cmap adata.obs['Gata2'] = adata.X[:, 0] sc.pl.matrixplot(adata, adata.var_names, 'Gata2', use_raw=False, @@ -150,7 +179,7 @@ def test_rank_genes_groups(): # test ranked genes using heatmap (swap_axes=True show_gene_labels=False) sc.pl.rank_genes_groups_heatmap(pbmc, n_genes=20, swap_axes=True, use_raw=False, - show_gene_labels=False, show=False) + show_gene_labels=False, show=False, vmin=-3, vmax=3, cmap='bwr') save_and_compare_images('master_ranked_genes_heatmap_swap_axes', tolerance=tolerance) # test ranked genes using stacked violin plots