summaryrefslogtreecommitdiffstats
path: root/8.py
blob: 99b5502e276a6c8840678cee1a0c9817511ed9db (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from collections import Counter
from itertools import chain
from utils import open_day
from enum import IntEnum

fset = frozenset

Seg = IntEnum('Seg', [(n, i) for i, n in enumerate('ABCDEFG')])

dig = [
    fset((Seg.A, Seg.B, Seg.C,        Seg.E, Seg.F, Seg.G)),
    fset((              Seg.C,               Seg.F       )),
    fset((Seg.A,        Seg.C, Seg.D, Seg.E,        Seg.G)),
    fset((Seg.A,        Seg.C, Seg.D,        Seg.F, Seg.G)),
    fset((       Seg.B, Seg.C, Seg.D,        Seg.F       )),
    fset((Seg.A, Seg.B,        Seg.D,        Seg.F, Seg.G)),
    fset((Seg.A, Seg.B,        Seg.D, Seg.E, Seg.F, Seg.G)),
    fset((Seg.A,        Seg.C,               Seg.F       )),
    fset((Seg.A, Seg.B, Seg.C, Seg.D, Seg.E, Seg.F, Seg.G)),
    fset((Seg.A, Seg.B, Seg.C, Seg.D,        Seg.F, Seg.G)),
]
seg = [ fset(i for i in range(10) if seg in dig[i]) for seg in Seg ]

counts = Counter()
count = 0
with open_day(8) as f:
    for line in f:
        patterns, digits = line.rstrip().split(' | ')
        patterns = fset(fset(p) for p in patterns.split())
        countmap = {
            count: fset((seg,))
            for seg, count in Counter(chain.from_iterable(patterns)).items()
        }
        lenmap = { len(pat): pat for pat in patterns }
        CF      = lenmap[len(dig[1])]
        BCDF    = lenmap[len(dig[4])]
        ACF     = lenmap[len(dig[7])]
        ABCDEFG = lenmap[len(dig[8])]
        B = countmap[len(seg[Seg.B])]
        E = countmap[len(seg[Seg.E])]
        F = countmap[len(seg[Seg.F])]
        A = ACF - CF
        C = CF - F
        D = BCDF - CF - B
        G = ABCDEFG - A - BCDF - E
        segmap = [ A, B, C, D, E, F, G ]
        digmap = { fset(next(iter(segmap[s])) for s in d): i for i, d in enumerate(dig) }
        digits = [ digmap[fset(d)] for d in digits.split() ]
        counts.update(digits)
        count += sum(d * 10 ** i for i, d in enumerate(reversed(digits)))
print(counts[1] + counts[4] + counts[7] + counts[8])
print(count)