#!python
# Copyright (C) 2022 Karl Wette
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with with program; see the file COPYING. If not, write to the
# Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
# MA  02110-1301  USA

## \file
## \ingroup lalpulsar_bin_SFTTools
"""Copy SFTs between directories. The destination directory is organised
following the convention detailed in the SFT spec (T040164)."""

import argparse
import os
import sys
import time
import shutil
from contextlib import contextmanager
from concurrent.futures import ProcessPoolExecutor, as_completed
from tqdm import tqdm
import logging

from lal import LALERRORBIT, LALWARNINGBIT, LALINFOBIT, LALTRACEBIT
from lal import GetDebugLevel, ClobberDebugLevel

from lalpulsar import git_version
from lalpulsar import ValidateSFTFile, SFTErrorMessage
from lalpulsar.public_sft_directory import public_sft_directory
from lalpulsar.public_sft_directory import public_sft_directory_readme_md

__author__ = "Karl Wette <karl.wette@ligo.org>"
__version__ = git_version.id
__date__ = git_version.date

# Configure logging
LOGGER = logging.getLogger(__name__.rsplit(".", 1)[-1])
try:
    from coloredlogs import ColoredFormatter as _Formatter
except ImportError:
    _Formatter = logging.Formatter
if not LOGGER.hasHandlers():
    _LOG_HANDLER = logging.StreamHandler()
    _LOG_HANDLER.setFormatter(
        _Formatter(
            fmt="[%(asctime)s] %(levelname)+8s: %(message)s",
        )
    )
    LOGGER.addHandler(_LOG_HANDLER)
LOGGER.setLevel(logging.INFO)


@contextmanager
def silence_xlal_error_messages():
    saveDebugLevel = GetDebugLevel()
    silentDebugLevel = saveDebugLevel & ~(
        LALERRORBIT | LALWARNINGBIT | LALINFOBIT | LALTRACEBIT
    )
    ClobberDebugLevel(silentDebugLevel)
    try:
        yield None
    finally:
        ClobberDebugLevel(saveDebugLevel)


def parse_command_line():
    # parse command line
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument(
        "-p", "--processes", type=int, default=1, help="number of copying processes"
    )
    parser.add_argument(
        "-f", "--force", action="store_true", help="overwrite existing SFTs"
    )
    parser.add_argument(
        "-t", "--test", action="store_true", help="test finding SFTs without copying"
    )
    parser.add_argument(
        "-n",
        "--no-validate",
        dest="validate",
        action="store_false",
        help="do not validate destination SFTs",
    )
    parser.add_argument(
        "--no-readme-md",
        dest="readme_md",
        action="store_false",
        help="do not write README.md in the destination directory",
    )
    parser.add_argument("source_directory", type=str, help="SFT source directory")
    parser.add_argument("dest_directory", type=str, help="SFT destination directory")
    args = parser.parse_args()

    # check arguments
    if args.processes <= 0:
        parser.error("--processes must be strictly positive")
    if not os.path.isdir(args.source_directory):
        parser.error("source_directory is not a directory")
    if not os.path.isdir(args.dest_directory):
        parser.error("dest_directory is not a directory")

    return args


def find_SFT_files(source_directory, dest_directory, force):
    dest_dirs = set()
    src_dest_paths = []

    # find source SFT files
    t0 = time.time()
    num_SFTs = 0
    print_progress = 100
    print_progress_step = 100
    print_progress_max = 1000
    for src_root, _, src_files in os.walk(source_directory):
        for src_file in src_files:
            if src_file.endswith(".sft"):
                src_path = os.path.join(src_root, src_file)
                _, src_name = os.path.split(src_path)

                # build SFT destination directory
                dest_dir = os.path.join(dest_directory, public_sft_directory(src_name))
                dest_path = os.path.join(dest_dir, src_name)

                # skip file if already exists and force=False
                if os.path.isfile(dest_path) and not force:
                    continue

                # add to outputs
                dest_dirs.add(dest_dir)
                src_dest_paths.append((src_path, dest_path))

                # print progress
                num_SFTs += 1
                if num_SFTs % print_progress == 0:
                    LOGGER.info(
                        "found {n} SFTs in {dt:0.1f} seconds".format(
                            n=num_SFTs,
                            dt=time.time() - t0,
                        )
                    )
                    print_progress += print_progress_step
                    if print_progress == print_progress_max:
                        print_progress_step *= 10
                        print_progress_max *= 10

    LOGGER.info("found {n} SFTs to copy".format(n=num_SFTs))

    return dest_dirs, src_dest_paths


def make_dest_dirs(dest_dirs):
    # make destination SFT directories
    LOGGER.info("making {n} directories ...".format(n=len(dest_dirs)))
    for dest_dir in dest_dirs:
        if not os.path.isdir(dest_dir):
            os.makedirs(dest_dir)
    LOGGER.info("making {n} directories ... done".format(n=len(dest_dirs)))


def copy_SFT_file(src_path, dest_path, validate):
    # copy SFT with a temporary extension
    tmp_dest_path = dest_path + "_TO_BE_VALIDATED"
    shutil.copyfile(src_path, tmp_dest_path)

    # validate SFT if requested
    if validate:
        with silence_xlal_error_messages() as _:
            validate_errorcode = ValidateSFTFile(tmp_dest_path)
        if validate_errorcode != 0:
            validate_errorstr = SFTErrorMessage(validate_errorcode)
            return (tmp_dest_path, validate_errorstr)

    # move destination SFT to final location
    os.rename(tmp_dest_path, dest_path)

    return None


def copy_all_SFT_files(src_dest_paths, validate, processes):
    validate_errors = []

    # create executor
    LOGGER.info("copying {n} SFTs ...".format(n=len(src_dest_paths)))
    with ProcessPoolExecutor(max_workers=args.processes) as executor:
        # submit tasks
        pool = [
            executor.submit(copy_SFT_file, src_path, dest_path, validate)
            for src_path, dest_path in src_dest_paths
        ]

        # collect tasks
        for task in tqdm(as_completed(pool), total=len(pool)):
            validate_error = task.result()
            if validate_error is not None:
                validate_errors.append(validate_error)

    # show any validation errors
    if validate_errors:
        LOGGER.critical(
            "failed to validate {n} SFTs after copying:".format(n=len(validate_errors))
        )
        for tmp_dest_path, validate_errorstr in validate_errors:
            LOGGER.critical("  {p}".format(p=tmp_dest_path))
            LOGGER.critical("    {e}".format(e=validate_errorstr))
        sys.exit(1)

    LOGGER.info("copying {n} SFTs ... done".format(n=len(src_dest_paths)))


def write_readme_md(dest_directory):
    # write README.md
    with open(os.path.join(dest_directory, "README.md"), "w") as f:
        f.write(public_sft_directory_readme_md())


if __name__ == "__main__":
    args = parse_command_line()

    dest_dirs, src_dest_paths = find_SFT_files(
        args.source_directory, args.dest_directory, args.force
    )

    make_dest_dirs(dest_dirs)

    if args.test:
        LOGGER.info("TESTING, not copying SFTs")
        sys.exit(0)

    copy_all_SFT_files(src_dest_paths, args.validate, args.processes)

    if args.readme_md:
        write_readme_md(args.dest_directory)

    LOGGER.info("DONE")
