summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorTomasz Kramkowski <tk@the-tk.com>2021-12-18 00:24:20 +0000
committerTomasz Kramkowski <tk@the-tk.com>2021-12-18 00:39:35 +0000
commit77660a8cf9264db955b35801272386e7d6f68fdd (patch)
treef58880718f81ac58ab7182022a3a04a2df43092b
parent5a13478012d0968eec24e1962153886f512a6ee2 (diff)
downloadaoc2021-77660a8cf9264db955b35801272386e7d6f68fdd.tar.gz
aoc2021-77660a8cf9264db955b35801272386e7d6f68fdd.tar.xz
aoc2021-77660a8cf9264db955b35801272386e7d6f68fdd.zip
day 17: now faster with a list of intervals
-rw-r--r--17.py71
1 files changed, 60 insertions, 11 deletions
diff --git a/17.py b/17.py
index ebe3116..88665b3 100644
--- a/17.py
+++ b/17.py
@@ -1,7 +1,33 @@
+from __future__ import annotations
from math import sqrt, ceil, floor
from utils import parse_day
-from collections import defaultdict
-from itertools import chain
+from functools import cache, total_ordering
+from enum import Enum, auto
+from dataclasses import dataclass
+
+@total_ordering
+@dataclass(frozen=True)
+class Interval:
+ class Type(Enum):
+ START = auto()
+ END = auto()
+
+ class Axis(Enum):
+ Y = auto()
+ X = auto()
+
+ time: int
+ type: Interval.Type
+ axis: Interval.Axis
+ value: int
+
+ def __lt__(self, other: Interval) -> bool:
+ def cmptuple(o: Interval) -> tuple[int, bool, bool]:
+ is_end: bool = o.type == Interval.Type.END
+ return o.time, is_end, (o.axis == Interval.Axis.Y) ^ is_end, o.value
+ if not isinstance(other, Interval):
+ return NotImplemented
+ return cmptuple(self) < cmptuple(other)
tx_min, tx_max, ty_min, ty_max = \
map(int, parse_day(17, r'.* x=(-?\d+)..(-?\d+), y=(-?\d+)..(-?\d+)'))
@@ -11,16 +37,19 @@ vx_max = tx_max
vy_min = ty_min
vy_max = abs(ty_min) - 1
+@cache
def sum_n(n):
return n * (n + 1) // 2
print(sum_n(vy_max))
+@cache
def time_at_pos(v, d):
"""d = -t(t - 2v - 1) / 2 solved for t"""
middle = sqrt((2 * v + 1) ** 2 - 8 * d)
return (-middle + 2 * v + 1) / 2, (middle + 2 * v + 1) / 2
+@cache
def time_at_xpos(v, d):
if d > sum_n(v): return float('inf')
return time_at_pos(v, d)[0]
@@ -32,18 +61,38 @@ def y_times(vy0):
def x_times(vx0):
start = time_at_xpos(vx0, tx_min)
+ if start != float('inf'): start = ceil(start)
end = time_at_xpos(vx0, tx_max)
+ if end != float('inf'): end = floor(end)
return start, end
-times = defaultdict(set)
+intervals = []
+for vx0 in range(vx_min, vx_max + 1):
+ start, end = x_times(vx0)
+ if end < start: continue
+ intervals.append(Interval(start, Interval.Type.START, Interval.Axis.X, vx0))
+ intervals.append(Interval(end, Interval.Type.END, Interval.Axis.X, vx0))
for vy0 in range(vy_min, vy_max + 1):
- tmin, tmax = y_times(vy0)
- for t in range(tmin, tmax + 1):
- times[t].add(vy0)
+ start, end = y_times(vy0)
+ if end < start: continue
+ intervals.append(Interval(start, Interval.Type.START, Interval.Axis.Y, vy0))
+ intervals.append(Interval(end, Interval.Type.END, Interval.Axis.Y, vy0))
+intervals.sort()
+active_vx0 = set()
+active_vy0 = set()
count = 0
-for vx0 in range(vx_min, vx_max + 1):
- tmin, tmax = x_times(vx0)
- count += len(set(chain.from_iterable(
- vy0set for t, vy0set in times.items() if tmin <= t <= tmax
- )))
+for interval in intervals:
+ match interval:
+ case Interval(type=Interval.Type.END, axis=axis, value=value):
+ match axis:
+ case Interval.Axis.X: active_vx0.remove(value)
+ case Interval.Axis.Y: active_vy0.remove(value)
+ case Interval(type=Interval.Type.START, axis=axis, value=value):
+ match axis:
+ case Interval.Axis.X:
+ count += len(active_vy0)
+ active_vx0.add(value)
+ case Interval.Axis.Y:
+ count += len(active_vx0)
+ active_vy0.add(value)
print(count)