Skip to content

Commit 7d663ef

Browse files
authored
Merge pull request #497 from that-ar-guy/a-star
added a star
2 parents 8697b61 + 939e357 commit 7d663ef

File tree

2 files changed

+90
-0
lines changed

2 files changed

+90
-0
lines changed
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
from typing import Tuple, List, Dict
2+
from heapq import heappop, heappush
3+
import math
4+
from data_classes import Graph
5+
from logger_config import logger
6+
7+
def heuristic(node: Tuple[int, int], goal: Tuple[int, int]) -> float:
8+
"""Euclidean distance heuristic."""
9+
return math.sqrt((node[0] - goal[0])**2 + (node[1] - goal[1])**2)
10+
11+
def a_star(graph: Graph, start: Tuple[int, int], end: Tuple[int, int]) -> Tuple[List[Tuple[int, int]], float]:
12+
try:
13+
open_set = []
14+
heappush(open_set, (0, start))
15+
g_score: Dict[Tuple[int, int], float] = {vertex: float('inf') for vertex in graph.vertices}
16+
g_score[start] = 0
17+
came_from: Dict[Tuple[int, int], Tuple[int, int]] = {}
18+
19+
while open_set:
20+
_, current = heappop(open_set)
21+
22+
if current == end:
23+
path = []
24+
while current in came_from:
25+
path.insert(0, current)
26+
current = came_from[current]
27+
path.insert(0, start)
28+
return path, g_score[end]
29+
30+
for neighbor, weight in graph.edges.get(current, {}).items():
31+
tentative_g_score = g_score[current] + weight
32+
if tentative_g_score < g_score[neighbor]:
33+
g_score[neighbor] = tentative_g_score
34+
f_score = tentative_g_score + heuristic(neighbor, end)
35+
heappush(open_set, (f_score, neighbor))
36+
came_from[neighbor] = current
37+
38+
raise ValueError("Path not found")
39+
except Exception as e:
40+
logger.error(f"Error in A* algorithm: {e}")
41+
raise
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import unittest
2+
from data_classes import Graph
3+
from a_star import a_star
4+
import sys
5+
import os
6+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
7+
class TestAStar(unittest.TestCase):
8+
def setUp(self):
9+
vertices = {(0, 0), (1, 0), (1, 1), (2, 1), (2, 2)}
10+
edges = {
11+
(0, 0): {(1, 0): 1},
12+
(1, 0): {(1, 1): 1, (0, 0): 1},
13+
(1, 1): {(2, 1): 1, (1, 0): 1},
14+
(2, 1): {(2, 2): 1, (1, 1): 1},
15+
(2, 2): {(2, 1): 1}
16+
}
17+
self.graph = Graph(vertices, edges)
18+
self.graph.vertices = {
19+
(0, 0), (1, 0), (1, 1), (2, 1), (2, 2)
20+
}
21+
self.graph.edges = {
22+
(0, 0): {(1, 0): 1},
23+
(1, 0): {(1, 1): 1, (0, 0): 1},
24+
(1, 1): {(2, 1): 1, (1, 0): 1},
25+
(2, 1): {(2, 2): 1, (1, 1): 1},
26+
(2, 2): {(2, 1): 1}
27+
}
28+
29+
def test_valid_path(self):
30+
path, cost = a_star(self.graph, (0, 0), (2, 2))
31+
self.assertEqual(path, [(0, 0), (1, 0), (1, 1), (2, 1), (2, 2)])
32+
self.assertEqual(cost, 4)
33+
34+
def test_no_path(self):
35+
self.graph.edges = {
36+
(0, 0): {(1, 0): 1},
37+
(1, 0): {(0, 0): 1},
38+
(2, 2): {} # Disconnected node
39+
}
40+
with self.assertRaises(ValueError):
41+
a_star(self.graph, (0, 0), (2, 2))
42+
43+
def test_same_start_end(self):
44+
path, cost = a_star(self.graph, (1, 1), (1, 1))
45+
self.assertEqual(path, [(1, 1)])
46+
self.assertEqual(cost, 0)
47+
48+
if __name__ == "__main__":
49+
unittest.main()

0 commit comments

Comments
 (0)