#!/usr/bin/env python3
#
# Python script to run and analyse MMS test
#

# Cores: 2
# requires: zoidberg

import argparse
import json
import pathlib
import sys
from time import time

import boutconfig as conf
import zoidberg as zb
from boutdata.collect import collect
from boututils.run_wrapper import build_and_log, launch_safe
from numpy import arange, array, linspace, log, polyfit
from scipy.interpolate import RectBivariateSpline as RBS

# Global parameters
DIRECTORY = "data"
NPROC = 2
MTHREAD = 2
OPERATORS = ("grad_par", "grad2_par2", "div_par", "div_par_K_grad_par", "laplace_par")
# Note that we need at least _2_ interior points for hermite spline
# interpolation due to an awkwardness with the boundaries
NX = 4
# Resolution in y and z
NLIST = [8, 16, 32, 64]
dx = 1.0 / array(NLIST)


def myRBS(a, b, c):
    """RectBivariateSpline, but automatically tune spline degree for small arrays"""
    mx, _ = c.shape
    kx = max(mx - 1, 1)
    kx = min(kx, 3)
    return RBS(a, b, c, kx=kx)


zb.poloidal_grid.RectBivariateSpline = myRBS


def quiet_collect(name: str) -> float:
    # Index to return a plain (numpy) float rather than `BoutArray`
    return collect(
        name,
        tind=[1, 1],
        info=False,
        path=DIRECTORY,
        xguards=False,
        yguards=False,
    )[()]


def assert_convergence(error, dx, name, expected) -> bool:
    fit = polyfit(log(dx), log(error), 1)
    order = fit[0]
    print(f"{name} convergence order = {order:f} (fit)", end="")

    order = log(error[-2] / error[-1]) / log(dx[-2] / dx[-1])
    print(f", {order:f} (small spacing)", end="")

    # Should be close to the expected order
    success = order > expected * 0.95
    print(f"\t............ {'PASS' if success else 'FAIL'}")

    return success


def run_fci_operators(
    nslice: int, nz: int, yperiodic: bool, name: str
) -> dict[str, float]:
    # Define the magnetic field using new poloidal gridding method
    # Note that the Bz and Bzprime parameters here must be the same as in mms.py
    field = zb.field.Slab(Bz=0.05, Bzprime=0.1)
    # Create rectangular poloidal grids
    poloidal_grid = zb.poloidal_grid.RectangularPoloidalGrid(NX, nz, 0.1, 1.0, MXG=1)
    # Set the ylength and y locations
    ylength = 10.0

    if yperiodic:
        ycoords = linspace(0.0, ylength, nz, endpoint=False)
    else:
        # Doesn't include the end points
        ycoords = (arange(nz) + 0.5) * ylength / float(nz)

    # Create the grid
    grid = zb.grid.Grid(poloidal_grid, ycoords, ylength, yperiodic=yperiodic)
    maps = zb.make_maps(grid, field, nslice=nslice, quiet=True, MXG=1)
    zb.write_maps(
        grid,
        field,
        maps,
        new_names=False,
        metric2d=conf.isMetric2D(),
        quiet=True,
    )

    # Command to run
    args = f"MZ={nz} MYG={nslice} mesh:paralleltransform:y_periodic={yperiodic} {name}"
    cmd = f"./fci_mms {args}"
    print(f"Running command: {cmd}", end="")

    # Launch using MPI
    start = time()
    status, out = launch_safe(cmd, nproc=NPROC, mthread=MTHREAD, pipe=True)
    print(f" ... done in {time() - start:.3}s")

    # Save output to log file
    pathlib.Path(f"run.log.{nz}").write_text(out)

    if status:
        print(f"Run failed!\nOutput was:\n{out}")
        sys.exit(status)

    return {
        operator: {
            "l_2": quiet_collect(f"{operator}_l_2"),
            "l_inf": quiet_collect(f"{operator}_l_inf"),
        }
        for operator in OPERATORS
    }


def transpose(
    errors: list[dict[str, dict[str, float]]],
) -> dict[str, dict[str, list[float]]]:
    """Turn a list of {operator: error} into a dict of {operator: [errors]}"""

    kinds = ("l_2", "l_inf")
    result = {operator: {kind: [] for kind in kinds} for operator in OPERATORS}
    for error in errors:
        for k, v in error.items():
            for kind in kinds:
                result[k][kind].append(v[kind])
    return result


def check_fci_operators(name: str, case: dict) -> bool:
    failures = []

    nslice = case["nslice"]
    yperiodic = case["yperiodic"]
    order = case["order"]
    args = case["args"]

    all_errors = []

    for n in NLIST:
        errors = run_fci_operators(nslice, n, yperiodic, args)
        all_errors.append(errors)

        for operator in OPERATORS:
            l_2 = errors[operator]["l_2"]
            l_inf = errors[operator]["l_inf"]

            print(f"{operator} errors: l-2 {l_2:f} l-inf {l_inf:f}")

    final_errors = transpose(all_errors)
    for operator in OPERATORS:
        test_name = f"{operator} {name}"
        success = assert_convergence(
            final_errors[operator]["l_2"], dx, test_name, order
        )
        if not success:
            failures.append(test_name)

    return final_errors, failures


def make_plots(cases: dict[str, dict]):
    try:
        import matplotlib.pyplot as plt
    except ImportError:
        print("No matplotlib")
        return

    num_operators = len(OPERATORS)
    fig, axes = plt.subplots(1, num_operators, figsize=(num_operators * 4, 4))

    for ax, operator in zip(axes, OPERATORS):
        for name, case in cases.items():
            ax.loglog(dx, case[operator]["l_2"], "-", label=f"{name} $l_2$")
            ax.loglog(dx, case[operator]["l_inf"], "--", label=f"{name} $l_\\inf$")
        ax.legend(loc="upper left")
        ax.grid()
        ax.set_title(f"Error scaling for {operator}")
        ax.set_xlabel(r"Mesh spacing $\delta x$")
        ax.set_ylabel("Error norm")

    fig.tight_layout()
    fig.savefig("fci_mms.pdf")
    print("Plot saved to fci_mms.pdf")

    if args.show_plots:
        plt.show()
    plt.close()


if __name__ == "__main__":
    build_and_log("FCI MMS test")

    parser = argparse.ArgumentParser("Error scaling test for FCI operators")
    parser.add_argument(
        "--make-plots", action="store_true", help="Create plots of error scaling"
    )
    parser.add_argument(
        "--show-plots",
        action="store_true",
        help="Stop and show plots, implies --make-plots",
    )
    parser.add_argument(
        "--dump-errors",
        type=str,
        help="Output file to dump errors as JSON",
        default="fci_operator_errors.json",
    )

    args = parser.parse_args()

    success = True
    failures = []
    cases = {
        "nslice=1 hermitespline": {
            "nslice": 1,
            "order": 2,
            "yperiodic": True,
            "args": "mesh:ddy:first=C2 mesh:paralleltransform:xzinterpolation:type=hermitespline",
        },
        "nslice=1 lagrange4pt": {
            "nslice": 1,
            "order": 2,
            "yperiodic": True,
            "args": "mesh:ddy:first=C2 mesh:paralleltransform:xzinterpolation:type=lagrange4pt",
        },
        "nslice=1 monotonichermitespline": {
            "nslice": 1,
            "order": 2,
            "yperiodic": True,
            "args": (
                "mesh:ddy:first=C2 "
                "mesh:paralleltransform:xzinterpolation:type=monotonichermitespline "
                "mesh:paralleltransform:xzinterpolation:rtol=1e-3 "
                "mesh:paralleltransform:xzinterpolation:atol=5e-3"
            ),
        },
    }

    for name, case in cases.items():
        error2, failures_ = check_fci_operators(name, case)
        case.update(error2)
        failures.extend(failures_)
        success &= len(failures) == 0

    if args.dump_errors:
        pathlib.Path(args.dump_errors).write_text(json.dumps(cases))

    if args.make_plots or args.show_plots:
        make_plots(cases)

    if success:
        print("\nAll tests passed")
    else:
        print("\nSome tests failed:")
        for failure in failures:
            print(f"\t{failure}")

    sys.exit(0 if success else 1)
