diff --git a/mplexporter/exporter.py b/mplexporter/exporter.py index ac0cc8f..f59237b 100644 --- a/mplexporter/exporter.py +++ b/mplexporter/exporter.py @@ -52,8 +52,8 @@ def run(self, fig): self.crawl_fig(fig) @staticmethod - def process_transform(transform, ax=None, data=None, return_trans=False, - force_trans=None): + def process_transform(transform, ax=None, fig=None, data=None, + return_trans=False, force_trans=None): """Process the transform and convert data to figure or data coordinates Parameters @@ -62,6 +62,8 @@ def process_transform(transform, ax=None, data=None, return_trans=False, The transform applied to the data ax : matplotlib Axes object (optional) The axes the data is associated with + fig : matplotlib Figure object (optional) + The figure the data is associated with data : ndarray (optional) The array of data to be transformed. return_trans : bool (optional) @@ -91,6 +93,7 @@ def process_transform(transform, ax=None, data=None, return_trans=False, transform = force_trans code = "display" + fig_ref = ax.figure if ax is not None else fig if ax is not None: for (c, trans) in [("data", ax.transData), ("axes", ax.transAxes), @@ -99,6 +102,12 @@ def process_transform(transform, ax=None, data=None, return_trans=False, if transform.contains_branch(trans): code, transform = (c, transform - trans) break + elif fig_ref is not None: + for (c, trans) in [("figure", fig_ref.transFigure), + ("display", transforms.IdentityTransform())]: + if transform.contains_branch(trans): + code, transform = (c, transform - trans) + break if data is not None: if return_trans: @@ -115,6 +124,12 @@ def crawl_fig(self, fig): """Crawl the figure and process all axes""" with self.renderer.draw_figure(fig=fig, props=utils.get_figure_properties(fig)): + if getattr(fig, "_suptitle", None) is not None: + self.draw_figure_text(fig, fig._suptitle, text_type="suptitle") + for text in fig.texts: + if text is not getattr(fig, "_suptitle", None): + self.draw_figure_text(fig, text) + for ax in fig.axes: self.crawl_ax(ax) @@ -149,6 +164,20 @@ def crawl_ax(self, ax): if props['visible']: self.crawl_legend(ax, legend) + def draw_figure_text(self, fig, text, text_type=None): + """Process a figure-level matplotlib text object""" + content = text.get_text() + if content: + transform = text.get_transform() + position = text.get_position() + coords, position = self.process_transform(transform, None, fig, + position) + style = utils.get_text_style(text) + self.renderer.draw_figure_text(text=content, position=position, + coordinates=coords, + text_type=text_type, + style=style, mplobj=text) + def crawl_legend(self, ax, legend): """ Recursively look through objects in legend children @@ -184,7 +213,8 @@ def crawl_legend(self, ax, legend): def draw_line(self, ax, line, force_trans=None): """Process a matplotlib line and call renderer.draw_line""" coordinates, data = self.process_transform(line.get_transform(), - ax, line.get_xydata(), + ax=ax, + data=line.get_xydata(), force_trans=force_trans) linestyle = utils.get_line_style(line) if (linestyle['dasharray'] is None @@ -208,8 +238,9 @@ def draw_text(self, ax, text, force_trans=None, text_type=None): if content: transform = text.get_transform() position = text.get_position() - coords, position = self.process_transform(transform, ax, - position, + coords, position = self.process_transform(transform, + ax=ax, + data=position, force_trans=force_trans) style = utils.get_text_style(text) self.renderer.draw_text(text=content, position=position, @@ -222,7 +253,8 @@ def draw_patch(self, ax, patch, force_trans=None): vertices, pathcodes = utils.SVG_path(patch.get_path()) transform = patch.get_transform() coordinates, vertices = self.process_transform(transform, - ax, vertices, + ax=ax, + data=vertices, force_trans=force_trans) linestyle = utils.get_path_style(patch, fill=patch.get_fill()) self.renderer.draw_path(data=vertices, @@ -239,13 +271,14 @@ def draw_collection(self, ax, collection, offsets, paths) = _collections_prepare_points(collection, ax) offset_coords, offsets = self.process_transform( - transOffset, ax, offsets, force_trans=force_offsettrans) + transOffset, ax=ax, data=offsets, force_trans=force_offsettrans) path_coords = self.process_transform( - transform, ax, force_trans=force_pathtrans) + transform, ax=ax, force_trans=force_pathtrans) processed_paths = [utils.SVG_path(path) for path in paths] processed_paths = [(self.process_transform( - transform, ax, path[0], force_trans=force_pathtrans)[1], path[1]) + transform, ax=ax, data=path[0], + force_trans=force_pathtrans)[1], path[1]) for path in processed_paths] path_transforms = collection.get_transforms() @@ -260,6 +293,7 @@ def draw_collection(self, ax, collection, styles = {'linewidth': collection.get_linewidths(), 'facecolor': collection.get_facecolors(), 'edgecolor': collection.get_edgecolors(), + 'dasharray': utils.get_dasharray_list(collection), 'alpha': collection._alpha, 'zorder': collection.get_zorder()} diff --git a/mplexporter/renderers/base.py b/mplexporter/renderers/base.py index 83543f5..86258bf 100644 --- a/mplexporter/renderers/base.py +++ b/mplexporter/renderers/base.py @@ -158,6 +158,11 @@ def draw_marked_line(self, data, coordinates, linestyle, markerstyle, if markerstyle is not None: self.draw_markers(data, coordinates, markerstyle, label, mplobj) + def draw_figure_text(self, text, position, coordinates, style, + text_type=None, mplobj=None): + """Figure-level text; renderers that care can override.""" + pass + def draw_line(self, data, coordinates, style, label, mplobj=None): """ Draw a line. By default, draw the line via the draw_path() command. @@ -204,8 +209,12 @@ def _iter_path_collection(paths, path_transforms, offsets, styles): if np.size(facecolor) == 0: facecolor = ['none'] + dasharray = styles.get('dasharray', None) + if dasharray is None or np.size(dasharray) == 0: + dasharray = ['none'] + elements = [paths, path_transforms, offsets, - edgecolor, styles['linewidth'], facecolor] + edgecolor, styles['linewidth'], facecolor, dasharray] it = itertools return it.islice(py3k.zip(*py3k.map(it.cycle, elements)), N) @@ -258,7 +267,7 @@ def draw_path_collection(self, paths, path_coordinates, path_transforms, for tup in self._iter_path_collection(paths, path_transforms, offsets, styles): - (path, path_transform, offset, ec, lw, fc) = tup + (path, path_transform, offset, ec, lw, fc, da) = tup vertices, pathcodes = path path_transform = transforms.Affine2D(path_transform) vertices = path_transform.transform(vertices) @@ -268,7 +277,7 @@ def draw_path_collection(self, paths, path_coordinates, path_transforms, style = {"edgecolor": utils.export_color(ec), "facecolor": utils.export_color(fc), "edgewidth": lw, - "dasharray": "10,0", + "dasharray": da, "alpha": styles['alpha'], "zorder": styles['zorder']} self.draw_path(data=vertices, coordinates=path_coordinates, diff --git a/mplexporter/renderers/fake_renderer.py b/mplexporter/renderers/fake_renderer.py index 2c4c708..6d4c3e9 100644 --- a/mplexporter/renderers/fake_renderer.py +++ b/mplexporter/renderers/fake_renderer.py @@ -35,6 +35,10 @@ def open_legend(self, legend, props): def close_legend(self, legend): self.output += " closing legend\n" + def draw_figure_text(self, text, position, coordinates, style, + text_type=None, mplobj=None): + self.output += " draw figure text '{0}' {1}\n".format(text, text_type) + def draw_text(self, text, position, coordinates, style, text_type=None, mplobj=None): self.output += " draw text '{0}' {1}\n".format(text, text_type) diff --git a/mplexporter/tests/__init__.py b/mplexporter/tests/__init__.py index 7aa97e7..11a2bf8 100644 --- a/mplexporter/tests/__init__.py +++ b/mplexporter/tests/__init__.py @@ -1,9 +1,6 @@ import os -MPLBE = os.environ.get('MPLBE', 'Agg') - -if MPLBE: - import matplotlib +import matplotlib +if MPLBE := os.environ.get('MPLBE', 'Agg'): matplotlib.use(MPLBE) - import matplotlib.pyplot as plt diff --git a/mplexporter/tests/test_utils.py b/mplexporter/tests/test_utils.py index 51dad80..057fcbf 100644 --- a/mplexporter/tests/test_utils.py +++ b/mplexporter/tests/test_utils.py @@ -1,5 +1,5 @@ -from numpy.testing import assert_allclose, assert_equal -from . import plt +from numpy.testing import assert_, assert_allclose, assert_equal +from . import plt, matplotlib from .. import utils @@ -33,3 +33,41 @@ def test_axis_w_fixed_formatter(): # NOTE: Issue #471 # assert_equal(props['tickformat'], labels) + +def test_funcformatter_exports_major_ticklabels(): + # Test both log and normal cases: + for scale, x, y in [ + ("linear", [0, 1], [0, 1]), + ("log", [1, 10, 100], [0, 1, 2]), + ]: + fig, ax = plt.subplots() + ax.set_xscale(scale) + ax.plot([1, 10, 100], [0, 1, 2]) + ax.xaxis.set_major_formatter(matplotlib.ticker.FuncFormatter(lambda x, pos: f"t{pos}")) + + props = utils.get_axis_properties(ax.xaxis) + plt.close(fig) + + assert_equal(props["scale"], scale) + assert_equal(props["tickformat_formatter"], "func") + assert_(props["tickvalues"] is not None) + assert_equal(len(props["tickvalues"]), len(props["tickformat"])) + assert_equal(props["tickformat"][:3], ["t0", "t1", "t2"]) + + +def test_custom_formatter_exports_major_ticklabels(): + class CustomFormatter(matplotlib.ticker.Formatter): + def __call__(self, x, pos=None): + return f"c{pos}" + + fig, ax = plt.subplots() + ax.plot([0, 1, 2], [0, 1, 2]) + ax.xaxis.set_major_formatter(CustomFormatter()) + + props = utils.get_axis_properties(ax.xaxis) + plt.close(fig) + + assert_equal(props["tickformat_formatter"], "func") + assert_(props["tickvalues"] is not None) + assert_equal(len(props["tickvalues"]), len(props["tickformat"])) + assert_equal(props["tickformat"][:3], ["c0", "c1", "c2"]) diff --git a/mplexporter/utils.py b/mplexporter/utils.py index f9467d7..3f27be5 100644 --- a/mplexporter/utils.py +++ b/mplexporter/utils.py @@ -45,6 +45,18 @@ def _many_to_one(input_dict): ('', ' ', 'None', 'none'): None}) +def _dasharray_from_linestyle(ls): + if ls is None: + return LINESTYLES['solid'] + if isinstance(ls, tuple) and len(ls) == 2: # NOTE: No support for offset yet. + return ','.join(str(val) for val in ls[1]) if ls[1] else LINESTYLES['solid'] + dasharray = LINESTYLES.get(ls, 'not found') + if dasharray == 'not found': + warnings.warn(f"line style '{ls}' not understood: defaulting to solid") + dasharray = LINESTYLES['solid'] + return dasharray + + def get_dasharray(obj): """Get an SVG dash array for the given matplotlib linestyle @@ -59,16 +71,27 @@ def get_dasharray(obj): dasharray : string The HTML/SVG dasharray code associated with the object. """ - if obj.__dict__.get('_dashSeq', None) is not None: - return ','.join(map(str, obj._dashSeq)) + if dashseq := getattr(obj, '_dashSeq', None): + return _dasharray_from_linestyle(dashseq) + + ls = obj.get_linestyle() + if isinstance(ls, (list, tuple)) and not isinstance(ls, str): + ls = ls[0] if len(ls) else None + return _dasharray_from_linestyle(ls) + + +def get_dasharray_list(collection): + """Return a list of SVG dash arrays for a matplotlib Collection""" + linestyles = None + if hasattr(collection, "get_dashes"): + linestyles = collection.get_dashes() + elif hasattr(collection, "get_linestyle"): + linestyles = collection.get_linestyle() else: - ls = obj.get_linestyle() - dasharray = LINESTYLES.get(ls, 'not found') - if dasharray == 'not found': - warnings.warn("line style '{0}' not understood: " - "defaulting to solid line.".format(ls)) - dasharray = LINESTYLES['solid'] - return dasharray + return None + if not isinstance(linestyles, (list, tuple)): + linestyles = [linestyles] + return [_dasharray_from_linestyle(ls) for ls in linestyles] PATH_DICT = {Path.LINETO: 'L', @@ -186,6 +209,41 @@ def get_text_style(text): return style +def is_py_only_formatter(formatter): + return ( + isinstance(formatter, ticker.FuncFormatter) or + (isinstance(formatter, ticker.Formatter) and + not formatter.__class__.__module__.startswith("matplotlib.")) + ) + + +def _tick_format_props(formatter, tickvalues, labels): + if isinstance(formatter, ticker.NullFormatter): + return "", "" + if isinstance(formatter, ticker.StrMethodFormatter): + convertor = StrMethodTickFormatterConvertor(formatter) + return convertor.output, "str_method" + if isinstance(formatter, ticker.PercentFormatter): + return { + "xmax": formatter.xmax, + "decimals": formatter.decimals, + "symbol": formatter.symbol, + }, "percent" + if hasattr(ticker, 'IndexFormatter') and isinstance(formatter, ticker.IndexFormatter): + return [text.get_text() for text in labels], "index" + if isinstance(formatter, ticker.FixedFormatter): + return list(formatter.seq), "fixed" + if is_py_only_formatter(formatter) and tickvalues: + if hasattr(formatter, "set_locs"): + formatter.set_locs(tickvalues) + if hasattr(formatter, "format_ticks"): + return list(formatter.format_ticks(tickvalues)), "func" + return [formatter(value, i) for i, value in enumerate(tickvalues)], "func" + if not any(label.get_visible() for label in labels): + return "", "" + return None, "" + + def get_axis_properties(axis): """Return the property dictionary for a matplotlib.Axis instance""" props = {} @@ -206,42 +264,24 @@ def get_axis_properties(axis): # Use tick values if appropriate locator = axis.get_major_locator() + formatter = axis.get_major_formatter() props['nticks'] = len(locator()) - if isinstance(locator, ticker.FixedLocator): + if (isinstance(locator, ticker.FixedLocator) or + is_py_only_formatter(formatter)): props['tickvalues'] = list(locator()) else: props['tickvalues'] = None + minor_locator = axis.get_minor_locator() + props['minor_tickvalues'] = list(axis.get_minorticklocs()) if minor_locator else None + props['minorticklength'] = axis._minor_tick_kw.get('size', None) + props['majorticklength'] = axis._major_tick_kw.get('size', None) + # Find tick formats - props['tickformat_formatter'] = "" - formatter = axis.get_major_formatter() - if isinstance(formatter, ticker.NullFormatter): - props['tickformat'] = "" - elif isinstance(formatter, ticker.StrMethodFormatter): - convertor = StrMethodTickFormatterConvertor(formatter) - props['tickformat'] = convertor.output - props['tickformat_formatter'] = "str_method" - elif isinstance(formatter, ticker.PercentFormatter): - props['tickformat'] = { - "xmax": formatter.xmax, - "decimals": formatter.decimals, - "symbol": formatter.symbol, - } - props['tickformat_formatter'] = "percent" - elif hasattr(ticker, 'IndexFormatter') and isinstance(formatter, ticker.IndexFormatter): - # IndexFormatter was dropped in matplotlib 3.5 - props['tickformat'] = [text.get_text() for text in axis.get_ticklabels()] - props['tickformat_formatter'] = "index" - elif isinstance(formatter, ticker.FixedFormatter): - props['tickformat'] = list(formatter.seq) - props['tickformat_formatter'] = "fixed" - elif isinstance(formatter, ticker.FuncFormatter) and props['tickvalues']: - props['tickformat'] = [formatter(value) for value in props['tickvalues']] - props['tickformat_formatter'] = "func" - elif not any(label.get_visible() for label in axis.get_ticklabels()): - props['tickformat'] = "" - else: - props['tickformat'] = None + props['minor_tickformat'], props['minor_tickformat_formatter'] = _tick_format_props( + axis.get_minor_formatter(), props['minor_tickvalues'], axis.get_minorticklabels()) + props['tickformat'], props['tickformat_formatter'] = _tick_format_props( + formatter, props['tickvalues'], axis.get_ticklabels()) # Get axis scale props['scale'] = axis.get_scale() @@ -255,6 +295,7 @@ def get_axis_properties(axis): # Get associated grid props['grid'] = get_grid_style(axis) + props['minor_grid'] = get_grid_style(axis, which='minor') # get axis visibility props['visible'] = axis.get_visible() @@ -262,19 +303,24 @@ def get_axis_properties(axis): return props -def get_grid_style(axis): - gridlines = axis.get_gridlines() - if axis._major_tick_kw['gridOn'] and len(gridlines) > 0: - color = export_color(gridlines[0].get_color()) - alpha = gridlines[0].get_alpha() - dasharray = get_dasharray(gridlines[0]) - return dict(gridOn=True, - color=color, - dasharray=dasharray, - alpha=alpha) - else: +def get_grid_style(axis, which='major'): + tick_kw = axis._minor_tick_kw if which == 'minor' else axis._major_tick_kw + + if not tick_kw.get('gridOn'): return {"gridOn": False} + rc = matplotlib.rcParams + color = export_color(tick_kw.get('grid_color', tick_kw.get('grid_c', rc['grid.color']))) + alpha = tick_kw.get('grid_alpha', rc['grid.alpha']) + dasharray = _dasharray_from_linestyle(tick_kw.get('grid_linestyle', tick_kw.get('grid_ls', rc['grid.linestyle']))) + linewidth = tick_kw.get('grid_linewidth', tick_kw.get('grid_lw', rc['grid.linewidth'])) + + return dict(gridOn=True, + color=color, + dasharray=dasharray, + linewidth=linewidth, + alpha=alpha) + def get_figure_properties(fig): return {'figwidth': fig.get_figwidth(),