From 77660a8cf9264db955b35801272386e7d6f68fdd Mon Sep 17 00:00:00 2001 From: Tomasz Kramkowski Date: Sat, 18 Dec 2021 00:24:20 +0000 Subject: day 17: now faster with a list of intervals --- 17.py | 71 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 60 insertions(+), 11 deletions(-) diff --git a/17.py b/17.py index ebe3116..88665b3 100644 --- a/17.py +++ b/17.py @@ -1,7 +1,33 @@ +from __future__ import annotations from math import sqrt, ceil, floor from utils import parse_day -from collections import defaultdict -from itertools import chain +from functools import cache, total_ordering +from enum import Enum, auto +from dataclasses import dataclass + +@total_ordering +@dataclass(frozen=True) +class Interval: + class Type(Enum): + START = auto() + END = auto() + + class Axis(Enum): + Y = auto() + X = auto() + + time: int + type: Interval.Type + axis: Interval.Axis + value: int + + def __lt__(self, other: Interval) -> bool: + def cmptuple(o: Interval) -> tuple[int, bool, bool]: + is_end: bool = o.type == Interval.Type.END + return o.time, is_end, (o.axis == Interval.Axis.Y) ^ is_end, o.value + if not isinstance(other, Interval): + return NotImplemented + return cmptuple(self) < cmptuple(other) tx_min, tx_max, ty_min, ty_max = \ map(int, parse_day(17, r'.* x=(-?\d+)..(-?\d+), y=(-?\d+)..(-?\d+)')) @@ -11,16 +37,19 @@ vx_max = tx_max vy_min = ty_min vy_max = abs(ty_min) - 1 +@cache def sum_n(n): return n * (n + 1) // 2 print(sum_n(vy_max)) +@cache def time_at_pos(v, d): """d = -t(t - 2v - 1) / 2 solved for t""" middle = sqrt((2 * v + 1) ** 2 - 8 * d) return (-middle + 2 * v + 1) / 2, (middle + 2 * v + 1) / 2 +@cache def time_at_xpos(v, d): if d > sum_n(v): return float('inf') return time_at_pos(v, d)[0] @@ -32,18 +61,38 @@ def y_times(vy0): def x_times(vx0): start = time_at_xpos(vx0, tx_min) + if start != float('inf'): start = ceil(start) end = time_at_xpos(vx0, tx_max) + if end != float('inf'): end = floor(end) return start, end -times = defaultdict(set) +intervals = [] +for vx0 in range(vx_min, vx_max + 1): + start, end = x_times(vx0) + if end < start: continue + intervals.append(Interval(start, Interval.Type.START, Interval.Axis.X, vx0)) + intervals.append(Interval(end, Interval.Type.END, Interval.Axis.X, vx0)) for vy0 in range(vy_min, vy_max + 1): - tmin, tmax = y_times(vy0) - for t in range(tmin, tmax + 1): - times[t].add(vy0) + start, end = y_times(vy0) + if end < start: continue + intervals.append(Interval(start, Interval.Type.START, Interval.Axis.Y, vy0)) + intervals.append(Interval(end, Interval.Type.END, Interval.Axis.Y, vy0)) +intervals.sort() +active_vx0 = set() +active_vy0 = set() count = 0 -for vx0 in range(vx_min, vx_max + 1): - tmin, tmax = x_times(vx0) - count += len(set(chain.from_iterable( - vy0set for t, vy0set in times.items() if tmin <= t <= tmax - ))) +for interval in intervals: + match interval: + case Interval(type=Interval.Type.END, axis=axis, value=value): + match axis: + case Interval.Axis.X: active_vx0.remove(value) + case Interval.Axis.Y: active_vy0.remove(value) + case Interval(type=Interval.Type.START, axis=axis, value=value): + match axis: + case Interval.Axis.X: + count += len(active_vy0) + active_vx0.add(value) + case Interval.Axis.Y: + count += len(active_vx0) + active_vy0.add(value) print(count) -- cgit v1.2.3-54-g00ecf