from collections.abc import Iterator from dataclasses import dataclass, fields from sys import stdin from typing import Self, Type, TypeAlias @dataclass(frozen=True, slots=True) class Constraint: split: int less: bool @classmethod def from_str(cls: Type[Self], s: str) -> Self: return cls(int(s[1:]), s[0] == "<") @dataclass(frozen=True, slots=True) class Range: start: int length: int def split(self, c: Constraint) -> tuple[Self, Self]: if c.less: length = min(self.length, max(0, c.split - self.start)) return self.__class__(self.start, length), self.__class__( self.start + length, self.length - length ) else: length = min(self.length, max(0, c.split - self.start + 1)) return self.__class__( self.start + length, self.length - length ), self.__class__(self.start, length) @dataclass(frozen=True, slots=True) class XMASRange: x: Range m: Range a: Range s: Range @property def product(self) -> int: return self.x.length * self.m.length * self.a.length * self.s.length def split(self, c: tuple[str, Constraint]) -> tuple[Self, Self]: base = dict((field.name, getattr(self, field.name)) for field in fields(self)) target, constraint = c leftattr, rightattr = getattr(self, target).split(constraint) base[target] = leftattr left = self.__class__(**base) base[target] = rightattr right = self.__class__(**base) return left, right Rule: TypeAlias = tuple[tuple[str, Constraint], str] | str def combinations( r: XMASRange, rules: Iterator[Rule], tables: dict[str, list[Rule]] ) -> int: rule = next(rules) if isinstance(rule, str): if rule == "A": return r.product if rule == "R": return 0 return combinations(r, iter(tables[rule]), tables) constraint, target = rule accept, reject = r.split(constraint) if target == "A": left = accept.product elif target == "R": left = 0 else: left = combinations(accept, iter(tables[target]), tables) return left + combinations(reject, rules, tables) def parse_rule(r: str) -> Rule: if ":" not in r: return r cond, target = r.split(":") attr, constraint = cond[0], cond[1:] return (attr, Constraint.from_str(constraint)), target def parse_table(t: str) -> tuple[str, list[Rule]]: key, rules = t.rstrip("}").split("{") return key, [parse_rule(r) for r in rules.split(",")] def parse_part(p: str) -> XMASRange: d = dict() for kv in p.strip("{}").split(","): k, v = kv.split("=") d[k] = Range(int(v), 1) return XMASRange(**d) tables, parts = stdin.read().rstrip("\n").split("\n\n") tables = dict(parse_table(t) for t in tables.split("\n")) parts = [parse_part(p) for p in parts.split("\n")] p1 = 0 for part in parts: if combinations(part, iter(tables["in"]), tables) == 1: p1 += part.x.start + part.m.start + part.a.start + part.s.start print(p1) print( combinations( XMASRange(Range(1, 4000), Range(1, 4000), Range(1, 4000), Range(1, 4000)), iter(tables["in"]), tables, ) )