#!/usr/bin/env bash
set -euo pipefail

# arguments
# -n: number of processes (world size)
# must have at least 2 arguments: world size and executable

if [ "$#" -lt 2 ]; then
  echo "Usage: $0 <world_size> <executable>"
  exit 1
fi

WORLD_SIZE=${1:-2}
EXECUTABLE=${2}
ALL_OHTER_ARGS=${@:3}
BACKEND=${BACKEND:-gloo}
MASTER_ADDR=${MASTER_ADDR:-127.0.0.1}
MASTER_PORT=${MASTER_PORT:-29500}

# If using NCCL, specify GPUs
CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-0,1,2,3}

IFS=',' read -r -a GPU_LIST <<< "${CUDA_VISIBLE_DEVICES}"

#echo "Launching WORLD_SIZE=$WORLD_SIZE BACKEND=$BACKEND"

for (( RANK=0; RANK<${WORLD_SIZE}; RANK++ )); do
  LOCAL_RANK=$(( RANK % ${#GPU_LIST[@]} ))
  DEV=${GPU_LIST[$LOCAL_RANK]}

  #echo "Launching RANK=$RANK LOCAL_RANK=$LOCAL_RANK GPU=$DEV"

  RANK=$RANK \
  WORLD_SIZE=$WORLD_SIZE \
  LOCAL_RANK=$LOCAL_RANK \
  BACKEND=$BACKEND \
  MASTER_ADDR=$MASTER_ADDR \
  MASTER_PORT=$MASTER_PORT \
  CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES \
  ${EXECUTABLE} ${ALL_OHTER_ARGS} &
done

wait
#echo "All ranks finished."
