diff --git a/.gitignore b/.gitignore index 3311046..1cc3bdb 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ dist *.egg-info *~ *.png +*.swp diff --git a/README.rst b/README.rst index 672f9f2..c7a67ff 100644 --- a/README.rst +++ b/README.rst @@ -1,10 +1,13 @@ +This a fork of the daft project by Daniel Foreman-Mackey, David W. Hogg, and +others (https://github.com/dfm/daft ). + +`Oliver Lindemann `_ + .. image:: https://raw.github.com/davidwhogg/daft/master/images/logo.png **Daft** is a Python package that uses `matplotlib `_ to render pixel-perfect *probabilistic graphical models* for publication in a journal or on the internet. With a short Python script and an intuitive model-building syntax you can design directed and undirected graphs and save -them in any formats that matplotlib supports. - -Get more information at: `daft-pgm.org `_ -************************************************************** +them in any formats that matplotlib supports. Get more information at: +`daft-pgm.org `_ diff --git a/daft.py b/daft.py index dd27d07..c4942be 100644 --- a/daft.py +++ b/daft.py @@ -1,8 +1,14 @@ -__all__ = ["PGM", "Node", "Edge", "Plate"] +"""This a fork of the daft project by Daniel Foreman-Mackey, David W. Hogg, +and others (https://github.com/dfm/daft ). + +O. Lindemann (https://github.com/lindemann09/daft) +""" +__all__ = ["PGM", "Node", "Edge", "Plate"] -__version__ = "0.0.3" +__version__ = "0.1" +__author__ = "Oliver Lindemann" import matplotlib.pyplot as plt from matplotlib.patches import Ellipse @@ -57,6 +63,7 @@ def __init__(self, shape, origin=[0, 0], self._nodes = {} self._edges = [] self._plates = [] + self._annotations = [] self._ctx = _rendering_context(shape=shape, origin=origin, grid_unit=grid_unit, @@ -110,6 +117,15 @@ def add_plate(self, plate): self._plates.append(plate) return None + def add_annotation(self, position, text, **kwargs): + """ + Add an annotation text to the model. + + """ + + self._annotations.append([position, text, kwargs]) + return + def render(self): """ Render the :class:`Plate`, :class:`Edge` and :class:`Node` objects in @@ -129,8 +145,25 @@ def render(self): for name, node in self._nodes.iteritems(): node.render(self._ctx) + for pos, text, kwargs in self._annotations: + self.ax.annotate(text, self._ctx.convert(pos[0], pos[1]), + **kwargs) + return self.ax + def add(self, plate_or_note): + """ + Add a :class: `Plate` or `Node` to the model. + + """ + if type(plate_or_note) == Node: + self.add_node(plate_or_note) + elif type(plate_or_note) == Plate: + self.add_plate(plate_or_note) + else: + raise RuntimeError("Known object to be added to model") + + class Node(object): """ @@ -155,8 +188,13 @@ class Node(object): :param aspect: (optional) The aspect ratio width/height for elliptical nodes; default 1. - :param observed: (optional) - Should this be a conditioned variable? + :param rectangle: (optional) + If `True` node has a rectangular shape. + + :param double: (optional) + Double lines. This must be ``"inner"`` or ``"outer"``. + Node is shown as double shapes with the second shape plotted inside + or outside of the standard one, respectively. :param fixed: (optional) Should this be a fixed (not permitted to vary) variable? @@ -172,15 +210,25 @@ class Node(object): A dictionary of parameters to pass to the :class:`matplotlib.patches.Ellipse` constructor. + :param space_double_line: (optional) + Distance between the two line for double lines notes. + default = 0.15 + """ + def __init__(self, name, content, x, y, scale=1, aspect=None, - observed=False, fixed=False, - offset=[0, 0], plot_params={}, label_params=None): + observed=False, fixed=False, rectangle=False, + double = "", + offset=[0, 0], plot_params={}, label_params=None, + space_double_line=0.15): # Node style. assert not (observed and fixed), \ "A node cannot be both 'observed' and 'fixed'." self.observed = observed self.fixed = fixed + self.rectangle = rectangle + self.double = double + self.space_double_line = space_double_line # Metadata. self.name = name @@ -256,23 +304,28 @@ def render(self, ctx): aspect = ctx.aspect # Set up an observed node. Note the fc INSANITY. - if self.observed: + if self.observed or self.double =="inner" or self.double=="outer": # Update the plotting parameters depending on the style of # observed node. h = float(diameter) w = aspect * float(diameter) - if ctx.observed_style == "shaded": + if ctx.observed_style == "shaded" and self.observed: p["fc"] = "0.7" - elif ctx.observed_style == "outer": - h = diameter + 0.1 * diameter - w = aspect * diameter + 0.1 * diameter - elif ctx.observed_style == "inner": - h = diameter - 0.1 * diameter - w = aspect * diameter - 0.1 * diameter + elif ctx.observed_style == "outer" or self.double == "outer": + h = diameter + self.space_double_line + w = aspect * diameter + self.space_double_line + elif ctx.observed_style == "inner" or self.double == "inner": + h = diameter - self.space_double_line + w = aspect * diameter - self.space_double_line p["fc"] = fc # Draw the background ellipse. - bg = Ellipse(xy=ctx.convert(self.x, self.y), + if self.rectangle: + xy = np.array(ctx.convert(self.x, self.y)) - diameter/2.0 + xy = xy + (diameter-w)/float(2) + bg = Rectangle(xy=xy, width=w, height=h, **p) + else: + bg = Ellipse(xy=ctx.convert(self.x, self.y), width=w, height=h, **p) ax.add_artist(bg) @@ -280,9 +333,15 @@ def render(self, ctx): p["fc"] = fc # Draw the foreground ellipse. - if ctx.observed_style == "inner" and not self.fixed: + if ctx.observed_style == "inner" and not self.fixed and \ + (self.observed or self.double == "inner"): p["fc"] = "none" - el = Ellipse(xy=ctx.convert(self.x, self.y), + if self.rectangle: + xy = np.array(ctx.convert(self.x, self.y)) - diameter/2.0 + el = Rectangle(xy=xy, width=diameter * aspect, + height=diameter, **p) + else: + el = Ellipse(xy=ctx.convert(self.x, self.y), width=diameter * aspect, height=diameter, **p) ax.add_artist(el) @@ -351,9 +410,12 @@ def _get_coords(self, ctx): dist1 = np.sqrt(dy * dy + dx * dx / float(a1 ** 2)) dist2 = np.sqrt(dy * dy + dx * dx / float(a2 ** 2)) + radius1 = 0.5 * ctx.node_unit * self.node1.scale + radius2 = 0.5 * ctx.node_unit * self.node2.scale + # Compute the fractional effect of the radii of the nodes. - alpha1 = 0.5 * ctx.node_unit * self.node1.scale / dist1 - alpha2 = 0.5 * ctx.node_unit * self.node2.scale / dist2 + alpha1 = radius1 / dist1 + alpha2 = radius2 / dist2 # Get the coordinates of the starting position. x0, y0 = x1 + alpha1 * dx, y1 + alpha1 * dy @@ -362,6 +424,24 @@ def _get_coords(self, ctx): dx0 = dx * (1. - alpha1 - alpha2) dy0 = dy * (1. - alpha1 - alpha2) + if self.node1.rectangle or self.node2.rectangle: + # calc displacement of the edge (in direction of the edge) + length, angle = cart2polar(dx0, dy0) + if abs(angle)>np.pi*0.25 and abs(angle)