summaryrefslogtreecommitdiffstats
path: root/8.py
blob: 0bd67eda6ea0e622d0ea515108c233bf190d23fa (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from collections import Counter
from itertools import chain
from utils import open_day
from enum import IntEnum, auto

fset = frozenset

class Seg(IntEnum):
    A = 0
    B = auto()
    C = auto()
    D = auto()
    E = auto()
    F = auto()
    G = auto()

dig = [
    fset((Seg.A, Seg.B, Seg.C, Seg.E, Seg.F, Seg.G)),
    fset((Seg.C, Seg.F)),
    fset((Seg.A, Seg.C, Seg.D, Seg.E, Seg.G)),
    fset((Seg.A, Seg.C, Seg.D, Seg.F, Seg.G)),
    fset((Seg.B, Seg.C, Seg.D, Seg.F)),
    fset((Seg.A, Seg.B, Seg.D, Seg.F, Seg.G)),
    fset((Seg.A, Seg.B, Seg.D, Seg.E, Seg.F, Seg.G)),
    fset((Seg.A, Seg.C, Seg.F)),
    fset((Seg.A, Seg.B, Seg.C, Seg.D, Seg.E, Seg.F, Seg.G)),
    fset((Seg.A, Seg.B, Seg.C, Seg.D, Seg.F, Seg.G)),
]

seg = [
    fset((0, 2, 3, 5, 6, 7, 8, 9)),
    fset((0, 4, 5, 6, 8, 9)),
    fset((0, 1, 2, 3, 4, 7, 8, 9)),
    fset((2, 3, 4, 5, 6, 8, 9)),
    fset((0, 2, 6, 8)),
    fset((0, 1, 3, 4, 5, 6, 7, 8, 9)),
    fset((0, 2, 3, 5, 6, 8, 9)),
]

counts = Counter()
count = 0
with open_day(8) as f:
    for line in f:
        patterns, digits = line.rstrip().split(' | ')
        patterns = fset(fset(p) for p in patterns.split())
        countmap = {
            count: seg
            for seg, count in Counter(chain.from_iterable(patterns)).items()
        }
        def patbylen(l): return next(p for p in patterns if len(p) == l)
        CF      = patbylen(len(dig[1]))
        BCDF    = patbylen(len(dig[4]))
        ACF     = patbylen(len(dig[7]))
        ABCDEFG = patbylen(len(dig[8]))
        segmap = dict()
        segmap[Seg.B] = countmap[len(seg[Seg.B])]
        segmap[Seg.E] = countmap[len(seg[Seg.E])]
        segmap[Seg.F] = countmap[len(seg[Seg.F])]
        def single(s):
            assert(len(s) == 1)
            return next(iter(s))
        segmap[Seg.A] = single(ACF - CF)
        segmap[Seg.C] = single(CF - fset(segmap[Seg.F]))
        segmap[Seg.D] = single(BCDF - CF - fset(segmap[Seg.B]))
        segmap[Seg.G] = single(ABCDEFG - fset(segmap.values()))
        digmap = { fset(segmap[s] for s in dig[i]): i for i in range(10) }
        digits = [digmap[fset(d)] for d in digits.split()]
        counts.update(digits)
        count += sum(d * 10 ** i for i, d in enumerate(reversed(digits)))
print(counts[1] + counts[4] + counts[7] + counts[8])
print(count)