summaryrefslogtreecommitdiffstats
path: root/5.py
blob: bcaf85c629ea1d6715a4cea18c40898ed43a3c59 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
# pyright: strict
from collections.abc import Iterator
from dataclasses import dataclass
from sys import stdin


@dataclass(frozen=True)
class Range:
    start: int
    length: int


@dataclass
class MappedRange:
    destination: int
    source: int
    length: int


@dataclass
class Mapping:
    source: str
    destination: str
    ranges: list[MappedRange]

    def map(self, n: int) -> int:
        for r in self.ranges:
            if n >= r.source and n < r.source + r.length:
                return n - r.source + r.destination
        return n

    def map_range(self, t: Range) -> Iterator[Range]:
        start = t.start
        length = t.length
        for r in sorted(self.ranges, key=lambda r: r.source):
            if start + length <= r.source:
                return
            if start >= r.source + r.length:
                continue
            if start < r.source:
                rlen = r.source - start
                yield Range(start, rlen)
                start += rlen
                length -= rlen
            rlen = min(length, r.length - (start - r.source))
            yield Range(start - r.source + r.destination, rlen)
            start += rlen
            length -= rlen
            if length <= 0:
                return
        yield Range(start, length)


seeds, *maps = stdin.read().rstrip("\n").split("\n\n")
seeds = list(map(int, seeds.removeprefix("seeds: ").split()))
mappings: dict[str, Mapping] = dict()
for m in maps:
    header, ranges = m.split("\n", maxsplit=1)
    source, dest = header.removesuffix(" map:").split("-to-")
    ranges = [MappedRange(*map(int, r.split())) for r in ranges.split("\n")]
    mappings[source] = Mapping(source, dest, ranges)


def recmap(mappings: dict[str, Mapping], current: str, value: int) -> int:
    if current not in mappings:
        return value
    mapping = mappings[current]
    return recmap(mappings, mapping.destination, mapping.map(value))


def recmap_range(
    mappings: dict[str, Mapping], current: str, range_: Range
) -> Iterator[Range]:
    if current not in mappings:
        yield range_
        return
    mapping = mappings[current]
    for r in mapping.map_range(range_):
        for s in recmap_range(mappings, mapping.destination, r):
            yield s


print(min(recmap(mappings, "seed", seed) for seed in seeds))
print(
    min(
        (
            r
            for i in range(0, len(seeds), 2)
            for r in recmap_range(mappings, "seed", Range(seeds[i], seeds[i + 1]))
        ),
        key=lambda r: r.start,
    ).start
)