diff options
Diffstat (limited to 'utils.py')
-rw-r--r-- | utils.py | 67 |
1 files changed, 67 insertions, 0 deletions
diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..afdd191 --- /dev/null +++ b/utils.py @@ -0,0 +1,67 @@ +# pyright: strict +from collections import defaultdict +from collections.abc import Callable, Iterator +from enum import Enum, auto +from heapq import heappop, heappush +from typing import TypeVar + + +class Dir(Enum): + NORTH = auto() + EAST = auto() + SOUTH = auto() + WEST = auto() + + +T = TypeVar("T") + + +def a_star( + start: T, + goal: T | Callable[[T], bool], + neighbours: Callable[[T], Iterator[T]], + h: Callable[[T], float], + d: Callable[[T, T], float], +): + vtoi: dict[T, int] = dict() + itov: dict[int, T] = dict() + i = 0 + + def intern(v: T) -> int: + nonlocal i + ret = vtoi.get(v) + if ret is None: + ret = i + vtoi[v] = i + itov[i] = v + i += 1 + return ret + + open_set: set[T] = {start} + open_heapq: list[tuple[float, int]] = [(0, intern(start))] + came_from: dict[T, T] = dict() + g_score: defaultdict[T, float] = defaultdict(lambda: float("inf")) + g_score[start] = 0 + f_score: dict[T, float] = dict() + f_score[start] = h(start) + while open_set: + while True: + f, current = heappop(open_heapq) + current = itov[current] + if current in open_set: + break + if callable(goal): + if goal(current): + return f + elif current == goal: + return f + open_set.remove(current) + for neighbour in neighbours(current): + tentative_g = g_score[current] + d(current, neighbour) + if tentative_g < g_score[neighbour]: + came_from[neighbour] = current + g_score[neighbour] = tentative_g + f_score[neighbour] = tentative_g + h(neighbour) + open_set.add(neighbour) + heappush(open_heapq, (f_score[neighbour], intern(neighbour))) + return None |