summaryrefslogtreecommitdiffstats
path: root/16.py
blob: b5ef399ee8a7d736693ed3cda7e262e50c630478 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from __future__ import annotations
from utils import open_day
from bitstring import BitStream
from dataclasses import dataclass
from functools import reduce
from operator import mul, gt, lt, eq

@dataclass
class Packet:
    version: int
    tag: int
    contents: list[Packet] | int

with open_day(16) as f:
    bits = BitStream(bytes.fromhex(f.read().rstrip()))

def get_packet(bits):
    version, tag = bits.readlist('2*uint:3')
    if tag == 4:
        value = 0
        while True:
            keep_going, nybble = bits.readlist('bool,uint:4')
            value <<= 4
            value |= nybble
            if not keep_going: break
        return Packet(version, tag, value)
    ltid = bits.read('uint:1')
    contents = []
    if ltid == 0:
        length = bits.read('uint:15')
        start = bits.pos
        while bits.pos - start < length:
            contents.append(get_packet(bits))
    else:
        packets = bits.read('uint:11')
        for _ in range(packets):
            contents.append(get_packet(bits))
    return Packet(version, tag, contents)

def sum_versions(packet):
    if packet.tag == 4:
        return packet.version
    return packet.version + sum(sum_versions(p) for p in packet.contents)

def evaluate(packet):
    def eval_all(p):
        for q in p.contents:
            yield evaluate(q)
    def binary_op(p, op):
        assert(len(p.contents) == 2)
        return int(op(evaluate(p.contents[0]), evaluate(p.contents[1])))
    actions = [
        lambda p: sum(eval_all(p)),
        lambda p: reduce(mul, eval_all(p), 1),
        lambda p: min(eval_all(p)),
        lambda p: max(eval_all(p)),
        lambda p: p.contents,
        lambda p: binary_op(p, gt),
        lambda p: binary_op(p, lt),
        lambda p: binary_op(p, eq),
    ]
    return actions[packet.tag](packet)

packet = get_packet(bits)
print(sum_versions(packet))
print(evaluate(packet))