#!python

import argparse
import yaml

from astropy.coordinates import SkyCoord
from astropy import units as u
from astropy.table import Table
from astropy.time import Time
import galsim
import romanisim
from romanisim import log, wcs, persistence, parameters, l3, bandpass, util
from romanisim import ris_make_utils as ris
from copy import deepcopy
import math
import asdf
import os.path


if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description='Make an L3 image.',
        epilog='EXAMPLE: %(prog)s output_image.asdf catalog.ecsv',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('filename', type=str, help='output image (asdf)')
    parser.add_argument('catalog', type=str, help='input catalog (ecsv)')
    parser.add_argument('--bandpass', type=str, help='bandpass to simulate',
                        default='F087')
    parser.add_argument('--config', type=str, help='input parameter override file (yaml)',
                        default=None)
    parser.add_argument('--date', type=str, default=parameters.default_date.isot,
                        help=('UTC Date and Time of observation to simulate in ISOT format.'))
    parser.add_argument('--radec', type=float, nargs=2,
                        help='ra and dec (deg)', default=None)
    parser.add_argument('--npix', type=int, default=4000,
                        help='number of pixels across image')
    parser.add_argument('--pixscalefrac', type=float, default=1.0,
                        help='pixel scale as fraction of original')
    parser.add_argument('--exptime', type=float, default=100.0,
                        help='total exposure time on field; '
                        'roughly time per exposure times number of exposures')
    parser.add_argument('--effreadnoise', type=float, default=None,
                        help='effective readnoise per pixel in MJy/sr.  If '
                        'None, a pessimistic estimate is computed.')
    parser.add_argument('--nexposures', type=int, default=1,
                        help='number of exposures on field.  Used only '
                        'to compute the effective read noise.')
    parser.add_argument('--rng_seed', type=int, default=None)
    parser.add_argument('--psftype', type=str, default='galsim',
                        choices=['epsf', 'galsim', 'stpsf'],
                        help='Type of PSF generator to use.')
    ris.add_meta_args(parser)

    args = parser.parse_args()

    log.info(f'Running with args: {args}')
    log.info('Starting simulation...')

    pixscale = args.pixscalefrac * parameters.pixel_scale
    midpoint = (args.npix - 1) / 2
    r, d = args.radec
    center = util.celestialcoord(SkyCoord(ra=r * u.deg, dec=d * u.deg))
    twcs = wcs.create_tangent_plane_gwcs(
        (midpoint, midpoint), pixscale, center)

    metadata = deepcopy(parameters.default_mosaic_parameters_dictionary)
    if args.date is not None:
        metadata['coadd_info']['time_mean'] = Time(args.date, format='isot')
    else:
        metadata['coadd_info']['time_mean'] = parameters.default_date

    metadata['filename'] = os.path.basename(args.filename)
    ris.apply_meta_args(args, metadata)

    cat = ris.create_catalog(metadata=metadata, catalog_name=args.catalog,
                            bandpasses=[args.bandpass], coord=center, radius=args.npix)

    im, extras = l3.simulate(
        (args.npix, args.npix), twcs, args.exptime, args.bandpass,
        cat, effreadnoise=args.effreadnoise, nexposures=args.nexposures,
        metadata=metadata, psftype=args.psftype)

    # Create metadata for simulation parameter
    romanisimdict = deepcopy(vars(args))
    if 'filename' in romanisimdict:
        romanisimdict['filename'] = str(romanisimdict['filename'])
    romanisimdict.update(**extras)
    romanisimdict['version'] = romanisim.__version__

    af = asdf.AsdfFile()
    af.tree = {'roman': im, 'romanisim': romanisimdict}
    af.write_to(open(args.filename, 'wb'))
