Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ dist
*.egg-info
*~
*.png
*.swp
11 changes: 7 additions & 4 deletions README.rst
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
This a fork of the daft project by Daniel Foreman-Mackey, David W. Hogg, and
others (https://github.com/dfm/daft ).

`Oliver Lindemann <https://github.com/lindemann09>`_

.. image:: https://raw.github.com/davidwhogg/daft/master/images/logo.png

**Daft** is a Python package that uses `matplotlib <http://matplotlib.org/>`_
to render pixel-perfect *probabilistic graphical models* for publication
in a journal or on the internet. With a short Python script and an intuitive
model-building syntax you can design directed and undirected graphs and save
them in any formats that matplotlib supports.

Get more information at: `daft-pgm.org <http://daft-pgm.org>`_
**************************************************************
them in any formats that matplotlib supports. Get more information at:
`daft-pgm.org <http://daft-pgm.org>`_
131 changes: 112 additions & 19 deletions daft.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
__all__ = ["PGM", "Node", "Edge", "Plate"]
"""This a fork of the daft project by Daniel Foreman-Mackey, David W. Hogg,
and others (https://github.com/dfm/daft ).

O. Lindemann (https://github.com/lindemann09/daft)
"""

__all__ = ["PGM", "Node", "Edge", "Plate"]

__version__ = "0.0.3"

__version__ = "0.1"
__author__ = "Oliver Lindemann"

import matplotlib.pyplot as plt
from matplotlib.patches import Ellipse
Expand Down Expand Up @@ -57,6 +63,7 @@ def __init__(self, shape, origin=[0, 0],
self._nodes = {}
self._edges = []
self._plates = []
self._annotations = []

self._ctx = _rendering_context(shape=shape, origin=origin,
grid_unit=grid_unit,
Expand Down Expand Up @@ -110,6 +117,15 @@ def add_plate(self, plate):
self._plates.append(plate)
return None

def add_annotation(self, position, text, **kwargs):
"""
Add an annotation text to the model.

"""

self._annotations.append([position, text, kwargs])
return

def render(self):
"""
Render the :class:`Plate`, :class:`Edge` and :class:`Node` objects in
Expand All @@ -129,8 +145,25 @@ def render(self):
for name, node in self._nodes.iteritems():
node.render(self._ctx)

for pos, text, kwargs in self._annotations:
self.ax.annotate(text, self._ctx.convert(pos[0], pos[1]),
**kwargs)

return self.ax

def add(self, plate_or_note):
"""
Add a :class: `Plate` or `Node` to the model.

"""
if type(plate_or_note) == Node:
self.add_node(plate_or_note)
elif type(plate_or_note) == Plate:
self.add_plate(plate_or_note)
else:
raise RuntimeError("Known object to be added to model")



class Node(object):
"""
Expand All @@ -155,8 +188,13 @@ class Node(object):
:param aspect: (optional)
The aspect ratio width/height for elliptical nodes; default 1.

:param observed: (optional)
Should this be a conditioned variable?
:param rectangle: (optional)
If `True` node has a rectangular shape.

:param double: (optional)
Double lines. This must be ``"inner"`` or ``"outer"``.
Node is shown as double shapes with the second shape plotted inside
or outside of the standard one, respectively.

:param fixed: (optional)
Should this be a fixed (not permitted to vary) variable?
Expand All @@ -172,15 +210,25 @@ class Node(object):
A dictionary of parameters to pass to the
:class:`matplotlib.patches.Ellipse` constructor.

:param space_double_line: (optional)
Distance between the two line for double lines notes.
default = 0.15

"""

def __init__(self, name, content, x, y, scale=1, aspect=None,
observed=False, fixed=False,
offset=[0, 0], plot_params={}, label_params=None):
observed=False, fixed=False, rectangle=False,
double = "",
offset=[0, 0], plot_params={}, label_params=None,
space_double_line=0.15):
# Node style.
assert not (observed and fixed), \
"A node cannot be both 'observed' and 'fixed'."
self.observed = observed
self.fixed = fixed
self.rectangle = rectangle
self.double = double
self.space_double_line = space_double_line

# Metadata.
self.name = name
Expand Down Expand Up @@ -256,33 +304,44 @@ def render(self, ctx):
aspect = ctx.aspect

# Set up an observed node. Note the fc INSANITY.
if self.observed:
if self.observed or self.double =="inner" or self.double=="outer":
# Update the plotting parameters depending on the style of
# observed node.
h = float(diameter)
w = aspect * float(diameter)
if ctx.observed_style == "shaded":
if ctx.observed_style == "shaded" and self.observed:
p["fc"] = "0.7"
elif ctx.observed_style == "outer":
h = diameter + 0.1 * diameter
w = aspect * diameter + 0.1 * diameter
elif ctx.observed_style == "inner":
h = diameter - 0.1 * diameter
w = aspect * diameter - 0.1 * diameter
elif ctx.observed_style == "outer" or self.double == "outer":
h = diameter + self.space_double_line
w = aspect * diameter + self.space_double_line
elif ctx.observed_style == "inner" or self.double == "inner":
h = diameter - self.space_double_line
w = aspect * diameter - self.space_double_line
p["fc"] = fc

# Draw the background ellipse.
bg = Ellipse(xy=ctx.convert(self.x, self.y),
if self.rectangle:
xy = np.array(ctx.convert(self.x, self.y)) - diameter/2.0
xy = xy + (diameter-w)/float(2)
bg = Rectangle(xy=xy, width=w, height=h, **p)
else:
bg = Ellipse(xy=ctx.convert(self.x, self.y),
width=w, height=h, **p)
ax.add_artist(bg)

# Reset the face color.
p["fc"] = fc

# Draw the foreground ellipse.
if ctx.observed_style == "inner" and not self.fixed:
if ctx.observed_style == "inner" and not self.fixed and \
(self.observed or self.double == "inner"):
p["fc"] = "none"
el = Ellipse(xy=ctx.convert(self.x, self.y),
if self.rectangle:
xy = np.array(ctx.convert(self.x, self.y)) - diameter/2.0
el = Rectangle(xy=xy, width=diameter * aspect,
height=diameter, **p)
else:
el = Ellipse(xy=ctx.convert(self.x, self.y),
width=diameter * aspect, height=diameter, **p)
ax.add_artist(el)

Expand Down Expand Up @@ -351,9 +410,12 @@ def _get_coords(self, ctx):
dist1 = np.sqrt(dy * dy + dx * dx / float(a1 ** 2))
dist2 = np.sqrt(dy * dy + dx * dx / float(a2 ** 2))

radius1 = 0.5 * ctx.node_unit * self.node1.scale
radius2 = 0.5 * ctx.node_unit * self.node2.scale

# Compute the fractional effect of the radii of the nodes.
alpha1 = 0.5 * ctx.node_unit * self.node1.scale / dist1
alpha2 = 0.5 * ctx.node_unit * self.node2.scale / dist2
alpha1 = radius1 / dist1
alpha2 = radius2 / dist2

# Get the coordinates of the starting position.
x0, y0 = x1 + alpha1 * dx, y1 + alpha1 * dy
Expand All @@ -362,6 +424,24 @@ def _get_coords(self, ctx):
dx0 = dx * (1. - alpha1 - alpha2)
dy0 = dy * (1. - alpha1 - alpha2)

if self.node1.rectangle or self.node2.rectangle:
# calc displacement of the edge (in direction of the edge)
length, angle = cart2polar(dx0, dy0)
if abs(angle)>np.pi*0.25 and abs(angle)<np.pi*0.75: # upper & lower
angle2 = angle - np.pi/2.0
else:
angle2 = angle
displacement = radius1/abs(np.cos(angle2)) - radius1

if self.node1.rectangle:
displacement_xy = polar2cart(displacement, angle)
x0 += displacement_xy[0]
y0 += displacement_xy[1]
if self.node2.rectangle and self.node1.rectangle:
dx0, dy0 = polar2cart(length-2*displacement, angle)
else:
dx0, dy0 = polar2cart(length-displacement, angle)

return x0, y0, dx0, dy0

def render(self, ctx):
Expand Down Expand Up @@ -417,6 +497,7 @@ class Plate(object):

:param rect:
The rectangle describing the plate bounds in model coordinates.
[left, bottom, width, height]

:param label: (optional)
A string to annotate the plate.
Expand Down Expand Up @@ -621,3 +702,15 @@ def _pop_multiple(d, default, *args):
return default

return results[0][1]

def polar2cart(r, theta):
"""polar coordinates to cartesian coordinates"""
x = r * np.cos(theta)
y = r * np.sin(theta)
return x, y

def cart2polar(x,y):
""" cartesian coordinates to polar coordinates"""
r = np.sqrt(x**2 + y**2)
theta = np.arctan2(y, x)
return r, theta