summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorTomasz Kramkowski <tomasz@kramkow.ski>2023-12-17 12:53:21 +0000
committerTomasz Kramkowski <tomasz@kramkow.ski>2023-12-17 12:53:21 +0000
commita259269cfece3f812d1a0a5c4a837e6a96a3d2a3 (patch)
tree78b2c9acbed7ad32b58f7089648cd068293b1fd6
parentc30b831007c0e663d545c55cab87e7d257aa0293 (diff)
downloadaoc2023-a259269cfece3f812d1a0a5c4a837e6a96a3d2a3.tar.gz
aoc2023-a259269cfece3f812d1a0a5c4a837e6a96a3d2a3.tar.xz
aoc2023-a259269cfece3f812d1a0a5c4a837e6a96a3d2a3.zip
day 17
-rw-r--r--17.py79
-rw-r--r--utils.py67
2 files changed, 146 insertions, 0 deletions
diff --git a/17.py b/17.py
new file mode 100644
index 0000000..2b463f8
--- /dev/null
+++ b/17.py
@@ -0,0 +1,79 @@
+from collections.abc import Callable, Iterator
+from dataclasses import dataclass
+from functools import total_ordering
+from sys import stdin
+from typing import Self
+
+from utils import Dir, a_star
+
+
+@total_ordering
+@dataclass
+class Pos:
+ x: int
+ y: int
+ direction: Dir | None = None
+
+ def __eq__(self, other: object) -> bool:
+ if isinstance(other, Pos):
+ return (self.x, self.y) == (other.x, other.y)
+ return NotImplemented
+
+ def __lt__(self, other: object) -> bool:
+ if isinstance(other, Pos):
+ return (self.x, self.y) < (other.x, other.y)
+ return NotImplemented
+
+ def __hash__(self) -> int:
+ return hash((self.x, self.y, self.direction))
+
+ def manhattan(self, other: Self) -> int:
+ return abs(self.x - other.x) + abs(self.y - other.y)
+
+
+def make_neighbours(
+ width: int, height: int, min_dist: int, max_dist: int
+) -> Callable[[Pos], Iterator[Pos]]:
+ def neighbours(p: Pos) -> Iterator[Pos]:
+ if p.direction not in {Dir.NORTH, Dir.SOUTH}:
+ for dy in range(min_dist, max_dist + 1):
+ if p.y - dy < height and p.y - dy >= 0:
+ yield Pos(p.x, p.y - dy, Dir.NORTH)
+ if p.y + dy < height and p.y + dy >= 0:
+ yield Pos(p.x, p.y + dy, Dir.SOUTH)
+ if p.direction not in {Dir.EAST, Dir.WEST}:
+ for dx in range(min_dist, max_dist + 1):
+ if p.x + dx < width and p.x + dx >= 0:
+ yield Pos(p.x + dx, p.y, Dir.EAST)
+ if p.x - dx < width and p.x - dx >= 0:
+ yield Pos(p.x - dx, p.y, Dir.WEST)
+
+ return neighbours
+
+
+inp = [line.rstrip("\n") for line in stdin]
+
+
+def normalize(n: int) -> int:
+ return min(1, max(-1, n))
+
+
+def d(a: Pos, b: Pos) -> int:
+ ndx = normalize(b.x - a.x)
+ ndy = normalize(b.y - a.y)
+ total = 0
+ if ndx != 0:
+ for x in range(a.x + ndx, b.x + ndx, ndx):
+ total += int(inp[a.y][x])
+ elif ndy != 0:
+ for y in range(a.y + ndy, b.y + ndy, ndy):
+ total += int(inp[y][a.x])
+ return total
+
+
+goal = Pos(len(inp[0]) - 1, len(inp) - 1)
+
+p1neighbours = make_neighbours(len(inp[0]), len(inp), 1, 3)
+print(a_star(Pos(0, 0), goal, p1neighbours, goal.manhattan, d))
+p2neighbours = make_neighbours(len(inp[0]), len(inp), 4, 10)
+print(a_star(Pos(0, 0), goal, p2neighbours, goal.manhattan, d))
diff --git a/utils.py b/utils.py
new file mode 100644
index 0000000..afdd191
--- /dev/null
+++ b/utils.py
@@ -0,0 +1,67 @@
+# pyright: strict
+from collections import defaultdict
+from collections.abc import Callable, Iterator
+from enum import Enum, auto
+from heapq import heappop, heappush
+from typing import TypeVar
+
+
+class Dir(Enum):
+ NORTH = auto()
+ EAST = auto()
+ SOUTH = auto()
+ WEST = auto()
+
+
+T = TypeVar("T")
+
+
+def a_star(
+ start: T,
+ goal: T | Callable[[T], bool],
+ neighbours: Callable[[T], Iterator[T]],
+ h: Callable[[T], float],
+ d: Callable[[T, T], float],
+):
+ vtoi: dict[T, int] = dict()
+ itov: dict[int, T] = dict()
+ i = 0
+
+ def intern(v: T) -> int:
+ nonlocal i
+ ret = vtoi.get(v)
+ if ret is None:
+ ret = i
+ vtoi[v] = i
+ itov[i] = v
+ i += 1
+ return ret
+
+ open_set: set[T] = {start}
+ open_heapq: list[tuple[float, int]] = [(0, intern(start))]
+ came_from: dict[T, T] = dict()
+ g_score: defaultdict[T, float] = defaultdict(lambda: float("inf"))
+ g_score[start] = 0
+ f_score: dict[T, float] = dict()
+ f_score[start] = h(start)
+ while open_set:
+ while True:
+ f, current = heappop(open_heapq)
+ current = itov[current]
+ if current in open_set:
+ break
+ if callable(goal):
+ if goal(current):
+ return f
+ elif current == goal:
+ return f
+ open_set.remove(current)
+ for neighbour in neighbours(current):
+ tentative_g = g_score[current] + d(current, neighbour)
+ if tentative_g < g_score[neighbour]:
+ came_from[neighbour] = current
+ g_score[neighbour] = tentative_g
+ f_score[neighbour] = tentative_g + h(neighbour)
+ open_set.add(neighbour)
+ heappush(open_heapq, (f_score[neighbour], intern(neighbour)))
+ return None