1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
|
from __future__ import annotations
from math import sqrt, ceil, floor
from utils import parse_day
from functools import cache, total_ordering
from enum import Enum, auto
from dataclasses import dataclass
@total_ordering
@dataclass(frozen=True)
class Interval:
class Type(Enum):
START = auto()
END = auto()
class Axis(Enum):
Y = auto()
X = auto()
time: int
type: Interval.Type
axis: Interval.Axis
value: int
def __lt__(self, other: Interval) -> bool:
def cmptuple(o: Interval) -> tuple[int, bool, bool]:
is_end: bool = o.type == Interval.Type.END
return o.time, is_end, (o.axis == Interval.Axis.Y) ^ is_end, o.value
if not isinstance(other, Interval):
return NotImplemented
return cmptuple(self) < cmptuple(other)
tx_min, tx_max, ty_min, ty_max = \
map(int, parse_day(17, r'.* x=(-?\d+)..(-?\d+), y=(-?\d+)..(-?\d+)'))
vx_min = int(sqrt(tx_min * 2) - 1)
vx_max = tx_max
vy_min = ty_min
vy_max = abs(ty_min) - 1
@cache
def sum_n(n):
return n * (n + 1) // 2
print(sum_n(vy_max))
@cache
def time_at_pos(v, d):
"""d = -t(t - 2v - 1) / 2 solved for t"""
middle = sqrt((2 * v + 1) ** 2 - 8 * d)
return (-middle + 2 * v + 1) / 2, (middle + 2 * v + 1) / 2
@cache
def time_at_xpos(v, d):
if d > sum_n(v): return float('inf')
return time_at_pos(v, d)[0]
def y_times(vy0):
start = ceil(time_at_pos(vy0, ty_max)[1])
end = floor(time_at_pos(vy0, ty_min)[1])
return start, end
def x_times(vx0):
start = time_at_xpos(vx0, tx_min)
if start != float('inf'): start = ceil(start)
end = time_at_xpos(vx0, tx_max)
if end != float('inf'): end = floor(end)
return start, end
intervals = []
for vx0 in range(vx_min, vx_max + 1):
start, end = x_times(vx0)
if end < start: continue
intervals.append(Interval(start, Interval.Type.START, Interval.Axis.X, vx0))
intervals.append(Interval(end, Interval.Type.END, Interval.Axis.X, vx0))
for vy0 in range(vy_min, vy_max + 1):
start, end = y_times(vy0)
if end < start: continue
intervals.append(Interval(start, Interval.Type.START, Interval.Axis.Y, vy0))
intervals.append(Interval(end, Interval.Type.END, Interval.Axis.Y, vy0))
intervals.sort()
active_vx0 = set()
active_vy0 = set()
count = 0
for interval in intervals:
match interval:
case Interval(type=Interval.Type.END, axis=axis, value=value):
match axis:
case Interval.Axis.X: active_vx0.remove(value)
case Interval.Axis.Y: active_vy0.remove(value)
case Interval(type=Interval.Type.START, axis=axis, value=value):
match axis:
case Interval.Axis.X:
count += len(active_vy0)
active_vx0.add(value)
case Interval.Axis.Y:
count += len(active_vx0)
active_vy0.add(value)
print(count)
|