#!/usr/bin/env bash
set -Eeuo pipefail

usage() {
  cat <<'EOF'
Usage:
  llama-cpp-bisect-context MODEL --low N --high N [options]
  llama-cpp-bisect-context --cmd-template CMD --low N --high N [options]
  llama-cpp-bisect-context --cmd-file FILE --low N --high N [options]

Bisect the largest llama.cpp llama-server context that can start and complete
a near-context prompt without OOMing. Startup-only mode is available for isolating the first cliff.

MODEL is a llama.cpp preset generated from the Nix llama-swap config.

Command templates are evaluated with these environment variables:
  PORT  random listen port for this trial
  CTX   candidate context size

Options:
  --cmd-template CMD        llama-server command, e.g. 'llama-server --port ${PORT} -c ${CTX} ...'
  --cmd-file FILE           executable or shell snippet using $PORT and $CTX
  --preset-file FILE        preset JSON file (default: Nix-generated presets)
  --list-presets            list available Nix-generated presets and exit
  --low N                   known/assumed lower context bound
  --high N                  upper context bound to test
  --step N                  stop when high-low <= N (default: 1024)
  --prompt-ratio PCT        prompt fill target as percent of CTX (default: 90)
  --chars-per-token N       rough prompt sizing ratio (default: 4)
  --prompt-turns N          split the prompt across N user/assistant turns (default: 4)
  --max-tokens N            generated tokens for prompt test (default: 32)
  --startup-timeout SEC     seconds to wait for /health readiness (default: 300)
  --request-timeout SEC     seconds to wait for prompt response (default: 600)
  --cooldown SEC            seconds to sleep after stopping server (default: 5)
  --startup-only            only test server startup, not prompt/runtime OOM
  --verbose                 print llama-server logs for each failed trial
  --keep-logs               keep trial logs after a successful run too
  -h, --help                show this help

Examples:
  llama-cpp-bisect-context \
    --cmd-template 'llama-server --port ${PORT} -m model.gguf -c ${CTX} -ngl 99' \
    --low 32768 --high 196608

  llama-cpp-bisect-context qwen3.6-27b-ik-cuda0 --low 32768 --high 180000
  llama-cpp-bisect-context --cmd-file ./server-command.sh --low 32768 --high 196608
EOF
}

preset_model=""
preset_file="__LLAMA_CPP_PRESETS__"
list_presets=0
cmd_template=""
cmd_file=""
low=""
high=""
step=1024
prompt_ratio=90
chars_per_token=4
prompt_turns=4
max_tokens=32
startup_timeout=300
request_timeout=600
cooldown=5
startup_only=0
verbose=0
keep_logs=0

while [[ $# -gt 0 ]]; do
  case "$1" in
    --cmd-template) cmd_template="${2:-}"; shift 2 ;;
    --cmd-file) cmd_file="${2:-}"; shift 2 ;;
    --preset-file) preset_file="${2:-}"; shift 2 ;;
    --list-presets) list_presets=1; shift ;;
    --low) low="${2:-}"; shift 2 ;;
    --high) high="${2:-}"; shift 2 ;;
    --step) step="${2:-}"; shift 2 ;;
    --prompt-ratio) prompt_ratio="${2:-}"; shift 2 ;;
    --chars-per-token) chars_per_token="${2:-}"; shift 2 ;;
    --prompt-turns) prompt_turns="${2:-}"; shift 2 ;;
    --max-tokens) max_tokens="${2:-}"; shift 2 ;;
    --startup-timeout) startup_timeout="${2:-}"; shift 2 ;;
    --request-timeout) request_timeout="${2:-}"; shift 2 ;;
    --cooldown) cooldown="${2:-}"; shift 2 ;;
    --startup-only) startup_only=1; shift ;;
    --verbose) verbose=1; shift ;;
    --keep-logs) keep_logs=1; shift ;;
    -h|--help) usage; exit 0 ;;
    --*) echo "unknown argument: $1" >&2; usage >&2; exit 2 ;;
    *)
      if [[ -n "$preset_model" ]]; then
        echo "unexpected positional argument: $1" >&2
        usage >&2
        exit 2
      fi
      preset_model="$1"
      shift
      ;;
  esac
done

list_presets_json() {
  python3 - "$preset_file" <<'PY'
import json
import sys
with open(sys.argv[1]) as f:
    presets = json.load(f)
for key in sorted(presets):
    name = presets[key].get("name", "")
    print(f"{key}\t{name}" if name else key)
PY
  exit 0
}

if (( list_presets )); then
  list_presets_json
fi

load_preset() {
  local command_file="$tmpdir/preset-command.sh"
  python3 - "$preset_file" "$preset_model" "$command_file" <<'PY'
import json
import shlex
import sys

preset_file, model_id, command_file = sys.argv[1:]
with open(preset_file) as f:
    presets = json.load(f)
try:
    preset = presets[model_id]
except KeyError:
    print(f"unknown preset: {model_id}", file=sys.stderr)
    print("available presets:", file=sys.stderr)
    for key in sorted(presets):
        print(f"  {key}", file=sys.stderr)
    sys.exit(2)

cmd = preset["cmd"].replace("${ctx}", "${CTX}").replace("$ctx", "${CTX}")
env = preset.get("env", [])
with open(command_file, "w") as f:
    f.write("set -e\n")
    for item in env:
        key, sep, value = item.partition("=")
        if not sep or not key:
            continue
        f.write(f"export {key}={shlex.quote(value)}\n")
    f.write(cmd)
    if not cmd.endswith("\n"):
        f.write("\n")
PY
  cmd_file="$command_file"
}

require_int() {
  local name="$1" value="$2"
  if [[ ! "$value" =~ ^[0-9]+$ ]]; then
    echo "$name must be a positive integer" >&2
    exit 2
  fi
}

mode_count=0
[[ -n "$preset_model" ]] && mode_count=$((mode_count + 1))
[[ -n "$cmd_template" ]] && mode_count=$((mode_count + 1))
[[ -n "$cmd_file" ]] && mode_count=$((mode_count + 1))
if (( mode_count != 1 )); then
  echo "use exactly one of MODEL, --cmd-template, or --cmd-file" >&2
  exit 2
fi
if [[ -z "$low" || -z "$high" ]]; then
  echo "missing --low or --high" >&2
  exit 2
fi

for pair in \
  "--low:$low" \
  "--high:$high" \
  "--step:$step" \
  "--prompt-ratio:$prompt_ratio" \
  "--chars-per-token:$chars_per_token" \
  "--prompt-turns:$prompt_turns" \
  "--max-tokens:$max_tokens" \
  "--startup-timeout:$startup_timeout" \
  "--request-timeout:$request_timeout" \
  "--cooldown:$cooldown"; do
  require_int "${pair%%:*}" "${pair#*:}"
done

if (( low <= 0 || high <= low || step <= 0 || prompt_ratio <= 0 || chars_per_token <= 0 || prompt_turns <= 0 )); then
  echo "invalid numeric bounds/options" >&2
  exit 2
fi

if [[ -n "$cmd_file" && ! -f "$cmd_file" ]]; then
  echo "cmd file not found: $cmd_file" >&2
  exit 2
fi

for dep in curl python3; do
  if ! command -v "$dep" >/dev/null 2>&1; then
    echo "missing required command: $dep" >&2
    exit 2
  fi
done

tmpdir="$(mktemp -d)"
server_pid=""
log_file=""

terminate_server() {
  if [[ -z "${server_pid:-}" ]]; then
    return 0
  fi

  kill -- "-${server_pid}" >/dev/null 2>&1 || true
  kill "$server_pid" >/dev/null 2>&1 || true

  local waited=0
  while kill -0 "$server_pid" >/dev/null 2>&1 && (( waited < 30 )); do
    sleep 1
    waited=$((waited + 1))
  done

  if kill -0 "$server_pid" >/dev/null 2>&1; then
    kill -9 -- "-${server_pid}" >/dev/null 2>&1 || true
    kill -9 "$server_pid" >/dev/null 2>&1 || true
  fi

  wait "$server_pid" >/dev/null 2>&1 || true
  server_pid=""
}

cleanup() {
  local status=$?
  trap - EXIT INT TERM HUP
  terminate_server
  if (( keep_logs || status != 0 )); then
    echo "logs kept in: $tmpdir" >&2
  else
    rm -rf "$tmpdir"
  fi
}

interrupt() {
  echo "interrupted; stopping llama-server" >&2
  exit 130
}

trap cleanup EXIT
trap interrupt INT TERM HUP

if [[ -n "$preset_model" ]]; then
  load_preset
fi

free_port() {
  python3 - <<'PY'
import socket
with socket.socket() as s:
    s.bind(("127.0.0.1", 0))
    print(s.getsockname()[1])
PY
}

start_server() {
  local ctx="$1"
  PORT="$(free_port)"
  CTX="$ctx"
  export PORT CTX
  log_file="$tmpdir/llama-server-${ctx}.log"
  {
    printf 'CTX=%s\n' "$CTX"
    printf 'PORT=%s\n' "$PORT"
    if [[ -n "$cmd_file" ]]; then
      printf 'CMD_FILE=%s\n' "$cmd_file"
    else
      printf 'CMD_TEMPLATE=%s\n' "$cmd_template"
    fi
    printf -- '--- llama-server output ---\n'
  } >"$log_file"

  if [[ -n "$cmd_file" ]]; then
    setsid bash "$cmd_file" >>"$log_file" 2>&1 &
  else
    setsid bash -c "$cmd_template" >>"$log_file" 2>&1 &
  fi
  server_pid="$!"
}

stop_server() {
  terminate_server
  sleep "$cooldown"
}

print_failure_log() {
  local label="$1" ctx="$2"
  echo "[$label] ctx=$ctx failed; log: $log_file" >&2
  if (( verbose )) && [[ -f "$log_file" ]]; then
    sed -n '1,220p' "$log_file" >&2 || true
  fi
}

wait_ready() {
  local deadline=$((SECONDS + startup_timeout))
  while (( SECONDS < deadline )); do
    if [[ -n "${server_pid:-}" ]] && ! kill -0 "$server_pid" >/dev/null 2>&1; then
      return 1
    fi
    if curl -fsS --max-time 5 "http://127.0.0.1:${PORT}/health" >/dev/null 2>&1; then
      return 0
    fi
    if curl -fsS --max-time 5 "http://127.0.0.1:${PORT}/v1/models" >/dev/null 2>&1; then
      return 0
    fi
    sleep 2
  done
  return 1
}

make_prompt_json() {
  local ctx="$1"
  local approx_tokens=$(( ctx * prompt_ratio / 100 ))
  local chars=$(( approx_tokens * chars_per_token ))
  python3 - "$chars" "$max_tokens" "$prompt_turns" <<'PY'
import json
import sys

chars = int(sys.argv[1])
max_tokens = int(sys.argv[2])
prompt_turns = int(sys.argv[3])
seed = (
    "This is deterministic context filler for memory testing. "
    "It uses normal words so token estimates are closer to real prompts. "
)
messages = []
remaining = chars
for turn in range(prompt_turns):
    turns_left = prompt_turns - turn
    chunk_chars = max(1, remaining // turns_left)
    content = (seed * ((chunk_chars // len(seed)) + 1))[:chunk_chars]
    messages.append({"role": "user", "content": content})
    remaining -= chunk_chars
    if turn != prompt_turns - 1:
        messages.append({"role": "assistant", "content": "Acknowledged."})

print(json.dumps({
    "messages": messages,
    "max_tokens": max_tokens,
    "temperature": 0,
    "stream": False,
}))
PY
}

run_prompt() {
  local ctx="$1"
  local payload="$tmpdir/prompt-${ctx}.json"
  make_prompt_json "$ctx" >"$payload"
  curl -fsS \
    --max-time "$request_timeout" \
    -H 'Content-Type: application/json' \
    -d "@$payload" \
    "http://127.0.0.1:${PORT}/v1/chat/completions" \
    >/dev/null
}

test_startup() {
  local ctx="$1"
  echo "[startup] testing ctx=$ctx" >&2
  start_server "$ctx"
  if wait_ready; then
    stop_server
    echo "[startup] ctx=$ctx PASS" >&2
    return 0
  fi
  print_failure_log startup "$ctx"
  stop_server
  return 1
}

test_qualified_context() {
  local ctx="$1"
  echo "[ctx] testing ctx=$ctx with prompt_ratio=${prompt_ratio}% prompt_turns=${prompt_turns}" >&2
  start_server "$ctx"
  if ! wait_ready; then
    print_failure_log ctx-startup "$ctx"
    stop_server
    return 1
  fi
  if run_prompt "$ctx"; then
    stop_server
    echo "[ctx] ctx=$ctx PASS" >&2
    return 0
  fi
  print_failure_log ctx-prompt "$ctx"
  stop_server
  return 1
}

bisect_max() {
  local label="$1" pass="$2" fail="$3" fn="$4"
  while (( fail - pass > step )); do
    local mid=$(( (pass + fail) / 2 ))
    if "$fn" "$mid"; then
      pass="$mid"
    else
      fail="$mid"
    fi
  done
  printf '%s:%s:%s\n' "$label" "$pass" "$fail"
}

if (( startup_only )); then
  if ! test_startup "$low"; then
    echo "low bound does not pass startup: $low" >&2
    exit 1
  fi
  result="$(bisect_max startup "$low" "$high" test_startup)"
  pass="$(cut -d: -f2 <<<"$result")"
  fail="$(cut -d: -f3 <<<"$result")"

  printf '\nResult:\n'
  printf '  startup max passing ctx: %s\n' "$pass"
  printf '  startup min failing ctx: %s\n' "$fail"
  python3 - "$pass" "$fail" <<'PY'
import json
import sys
max_passing, min_failing = map(int, sys.argv[1:])
print(json.dumps({"startup": {"maxPassingCtx": max_passing, "minFailingCtx": min_failing}}, indent=2))
PY
  exit 0
fi

if ! test_qualified_context "$low"; then
  echo "low bound does not pass qualified context test: $low" >&2
  exit 1
fi

result="$(bisect_max context "$low" "$high" test_qualified_context)"
pass="$(cut -d: -f2 <<<"$result")"
fail="$(cut -d: -f3 <<<"$result")"

printf '\nResult:\n'
printf '  context max passing ctx: %s\n' "$pass"
printf '  context min failing ctx: %s\n' "$fail"
printf '  prompt ratio:            %s%%\n' "$prompt_ratio"
printf '  prompt turns:            %s\n' "$prompt_turns"

python3 - "$pass" "$fail" "$prompt_ratio" "$prompt_turns" <<'PY'
import json
import sys
max_passing = int(sys.argv[1])
min_failing = int(sys.argv[2])
prompt_ratio = int(sys.argv[3])
prompt_turns = int(sys.argv[4])
print(json.dumps({
    "context": {
        "maxPassingCtx": max_passing,
        "minFailingCtx": min_failing,
        "promptRatio": prompt_ratio,
        "promptTurns": prompt_turns,
    }
}, indent=2))
PY
