Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions mplexporter/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
import os

MPLBE = os.environ.get('MPLBE', 'Agg')

if MPLBE:
import matplotlib
import matplotlib
if MPLBE := os.environ.get('MPLBE', 'Agg'):
matplotlib.use(MPLBE)

import matplotlib.pyplot as plt
24 changes: 22 additions & 2 deletions mplexporter/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from numpy.testing import assert_allclose, assert_equal
from . import plt
from numpy.testing import assert_, assert_allclose, assert_equal
from . import plt, matplotlib
from .. import utils


Expand Down Expand Up @@ -33,3 +33,23 @@ def test_axis_w_fixed_formatter():
# NOTE: Issue #471
# assert_equal(props['tickformat'], labels)


def test_funcformatter_exports_major_ticklabels():
# Test both log and normal cases:
for scale, x, y in [
("linear", [0, 1], [0, 1]),
("log", [1, 10, 100], [0, 1, 2]),
]:
fig, ax = plt.subplots()
ax.set_xscale(scale)
ax.plot([1, 10, 100], [0, 1, 2])
ax.xaxis.set_major_formatter(matplotlib.ticker.FuncFormatter(lambda x, pos: f"t{pos}"))

props = utils.get_axis_properties(ax.xaxis)
plt.close(fig)

assert_equal(props["scale"], scale)
assert_equal(props["tickformat_formatter"], "func")
assert_(props["tickvalues"] is not None)
assert_equal(len(props["tickvalues"]), len(props["tickformat"]))
assert_equal(props["tickformat"][:3], ["t0", "t1", "t2"])
91 changes: 50 additions & 41 deletions mplexporter/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,29 @@ def get_text_style(text):
return style


def _tick_format_props(formatter, tickvalues, labels):
if isinstance(formatter, ticker.NullFormatter):
return "", ""
if isinstance(formatter, ticker.StrMethodFormatter):
convertor = StrMethodTickFormatterConvertor(formatter)
return convertor.output, "str_method"
if isinstance(formatter, ticker.PercentFormatter):
return {
"xmax": formatter.xmax,
"decimals": formatter.decimals,
"symbol": formatter.symbol,
}, "percent"
if hasattr(ticker, 'IndexFormatter') and isinstance(formatter, ticker.IndexFormatter):
return [text.get_text() for text in labels], "index"
if isinstance(formatter, ticker.FixedFormatter):
return list(formatter.seq), "fixed"
if isinstance(formatter, ticker.FuncFormatter) and tickvalues:
return [formatter(value, i) for i, value in enumerate(tickvalues)], "func"
if not any(label.get_visible() for label in labels):
return "", ""
return None, ""


def get_axis_properties(axis):
"""Return the property dictionary for a matplotlib.Axis instance"""
props = {}
Expand All @@ -207,41 +230,21 @@ def get_axis_properties(axis):
# Use tick values if appropriate
locator = axis.get_major_locator()
props['nticks'] = len(locator())
if isinstance(locator, ticker.FixedLocator):
if isinstance(locator, ticker.FixedLocator) or isinstance(axis.get_major_formatter(), ticker.FuncFormatter):
props['tickvalues'] = list(locator())
else:
props['tickvalues'] = None

minor_locator = axis.get_minor_locator()
props['minor_tickvalues'] = list(axis.get_minorticklocs()) if minor_locator else None
props['minorticklength'] = axis._minor_tick_kw.get('size', None)
props['majorticklength'] = axis._major_tick_kw.get('size', None)

# Find tick formats
props['tickformat_formatter'] = ""
formatter = axis.get_major_formatter()
if isinstance(formatter, ticker.NullFormatter):
props['tickformat'] = ""
elif isinstance(formatter, ticker.StrMethodFormatter):
convertor = StrMethodTickFormatterConvertor(formatter)
props['tickformat'] = convertor.output
props['tickformat_formatter'] = "str_method"
elif isinstance(formatter, ticker.PercentFormatter):
props['tickformat'] = {
"xmax": formatter.xmax,
"decimals": formatter.decimals,
"symbol": formatter.symbol,
}
props['tickformat_formatter'] = "percent"
elif hasattr(ticker, 'IndexFormatter') and isinstance(formatter, ticker.IndexFormatter):
# IndexFormatter was dropped in matplotlib 3.5
props['tickformat'] = [text.get_text() for text in axis.get_ticklabels()]
props['tickformat_formatter'] = "index"
elif isinstance(formatter, ticker.FixedFormatter):
props['tickformat'] = list(formatter.seq)
props['tickformat_formatter'] = "fixed"
elif isinstance(formatter, ticker.FuncFormatter) and props['tickvalues']:
props['tickformat'] = [formatter(value) for value in props['tickvalues']]
props['tickformat_formatter'] = "func"
elif not any(label.get_visible() for label in axis.get_ticklabels()):
props['tickformat'] = ""
else:
props['tickformat'] = None
props['minor_tickformat'], props['minor_tickformat_formatter'] = _tick_format_props(
axis.get_minor_formatter(), props['minor_tickvalues'], axis.get_minorticklabels())
props['tickformat'], props['tickformat_formatter'] = _tick_format_props(
axis.get_major_formatter(), props['tickvalues'], axis.get_ticklabels())

# Get axis scale
props['scale'] = axis.get_scale()
Expand All @@ -255,26 +258,32 @@ def get_axis_properties(axis):

# Get associated grid
props['grid'] = get_grid_style(axis)
props['minor_grid'] = get_grid_style(axis, which='minor')

# get axis visibility
props['visible'] = axis.get_visible()

return props


def get_grid_style(axis):
gridlines = axis.get_gridlines()
if axis._major_tick_kw['gridOn'] and len(gridlines) > 0:
color = export_color(gridlines[0].get_color())
alpha = gridlines[0].get_alpha()
dasharray = get_dasharray(gridlines[0])
return dict(gridOn=True,
color=color,
dasharray=dasharray,
alpha=alpha)
else:
def get_grid_style(axis, which='major'):
tick_kw = axis._minor_tick_kw if which == 'minor' else axis._major_tick_kw

if not tick_kw.get('gridOn'):
return {"gridOn": False}

rc = matplotlib.rcParams
color = export_color(tick_kw.get('grid_color', tick_kw.get('grid_c', rc['grid.color'])))
alpha = tick_kw.get('grid_alpha', rc['grid.alpha'])
dasharray = _dasharray_from_linestyle(tick_kw.get('grid_linestyle', tick_kw.get('grid_ls', rc['grid.linestyle'])))
linewidth = tick_kw.get('grid_linewidth', tick_kw.get('grid_lw', rc['grid.linewidth']))

return dict(gridOn=True,
color=color,
dasharray=dasharray,
linewidth=linewidth,
alpha=alpha)


def get_figure_properties(fig):
return {'figwidth': fig.get_figwidth(),
Expand Down