From 4c867d0024d19ffcb57bb2dbb7d6ab10603a579f Mon Sep 17 00:00:00 2001 From: Tomasz Kramkowski Date: Fri, 16 Dec 2022 16:58:00 +0000 Subject: 16 faster --- 16.py | 39 +++++++++++++++++++++++++++------------ 1 file changed, 27 insertions(+), 12 deletions(-) diff --git a/16.py b/16.py index a8819f3..c004091 100644 --- a/16.py +++ b/16.py @@ -1,7 +1,7 @@ from utils import open_day from functools import cache from dataclasses import dataclass -from itertools import count +from itertools import count, chain import re @dataclass @@ -9,6 +9,15 @@ class Node: flow: int neighbours: list[tuple[str, int]] +ids = {'AA': 0} +nextid = 1 +def intern(n): + global nextid + if n in ids: return ids[n] + ids[n] = nextid + nextid += 1 + return nextid - 1 + regex = re.compile(r'^Valve (..) has flow rate=([0-9]+); tunnels? leads? to valves? (.*)$') inp = {} with open_day(16) as f: @@ -16,16 +25,16 @@ with open_day(16) as f: m = regex.match(line) assert(m) valve, flow, neighbours = m.group(1, 2, 3) - inp[valve] = (int(flow), neighbours.split(', ')) + inp[intern(valve)] = (int(flow), [intern(n) for n in neighbours.split(', ')]) nodes = {} for valve, (flow, neighbours) in inp.items(): - if valve != 'AA' and flow == 0: continue + if valve != 0 and flow == 0: continue actual_neighbours = [] for n in neighbours: prev = valve for cost in count(1): - if n == 'AA' or inp[n][0] != 0: break + if n == 0 or inp[n][0] != 0: break l, r = inp[n][1] nnext = r if l == prev else l prev = n @@ -35,21 +44,27 @@ for valve, (flow, neighbours) in inp.items(): @cache def sum_flow(open_valves): - return sum(nodes[n].flow for n in open_valves) + total = 0 + for i in count(): + if not open_valves: break + if open_valves & 1: total += nodes[i].flow + open_valves >>= 1 + return total @cache -def recurse(open_valves=frozenset(), current='AA', flow=0, time_left=30): +def recurse(open_valves=0, current=0, flow=0, time_left=30): cflow = sum_flow(open_valves) cnode = nodes[current] best = flow + cflow * time_left + def check(open_valves, current, flow, time_left): + nonlocal best + new = recurse(open_valves, current, flow, time_left) + if new > best: best = new for neighbour, cost in cnode.neighbours: if cost >= time_left: continue - new = recurse(open_valves, neighbour, - flow + cflow * cost, time_left - cost) - best = max(best, new) - if current not in open_valves and time_left > 0 and cnode.flow > 0: - best = max(best, recurse(open_valves | {current}, current, - flow + cflow, time_left - 1)) + check(open_valves, neighbour, flow + cflow * cost, time_left - cost) + if not (open_valves & (1 << current)) and time_left > 0 and cnode.flow > 0: + check(open_valves | 1 << current, current, flow + cflow, time_left - 1) return best print(recurse()) -- cgit v1.2.3-54-g00ecf