From 2b6b5944d325846f1490e7f521ead7148639a1a0 Mon Sep 17 00:00:00 2001 From: Tomasz Kramkowski Date: Sun, 18 Dec 2022 13:14:00 +0000 Subject: 16 remove sum_flow --- 16.py | 51 ++++++++++++++++++++++++--------------------------- 1 file changed, 24 insertions(+), 27 deletions(-) diff --git a/16.py b/16.py index 7a4bcb5..cd27325 100644 --- a/16.py +++ b/16.py @@ -47,32 +47,29 @@ for valve, (flow, neighbours) in inp.items(): actual_neighbours.append((intern(n), cost)) nodes[intern(valve)] = Node(flow, actual_neighbours) -@cache -def sum_flow(open_valves): - total = 0 - for i in count(): - if not open_valves: break - if open_valves & 1: total += nodes[i].flow - open_valves >>= 1 - return total +def solve(nodes, open_valves=1, time_left=30): + nodemask = 2 ** len(nodes) - 1 + recurse_cache = dict() + def recurse(open_valves, current, flow, time_left, path): + key = (current, flow, time_left) + if key in recurse_cache: return recurse_cache[key] + best = flow, path + if open_valves == nodemask: + recurse_cache[key] = best + return best + cnode = nodes[current] + def check(open_valves, current, flow, time_left): + nonlocal best + new = recurse(open_valves, current, flow, time_left, path + [current]) + if new[0] > best[0]: best = new + for neighbour, cost in cnode.neighbours: + if cost >= time_left: continue + check(open_valves, neighbour, flow, time_left - cost) + if not (open_valves & (1 << current)) and time_left > 0: + check(open_valves | 1 << current, current, flow + cnode.flow * (time_left - 1), time_left - 1) + recurse_cache[key] = best + return best + return recurse(open_valves, 0, 0, time_left, []) -recurse_cache = dict() -def recurse(open_valves=1, current=0, flow=0, time_left=30, path=[]): - key = (current, flow, time_left) - if key in recurse_cache: return recurse_cache[key] - cnode = nodes[current] - best = flow, path - def check(open_valves, current, flow, time_left): - nonlocal best - new = recurse(open_valves, current, flow, time_left, path + [current]) - if new[0] > best[0]: best = new - for neighbour, cost in cnode.neighbours: - if cost >= time_left: continue - check(open_valves, neighbour, flow, time_left - cost) - if not (open_valves & (1 << current)) and time_left > 0: - check(open_valves | 1 << current, current, flow + cnode.flow * (time_left - 1), time_left - 1) - recurse_cache[key] = best - return best - -v, p = recurse() +v, p = solve(nodes) print(v, [extern(s) for s in p]) -- cgit v1.2.3-54-g00ecf