from __future__ import annotations from utils import open_day from bitstring import BitStream from dataclasses import dataclass from functools import reduce from operator import mul, gt, lt, eq @dataclass class Packet: version: int tag: int contents: list[Packet] | int with open_day(16) as f: bits = BitStream(bytes.fromhex(f.read().rstrip())) def get_packet(bits): version, tag = bits.readlist('2*uint:3') if tag == 4: value = 0 while True: keep_going, nybble = bits.readlist('bool,uint:4') value <<= 4 value |= nybble if not keep_going: break return Packet(version, tag, value) ltid = bits.read('uint:1') contents = [] if ltid == 0: length = bits.read('uint:15') start = bits.pos while bits.pos - start < length: contents.append(get_packet(bits)) else: packets = bits.read('uint:11') for _ in range(packets): contents.append(get_packet(bits)) return Packet(version, tag, contents) def sum_versions(packet): if packet.tag == 4: return packet.version return packet.version + sum(sum_versions(p) for p in packet.contents) def evaluate(packet): def eval_all(p): for q in p.contents: yield evaluate(q) def binary_op(p, op): assert(len(p.contents) == 2) return int(op(evaluate(p.contents[0]), evaluate(p.contents[1]))) actions = [ lambda p: sum(eval_all(p)), lambda p: reduce(mul, eval_all(p), 1), lambda p: min(eval_all(p)), lambda p: max(eval_all(p)), lambda p: p.contents, lambda p: binary_op(p, gt), lambda p: binary_op(p, lt), lambda p: binary_op(p, eq), ] return actions[packet.tag](packet) packet = get_packet(bits) print(sum_versions(packet)) print(evaluate(packet))