from functools import cache from collections import defaultdict from utils import open_day adj = defaultdict(set) with open_day(12) as f: for l in f: a, b = l.rstrip().split('-') adj[a].add(b) adj[b].add(a) def count_paths(adj, look_twice=False): @cache def f(node, seen, saw_once=None): def descend(node, seen, saw_once): return sum(f(n, seen, saw_once) for n in adj[node] - seen) if node == 'end': if saw_once is None or saw_once in seen: return 1 return 0 total = 0 if node.islower(): if look_twice and saw_once is None: total += descend(node, seen, node) seen = seen | frozenset((node,)) return total + descend(node, seen, saw_once) return sum(f(n, frozenset(('start',))) for n in adj['start']) print(count_paths(adj)) print(count_paths(adj, True))