diff options
| author | Tomasz Kramkowski <tomasz@kramkow.ski> | 2023-12-17 12:53:21 +0000 | 
|---|---|---|
| committer | Tomasz Kramkowski <tomasz@kramkow.ski> | 2023-12-17 12:53:21 +0000 | 
| commit | a259269cfece3f812d1a0a5c4a837e6a96a3d2a3 (patch) | |
| tree | 78b2c9acbed7ad32b58f7089648cd068293b1fd6 | |
| parent | c30b831007c0e663d545c55cab87e7d257aa0293 (diff) | |
| download | aoc2023-a259269cfece3f812d1a0a5c4a837e6a96a3d2a3.tar.gz aoc2023-a259269cfece3f812d1a0a5c4a837e6a96a3d2a3.tar.xz aoc2023-a259269cfece3f812d1a0a5c4a837e6a96a3d2a3.zip  | |
day 17
| -rw-r--r-- | 17.py | 79 | ||||
| -rw-r--r-- | utils.py | 67 | 
2 files changed, 146 insertions, 0 deletions
@@ -0,0 +1,79 @@ +from collections.abc import Callable, Iterator +from dataclasses import dataclass +from functools import total_ordering +from sys import stdin +from typing import Self + +from utils import Dir, a_star + + +@total_ordering +@dataclass +class Pos: +    x: int +    y: int +    direction: Dir | None = None + +    def __eq__(self, other: object) -> bool: +        if isinstance(other, Pos): +            return (self.x, self.y) == (other.x, other.y) +        return NotImplemented + +    def __lt__(self, other: object) -> bool: +        if isinstance(other, Pos): +            return (self.x, self.y) < (other.x, other.y) +        return NotImplemented + +    def __hash__(self) -> int: +        return hash((self.x, self.y, self.direction)) + +    def manhattan(self, other: Self) -> int: +        return abs(self.x - other.x) + abs(self.y - other.y) + + +def make_neighbours( +    width: int, height: int, min_dist: int, max_dist: int +) -> Callable[[Pos], Iterator[Pos]]: +    def neighbours(p: Pos) -> Iterator[Pos]: +        if p.direction not in {Dir.NORTH, Dir.SOUTH}: +            for dy in range(min_dist, max_dist + 1): +                if p.y - dy < height and p.y - dy >= 0: +                    yield Pos(p.x, p.y - dy, Dir.NORTH) +                if p.y + dy < height and p.y + dy >= 0: +                    yield Pos(p.x, p.y + dy, Dir.SOUTH) +        if p.direction not in {Dir.EAST, Dir.WEST}: +            for dx in range(min_dist, max_dist + 1): +                if p.x + dx < width and p.x + dx >= 0: +                    yield Pos(p.x + dx, p.y, Dir.EAST) +                if p.x - dx < width and p.x - dx >= 0: +                    yield Pos(p.x - dx, p.y, Dir.WEST) + +    return neighbours + + +inp = [line.rstrip("\n") for line in stdin] + + +def normalize(n: int) -> int: +    return min(1, max(-1, n)) + + +def d(a: Pos, b: Pos) -> int: +    ndx = normalize(b.x - a.x) +    ndy = normalize(b.y - a.y) +    total = 0 +    if ndx != 0: +        for x in range(a.x + ndx, b.x + ndx, ndx): +            total += int(inp[a.y][x]) +    elif ndy != 0: +        for y in range(a.y + ndy, b.y + ndy, ndy): +            total += int(inp[y][a.x]) +    return total + + +goal = Pos(len(inp[0]) - 1, len(inp) - 1) + +p1neighbours = make_neighbours(len(inp[0]), len(inp), 1, 3) +print(a_star(Pos(0, 0), goal, p1neighbours, goal.manhattan, d)) +p2neighbours = make_neighbours(len(inp[0]), len(inp), 4, 10) +print(a_star(Pos(0, 0), goal, p2neighbours, goal.manhattan, d)) diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..afdd191 --- /dev/null +++ b/utils.py @@ -0,0 +1,67 @@ +# pyright: strict +from collections import defaultdict +from collections.abc import Callable, Iterator +from enum import Enum, auto +from heapq import heappop, heappush +from typing import TypeVar + + +class Dir(Enum): +    NORTH = auto() +    EAST = auto() +    SOUTH = auto() +    WEST = auto() + + +T = TypeVar("T") + + +def a_star( +    start: T, +    goal: T | Callable[[T], bool], +    neighbours: Callable[[T], Iterator[T]], +    h: Callable[[T], float], +    d: Callable[[T, T], float], +): +    vtoi: dict[T, int] = dict() +    itov: dict[int, T] = dict() +    i = 0 + +    def intern(v: T) -> int: +        nonlocal i +        ret = vtoi.get(v) +        if ret is None: +            ret = i +            vtoi[v] = i +            itov[i] = v +            i += 1 +        return ret + +    open_set: set[T] = {start} +    open_heapq: list[tuple[float, int]] = [(0, intern(start))] +    came_from: dict[T, T] = dict() +    g_score: defaultdict[T, float] = defaultdict(lambda: float("inf")) +    g_score[start] = 0 +    f_score: dict[T, float] = dict() +    f_score[start] = h(start) +    while open_set: +        while True: +            f, current = heappop(open_heapq) +            current = itov[current] +            if current in open_set: +                break +        if callable(goal): +            if goal(current): +                return f +        elif current == goal: +            return f +        open_set.remove(current) +        for neighbour in neighbours(current): +            tentative_g = g_score[current] + d(current, neighbour) +            if tentative_g < g_score[neighbour]: +                came_from[neighbour] = current +                g_score[neighbour] = tentative_g +                f_score[neighbour] = tentative_g + h(neighbour) +                open_set.add(neighbour) +                heappush(open_heapq, (f_score[neighbour], intern(neighbour))) +    return None  | 
