from __future__ import annotations from math import sqrt, ceil, floor from utils import parse_day from functools import cache, total_ordering from enum import Enum, auto from dataclasses import dataclass @total_ordering @dataclass(frozen=True) class Interval: class Type(Enum): START = auto() END = auto() class Axis(Enum): Y = auto() X = auto() time: int type: Interval.Type axis: Interval.Axis value: int def __lt__(self, other: Interval) -> bool: def cmptuple(o: Interval) -> tuple[int, bool, bool]: is_end: bool = o.type == Interval.Type.END return o.time, is_end, (o.axis == Interval.Axis.Y) ^ is_end, o.value if not isinstance(other, Interval): return NotImplemented return cmptuple(self) < cmptuple(other) tx_min, tx_max, ty_min, ty_max = \ map(int, parse_day(17, r'.* x=(-?\d+)..(-?\d+), y=(-?\d+)..(-?\d+)')) vx_min = int(sqrt(tx_min * 2) - 1) vx_max = tx_max vy_min = ty_min vy_max = abs(ty_min) - 1 @cache def sum_n(n): return n * (n + 1) // 2 print(sum_n(vy_max)) @cache def time_at_pos(v, d): """d = -t(t - 2v - 1) / 2 solved for t""" middle = sqrt((2 * v + 1) ** 2 - 8 * d) return (-middle + 2 * v + 1) / 2, (middle + 2 * v + 1) / 2 @cache def time_at_xpos(v, d): if d > sum_n(v): return float('inf') return time_at_pos(v, d)[0] def y_times(vy0): start = ceil(time_at_pos(vy0, ty_max)[1]) end = floor(time_at_pos(vy0, ty_min)[1]) return start, end def x_times(vx0): start = time_at_xpos(vx0, tx_min) if start != float('inf'): start = ceil(start) end = time_at_xpos(vx0, tx_max) if end != float('inf'): end = floor(end) return start, end intervals = [] for vx0 in range(vx_min, vx_max + 1): start, end = x_times(vx0) if end < start: continue intervals.append(Interval(start, Interval.Type.START, Interval.Axis.X, vx0)) intervals.append(Interval(end, Interval.Type.END, Interval.Axis.X, vx0)) for vy0 in range(vy_min, vy_max + 1): start, end = y_times(vy0) if end < start: continue intervals.append(Interval(start, Interval.Type.START, Interval.Axis.Y, vy0)) intervals.append(Interval(end, Interval.Type.END, Interval.Axis.Y, vy0)) intervals.sort() active_vx0 = set() active_vy0 = set() count = 0 for interval in intervals: match interval: case Interval(type=Interval.Type.END, axis=axis, value=value): match axis: case Interval.Axis.X: active_vx0.remove(value) case Interval.Axis.Y: active_vy0.remove(value) case Interval(type=Interval.Type.START, axis=axis, value=value): match axis: case Interval.Axis.X: count += len(active_vy0) active_vx0.add(value) case Interval.Axis.Y: count += len(active_vx0) active_vy0.add(value) print(count)