from kaitaistruct import KaitaiStream


class Record():
    def __init__(self, io=None):
        self.magic = []
        self.valid_index = [0, 1, 2, 3, 4, 5]
        self._io = io
        self.nbyte = 0
        self.data = []
        self.name = "unknown"

    def match(self, magic):
        for index in self.valid_index:
            if self.magic[index] != magic[index]:
                return False
        return True

    def read(self, magic):
        for i in range(self.nbyte):
            self.data.append(self._io.read_u1())

    def __str__(self):
        str_list = list(map(lambda x: str(hex(x)), self.magic + self.data))
        return ",".join(str_list)


class Record_0x0b(Record):
    def __init__(self, io=None):
        super(Record_0x0b, self).__init__(io)
        self.magic = [0x0b, 0x0, 0x0, 0x5, 0x23, 0x24]
        self.valid_index = [0]
        self.nbyte = 6
        self.name = "0x0b"
        self.head = None

    def read(self, magic):
        super(Record_0x0b, self).read(magic)
        self.head = magic

    def __str__(self):
        str_list = []
        if self.head is not None:
            str_list = list(map(lambda x: str(hex(x)), self.head))
        str_list.extend(list(map(lambda x: str(hex(x)), self.data)))
        return ",".join(str_list)


class RecordStHits(Record):
    def __init__(self, io=None):
        super(RecordStHits, self).__init__(io)
        self.magic = [0x0, 0x0, 0x0, 0x5, 0x23, 0x24]
        self.valid_index = [2, 3, 4]
        self.name = "st_hits"

    def read(self, magic):
        self.read_head(magic)
        self.read_data()

    def read_head(self, magic):
        self.head = magic
        self.top_rail = 0 if self.head[5] == 0x64 else 1
        self.start = self._io.read_u4le()
        self.length = self._io.read_u1()
        self.head.append(self.start)
        self.head.append(self.length)

    def read_data(self):
        self.data = []
        i = self.start
        end = self.start + self.length
        while i < end:
            ndata = self._io.read_u1()
            for j in range(ndata):
                depth = self._io.read_u1()
                channels = self._io.read_u2le()
                self.data.append((i, self.top_rail, depth, channels))
            i += 1

    def head_str(self):
        str_list = list(map(lambda x: str(hex(x)), self.head))
        return ",".join(str_list)

    def __str__(self):
        str_list = []
        for pos, depth, channels in self.data:
            str_list.append(f"{pos},{self.top_rail},{depth},{channels}")
        return "\n".join(str_list)


class RecordUserIcons(Record):
    def __init__(self, io=None):
        super(RecordUserIcons, self).__init__(io)
        self.magic = [0x6, 0x0, 0x0, 0x20, 0x22, 0x27]
        self.nbyte = 1
        self.name = "user_icons"

    def read(self, magic):
        key_code = self._io.read_u1()
        self.data = [key_code]


class RecordPositions(Record):
    def __init__(self, io=None):
        super(RecordPositions, self).__init__(io)
        self.magic = [0xd, 0x0, 0x0, 0x49, 0x22, 0x21]
        self.nbyte = 8
        self.name = "positions"

    def read(self, magic):
        pos_bus = self._io.read_u4le()
        position = self._io.read_u4le()
        self.data = [pos_bus, position]


class RecordChainage(Record):
    def __init__(self, io=None):
        super(RecordChainage, self).__init__(io)
        self.magic = [0x12, 0x0, 0x0, 0x9, 0x22, 0x21]
        self.nbyte = 13
        self.name = "chainage"

    def read(self, magic):
        # unknown
        self._io.read_u1()
        post = self._io.read_f4le()
        pos_bus = self._io.read_u4le()
        velocity = self._io.read_f4le()
        self.data = [post, pos_bus, velocity]


class RecordUltraRecognitions(Record):
    def __init__(self, io=None):
        super(RecordUltraRecognitions, self).__init__(io)
        self.magic = [0x13, 0x0, 0x0, 0x7, 0x21, 0x23]
        self.valid_index = [0, 1]
        self.nbyte = 14
        self.name = "ultra_recognitions"

    def read(self, magic):
        first_byte = self._io.read_u1()
        top_rail = first_byte >> 6
        rec_code = first_byte & 0x3f
        pos_bus = self._io.read_u4le()
        max_depth = self._io.read_u1()
        length = self._io.read_u2le()
        min_depth = self._io.read_u1()
        severity = self._io.read_u1()
        tag_id = self._io.read_u4le()
        self.data = [pos_bus, rec_code, top_rail,
                     max_depth, min_depth,
                     length, severity, tag_id]


class Record_0x13_2(Record):
    def __init__(self, io=None):
        super(Record_0x13_2, self).__init__(io)
        self.magic = [0x13, 0x01, 0x0, 0x36, 0x22, 0x21]
        self.valid_index = [0, 1]
        self.nbyte = 14
        self.name = "0x13_2"


class Record_0x17(Record):
    def __init__(self, io=None):
        super(Record_0x17, self).__init__(io)
        self.magic = [0x17, 0x0, 0x0, 0x47, 0x22, 0x21]
        self.nbyte = 18
        self.name = "0x17"


class RecordLocationGPS(Record):
    def __init__(self, io=None):
        super(RecordLocationGPS, self).__init__(io)
        self.magic = [0x20, 0x2c, 0x0, 0x9, 0x22, 0x22]
        self.valid_index = [0, 2, 3, 4, 5]
        self.nbyte = 27
        self.name = "location_gps"

    def read(self, magic):
        self._io.read_u1()
        latitude = self._io.read_u4le()
        longtitude = self._io.read_u4le()
        gps_source = self._io.read_u4le()
        satelites = self._io.read_u4le()
        pos_bus = self._io.read_u4le()
        self._io.read_u1()
        post = self._io.read_f4le()
        self._io.read_u1()
        self.data = [latitude, longtitude, gps_source,
                     satelites, pos_bus, post]


class RecordEnd(Record):
    def __init__(self, io=None):
        super(RecordEnd, self).__init__(io)
        self.magic = [0x05, 0x0, 0x0, 0x03, 0x23, 0x21]
        self.nbyte = 0
        self.name = "end"


class RecordStart(Record):
    def __init__(self, io=None):
        super(RecordStart, self).__init__(io)
        self.magic = [0xfb, 0x0, 0x0, 0x02, 0x22, 0x22]
        self.nbyte = 0x1470
        self.name = "start"

    def read(self, magic):
        data_start_magic = [0x00, 0x0c, 0x23, 0x21, 0x11, 0x00]
        cur_magic = magic
        while cur_magic != data_start_magic:
            byte = self._io.read_u1()
            cur_magic.pop(0)
            cur_magic.append(byte)


class T1k():
    def __init__(self, filename, verbose=False):
        self.f = open(filename, "rb")
        self._io = KaitaiStream(self.f)
        record_start = RecordStart(self._io)
        start_magic = []
        for i in range(6):
            start_magic.append(self._io.read_u1())
        if start_magic != record_start.magic:
            raise("Not a t1k file")
        record_start.read(start_magic)
        self.offset = self._io.pos()
        # if verbose:
        print("data start addr:", hex(self.offset))
        self.verbose = verbose
        self.record_objs = []
        self.record_cls = [
            Record_0x0b,
            RecordStHits,
            RecordUserIcons,
            RecordPositions,
            RecordChainage,
            RecordUltraRecognitions,
            Record_0x13_2,
            Record_0x17,
            RecordLocationGPS,
            RecordEnd
        ]
        for i in range(len(self.record_cls)):
            self.record_objs.append(self.record_cls[i]())

    def __len__(self):
        return self._io.size() - self.offset

    def read_record(self, magic):
        for i, record_obj in enumerate(self.record_objs):
            if record_obj.match(magic):
                record = self.record_cls[i](self._io)
                record.read(magic)
                return record
        print(hex(self._io.pos()), list(map(hex, magic)))
        assert(False)

    def all_data(self):
        N = 0
        while not self._io.is_eof():
            prev_pos = self._io.pos()
            magic = []
            for i in range(6):
                magic.append(self._io.read_u1())
            record = self.read_record(magic)
            if record.name == "unknown":
                print(record, type(record), hex(self._io.pos()))
            assert(record.name != "unknown")
            N += self._io.pos() - prev_pos
            yield record, self._io.pos() - prev_pos
        assert(N == len(self))


def t1k_parse(filename="/home/ellery/Downloads/805_DF280419_001.t1k",
              verbose=False,
              show_process=True,
              dst_dir="./"):
    if verbose:
        show_process = False
    t1k = T1k(filename, verbose)
    pbar = None
    if show_process:
        import tqdm
        pbar = tqdm.tqdm(total=len(t1k))

    dts = {}
    for record, pos in t1k.all_data():
        if record.name not in dts:
            dts[record.name] = []
        if record.name == "st_hits":
            dts[record.name].extend(record.data)
        else:
            dts[record.name].append(record.data)
        if pbar is not None:
            pbar.update(pos)
    if pbar is not None:
        pbar.close()

    # import time
    # start = time.time()
    dts["st_hits"] = sorted(dts["st_hits"],
                            key=lambda x: (x[0], -x[3], x[2], -x[1]))
    # print("sort time: ", time.time() - start)
    from pathlib import PurePath
    import os
    dst_dir = os.path.join(dst_dir, PurePath(filename).stem)
    if not os.path.exists(dst_dir):
        os.mkdir(dst_dir)
    for k in dts:
        with open(f"{dst_dir}/{k}.csv", "w") as f:
            transform = lambda x: str(x)
            if "0x" in k:
                transform = lambda x: str(hex(x))
            for data in dts[k]:
                f.write(",".join(list(map(transform, data))))
                f.write("\n")


if __name__ == "__main__":
    import fire
    fire.Fire(t1k_parse)
