summaryrefslogtreecommitdiffstats
path: root/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'utils.py')
-rw-r--r--utils.py67
1 files changed, 67 insertions, 0 deletions
diff --git a/utils.py b/utils.py
new file mode 100644
index 0000000..afdd191
--- /dev/null
+++ b/utils.py
@@ -0,0 +1,67 @@
+# pyright: strict
+from collections import defaultdict
+from collections.abc import Callable, Iterator
+from enum import Enum, auto
+from heapq import heappop, heappush
+from typing import TypeVar
+
+
+class Dir(Enum):
+ NORTH = auto()
+ EAST = auto()
+ SOUTH = auto()
+ WEST = auto()
+
+
+T = TypeVar("T")
+
+
+def a_star(
+ start: T,
+ goal: T | Callable[[T], bool],
+ neighbours: Callable[[T], Iterator[T]],
+ h: Callable[[T], float],
+ d: Callable[[T, T], float],
+):
+ vtoi: dict[T, int] = dict()
+ itov: dict[int, T] = dict()
+ i = 0
+
+ def intern(v: T) -> int:
+ nonlocal i
+ ret = vtoi.get(v)
+ if ret is None:
+ ret = i
+ vtoi[v] = i
+ itov[i] = v
+ i += 1
+ return ret
+
+ open_set: set[T] = {start}
+ open_heapq: list[tuple[float, int]] = [(0, intern(start))]
+ came_from: dict[T, T] = dict()
+ g_score: defaultdict[T, float] = defaultdict(lambda: float("inf"))
+ g_score[start] = 0
+ f_score: dict[T, float] = dict()
+ f_score[start] = h(start)
+ while open_set:
+ while True:
+ f, current = heappop(open_heapq)
+ current = itov[current]
+ if current in open_set:
+ break
+ if callable(goal):
+ if goal(current):
+ return f
+ elif current == goal:
+ return f
+ open_set.remove(current)
+ for neighbour in neighbours(current):
+ tentative_g = g_score[current] + d(current, neighbour)
+ if tentative_g < g_score[neighbour]:
+ came_from[neighbour] = current
+ g_score[neighbour] = tentative_g
+ f_score[neighbour] = tentative_g + h(neighbour)
+ open_set.add(neighbour)
+ heappush(open_heapq, (f_score[neighbour], intern(neighbour)))
+ return None