#!/usr/bin/env python3
"""Launch daylily-omics-analysis inside a tmux session on the head node."""

from __future__ import annotations

import argparse
import json
import os
import shlex
import subprocess
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Iterable, List, Optional, Tuple


SSH_OPTIONS: Tuple[str, ...] = (
    "-o",
    "StrictHostKeyChecking=no",
    "-o",
    "UserKnownHostsFile=/dev/null",
)


class CommandError(RuntimeError):
    """Raised when an external command fails."""


@dataclass
class RemoteConfig:
    stage_dir: str
    samples_path: str
    units_path: str


def normalize_remote_path(path: str) -> str:
    if path.startswith("~/"):
        return path.replace("~/", "/home/ubuntu/", 1)
    if path == "~":
        return "/home/ubuntu"
    return path


def run_command(
    command: Iterable[str],
    *,
    capture_output: bool = False,
    env: Optional[dict] = None,
    check: bool = True,
) -> subprocess.CompletedProcess:
    try:
        result = subprocess.run(  # type: ignore[call-arg]  # noqa: S603
            list(command),
            check=check,
            capture_output=capture_output,
            text=True,
            env=env,
        )
    except subprocess.CalledProcessError as exc:  # pragma: no cover
        stdout = exc.stdout or ""
        stderr = exc.stderr or ""
        message = f"Command failed ({exc.returncode}): {' '.join(command)}"
        if stdout:
            message += f"\nSTDOUT:\n{stdout.strip()}"
        if stderr:
            message += f"\nSTDERR:\n{stderr.strip()}"
        raise CommandError(message) from exc
    return result


def choose_from(prompt: str, options: List[str]) -> str:
    if not options:
        raise CommandError(f"No options available for: {prompt}")
    if len(options) == 1:
        return options[0]
    print(prompt)
    for idx, value in enumerate(options, start=1):
        print(f"  {idx}) {value}")
    while True:
        try:
            selection = int(input("Select an option: "))
        except ValueError:
            print("Please enter a number.")
            continue
        if 1 <= selection <= len(options):
            return options[selection - 1]
        print("Selection out of range; try again.")


def resolve_region(profile: str) -> str:
    env = {**os.environ, "AWS_PROFILE": profile} if profile else None
    result = run_command(
        ["aws", "ec2", "describe-regions", "--output", "json"],
        capture_output=True,
        env=env,
    )
    data = json.loads(result.stdout)
    regions = sorted(entry["RegionName"] for entry in data.get("Regions", []))
    if not regions:
        raise CommandError("Unable to retrieve AWS regions. Check AWS credentials.")
    default_region = os.environ.get("AWS_REGION") or os.environ.get("AWS_DEFAULT_REGION")
    if default_region and default_region in regions:
        return default_region
    return choose_from("Select AWS region:", regions)


def resolve_cluster(profile: str, region: str) -> str:
    env = {**os.environ, "AWS_PROFILE": profile} if profile else None
    result = run_command(
        ["pcluster", "list-clusters", "--region", region, "--output", "json"],
        capture_output=True,
        env=env,
    )
    data = json.loads(result.stdout or "{}")
    clusters = [entry.get("clusterName") for entry in data.get("clusters", [])]
    clusters = [name for name in clusters if name]
    if not clusters:
        raise CommandError(f"No ParallelCluster clusters found in region {region!r}.")
    return choose_from("Select cluster:", sorted(clusters))


def resolve_pem_file(path: Optional[str]) -> str:
    if path:
        expanded = os.path.expanduser(path)
        if not os.path.exists(expanded):
            raise CommandError(f"PEM file not found: {expanded}")
        return expanded
    pem_candidates = sorted(Path.home().glob(".ssh/*.pem"))
    if not pem_candidates:
        raise CommandError("No PEM files found under ~/.ssh. Provide one with --pem.")
    selection = choose_from("Select SSH PEM key:", [str(p) for p in pem_candidates])
    return selection


def fetch_headnode_ip(profile: str, region: str, cluster_name: str) -> str:
    env = {**os.environ, "AWS_PROFILE": profile} if profile else None
    cmd = [
        "pcluster",
        "describe-cluster",
        "--region",
        region,
        "--cluster-name",
        cluster_name,
        "--output",
        "json",
    ]
    try:
        result = run_command(cmd, capture_output=True, env=env)
        payload = json.loads(result.stdout)
        head_node = payload.get("cluster", {}).get("headNode") or payload.get("headNode", {})
        ip_address = head_node.get("publicIpAddress") or head_node.get("publicIp")
        if not ip_address:
            raise KeyError("publicIpAddress")
        return ip_address
    except (CommandError, json.JSONDecodeError, KeyError):  # pragma: no cover
        result = run_command(
            [
                "pcluster",
                "describe-cluster",
                "--region",
                region,
                "--cluster-name",
                cluster_name,
            ],
            capture_output=True,
            env=env,
        )
        for line in (result.stdout or "").splitlines():
            if "publicIpAddress" in line or "publicIp" in line:
                parts = line.replace("\"", "").replace(",", "").split(":", 1)
                if len(parts) == 2 and parts[1].strip():
                    return parts[1].strip()
        raise CommandError("Unable to determine head node IP address.")


def run_remote_script(pem: str, host: str, script: str) -> subprocess.CompletedProcess:
    remote_cmd = f"bash -lc {shlex.quote(script)}"
    return run_command(
        ["ssh", "-i", pem, *SSH_OPTIONS, f"ubuntu@{host}", remote_cmd],
        capture_output=True,
    )


def parse_remote_config(stdout: str) -> RemoteConfig:
    stage_dir = samples_path = units_path = None
    for line in stdout.splitlines():
        if line.startswith("__DAYLILY_STAGE_DIR__="):
            stage_dir = line.split("=", 1)[1].strip()
        elif line.startswith("__DAYLILY_STAGE_SAMPLES__="):
            samples_path = line.split("=", 1)[1].strip()
        elif line.startswith("__DAYLILY_STAGE_UNITS__="):
            units_path = line.split("=", 1)[1].strip()
        elif line.startswith("__DAYLILY_ERROR__="):
            raise CommandError(f"Remote lookup failed: {line.split('=', 1)[1]}")
    if not (stage_dir and samples_path and units_path):
        raise CommandError("Unable to determine staged config paths on the head node.")
    return RemoteConfig(stage_dir, samples_path, units_path)


def discover_stage_config(
    pem: str,
    host: str,
    stage_dir: Optional[str],
    stage_base: str,
) -> RemoteConfig:
    if stage_dir:
        target_dir = normalize_remote_path(stage_dir.rstrip("/"))
        script = f"""
set -euo pipefail
STAGE_DIR={shlex.quote(target_dir)}
if [[ ! -d "$STAGE_DIR" ]]; then
  echo "__DAYLILY_ERROR__=missing_stage_dir"
  exit 2
fi
samples_file=$(ls -1 "$STAGE_DIR"/*_samples.tsv 2>/dev/null | head -n 1)
units_file=$(ls -1 "$STAGE_DIR"/*_units.tsv 2>/dev/null | head -n 1)
if [[ -z "$samples_file" || -z "$units_file" ]]; then
  echo "__DAYLILY_ERROR__=missing_config"
  exit 3
fi
echo "__DAYLILY_STAGE_DIR__=$STAGE_DIR"
echo "__DAYLILY_STAGE_SAMPLES__=$samples_file"
echo "__DAYLILY_STAGE_UNITS__=$units_file"
"""
    else:
        stage_base_norm = normalize_remote_path(stage_base.rstrip("/"))
        script = f"""
set -euo pipefail
STAGE_BASE={shlex.quote(stage_base_norm)}
if [[ ! -d "$STAGE_BASE" ]]; then
  echo "__DAYLILY_ERROR__=missing_stage_base"
  exit 2
fi
latest_dir=$(ls -1dt "$STAGE_BASE"/*/ 2>/dev/null | head -n 1)
if [[ -z "$latest_dir" ]]; then
  echo "__DAYLILY_ERROR__=no_stage_runs"
  exit 3
fi
samples_file=$(ls -1 "$latest_dir"/*_samples.tsv 2>/dev/null | head -n 1)
units_file=$(ls -1 "$latest_dir"/*_units.tsv 2>/dev/null | head -n 1)
if [[ -z "$samples_file" || -z "$units_file" ]]; then
  echo "__DAYLILY_ERROR__=missing_config"
  exit 4
fi
echo "__DAYLILY_STAGE_DIR__=$latest_dir"
echo "__DAYLILY_STAGE_SAMPLES__=$samples_file"
echo "__DAYLILY_STAGE_UNITS__=$units_file"
"""
    result = run_remote_script(pem, host, script)
    if result.stdout:
        print(result.stdout, end="")
    if result.stderr:
        print(result.stderr, file=sys.stderr, end="")
    return parse_remote_config(result.stdout)


def format_list(values: List[str]) -> str:
    quoted = ",".join(f"'{value.strip()}'" for value in values if value.strip())
    return f"[{quoted}]"


def build_default_command(
    target: str,
    genome: str,
    jobs: int,
    aligners: List[str],
    dedupers: List[str],
    snv_callers: List[str],
    containerized: bool,
    dry_run: bool,
    extra: Optional[str],
) -> str:
    config_args = [
        f"genome_build={genome}",
        f"aligners={format_list(aligners)}",
        f"dedupers={format_list(dedupers)}",
        f"snv_callers={format_list(snv_callers)}",
    ]
    command = [
        "DAY_CONTAINERIZED=true" if containerized else "DAY_CONTAINERIZED=false",
        "dy-r",
        target,
        "-p",
        "-k",
        f"-j {jobs}",
        "--config",
        " ".join(config_args),
    ]
    if dry_run:
        command.append("-n")
    if extra:
        command.append(extra)
    return " ".join(command)


def build_parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(
        description="Clone daylily-omics-analysis and launch a workflow inside tmux.",
    )
    parser.add_argument(
        "--profile",
        default=os.environ.get("AWS_PROFILE"),
        help="AWS CLI profile to use (default: $AWS_PROFILE)",
    )
    parser.add_argument("--region", help="AWS region for the cluster")
    parser.add_argument("--cluster", help="ParallelCluster name")
    parser.add_argument("--pem", help="Path to the SSH PEM key")
    parser.add_argument(
        "--stage-dir",
        help="Specific staging directory containing *_samples.tsv and *_units.tsv",
    )
    parser.add_argument(
        "--stage-base",
        default="/fsx/staged_sample_data",
        help="Base staging directory to scan when --stage-dir is omitted",
    )
    parser.add_argument(
        "--session-name",
        default="daylily-omics-analysis",
        help="Name of the tmux session to create on the head node",
    )
    parser.add_argument(
        "--destination",
        default="dayoa",
        help="Workspace destination passed to day-clone",
    )
    parser.add_argument(
        "--repository",
        default="daylily-omics-analysis",
        help="Repository key to pass to day-clone",
    )
    parser.add_argument(
        "--transport",
        choices=["https", "ssh"],
        default="https",
        help="Git transport for day-clone",
    )
    parser.add_argument(
        "--project",
        help="Project/budget to supply to dyoainit",
    )
    parser.add_argument(
        "--skip-project-check",
        action="store_true",
        help="Pass --skip-project-check to dyoainit",
    )
    parser.add_argument(
        "--genome",
        default="hg38",
        help="Genome build to activate (default: %(default)s)",
    )
    parser.add_argument(
        "--jobs",
        type=int,
        default=6,
        help="Value for dy-r -j (default: %(default)s)",
    )
    parser.add_argument(
        "--aligners",
        default="bwa2a",
        help="Comma separated list of aligners",
    )
    parser.add_argument(
        "--dedupers",
        default="dppl",
        help="Comma separated list of dedupers",
    )
    parser.add_argument(
        "--snv-callers",
        default="deep",
        help="Comma separated list of SNV callers",
    )
    parser.add_argument(
        "--target",
        default="produce_snv_concordances",
        help="Workflow target to run via dy-r",
    )
    parser.add_argument(
        "--dy-command",
        help="Override the dy-r command entirely",
    )
    parser.add_argument(
        "--snakemake-extra",
        help="Additional arguments appended to the dy-r command",
    )
    parser.add_argument(
        "--no-containerized",
        action="store_true",
        help="Disable DAY_CONTAINERIZED (enabled by default)",
    )
    parser.add_argument(
        "--dry-run",
        action="store_true",
        help="Add -n to the dy-r command",
    )
    return parser


def main(argv: Optional[List[str]] = None) -> int:
    parser = build_parser()
    args = parser.parse_args(argv)

    aws_profile = args.profile
    if not aws_profile:
        raise CommandError("AWS profile is required. Set AWS_PROFILE or use --profile.")

    region = args.region or resolve_region(aws_profile)
    cluster_name = args.cluster or resolve_cluster(aws_profile, region)
    pem_file = resolve_pem_file(args.pem)
    headnode_ip = fetch_headnode_ip(aws_profile, region, cluster_name)

    stage_config = discover_stage_config(pem_file, headnode_ip, args.stage_dir, args.stage_base)

    if args.dy_command:
        dy_command = args.dy_command
    else:
        aligners = args.aligners.split(",")
        dedupers = args.dedupers.split(",")
        snv_callers = args.snv_callers.split(",")
        dy_command = build_default_command(
            target=args.target,
            genome=args.genome,
            jobs=args.jobs,
            aligners=aligners,
            dedupers=dedupers,
            snv_callers=snv_callers,
            containerized=not args.no_containerized,
            dry_run=args.dry_run,
            extra=args.snakemake_extra,
        )

    project_arg = shlex.quote(args.project) if args.project else ""
    skip_check = "true" if args.skip_project_check else "false"

    pipeline_script = f"""
set -euo pipefail
SESSION_NAME={shlex.quote(args.session_name)}
STAGE_DIR={shlex.quote(stage_config.stage_dir)}
STAGE_SAMPLES={shlex.quote(stage_config.samples_path)}
STAGE_UNITS={shlex.quote(stage_config.units_path)}
DESTINATION={shlex.quote(args.destination)}
REPO_KEY={shlex.quote(args.repository)}
TRANSPORT={shlex.quote(args.transport)}
PROJECT_VALUE={project_arg if project_arg else ""}
SKIP_PROJECT_CHECK={skip_check}
DY_COMMAND={dy_command}

analysis_root=$(python3 - <<'PYCONFIG'
from pathlib import Path
analysis_root = '/fsx/analysis_results'
config_path = Path.home() / '.config/daylily/daylily_cli_global.yaml'
if config_path.exists():
    for line in config_path.read_text().splitlines():
        line = line.strip()
        if line.startswith('analysis_root:'):
            analysis_root = line.split(':', 1)[1].strip()
            break
print(analysis_root.rstrip('/'))
PYCONFIG
)
repo_relative=$(python3 - <<'PYREPOS'
from pathlib import Path
repo_key = {shlex.quote(args.repository)}
relative = 'daylily-omics-analysis'
config_path = Path.home() / '.config/daylily/daylily_available_repositories.yaml'
if config_path.exists():
    current_key = None
    for raw in config_path.read_text().splitlines():
        stripped = raw.strip()
        if not stripped or stripped.startswith('#'):
            continue
        if stripped.endswith(':'):
            current_key = stripped[:-1].strip()
            continue
        if current_key == repo_key and stripped.startswith('relative_path:'):
            relative = stripped.split(':', 1)[1].strip()
            break
print(relative.strip())
PYREPOS
)
if [[ -z "$analysis_root" ]]; then
  echo "[ERROR] Unable to determine analysis_root." >&2
  exit 5
fi
analysis_root=${{analysis_root%/}}
user_dir=$(whoami)
clone_root="$analysis_root/$user_dir/$DESTINATION"
repo_path="$clone_root/$repo_relative"
mkdir -p "$clone_root"
if [[ ! -d "$repo_path/.git" ]]; then
  echo "[INFO] Cloning $REPO_KEY into $clone_root via day-clone..."
  day-clone --destination "$DESTINATION" --repository "$REPO_KEY" --which-one "$TRANSPORT"
else
  echo "[INFO] Repository already exists at $repo_path; skipping clone."
fi
if [[ ! -d "$repo_path" ]]; then
  echo "[ERROR] day-clone did not create $REPO_KEY at $repo_path" >&2
  exit 6
fi
cd "$repo_path"
mkdir -p config
cp "$STAGE_SAMPLES" config/samples.tsv
cp "$STAGE_UNITS" config/units.tsv

if [[ ! -f dyoainit ]]; then
  echo "[ERROR] dyoainit not found within $repo_path" >&2
  exit 7
fi

declare -a dyoa_args=()
if [[ -n "$PROJECT_VALUE" ]]; then
  dyoa_args+=(--project {project_arg} )
fi
if [[ "$SKIP_PROJECT_CHECK" == "true" ]]; then
  dyoa_args+=(--skip-project-check)
fi
. dyoainit "${{dyoa_args[@]}}"
dy-a slurm {shlex.quote(args.genome)} remote

echo "[INFO] Launching workflow: $DY_COMMAND"
set +e
eval "$DY_COMMAND"
workflow_status=$?
set -e
echo "[INFO] Workflow exited with status $workflow_status"
if [[ $workflow_status -ne 0 ]]; then
  echo "[ERROR] Workflow failed with status $workflow_status" >&2
fi
echo "[INFO] Attach with: tmux attach -t $SESSION_NAME"
exec bash
"""

    tmux_script = f"""
set -euo pipefail
SESSION_NAME={shlex.quote(args.session_name)}
if tmux has-session -t "$SESSION_NAME" 2>/dev/null; then
  echo "__DAYLILY_ERROR__=session_exists"
  exit 8
fi
work_script=$(mktemp)
cat <<'PAYLOAD' > "$work_script"
{pipeline_script}
PAYLOAD
tmux new-session -d -s "$SESSION_NAME" "bash -lc 'source \"$work_script\"'"
echo "__DAYLILY_SESSION__=$SESSION_NAME"
echo "__DAYLILY_WORK_SCRIPT__=$work_script"
"""

    result = run_remote_script(pem_file, headnode_ip, tmux_script)
    if result.stdout:
        print(result.stdout, end="")
    if result.stderr:
        print(result.stderr, file=sys.stderr, end="")

    for line in result.stdout.splitlines():
        if line.startswith("__DAYLILY_ERROR__="):
            raise CommandError(line.split("=", 1)[1])
        if line.startswith("__DAYLILY_SESSION__="):
            session_name = line.split("=", 1)[1].strip()
            print(f"Tmux session '{session_name}' created on the head node.")
            print(
                "Attach with: ssh -i {pem} ubuntu@{host} 'tmux attach -t {session}'".format(
                    pem=pem_file,
                    host=headnode_ip,
                    session=session_name,
                )
            )
            break
    else:
        raise CommandError("Tmux session creation did not report success.")

    return 0


if __name__ == "__main__":  # pragma: no cover
    try:
        raise SystemExit(main())
    except CommandError as exc:
        print(f"Error: {exc}", file=sys.stderr)
        raise SystemExit(1)

