summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorTomasz Kramkowski <tomasz@kramkow.ski>2022-12-18 13:14:00 +0000
committerTomasz Kramkowski <tomasz@kramkow.ski>2022-12-18 13:19:18 +0000
commit2b6b5944d325846f1490e7f521ead7148639a1a0 (patch)
tree07d618e554e8621342e56c2d3798fa0b535b2af2
parent4224d099d7ed66dcb8514c462e3e5db389fe1c4e (diff)
downloadaoc2022-2b6b5944d325846f1490e7f521ead7148639a1a0.tar.gz
aoc2022-2b6b5944d325846f1490e7f521ead7148639a1a0.tar.xz
aoc2022-2b6b5944d325846f1490e7f521ead7148639a1a0.zip
16 remove sum_flow
-rw-r--r--16.py51
1 files changed, 24 insertions, 27 deletions
diff --git a/16.py b/16.py
index 7a4bcb5..cd27325 100644
--- a/16.py
+++ b/16.py
@@ -47,32 +47,29 @@ for valve, (flow, neighbours) in inp.items():
actual_neighbours.append((intern(n), cost))
nodes[intern(valve)] = Node(flow, actual_neighbours)
-@cache
-def sum_flow(open_valves):
- total = 0
- for i in count():
- if not open_valves: break
- if open_valves & 1: total += nodes[i].flow
- open_valves >>= 1
- return total
+def solve(nodes, open_valves=1, time_left=30):
+ nodemask = 2 ** len(nodes) - 1
+ recurse_cache = dict()
+ def recurse(open_valves, current, flow, time_left, path):
+ key = (current, flow, time_left)
+ if key in recurse_cache: return recurse_cache[key]
+ best = flow, path
+ if open_valves == nodemask:
+ recurse_cache[key] = best
+ return best
+ cnode = nodes[current]
+ def check(open_valves, current, flow, time_left):
+ nonlocal best
+ new = recurse(open_valves, current, flow, time_left, path + [current])
+ if new[0] > best[0]: best = new
+ for neighbour, cost in cnode.neighbours:
+ if cost >= time_left: continue
+ check(open_valves, neighbour, flow, time_left - cost)
+ if not (open_valves & (1 << current)) and time_left > 0:
+ check(open_valves | 1 << current, current, flow + cnode.flow * (time_left - 1), time_left - 1)
+ recurse_cache[key] = best
+ return best
+ return recurse(open_valves, 0, 0, time_left, [])
-recurse_cache = dict()
-def recurse(open_valves=1, current=0, flow=0, time_left=30, path=[]):
- key = (current, flow, time_left)
- if key in recurse_cache: return recurse_cache[key]
- cnode = nodes[current]
- best = flow, path
- def check(open_valves, current, flow, time_left):
- nonlocal best
- new = recurse(open_valves, current, flow, time_left, path + [current])
- if new[0] > best[0]: best = new
- for neighbour, cost in cnode.neighbours:
- if cost >= time_left: continue
- check(open_valves, neighbour, flow, time_left - cost)
- if not (open_valves & (1 << current)) and time_left > 0:
- check(open_valves | 1 << current, current, flow + cnode.flow * (time_left - 1), time_left - 1)
- recurse_cache[key] = best
- return best
-
-v, p = recurse()
+v, p = solve(nodes)
print(v, [extern(s) for s in p])