summaryrefslogtreecommitdiffstats
path: root/utils.py
blob: c7ab871aa471b0a88487e9717c3c6816951b4cb9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from __future__ import annotations
from collections import deque, defaultdict
from collections.abc import Iterable, Iterator
from dataclasses import dataclass
from heapq import heappop, heappush
from itertools import islice
from math import sqrt
from re import match as re_match
from sys import argv
from typing import TypeVar, Generic, Union, Optional, cast

NumT = TypeVar('NumT', float, int)

@dataclass(frozen=True)
class Point2D(Generic[NumT]):
    x: NumT
    y: NumT

    def __str__(self) -> str:
        return f'({self.x}, {self.y})'

    def __add__(self, other: Point2D[NumT]) -> Union[Point2D[NumT], bool]:
        if isinstance(other, Point2D):
            return Point2D(self.x + other.x, self.y + other.y)
        return NotImplemented

    def __sub__(self, other: Point2D[NumT]) -> Union[Point2D[NumT], bool]:
        if isinstance(other, Point2D):
            return Point2D(self.x - other.x, self.y - other.y)
        return NotImplemented

    def __abs__(self) -> Point2D[NumT]:
        return Point2D(cast(NumT, abs(self.x)), cast(NumT, abs(self.y)))

def norm_1(p: Point2D[NumT]) -> NumT:
    return cast(NumT, abs(p.x) + abs(p.y))

def norm_2(p: Point2D) -> float:
    return sqrt(p.x ** 2 + p.y ** 2)

def norm_inf(p: Point2D[NumT]) -> NumT:
    return cast(NumT, max(abs(p.x), abs(p.y)))

def inbounds(p: Point2D[NumT], a: Point2D[NumT], _b: Optional[Point2D[NumT]] = None) -> bool:
    if isinstance(_b, Point2D):
        b = _b
    else:
        b = a
        a = cast(Point2D[NumT], Point2D(0, 0) if isinstance(a.x, int) else Point2D(0.0, 0.0))
    return p.x >= a.x and p.y >= a.y and p.x < b.x and p.y < b.y

def adjacent(p: Point2D[int], diagonal: bool = True) -> Iterator[Point2D[int]]:
    for dx in range(-1, 2):
        for dy in range(-1, 2):
            if dx == 0 and dy == 0: continue
            if dx != 0 and dy != 0 and not diagonal: continue
            yield Point2D(p.x + dx, p.y + dy)

def adjacent_bounded(p: Point2D[int], bound: Point2D[int], diagonal: bool = True) \
        -> Iterator[Point2D[int]]:
    return filter(lambda p: inbounds(p, bound), adjacent(p, diagonal))

def a_star(start, goal, neighbours, h, d):
    vtoi = dict()
    itov = dict()
    i = 0
    def intern(v):
        nonlocal i
        ret = vtoi.get(v)
        if ret is None:
            ret = i
            vtoi[v] = i
            itov[i] = v
            i += 1
        return ret
    open_set = {start}
    open_heapq = [(0, intern(start))]
    came_from = dict()
    g_score = defaultdict(lambda: float('inf'))
    g_score[start] = 0
    f_score = {}
    f_score[start] = h(start)
    while open_set:
        while True:
            f, current = heappop(open_heapq)
            current = itov[current]
            if current in open_set: break
        if current == goal:
            return f
        open_set.remove(current)
        for neighbour in neighbours(current):
            tentative_g = g_score[current] + d(current, neighbour)
            if tentative_g < g_score[neighbour]:
                came_from[neighbour] = current
                g_score[neighbour] = tentative_g
                f_score[neighbour] = tentative_g + h(neighbour)
                open_set.add(neighbour)
                heappush(open_heapq, (f_score[neighbour], intern(neighbour)))
    return None

T = TypeVar('T')
def sliding_window(iterable: Iterable[T], n: int) -> Iterator[tuple[T, ...]]:
    # sliding_window('ABCDEFG', 4) -> ABCD BCDE CDEF DEFG
    it: Iterator[T] = iter(iterable)
    window: deque[T] = deque(islice(it, n), maxlen=n)
    if len(window) == n:
        yield tuple(window)
    for x in it:
        window.append(x)
        yield tuple(window)

def open_day(n: int):
    if len(argv) == 2:
        return open(argv[1])
    return open(f'{n}.in')

def parse_day(n: int, regex: str):
    with open_day(n) as f:
        return re_match(regex, f.read().rstrip()).groups()