summaryrefslogtreecommitdiffstats
path: root/24.py
blob: 47b4ba076328fd3dc39b77e7f5fffe450657899e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from dataclasses import dataclass
from functools import cache
from itertools import combinations
from sys import stdin

from z3 import z3


@dataclass(frozen=True, slots=True)
class Vec:
    x: float
    y: float
    z: float

    def __str__(self) -> str:
        return f"{self.x}, {self.y}, {self.z}"


@dataclass(frozen=True, slots=True)
class Hailstone:
    pos: Vec
    vel: Vec

    @property
    @cache
    def m(self) -> float:
        return self.vel.y / self.vel.x

    @property
    @cache
    def c(self) -> float:
        return self.pos.y - self.pos.x / self.vel.x * self.vel.y

    def __str__(self) -> str:
        return f"{self.pos} @ {self.vel}"


inp: list[Hailstone] = list()
for line in stdin:
    pos, vel = line.rstrip("\n").split(" @ ")
    pos = Vec(*map(float, pos.split(", ")))
    vel = Vec(*map(float, vel.split(", ")))
    inp.append(Hailstone(pos, vel))


def intersection(a: Hailstone, b: Hailstone) -> tuple[float, float] | None:
    numer = b.c - a.c
    denom = a.m - b.m
    if denom == 0:
        return None
    xi = numer / denom
    return xi, xi * a.m + a.c


lbound = 200000000000000
ubound = 400000000000000

p1 = 0
for a, b in combinations(inp, 2):
    inter = intersection(a, b)
    if inter is None:
        continue
    ix, iy = inter
    a_dist = (ix - a.pos.x) / a.vel.x
    b_dist = (ix - b.pos.x) / b.vel.x
    if (
        a_dist >= 0
        and b_dist >= 0
        and lbound <= ix <= ubound
        and lbound <= iy <= ubound
    ):
        p1 += 1
print(p1)

a, b, c, x, y, z = (z3.Int(c) for c in "abcxyz")
s = z3.Solver()
for i, m in enumerate("tuvw"):
    m = z3.Int(m)
    h = inp[i]
    s.add(h.pos.x + h.vel.x * m == x + a * m)
    s.add(h.pos.y + h.vel.y * m == y + b * m)
    s.add(h.pos.z + h.vel.z * m == z + c * m)
s.check()
print(s.model().eval(x + y + z))