diff --git a/mplexporter/tests/__init__.py b/mplexporter/tests/__init__.py index 7aa97e7..11a2bf8 100644 --- a/mplexporter/tests/__init__.py +++ b/mplexporter/tests/__init__.py @@ -1,9 +1,6 @@ import os -MPLBE = os.environ.get('MPLBE', 'Agg') - -if MPLBE: - import matplotlib +import matplotlib +if MPLBE := os.environ.get('MPLBE', 'Agg'): matplotlib.use(MPLBE) - import matplotlib.pyplot as plt diff --git a/mplexporter/tests/test_utils.py b/mplexporter/tests/test_utils.py index 51dad80..05da716 100644 --- a/mplexporter/tests/test_utils.py +++ b/mplexporter/tests/test_utils.py @@ -1,5 +1,5 @@ -from numpy.testing import assert_allclose, assert_equal -from . import plt +from numpy.testing import assert_, assert_allclose, assert_equal +from . import plt, matplotlib from .. import utils @@ -33,3 +33,23 @@ def test_axis_w_fixed_formatter(): # NOTE: Issue #471 # assert_equal(props['tickformat'], labels) + +def test_funcformatter_exports_major_ticklabels(): + # Test both log and normal cases: + for scale, x, y in [ + ("linear", [0, 1], [0, 1]), + ("log", [1, 10, 100], [0, 1, 2]), + ]: + fig, ax = plt.subplots() + ax.set_xscale(scale) + ax.plot([1, 10, 100], [0, 1, 2]) + ax.xaxis.set_major_formatter(matplotlib.ticker.FuncFormatter(lambda x, pos: f"t{pos}")) + + props = utils.get_axis_properties(ax.xaxis) + plt.close(fig) + + assert_equal(props["scale"], scale) + assert_equal(props["tickformat_formatter"], "func") + assert_(props["tickvalues"] is not None) + assert_equal(len(props["tickvalues"]), len(props["tickformat"])) + assert_equal(props["tickformat"][:3], ["t0", "t1", "t2"]) diff --git a/mplexporter/utils.py b/mplexporter/utils.py index f9467d7..d02b4b0 100644 --- a/mplexporter/utils.py +++ b/mplexporter/utils.py @@ -186,6 +186,29 @@ def get_text_style(text): return style +def _tick_format_props(formatter, tickvalues, labels): + if isinstance(formatter, ticker.NullFormatter): + return "", "" + if isinstance(formatter, ticker.StrMethodFormatter): + convertor = StrMethodTickFormatterConvertor(formatter) + return convertor.output, "str_method" + if isinstance(formatter, ticker.PercentFormatter): + return { + "xmax": formatter.xmax, + "decimals": formatter.decimals, + "symbol": formatter.symbol, + }, "percent" + if hasattr(ticker, 'IndexFormatter') and isinstance(formatter, ticker.IndexFormatter): + return [text.get_text() for text in labels], "index" + if isinstance(formatter, ticker.FixedFormatter): + return list(formatter.seq), "fixed" + if isinstance(formatter, ticker.FuncFormatter) and tickvalues: + return [formatter(value, i) for i, value in enumerate(tickvalues)], "func" + if not any(label.get_visible() for label in labels): + return "", "" + return None, "" + + def get_axis_properties(axis): """Return the property dictionary for a matplotlib.Axis instance""" props = {} @@ -207,41 +230,21 @@ def get_axis_properties(axis): # Use tick values if appropriate locator = axis.get_major_locator() props['nticks'] = len(locator()) - if isinstance(locator, ticker.FixedLocator): + if isinstance(locator, ticker.FixedLocator) or isinstance(axis.get_major_formatter(), ticker.FuncFormatter): props['tickvalues'] = list(locator()) else: props['tickvalues'] = None + minor_locator = axis.get_minor_locator() + props['minor_tickvalues'] = list(axis.get_minorticklocs()) if minor_locator else None + props['minorticklength'] = axis._minor_tick_kw.get('size', None) + props['majorticklength'] = axis._major_tick_kw.get('size', None) + # Find tick formats - props['tickformat_formatter'] = "" - formatter = axis.get_major_formatter() - if isinstance(formatter, ticker.NullFormatter): - props['tickformat'] = "" - elif isinstance(formatter, ticker.StrMethodFormatter): - convertor = StrMethodTickFormatterConvertor(formatter) - props['tickformat'] = convertor.output - props['tickformat_formatter'] = "str_method" - elif isinstance(formatter, ticker.PercentFormatter): - props['tickformat'] = { - "xmax": formatter.xmax, - "decimals": formatter.decimals, - "symbol": formatter.symbol, - } - props['tickformat_formatter'] = "percent" - elif hasattr(ticker, 'IndexFormatter') and isinstance(formatter, ticker.IndexFormatter): - # IndexFormatter was dropped in matplotlib 3.5 - props['tickformat'] = [text.get_text() for text in axis.get_ticklabels()] - props['tickformat_formatter'] = "index" - elif isinstance(formatter, ticker.FixedFormatter): - props['tickformat'] = list(formatter.seq) - props['tickformat_formatter'] = "fixed" - elif isinstance(formatter, ticker.FuncFormatter) and props['tickvalues']: - props['tickformat'] = [formatter(value) for value in props['tickvalues']] - props['tickformat_formatter'] = "func" - elif not any(label.get_visible() for label in axis.get_ticklabels()): - props['tickformat'] = "" - else: - props['tickformat'] = None + props['minor_tickformat'], props['minor_tickformat_formatter'] = _tick_format_props( + axis.get_minor_formatter(), props['minor_tickvalues'], axis.get_minorticklabels()) + props['tickformat'], props['tickformat_formatter'] = _tick_format_props( + axis.get_major_formatter(), props['tickvalues'], axis.get_ticklabels()) # Get axis scale props['scale'] = axis.get_scale() @@ -255,6 +258,7 @@ def get_axis_properties(axis): # Get associated grid props['grid'] = get_grid_style(axis) + props['minor_grid'] = get_grid_style(axis, which='minor') # get axis visibility props['visible'] = axis.get_visible() @@ -262,19 +266,24 @@ def get_axis_properties(axis): return props -def get_grid_style(axis): - gridlines = axis.get_gridlines() - if axis._major_tick_kw['gridOn'] and len(gridlines) > 0: - color = export_color(gridlines[0].get_color()) - alpha = gridlines[0].get_alpha() - dasharray = get_dasharray(gridlines[0]) - return dict(gridOn=True, - color=color, - dasharray=dasharray, - alpha=alpha) - else: +def get_grid_style(axis, which='major'): + tick_kw = axis._minor_tick_kw if which == 'minor' else axis._major_tick_kw + + if not tick_kw.get('gridOn'): return {"gridOn": False} + rc = matplotlib.rcParams + color = export_color(tick_kw.get('grid_color', tick_kw.get('grid_c', rc['grid.color']))) + alpha = tick_kw.get('grid_alpha', rc['grid.alpha']) + dasharray = _dasharray_from_linestyle(tick_kw.get('grid_linestyle', tick_kw.get('grid_ls', rc['grid.linestyle']))) + linewidth = tick_kw.get('grid_linewidth', tick_kw.get('grid_lw', rc['grid.linewidth'])) + + return dict(gridOn=True, + color=color, + dasharray=dasharray, + linewidth=linewidth, + alpha=alpha) + def get_figure_properties(fig): return {'figwidth': fig.get_figwidth(),