diff options
Diffstat (limited to '16.py')
-rw-r--r-- | 16.py | 75 |
1 files changed, 75 insertions, 0 deletions
@@ -0,0 +1,75 @@ +from utils import open_day +from functools import cache +from dataclasses import dataclass +from itertools import count +import re + +@dataclass +class Node: + flow: int + neighbours: list[tuple[str, int]] + +regex = re.compile(r'^Valve (..) has flow rate=([0-9]+); tunnels? leads? to valves? (.*)$') +inp = {} +with open_day(16) as f: + for line in f: + m = regex.match(line) + assert(m) + valve, flow, neighbours = m.group(1, 2, 3) + inp[valve] = (int(flow), neighbours.split(', ')) + +nodes = {} +for valve, (flow, neighbours) in inp.items(): + if valve != 'AA' and flow == 0: continue + actual_neighbours = [] + for n in neighbours: + prev = valve + for cost in count(1): + if n == 'AA' or inp[n][0] != 0: + break + l, r = inp[n][1] + if l == prev: + prev = n + n = r + else: + prev = n + n = l + actual_neighbours.append((n, cost)) + nodes[valve] = Node(flow, actual_neighbours) + +@cache +def sum_flow(open_valves): + return sum(nodes[n].flow for n in open_valves) + +@cache +def recurse(open_valves, current, flow, time_left): + cflow = sum_flow(open_valves) + cnode = nodes[current] + best = 0 + for neighbour, cost in cnode.neighbours: + if cost >= time_left: + best = max(best, flow + cflow * time_left) + else: + best = max( + best, + recurse(open_valves, neighbour, + flow + cflow * cost, time_left - cost) + ) + if current not in open_valves and time_left > 0 and cnode.flow > 0: + flow += cflow + open_valves |= {current} + cflow = sum_flow(open_valves) + best = max(best, flow) + time_left -= 1 + for neighbour, cost in cnode.neighbours: + if cost >= time_left: + best = max(best, flow + cflow * time_left) + else: + best = max( + best, + recurse(open_valves, neighbour, + flow + cflow * cost, time_left - cost) + ) + return best + +print(recurse(frozenset(), 'AA', 0, 30)) |