summaryrefslogtreecommitdiffstats
path: root/3.py
blob: eed75a3bbc3dabb28c7a9801ffc4f731f17f6eda (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
# pyright: strict
import re
from collections import defaultdict
from collections.abc import Iterator
from functools import reduce
from operator import mul
from sys import stdin

num_re = re.compile("[0-9]+")

inp = [line.rstrip() for line in stdin]


def find_adjacent_symbols(
    inp: list[str], ly: int, sx: int, ex: int
) -> Iterator[tuple[int, int]]:
    for x in range(max(0, sx - 1), min(ex + 1, len(inp[0]))):
        for y in range(max(0, ly - 1), min(ly + 2, len(inp))):
            if y == ly and sx <= x < ex:
                continue
            if inp[y][x] != "." and not inp[y][x].isdigit():
                yield x, y


p1 = 0
p2_gears: defaultdict[tuple[int, int], list[int]] = defaultdict(list)
for ly, l in enumerate(inp):
    for match in num_re.finditer(l):
        value = int(match.group())
        sx, ex = match.span()
        is_part_no = False
        for x, y in find_adjacent_symbols(inp, ly, sx, ex):
            if inp[y][x] == "*":
                p2_gears[x, y].append(value)
            is_part_no = True
        if is_part_no:
            p1 += value

print(p1)
print(sum(reduce(mul, nums) for nums in p2_gears.values() if len(nums) == 2))