#!/usr/bin/env python
"""
Optimize a data-processing pipeline specified in config-file using
principles from statistical Design of Experiments (DoE).
"""
import argparse
import os
import sys
import logging
import pandas as pd

from doepipeline.executor import LocalPipelineExecutor, SlurmPipelineExecutor
from doepipeline.generator import PipelineGenerator


## todo: recovery-flag.

class Range:
    """
    Type checker-class for float input required to be within an interval.
    """
    def __init__(self, start, end):
        self.start = start
        self.end = end

    def __eq__(self, other):
        return self.start <= other <= self.end

    def __repr__(self):
        return '{0}-{1}'.format(self.start, self.end)


def reduction_argument(value):
    if value == 'auto':
        return value

    err_msg = '--reduction must be "auto" or positive integer'
    try:
        value = int(value)
    except TypeError:
        raise argparse.ArgumentTypeError(err_msg)
    if value < 1:
        raise argparse.ArgumentTypeError(err_msg)
    return value


def recover_last_iteration(basedir, maxiter):
    iteration = 1
    while iteration < maxiter:
        next_dir = os.path.join(basedir, str(iteration))
        if os.path.isdir(next_dir):
            iteration += 1
            continue
        else:
            iteration -= 1
            return iteration


def make_parser():
    """ Create and config argument-parser.

    :return: prepared parser.
    :rtype: argparse.ArgumentParser
    """
    parser = argparse.ArgumentParser(
        description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
    )
    parser.add_argument('config', help='pipeline config-file')
    parser.add_argument('-e', '--execution', default='serial',
                          help=('how to execute steps in pipeline '
                                '(default serial)'),
                          choices=('serial', 'slurm', 'parallel'))

    parser.add_argument('-i', '--maxiter', default=10, type=int,
                        choices=[Range(1, 100)],
                        help=('maximum number of iterations (default: 10)'))

    parser.add_argument('--tol', type=float, default=.25, choices=[Range(0, 1)],
                        help=('maximum allowed relative distance between '
                              'predicted optimum to edge of experimental '
                              'design in order to consider optimization '
                              'converged (default: 0.25)'))
    parser.add_argument('--skip_screening', action='store_true',
                        help=('If this flag is set, not screening will be run '
                              'spanning all values from factor mins to factor '
                              'maxes.'))
    parser.add_argument('--reevaluate_screening', action='store_true',
                        help=('If this flag is set and high or low-limits '
                              'for response(s) are set, the results from '
                              'screening will be re-evaluated if algorithm '
                              'converges at an optimum which does not '
                              'fulfill the the requirements.'))
    parser.add_argument('--screening_reduction', default='auto',
                        type=reduction_argument,
                        help='Reduction factor for GSD-screening (default auto)')

    parser.add_argument('--degree', type=int, choices=[Range(1, 3)],
                        default=2, help='degree of polynomial fitted during '
                                        'optimization (default: 2)')
    parser.add_argument('-o', '--output', default='optimal_parameters.csv',
                        help='output file (default: optimal_parameters.csv)')
    parser.add_argument('-l', '--log', default=None,
                        help='log file (default: STDOUT)')
    parser.add_argument('-d', '--debug', action='store_true',
                        help='if set, logging will be set to debug')
    parser.add_argument('-m', '--model_selection_method', default='brute',
                        choices=['brute', 'greedy'], help='The method of selecting \
                        the best model in each iteration. It\'s recommended to use \
                        "-m greedy" when fitting a quadratic model and factors are \
                        more than 3 because of the steep running time increase.')
    parser.add_argument('-s', '--shrinkage', type=float, choices=[Range(0.9, 1)],
                        help='The span between high and low settings for numeric/ordinal \
                        factors can be decreased between iterations. The shrinkage \
                        specifies the ratio that is retained of the previous span - i.e. \
                        "-s 0.9" will decrease the span by 10%% between iterations. \
                        Default is no shrinkage (1.0).', default=1.0)
    parser.add_argument('-q', '--q2_limit', type=float, choices=[Range(0, 1)],
                        help='The Q2 value (predictive power) of a model must be greater \
                        than this value in order to continue the optimization process \
                        (default: 0.5)', default=0.5)
    parser.add_argument('--gsd_span_ratio', type=float, choices=[Range(0,1)],
                        default=0.5, help='After running screening, the factor highs \
                        and lows for the coming optimization iteration are placed around \
                        the point in the screening design that produced the best result. \
                        This option sets the distance between the new highs and lows (span) \
                        to a ratio of the span between nearest points of the best screening \
                        point (default 0.5).')
    parser.add_argument('--recover', action='store_true', help='If set, will try to recover \
                        a previous run. IMPORTANT: You MUST run doepipeline with the exact same \
                        parameters and yaml-file as with the run you are trying to recover. \
                        There are no internal checks for this, so use at own discretion.')

    # TODO: Flag for setting maximum number of threads when running in parallel mode. As of now, it will start all jobs irregardles of how many threads are available.

    return parser


if __name__ == '__main__':
    parser = make_parser()
    args = parser.parse_args()
    recovering = args.recover

    if args.debug:
        log_format = ('[%(filename)s:%(lineno)s - %(funcName)20s() ] '
                      '%(asctime)s %(message)s')
    else:
        log_format = '%(asctime)s %(levelname)-8s: %(message)s'
    logging.basicConfig(
        format=log_format, filename=args.log,
        level=logging.INFO if not args.debug else logging.DEBUG
    )

    if recovering:
        logging.info('### Attempting to recover a previous run. ###')

    try:
        logging.info('Reads config: {}'.format(args.config))
        generator = PipelineGenerator.from_yaml(args.config)
    except ValueError as e:
        logging.critical('Failed to read config.')
        sys.exit(str(e))

    logging.info('Initialize designer.')
    designer = generator.new_designer_from_config(
        skip_screening=args.skip_screening,
        gsd_reduction=args.screening_reduction,
        model_selection=args.model_selection_method,
        shrinkage=args.shrinkage,
        q2_limit=args.q2_limit)

    if args.execution == 'slurm':
        executor_class = SlurmPipelineExecutor
    elif args.execution == 'serial':
        executor_class = LocalPipelineExecutor
    elif args.execution == 'parallel':
        executor_class = lambda *a, **kw: LocalPipelineExecutor(
            *a, run_serial=False, **kw)
    else:
        sys.exit('Unknown executor: {}'.format(args.execution))

    n_iter = 0
    old_optimum = None
    while n_iter < args.maxiter:
        n_iter += 1

        if recovering:
            basedir = generator.get_base_directory()
            n_iter = recover_last_iteration(basedir, args.maxiter)
            if n_iter == 0:
                # not even the first iteration had begun, nothing to recover.
                n_iter = 1
            logging.info('Previously, the iteration being '
                         'executed was {}'.format(n_iter))
            if n_iter > 1:
                # no need to do anything if n_iter==1, but otherwise we
                # need to update stuff:

                # update factors to correspond to last iteration
                # update best run
                generator.set_current_iteration(n_iter)
                designer.set_phase('optimization')

                # Recover factor settings from the iteration that failed,
                # and update designer accordingly
                iterdir = os.path.join(basedir, str(n_iter))
                factors_csv_file = os.path.join(iterdir, 'factor_settings.csv')
                logging.info('Fetching factor settings '
                             'from iter {}'.format(n_iter))
                if os.path.isfile(factors_csv_file):
                    designer.update_factors_from_csv(factors_csv_file)
                else:
                    sys.exit('Missing factor_settings.csv '
                             'file in {}'.format(iterdir))

                # Recover the best result obtained previously
                # The design and results from the last completed iteration
                last_iterdir = os.path.join(basedir, str(n_iter-1))
                logging.debug('Fetching design and results from last '
                              'completed iteration (iter {})'.format(n_iter-1))
                previous_design = pd.DataFrame.from_csv(os.path.join(last_iterdir, 'design.csv'))
                previous_results = pd.DataFrame.from_csv(os.path.join(last_iterdir, 'results.csv'))

                # The previously best result. Update designer to be aware of it
                best_results = designer.get_best_experiment(previous_design,
                                                            previous_results)
                logging.debug('The best experiment in '
                              'iter {} was {}'.format(n_iter-1, best_results))
                designer.new_design()  # set up design_sheet within designer

                # Recreate old_optimum
                old_optimum = designer.update_factors_from_optimum(best_results, recovery=True)

            recovering = False

        if not args.skip_screening and n_iter == 1:
            logging.info('Starts iteration {} (screening).'.format(n_iter))
        else:
            logging.info('Starts iteration {} (optimization).'.format(n_iter))

        executor = executor_class(base_command='{script}', recovery_mode=args.recover)

        logging.info('Sets up new design.')
        design = designer.new_design()
        with pd.option_context('display.max_rows', None, 'display.max_columns', None):
            logging.info('The following experiments will be run:\n{}'.format(design))

        logging.info('Preparing pipeline with design parameters.')
        pipeline = generator.new_pipeline_collection(design)

        # Save factor settings used in this iteration
        iter_dir = pipeline['WORKDIR']
        factor_csv_file = os.path.join(iter_dir, 'factor_settings.csv')
        if not os.path.isdir(iter_dir):
            executor.make_dir(iter_dir)
        designer.write_factor_csv(factor_csv_file)

        logging.info('Start execution of pipeline.')
        results = executor.run_pipeline_collection(pipeline)

        exp_sheet_complete = pd.concat([design, results], axis=1)
        exp_sheet_complete.index.name = 'Exp'
        exp_sheet_outfile = os.path.join(
            pipeline['WORKDIR'], 'complete_experimental_sheet.csv')
        with pd.option_context('display.max_rows', None, 'display.max_columns', None):
            logging.info('The design sheet with results:\n{}'.format(exp_sheet_complete))
        logging.info('Saving the design sheet and result in {}'.format(exp_sheet_outfile))
        exp_sheet_complete.to_csv(exp_sheet_outfile)

        logging.info('Execution of iteration {} finished.'.format(n_iter))
        optimum = designer.get_optimal_settings(results)
        if n_iter < 2 and not args.skip_screening:
            best_results = designer.get_best_experiment(design, results)
            design.to_csv(os.path.join(pipeline['WORKDIR'], 'design.csv'))
            results.to_csv(os.path.join(pipeline['WORKDIR'], 'results.csv'))

        if not optimum.empirically_found:
            # If the optimum was predicted from a model
            if not optimum.predicted_optimum.isnull().all():
                # If it was possible to obtain a prediction, run a validation
                # experiment with the predicted optimal settings
                logging.info(
                    'Will run a validation experiment using the predicted '
                    'optimal settings.')
                validation_experiment = pd.DataFrame(
                    optimum.predicted_optimum,
                    columns=['validation']).transpose()
                validation_pipeline = generator.new_pipeline_collection(validation_experiment,
                                                                        validation_run=True)
                validation_result = executor.run_pipeline_collection(validation_pipeline)

                logging.info('Done with execution of the validation '
                             'experiment. The result was:\n{}'.format(validation_result))

                if validation_result.shape[1] > 1:
                    combined_response, _ = designer.treat_response(validation_result,
                                                                   perform_transform=False)
                    logging.info('The combined response was:\n{}'.format(combined_response))

                results = pd.concat([results, validation_result])
                design = pd.concat([design, validation_experiment])

            # Find the best experiment (incl. possible validation experiment)
            optimal_experiment = designer.get_best_experiment(design, results)
            design.to_csv(os.path.join(pipeline['WORKDIR'], 'design.csv'))
            results.to_csv(os.path.join(pipeline['WORKDIR'], 'results.csv'))

            if optimal_experiment['new_best']:
                logging.info('Found a new best response among the experiments in this iteration.')
                best_results = optimal_experiment

                if optimal_experiment['factor_settings'].name == 'validation':
                    logging.info('It was found using the predicted optimal settings.')
                else:
                    logging.info('It was not found using the predicted optimal settings.')

                # Update the factor settings based on the best experiment
                optimum = designer.update_factors_from_optimum(optimal_experiment, tol=args.tol)

            else:
                logging.info('Did not find a better response in this '
                             'iteration compared to the previous one.')
                logging.info('Rolling back to the best result in previous iteration.')
                optimum = old_optimum

                if not optimum.reached_limits:
                    logging.warning('Response limits were not reached.')
                    if not args.skip_screening and args.reevaluate_screening:
                        logging.info(('Re-evaluate screening results to optimize '
                        'from next best screening results.'))
                        designer.reevaluate_screening()
                        continue

                # Since the best response in this iteration was worse than before, stop
                break

        old_optimum = optimum

        if optimum.converged:
            logging.info('Convergence reached after {} iterations.'.format(n_iter))

            if not optimum.reached_limits:
                logging.warning('Response limits not reached.')
                if not args.skip_screening and args.reevaluate_screening:
                    logging.info(('Re-evaluate screening results to optimize '
                                  'from next best screening results.'))
                    designer.reevaluate_screening()
                    continue

            logging.info('Saving best settings to {}.'.format(n_iter, args.output))
            optimum.predicted_optimum.to_csv(args.output)

            break

    if not optimum.converged:
        logging.info('Failed to converge to optimum in {} iterations'.format(n_iter))
        try:
            prefix, suffix = args.output.rsplit('.', 1)
        except ValueError:
            prefix = args.output
            suffix = ''
        outpath = '.'.join([prefix, 'unconverged', suffix])

        logging.info('Saves unconverged results to {}.'.format(outpath))
        optimum.predicted_optimum.to_csv(outpath)

    logging.info('Best response was:\n{}'.format(best_results['response']))
    logging.info('The settings were:\n{}'.format(best_results['factor_settings']))
    if best_results['response'].shape[0] > 1:
        logging.info('The combined response was:\n{}'.format(best_results['weighted_response']))
    logging.info('Shuts down.')
