1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
|
from functools import cache
from collections import defaultdict
from utils import open_day
adj = defaultdict(set)
with open_day(12) as f:
for l in f:
a, b = l.rstrip().split('-')
adj[a].add(b)
adj[b].add(a)
def count_paths(adj, look_twice=False):
@cache
def f(node, seen, saw_once=None):
def descend(node, seen, saw_once):
return sum(f(n, seen, saw_once) for n in adj[node] - seen)
if node == 'end':
if saw_once is None or saw_once in seen:
return 1
return 0
total = 0
if node.islower():
if look_twice and saw_once is None:
total += descend(node, seen, node)
seen = seen | frozenset((node,))
return total + descend(node, seen, saw_once)
return sum(f(n, frozenset(('start',))) for n in adj['start'])
print(count_paths(adj))
print(count_paths(adj, True))
|