#!/usr/bin/env python

import argparse
import sys

import normalpyrunner as runner

def describe_image(image_path: str, mute: bool = False):
    out_lines = []
    runner.exec(f"parted -ms '{image_path}' unit B print free", True, out_lines=out_lines, mute_out=mute, mute_err=mute)
    out_lines = [line.split(";")[0].split(":") for line in out_lines]

    assert 'BYT' == out_lines[0][0]

    partitions = []
    for line in out_lines[2:-1]:
        partition = {
            "number": int(line[0]),
            "start": int(line[1][:-1]),
            "end": int(line[2][:-1]),
            "length": int(line[3][:-1]),
            "type": line[4]
        }
        if len(line) > 5:
            partition["flags"] = line[5]
        partitions.append(partition)

    res = {
        "drive": {
            "path": out_lines[1][0],
            "size": int(out_lines[1][1][:-1]),
            "type": out_lines[1][2],
            "logical_sector_size": int(out_lines[1][3]),
            "physical_sector_size": int(out_lines[1][4]),
            "partition_table": out_lines[1][5]
        },
        "partitions": partitions
    }

    return res


def check_fix_filesystem(loop_device):
    try:
        runner.exec(f"e2fsck -pf {loop_device}", sudo=True)

    except ValueError:
        runner.exec(f"e2fsck -y {loop_device}", sudo=True)

    res = runner.exec(f"resize2fs -P {loop_device}", sudo=True)
    res = int(res.split(":")[-1])
    res = int(res)

    return res


def describe_check_fix_filesystem(loop_device):
    out_lines = []
    runner.exec(f"tune2fs -l {loop_device}", sudo=True, out_lines=out_lines, mute_out=True, mute_err=True)

    min_size = check_fix_filesystem(loop_device)

    runner.exec(f"tune2fs -l {loop_device}", sudo=True, out_lines=out_lines)

    clauses = {}
    for line in out_lines:
        pos = line.find(":")
        if pos >= 0:
            key = line[:pos].strip()
            value = line[pos + 1:].strip()
        else:
            key = line.strip()
            value = None

        clauses[key] = value

    res = {
        "block_count": int(clauses["Block count"]),
        "block_size": int(clauses["Block size"]),
        "min_size": min_size
    }

    return res


def plug_device(partition, image_path):
    start = partition["start"]
    length = partition["length"]

    loop_device = runner.exec(f"losetup -f --show -o {start} --sizelimit {length} {image_path}", sudo=True)
    loop_device = loop_device.strip()
    partition["loop_device"] = loop_device

    return loop_device


def unplug_device(partition):
    loop_device = partition.get("loop_device", None)

    if loop_device:
        runner.exec(f"losetup -d {loop_device}", sudo=True)


def shrink_filesystem(partition, image_path):
    plug_device(partition, image_path)

    try:
        loop_device = partition["loop_device"]

        descriptor = describe_check_fix_filesystem(loop_device)

        new_size = descriptor["min_size"]
        old_size = descriptor["block_count"]
        space = old_size - new_size

        if space > 0:
            for extra_space in [5000, 1000, 100]:
                if extra_space < space:
                    new_size = new_size + extra_space
                    break

        if new_size >= old_size:
            print("The partition is already of minimal size", file=sys.stderr)
            descriptor["new_size"] = old_size
            partition["filesystem"] = descriptor
            return descriptor

        runner.exec(f"resize2fs -p '{loop_device}' {new_size}", sudo=True)
        descriptor = describe_check_fix_filesystem(loop_device)
        descriptor["old_size"] = old_size
        partition["filesystem"] = descriptor
        return descriptor

    finally:
        unplug_device(partition)


def shrink_partition(partition, image_path):

    print(f"Shrinking partition '{partition['number']}'...", file=sys.stderr)

    if partition["type"] in ["free"]:
        print(f"Unable to shrink partition '{partition['type']}'", file=sys.stderr)
        return

    try:
        shrink_filesystem(partition, image_path)

        partno = partition["number"]
        parrtype = "primary"

        runner.exec(f"parted -s -a minimal '{image_path}' rm {partno}", sudo=True)

        newpartend = partition['start'] + partition["filesystem"]["block_count"] * partition["filesystem"][
            "block_size"] - 1
        runner.exec(f"parted -s '{image_path}' unit B mkpart {parrtype} {partition['start']} {newpartend}", sudo=True)

        partition["end"] = newpartend
        partition["old_length"] = partition["length"]
        partition["length"] = (newpartend+1) - partition["start"]

        partition["resized"] = True

        print(f"Partition '{partition['number']}' resized from {partition['old_length']} to {partition['length']}...", file=sys.stderr)

    except Exception as e:
        print("Unable to shrink this partition", file=sys.stderr)
        return


def copy_within_file(filename, old_start, new_start, length, chunk_size=4096):
    def copy_chunk_within_file(f, src_pos, dest_pos, current_chunk_size):
        f.seek(src_pos)
        data = f.read(current_chunk_size)
        f.seek(dest_pos)
        f.write(data)

    with open(filename, 'rb+') as f:
        if new_start > old_start:
            for src_end_loc in range(length, 0, -chunk_size):
                current_chunk_size = min(chunk_size, src_end_loc)
                src_pos = old_start + src_end_loc - current_chunk_size
                dest_pos = new_start + src_end_loc - current_chunk_size

                copy_chunk_within_file(f, src_pos, dest_pos, current_chunk_size)
        else:
            for src_loc in range(0, length, chunk_size):
                current_chunk_size = min(chunk_size, length - src_loc)

                src_pos = old_start + src_loc
                dest_pos = new_start + src_loc

                copy_chunk_within_file(f, src_pos, dest_pos, current_chunk_size)


def move_partition(partition, image_path, new_start):
    partno = partition["number"]
    print(f"Moving partition {partno}...", file=sys.stderr)

    parrtype = "primary"

    runner.exec(f"parted -s -a minimal '{image_path}' rm {partno}", sudo=True)

    copy_within_file(image_path, partition["start"], new_start, partition["length"])
    partition["start"] = new_start
    partition["end"] = new_start + partition["length"]

    runner.exec(f"parted -s '{image_path}' unit B mkpart {parrtype} {partition['start']} {partition['end']}", sudo=True)


def compact_partitions(image_path):
    descriptor = describe_image(image_path)

    previous = None
    for partition in descriptor["partitions"]:

        if "free" != partition["type"]:

            if previous is not None:

                previous_end = (previous["end"] + 1)
                gap = partition["start"] - previous_end
                if gap > 1024:
                    move_partition(partition, image_path, previous_end)

            previous = partition


def truncate_image_file_if_needed(image_path):
    descriptor = describe_image(image_path)

    last_partition = descriptor["partitions"][-1]

    if "free" == last_partition["type"]:
        filelength = last_partition["start"]
        print(f"Truncating image file {image_path} to {filelength} bytes...", file=sys.stderr)
        with open(image_path, 'rb+') as f:
            f.truncate(filelength)


def main(image_path: str):
    initial_parted = runner.exec(f"parted {image_path} print free", sudo=True, mute_out=True, mute_err=True)

    descriptor = describe_image(image_path)

    for partition in descriptor["partitions"]:
        shrink_partition(partition, image_path)

    compact_partitions(image_path)

    truncate_image_file_if_needed(image_path)

    print("Initial partition table:")

    print(initial_parted, file=sys.stdout)

    print("Final partition table:")

    runner.exec(f"parted {image_path} print free", sudo=True)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Shrinks Raspberry Pi (any?) images")

    parser.add_argument("--img", required=True, help="Image file")

    rgs = parser.parse_args()

    main(rgs.img)
