|
| 1 | +# Copyright (C) 2023-present The Project Contributors |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +import pytest |
| 16 | +from cl.runtime.context.testing_context import TestingContext |
| 17 | +from dataclasses import dataclass |
| 18 | +from typing import List, Optional |
| 19 | +import networkx as nx |
| 20 | +import matplotlib.pyplot as plt |
| 21 | +from matplotlib.patches import Rectangle |
| 22 | +from runtime.testing.pytest.pytest_fixtures import local_dir_fixture |
| 23 | + |
| 24 | +def test_smoke(local_dir_fixture): |
| 25 | + with TestingContext() as context: |
| 26 | + |
| 27 | + @dataclass |
| 28 | + class Node: |
| 29 | + title: str |
| 30 | + successors: Optional[List['Node']] = None |
| 31 | + |
| 32 | + # Define the nodes with successors |
| 33 | + staff_a = Node(title="Staff A") |
| 34 | + staff_b = Node(title="Staff B") |
| 35 | + staff_c = Node(title="Staff C") |
| 36 | + staff_d = Node(title="Staff D") |
| 37 | + team_1 = Node(title="Team A Lead", successors=[staff_a, staff_b]) |
| 38 | + team_2 = Node(title="Team B Lead", successors=[staff_c, staff_d]) |
| 39 | + ceo = Node(title="CEO", successors=[team_1, team_2]) |
| 40 | + |
| 41 | + # Create a directed graph |
| 42 | + G = nx.DiGraph() |
| 43 | + |
| 44 | + # Initialize position dictionary and label dictionary |
| 45 | + pos = {} |
| 46 | + labels = {} |
| 47 | + |
| 48 | + # Starting coordinates for the CEO |
| 49 | + x_start = 10 |
| 50 | + y_start = 10 |
| 51 | + x_offset = 6 # Horizontal distance between nodes |
| 52 | + y_offset = 2 # Vertical distance between each successor |
| 53 | + |
| 54 | + current_y = 0 |
| 55 | + |
| 56 | + # Function to recursively add nodes and edges to the graph |
| 57 | + def add_nodes_recursive(graph, node, current_id, x, y, pos, labels) -> int: |
| 58 | + # Add the current node to the graph |
| 59 | + graph.add_node(current_id) |
| 60 | + pos[current_id] = (x, y) |
| 61 | + labels[current_id] = node.title |
| 62 | + |
| 63 | + # Add successors recursively |
| 64 | + if node.successors: |
| 65 | + for i, successor in enumerate(node.successors): |
| 66 | + successor_id = len(pos) # Create a new unique ID for each successor |
| 67 | + # Add edge from the current node to the successor |
| 68 | + graph.add_edge(current_id, successor_id) |
| 69 | + # Position each successor progressively lower |
| 70 | + y = y - (i + 1) * y_offset # Adjust vertical spacing between successors |
| 71 | + y = add_nodes_recursive(graph, successor, successor_id, x + x_offset, y, pos, |
| 72 | + labels) # Adjust horizontal spacing |
| 73 | + return y |
| 74 | + |
| 75 | + # Add CEO node and its successors recursively |
| 76 | + add_nodes_recursive(G, ceo, 0, x_start, y_start, pos, labels) |
| 77 | + |
| 78 | + # Increase the canvas size using figsize (width, height in inches) |
| 79 | + fig, ax = plt.subplots(figsize=(12, 8)) # Adjust this to make the canvas bigger |
| 80 | + |
| 81 | + # Define a function to manually position the arrows |
| 82 | + def draw_edges_with_custom_arrows(graph, pos, ax): |
| 83 | + for edge in graph.edges(): |
| 84 | + start_node, end_node = edge |
| 85 | + |
| 86 | + # Get the positions of the nodes |
| 87 | + start_x, start_y = pos[start_node] |
| 88 | + end_x, end_y = pos[end_node] |
| 89 | + |
| 90 | + # Define the exit point (right of the start node) and entry point (left of the end node) |
| 91 | + exit_x = start_x + 1.5 # Right side of the start node (assuming box width of 3) |
| 92 | + entry_x = end_x - 1.5 # Left side of the end node (assuming box width of 3) |
| 93 | + |
| 94 | + # Draw the arrow |
| 95 | + ax.annotate( |
| 96 | + '', xy=(entry_x, end_y), xytext=(exit_x, start_y), |
| 97 | + arrowprops=dict(arrowstyle='-|>', lw=1.5, color='black', |
| 98 | + connectionstyle='arc3,rad=0.0') # Straight arrow |
| 99 | + ) |
| 100 | + |
| 101 | + # Call the function to draw custom arrows |
| 102 | + draw_edges_with_custom_arrows(G, pos, ax) |
| 103 | + |
| 104 | + # Draw the labels |
| 105 | + nx.draw_networkx_labels(G, pos, labels, font_size=10, font_color="black") |
| 106 | + |
| 107 | + # Manually draw boxes (rectangles) around the nodes |
| 108 | + for node, (x, y) in pos.items(): |
| 109 | + # Define the size of each box (width and height can be adjusted) |
| 110 | + width = 3 |
| 111 | + height = 1.5 |
| 112 | + # Draw a rectangle centered on the node's position |
| 113 | + rect = Rectangle((x - width / 2, y - height / 2), width, height, |
| 114 | + linewidth=1, edgecolor='black', facecolor='lightblue') |
| 115 | + ax.add_patch(rect) |
| 116 | + |
| 117 | + # Dynamically calculate plot limits to ensure all boxes fit |
| 118 | + x_values = [x for x, y in pos.values()] |
| 119 | + y_values = [y for x, y in pos.values()] |
| 120 | + |
| 121 | + # Adjust the limits based on the positions and box sizes |
| 122 | + x_margin = width / 2 + 1 # Add margin for the box width and extra space |
| 123 | + y_margin = height / 2 + 1 # Add margin for the box height and extra space |
| 124 | + |
| 125 | + ax.set_xlim(min(x_values) - x_margin, max(x_values) + x_margin) |
| 126 | + ax.set_ylim(min(y_values) - y_margin, max(y_values) + y_margin) |
| 127 | + |
| 128 | + # Remove the default axes for a cleaner look |
| 129 | + ax.set_axis_off() |
| 130 | + |
| 131 | + # Add a title and display the plot |
| 132 | + plt.title("Orgchart") |
| 133 | + plt.savefig("test_successor_dag.png") |
| 134 | + |
| 135 | +if __name__ == "__main__": |
| 136 | + pytest.main([__file__]) |
0 commit comments