#!/usr/bin/env bash
set -euo pipefail

if [ "$#" -lt 2 ]; then
  echo "Usage: $0 <world_size> <executable>"
  exit 1
fi

WORLD_SIZE=${1:-2}
EXECUTABLE=${2}
if [ "$#" -gt 2 ]; then
  ALL_OTHER_ARGS=("${@:3}")
else
  ALL_OTHER_ARGS=()
fi
BACKEND=${BACKEND:-gloo}
MASTER_ADDR=${MASTER_ADDR:-127.0.0.1}
MASTER_PORT=${MASTER_PORT:-}
BLOCKS_PER_PROCESS=${BLOCKS_PER_PROCESS:-1}
NODE_RANK=${NODE_RANK:-0}
LOCAL_WORLD_SIZE=${LOCAL_WORLD_SIZE:-$WORLD_SIZE}
LAUNCH_TIMEOUT=${PD_RUN_TIMEOUT:-10}

# If using NCCL, specify GPUs
CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-0,1,2,3}
IFS=',' read -r -a GPU_LIST <<< "${CUDA_VISIBLE_DEVICES}"

if [ "${WORLD_SIZE}" -le 0 ]; then
  echo "WORLD_SIZE must be positive"
  exit 1
fi

if [ "${BLOCKS_PER_PROCESS}" -le 0 ]; then
  echo "BLOCKS_PER_PROCESS must be positive"
  exit 1
fi

choose_master_port() {
  local candidate
  if command -v ss >/dev/null 2>&1; then
    for _ in $(seq 1 32); do
      candidate=$((20000 + RANDOM % 20000))
      if ! ss -ltn "( sport = :${candidate} )" 2>/dev/null | grep -q LISTEN; then
        echo "${candidate}"
        return
      fi
    done
  fi
  echo $((20000 + $$ % 20000))
}

if [ -z "${MASTER_PORT}" ]; then
  MASTER_PORT=$(choose_master_port)
fi

PIDS=()
STATUS_FILES=()
REMAINING=0
EXIT_CODE=0
STOPPING=0

STATUS_DIR=$(mktemp -d "${TMPDIR:-/tmp}/pd-run.XXXXXX")

cleanup() {
  local signal_name=${1:-EXIT}
  if [ "${STOPPING}" -ne 0 ]; then
    return
  fi
  STOPPING=1

  for pid in "${PIDS[@]:-}"; do
    if kill -0 "${pid}" 2>/dev/null; then
      kill "${pid}" 2>/dev/null || true
    fi
  done

  local deadline=$((SECONDS + LAUNCH_TIMEOUT))
  while [ "${SECONDS}" -lt "${deadline}" ]; do
    local alive=0
    for pid in "${PIDS[@]:-}"; do
      if kill -0 "${pid}" 2>/dev/null; then
        alive=1
        break
      fi
    done
    if [ "${alive}" -eq 0 ]; then
      break
    fi
    sleep 0.1
  done

  for pid in "${PIDS[@]:-}"; do
    if kill -0 "${pid}" 2>/dev/null; then
      kill -9 "${pid}" 2>/dev/null || true
    fi
  done

  wait 2>/dev/null || true
  rm -rf "${STATUS_DIR}"
}

on_signal() {
  EXIT_CODE=130
  cleanup "${1}"
  exit "${EXIT_CODE}"
}

trap 'on_signal INT' INT
trap 'on_signal TERM' TERM
trap 'cleanup EXIT' EXIT

launch_rank() {
  local rank=$1
  local local_rank=$2
  local device_id=$3
  local status_file=$4

  (
    set +e
    RANK=$rank \
    WORLD_SIZE=$WORLD_SIZE \
    PROCESS_RANK=$rank \
    PROCESS_WORLD_SIZE=$WORLD_SIZE \
    LOCAL_RANK=$local_rank \
    LOCAL_WORLD_SIZE=$LOCAL_WORLD_SIZE \
    NODE_RANK=$NODE_RANK \
    BLOCKS_PER_PROCESS=$BLOCKS_PER_PROCESS \
    BACKEND=$BACKEND \
    MASTER_ADDR=$MASTER_ADDR \
    MASTER_PORT=$MASTER_PORT \
    CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES \
    DEVICE_ID=$device_id \
    "${EXECUTABLE}" "${ALL_OTHER_ARGS[@]}"
    echo $? > "${status_file}"
  ) &
}

for (( RANK=0; RANK<${WORLD_SIZE}; RANK++ )); do
  LOCAL_RANK=$(( RANK % ${#GPU_LIST[@]} ))
  DEV=${GPU_LIST[$LOCAL_RANK]:-0}
  STATUS_FILE="${STATUS_DIR}/rank_${RANK}.status"

  launch_rank "${RANK}" "${LOCAL_RANK}" "${DEV}" "${STATUS_FILE}"
  PIDS+=("$!")
  STATUS_FILES+=("${STATUS_FILE}")
  REMAINING=$((REMAINING + 1))
done

while [ "${REMAINING}" -gt 0 ]; do
  PROGRESS=0
  for i in "${!STATUS_FILES[@]}"; do
    STATUS_FILE=${STATUS_FILES[$i]}
    if [ -z "${STATUS_FILE}" ] || [ ! -f "${STATUS_FILE}" ]; then
      continue
    fi

    if read -r WAIT_STATUS < "${STATUS_FILE}"; then
      rm -f "${STATUS_FILE}"
      STATUS_FILES[$i]=""
      PROGRESS=1
      REMAINING=$((REMAINING - 1))
      if [ "${WAIT_STATUS}" -ne 0 ]; then
        EXIT_CODE=${WAIT_STATUS}
        cleanup FAIL
        exit "${EXIT_CODE}"
      fi
    fi
  done

  if [ "${PROGRESS}" -eq 0 ]; then
    sleep 0.1
  fi
done

trap - EXIT INT TERM
rm -rf "${STATUS_DIR}"
