summaryrefslogtreecommitdiffstats
path: root/21.py
blob: 58d20019462dc7dad45fc8f738b172918df2f213 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
from collections import deque
from itertools import groupby, islice
from sys import stdin

inp = [line.rstrip() for line in stdin]

for y, l in enumerate(inp):
    try:
        sx, sy = l.index("S"), y
        break
    except ValueError:
        pass
else:
    raise ValueError


def all_equal(iterable):
    g = groupby(iterable)
    return next(g, True) and not next(g, False)


def solve(inp: list[str], start: tuple[int, int]) -> list[int]:
    queue = deque()
    queue.append((start[0], start[1], 0))
    seen = set()
    history = list()
    while queue:
        px, py, depth = queue.popleft()
        if (px, py) in seen:
            continue
        seen.add((px, py))
        if len(history) <= depth:
            history.append(0)
        history[depth] += 1
        for dx in (-1, 1):
            nx = px + dx
            if 0 <= nx < len(inp[0]) and inp[py][nx] != "#":
                queue.append((nx, py, depth + 1))
        for dy in (-1, 1):
            ny = py + dy
            if 0 <= ny < len(inp) and inp[ny][px] != "#":
                queue.append((px, ny, depth + 1))
    return history


assert len(inp) == len(inp[0])

nw = solve(inp, (len(inp[0]) - 1, len(inp) - 1))
n = solve(inp, (sx, len(inp) - 1))
ne = solve(inp, (0, len(inp) - 1))
w = solve(inp, (len(inp[0]) - 1, sy))
c = solve(inp, (sx, sy))
e = solve(inp, (0, sy))
sw = solve(inp, (len(inp[0]) - 1, 0))
s = solve(inp, (sx, 0))
se = solve(inp, (0, 0))


def sum_offt(a: list[int], offset: int, point: int | None = None) -> int:
    if point is None:
        point = len(a)
    else:
        point += 1
    return sum(islice(a, offset, point, 2))


def closed(target: int) -> int:
    base_offt = target % 2
    radius = (len(inp) - 1) // 2

    arm_len = max(target - len(inp) - radius - 1, 0) // len(inp)
    a_arms_offt = (base_offt + radius + 1) % 2
    a_arms_count = (arm_len + 1) // 2
    b_arms_offt = (a_arms_offt + 1) % 2
    b_arms_count = arm_len // 2
    arms_sum = sum(
        sum_offt(d, a_arms_offt) * a_arms_count
        + sum_offt(d, b_arms_offt) * b_arms_count
        for d in (n, e, s, w)
    )

    corner_len = max(target - len(inp), 0) // len(inp)
    e_corner_count = (corner_len // 2) ** 2
    o_corner_len = max(corner_len - 1, 0) // 2
    o_corner_count = o_corner_len * (o_corner_len + 1)
    corner_sum = sum(
        sum_offt(d, base_offt) * e_corner_count
        + sum_offt(d, (base_offt + 1) % 2) * o_corner_count
        for d in (nw, ne, sw, se)
    )

    a_tips_point = max(target - radius - 1, 0) % (2 * len(inp))
    a_tips_offt = (radius + base_offt + 1) % 2
    a_tips_count = 1 if target + radius >= len(inp) else 0
    b_tips_point = max(target - radius - len(inp) - 1, 0) % (2 * len(inp))
    b_tips_offt = (radius + base_offt) % 2
    b_tips_count = 1 if target - radius > len(inp) else 0

    tips_sum = sum(
        sum_offt(d, a_tips_offt, a_tips_point) * a_tips_count
        + sum_offt(d, b_tips_offt, b_tips_point) * b_tips_count
        for d in (n, e, s, w)
    )

    a_sides_point = max(target - len(inp), 0) % (2 * len(inp))
    a_sides_count = max((target + len(inp) - 1) // (2 * len(inp)) * 2 - 1, 0)
    a_sides_offt = base_offt
    a_sides_sum = (
        sum(sum_offt(d, a_sides_offt, a_sides_point) for d in (nw, ne, sw, se))
        * a_sides_count
    )

    b_sides_point = target % (2 * len(inp))
    b_sides_count = max(target - 1, 0) // (2 * len(inp)) * 2
    b_sides_offt = (base_offt + 1) % 2
    b_sides_sum = (
        sum(sum_offt(d, b_sides_offt, b_sides_point) for d in (nw, ne, sw, se))
        * b_sides_count
    )

    return (
        sum_offt(c, base_offt, target)
        + corner_sum
        + arms_sum
        + tips_sum
        + a_sides_sum
        + b_sides_sum
    )


print(closed(64))
print(closed(26501365))