From 29e1125fa2b780f535802fbd5ca67ecbb5c6fdf4 Mon Sep 17 00:00:00 2001 From: Lucas Beyer Date: Fri, 28 Nov 2025 17:53:51 +0100 Subject: [PATCH 1/7] Support text elements in figure container. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This is mainly for mpld3, exporting figure-level text objects (such as fig.suptitle) would never get exported. It looks like figure-level text objects (and objects in general) was completely missing here. This adds text objects, I might add more in the future. The alternative was to put figure text objects into the first axis object, but that's extremely hacky, so I went for the larger but proper fix instead. I did this together with gpt-5.1-codex, not alone. Here is what it has to say: - Exporter now emits figure-level text (suptitle + fig.text) via a dedicated draw_figure_text call before crawling axes; figure transforms passed directly to process_transform instead of shoving text into the first axes. - Renderer API gains a draw_figure_text hook (no-op default in base) so non-mpld3 renderers don’t break; FakeRenderer already implements it. - Figure JSON now carries a texts array and MPLD3Renderer serializes figure-level text entries with proper coordinates/attrs; tests cover presence/positions of exported figure texts. --- mplexporter/exporter.py | 51 +++++++++++++++++++++----- mplexporter/renderers/base.py | 5 +++ mplexporter/renderers/fake_renderer.py | 4 ++ 3 files changed, 51 insertions(+), 9 deletions(-) diff --git a/mplexporter/exporter.py b/mplexporter/exporter.py index ac0cc8f..9992690 100644 --- a/mplexporter/exporter.py +++ b/mplexporter/exporter.py @@ -52,8 +52,8 @@ def run(self, fig): self.crawl_fig(fig) @staticmethod - def process_transform(transform, ax=None, data=None, return_trans=False, - force_trans=None): + def process_transform(transform, ax=None, fig=None, data=None, + return_trans=False, force_trans=None): """Process the transform and convert data to figure or data coordinates Parameters @@ -62,6 +62,8 @@ def process_transform(transform, ax=None, data=None, return_trans=False, The transform applied to the data ax : matplotlib Axes object (optional) The axes the data is associated with + fig : matplotlib Figure object (optional) + The figure the data is associated with data : ndarray (optional) The array of data to be transformed. return_trans : bool (optional) @@ -91,6 +93,7 @@ def process_transform(transform, ax=None, data=None, return_trans=False, transform = force_trans code = "display" + fig_ref = ax.figure if ax is not None else fig if ax is not None: for (c, trans) in [("data", ax.transData), ("axes", ax.transAxes), @@ -99,6 +102,12 @@ def process_transform(transform, ax=None, data=None, return_trans=False, if transform.contains_branch(trans): code, transform = (c, transform - trans) break + elif fig_ref is not None: + for (c, trans) in [("figure", fig_ref.transFigure), + ("display", transforms.IdentityTransform())]: + if transform.contains_branch(trans): + code, transform = (c, transform - trans) + break if data is not None: if return_trans: @@ -115,6 +124,12 @@ def crawl_fig(self, fig): """Crawl the figure and process all axes""" with self.renderer.draw_figure(fig=fig, props=utils.get_figure_properties(fig)): + if getattr(fig, "_suptitle", None) is not None: + self.draw_figure_text(fig, fig._suptitle, text_type="suptitle") + for text in fig.texts: + if text is not getattr(fig, "_suptitle", None): + self.draw_figure_text(fig, text) + for ax in fig.axes: self.crawl_ax(ax) @@ -149,6 +164,20 @@ def crawl_ax(self, ax): if props['visible']: self.crawl_legend(ax, legend) + def draw_figure_text(self, fig, text, text_type=None): + """Process a figure-level matplotlib text object""" + content = text.get_text() + if content: + transform = text.get_transform() + position = text.get_position() + coords, position = self.process_transform(transform, None, fig, + position) + style = utils.get_text_style(text) + self.renderer.draw_figure_text(text=content, position=position, + coordinates=coords, + text_type=text_type, + style=style, mplobj=text) + def crawl_legend(self, ax, legend): """ Recursively look through objects in legend children @@ -184,7 +213,8 @@ def crawl_legend(self, ax, legend): def draw_line(self, ax, line, force_trans=None): """Process a matplotlib line and call renderer.draw_line""" coordinates, data = self.process_transform(line.get_transform(), - ax, line.get_xydata(), + ax=ax, + data=line.get_xydata(), force_trans=force_trans) linestyle = utils.get_line_style(line) if (linestyle['dasharray'] is None @@ -208,8 +238,9 @@ def draw_text(self, ax, text, force_trans=None, text_type=None): if content: transform = text.get_transform() position = text.get_position() - coords, position = self.process_transform(transform, ax, - position, + coords, position = self.process_transform(transform, + ax=ax, + data=position, force_trans=force_trans) style = utils.get_text_style(text) self.renderer.draw_text(text=content, position=position, @@ -222,7 +253,8 @@ def draw_patch(self, ax, patch, force_trans=None): vertices, pathcodes = utils.SVG_path(patch.get_path()) transform = patch.get_transform() coordinates, vertices = self.process_transform(transform, - ax, vertices, + ax=ax, + data=vertices, force_trans=force_trans) linestyle = utils.get_path_style(patch, fill=patch.get_fill()) self.renderer.draw_path(data=vertices, @@ -239,13 +271,14 @@ def draw_collection(self, ax, collection, offsets, paths) = _collections_prepare_points(collection, ax) offset_coords, offsets = self.process_transform( - transOffset, ax, offsets, force_trans=force_offsettrans) + transOffset, ax=ax, data=offsets, force_trans=force_offsettrans) path_coords = self.process_transform( - transform, ax, force_trans=force_pathtrans) + transform, ax=ax, force_trans=force_pathtrans) processed_paths = [utils.SVG_path(path) for path in paths] processed_paths = [(self.process_transform( - transform, ax, path[0], force_trans=force_pathtrans)[1], path[1]) + transform, ax=ax, data=path[0], + force_trans=force_pathtrans)[1], path[1]) for path in processed_paths] path_transforms = collection.get_transforms() diff --git a/mplexporter/renderers/base.py b/mplexporter/renderers/base.py index 83543f5..fd8337d 100644 --- a/mplexporter/renderers/base.py +++ b/mplexporter/renderers/base.py @@ -158,6 +158,11 @@ def draw_marked_line(self, data, coordinates, linestyle, markerstyle, if markerstyle is not None: self.draw_markers(data, coordinates, markerstyle, label, mplobj) + def draw_figure_text(self, text, position, coordinates, style, + text_type=None, mplobj=None): + """Figure-level text; renderers that care can override.""" + pass + def draw_line(self, data, coordinates, style, label, mplobj=None): """ Draw a line. By default, draw the line via the draw_path() command. diff --git a/mplexporter/renderers/fake_renderer.py b/mplexporter/renderers/fake_renderer.py index 2c4c708..6d4c3e9 100644 --- a/mplexporter/renderers/fake_renderer.py +++ b/mplexporter/renderers/fake_renderer.py @@ -35,6 +35,10 @@ def open_legend(self, legend, props): def close_legend(self, legend): self.output += " closing legend\n" + def draw_figure_text(self, text, position, coordinates, style, + text_type=None, mplobj=None): + self.output += " draw figure text '{0}' {1}\n".format(text, text_type) + def draw_text(self, text, position, coordinates, style, text_type=None, mplobj=None): self.output += " draw text '{0}' {1}\n".format(text, text_type) From 80fb62cb2b2fee3e200214fbd470b49656cc6065 Mon Sep 17 00:00:00 2001 From: Lucas Beyer Date: Mon, 1 Dec 2025 10:20:11 +0100 Subject: [PATCH 2/7] dashing line collections (hlines, vlines, ...) line-style was previously ignored for line collections, this fixes it. Note that I did not add support for offset in line-style, which has not existed in the first place, and I have no need for it. There's a corresponding commit/PR in mpld3 coming. This was done together with gpt-5.1-codex (but I reviewed and edited a lot), here is what it has to say: Handle collection dasharrays in exporter - derive dash arrays from collection linestyles/dashes - carry dasharrays through export and render_path_collection so hlines/vlines/grid keep their patterns --- mplexporter/exporter.py | 1 + mplexporter/renderers/base.py | 10 ++++++--- mplexporter/utils.py | 41 +++++++++++++++++++++++++++-------- 3 files changed, 40 insertions(+), 12 deletions(-) diff --git a/mplexporter/exporter.py b/mplexporter/exporter.py index 9992690..f59237b 100644 --- a/mplexporter/exporter.py +++ b/mplexporter/exporter.py @@ -293,6 +293,7 @@ def draw_collection(self, ax, collection, styles = {'linewidth': collection.get_linewidths(), 'facecolor': collection.get_facecolors(), 'edgecolor': collection.get_edgecolors(), + 'dasharray': utils.get_dasharray_list(collection), 'alpha': collection._alpha, 'zorder': collection.get_zorder()} diff --git a/mplexporter/renderers/base.py b/mplexporter/renderers/base.py index fd8337d..86258bf 100644 --- a/mplexporter/renderers/base.py +++ b/mplexporter/renderers/base.py @@ -209,8 +209,12 @@ def _iter_path_collection(paths, path_transforms, offsets, styles): if np.size(facecolor) == 0: facecolor = ['none'] + dasharray = styles.get('dasharray', None) + if dasharray is None or np.size(dasharray) == 0: + dasharray = ['none'] + elements = [paths, path_transforms, offsets, - edgecolor, styles['linewidth'], facecolor] + edgecolor, styles['linewidth'], facecolor, dasharray] it = itertools return it.islice(py3k.zip(*py3k.map(it.cycle, elements)), N) @@ -263,7 +267,7 @@ def draw_path_collection(self, paths, path_coordinates, path_transforms, for tup in self._iter_path_collection(paths, path_transforms, offsets, styles): - (path, path_transform, offset, ec, lw, fc) = tup + (path, path_transform, offset, ec, lw, fc, da) = tup vertices, pathcodes = path path_transform = transforms.Affine2D(path_transform) vertices = path_transform.transform(vertices) @@ -273,7 +277,7 @@ def draw_path_collection(self, paths, path_coordinates, path_transforms, style = {"edgecolor": utils.export_color(ec), "facecolor": utils.export_color(fc), "edgewidth": lw, - "dasharray": "10,0", + "dasharray": da, "alpha": styles['alpha'], "zorder": styles['zorder']} self.draw_path(data=vertices, coordinates=path_coordinates, diff --git a/mplexporter/utils.py b/mplexporter/utils.py index f9467d7..a8c6828 100644 --- a/mplexporter/utils.py +++ b/mplexporter/utils.py @@ -45,6 +45,18 @@ def _many_to_one(input_dict): ('', ' ', 'None', 'none'): None}) +def _dasharray_from_linestyle(ls): + if ls is None: + return LINESTYLES['solid'] + if isinstance(ls, tuple) and len(ls) == 2: # NOTE: No support for offset yet. + return ','.join(str(val) for val in ls[1]) if ls[1] else LINESTYLES['solid'] + dasharray = LINESTYLES.get(ls, 'not found') + if dasharray == 'not found': + warnings.warn(f"line style '{ls}' not understood: defaulting to solid") + dasharray = LINESTYLES['solid'] + return dasharray + + def get_dasharray(obj): """Get an SVG dash array for the given matplotlib linestyle @@ -59,16 +71,27 @@ def get_dasharray(obj): dasharray : string The HTML/SVG dasharray code associated with the object. """ - if obj.__dict__.get('_dashSeq', None) is not None: - return ','.join(map(str, obj._dashSeq)) + if dashseq := getattr(obj, '_dashSeq', None): + return _dasharray_from_linestyle(dashseq) + + ls = obj.get_linestyle() + if isinstance(ls, (list, tuple)) and not isinstance(ls, str): + ls = ls[0] if len(ls) else None + return _dasharray_from_linestyle(ls) + + +def get_dasharray_list(collection): + """Return a list of SVG dash arrays for a matplotlib Collection""" + linestyles = None + if hasattr(collection, "get_dashes"): + linestyles = collection.get_dashes() + elif hasattr(collection, "get_linestyle"): + linestyles = collection.get_linestyle() else: - ls = obj.get_linestyle() - dasharray = LINESTYLES.get(ls, 'not found') - if dasharray == 'not found': - warnings.warn("line style '{0}' not understood: " - "defaulting to solid line.".format(ls)) - dasharray = LINESTYLES['solid'] - return dasharray + return None + if not isinstance(linestyles, (list, tuple)): + linestyles = [linestyles] + return [_dasharray_from_linestyle(ls) for ls in linestyles] PATH_DICT = {Path.LINETO: 'L', From 65369f1972ff888f95fb71e682ba25b539a61d47 Mon Sep 17 00:00:00 2001 From: Lucas Beyer Date: Mon, 1 Dec 2025 11:18:28 +0100 Subject: [PATCH 3/7] line-width was missing from gridlines. --- mplexporter/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mplexporter/utils.py b/mplexporter/utils.py index a8c6828..10bff48 100644 --- a/mplexporter/utils.py +++ b/mplexporter/utils.py @@ -291,9 +291,11 @@ def get_grid_style(axis): color = export_color(gridlines[0].get_color()) alpha = gridlines[0].get_alpha() dasharray = get_dasharray(gridlines[0]) + linewidth = gridlines[0].get_linewidth() return dict(gridOn=True, color=color, dasharray=dasharray, + linewidth=linewidth, alpha=alpha) else: return {"gridOn": False} From cfdb693f596bd1d3a94f94a9740e36b71dcf2482 Mon Sep 17 00:00:00 2001 From: Lucas Beyer Date: Mon, 1 Dec 2025 23:21:49 +0100 Subject: [PATCH 4/7] Also export minor grid. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This will allow to render both major and minor gridlines in mpld3, which fixes the following issue: https://github.com/mpld3/mpld3/issues/527 As per usual, disclaimer that I co-developed this with gpt-5.1-codex, having it figure out the issues and give implementation recommendations, with me testing, verifying, and tidying up the code. Here's what it has to say, especially wrt the change in API call: - include minor tick values/length and minor grid style in axis props so minor ticks/grids render in mpld3 - read grid color/linewidth/linestyle from tick kwargs (and rcParams fallback) instead of inspecting gridlines[0], avoiding the get_gridlines(which=…) API that isn’t available on matplotlib 3.10” Rationale for the kw/rc approach: `Axis.get_gridlines()` doesn’t accept `which` on matplotlib 3.10, so probing `gridlines[0]` for minor/major fails. Pulling style from the tick keyword dict (which matplotlib populates with `grid_*` fields when you call `ax.grid(...)`) plus `rcParams` defaults gives the same style without needing `get_gridlines(which=…)`, keeping compatibility and matching user-set grid styles. (I verified, indeed get_gridlines does not allow specifying which ones - seems like an omission in matplotlib API to me) --- mplexporter/utils.py | 34 +++++++++++++++++++++------------- 1 file changed, 21 insertions(+), 13 deletions(-) diff --git a/mplexporter/utils.py b/mplexporter/utils.py index 10bff48..5683c64 100644 --- a/mplexporter/utils.py +++ b/mplexporter/utils.py @@ -235,6 +235,10 @@ def get_axis_properties(axis): else: props['tickvalues'] = None + minor_locator = axis.get_minor_locator() + props['minor_tickvalues'] = list(axis.get_minorticklocs()) if minor_locator else None + props['minorticklength'] = axis._minor_tick_kw.get('size', None) + # Find tick formats props['tickformat_formatter'] = "" formatter = axis.get_major_formatter() @@ -278,6 +282,7 @@ def get_axis_properties(axis): # Get associated grid props['grid'] = get_grid_style(axis) + props['minor_grid'] = get_grid_style(axis, which='minor') # get axis visibility props['visible'] = axis.get_visible() @@ -285,21 +290,24 @@ def get_axis_properties(axis): return props -def get_grid_style(axis): - gridlines = axis.get_gridlines() - if axis._major_tick_kw['gridOn'] and len(gridlines) > 0: - color = export_color(gridlines[0].get_color()) - alpha = gridlines[0].get_alpha() - dasharray = get_dasharray(gridlines[0]) - linewidth = gridlines[0].get_linewidth() - return dict(gridOn=True, - color=color, - dasharray=dasharray, - linewidth=linewidth, - alpha=alpha) - else: +def get_grid_style(axis, which='major'): + tick_kw = axis._minor_tick_kw if which == 'minor' else axis._major_tick_kw + + if not tick_kw.get('gridOn'): return {"gridOn": False} + rc = matplotlib.rcParams + color = export_color(tick_kw.get('grid_color', tick_kw.get('grid_c', rc['grid.color']))) + alpha = tick_kw.get('grid_alpha', rc['grid.alpha']) + dasharray = _dasharray_from_linestyle(tick_kw.get('grid_linestyle', tick_kw.get('grid_ls', rc['grid.linestyle']))) + linewidth = tick_kw.get('grid_linewidth', tick_kw.get('grid_lw', rc['grid.linewidth'])) + + return dict(gridOn=True, + color=color, + dasharray=dasharray, + linewidth=linewidth, + alpha=alpha) + def get_figure_properties(fig): return {'figwidth': fig.get_figwidth(), From 0e95a2700f829eeeba30ff99f38886177011e5bd Mon Sep 17 00:00:00 2001 From: Lucas Beyer Date: Tue, 2 Dec 2025 19:13:34 +0100 Subject: [PATCH 5/7] Bring minor and major ticks/tickabels to parity Minor ticklabels were missing altogether, --- mplexporter/utils.py | 57 ++++++++++++++++++++++---------------------- 1 file changed, 28 insertions(+), 29 deletions(-) diff --git a/mplexporter/utils.py b/mplexporter/utils.py index 5683c64..b3bd69c 100644 --- a/mplexporter/utils.py +++ b/mplexporter/utils.py @@ -209,6 +209,29 @@ def get_text_style(text): return style +def _tick_format_props(formatter, tickvalues, labels): + if isinstance(formatter, ticker.NullFormatter): + return "", "" + if isinstance(formatter, ticker.StrMethodFormatter): + convertor = StrMethodTickFormatterConvertor(formatter) + return convertor.output, "str_method" + if isinstance(formatter, ticker.PercentFormatter): + return { + "xmax": formatter.xmax, + "decimals": formatter.decimals, + "symbol": formatter.symbol, + }, "percent" + if hasattr(ticker, 'IndexFormatter') and isinstance(formatter, ticker.IndexFormatter): + return [text.get_text() for text in labels], "index" + if isinstance(formatter, ticker.FixedFormatter): + return list(formatter.seq), "fixed" + if isinstance(formatter, ticker.FuncFormatter) and tickvalues: + return [formatter(value) for value in tickvalues], "func" + if not any(label.get_visible() for label in labels): + return "", "" + return None, "" + + def get_axis_properties(axis): """Return the property dictionary for a matplotlib.Axis instance""" props = {} @@ -238,37 +261,13 @@ def get_axis_properties(axis): minor_locator = axis.get_minor_locator() props['minor_tickvalues'] = list(axis.get_minorticklocs()) if minor_locator else None props['minorticklength'] = axis._minor_tick_kw.get('size', None) + props['majorticklength'] = axis._major_tick_kw.get('size', None) # Find tick formats - props['tickformat_formatter'] = "" - formatter = axis.get_major_formatter() - if isinstance(formatter, ticker.NullFormatter): - props['tickformat'] = "" - elif isinstance(formatter, ticker.StrMethodFormatter): - convertor = StrMethodTickFormatterConvertor(formatter) - props['tickformat'] = convertor.output - props['tickformat_formatter'] = "str_method" - elif isinstance(formatter, ticker.PercentFormatter): - props['tickformat'] = { - "xmax": formatter.xmax, - "decimals": formatter.decimals, - "symbol": formatter.symbol, - } - props['tickformat_formatter'] = "percent" - elif hasattr(ticker, 'IndexFormatter') and isinstance(formatter, ticker.IndexFormatter): - # IndexFormatter was dropped in matplotlib 3.5 - props['tickformat'] = [text.get_text() for text in axis.get_ticklabels()] - props['tickformat_formatter'] = "index" - elif isinstance(formatter, ticker.FixedFormatter): - props['tickformat'] = list(formatter.seq) - props['tickformat_formatter'] = "fixed" - elif isinstance(formatter, ticker.FuncFormatter) and props['tickvalues']: - props['tickformat'] = [formatter(value) for value in props['tickvalues']] - props['tickformat_formatter'] = "func" - elif not any(label.get_visible() for label in axis.get_ticklabels()): - props['tickformat'] = "" - else: - props['tickformat'] = None + props['minor_tickformat'], props['minor_tickformat_formatter'] = _tick_format_props( + axis.get_minor_formatter(), props['minor_tickvalues'], axis.get_minorticklabels()) + props['tickformat'], props['tickformat_formatter'] = _tick_format_props( + axis.get_major_formatter(), props['tickvalues'], axis.get_ticklabels()) # Get axis scale props['scale'] = axis.get_scale() From 1d200a9050f824a4d3d21e2eca885e729c9ddb1f Mon Sep 17 00:00:00 2001 From: Lucas Beyer Date: Tue, 16 Dec 2025 09:28:51 +0100 Subject: [PATCH 6/7] Make FuncFormatter work with non-fixed locator The previous fix to make FuncFormatter work, in https://github.com/mpld3/mplexporter/pull/67/files and https://github.com/mpld3/mpld3/commit/e7fa282e72328e25ee2b644c4123606a6e86b9a6 had two issues still: 1) the mpl FuncFormatter API also takes the index as second argument, which was missing. 2) it exported `tickvalues` only in the FixedLocator case, so these were missing for non-fixed FuncFormatter and hence the FuncFormatter codepath, which also requires them to be present, was never hit with non-fixed locators. Also add a test, and let tests import `matplotlib` via the backend-setting mechanism, not only `plt`. This fixes both. Again, this was figured out and helped with gpt-codex and my careful review+cleanup. Codex even identified the two original commits I'm linking above :) --- mplexporter/tests/__init__.py | 7 ++----- mplexporter/tests/test_utils.py | 24 ++++++++++++++++++++++-- mplexporter/utils.py | 4 ++-- 3 files changed, 26 insertions(+), 9 deletions(-) diff --git a/mplexporter/tests/__init__.py b/mplexporter/tests/__init__.py index 7aa97e7..11a2bf8 100644 --- a/mplexporter/tests/__init__.py +++ b/mplexporter/tests/__init__.py @@ -1,9 +1,6 @@ import os -MPLBE = os.environ.get('MPLBE', 'Agg') - -if MPLBE: - import matplotlib +import matplotlib +if MPLBE := os.environ.get('MPLBE', 'Agg'): matplotlib.use(MPLBE) - import matplotlib.pyplot as plt diff --git a/mplexporter/tests/test_utils.py b/mplexporter/tests/test_utils.py index 51dad80..05da716 100644 --- a/mplexporter/tests/test_utils.py +++ b/mplexporter/tests/test_utils.py @@ -1,5 +1,5 @@ -from numpy.testing import assert_allclose, assert_equal -from . import plt +from numpy.testing import assert_, assert_allclose, assert_equal +from . import plt, matplotlib from .. import utils @@ -33,3 +33,23 @@ def test_axis_w_fixed_formatter(): # NOTE: Issue #471 # assert_equal(props['tickformat'], labels) + +def test_funcformatter_exports_major_ticklabels(): + # Test both log and normal cases: + for scale, x, y in [ + ("linear", [0, 1], [0, 1]), + ("log", [1, 10, 100], [0, 1, 2]), + ]: + fig, ax = plt.subplots() + ax.set_xscale(scale) + ax.plot([1, 10, 100], [0, 1, 2]) + ax.xaxis.set_major_formatter(matplotlib.ticker.FuncFormatter(lambda x, pos: f"t{pos}")) + + props = utils.get_axis_properties(ax.xaxis) + plt.close(fig) + + assert_equal(props["scale"], scale) + assert_equal(props["tickformat_formatter"], "func") + assert_(props["tickvalues"] is not None) + assert_equal(len(props["tickvalues"]), len(props["tickformat"])) + assert_equal(props["tickformat"][:3], ["t0", "t1", "t2"]) diff --git a/mplexporter/utils.py b/mplexporter/utils.py index b3bd69c..5b33c88 100644 --- a/mplexporter/utils.py +++ b/mplexporter/utils.py @@ -226,7 +226,7 @@ def _tick_format_props(formatter, tickvalues, labels): if isinstance(formatter, ticker.FixedFormatter): return list(formatter.seq), "fixed" if isinstance(formatter, ticker.FuncFormatter) and tickvalues: - return [formatter(value) for value in tickvalues], "func" + return [formatter(value, i) for i, value in enumerate(tickvalues)], "func" if not any(label.get_visible() for label in labels): return "", "" return None, "" @@ -253,7 +253,7 @@ def get_axis_properties(axis): # Use tick values if appropriate locator = axis.get_major_locator() props['nticks'] = len(locator()) - if isinstance(locator, ticker.FixedLocator): + if isinstance(locator, ticker.FixedLocator) or isinstance(axis.get_major_formatter(), ticker.FuncFormatter): props['tickvalues'] = list(locator()) else: props['tickvalues'] = None From 267610315f11fdbfd2e44cd46a2e841b86d23767 Mon Sep 17 00:00:00 2001 From: lucasb-eyer Date: Wed, 7 Jan 2026 09:37:15 +0100 Subject: [PATCH 7/7] Freeze labels for Python-only formatters MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Treat Python-only formatters (FuncFormatter, and formatter subclasses) as export-time only by precomputing tick labels and exporting tickvalues, so mpld3 doesn’t drop custom formatter subclasses. This unifies FuncFormatter and non-matplotlib Formatter subclasses under a single check, preserving existing d3 behavior for built-in formatters like Date/Log/Scalar. Adds tests for custom formatters in both exporter copies. As a disclaimer, as usual, I noticed an issue in my use, and fixed it together with gpt-5.2-codex, manually reviewing and cleaning up the changes. --- mplexporter/tests/test_utils.py | 18 ++++++++++++++++++ mplexporter/utils.py | 20 +++++++++++++++++--- 2 files changed, 35 insertions(+), 3 deletions(-) diff --git a/mplexporter/tests/test_utils.py b/mplexporter/tests/test_utils.py index 05da716..057fcbf 100644 --- a/mplexporter/tests/test_utils.py +++ b/mplexporter/tests/test_utils.py @@ -53,3 +53,21 @@ def test_funcformatter_exports_major_ticklabels(): assert_(props["tickvalues"] is not None) assert_equal(len(props["tickvalues"]), len(props["tickformat"])) assert_equal(props["tickformat"][:3], ["t0", "t1", "t2"]) + + +def test_custom_formatter_exports_major_ticklabels(): + class CustomFormatter(matplotlib.ticker.Formatter): + def __call__(self, x, pos=None): + return f"c{pos}" + + fig, ax = plt.subplots() + ax.plot([0, 1, 2], [0, 1, 2]) + ax.xaxis.set_major_formatter(CustomFormatter()) + + props = utils.get_axis_properties(ax.xaxis) + plt.close(fig) + + assert_equal(props["tickformat_formatter"], "func") + assert_(props["tickvalues"] is not None) + assert_equal(len(props["tickvalues"]), len(props["tickformat"])) + assert_equal(props["tickformat"][:3], ["c0", "c1", "c2"]) diff --git a/mplexporter/utils.py b/mplexporter/utils.py index 5b33c88..3f27be5 100644 --- a/mplexporter/utils.py +++ b/mplexporter/utils.py @@ -209,6 +209,14 @@ def get_text_style(text): return style +def is_py_only_formatter(formatter): + return ( + isinstance(formatter, ticker.FuncFormatter) or + (isinstance(formatter, ticker.Formatter) and + not formatter.__class__.__module__.startswith("matplotlib.")) + ) + + def _tick_format_props(formatter, tickvalues, labels): if isinstance(formatter, ticker.NullFormatter): return "", "" @@ -225,7 +233,11 @@ def _tick_format_props(formatter, tickvalues, labels): return [text.get_text() for text in labels], "index" if isinstance(formatter, ticker.FixedFormatter): return list(formatter.seq), "fixed" - if isinstance(formatter, ticker.FuncFormatter) and tickvalues: + if is_py_only_formatter(formatter) and tickvalues: + if hasattr(formatter, "set_locs"): + formatter.set_locs(tickvalues) + if hasattr(formatter, "format_ticks"): + return list(formatter.format_ticks(tickvalues)), "func" return [formatter(value, i) for i, value in enumerate(tickvalues)], "func" if not any(label.get_visible() for label in labels): return "", "" @@ -252,8 +264,10 @@ def get_axis_properties(axis): # Use tick values if appropriate locator = axis.get_major_locator() + formatter = axis.get_major_formatter() props['nticks'] = len(locator()) - if isinstance(locator, ticker.FixedLocator) or isinstance(axis.get_major_formatter(), ticker.FuncFormatter): + if (isinstance(locator, ticker.FixedLocator) or + is_py_only_formatter(formatter)): props['tickvalues'] = list(locator()) else: props['tickvalues'] = None @@ -267,7 +281,7 @@ def get_axis_properties(axis): props['minor_tickformat'], props['minor_tickformat_formatter'] = _tick_format_props( axis.get_minor_formatter(), props['minor_tickvalues'], axis.get_minorticklabels()) props['tickformat'], props['tickformat_formatter'] = _tick_format_props( - axis.get_major_formatter(), props['tickvalues'], axis.get_ticklabels()) + formatter, props['tickvalues'], axis.get_ticklabels()) # Get axis scale props['scale'] = axis.get_scale()