summaryrefslogtreecommitdiffstats
path: root/12.py
blob: c46fbe6d3877dfb4a5f3c319bfe22463e7470a4b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from functools import cache
from collections import defaultdict
from utils import open_day

adj = defaultdict(set)

with open_day(12) as f:
    for l in f:
        a, b = l.rstrip().split('-')
        adj[a].add(b)
        adj[b].add(a)

def count_paths(adj, look_twice=False):
    @cache
    def f(node, seen, saw_once=None):
        def descend(node, seen, saw_once):
            return sum(f(n, seen, saw_once) for n in adj[node] - seen)
        if node == 'end':
            if saw_once is None or saw_once in seen:
                return 1
            return 0
        total = 0
        if node.islower():
            if look_twice and saw_once is None:
                total += descend(node, seen, node)
            seen = seen | frozenset((node,))
        return total + descend(node, seen, saw_once)
    return sum(f(n, frozenset(('start',))) for n in adj['start'])

print(count_paths(adj))
print(count_paths(adj, True))