Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ dist
*.egg-info
*~
*.png
*.swp
76 changes: 64 additions & 12 deletions daft.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,13 @@ class Node(object):
:param aspect: (optional)
The aspect ratio width/height for elliptical nodes; default 1.

:param observed: (optional)
Should this be a conditioned variable?
:param rectangle: (optional)
If `True` node has a rectangular shape.

:param double: (optional)
Double lines. This must be ``"inner"`` or ``"outer"``.
Node is shown as double shapes with the second shape plotted inside
or outside of the standard one, respectively.

:param fixed: (optional)
Should this be a fixed (not permitted to vary) variable?
Expand All @@ -174,13 +179,16 @@ class Node(object):

"""
def __init__(self, name, content, x, y, scale=1, aspect=None,
observed=False, fixed=False,
observed=False, fixed=False, rectangle=False,
double = "",
offset=[0, 0], plot_params={}, label_params=None):
# Node style.
assert not (observed and fixed), \
"A node cannot be both 'observed' and 'fixed'."
self.observed = observed
self.fixed = fixed
self.rectangle = rectangle
self.double = double

# Metadata.
self.name = name
Expand Down Expand Up @@ -256,33 +264,44 @@ def render(self, ctx):
aspect = ctx.aspect

# Set up an observed node. Note the fc INSANITY.
if self.observed:
if self.observed or self.double =="inner" or self.double=="outer":
# Update the plotting parameters depending on the style of
# observed node.
h = float(diameter)
w = aspect * float(diameter)
if ctx.observed_style == "shaded":
if ctx.observed_style == "shaded" and self.observed:
p["fc"] = "0.7"
elif ctx.observed_style == "outer":
elif ctx.observed_style == "outer" or self.double == "outer":
h = diameter + 0.1 * diameter
w = aspect * diameter + 0.1 * diameter
elif ctx.observed_style == "inner":
elif ctx.observed_style == "inner" or self.double == "inner":
h = diameter - 0.1 * diameter
w = aspect * diameter - 0.1 * diameter
p["fc"] = fc

# Draw the background ellipse.
bg = Ellipse(xy=ctx.convert(self.x, self.y),
if self.rectangle:
xy = np.array(ctx.convert(self.x, self.y)) - diameter/2.0
xy = xy + (diameter-w)/float(2)
bg = Rectangle(xy=xy, width=w, height=h, **p)
else:
bg = Ellipse(xy=ctx.convert(self.x, self.y),
width=w, height=h, **p)
ax.add_artist(bg)

# Reset the face color.
p["fc"] = fc

# Draw the foreground ellipse.
if ctx.observed_style == "inner" and not self.fixed:
if ctx.observed_style == "inner" and not self.fixed and \
(self.observed or self.double == "inner"):
p["fc"] = "none"
el = Ellipse(xy=ctx.convert(self.x, self.y),
if self.rectangle:
xy = np.array(ctx.convert(self.x, self.y)) - diameter/2.0
el = Rectangle(xy=xy, width=diameter * aspect,
height=diameter, **p)
else:
el = Ellipse(xy=ctx.convert(self.x, self.y),
width=diameter * aspect, height=diameter, **p)
ax.add_artist(el)

Expand Down Expand Up @@ -351,9 +370,12 @@ def _get_coords(self, ctx):
dist1 = np.sqrt(dy * dy + dx * dx / float(a1 ** 2))
dist2 = np.sqrt(dy * dy + dx * dx / float(a2 ** 2))

radius1 = 0.5 * ctx.node_unit * self.node1.scale
radius2 = 0.5 * ctx.node_unit * self.node2.scale

# Compute the fractional effect of the radii of the nodes.
alpha1 = 0.5 * ctx.node_unit * self.node1.scale / dist1
alpha2 = 0.5 * ctx.node_unit * self.node2.scale / dist2
alpha1 = radius1 / dist1
alpha2 = radius2 / dist2

# Get the coordinates of the starting position.
x0, y0 = x1 + alpha1 * dx, y1 + alpha1 * dy
Expand All @@ -362,6 +384,24 @@ def _get_coords(self, ctx):
dx0 = dx * (1. - alpha1 - alpha2)
dy0 = dy * (1. - alpha1 - alpha2)

if self.node1.rectangle or self.node2.rectangle:
# calc displacement of the edge (in direction of the edge)
length, angle = cart2polar(dx0, dy0)
if abs(angle)>np.pi*0.25 and abs(angle)<np.pi*0.75: # upper & lower
angle2 = angle - np.pi/2.0
else:
angle2 = angle
displacement = radius1/abs(np.cos(angle2)) - radius1

if self.node1.rectangle:
displacement_xy = polar2cart(displacement, angle)
x0 += displacement_xy[0]
y0 += displacement_xy[1]
if self.node2.rectangle and self.node1.rectangle:
dx0, dy0 = polar2cart(length-2*displacement, angle)
else:
dx0, dy0 = polar2cart(length-displacement, angle)

return x0, y0, dx0, dy0

def render(self, ctx):
Expand Down Expand Up @@ -621,3 +661,15 @@ def _pop_multiple(d, default, *args):
return default

return results[0][1]

def polar2cart(r, theta):
"""polar coordinates to cartesian coordinates"""
x = r * np.cos(theta)
y = r * np.sin(theta)
return x, y

def cart2polar(x,y):
""" cartesian coordinates to polar coordinates"""
r = np.sqrt(x**2 + y**2)
theta = np.arctan2(y, x)
return r, theta