1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
|
from dataclasses import dataclass
@dataclass(frozen=True)
class Point:
x: int
y: int
def __add__(self, other):
return Point(self.x + other.x, self.y + other.y)
# def __mul__(self, other):
# return Point(self.x * other, self.y + other)
def __sub__(self, other):
return Point(self.x - other.x, self.y - other.y)
def norm(self):
return Point(min(max(self.x, -1), 1), min(max(self.y, -1), 1))
def pabs(p):
return max(abs(p.x), abs(p.y))
delta = {
'U': Point(0, 1), 'D': Point(0, -1), 'L': Point(-1, 0), 'R': Point(1, 0)
}
with open('9.in') as f:
inp = [tuple(line.rstrip().split()) for line in f]
inp = [(delta[d], int(q)) for d, q in inp]
def solve(tails):
knots = [Point(0, 0) for _ in range(tails + 1)]
visited = set()
visited.add(knots[-1])
for move, count in inp:
for _ in range(count):
knots[0] += move
for i in range(1, len(knots)):
diff = knots[i-1] - knots[i]
if pabs(diff) > 1:
knots[i] += diff.norm()
visited.add(knots[-1])
# for y in range(20, -20, -1):
# line = []
# for x in range(-20, 20):
# p = Point(x, y)
# if p in visited:
# line.append('#')
# else:
# line.append(' ')
# try:
# line[-1] = str(knots.index(p))
# except ValueError:
# pass
# print(''.join(line))
return len(visited)
print(solve(1))
print(solve(9))
|