summaryrefslogtreecommitdiffstats
path: root/utils.py
blob: 44fd07a95f0e5c0a4d5db0ac33e3d6807809f510 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from __future__ import annotations
from collections import deque
from collections.abc import Iterable, Iterator
from dataclasses import dataclass
from itertools import islice
from math import sqrt
from re import match as re_match
from sys import argv
from typing import TypeVar, Generic, Union, Optional, cast

NumT = TypeVar('NumT', float, int)

@dataclass(frozen=True)
class Point2D(Generic[NumT]):
    x: NumT
    y: NumT

    def __str__(self) -> str:
        return f'({self.x}, {self.y})'

    def __add__(self, other: Point2D[NumT]) -> Union[Point2D[NumT], bool]:
        if isinstance(other, Point2D):
            return Point2D(self.x + other.x, self.y + other.y)
        return NotImplemented

    def __sub__(self, other: Point2D[NumT]) -> Union[Point2D[NumT], bool]:
        if isinstance(other, Point2D):
            return Point2D(self.x - other.x, self.y - other.y)
        return NotImplemented

    def __abs__(self) -> Point2D[NumT]:
        return Point2D(cast(NumT, abs(self.x)), cast(NumT, abs(self.y)))

def norm_1(p: Point2D[NumT]) -> NumT:
    return cast(NumT, abs(p.x) + abs(p.y))

def norm_2(p: Point2D) -> float:
    return sqrt(p.x ** 2 + p.y ** 2)

def norm_inf(p: Point2D[NumT]) -> NumT:
    return cast(NumT, max(abs(p.x), abs(p.y)))

def inbounds(p: Point2D[NumT], a: Point2D[NumT], _b: Optional[Point2D[NumT]] = None) -> bool:
    if isinstance(_b, Point2D):
        b = _b
    else:
        b = a
        a = cast(Point2D[NumT], Point2D(0, 0) if isinstance(a.x, int) else Point2D(0.0, 0.0))
    return p.x >= a.x and p.y >= a.y and p.x < b.x and p.y < b.y

def adjacent(p: Point2D[int], diagonal: bool = True) -> Iterator[Point2D[int]]:
    for dx in range(-1, 2):
        for dy in range(-1, 2):
            if dx == 0 and dy == 0: continue
            if dx != 0 and dy != 0 and not diagonal: continue
            yield Point2D(p.x + dx, p.y + dy)

def adjacent_bounded(p: Point2D[int], bound: Point2D[int], diagonal: bool = True) \
        -> Iterator[Point2D[int]]:
    return filter(lambda p: inbounds(p, bound), adjacent(p, diagonal))

T = TypeVar('T')
def sliding_window(iterable: Iterable[T], n: int) -> Iterator[tuple[T, ...]]:
    # sliding_window('ABCDEFG', 4) -> ABCD BCDE CDEF DEFG
    it: Iterator[T] = iter(iterable)
    window: deque[T] = deque(islice(it, n), maxlen=n)
    if len(window) == n:
        yield tuple(window)
    for x in it:
        window.append(x)
        yield tuple(window)

def open_day(n: int):
    if len(argv) == 2:
        return open(argv[1])
    return open(f'{n}.in')

def parse_day(n: int, regex: str):
    with open_day(n) as f:
        return re_match(regex, f.read().rstrip()).groups()