#!/usr/bin/env python
import os
import argparse
from collections import Counter
from copy import deepcopy
import yaml
from astropy.coordinates import SkyCoord
from astropy import units as u
from astropy.table import Table
from astropy.time import Time
import galsim
from romanisim import wcs, persistence, log, parameters, util
from romanisim.parameters import default_parameters_dictionary
from romanisim import ris_make_utils as ris


def main():
    parser = argparse.ArgumentParser(
        description='Make a stack of demo images.',
        epilog='EXAMPLE: %(prog)s mosaic_list.csv',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    # Command line Argument - pointing file input
    parser.add_argument('pointing_file_name',
                        type=str,
                        metavar='mosaic_list.csv',
                        help='Input (csv) file containing lists of observation parameters: '
                             'ra, dec, roll_angle, optical_element, date, overhead_time, '
                             'ma_table_number')

    # WCS Object Catalog
    parser.add_argument('cat_file_name',
                        type=str,
                        metavar='object_table.csv',
                        help='Object catalog file for wcs matching (csv)')

    # Create script making argument
    parser.add_argument('-a', "--apt",
                        type=str,
                        metavar='small_dither_program.apt',
                        help='APT file for metadata.')

    # Boresight
    parser.add_argument('-b', '--boresight',
                        action='store_true',
                        help=('RA & Dec specifies location of boresight, not center of WFI.'))

    # Use config file to override metadata or values in default parameters file
    parser.add_argument('-c', '--config',
                        type=str,
                        help='Input parameter override file (yaml)',
                        default=None)

    # Date & Time argument
    parser.add_argument('-d', '--date',
                        type=str,
                        default=parameters.default_date.isot,
                        help=('UTC Date and Time of observation to simulate in ISOT format.'))

    # Level argument
    parser.add_argument('-l', '--level',
                        type=int,
                        default=1,
                        help='1 or 2, for L1 or L2 output')

    # Create script making argument
    parser.add_argument('-m', "--make_script",
                        type=str,
                        metavar='sims',
                        default=None,
                        help='Filename to output list of romanisim calls (sims.script) instead of '
                             'making simulation files (e.g. for cluster usage).')

    # Persistence
    parser.add_argument('-p', '--persistence',
                        action='store_true',
                        help=('Enable persistence modeling of previous files for exposures after '
                              'the first in each SCA.'))

    # Random number seed
    parser.add_argument('-r', '--rng_seed',
                        type=int,
                        default=None,
                        help='Random number seed (int)')

    # SCA detector identifier number
    parser.add_argument('-s', '--sca',
                        type=int,
                        default=-1,
                        help='SCA to simulate; -1 simulates all SCAs')

    # Use CRDS for distortion maps
    parser.add_argument('-u', '--usecrds',
                        action='store_true',
                        help='Use CRDS for distortion map')

    parser.add_argument('--psftype', type=str, default='galsim',
                        choices=['epsf', 'galsim', 'stpsf'],
                        help=('Type of PSF generator to use.'
                              ' If None and --usecrds, then "epsf" will be used'
                              ' otherwise "galsim" is used.'))
    parser.add_argument('-w', '--webbpsf',
                        action='store_true',
                        help='Use stpsf for PSFs. (deprecated, use `--psftype stpsf`)')

    parser.add_argument('--stpsf',
                        action='store_true',
                        help='Use stpsf for PSFs. (deprecated, use `--psftype stpsf`)')

    parser.add_argument('-f', '--force-ma-table-number', type=int, default=-1,
                        help='Force MA table number, ignoring file content.')

    args = parser.parse_args()

    log.info('Starting simulation...')

    if args.config is not None:
        # Open and parse overrides file
        with open(args.config, "r") as config_file:
            config = yaml.safe_load(config_file)
            combo_dict = parameters.__dict__
            ris.merge_nested_dicts(combo_dict, config)
            parameters.__dict__.update(combo_dict)
    elif args.usecrds:
        # don't use default values
        for k in parameters.reference_data:
            parameters.reference_data[k] = None

    # Set up metadata
    metadata = deepcopy(default_parameters_dictionary)

    # Create random number distribution
    rng = galsim.UniformDeviate(args.rng_seed)

    # Open csv table pointing file
    pointings = Table.read(args.pointing_file_name, comment="#", delimiter=" ")

    # Count exposures per visit for nexposures metadata
    visit_counts = Counter(
        (e['PLAN'], e['PASS'], e['SEGMENT'], e['OBSERVATION'], e['VISIT'])
        for e in pointings)

    # If creating a script, create files
    if args.make_script:
        # Create script file
        script_file = open(args.make_script + ".script", "w")

        # Acquire cleaned options to pass to script call
        options_dct = {}
        for option, value in vars(args).items():
            log.debug(f'option, value = {option, value}')
            if option not in ('pointing_file_name', 'debug', 'make_script') and \
                    value is not None:
                options_dct[option] = value

    # Initialize time offset
    time_offset = 0.0 * u.s

    # Set file name suffix
    suffix = "uncal" if (args.level == 1) else "cal"

    # If selected, reset persistence dictionary
    if args.persistence:
        previous_file_name = {}

    apt_metadata = None if not args.apt else ris.parse_apt_file(args.apt)

    program = '00001' if apt_metadata is None else apt_metadata['observation']['program']

    apt_metadata['observation']['program'] = int(apt_metadata['observation']['program'])

    # Loop over pointings
    for entry_idx, entry in enumerate(pointings):

        # Debug: print keys
        log.debug(f'entry.keys = {entry.keys()}')

        # Set initial (possibly only) SCA
        sca = args.sca if (args.sca > 0) else 1
        plan, passno, segment, observation, visit, exposure = (
            entry['PLAN'], entry['PASS'], entry['SEGMENT'],
            entry['OBSERVATION'], entry['VISIT'], entry['EXPOSURE'])
        nexposures = visit_counts[(plan, passno, segment, observation, visit)]
        ma_table_number = int(entry['MA_TABLE_NUMBER'])
        bandpass = entry['BANDPASS']
        if args.force_ma_table_number > 0:
            ma_table_number = args.force_ma_table_number

        # If SCA value is within 1-18, only run once
        # otherwise loop over all detectors
        while (sca <= parameters.NUMBER_OF_DETECTORS):
            # Create output file name
            output_file_name = (
                f'r{program}{plan:02d}{passno:03d}{segment:03d}'
                f'{observation:03d}{visit:03d}_{exposure:04d}'
                f'_wfi{sca:02d}_{bandpass.lower()}_{suffix}.asdf')

            log.debug(f"output_file_name = {output_file_name}")

            # If making a script, write the line and proceed
            # otherwise create simulation file
            if args.make_script:
                line = f"romanisim-make-image {output_file_name}"

                ignore_opts = ('cat_file_name', 'apt', 'force_ma_table_number')
                # Preserve relevant stack options for image lines
                for option, value in options_dct.items():
                    if ((type(value) != bool) and (option != 'sca')
                            and option not in ignore_opts):
                        line += f" --{option} {value}"
                    elif (option == 'sca'):
                        line += f" --{option} {sca}"
                    elif (value is True):
                        if (option != 'persistence'):
                            line += f" --{option}"
                        elif (sca in previous_file_name.keys()):
                            line += f" --previous {previous_file_name[sca]}"

                # Add image options for contents of pointing input file
                # Subtract 60 deg from APT PA to convert from APT V3 PA
                # to WFI_CEN PA
                line += f" --bandpass {bandpass}"
                line += f" --radec {entry['RA']} {entry['DEC']}"
                line += f" --roll {entry['PA'] - 60}"
                line += f" --ma_table_number {ma_table_number}"
                line += f" --catalog {args.cat_file_name}"
                line += f" --meta visit.nexposures={nexposures}"

                # Debug print line
                log.debug(f'line = {line}')

                # add line to script
                script_file.write(line + '\n')

            else:
                # RA & Dec
                log.debug(f"float(entry['PA']) = {float(entry['PA'])}")
                radius = parameters.WFS_FOV
                coord = SkyCoord(ra=float(entry['RA']) * u.deg, dec=float(entry['DEC']) * u.deg, frame='icrs')

                wcs.fill_in_parameters(metadata, coord, boresight=args.boresight,
                                       pa_aper=float(entry['PA']))

                # Set metadata
                metadata = ris.set_metadata(meta=metadata,
                                            date=Time(args.date, format='isot') + time_offset,
                                            bandpass=entry['BANDPASS'], sca=sca,
                                            ma_table_number=ma_table_number,
                                            usecrds=args.usecrds)
                if apt_metadata:
                    util.merge_dicts(metadata, apt_metadata)
                metadata.setdefault('visit', {})['nexposures'] = nexposures

                # If selected, apply persistence
                if args.persistence and entry_idx > 0:
                    # Use previously created file for persistence for each exposure
                    # after the first
                    persist = persistence.Persistence.read(previous_file_name[sca])

                else:
                    persist = persistence.Persistence()

                # Create catalog object
                cat = ris.create_catalog(metadata=metadata, catalog_name=args.cat_file_name,
                                         bandpasses=[bandpass], coord=coord, radius=radius)

                # Set arguments to pass to image simulation
                args.filename = output_file_name
                args.bandpass = bandpass
                args.pretend_spectral = None

                # Simulate image files
                ris.simulate_image_file(args, metadata, cat, rng, persist)

            # Update persistence file if applicable
            if args.persistence:
                previous_file_name[sca] = output_file_name

            # Break the loop if only one detector specified
            if args.sca > 0:
                break
            else:
                sca += 1

        # Add time offset for the next exposure group
        time_offset += (float(entry['DURATION'])) * u.s

    # Close script file if appropriate
    if args.make_script:
        script_file.close()


# Call main if run (as opposed to imported as a module)
if __name__ == "__main__":
    main()
