diff options
| -rw-r--r-- | 17.py | 71 | 
1 files changed, 60 insertions, 11 deletions
| @@ -1,7 +1,33 @@ +from __future__ import annotations  from math import sqrt, ceil, floor  from utils import parse_day -from collections import defaultdict -from itertools import chain +from functools import cache, total_ordering +from enum import Enum, auto +from dataclasses import dataclass + +@total_ordering +@dataclass(frozen=True) +class Interval: +    class Type(Enum): +        START = auto() +        END = auto() + +    class Axis(Enum): +        Y = auto() +        X = auto() + +    time: int +    type: Interval.Type +    axis: Interval.Axis +    value: int + +    def __lt__(self, other: Interval) -> bool: +        def cmptuple(o: Interval) -> tuple[int, bool, bool]: +            is_end: bool = o.type == Interval.Type.END +            return o.time, is_end, (o.axis == Interval.Axis.Y) ^ is_end, o.value +        if not isinstance(other, Interval): +            return NotImplemented +        return cmptuple(self) < cmptuple(other)  tx_min, tx_max, ty_min, ty_max = \      map(int, parse_day(17, r'.* x=(-?\d+)..(-?\d+), y=(-?\d+)..(-?\d+)')) @@ -11,16 +37,19 @@ vx_max = tx_max  vy_min = ty_min  vy_max = abs(ty_min) - 1 +@cache  def sum_n(n):      return n * (n + 1) // 2  print(sum_n(vy_max)) +@cache  def time_at_pos(v, d):      """d = -t(t - 2v - 1) / 2 solved for t"""      middle = sqrt((2 * v + 1) ** 2 - 8 * d)      return (-middle + 2 * v + 1) / 2, (middle + 2 * v + 1) / 2 +@cache  def time_at_xpos(v, d):      if d > sum_n(v): return float('inf')      return time_at_pos(v, d)[0] @@ -32,18 +61,38 @@ def y_times(vy0):  def x_times(vx0):      start = time_at_xpos(vx0, tx_min) +    if start != float('inf'): start = ceil(start)      end = time_at_xpos(vx0, tx_max) +    if end != float('inf'): end = floor(end)      return start, end -times = defaultdict(set) +intervals = [] +for vx0 in range(vx_min, vx_max + 1): +    start, end = x_times(vx0) +    if end < start: continue +    intervals.append(Interval(start, Interval.Type.START, Interval.Axis.X, vx0)) +    intervals.append(Interval(end,   Interval.Type.END,   Interval.Axis.X, vx0))  for vy0 in range(vy_min, vy_max + 1): -    tmin, tmax = y_times(vy0) -    for t in range(tmin, tmax + 1): -        times[t].add(vy0) +    start, end = y_times(vy0) +    if end < start: continue +    intervals.append(Interval(start, Interval.Type.START, Interval.Axis.Y, vy0)) +    intervals.append(Interval(end,   Interval.Type.END,   Interval.Axis.Y, vy0)) +intervals.sort() +active_vx0 = set() +active_vy0 = set()  count = 0 -for vx0 in range(vx_min, vx_max + 1): -    tmin, tmax = x_times(vx0) -    count += len(set(chain.from_iterable( -        vy0set for t, vy0set in times.items() if tmin <= t <= tmax -    ))) +for interval in intervals: +    match interval: +        case Interval(type=Interval.Type.END, axis=axis, value=value): +            match axis: +                case Interval.Axis.X: active_vx0.remove(value) +                case Interval.Axis.Y: active_vy0.remove(value) +        case Interval(type=Interval.Type.START, axis=axis, value=value): +            match axis: +                case Interval.Axis.X: +                    count += len(active_vy0) +                    active_vx0.add(value) +                case Interval.Axis.Y: +                    count += len(active_vx0) +                    active_vy0.add(value)  print(count) | 
