#!/usr/bin/env bash
set -u -o pipefail

usage() {
  cat <<USAGE
Usage:
  sg-qe-bench-qe-vs-ngc [--preset epw_metal_bench_heavy]
                        [--bench-root DIR] [--input FILE]
                        [--np1 N] [--np4 N]
                        [--auto-scale|--no-auto-scale] [--max-scale N]
                        [--native-qe-prefix DIR]
                        [--with-ngc]
                        [--ngc-image IMAGE] [--ngc-pw PATH]
                        [--mca-profile ob1-tcp-eth0|ob1-tcp|ucx|smcuda]

Defaults:
  --preset             epw_metal_bench_heavy
  --np1                1
  --np4                4
  --auto-scale         on (1,2,4,8... up to GPU count, max 8)
  --max-scale          8
  --native-qe-prefix   \$HOME/.local/sg/qe-gpu-src/qe-7.5
  --with-ngc           off (default native-only)
  --ngc-image          (unset; providing it implies --with-ngc)
  --ngc-pw             /usr/local/qe/bin/pw.x
  --mca-profile        ob1-tcp-eth0 (stable default)

Notes:
  - fixed comparison: npN uses -nk N
  - mpirun pinning is unified: --bind-to core --map-by slot --rank-by slot
  - rank->GPU is fixed by OMPI local rank (CUDA_VISIBLE_DEVICES=local_rank)
  - nvidia-smi log is sampled every 1s per run.
USAGE
}

die(){ echo "ERROR: $*" >&2; exit 1; }

PRESET="epw_metal_bench_heavy"
BENCH_ROOT=""
INPUT=""
NP1=1
NP4=4
AUTO_SCALE=1
MAX_SCALE=8
NATIVE_QE_PREFIX="${HOME}/.local/sg/qe-gpu-src/qe-7.5"
NGC_IMAGE=""
NGC_PW="/usr/local/qe/bin/pw.x"
WITH_NGC=0
MCA_PROFILE="ob1-tcp-eth0"
MCA_PROFILE_SET=0
RUN_TAG="bench_qe_vs_ngc"
LABEL_PREFIX="bench"

while [[ $# -gt 0 ]]; do
  case "$1" in
    --preset) PRESET="$2"; shift 2;;
    --bench-root) BENCH_ROOT="$2"; shift 2;;
    --input) INPUT="$2"; shift 2;;
    --np1) NP1="$2"; shift 2;;
    --np4) NP4="$2"; shift 2;;
    --auto-scale) AUTO_SCALE=1; shift 1;;
    --no-auto-scale) AUTO_SCALE=0; shift 1;;
    --max-scale) MAX_SCALE="$2"; shift 2;;
    --native-qe-prefix) NATIVE_QE_PREFIX="$2"; shift 2;;
    --with-ngc) WITH_NGC=1; shift 1;;
    --ngc-image) NGC_IMAGE="$2"; WITH_NGC=1; shift 2;;
    --ngc-pw) NGC_PW="$2"; shift 2;;
    --mca-profile) MCA_PROFILE="$2"; MCA_PROFILE_SET=1; shift 2;;
    -h|--help) usage; exit 0;;
    *) echo "ERROR: unknown arg: $1" >&2; usage; exit 2;;
  esac
done

[[ "$MAX_SCALE" =~ ^[1-9][0-9]*$ ]] || die "--max-scale must be >=1"

if [[ -n "$PRESET" ]]; then
  case "$PRESET" in
    epw_metal_bench_heavy)
      RUN_TAG="epw_metal_vs_ngc"
      LABEL_PREFIX="bench2"
      [[ -n "$BENCH_ROOT" ]] || BENCH_ROOT="/home/dl/bench/BENCH-QE-TESTCASE-PILOT-001"
      if [[ -z "$INPUT" ]]; then
        INPUT="${BENCH_ROOT}/work/epw_metal_vs_ngc_20260220_145445/bench/input_bench_heavy.in"
      fi
      if [[ ! -f "$INPUT" ]]; then
        INPUT="$(find "${BENCH_ROOT}/work" -maxdepth 3 -type f -path "*/epw_metal_vs_ngc_*/bench/input_bench_heavy.in" 2>/dev/null | sort | tail -n 1)"
      fi
      if [[ "$MCA_PROFILE_SET" -eq 0 ]]; then
        MCA_PROFILE="ob1-tcp-eth0"
      fi
      ;;
    *)
      die "unknown --preset: $PRESET"
      ;;
  esac
fi

[[ -n "$BENCH_ROOT" ]] || die "--bench-root is required"
[[ -n "$INPUT" ]] || die "--input is required (or use --preset epw_metal_bench_heavy)"
[[ -f "$INPUT" ]] || die "input not found: $INPUT"
[[ -d "$BENCH_ROOT" ]] || die "bench-root not found: $BENCH_ROOT"

PSEUDO_DIR="$(cd "$(dirname "$INPUT")" && pwd)/pseudos"
[[ -d "$PSEUDO_DIR" ]] || die "pseudos dir not found: $PSEUDO_DIR"

NATIVE_PW="${NATIVE_QE_PREFIX}/bin/pw.x"
[[ -x "$NATIVE_PW" ]] || die "native pw.x not found: $NATIVE_PW"

MPIRUN="$(command -v mpirun || true)"
[[ -n "$MPIRUN" ]] || die "mpirun not found in PATH"

TS="$(date +%Y%m%d_%H%M%S)"
WORKROOT="${BENCH_ROOT}/work/${RUN_TAG}_${TS}"
LOGROOT="${BENCH_ROOT}/logs/${RUN_TAG}_${TS}"
mkdir -p "$WORKROOT" "$LOGROOT" || die "failed to create output dirs under bench-root: $BENCH_ROOT"

SUMMARY="${LOGROOT}/summary.txt"
SUMMARY_ALL="${LOGROOT}/summary_all.txt"
SUMMARY_BENCH2="${LOGROOT}/summary_bench2.txt"
: > "$SUMMARY"
: > "$SUMMARY_ALL"
: > "$SUMMARY_BENCH2"

COMMON_PINNING=(--bind-to core --map-by slot --rank-by slot)
COMMON_MPI_ENV=(OMP_NUM_THREADS=1 MKL_NUM_THREADS=1 OPENBLAS_NUM_THREADS=1 FFTW_NUM_THREADS=1 OMP_PROC_BIND=close OMP_PLACES=cores)
NVHPC_COMM_LIBS_DEFAULT="/opt/nvidia/hpc_sdk/Linux_x86_64/25.7/comm_libs/12.9"
NVHPC_CUDA_DEFAULT="/opt/nvidia/hpc_sdk/Linux_x86_64/25.7/cuda/12.9"
if [[ -z "${NVCOMPILER_COMM_LIBS_HOME:-}" && -d "$NVHPC_COMM_LIBS_DEFAULT" ]]; then
  export NVCOMPILER_COMM_LIBS_HOME="$NVHPC_COMM_LIBS_DEFAULT"
fi
if [[ -z "${NVHPC_CUDA_HOME:-}" && -d "$NVHPC_CUDA_DEFAULT" ]]; then
  export NVHPC_CUDA_HOME="$NVHPC_CUDA_DEFAULT"
fi

have_nvidia_smi=0
if command -v nvidia-smi >/dev/null 2>&1; then
  have_nvidia_smi=1
fi

detect_gpu_count() {
  if [[ "$have_nvidia_smi" -eq 1 ]]; then
    local c
    c="$(nvidia-smi --list-gpus 2>/dev/null | wc -l | tr -d '[:space:]')"
    if [[ "$c" =~ ^[1-9][0-9]*$ ]]; then
      echo "$c"
      return 0
    fi
  fi
  echo "1"
}

build_scale_list() {
  local gpus="$1"
  local n=1
  local lim="$MAX_SCALE"
  local out=()
  if [[ "$AUTO_SCALE" -eq 0 ]]; then
    out+=("$NP1")
    if [[ "$NP4" != "$NP1" ]]; then out+=("$NP4"); fi
  else
    if (( gpus < lim )); then lim="$gpus"; fi
    while (( n <= lim )); do
      out+=("$n")
      n=$((n*2))
    done
  fi
  printf '%s\n' "${out[@]}" | awk '!seen[$0]++'
}

start_smi_log() {
  local out="$1"
  if [[ "$have_nvidia_smi" -eq 1 ]]; then
    nvidia-smi --query-gpu=timestamp,index,utilization.gpu,utilization.memory,memory.used,memory.total \
      --format=csv -l 1 > "$out" 2>/dev/null &
    echo $!
  else
    echo ""
  fi
}

stop_smi_log() {
  local pid="$1"
  if [[ -n "$pid" ]]; then
    kill "$pid" >/dev/null 2>&1 || true
    wait "$pid" >/dev/null 2>&1 || true
  fi
}

extract_wall() {
  local out="$1"
  rg -n "PWSCF[[:space:]]*:.*WALL" "$out" | tail -n 1 | sed -E 's/^.*(PWSCF[[:space:]]*:.*WALL)/     \1/'
}

job_done() {
  local out="$1"
  rg -q "JOB DONE" "$out"
}

write_summary() {
  local label="$1" rc="$2" profile="$3" out="$4" cmdlog="$5"
  local done="NO"
  local wall=""
  if job_done "$out"; then done="YES"; fi
  local line
  line="$(extract_wall "$out")"
  if [[ -n "$line" ]]; then wall="$line"; fi
  local row
  row="${label}|job_done=${done}|${wall}|"
  echo "$row" >> "$SUMMARY"
  echo "$row" >> "$SUMMARY_ALL"
  if [[ "$label" == bench2_* ]]; then
    echo "$row" >> "$SUMMARY_BENCH2"
  fi
  {
    echo "rc=${rc}"
    echo "profile=${profile}"
  } >> "$cmdlog"
}

mca_env_for_profile() {
  local profile="$1"
  case "$profile" in
    ob1-tcp-eth0)
      echo "OMPI_MCA_pml=ob1 OMPI_MCA_btl=self,tcp OMPI_MCA_btl_tcp_if_include=eth0 OMPI_MCA_oob_tcp_if_include=eth0 OMPI_MCA_coll=^hcoll"
      ;;
    ob1-tcp)
      echo "OMPI_MCA_pml=ob1 OMPI_MCA_btl=self,tcp OMPI_MCA_coll=^hcoll"
      ;;
    ucx)
      echo "OMPI_MCA_pml=ucx OMPI_MCA_coll=^hcoll"
      ;;
    smcuda)
      echo "OMPI_MCA_pml=ob1 OMPI_MCA_btl=self,tcp,smcuda OMPI_MCA_coll=^hcoll"
      ;;
    *)
      echo ""
      ;;
  esac
}

run_native_case() {
  local label="$1" np="$2" nk="$3"
  local casedir="${WORKROOT}/${label}"
  mkdir -p "$casedir"
  cp -a "$INPUT" "${casedir}/input.in"
  cp -a "$PSEUDO_DIR" "${casedir}/pseudos"

  local out="${casedir}/${label}.out"
  local err="${casedir}/${label}.err"
  local smi="${LOGROOT}/${label}_nvidia_smi.csv"
  local cmdlog="${LOGROOT}/${label}.log"

  {
    echo "label=${label}"
    echo "cmd: mpirun ${COMMON_PINNING[*]} -np ${np} bash -lc 'CUDA_VISIBLE_DEVICES=\${OMPI_COMM_WORLD_LOCAL_RANK:-0}; exec \"${NATIVE_PW}\" -nk ${nk} -in \"${casedir}/input.in\"'"
    echo "env: ${COMMON_MPI_ENV[*]}"
    sha256sum "${casedir}/input.in"
  } > "$cmdlog"

  local smi_pid
  smi_pid="$(start_smi_log "$smi")"

  local rc=0
  local native_env=("${COMMON_MPI_ENV[@]}")
  if [[ -n "${NVCOMPILER_COMM_LIBS_HOME:-}" ]]; then
    native_env+=("NVCOMPILER_COMM_LIBS_HOME=${NVCOMPILER_COMM_LIBS_HOME}")
  fi
  if [[ -n "${NVHPC_CUDA_HOME:-}" ]]; then
    native_env+=("NVHPC_CUDA_HOME=${NVHPC_CUDA_HOME}")
  fi
  env "${native_env[@]}" SG_PW="$NATIVE_PW" SG_CASEDIR="$casedir" SG_NK="$nk" \
    "$MPIRUN" "${COMMON_PINNING[@]}" -np "$np" \
    bash -lc 'cd "$SG_CASEDIR"; export CUDA_VISIBLE_DEVICES=${OMPI_COMM_WORLD_LOCAL_RANK:-0}; exec "$SG_PW" -nk "$SG_NK" -in input.in' \
    > "$out" 2> "$err" || rc=$?

  stop_smi_log "$smi_pid"
  write_summary "$label" "$rc" "native" "$out" "$cmdlog"
}

run_ngc_case() {
  local label="$1" np="$2" nk="$3" profile="$4"
  local casedir="${WORKROOT}/${label}"
  mkdir -p "$casedir"
  cp -a "$INPUT" "${casedir}/input.in"
  cp -a "$PSEUDO_DIR" "${casedir}/pseudos"

  local out="${casedir}/${label}.out"
  local err="${casedir}/${label}.err"
  local smi="${LOGROOT}/${label}_nvidia_smi.csv"
  local cmdlog="${LOGROOT}/${label}.log"

  local mca_env
  mca_env="$(mca_env_for_profile "$profile")"

  {
    echo "label=${label}"
    echo "profile=${profile}"
    echo "cmd: docker run --rm --gpus all --ipc=host --network=host -v ${casedir}:/work ${NGC_IMAGE} mpirun ${COMMON_PINNING[*]} -np ${np} bash -lc 'CUDA_VISIBLE_DEVICES=\${OMPI_COMM_WORLD_LOCAL_RANK:-0}; exec ${NGC_PW} -nk ${nk} -in /work/input.in'"
    echo "env: ${COMMON_MPI_ENV[*]} ${mca_env}"
    sha256sum "${casedir}/input.in"
  } > "$cmdlog"

  local smi_pid
  smi_pid="$(start_smi_log "$smi")"

  local rc=0
  local docker_env=()
  for kv in "${COMMON_MPI_ENV[@]}"; do docker_env+=(-e "$kv"); done
  for kv in $mca_env; do docker_env+=(-e "$kv"); done

  docker run --rm --gpus all --ipc=host --network=host \
    -w /work \
    -v "${casedir}:/work" \
    "${docker_env[@]}" \
    -e SG_NK="$nk" \
    "$NGC_IMAGE" \
    bash -lc "mpirun ${COMMON_PINNING[*]} -np ${np} bash -lc 'cd /work; export CUDA_VISIBLE_DEVICES=\${OMPI_COMM_WORLD_LOCAL_RANK:-0}; exec ${NGC_PW} -nk \${SG_NK} -in input.in'" \
    > "$out" 2> "$err" || rc=$?

  stop_smi_log "$smi_pid"
  write_summary "$label" "$rc" "$profile" "$out" "$cmdlog"
}

echo "== sg-qe-bench-qe-vs-ngc =="
echo "WORKROOT: $WORKROOT" | tee -a "$SUMMARY"
echo "LOGROOT : $LOGROOT" | tee -a "$SUMMARY"
echo "INPUT   : $INPUT" | tee -a "$SUMMARY"
echo "PROFILE : $MCA_PROFILE" | tee -a "$SUMMARY"
GPU_COUNT="$(detect_gpu_count)"
echo "GPUS    : $GPU_COUNT" | tee -a "$SUMMARY"
echo "AUTO_SCALE: $AUTO_SCALE (max=$MAX_SCALE)" | tee -a "$SUMMARY"
echo "WITH_NGC: $WITH_NGC" | tee -a "$SUMMARY"
echo "" >> "$SUMMARY"

if command -v sha256sum >/dev/null 2>&1; then
  sha256sum "$INPUT" > "${LOGROOT}/input_hashes.sha256"
else
  die "sha256sum not found"
fi

mapfile -t SCALE_LIST < <(build_scale_list "$GPU_COUNT")
[[ "${#SCALE_LIST[@]}" -gt 0 ]] || die "empty scale list"

for np in "${SCALE_LIST[@]}"; do
  run_native_case "${LABEL_PREFIX}_native_np${np}_nk${np}" "$np" "$np"
done

NGC_SKIP_REASON=""
if [[ "$WITH_NGC" -eq 0 ]]; then
  NGC_SKIP_REASON="image not provided"
elif [[ -z "$NGC_IMAGE" ]]; then
  NGC_SKIP_REASON="image not provided"
elif ! command -v docker >/dev/null 2>&1; then
  NGC_SKIP_REASON="docker not available"
fi

if [[ -n "$NGC_SKIP_REASON" ]]; then
  echo "NGC: SKIP (image not provided / docker not available)" | tee -a "$SUMMARY" "$SUMMARY_ALL"
  echo "NGC_SKIP_REASON: ${NGC_SKIP_REASON}" | tee -a "$SUMMARY" "$SUMMARY_ALL"
else
  case "$MCA_PROFILE" in
    ob1-tcp-eth0|ob1-tcp|ucx|smcuda)
      for np in "${SCALE_LIST[@]}"; do
        run_ngc_case "${LABEL_PREFIX}_ngc_np${np}_nk${np}_${MCA_PROFILE}" "$np" "$np" "$MCA_PROFILE"
      done
      ;;
    *)
      die "unknown --mca-profile: $MCA_PROFILE"
      ;;
  esac
fi

ZIP="${BENCH_ROOT}/${RUN_TAG}_${TS}.zip"
command -v zip >/dev/null 2>&1 || die "zip not found"
(cd "$BENCH_ROOT" && zip -r "$ZIP" "work/${RUN_TAG}_${TS}" "logs/${RUN_TAG}_${TS}" >/dev/null)
echo "ZIP: $ZIP" >> "$SUMMARY"

echo "Summary: $SUMMARY"
