summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorTomasz Kramkowski <tk@the-tk.com>2021-12-17 20:15:58 +0000
committerTomasz Kramkowski <tk@the-tk.com>2021-12-17 20:15:58 +0000
commitb2ab1dda5fedd20b05ec2fa10b4a2af8196465dc (patch)
treed66a195438cfb4222e2992c3fc9c32382f158caf
parentf0fe0c242e97908bde78fdeee1549406e2059c04 (diff)
downloadaoc2021-b2ab1dda5fedd20b05ec2fa10b4a2af8196465dc.tar.gz
aoc2021-b2ab1dda5fedd20b05ec2fa10b4a2af8196465dc.tar.xz
aoc2021-b2ab1dda5fedd20b05ec2fa10b4a2af8196465dc.zip
day 17: faster
-rw-r--r--17.py42
1 files changed, 27 insertions, 15 deletions
diff --git a/17.py b/17.py
index fa426d4..be5bccb 100644
--- a/17.py
+++ b/17.py
@@ -1,5 +1,6 @@
from math import sqrt, ceil, floor
from utils import parse_day
+from collections import defaultdict
tx_min, tx_max, ty_min, ty_max = \
map(int, parse_day(17, r'.* x=(-?\d+)..(-?\d+), y=(-?\d+)..(-?\d+)'))
@@ -15,23 +16,34 @@ def sum_n(n):
print(sum_n(vy_max))
def time_at_pos(v, d):
- """d = - (t + 1)(t - 2v) / 2 solved for t"""
- return (sqrt((2 * v + 1) ** 2 - 8 * d) + 2 * v - 1) / 2 + 1
+ """d = -t(t - 2v - 1) / 2 solved for t"""
+ middle = sqrt((2 * v + 1) ** 2 - 8 * d)
+ return (-middle + 2 * v + 1) / 2, (middle + 2 * v + 1) / 2
+
+def time_at_xpos(v, d):
+ if d > sum_n(v): return float('inf')
+ return time_at_pos(v, d)[0]
def y_times(vy0):
- start = time_at_pos(vy0, ty_max)
- end = time_at_pos(vy0, ty_min)
- return range(ceil(start), floor(end) + 1)
+ start = ceil(time_at_pos(vy0, ty_max)[1])
+ end = floor(time_at_pos(vy0, ty_min)[1])
+ return start, end
-def x_at(vx0, t):
- assert(t >= 0)
- return sum_n(vx0) - sum_n(max(vx0 - t, 0))
+def x_times(vx0):
+ start = time_at_xpos(vx0, tx_min)
+ end = time_at_xpos(vx0, tx_max)
+ return start, end
-count = 0
+times = defaultdict(set)
for vy0 in range(vy_min, vy_max + 1):
- for vx0 in range(vx_min, vx_max + 1):
- for t in y_times(vy0):
- if tx_min <= x_at(vx0, t) <= tx_max:
- count += 1
- break
-print(count)
+ tmin, tmax = y_times(vy0)
+ for t in range(tmin, tmax + 1):
+ times[t].add(vy0)
+points = set()
+for vx0 in range(vx_min, vx_max + 1):
+ tmin, tmax = x_times(vx0)
+ for t, vy0set in times.items():
+ if tmin <= t <= tmax:
+ for vy0 in vy0set:
+ points.add((vx0, vy0))
+print(len(points))