summaryrefslogtreecommitdiffstats
path: root/19.py
blob: c33ee43ba8ff1a8f37ab16e49486771260b7f7ca (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
from collections.abc import Iterator
from dataclasses import dataclass, fields
from sys import stdin
from typing import Self, Type, TypeAlias


@dataclass(frozen=True, slots=True)
class Constraint:
    split: int
    less: bool

    @classmethod
    def from_str(cls: Type[Self], s: str) -> Self:
        return cls(int(s[1:]), s[0] == "<")


@dataclass(frozen=True, slots=True)
class Range:
    start: int
    length: int

    def split(self, c: Constraint) -> tuple[Self, Self]:
        if c.less:
            length = min(self.length, max(0, c.split - self.start))
            return self.__class__(self.start, length), self.__class__(
                self.start + length, self.length - length
            )
        else:
            length = min(self.length, max(0, c.split - self.start + 1))
            return self.__class__(
                self.start + length, self.length - length
            ), self.__class__(self.start, length)


@dataclass(frozen=True, slots=True)
class XMASRange:
    x: Range
    m: Range
    a: Range
    s: Range

    @property
    def product(self) -> int:
        return self.x.length * self.m.length * self.a.length * self.s.length

    def split(self, c: tuple[str, Constraint]) -> tuple[Self, Self]:
        base = dict((field.name, getattr(self, field.name)) for field in fields(self))
        target, constraint = c
        leftattr, rightattr = getattr(self, target).split(constraint)
        base[target] = leftattr
        left = self.__class__(**base)
        base[target] = rightattr
        right = self.__class__(**base)
        return left, right


Rule: TypeAlias = tuple[tuple[str, Constraint], str] | str


def combinations(
    r: XMASRange, rules: Iterator[Rule], tables: dict[str, list[Rule]]
) -> int:
    rule = next(rules)
    if isinstance(rule, str):
        if rule == "A":
            return r.product
        if rule == "R":
            return 0
        return combinations(r, iter(tables[rule]), tables)
    constraint, target = rule
    accept, reject = r.split(constraint)
    if target == "A":
        left = accept.product
    elif target == "R":
        left = 0
    else:
        left = combinations(accept, iter(tables[target]), tables)
    return left + combinations(reject, rules, tables)


def parse_rule(r: str) -> Rule:
    if ":" not in r:
        return r
    cond, target = r.split(":")
    attr, constraint = cond[0], cond[1:]
    return (attr, Constraint.from_str(constraint)), target


def parse_table(t: str) -> tuple[str, list[Rule]]:
    key, rules = t.rstrip("}").split("{")
    return key, [parse_rule(r) for r in rules.split(",")]


def parse_part(p: str) -> XMASRange:
    d = dict()
    for kv in p.strip("{}").split(","):
        k, v = kv.split("=")
        d[k] = Range(int(v), 1)
    return XMASRange(**d)


tables, parts = stdin.read().rstrip("\n").split("\n\n")
tables = dict(parse_table(t) for t in tables.split("\n"))
parts = [parse_part(p) for p in parts.split("\n")]

p1 = 0
for part in parts:
    if combinations(part, iter(tables["in"]), tables) == 1:
        p1 += part.x.start + part.m.start + part.a.start + part.s.start
print(p1)
print(
    combinations(
        XMASRange(Range(1, 4000), Range(1, 4000), Range(1, 4000), Range(1, 4000)),
        iter(tables["in"]),
        tables,
    )
)