Skip to content

Commit 243d6e2

Browse files
committed
Initial version of successor DAG with ladder node placement.
1 parent 7e3fa88 commit 243d6e2

File tree

2 files changed

+136
-0
lines changed

2 files changed

+136
-0
lines changed
39.6 KB
Loading
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
# Copyright (C) 2023-present The Project Contributors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pytest
16+
from cl.runtime.context.testing_context import TestingContext
17+
from dataclasses import dataclass
18+
from typing import List, Optional
19+
import networkx as nx
20+
import matplotlib.pyplot as plt
21+
from matplotlib.patches import Rectangle
22+
from runtime.testing.pytest.pytest_fixtures import local_dir_fixture
23+
24+
def test_smoke(local_dir_fixture):
25+
with TestingContext() as context:
26+
27+
@dataclass
28+
class Node:
29+
title: str
30+
successors: Optional[List['Node']] = None
31+
32+
# Define the nodes with successors
33+
staff_a = Node(title="Staff A")
34+
staff_b = Node(title="Staff B")
35+
staff_c = Node(title="Staff C")
36+
staff_d = Node(title="Staff D")
37+
team_1 = Node(title="Team A Lead", successors=[staff_a, staff_b])
38+
team_2 = Node(title="Team B Lead", successors=[staff_c, staff_d])
39+
ceo = Node(title="CEO", successors=[team_1, team_2])
40+
41+
# Create a directed graph
42+
G = nx.DiGraph()
43+
44+
# Initialize position dictionary and label dictionary
45+
pos = {}
46+
labels = {}
47+
48+
# Starting coordinates for the CEO
49+
x_start = 10
50+
y_start = 10
51+
x_offset = 6 # Horizontal distance between nodes
52+
y_offset = 2 # Vertical distance between each successor
53+
54+
current_y = 0
55+
56+
# Function to recursively add nodes and edges to the graph
57+
def add_nodes_recursive(graph, node, current_id, x, y, pos, labels) -> int:
58+
# Add the current node to the graph
59+
graph.add_node(current_id)
60+
pos[current_id] = (x, y)
61+
labels[current_id] = node.title
62+
63+
# Add successors recursively
64+
if node.successors:
65+
for i, successor in enumerate(node.successors):
66+
successor_id = len(pos) # Create a new unique ID for each successor
67+
# Add edge from the current node to the successor
68+
graph.add_edge(current_id, successor_id)
69+
# Position each successor progressively lower
70+
y = y - (i + 1) * y_offset # Adjust vertical spacing between successors
71+
y = add_nodes_recursive(graph, successor, successor_id, x + x_offset, y, pos,
72+
labels) # Adjust horizontal spacing
73+
return y
74+
75+
# Add CEO node and its successors recursively
76+
add_nodes_recursive(G, ceo, 0, x_start, y_start, pos, labels)
77+
78+
# Increase the canvas size using figsize (width, height in inches)
79+
fig, ax = plt.subplots(figsize=(12, 8)) # Adjust this to make the canvas bigger
80+
81+
# Define a function to manually position the arrows
82+
def draw_edges_with_custom_arrows(graph, pos, ax):
83+
for edge in graph.edges():
84+
start_node, end_node = edge
85+
86+
# Get the positions of the nodes
87+
start_x, start_y = pos[start_node]
88+
end_x, end_y = pos[end_node]
89+
90+
# Define the exit point (right of the start node) and entry point (left of the end node)
91+
exit_x = start_x + 1.5 # Right side of the start node (assuming box width of 3)
92+
entry_x = end_x - 1.5 # Left side of the end node (assuming box width of 3)
93+
94+
# Draw the arrow
95+
ax.annotate(
96+
'', xy=(entry_x, end_y), xytext=(exit_x, start_y),
97+
arrowprops=dict(arrowstyle='-|>', lw=1.5, color='black',
98+
connectionstyle='arc3,rad=0.0') # Straight arrow
99+
)
100+
101+
# Call the function to draw custom arrows
102+
draw_edges_with_custom_arrows(G, pos, ax)
103+
104+
# Draw the labels
105+
nx.draw_networkx_labels(G, pos, labels, font_size=10, font_color="black")
106+
107+
# Manually draw boxes (rectangles) around the nodes
108+
for node, (x, y) in pos.items():
109+
# Define the size of each box (width and height can be adjusted)
110+
width = 3
111+
height = 1.5
112+
# Draw a rectangle centered on the node's position
113+
rect = Rectangle((x - width / 2, y - height / 2), width, height,
114+
linewidth=1, edgecolor='black', facecolor='lightblue')
115+
ax.add_patch(rect)
116+
117+
# Dynamically calculate plot limits to ensure all boxes fit
118+
x_values = [x for x, y in pos.values()]
119+
y_values = [y for x, y in pos.values()]
120+
121+
# Adjust the limits based on the positions and box sizes
122+
x_margin = width / 2 + 1 # Add margin for the box width and extra space
123+
y_margin = height / 2 + 1 # Add margin for the box height and extra space
124+
125+
ax.set_xlim(min(x_values) - x_margin, max(x_values) + x_margin)
126+
ax.set_ylim(min(y_values) - y_margin, max(y_values) + y_margin)
127+
128+
# Remove the default axes for a cleaner look
129+
ax.set_axis_off()
130+
131+
# Add a title and display the plot
132+
plt.title("Orgchart")
133+
plt.savefig("test_successor_dag.png")
134+
135+
if __name__ == "__main__":
136+
pytest.main([__file__])

0 commit comments

Comments
 (0)