from dataclasses import dataclass @dataclass(frozen=True) class Point: x: int y: int def __add__(self, other): return Point(self.x + other.x, self.y + other.y) # def __mul__(self, other): # return Point(self.x * other, self.y + other) def __sub__(self, other): return Point(self.x - other.x, self.y - other.y) def norm(self): return Point(min(max(self.x, -1), 1), min(max(self.y, -1), 1)) def pabs(p): return max(abs(p.x), abs(p.y)) delta = { 'U': Point(0, 1), 'D': Point(0, -1), 'L': Point(-1, 0), 'R': Point(1, 0) } with open('9.in') as f: inp = [tuple(line.rstrip().split()) for line in f] inp = [(delta[d], int(q)) for d, q in inp] def solve(tails): knots = [Point(0, 0) for _ in range(tails + 1)] visited = set() visited.add(knots[-1]) for move, count in inp: for _ in range(count): knots[0] += move for i in range(1, len(knots)): diff = knots[i-1] - knots[i] if pabs(diff) > 1: knots[i] += diff.norm() visited.add(knots[-1]) # for y in range(20, -20, -1): # line = [] # for x in range(-20, 20): # p = Point(x, y) # if p in visited: # line.append('#') # else: # line.append(' ') # try: # line[-1] = str(knots.index(p)) # except ValueError: # pass # print(''.join(line)) return len(visited) print(solve(1)) print(solve(9))