#!/usr/bin/env python3
#
# Python script to run and analyse MMS test

from boututils.run_wrapper import build_and_log, launch_safe
from boutdata.collect import collect
import itertools
import sys

import numpy.testing as npt

# Resolution in x and y
NLIST = [1, 2, 4]
MAXCORES = 8
NSLICES = [1]

build_and_log("FCI MMS test")

COLLECT_KW = dict(info=False, xguards=False, yguards=False, path="data")


def run_case(nxpe: int, nype: int, mthread: int):
    cmd = f"./fci_mpi NXPE={nxpe} NYPE={nype}"
    print(f"Running command: {cmd}")

    _, out = launch_safe(cmd, nproc=nxpe * nype, mthread=mthread, pipe=True)

    # Save output to log file
    with open(f"run.log.{nxpe}.{nype}.{nslice}.log", "w") as f:
        f.write(out)


def test_case(nxpe: int, nype: int, mthread: int, ref: dict) -> bool:
    run_case(nxpe, nype, mthread)

    failures = []

    for name, val in ref.items():
        try:
            npt.assert_allclose(val, collect(name, **COLLECT_KW))
        except AssertionError as e:
            failures.append((nxpe, nype, name, e))

    return failures


failures = []

for nslice in NSLICES:
    # reference data!
    run_case(1, 1, MAXCORES)

    ref = {}
    for i in range(4):
        for yp in range(1, nslice + 1):
            for y in [-yp, yp]:
                name = f"output_{i}_{y:+d}"
                ref[name] = collect(name, **COLLECT_KW)

    for nxpe, nype in itertools.product(NLIST, NLIST):
        if (nxpe, nype) == (1, 1):
            # reference case, done above
            continue

        if nxpe * nype > MAXCORES:
            continue

        mthread = MAXCORES // (nxpe * nype)
        failures_ = test_case(nxpe, nype, mthread, ref)
        failures.extend(failures_)


success = len(failures) == 0
if success:
    print("\nAll tests passed")
else:
    print("\nSome tests failed:")
    for nxpe, nype, name, error in failures:
        print("----------")
        print(f"case {nxpe=} {nype=} {name=}\n{error}")

sys.exit(0 if success else 1)
