summaryrefslogtreecommitdiffstats
path: root/9.py
blob: 1e9270b61f07ed663e8a7611e2099cda922bf0a1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from dataclasses import dataclass

@dataclass(frozen=True)
class Point:
    x: int
    y: int
    def __add__(self, other):
        return Point(self.x + other.x, self.y + other.y)
    # def __mul__(self, other):
        # return Point(self.x * other, self.y + other)
    def __sub__(self, other):
        return Point(self.x - other.x, self.y - other.y)
    def norm(self):
        return Point(min(max(self.x, -1), 1), min(max(self.y, -1), 1))

def pabs(p):
    return max(abs(p.x), abs(p.y))
delta = {
    'U': Point(0, 1), 'D': Point(0, -1), 'L': Point(-1, 0), 'R': Point(1, 0)
}
with open('9.in') as f:
    inp = [tuple(line.rstrip().split()) for line in f]
inp = [(delta[d], int(q)) for d, q in inp]
def solve(tails):
    knots = [Point(0, 0) for _ in range(tails + 1)]
    visited = set()
    visited.add(knots[-1])
    for move, count in inp:
        for _ in range(count):
            knots[0] += move
            for i in range(1, len(knots)):
                diff = knots[i-1] - knots[i]
                if pabs(diff) > 1:
                    knots[i] += diff.norm()
            visited.add(knots[-1])
        # for y in range(20, -20, -1):
            # line = []
            # for x in range(-20, 20):
                # p = Point(x, y)
                # if p in visited:
                    # line.append('#')
                # else:
                    # line.append(' ')
                # try:
                    # line[-1] = str(knots.index(p))
                # except ValueError:
                    # pass
            # print(''.join(line))
    return len(visited)

print(solve(1))
print(solve(9))