#!/usr/bin/env python

import argparse
import yaml

from astropy.coordinates import SkyCoord
from astropy import units as u
import galsim
from romanisim import log, wcs, persistence, parameters
from romanisim import ris_make_utils as ris
from astropy.io import fits


def go(args):
    if args.usecrds:
        # don't use default values
        for k in parameters.reference_data:
            parameters.reference_data[k] = None

    if args.config is not None:
        # Open and parse overrides file
        with open(args.config, "r") as config_file:
            config = yaml.safe_load(config_file)
            combo_dict = parameters.__dict__
            ris.merge_nested_dicts(combo_dict, config)
            parameters.__dict__.update(combo_dict)

    if args.sca == -1:
        # simulate all 18 SCAs sequentially
        for i in range(1, 19):
            args.sca = i
            go(args)
        return

    metadata = ris.set_metadata(
        date=args.date, bandpass=args.bandpass,
        sca=args.sca, ma_table_number=args.ma_table_number,
        usecrds=args.usecrds,
        truncate=args.truncate, scale_factor=args.scale_factor)

    ris.apply_meta_args(args, metadata)

    if args.radec is not None:
        coord = SkyCoord(ra=args.radec[0] * u.deg, dec=args.radec[1] * u.deg,
                         frame='icrs')
        wcs.fill_in_parameters(metadata, coord, boresight=args.boresight, pa_aper=args.roll)

    rng = galsim.UniformDeviate(args.rng_seed)

    # Create catalog
    cat = ris.create_catalog(metadata=metadata, catalog_name=args.catalog,
                             bandpasses=[args.bandpass], radius=0.1, rng=rng,
                             nobj=args.nobj, usecrds=args.usecrds)

    # Create persistence object
    if args.previous is not None:
        prevfn = ris.format_filename(args.previous, args.sca)
        persist = persistence.Persistence.read(prevfn)
    else:
        persist = persistence.Persistence()

    # Set up for extra counts as needed
    extra_counts = None
    if args.extra_counts is not None:
        if len(args.extra_counts) == 1:
            extra_counts_file = args.extra_counts[0]
            extra_counts_hdu = 0
        elif len(args.extra_counts) == 2:
            extra_counts_file = args.extra_counts[0]
            extra_counts_hdu = int(args.extra_counts[1])
        else:
            raise ValueError("If --extra-counts is used, must provide one or two arguments only.")
        with fits.open(extra_counts_file) as hdul:
            extra_counts = hdul[extra_counts_hdu].data


    # Simulate image and write to file
    ris.simulate_image_file(args, metadata, cat, rng, persist, extra_counts=extra_counts)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description='Make a demo image.',
        epilog='EXAMPLE: %(prog)s output_image.asdf',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('filename', type=str,
                        help=('output image (asdf).  {} and {bandpass} strings '
                              'will be automatically populated with detector '
                              'and bandpass information.'))
    parser.add_argument('--bandpass', type=str, help='bandpass to simulate',
                        default='F087')
    parser.add_argument('--boresight', action='store_true', default=False,
                        help=('radec specifies location of boresight, not '
                              'center of WFI.'))
    parser.add_argument('--catalog', type=str, help='input catalog (ecsv)',
                        default=None)
    parser.add_argument('--config', type=str, help='input parameter override file (yaml)',
                        default=None)
    parser.add_argument('--date', type=str, default=parameters.default_date.isot,
                        help=('UTC Date and Time of observation to simulate in ISOT format.'))
    parser.add_argument('--level', type=int, default=2,
                        help='1 or 2, for L1 or L2 output')
    parser.add_argument('--ma_table_number', type=int, default=4)
    parser.add_argument('--nobj', type=int, default=1000)
    parser.add_argument('--previous', default=None, type=str,
                        help=('previous simulated file in chronological order '
                              'used for persistence modeling.'))
    parser.add_argument('--radec', type=float, nargs=2,
                        help='ra and dec (deg)', default=None)
    parser.add_argument('--rng_seed', type=int, default=None)
    parser.add_argument('--roll', type=float, default=0,
                        help='Position angle (North towards YIdl) measured at the V2Ref/V3Ref of the aperture used.')
    parser.add_argument('--sca', type=int, default=7, help=(
        'SCA to simulate. Use -1 to generate images for all SCAs; include {} in filename for this mode '
        'to indicate where the detector number should be filled, e.g. l1_{}.asdf'))
    parser.add_argument('--usecrds', action='store_true',
                        help='Use CRDS references.')
    parser.add_argument('--psftype', type=str, default='galsim',
                        choices=['epsf', 'galsim', 'stpsf'],
                        help=('Type of PSF generator to use.'
                              ' If None and --usecrds, then "epsf" will be used'
                              ' otherwise "galsim" is used.'))
    parser.add_argument('--webbpsf', action='store_true',
                        help='Use stpsf for PSF (deprecated, use `--psftype stpsf`)')
    parser.add_argument('--stpsf', action='store_true',
                        help='Use stpsf for PSF (deprecated, use `--psftype stpsf`)')
    parser.add_argument('--truncate', type=int, default=None, help=(
        'If set, truncate the MA table at given number of resultants.'))
    parser.add_argument('--pretend-spectral', type=str, default=None, help=(
        'Pretend the image is spectral.  exposure.type and instrument.element '
        'are updated to be grism / prism.'))
    parser.add_argument('--drop-extra-dq', default=False, action='store_true',
                        help=('Do not store the optional simulated dq array.'))
    parser.add_argument('--scale-factor', type=float, default=-1.,
                        help=('Velocity aberration-induced scale factor. If negative, use given time to calculated based on orbit ephemeris.'))
    ris.add_meta_args(parser)
    parser.add_argument('--extra-counts', type=str, default=None, nargs='+', help=(
        'An optional FITS file to read to get an array of counts to add into the simulated image. '
        'Useful for wrapping idealized images. '
        'If 2 arguments are sent in, then the second argument is assumed to be the HDU to use (default=0).'
    ))

    args = parser.parse_args()

    log.info('Starting simulation...')
    go(args)
