465 lines
12 KiB
Bash
Executable File
465 lines
12 KiB
Bash
Executable File
#!/usr/bin/env bash
|
|
set -Eeuo pipefail
|
|
|
|
usage() {
|
|
cat <<'EOF'
|
|
Usage:
|
|
llama-cpp-bisect-context MODEL --low N --high N [options]
|
|
llama-cpp-bisect-context --cmd-template CMD --low N --high N [options]
|
|
llama-cpp-bisect-context --cmd-file FILE --low N --high N [options]
|
|
|
|
Bisect the largest llama.cpp llama-server context that can start and complete
|
|
a near-context prompt without OOMing. Startup-only mode is available for isolating the first cliff.
|
|
|
|
MODEL is a llama.cpp preset generated from the Nix llama-swap config.
|
|
|
|
Command templates are evaluated with these environment variables:
|
|
PORT random listen port for this trial
|
|
CTX candidate context size
|
|
|
|
Options:
|
|
--cmd-template CMD llama-server command, e.g. 'llama-server --port ${PORT} -c ${CTX} ...'
|
|
--cmd-file FILE executable or shell snippet using $PORT and $CTX
|
|
--preset-file FILE preset JSON file (default: Nix-generated presets)
|
|
--list-presets list available Nix-generated presets and exit
|
|
--low N known/assumed lower context bound
|
|
--high N upper context bound to test
|
|
--step N stop when high-low <= N (default: 1024)
|
|
--prompt-ratio PCT prompt fill target as percent of CTX (default: 90)
|
|
--chars-per-token N rough prompt sizing ratio (default: 4)
|
|
--prompt-turns N split the prompt across N user/assistant turns (default: 4)
|
|
--max-tokens N generated tokens for prompt test (default: 32)
|
|
--startup-timeout SEC seconds to wait for /health readiness (default: 300)
|
|
--request-timeout SEC seconds to wait for prompt response (default: 600)
|
|
--cooldown SEC seconds to sleep after stopping server (default: 5)
|
|
--startup-only only test server startup, not prompt/runtime OOM
|
|
--verbose print llama-server logs for each failed trial
|
|
--keep-logs keep trial logs after a successful run too
|
|
-h, --help show this help
|
|
|
|
Examples:
|
|
llama-cpp-bisect-context \
|
|
--cmd-template 'llama-server --port ${PORT} -m model.gguf -c ${CTX} -ngl 99' \
|
|
--low 32768 --high 196608
|
|
|
|
llama-cpp-bisect-context qwen3.6-27b-ik-cuda0 --low 32768 --high 180000
|
|
llama-cpp-bisect-context --cmd-file ./server-command.sh --low 32768 --high 196608
|
|
EOF
|
|
}
|
|
|
|
preset_model=""
|
|
preset_file="__LLAMA_CPP_PRESETS__"
|
|
list_presets=0
|
|
cmd_template=""
|
|
cmd_file=""
|
|
low=""
|
|
high=""
|
|
step=1024
|
|
prompt_ratio=90
|
|
chars_per_token=4
|
|
prompt_turns=4
|
|
max_tokens=32
|
|
startup_timeout=300
|
|
request_timeout=600
|
|
cooldown=5
|
|
startup_only=0
|
|
verbose=0
|
|
keep_logs=0
|
|
|
|
while [[ $# -gt 0 ]]; do
|
|
case "$1" in
|
|
--cmd-template) cmd_template="${2:-}"; shift 2 ;;
|
|
--cmd-file) cmd_file="${2:-}"; shift 2 ;;
|
|
--preset-file) preset_file="${2:-}"; shift 2 ;;
|
|
--list-presets) list_presets=1; shift ;;
|
|
--low) low="${2:-}"; shift 2 ;;
|
|
--high) high="${2:-}"; shift 2 ;;
|
|
--step) step="${2:-}"; shift 2 ;;
|
|
--prompt-ratio) prompt_ratio="${2:-}"; shift 2 ;;
|
|
--chars-per-token) chars_per_token="${2:-}"; shift 2 ;;
|
|
--prompt-turns) prompt_turns="${2:-}"; shift 2 ;;
|
|
--max-tokens) max_tokens="${2:-}"; shift 2 ;;
|
|
--startup-timeout) startup_timeout="${2:-}"; shift 2 ;;
|
|
--request-timeout) request_timeout="${2:-}"; shift 2 ;;
|
|
--cooldown) cooldown="${2:-}"; shift 2 ;;
|
|
--startup-only) startup_only=1; shift ;;
|
|
--verbose) verbose=1; shift ;;
|
|
--keep-logs) keep_logs=1; shift ;;
|
|
-h|--help) usage; exit 0 ;;
|
|
--*) echo "unknown argument: $1" >&2; usage >&2; exit 2 ;;
|
|
*)
|
|
if [[ -n "$preset_model" ]]; then
|
|
echo "unexpected positional argument: $1" >&2
|
|
usage >&2
|
|
exit 2
|
|
fi
|
|
preset_model="$1"
|
|
shift
|
|
;;
|
|
esac
|
|
done
|
|
|
|
list_presets_json() {
|
|
python3 - "$preset_file" <<'PY'
|
|
import json
|
|
import sys
|
|
with open(sys.argv[1]) as f:
|
|
presets = json.load(f)
|
|
for key in sorted(presets):
|
|
name = presets[key].get("name", "")
|
|
print(f"{key}\t{name}" if name else key)
|
|
PY
|
|
exit 0
|
|
}
|
|
|
|
if (( list_presets )); then
|
|
list_presets_json
|
|
fi
|
|
|
|
load_preset() {
|
|
local command_file="$tmpdir/preset-command.sh"
|
|
python3 - "$preset_file" "$preset_model" "$command_file" <<'PY'
|
|
import json
|
|
import shlex
|
|
import sys
|
|
|
|
preset_file, model_id, command_file = sys.argv[1:]
|
|
with open(preset_file) as f:
|
|
presets = json.load(f)
|
|
try:
|
|
preset = presets[model_id]
|
|
except KeyError:
|
|
print(f"unknown preset: {model_id}", file=sys.stderr)
|
|
print("available presets:", file=sys.stderr)
|
|
for key in sorted(presets):
|
|
print(f" {key}", file=sys.stderr)
|
|
sys.exit(2)
|
|
|
|
cmd = preset["cmd"].replace("${ctx}", "${CTX}").replace("$ctx", "${CTX}")
|
|
env = preset.get("env", [])
|
|
with open(command_file, "w") as f:
|
|
f.write("set -e\n")
|
|
for item in env:
|
|
key, sep, value = item.partition("=")
|
|
if not sep or not key:
|
|
continue
|
|
f.write(f"export {key}={shlex.quote(value)}\n")
|
|
f.write(cmd)
|
|
if not cmd.endswith("\n"):
|
|
f.write("\n")
|
|
PY
|
|
cmd_file="$command_file"
|
|
}
|
|
|
|
require_int() {
|
|
local name="$1" value="$2"
|
|
if [[ ! "$value" =~ ^[0-9]+$ ]]; then
|
|
echo "$name must be a positive integer" >&2
|
|
exit 2
|
|
fi
|
|
}
|
|
|
|
mode_count=0
|
|
[[ -n "$preset_model" ]] && mode_count=$((mode_count + 1))
|
|
[[ -n "$cmd_template" ]] && mode_count=$((mode_count + 1))
|
|
[[ -n "$cmd_file" ]] && mode_count=$((mode_count + 1))
|
|
if (( mode_count != 1 )); then
|
|
echo "use exactly one of MODEL, --cmd-template, or --cmd-file" >&2
|
|
exit 2
|
|
fi
|
|
if [[ -z "$low" || -z "$high" ]]; then
|
|
echo "missing --low or --high" >&2
|
|
exit 2
|
|
fi
|
|
|
|
for pair in \
|
|
"--low:$low" \
|
|
"--high:$high" \
|
|
"--step:$step" \
|
|
"--prompt-ratio:$prompt_ratio" \
|
|
"--chars-per-token:$chars_per_token" \
|
|
"--prompt-turns:$prompt_turns" \
|
|
"--max-tokens:$max_tokens" \
|
|
"--startup-timeout:$startup_timeout" \
|
|
"--request-timeout:$request_timeout" \
|
|
"--cooldown:$cooldown"; do
|
|
require_int "${pair%%:*}" "${pair#*:}"
|
|
done
|
|
|
|
if (( low <= 0 || high <= low || step <= 0 || prompt_ratio <= 0 || chars_per_token <= 0 || prompt_turns <= 0 )); then
|
|
echo "invalid numeric bounds/options" >&2
|
|
exit 2
|
|
fi
|
|
|
|
if [[ -n "$cmd_file" && ! -f "$cmd_file" ]]; then
|
|
echo "cmd file not found: $cmd_file" >&2
|
|
exit 2
|
|
fi
|
|
|
|
for dep in curl python3; do
|
|
if ! command -v "$dep" >/dev/null 2>&1; then
|
|
echo "missing required command: $dep" >&2
|
|
exit 2
|
|
fi
|
|
done
|
|
|
|
tmpdir="$(mktemp -d)"
|
|
server_pid=""
|
|
log_file=""
|
|
|
|
terminate_server() {
|
|
if [[ -z "${server_pid:-}" ]]; then
|
|
return 0
|
|
fi
|
|
|
|
kill -- "-${server_pid}" >/dev/null 2>&1 || true
|
|
kill "$server_pid" >/dev/null 2>&1 || true
|
|
|
|
local waited=0
|
|
while kill -0 "$server_pid" >/dev/null 2>&1 && (( waited < 30 )); do
|
|
sleep 1
|
|
waited=$((waited + 1))
|
|
done
|
|
|
|
if kill -0 "$server_pid" >/dev/null 2>&1; then
|
|
kill -9 -- "-${server_pid}" >/dev/null 2>&1 || true
|
|
kill -9 "$server_pid" >/dev/null 2>&1 || true
|
|
fi
|
|
|
|
wait "$server_pid" >/dev/null 2>&1 || true
|
|
server_pid=""
|
|
}
|
|
|
|
cleanup() {
|
|
local status=$?
|
|
trap - EXIT INT TERM HUP
|
|
terminate_server
|
|
if (( keep_logs || status != 0 )); then
|
|
echo "logs kept in: $tmpdir" >&2
|
|
else
|
|
rm -rf "$tmpdir"
|
|
fi
|
|
}
|
|
|
|
interrupt() {
|
|
echo "interrupted; stopping llama-server" >&2
|
|
exit 130
|
|
}
|
|
|
|
trap cleanup EXIT
|
|
trap interrupt INT TERM HUP
|
|
|
|
if [[ -n "$preset_model" ]]; then
|
|
load_preset
|
|
fi
|
|
|
|
free_port() {
|
|
python3 - <<'PY'
|
|
import socket
|
|
with socket.socket() as s:
|
|
s.bind(("127.0.0.1", 0))
|
|
print(s.getsockname()[1])
|
|
PY
|
|
}
|
|
|
|
start_server() {
|
|
local ctx="$1"
|
|
PORT="$(free_port)"
|
|
CTX="$ctx"
|
|
export PORT CTX
|
|
log_file="$tmpdir/llama-server-${ctx}.log"
|
|
{
|
|
printf 'CTX=%s\n' "$CTX"
|
|
printf 'PORT=%s\n' "$PORT"
|
|
if [[ -n "$cmd_file" ]]; then
|
|
printf 'CMD_FILE=%s\n' "$cmd_file"
|
|
else
|
|
printf 'CMD_TEMPLATE=%s\n' "$cmd_template"
|
|
fi
|
|
printf -- '--- llama-server output ---\n'
|
|
} >"$log_file"
|
|
|
|
if [[ -n "$cmd_file" ]]; then
|
|
setsid bash "$cmd_file" >>"$log_file" 2>&1 &
|
|
else
|
|
setsid bash -c "$cmd_template" >>"$log_file" 2>&1 &
|
|
fi
|
|
server_pid="$!"
|
|
}
|
|
|
|
stop_server() {
|
|
terminate_server
|
|
sleep "$cooldown"
|
|
}
|
|
|
|
print_failure_log() {
|
|
local label="$1" ctx="$2"
|
|
echo "[$label] ctx=$ctx failed; log: $log_file" >&2
|
|
if (( verbose )) && [[ -f "$log_file" ]]; then
|
|
sed -n '1,220p' "$log_file" >&2 || true
|
|
fi
|
|
}
|
|
|
|
wait_ready() {
|
|
local deadline=$((SECONDS + startup_timeout))
|
|
while (( SECONDS < deadline )); do
|
|
if [[ -n "${server_pid:-}" ]] && ! kill -0 "$server_pid" >/dev/null 2>&1; then
|
|
return 1
|
|
fi
|
|
if curl -fsS --max-time 5 "http://127.0.0.1:${PORT}/health" >/dev/null 2>&1; then
|
|
return 0
|
|
fi
|
|
if curl -fsS --max-time 5 "http://127.0.0.1:${PORT}/v1/models" >/dev/null 2>&1; then
|
|
return 0
|
|
fi
|
|
sleep 2
|
|
done
|
|
return 1
|
|
}
|
|
|
|
make_prompt_json() {
|
|
local ctx="$1"
|
|
local approx_tokens=$(( ctx * prompt_ratio / 100 ))
|
|
local chars=$(( approx_tokens * chars_per_token ))
|
|
python3 - "$chars" "$max_tokens" "$prompt_turns" <<'PY'
|
|
import json
|
|
import sys
|
|
|
|
chars = int(sys.argv[1])
|
|
max_tokens = int(sys.argv[2])
|
|
prompt_turns = int(sys.argv[3])
|
|
seed = (
|
|
"This is deterministic context filler for memory testing. "
|
|
"It uses normal words so token estimates are closer to real prompts. "
|
|
)
|
|
messages = []
|
|
remaining = chars
|
|
for turn in range(prompt_turns):
|
|
turns_left = prompt_turns - turn
|
|
chunk_chars = max(1, remaining // turns_left)
|
|
content = (seed * ((chunk_chars // len(seed)) + 1))[:chunk_chars]
|
|
messages.append({"role": "user", "content": content})
|
|
remaining -= chunk_chars
|
|
if turn != prompt_turns - 1:
|
|
messages.append({"role": "assistant", "content": "Acknowledged."})
|
|
|
|
print(json.dumps({
|
|
"messages": messages,
|
|
"max_tokens": max_tokens,
|
|
"temperature": 0,
|
|
"stream": False,
|
|
}))
|
|
PY
|
|
}
|
|
|
|
run_prompt() {
|
|
local ctx="$1"
|
|
local payload="$tmpdir/prompt-${ctx}.json"
|
|
make_prompt_json "$ctx" >"$payload"
|
|
curl -fsS \
|
|
--max-time "$request_timeout" \
|
|
-H 'Content-Type: application/json' \
|
|
-d "@$payload" \
|
|
"http://127.0.0.1:${PORT}/v1/chat/completions" \
|
|
>/dev/null
|
|
}
|
|
|
|
test_startup() {
|
|
local ctx="$1"
|
|
echo "[startup] testing ctx=$ctx" >&2
|
|
start_server "$ctx"
|
|
if wait_ready; then
|
|
stop_server
|
|
echo "[startup] ctx=$ctx PASS" >&2
|
|
return 0
|
|
fi
|
|
print_failure_log startup "$ctx"
|
|
stop_server
|
|
return 1
|
|
}
|
|
|
|
test_qualified_context() {
|
|
local ctx="$1"
|
|
echo "[ctx] testing ctx=$ctx with prompt_ratio=${prompt_ratio}% prompt_turns=${prompt_turns}" >&2
|
|
start_server "$ctx"
|
|
if ! wait_ready; then
|
|
print_failure_log ctx-startup "$ctx"
|
|
stop_server
|
|
return 1
|
|
fi
|
|
if run_prompt "$ctx"; then
|
|
stop_server
|
|
echo "[ctx] ctx=$ctx PASS" >&2
|
|
return 0
|
|
fi
|
|
print_failure_log ctx-prompt "$ctx"
|
|
stop_server
|
|
return 1
|
|
}
|
|
|
|
bisect_max() {
|
|
local label="$1" pass="$2" fail="$3" fn="$4"
|
|
while (( fail - pass > step )); do
|
|
local mid=$(( (pass + fail) / 2 ))
|
|
if "$fn" "$mid"; then
|
|
pass="$mid"
|
|
else
|
|
fail="$mid"
|
|
fi
|
|
done
|
|
printf '%s:%s:%s\n' "$label" "$pass" "$fail"
|
|
}
|
|
|
|
if (( startup_only )); then
|
|
if ! test_startup "$low"; then
|
|
echo "low bound does not pass startup: $low" >&2
|
|
exit 1
|
|
fi
|
|
result="$(bisect_max startup "$low" "$high" test_startup)"
|
|
pass="$(cut -d: -f2 <<<"$result")"
|
|
fail="$(cut -d: -f3 <<<"$result")"
|
|
|
|
printf '\nResult:\n'
|
|
printf ' startup max passing ctx: %s\n' "$pass"
|
|
printf ' startup min failing ctx: %s\n' "$fail"
|
|
python3 - "$pass" "$fail" <<'PY'
|
|
import json
|
|
import sys
|
|
max_passing, min_failing = map(int, sys.argv[1:])
|
|
print(json.dumps({"startup": {"maxPassingCtx": max_passing, "minFailingCtx": min_failing}}, indent=2))
|
|
PY
|
|
exit 0
|
|
fi
|
|
|
|
if ! test_qualified_context "$low"; then
|
|
echo "low bound does not pass qualified context test: $low" >&2
|
|
exit 1
|
|
fi
|
|
|
|
result="$(bisect_max context "$low" "$high" test_qualified_context)"
|
|
pass="$(cut -d: -f2 <<<"$result")"
|
|
fail="$(cut -d: -f3 <<<"$result")"
|
|
|
|
printf '\nResult:\n'
|
|
printf ' context max passing ctx: %s\n' "$pass"
|
|
printf ' context min failing ctx: %s\n' "$fail"
|
|
printf ' prompt ratio: %s%%\n' "$prompt_ratio"
|
|
printf ' prompt turns: %s\n' "$prompt_turns"
|
|
|
|
python3 - "$pass" "$fail" "$prompt_ratio" "$prompt_turns" <<'PY'
|
|
import json
|
|
import sys
|
|
max_passing = int(sys.argv[1])
|
|
min_failing = int(sys.argv[2])
|
|
prompt_ratio = int(sys.argv[3])
|
|
prompt_turns = int(sys.argv[4])
|
|
print(json.dumps({
|
|
"context": {
|
|
"maxPassingCtx": max_passing,
|
|
"minFailingCtx": min_failing,
|
|
"promptRatio": prompt_ratio,
|
|
"promptTurns": prompt_turns,
|
|
}
|
|
}, indent=2))
|
|
PY
|