aboutsummaryrefslogtreecommitdiffstats
path: root/scripts/update_chitu.py
blob: cf7fcfe9da440c5d418ec5303099efdcf8f916c7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
#!/usr/bin/env python3
# Encodes STM32 firmwares to be flashable from SD card by Chitu motherboards.
# Relocate firmware to 0x08008800!

# Copied from Marlin and modified.
# Licensed under GPL-3.0

import os
import struct
import uuid
import sys
import hashlib

def calculate_crc(contents, seed):
    accumulating_xor_value = seed;

    for i in range(0, len(contents), 4):
        value = struct.unpack('<I', contents[ i : i + 4])[0]
        accumulating_xor_value = accumulating_xor_value ^ value
    return accumulating_xor_value

def xor_block(r0, r1, block_number, block_size, file_key):
    # This is the loop counter
    loop_counter = 0x0

    # This is the key length
    key_length = 0x18

    # This is an initial seed
    xor_seed = 0x4bad

    # This is the block counter
    block_number = xor_seed * block_number

    #load the xor key from the file
    r7 =  file_key

    for loop_counter in range(0, block_size):
        # meant to make sure different bits of the key are used.
        xor_seed = int(loop_counter/key_length)

        # IP is a scratch register / R12
        ip = loop_counter - (key_length * xor_seed)

        # xor_seed = (loop_counter * loop_counter) + block_number
        xor_seed = (loop_counter * loop_counter) + block_number

        # shift the xor_seed left by the bits in IP.
        xor_seed = xor_seed >> ip

        # load a byte into IP
        ip = r0[loop_counter]

        # XOR the seed with r7
        xor_seed = xor_seed ^ r7

        # and then with IP
        xor_seed = xor_seed ^ ip

        #Now store the byte back
        r1[loop_counter] = xor_seed & 0xFF

        #increment the loop_counter
        loop_counter = loop_counter + 1


def encode_file(input, output_file, file_length):
    input_file = bytearray(input.read())
    block_size = 0x800
    key_length = 0x18

    file_digest = hashlib.md5(input_file).digest()
    uid_value = uuid.UUID(bytes=file_digest)
    print("Update UUID ", uid_value)
    file_key = int(uid_value.hex[0:8], 16)

    xor_crc = 0xef3d4323;

    # the input file is expected to be in chunks of 0x800
    # so round the size
    while len(input_file) % block_size != 0:
        input_file.extend(b'0x0')

    # write the file header
    output_file.write(struct.pack(">I", 0x443D2D3F))
    # encode the contents using a known file header key

    # write the file_key
    output_file.write(struct.pack("<I", file_key))

    #TODO - how to enforce that the firmware aligns to block boundaries?
    block_count = int(len(input_file) / block_size)
    print("Block Count is ", block_count)
    for block_number in range(0, block_count):
        block_offset = (block_number * block_size)
        block_end = block_offset + block_size
        block_array = bytearray(input_file[block_offset: block_end])
        xor_block(block_array, block_array, block_number, block_size, file_key)
        for n in range (0, block_size):
            input_file[block_offset + n] = block_array[n]

        # update the expected CRC value.
        xor_crc = calculate_crc(block_array, xor_crc)

    # write CRC
    output_file.write(struct.pack("<I", xor_crc))

    # finally, append the encoded results.
    output_file.write(input_file)
    return

def main():
    if len(sys.argv) != 3:
        print("Usage: update_chitu <input_file> <output_file>")
        exit(1)

    fw, output = sys.argv[1:]

    if not os.path.isfile(fw):
        print("Usage: update_chitu <input_file> <output_file>")
        print("Firmware file", fw, "does not exist")
        exit(1)

    firmware = open(fw, "rb")
    update = open(output, "wb")
    length = os.path.getsize(fw)

    encode_file(firmware, update, length)

    firmware.close()
    update.close()

    print("Encoding complete.")

if __name__ == '__main__':
    main()