summaryrefslogtreecommitdiffstats
path: root/utils.py
blob: 2328cc2385fbf605eceb116a0a944785dda6bc5d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
# pyright: strict
from collections import defaultdict
from collections.abc import Callable, Iterator
from enum import Enum, auto
from heapq import heappop, heappush
from typing import TypeVar


class Dir(Enum):
    NORTH = auto()
    EAST = auto()
    SOUTH = auto()
    WEST = auto()


T = TypeVar("T")


def a_star(
    start: T,
    goal: T | Callable[[T], bool],
    neighbours: Callable[[T], Iterator[T]],
    h: Callable[[T], float],
    d: Callable[[T, T], float],
):
    vtoi: dict[T, int] = dict()
    itov: dict[int, T] = dict()
    i = 0

    def intern(v: T) -> int:
        nonlocal i
        ret = vtoi.get(v)
        if ret is None:
            ret = i
            vtoi[v] = i
            itov[i] = v
            i += 1
        return ret

    open_heapq: list[tuple[float, int]] = [(h(start), intern(start))]
    open_set: set[int] = {intern(start)}
    g_score: defaultdict[T, float] = defaultdict(lambda: float("inf"), [(start, 0)])
    while open_set:
        while True:
            f, current = heappop(open_heapq)
            if current in open_set:
                break
        open_set.remove(current)
        current = itov[current]
        if callable(goal):
            if goal(current):
                return f
        elif current == goal:
            return f
        for neighbour in neighbours(current):
            tentative_g = g_score[current] + d(current, neighbour)
            if tentative_g < g_score[neighbour]:
                g_score[neighbour] = tentative_g
                f_score = tentative_g + h(neighbour)
                heappush(open_heapq, (f_score, intern(neighbour)))
                open_set.add(intern(neighbour))
    return None