Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 43 additions & 9 deletions mplexporter/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ def run(self, fig):
self.crawl_fig(fig)

@staticmethod
def process_transform(transform, ax=None, data=None, return_trans=False,
force_trans=None):
def process_transform(transform, ax=None, fig=None, data=None,
return_trans=False, force_trans=None):
"""Process the transform and convert data to figure or data coordinates

Parameters
Expand All @@ -62,6 +62,8 @@ def process_transform(transform, ax=None, data=None, return_trans=False,
The transform applied to the data
ax : matplotlib Axes object (optional)
The axes the data is associated with
fig : matplotlib Figure object (optional)
The figure the data is associated with
data : ndarray (optional)
The array of data to be transformed.
return_trans : bool (optional)
Expand Down Expand Up @@ -91,6 +93,7 @@ def process_transform(transform, ax=None, data=None, return_trans=False,
transform = force_trans

code = "display"
fig_ref = ax.figure if ax is not None else fig
if ax is not None:
for (c, trans) in [("data", ax.transData),
("axes", ax.transAxes),
Expand All @@ -99,6 +102,12 @@ def process_transform(transform, ax=None, data=None, return_trans=False,
if transform.contains_branch(trans):
code, transform = (c, transform - trans)
break
elif fig_ref is not None:
for (c, trans) in [("figure", fig_ref.transFigure),
("display", transforms.IdentityTransform())]:
if transform.contains_branch(trans):
code, transform = (c, transform - trans)
break

if data is not None:
if return_trans:
Expand All @@ -115,6 +124,12 @@ def crawl_fig(self, fig):
"""Crawl the figure and process all axes"""
with self.renderer.draw_figure(fig=fig,
props=utils.get_figure_properties(fig)):
if getattr(fig, "_suptitle", None) is not None:
self.draw_figure_text(fig, fig._suptitle, text_type="suptitle")
for text in fig.texts:
if text is not getattr(fig, "_suptitle", None):
self.draw_figure_text(fig, text)

for ax in fig.axes:
self.crawl_ax(ax)

Expand Down Expand Up @@ -149,6 +164,20 @@ def crawl_ax(self, ax):
if props['visible']:
self.crawl_legend(ax, legend)

def draw_figure_text(self, fig, text, text_type=None):
"""Process a figure-level matplotlib text object"""
content = text.get_text()
if content:
transform = text.get_transform()
position = text.get_position()
coords, position = self.process_transform(transform, None, fig,
position)
style = utils.get_text_style(text)
self.renderer.draw_figure_text(text=content, position=position,
coordinates=coords,
text_type=text_type,
style=style, mplobj=text)

def crawl_legend(self, ax, legend):
"""
Recursively look through objects in legend children
Expand Down Expand Up @@ -184,7 +213,8 @@ def crawl_legend(self, ax, legend):
def draw_line(self, ax, line, force_trans=None):
"""Process a matplotlib line and call renderer.draw_line"""
coordinates, data = self.process_transform(line.get_transform(),
ax, line.get_xydata(),
ax=ax,
data=line.get_xydata(),
force_trans=force_trans)
linestyle = utils.get_line_style(line)
if (linestyle['dasharray'] is None
Expand All @@ -208,8 +238,9 @@ def draw_text(self, ax, text, force_trans=None, text_type=None):
if content:
transform = text.get_transform()
position = text.get_position()
coords, position = self.process_transform(transform, ax,
position,
coords, position = self.process_transform(transform,
ax=ax,
data=position,
force_trans=force_trans)
style = utils.get_text_style(text)
self.renderer.draw_text(text=content, position=position,
Expand All @@ -222,7 +253,8 @@ def draw_patch(self, ax, patch, force_trans=None):
vertices, pathcodes = utils.SVG_path(patch.get_path())
transform = patch.get_transform()
coordinates, vertices = self.process_transform(transform,
ax, vertices,
ax=ax,
data=vertices,
force_trans=force_trans)
linestyle = utils.get_path_style(patch, fill=patch.get_fill())
self.renderer.draw_path(data=vertices,
Expand All @@ -239,13 +271,14 @@ def draw_collection(self, ax, collection,
offsets, paths) = _collections_prepare_points(collection, ax)

offset_coords, offsets = self.process_transform(
transOffset, ax, offsets, force_trans=force_offsettrans)
transOffset, ax=ax, data=offsets, force_trans=force_offsettrans)
path_coords = self.process_transform(
transform, ax, force_trans=force_pathtrans)
transform, ax=ax, force_trans=force_pathtrans)

processed_paths = [utils.SVG_path(path) for path in paths]
processed_paths = [(self.process_transform(
transform, ax, path[0], force_trans=force_pathtrans)[1], path[1])
transform, ax=ax, data=path[0],
force_trans=force_pathtrans)[1], path[1])
for path in processed_paths]

path_transforms = collection.get_transforms()
Expand All @@ -260,6 +293,7 @@ def draw_collection(self, ax, collection,
styles = {'linewidth': collection.get_linewidths(),
'facecolor': collection.get_facecolors(),
'edgecolor': collection.get_edgecolors(),
'dasharray': utils.get_dasharray_list(collection),
'alpha': collection._alpha,
'zorder': collection.get_zorder()}

Expand Down
15 changes: 12 additions & 3 deletions mplexporter/renderers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,11 @@ def draw_marked_line(self, data, coordinates, linestyle, markerstyle,
if markerstyle is not None:
self.draw_markers(data, coordinates, markerstyle, label, mplobj)

def draw_figure_text(self, text, position, coordinates, style,
text_type=None, mplobj=None):
"""Figure-level text; renderers that care can override."""
pass

def draw_line(self, data, coordinates, style, label, mplobj=None):
"""
Draw a line. By default, draw the line via the draw_path() command.
Expand Down Expand Up @@ -204,8 +209,12 @@ def _iter_path_collection(paths, path_transforms, offsets, styles):
if np.size(facecolor) == 0:
facecolor = ['none']

dasharray = styles.get('dasharray', None)
if dasharray is None or np.size(dasharray) == 0:
dasharray = ['none']

elements = [paths, path_transforms, offsets,
edgecolor, styles['linewidth'], facecolor]
edgecolor, styles['linewidth'], facecolor, dasharray]

it = itertools
return it.islice(py3k.zip(*py3k.map(it.cycle, elements)), N)
Expand Down Expand Up @@ -258,7 +267,7 @@ def draw_path_collection(self, paths, path_coordinates, path_transforms,

for tup in self._iter_path_collection(paths, path_transforms,
offsets, styles):
(path, path_transform, offset, ec, lw, fc) = tup
(path, path_transform, offset, ec, lw, fc, da) = tup
vertices, pathcodes = path
path_transform = transforms.Affine2D(path_transform)
vertices = path_transform.transform(vertices)
Expand All @@ -268,7 +277,7 @@ def draw_path_collection(self, paths, path_coordinates, path_transforms,
style = {"edgecolor": utils.export_color(ec),
"facecolor": utils.export_color(fc),
"edgewidth": lw,
"dasharray": "10,0",
"dasharray": da,
"alpha": styles['alpha'],
"zorder": styles['zorder']}
self.draw_path(data=vertices, coordinates=path_coordinates,
Expand Down
4 changes: 4 additions & 0 deletions mplexporter/renderers/fake_renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ def open_legend(self, legend, props):
def close_legend(self, legend):
self.output += " closing legend\n"

def draw_figure_text(self, text, position, coordinates, style,
text_type=None, mplobj=None):
self.output += " draw figure text '{0}' {1}\n".format(text, text_type)

def draw_text(self, text, position, coordinates, style,
text_type=None, mplobj=None):
self.output += " draw text '{0}' {1}\n".format(text, text_type)
Expand Down
7 changes: 2 additions & 5 deletions mplexporter/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
import os

MPLBE = os.environ.get('MPLBE', 'Agg')

if MPLBE:
import matplotlib
import matplotlib
if MPLBE := os.environ.get('MPLBE', 'Agg'):
matplotlib.use(MPLBE)

import matplotlib.pyplot as plt
42 changes: 40 additions & 2 deletions mplexporter/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from numpy.testing import assert_allclose, assert_equal
from . import plt
from numpy.testing import assert_, assert_allclose, assert_equal
from . import plt, matplotlib
from .. import utils


Expand Down Expand Up @@ -33,3 +33,41 @@ def test_axis_w_fixed_formatter():
# NOTE: Issue #471
# assert_equal(props['tickformat'], labels)


def test_funcformatter_exports_major_ticklabels():
# Test both log and normal cases:
for scale, x, y in [
("linear", [0, 1], [0, 1]),
("log", [1, 10, 100], [0, 1, 2]),
]:
fig, ax = plt.subplots()
ax.set_xscale(scale)
ax.plot([1, 10, 100], [0, 1, 2])
ax.xaxis.set_major_formatter(matplotlib.ticker.FuncFormatter(lambda x, pos: f"t{pos}"))

props = utils.get_axis_properties(ax.xaxis)
plt.close(fig)

assert_equal(props["scale"], scale)
assert_equal(props["tickformat_formatter"], "func")
assert_(props["tickvalues"] is not None)
assert_equal(len(props["tickvalues"]), len(props["tickformat"]))
assert_equal(props["tickformat"][:3], ["t0", "t1", "t2"])


def test_custom_formatter_exports_major_ticklabels():
class CustomFormatter(matplotlib.ticker.Formatter):
def __call__(self, x, pos=None):
return f"c{pos}"

fig, ax = plt.subplots()
ax.plot([0, 1, 2], [0, 1, 2])
ax.xaxis.set_major_formatter(CustomFormatter())

props = utils.get_axis_properties(ax.xaxis)
plt.close(fig)

assert_equal(props["tickformat_formatter"], "func")
assert_(props["tickvalues"] is not None)
assert_equal(len(props["tickvalues"]), len(props["tickformat"]))
assert_equal(props["tickformat"][:3], ["c0", "c1", "c2"])
Loading