summaryrefslogtreecommitdiffstats
path: root/17.py
blob: 88665b3a4d98a773632774086be0795fd1c7adcd (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from __future__ import annotations
from math import sqrt, ceil, floor
from utils import parse_day
from functools import cache, total_ordering
from enum import Enum, auto
from dataclasses import dataclass

@total_ordering
@dataclass(frozen=True)
class Interval:
    class Type(Enum):
        START = auto()
        END = auto()

    class Axis(Enum):
        Y = auto()
        X = auto()

    time: int
    type: Interval.Type
    axis: Interval.Axis
    value: int

    def __lt__(self, other: Interval) -> bool:
        def cmptuple(o: Interval) -> tuple[int, bool, bool]:
            is_end: bool = o.type == Interval.Type.END
            return o.time, is_end, (o.axis == Interval.Axis.Y) ^ is_end, o.value
        if not isinstance(other, Interval):
            return NotImplemented
        return cmptuple(self) < cmptuple(other)

tx_min, tx_max, ty_min, ty_max = \
    map(int, parse_day(17, r'.* x=(-?\d+)..(-?\d+), y=(-?\d+)..(-?\d+)'))

vx_min = int(sqrt(tx_min * 2) - 1)
vx_max = tx_max
vy_min = ty_min
vy_max = abs(ty_min) - 1

@cache
def sum_n(n):
    return n * (n + 1) // 2

print(sum_n(vy_max))

@cache
def time_at_pos(v, d):
    """d = -t(t - 2v - 1) / 2 solved for t"""
    middle = sqrt((2 * v + 1) ** 2 - 8 * d)
    return (-middle + 2 * v + 1) / 2, (middle + 2 * v + 1) / 2

@cache
def time_at_xpos(v, d):
    if d > sum_n(v): return float('inf')
    return time_at_pos(v, d)[0]

def y_times(vy0):
    start = ceil(time_at_pos(vy0, ty_max)[1])
    end = floor(time_at_pos(vy0, ty_min)[1])
    return start, end

def x_times(vx0):
    start = time_at_xpos(vx0, tx_min)
    if start != float('inf'): start = ceil(start)
    end = time_at_xpos(vx0, tx_max)
    if end != float('inf'): end = floor(end)
    return start, end

intervals = []
for vx0 in range(vx_min, vx_max + 1):
    start, end = x_times(vx0)
    if end < start: continue
    intervals.append(Interval(start, Interval.Type.START, Interval.Axis.X, vx0))
    intervals.append(Interval(end,   Interval.Type.END,   Interval.Axis.X, vx0))
for vy0 in range(vy_min, vy_max + 1):
    start, end = y_times(vy0)
    if end < start: continue
    intervals.append(Interval(start, Interval.Type.START, Interval.Axis.Y, vy0))
    intervals.append(Interval(end,   Interval.Type.END,   Interval.Axis.Y, vy0))
intervals.sort()
active_vx0 = set()
active_vy0 = set()
count = 0
for interval in intervals:
    match interval:
        case Interval(type=Interval.Type.END, axis=axis, value=value):
            match axis:
                case Interval.Axis.X: active_vx0.remove(value)
                case Interval.Axis.Y: active_vy0.remove(value)
        case Interval(type=Interval.Type.START, axis=axis, value=value):
            match axis:
                case Interval.Axis.X:
                    count += len(active_vy0)
                    active_vx0.add(value)
                case Interval.Axis.Y:
                    count += len(active_vx0)
                    active_vy0.add(value)
print(count)