diff options
author | Tomasz Kramkowski <tomasz@kramkow.ski> | 2022-12-16 16:58:00 +0000 |
---|---|---|
committer | Tomasz Kramkowski <tomasz@kramkow.ski> | 2022-12-16 16:58:00 +0000 |
commit | 4c867d0024d19ffcb57bb2dbb7d6ab10603a579f (patch) | |
tree | fff86b4e926052b6a95514a8539f850d169ca300 | |
parent | 67116c06713074e3782ddea2f3658490bf22bf67 (diff) | |
download | aoc2022-4c867d0024d19ffcb57bb2dbb7d6ab10603a579f.tar.gz aoc2022-4c867d0024d19ffcb57bb2dbb7d6ab10603a579f.tar.xz aoc2022-4c867d0024d19ffcb57bb2dbb7d6ab10603a579f.zip |
16 faster
-rw-r--r-- | 16.py | 39 |
1 files changed, 27 insertions, 12 deletions
@@ -1,7 +1,7 @@ from utils import open_day from functools import cache from dataclasses import dataclass -from itertools import count +from itertools import count, chain import re @dataclass @@ -9,6 +9,15 @@ class Node: flow: int neighbours: list[tuple[str, int]] +ids = {'AA': 0} +nextid = 1 +def intern(n): + global nextid + if n in ids: return ids[n] + ids[n] = nextid + nextid += 1 + return nextid - 1 + regex = re.compile(r'^Valve (..) has flow rate=([0-9]+); tunnels? leads? to valves? (.*)$') inp = {} with open_day(16) as f: @@ -16,16 +25,16 @@ with open_day(16) as f: m = regex.match(line) assert(m) valve, flow, neighbours = m.group(1, 2, 3) - inp[valve] = (int(flow), neighbours.split(', ')) + inp[intern(valve)] = (int(flow), [intern(n) for n in neighbours.split(', ')]) nodes = {} for valve, (flow, neighbours) in inp.items(): - if valve != 'AA' and flow == 0: continue + if valve != 0 and flow == 0: continue actual_neighbours = [] for n in neighbours: prev = valve for cost in count(1): - if n == 'AA' or inp[n][0] != 0: break + if n == 0 or inp[n][0] != 0: break l, r = inp[n][1] nnext = r if l == prev else l prev = n @@ -35,21 +44,27 @@ for valve, (flow, neighbours) in inp.items(): @cache def sum_flow(open_valves): - return sum(nodes[n].flow for n in open_valves) + total = 0 + for i in count(): + if not open_valves: break + if open_valves & 1: total += nodes[i].flow + open_valves >>= 1 + return total @cache -def recurse(open_valves=frozenset(), current='AA', flow=0, time_left=30): +def recurse(open_valves=0, current=0, flow=0, time_left=30): cflow = sum_flow(open_valves) cnode = nodes[current] best = flow + cflow * time_left + def check(open_valves, current, flow, time_left): + nonlocal best + new = recurse(open_valves, current, flow, time_left) + if new > best: best = new for neighbour, cost in cnode.neighbours: if cost >= time_left: continue - new = recurse(open_valves, neighbour, - flow + cflow * cost, time_left - cost) - best = max(best, new) - if current not in open_valves and time_left > 0 and cnode.flow > 0: - best = max(best, recurse(open_valves | {current}, current, - flow + cflow, time_left - 1)) + check(open_valves, neighbour, flow + cflow * cost, time_left - cost) + if not (open_valves & (1 << current)) and time_left > 0 and cnode.flow > 0: + check(open_valves | 1 << current, current, flow + cflow, time_left - 1) return best print(recurse()) |