#!/usr/bin/env python3
"""Stage analysis samples on a head node from the developer workstation.

This helper reads an ``analysis_samples.tsv`` file, validates the referenced
inputs, stages the data into the FSx-backed staging directory via the AWS CLI,
and generates ``samples.tsv``/``units.tsv`` manifests that match the head node
workflow.  The staging directory created on FSx follows the pattern
``/fsx/data/staged_sample_data/remote_stage_<timestamp>/``.
"""

from __future__ import annotations

import argparse
import csv
import datetime as dt
import json
import os
import subprocess
import sys
import tempfile
import uuid
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Sequence, Tuple


RUN_ID = "RUN_ID"
SAMPLE_ID = "SAMPLE_ID"
EXPERIMENT_ID = "EXPERIMENTID"
SAMPLE_TYPE = "SAMPLE_TYPE"
LIB_PREP = "LIB_PREP"
SEQ_VENDOR = "SEQ_VENDOR"
SEQ_PLATFORM = "SEQ_PLATFORM"
LANE = "LANE"
SEQBC_ID = "SEQBC_ID"
PATH_TO_CONCORDANCE = "PATH_TO_CONCORDANCE_DATA_DIR"
R1_FQ = "R1_FQ"
R2_FQ = "R2_FQ"
STAGE_DIRECTIVE = "STAGE_DIRECTIVE"
STAGE_TARGET = "STAGE_TARGET"
SUBSAMPLE_PCT = "SUBSAMPLE_PCT"
IS_POS_CTRL = "IS_POS_CTRL"
IS_NEG_CTRL = "IS_NEG_CTRL"
N_X = "N_X"
N_Y = "N_Y"
EXTERNAL_SAMPLE_ID = "EXTERNAL_SAMPLE_ID"

KEY_FIELDS = [
    RUN_ID,
    SAMPLE_ID,
    EXPERIMENT_ID,
    SAMPLE_TYPE,
    LIB_PREP,
    SEQ_VENDOR,
    SEQ_PLATFORM,
]

DERIVED_UNITS_FIELDS = {
    "RUNID",
    "SAMPLEID",
    "EXPERIMENTID",
    "LANEID",
    "BARCODEID",
    "LIBPREP",
    "SEQ_VENDOR",
    "SEQ_PLATFORM",
    "ILMN_R1_PATH",
    "ILMN_R2_PATH",
    "PACBIO_R1_PATH",
    "PACBIO_R2_PATH",
    "ONT_R1_PATH",
    "ONT_R2_PATH",
    "UG_R1_PATH",
    "UG_R2_PATH",
    "SUBSAMPLE_PCT",
    "SAMPLEUSE",
    "BWA_KMER",
}

UNITS_HEADER = [
    "RUNID",
    "SAMPLEID",
    "EXPERIMENTID",
    "LANEID",
    "BARCODEID",
    "LIBPREP",
    "SEQ_VENDOR",
    "SEQ_PLATFORM",
    "ILMN_R1_PATH",
    "ILMN_R2_PATH",
    "PACBIO_R1_PATH",
    "PACBIO_R2_PATH",
    "ONT_R1_PATH",
    "ONT_R2_PATH",
    "UG_R1_PATH",
    "UG_R2_PATH",
    "SUBSAMPLE_PCT",
    "SAMPLEUSE",
    "BWA_KMER",
    "DEEP_MODEL",
    "ULTIMA_CRAM",
    "ULTIMA_CRAM_ALIGNER",
    "ULTIMA_CRAM_SNV_CALLER",
    "ONT_CRAM",
    "ONT_CRAM_ALIGNER",
    "ONT_CRAM_SNV_CALLER",
    "PB_BAM",
    "PB_BAM_ALIGNER",
    "PB_BAM_SNV_CALLER",
]

SAMPLES_HEADER = [
    "SAMPLEID",
    "SAMPLESOURCE",
    "SAMPLECLASS",
    "BIOLOGICAL_SEX",
    "CONCORDANCE_CONTROL_PATH",
    "IS_POSITIVE_CONTROL",
    "IS_NEGATIVE_CONTROL",
    "SAMPLE_TYPE",
    "TUM_NRM_SAMPLEID_MATCH",
    "EXTERNAL_SAMPLE_ID",
    "N_X",
    "N_Y",
    "TRUTH_DATA_DIR",
]


class CommandError(RuntimeError):
    """Raised when an external command fails."""


@dataclass(frozen=True)
class AwsConfig:
    profile: str
    region: Optional[str]


@dataclass
class StagePaths:
    remote_fsx_root: str
    remote_stage_name: str
    remote_fsx_stage: str
    remote_s3_stage: str


def parse_args(argv: Optional[Sequence[str]] = None) -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="Stage analysis samples into FSx from the local workstation.",
    )
    parser.add_argument("analysis_samples", help="Path to analysis_samples.tsv")
    parser.add_argument(
        "--stage-target",
        default="/data/staged_sample_data",
        help="FSx staging base directory (default: %(default)s)",
    )
    parser.add_argument(
        "--reference-bucket",
        required=True,
        help="S3 URI (s3://bucket[/prefix]) mapped to the FSx data repository",
    )
    parser.add_argument(
        "--config-dir",
        help="Directory to place generated samples.tsv/units.tsv (default: TSV dir)",
    )
    parser.add_argument(
        "--profile",
        default=os.environ.get("AWS_PROFILE"),
        help="AWS CLI profile to use (default: $AWS_PROFILE)",
    )
    parser.add_argument(
        "--region",
        default=os.environ.get("AWS_REGION") or os.environ.get("AWS_DEFAULT_REGION"),
        help="AWS region to use for CLI commands (defaults to environment)",
    )
    parser.add_argument(
        "--debug",
        action="store_true",
        help="Print AWS CLI commands before execution",
    )
    return parser.parse_args(argv)


def ensure_profile(profile: Optional[str]) -> str:
    if not profile:
        raise CommandError(
            "AWS profile is required. Set AWS_PROFILE or pass --profile."
        )
    return profile


def parse_s3_uri(uri: str) -> Tuple[str, str]:
    if not uri.startswith("s3://"):
        raise CommandError(f"Expected an s3:// URI, received: {uri}")
    without_scheme = uri[5:]
    if "/" in without_scheme:
        bucket, key = without_scheme.split("/", 1)
    else:
        bucket, key = without_scheme, ""
    return bucket, key


def normalise_stage_target(stage_target: str) -> str:
    stage_target = stage_target.strip()
    if not stage_target.startswith("/data"):
        raise CommandError(
            "Stage target must be an FSx path (expected to start with /data)."
        )
    return stage_target.rstrip("/")


def build_stage_paths(stage_target: str, bucket_uri: str) -> StagePaths:
    stage_target = normalise_stage_target(stage_target)
    timestamp = dt.datetime.utcnow().strftime("%Y%m%dT%H%M%SZ")
    remote_stage_name = f"remote_stage_{timestamp}"
    remote_fsx_stage = f"{stage_target}/{remote_stage_name}"

    bucket, prefix = parse_s3_uri(bucket_uri.rstrip("/"))
    prefix = prefix.rstrip("/")
    fsx_relative = stage_target.lstrip("/")
    if prefix:
        remote_s3_stage = f"s3://{bucket}/{prefix}/{fsx_relative}/{remote_stage_name}"
    else:
        remote_s3_stage = f"s3://{bucket}/{fsx_relative}/{remote_stage_name}"
    return StagePaths(
        remote_fsx_root=stage_target,
        remote_stage_name=remote_stage_name,
        remote_fsx_stage=remote_fsx_stage,
        remote_s3_stage=remote_s3_stage,
    )


def build_aws_env(config: AwsConfig) -> Dict[str, str]:
    env = dict(os.environ)
    env["AWS_PROFILE"] = config.profile
    if config.region:
        env["AWS_REGION"] = config.region
        env.setdefault("AWS_DEFAULT_REGION", config.region)
    return env


def run_command(
    command: Sequence[str],
    *,
    env: Optional[Dict[str, str]] = None,
    check: bool = True,
    capture_output: bool = False,
) -> subprocess.CompletedProcess:
    try:
        return subprocess.run(  # type: ignore[arg-type]
            list(command),
            env=env,
            check=check,
            text=True,
            capture_output=capture_output,
        )
    except subprocess.CalledProcessError as exc:  # pragma: no cover - runtime path
        stdout = exc.stdout or ""
        stderr = exc.stderr or ""
        message = f"Command failed ({exc.returncode}): {' '.join(command)}"
        if stdout:
            message += f"\nSTDOUT:\n{stdout.strip()}"
        if stderr:
            message += f"\nSTDERR:\n{stderr.strip()}"
        raise CommandError(message) from exc


def aws_command(
    args: Sequence[str],
    *,
    aws_env: Dict[str, str],
    debug: bool = False,
    capture_output: bool = False,
) -> subprocess.CompletedProcess:
    command = ["aws", *args]
    if debug:
        print("[DEBUG]", " ".join(command))
    return run_command(command, env=aws_env, capture_output=capture_output)


def check_local_path(path: str) -> None:
    if not os.path.exists(path):
        raise CommandError(f"Local path not found: {path}")


def check_s3_path(uri: str, *, aws_env: Dict[str, str], debug: bool) -> None:
    args = ["s3", "ls", uri]
    result = aws_command(args, aws_env=aws_env, debug=debug, capture_output=True)
    if not result.stdout.strip():
        raise CommandError(f"S3 object or prefix not accessible: {uri}")


def validate_subsample_pct(value: str) -> str:
    if not value:
        return "na"
    try:
        pct = float(value)
    except ValueError:
        return "na"
    return value if 0.0 < pct < 1.0 else "na"


def build_concordance_reference_uri(path: str, reference_bucket: str) -> str:
    """Translate an FSx concordance path to the backing S3 URI."""

    relative = path[len("/fsx/") :]
    return f"{reference_bucket.rstrip('/')}/{relative.lstrip('/')}"


def check_concordance_path(
    path: str,
    *,
    reference_bucket: str,
    aws_env: Dict[str, str],
    debug: bool,
) -> None:
    if path.startswith("/fsx/"):
        s3_uri = build_concordance_reference_uri(path, reference_bucket)
        check_s3_path(s3_uri, aws_env=aws_env, debug=debug)
    else:
        check_local_path(os.path.expanduser(path))


def validate_sources(
    sources: Iterable[Tuple[str, bool]],
    *,
    reference_bucket: str,
    aws_env: Dict[str, str],
    debug: bool,
) -> None:
    for src, is_concordance in sources:
        if not src or src.lower() == "na":
            continue
        if src.startswith("s3://"):
            check_s3_path(src, aws_env=aws_env, debug=debug)
            continue
        if is_concordance:
            check_concordance_path(
                src,
                reference_bucket=reference_bucket,
                aws_env=aws_env,
                debug=debug,
            )
            continue
        check_local_path(os.path.expanduser(src))


def ensure_remote_stage_writable(stage: StagePaths, *, aws_env: Dict[str, str], debug: bool) -> None:
    with tempfile.NamedTemporaryFile("w", delete=False) as handle:
        handle.write("daylily staging write test\n")
        handle.flush()
        temp_path = handle.name
    dest = f"{stage.remote_s3_stage}/_write_test.txt"
    try:
        aws_command(["s3", "cp", temp_path, dest], aws_env=aws_env, debug=debug)
    finally:
        os.unlink(temp_path)
    aws_command(["s3", "rm", dest], aws_env=aws_env, debug=debug)


def safe_int(value: str, default: int = 0) -> int:
    try:
        return int(value)
    except (TypeError, ValueError):
        return default


def determine_sex(n_x: int, n_y: int) -> str:
    if n_x == 2 and n_y == 0:
        return "female"
    if n_x == 1 and n_y == 1:
        return "male"
    return "na"


def get_entry_value(entry: Dict[str, str], field: str, default: str = "") -> str:
    return (entry.get(field, default) or default).strip()


def normalise_identifier(value: str) -> str:
    return value.replace("_", "-")


def aws_copy(
    source: str,
    destination: str,
    *,
    aws_env: Dict[str, str],
    debug: bool,
    recursive: bool = False,
) -> None:
    args = ["s3", "cp", source, destination]
    if recursive:
        args.append("--recursive")
    aws_command(args, aws_env=aws_env, debug=debug)


def ensure_s3_objects(
    sources: Sequence[str],
    *,
    dest_s3_dir: str,
    sample_prefix: str,
    aws_env: Dict[str, str],
    debug: bool,
) -> Tuple[List[str], List[str]]:
    uploaded: List[str] = []
    resolved: List[str] = []
    try:
        for idx, source in enumerate(sources, start=1):
            if source.startswith("s3://"):
                resolved.append(source)
                continue
            expanded = os.path.expanduser(source)
            part_name = (
                f"{sample_prefix}_part{idx}_{uuid.uuid4().hex}_{Path(expanded).name}"
            )
            remote = f"{dest_s3_dir}/_parts/{part_name}"
            aws_copy(expanded, remote, aws_env=aws_env, debug=debug)
            uploaded.append(remote)
            resolved.append(remote)
    except Exception:
        cleanup_s3_objects(uploaded, aws_env=aws_env, debug=debug)
        raise
    return resolved, uploaded


def cleanup_s3_objects(
    uris: Sequence[str], *, aws_env: Dict[str, str], debug: bool
) -> None:
    for uri in uris:
        try:
            aws_command(["s3", "rm", uri], aws_env=aws_env, debug=debug)
        except CommandError:
            pass


def multipart_concatenate(
    sources: Sequence[str],
    destination: str,
    *,
    aws_env: Dict[str, str],
    debug: bool,
) -> None:
    if not sources:
        raise CommandError("No sources provided for multipart concatenation")

    bucket, key = parse_s3_uri(destination)
    create = aws_command(
        ["s3api", "create-multipart-upload", "--bucket", bucket, "--key", key],
        aws_env=aws_env,
        debug=debug,
        capture_output=True,
    )
    upload_id = json.loads(create.stdout or "{}").get("UploadId")
    if not upload_id:
        raise CommandError(
            f"Failed to initiate multipart upload for destination {destination}"
        )

    parts: List[Dict[str, str]] = []
    try:
        for idx, source in enumerate(sources, start=1):
            src_bucket, src_key = parse_s3_uri(source)
            copy_source = f"{src_bucket}/{src_key}"
            result = aws_command(
                [
                    "s3api",
                    "upload-part-copy",
                    "--bucket",
                    bucket,
                    "--key",
                    key,
                    "--part-number",
                    str(idx),
                    "--upload-id",
                    upload_id,
                    "--copy-source",
                    copy_source,
                ],
                aws_env=aws_env,
                debug=debug,
                capture_output=True,
            )
            payload = json.loads(result.stdout or "{}")
            etag = (
                payload.get("CopyPartResult", {}).get("ETag")
                if isinstance(payload, dict)
                else None
            )
            if not etag:
                raise CommandError(
                    f"Failed to copy part from {source} during multipart upload"
                )
            parts.append({"PartNumber": idx, "ETag": etag})

        complete_body = json.dumps({"Parts": parts})
        aws_command(
            [
                "s3api",
                "complete-multipart-upload",
                "--bucket",
                bucket,
                "--key",
                key,
                "--upload-id",
                upload_id,
                "--multipart-upload",
                complete_body,
            ],
            aws_env=aws_env,
            debug=debug,
        )
    except Exception:
        try:
            aws_command(
                [
                    "s3api",
                    "abort-multipart-upload",
                    "--bucket",
                    bucket,
                    "--key",
                    key,
                    "--upload-id",
                    upload_id,
                ],
                aws_env=aws_env,
                debug=debug,
            )
        except CommandError:
            pass
        raise


def stage_concordance(
    source: str,
    dest_fsx: str,
    dest_s3: str,
    *,
    aws_env: Dict[str, str],
    debug: bool,
) -> str:
    if source.lower() == "na" or source.startswith("/fsx/data"):
        return source
    if source.startswith("s3://"):
        aws_copy(source, dest_s3, aws_env=aws_env, debug=debug, recursive=True)
    else:
        expanded = os.path.expanduser(source)
        if os.path.isdir(expanded):
            aws_copy(expanded, dest_s3, aws_env=aws_env, debug=debug, recursive=True)
        else:
            aws_copy(expanded, dest_s3, aws_env=aws_env, debug=debug)
    return dest_fsx


def stage_single_lane(
    r1: str,
    r2: str,
    dest_fsx_dir: str,
    dest_s3_dir: str,
    *,
    aws_env: Dict[str, str],
    debug: bool,
) -> Tuple[str, str]:
    r1_name = os.path.basename(r1)
    r2_name = os.path.basename(r2)
    remote_r1_fsx = f"{dest_fsx_dir}/{r1_name}"
    remote_r2_fsx = f"{dest_fsx_dir}/{r2_name}"
    remote_r1_s3 = f"{dest_s3_dir}/{r1_name}"
    remote_r2_s3 = f"{dest_s3_dir}/{r2_name}"
    aws_copy(r1, remote_r1_s3, aws_env=aws_env, debug=debug)
    aws_copy(r2, remote_r2_s3, aws_env=aws_env, debug=debug)
    return remote_r1_fsx, remote_r2_fsx


def stage_multi_lane(
    r1_files: Sequence[str],
    r2_files: Sequence[str],
    sample_prefix: str,
    dest_fsx_dir: str,
    dest_s3_dir: str,
    *,
    aws_env: Dict[str, str],
    debug: bool,
) -> Tuple[str, str]:
    merged_r1_name = f"{sample_prefix}_merged_R1.fastq.gz"
    merged_r2_name = f"{sample_prefix}_merged_R2.fastq.gz"
    remote_r1_s3 = f"{dest_s3_dir}/{merged_r1_name}"
    remote_r2_s3 = f"{dest_s3_dir}/{merged_r2_name}"

    r1_sources, r1_uploaded = ensure_s3_objects(
        r1_files,
        dest_s3_dir=dest_s3_dir,
        sample_prefix=f"{sample_prefix}_R1",
        aws_env=aws_env,
        debug=debug,
    )
    try:
        r2_sources, r2_uploaded = ensure_s3_objects(
            r2_files,
            dest_s3_dir=dest_s3_dir,
            sample_prefix=f"{sample_prefix}_R2",
            aws_env=aws_env,
            debug=debug,
        )
    except Exception:
        cleanup_s3_objects(r1_uploaded, aws_env=aws_env, debug=debug)
        raise

    try:
        multipart_concatenate(
            r1_sources, remote_r1_s3, aws_env=aws_env, debug=debug
        )
        multipart_concatenate(
            r2_sources, remote_r2_s3, aws_env=aws_env, debug=debug
        )
    finally:
        cleanup_s3_objects(r1_uploaded, aws_env=aws_env, debug=debug)
        cleanup_s3_objects(r2_uploaded, aws_env=aws_env, debug=debug)

    remote_r1_fsx = f"{dest_fsx_dir}/{merged_r1_name}"
    remote_r2_fsx = f"{dest_fsx_dir}/{merged_r2_name}"
    return remote_r1_fsx, remote_r2_fsx


def write_tsv(path: Path, header: Sequence[str], rows: Sequence[Dict[str, str]]) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    with path.open("w", newline="") as handle:
        writer = csv.DictWriter(handle, fieldnames=header, delimiter="\t")
        writer.writeheader()
        writer.writerows(rows)


def deduplicate_rows(
    rows: Sequence[Dict[str, str]], header: Sequence[str]
) -> List[Dict[str, str]]:
    """Return rows with duplicate data removed, preserving order."""

    seen: set[Tuple[str, ...]] = set()
    unique_rows: List[Dict[str, str]] = []
    for row in rows:
        key = tuple(row.get(column, "") for column in header)
        if key in seen:
            continue
        seen.add(key)
        unique_rows.append(row)
    return unique_rows


def normalise_units_paths(rows: Sequence[Dict[str, str]]) -> None:
    """Ensure FSx paths include the /fsx prefix when required."""

    for row in rows:
        for field, value in list(row.items()):
            if not isinstance(value, str):
                continue
            if value.startswith("/data/"):
                row[field] = f"/fsx{value}"
            elif value == "/data":
                row[field] = "/fsx/data"


def process_samples(
    analysis_samples: Path,
    stage: StagePaths,
    *,
    reference_bucket: str,
    aws_env: Dict[str, str],
    debug: bool,
) -> Tuple[List[Dict[str, str]], List[Dict[str, str]], List[str], List[str]]:
    with analysis_samples.open(newline="") as ff:
        reader = csv.DictReader(ff, delimiter="\t")
        if reader.fieldnames is None:
            raise CommandError("Input TSV is missing a header row")
        missing = [field for field in KEY_FIELDS + [LANE, SEQBC_ID, R1_FQ, R2_FQ] if field not in reader.fieldnames]
        if missing:
            raise CommandError(f"Missing required columns: {', '.join(missing)}")

        grouped: Dict[Tuple[str, ...], List[Dict[str, str]]] = defaultdict(list)
        sources_to_validate: List[Tuple[str, bool]] = []
        for row in reader:
            if not row:
                continue
            normalized = {k: (v or "").strip() for k, v in row.items() if k}
            if not any(normalized.values()):
                continue
            grouped[tuple(normalized[field] for field in KEY_FIELDS)].append(normalized)
            sources_to_validate.extend(
                [
                    (normalized.get(R1_FQ, ""), False),
                    (normalized.get(R2_FQ, ""), False),
                    (normalized.get(PATH_TO_CONCORDANCE, ""), True),
                ]
            )

    validate_sources(
        sources_to_validate,
        reference_bucket=reference_bucket,
        aws_env=aws_env,
        debug=debug,
    )

    samples_rows: Dict[str, Dict[str, str]] = {}
    sampleid_to_entry: Dict[str, Tuple[str, Dict[str, str]]] = {}
    units_rows: List[Dict[str, str]] = []
    created_files: List[str] = []
    run_ids: set[str] = set()

    for key, entries in grouped.items():
        first = entries[0]
        ruid = normalise_identifier(key[0])
        sampleid = normalise_identifier(key[1])
        experiment_id = normalise_identifier(key[2])
        sample_type = normalise_identifier(key[3])
        libprep = normalise_identifier(key[4])
        vendor_value = normalise_identifier(key[5])
        vendor = vendor_value.upper()
        seq_platform = normalise_identifier(key[6])
        lane = normalise_identifier(first[LANE])
        seqbc = normalise_identifier(first[SEQBC_ID])

        composite_sample_id = f"{sampleid}-{seq_platform}-{libprep}-{sample_type}-{experiment_id}"
        sample_name = f"{ruid}_{composite_sample_id}"
        sample_prefix = f"{ruid}_{composite_sample_id}_{seqbc}_0"
        dest_fsx_dir = f"{stage.remote_fsx_stage}/{sample_prefix}"
        dest_s3_dir = f"{stage.remote_s3_stage}/{sample_prefix}"

        subsample_pct = validate_subsample_pct(get_entry_value(first, SUBSAMPLE_PCT, "na"))

        is_multi_lane = len(entries) > 1
        if is_multi_lane and lane == "0":
            raise CommandError(
                f"Invalid LANE=0 for multi-lane sample: {sample_name}"
            )

        if is_multi_lane:
            r1_files = [get_entry_value(entry, R1_FQ) for entry in entries]
            r2_files = [get_entry_value(entry, R2_FQ) for entry in entries]
            remote_r1, remote_r2 = stage_multi_lane(
                r1_files,
                r2_files,
                sample_prefix,
                dest_fsx_dir,
                dest_s3_dir,
                aws_env=aws_env,
                debug=debug,
            )
            lane_id = "0"
        else:
            r1 = get_entry_value(first, R1_FQ)
            r2 = get_entry_value(first, R2_FQ)
            remote_r1, remote_r2 = stage_single_lane(
                r1,
                r2,
                dest_fsx_dir,
                dest_s3_dir,
                aws_env=aws_env,
                debug=debug,
            )
            lane_id = lane

        created_files.extend([remote_r1, remote_r2])

        concordance_source = get_entry_value(first, PATH_TO_CONCORDANCE, "na")
        concordance_fsx = dest_fsx_dir + "/concordance_data"
        concordance_s3 = dest_s3_dir + "/concordance_data"
        concordance_path = stage_concordance(
            concordance_source,
            concordance_fsx,
            concordance_s3,
            aws_env=aws_env,
            debug=debug,
        )
        if concordance_path.startswith(stage.remote_fsx_root):
            created_files.append(concordance_path)

        units_row = {column: "" for column in UNITS_HEADER}
        units_row.update(
            {
                "RUNID": ruid,
                "SAMPLEID": sampleid,
                "EXPERIMENTID": experiment_id,
                "LANEID": lane_id,
                "BARCODEID": seqbc,
                "LIBPREP": libprep,
                "SEQ_VENDOR": vendor,
                "SEQ_PLATFORM": seq_platform,
                "SUBSAMPLE_PCT": subsample_pct,
            }
        )

        is_pos_ctrl = get_entry_value(first, IS_POS_CTRL).lower() == "true"
        units_row["SAMPLEUSE"] = get_entry_value(first, "SAMPLEUSE") or ("posControl" if is_pos_ctrl else "sample")
        units_row["BWA_KMER"] = get_entry_value(first, "BWA_KMER") or "19"

        if vendor == "ILMN":
            units_row["ILMN_R1_PATH"] = remote_r1
            units_row["ILMN_R2_PATH"] = remote_r2
        elif vendor == "ONT":
            units_row["ONT_R1_PATH"] = remote_r1
            units_row["ONT_R2_PATH"] = remote_r2
        elif vendor == "PACBIO":
            units_row["PACBIO_R1_PATH"] = remote_r1
            units_row["PACBIO_R2_PATH"] = remote_r2
        elif vendor == "UG":
            units_row["UG_R1_PATH"] = remote_r1
            units_row["UG_R2_PATH"] = remote_r2

        for field in set(first.keys()).intersection(UNITS_HEADER):
            if field in DERIVED_UNITS_FIELDS:
                continue
            value = get_entry_value(first, field)
            if value:
                units_row[field] = value

        units_rows.append(units_row)

        sex = determine_sex(
            safe_int(get_entry_value(first, N_X)),
            safe_int(get_entry_value(first, N_Y)),
        )
        samples_row = {
            "SAMPLEID": sampleid,
            "SAMPLESOURCE": sample_type,
            "SAMPLECLASS": "research",
            "BIOLOGICAL_SEX": sex,
            "CONCORDANCE_CONTROL_PATH": concordance_path,
            "IS_POSITIVE_CONTROL": get_entry_value(first, IS_POS_CTRL),
            "IS_NEGATIVE_CONTROL": get_entry_value(first, IS_NEG_CTRL),
            "SAMPLE_TYPE": sample_type,
            "TUM_NRM_SAMPLEID_MATCH": "na",
            "EXTERNAL_SAMPLE_ID": get_entry_value(first, EXTERNAL_SAMPLE_ID) or "na",
            "N_X": get_entry_value(first, N_X),
            "N_Y": get_entry_value(first, N_Y),
            "TRUTH_DATA_DIR": concordance_path,
        }

        existing = samples_rows.get(sample_name)
        if existing and existing != samples_row:
            raise CommandError(
                f"Conflicting metadata for sample {sample_name}."
            )
        sampleid_entry = sampleid_to_entry.get(sampleid)
        if sampleid_entry and sampleid_entry[1] != samples_row:
            raise CommandError(
                f"Duplicate SAMPLEID with conflicting metadata: {sampleid}"
            )
        samples_rows[sample_name] = samples_row
        sampleid_to_entry[sampleid] = (sample_name, samples_row)

        run_ids.add(ruid)

    sorted_samples = [samples_rows[name] for name in sorted(samples_rows.keys())]
    return sorted_samples, units_rows, sorted(created_files), sorted(run_ids)


def main(argv: Optional[Sequence[str]] = None) -> int:
    args = parse_args(argv)
    analysis_samples = Path(args.analysis_samples).expanduser().resolve()
    if not analysis_samples.exists():
        raise CommandError(f"Analysis samples TSV not found: {analysis_samples}")

    aws_config = AwsConfig(profile=ensure_profile(args.profile), region=args.region)
    aws_env = build_aws_env(aws_config)

    stage = build_stage_paths(args.stage_target, args.reference_bucket)
    ensure_remote_stage_writable(stage, aws_env=aws_env, debug=args.debug)

    samples_rows, units_rows, created_files, run_ids = process_samples(
        analysis_samples,
        stage,
        reference_bucket=args.reference_bucket,
        aws_env=aws_env,
        debug=args.debug,
    )

    timestamp = stage.remote_stage_name.replace("remote_stage_", "")
    samples_filename = f"{timestamp}_samples.tsv"
    units_filename = f"{timestamp}_units.tsv"

    if args.config_dir:
        config_dir = Path(args.config_dir).expanduser()
    else:
        config_dir = analysis_samples.parent

    samples_path = config_dir / samples_filename
    units_path = config_dir / units_filename
    unique_samples_rows = deduplicate_rows(samples_rows, SAMPLES_HEADER)
    normalise_units_paths(units_rows)

    write_tsv(samples_path, SAMPLES_HEADER, unique_samples_rows)
    write_tsv(units_path, UNITS_HEADER, units_rows)

    remote_samples_path = f"{stage.remote_s3_stage}/{samples_filename}"
    remote_units_path = f"{stage.remote_s3_stage}/{units_filename}"

    aws_copy(str(samples_path), remote_samples_path, aws_env=aws_env, debug=args.debug)
    aws_copy(str(units_path), remote_units_path, aws_env=aws_env, debug=args.debug)

    print("Remote staging completed successfully.")
    print(f"Remote FSx stage directory: {stage.remote_fsx_stage}")
    print(f"Staged files ({len(created_files)}):")
    for path in created_files:
        print(f"  {path}")
    print("Generated configuration files:")
    print(f"  samples.tsv -> {remote_samples_path}")
    print(f"  units.tsv   -> {remote_units_path}")
    return 0


if __name__ == "__main__":  # pragma: no cover - manual execution
    try:
        raise SystemExit(main())
    except CommandError as exc:
        print(f"Error: {exc}", file=sys.stderr)
        raise SystemExit(1)
