Hanrui / dflash /benchmark_sglang.py
Lekr0's picture
Add files using upload-large-folder tool
4024ed7 verified
from __future__ import annotations
import argparse
import time
import statistics
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass
from typing import Optional
import requests
import torch
from transformers import AutoTokenizer
from model import load_and_process_dataset
from sglang.srt.environ import envs
from sglang.srt.utils import get_device_sm, kill_process_tree
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
find_available_port,
popen_launch_server,
)
def _is_blackwell() -> bool:
if envs.IS_BLACKWELL.get():
return True
return get_device_sm() >= 100
def _flush_cache(base_url: str) -> None:
resp = requests.get(base_url + "/flush_cache", timeout=60)
resp.raise_for_status()
def _send_generate(
base_url: str,
prompt: str,
*,
max_new_tokens: int,
stop: list[str],
timeout_s: int,
) -> dict:
sampling_params: dict = {
"temperature": 0.0,
"top_p": 1.0,
"top_k": 1,
"max_new_tokens": int(max_new_tokens),
}
if stop:
sampling_params["stop"] = stop
resp = requests.post(
base_url + "/generate",
json={
"text": prompt,
"sampling_params": sampling_params,
},
timeout=int(timeout_s),
)
resp.raise_for_status()
return resp.json()
def _send_generate_batch(
base_url: str,
prompts: list[str],
*,
max_new_tokens: int,
stop: list[str],
timeout_s: int,
) -> list[dict]:
if not prompts:
return []
sampling_params: dict = {
"temperature": 0.0,
"top_p": 1.0,
"top_k": 1,
"max_new_tokens": int(max_new_tokens),
}
if stop:
sampling_params["stop"] = stop
resp = requests.post(
base_url + "/generate",
json={
"text": prompts,
"sampling_params": sampling_params,
},
timeout=int(timeout_s),
)
resp.raise_for_status()
out = resp.json()
if not isinstance(out, list):
raise RuntimeError(
"Expected a list response for batched /generate, but got "
f"type={type(out).__name__}."
)
return out
@dataclass(frozen=True)
class BenchMetrics:
latency_s: float
output_tokens: int
output_toks_per_s: float
spec_accept_length: Optional[float]
spec_verify_ct_sum: int
def _run_bench_requests(
base_url: str,
*,
prompts: list[str],
max_new_tokens: int,
concurrency: int,
batch_requests: bool,
stop: list[str],
timeout_s: int,
expect_dflash: bool,
) -> BenchMetrics:
# Drop the first batch from metrics to exclude one-time JIT/cuda-graph overhead
bs = max(int(concurrency), 1)
if len(prompts) > bs:
warmup_prompts = prompts[:bs]
if batch_requests:
_send_generate_batch(
base_url,
warmup_prompts,
max_new_tokens=max_new_tokens,
stop=stop,
timeout_s=timeout_s,
)
else:
with ThreadPoolExecutor(max_workers=int(concurrency)) as pool:
futures = [
pool.submit(
_send_generate,
base_url,
prompt,
max_new_tokens=max_new_tokens,
stop=stop,
timeout_s=timeout_s,
)
for prompt in warmup_prompts
]
for fut in as_completed(futures):
fut.result()
prompts = prompts[bs:]
start = time.perf_counter()
total_tokens = 0
spec_verify_ct_sum = 0
spec_accept_lengths: list[float] = []
if batch_requests:
bs = max(int(concurrency), 1)
for start_idx in range(0, len(prompts), bs):
chunk_prompts = prompts[start_idx : start_idx + bs]
outs = _send_generate_batch(
base_url,
chunk_prompts,
max_new_tokens=max_new_tokens,
stop=stop,
timeout_s=timeout_s,
)
if len(outs) != len(chunk_prompts):
raise RuntimeError(
"Batched /generate output length mismatch: "
f"got {len(outs)} outputs for {len(chunk_prompts)} prompts."
)
for j, out in enumerate(outs):
meta = out.get("meta_info", {}) or {}
total_tokens += int(meta.get("completion_tokens", 0))
spec_verify_ct_sum += int(meta.get("spec_verify_ct", 0))
if "spec_accept_length" in meta:
try:
spec_accept_lengths.append(float(meta["spec_accept_length"]))
except (TypeError, ValueError):
pass
else:
with ThreadPoolExecutor(max_workers=int(concurrency)) as pool:
futures = {
pool.submit(
_send_generate,
base_url,
prompt,
max_new_tokens=max_new_tokens,
stop=stop,
timeout_s=timeout_s,
): i
for i, prompt in enumerate(prompts)
}
for fut in as_completed(futures):
out = fut.result()
meta = out.get("meta_info", {}) or {}
total_tokens += int(meta.get("completion_tokens", 0))
spec_verify_ct_sum += int(meta.get("spec_verify_ct", 0))
if "spec_accept_length" in meta:
try:
spec_accept_lengths.append(float(meta["spec_accept_length"]))
except (TypeError, ValueError):
pass
latency = time.perf_counter() - start
toks_per_s = total_tokens / max(latency, 1e-6)
if expect_dflash and spec_verify_ct_sum <= 0:
raise RuntimeError(
"DFLASH sanity check failed: did not observe any `spec_verify_ct` in responses "
"(DFLASH may not have been enabled)."
)
spec_accept_length = (
float(statistics.mean(spec_accept_lengths)) if spec_accept_lengths else None
)
return BenchMetrics(
latency_s=float(latency),
output_tokens=int(total_tokens),
output_toks_per_s=float(toks_per_s),
spec_accept_length=spec_accept_length,
spec_verify_ct_sum=int(spec_verify_ct_sum),
)
def _format_table(
*,
concurrencies: list[int],
values: dict[int, Optional[float]],
float_fmt: str,
) -> str:
header = ["conc"] + [str(c) for c in concurrencies]
lines = [
"| " + " | ".join(header) + " |",
"| " + " | ".join(["---"] * len(header)) + " |",
]
row = ["value"]
for c in concurrencies:
v = values.get(c, None)
row.append("N/A" if v is None else format(v, float_fmt))
lines.append("| " + " | ".join(row) + " |")
return "\n".join(lines)
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument(
"--output-md",
type=str,
default=None,
help="Write a markdown report to this file (disabled by default).",
)
parser.add_argument("--dataset-name", type=str, default="gsm8k")
parser.add_argument("--target-model", type=str, default="Qwen/Qwen3-8B")
parser.add_argument("--draft-model", type=str, default="z-lab/Qwen3-8B-DFlash-b16")
parser.add_argument(
"--skip-baseline",
action="store_true",
help="Skip running the baseline (target-only) sweep; only run DFLASH and report N/A for baseline/speedup.",
)
parser.add_argument(
"--batch-requests",
action="store_true",
help="Send prompts as server-side batched /generate requests (batch size = concurrency) instead of client-side concurrent requests.",
)
parser.add_argument("--max-new-tokens", type=int, default=2048)
parser.add_argument("--timeout-s", type=int, default=3600)
parser.add_argument("--mem-fraction-static", type=float, default=0.75)
parser.add_argument("--disable-radix-cache", action="store_true")
parser.add_argument("--dtype", type=str, default="bfloat16")
parser.add_argument("--max-running-requests", type=int, default=64)
parser.add_argument(
"--tp-size",
type=int,
default=1,
help="Tensor parallel size (single value, no sweep).",
)
parser.add_argument(
"--concurrencies",
type=str,
default="1,2,4,8,16,32",
help="Comma-separated list of client concurrency levels.",
)
parser.add_argument(
"--questions-per-concurrency-base",
type=int,
default=128,
help="num_questions = base * concurrency (default matches the sweep plan).",
)
parser.add_argument(
"--max-questions-per-config",
type=int,
default=1024,
help="Cap num_questions per (tp, concurrency) run (default: 1024).",
)
parser.add_argument(
"--attention-backends",
type=str,
default="flashinfer,fa3,fa4",
help="Comma-separated list. Will auto-skip fa3 unless SM90 (Hopper), and fa4 unless SM100+ (Blackwell).",
)
args = parser.parse_args()
if not torch.cuda.is_available():
raise RuntimeError("CUDA is required for this sweep.")
concurrencies = [int(x) for x in args.concurrencies.split(",") if x.strip()]
concurrencies = [c for c in concurrencies if c >= 1]
if not concurrencies:
raise RuntimeError("No concurrencies specified.")
num_questions_by_conc = {
c: min(int(args.questions_per_concurrency_base) * int(c), int(args.max_questions_per_config))
for c in concurrencies
}
max_questions = max(num_questions_by_conc.values())
max_concurrency = max(concurrencies)
attention_backends = [s.strip() for s in args.attention_backends.split(",") if s.strip()]
is_blackwell = _is_blackwell()
device_sm = get_device_sm()
if device_sm != 90:
attention_backends = [b for b in attention_backends if b != "fa3"]
if device_sm < 100:
attention_backends = [b for b in attention_backends if b != "fa4"]
attention_backends = attention_backends or ["flashinfer"]
# --- Load Data using the new function ---
print(f"Loading dataset: {args.dataset_name}...")
dataset = load_and_process_dataset(args.dataset_name)
required_questions = max_questions + max_concurrency
if len(dataset) < required_questions:
print(f"Warning: Dataset has {len(dataset)} items, but need up to {required_questions}. Reusing items.")
tokenizer = AutoTokenizer.from_pretrained(args.target_model)
prompts: list[str] = []
# Build prompts list
for i in range(max(len(dataset), required_questions)):
item = dataset[i % len(dataset)]
user_content = item["turns"][0] # Extract the formatted turn
# Apply chat template
prompt_text = tokenizer.apply_chat_template(
[{"role": "user", "content": user_content}],
tokenize=False,
add_generation_prompt=True,
enable_thinking=False,
)
prompts.append(prompt_text)
if len(prompts) >= required_questions:
break
# Results indexed by (backend, concurrency) for baseline + dflash.
# Removed TP dimension from keys since we aren't sweeping it.
baseline_toks: dict[tuple[str, int], Optional[float]] = {}
dflash_toks: dict[tuple[str, int], Optional[float]] = {}
dflash_accept_len: dict[tuple[str, int], Optional[float]] = {}
tp = args.tp_size # Fixed TP size
for backend in attention_backends:
port_base = find_available_port(20000)
common_server_args: list[str] = [
"--trust-remote-code",
"--attention-backend",
backend,
"--tp-size",
str(tp),
"--dtype",
str(args.dtype),
"--mem-fraction-static",
str(args.mem_fraction_static),
"--max-running-requests",
str(args.max_running_requests),
]
common_server_args.extend(
["--cuda-graph-bs", *[str(i) for i in range(1, 33)], "--cuda-graph-max-bs", "32"]
)
if args.disable_radix_cache:
common_server_args.append("--disable-radix-cache")
if not args.skip_baseline:
print(f"\n=== backend={backend} tp={tp} (baseline) ===")
baseline_port = port_base
baseline_url = f"http://127.0.0.1:{baseline_port}"
baseline_proc = popen_launch_server(
args.target_model,
baseline_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=common_server_args,
)
try:
# Warm up.
_send_generate(
baseline_url,
"Hello",
max_new_tokens=8,
stop=[],
timeout_s=min(int(args.timeout_s), 300),
)
for conc in concurrencies:
n = num_questions_by_conc[conc]
_flush_cache(baseline_url)
print(
f"[warmup] run 1 warmup batch (size={conc}) after /flush_cache; excluded from metrics."
)
metrics = _run_bench_requests(
baseline_url,
prompts=prompts[: n + conc],
max_new_tokens=int(args.max_new_tokens),
concurrency=int(conc),
batch_requests=bool(args.batch_requests),
stop=[],
timeout_s=int(args.timeout_s),
expect_dflash=False,
)
baseline_toks[(backend, conc)] = metrics.output_toks_per_s
print(
f"[baseline] conc={conc:>2} n={n:<4} "
f"toks/s={metrics.output_toks_per_s:,.2f} "
f"latency={metrics.latency_s:.1f}s "
)
finally:
kill_process_tree(baseline_proc.pid)
try:
baseline_proc.wait(timeout=30)
except Exception:
pass
print(f"\n=== backend={backend} tp={tp} (DFLASH) ===")
dflash_port = find_available_port(port_base + 1)
dflash_url = f"http://127.0.0.1:{dflash_port}"
dflash_proc = popen_launch_server(
args.target_model,
dflash_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
*common_server_args,
"--speculative-algorithm",
"DFLASH",
"--speculative-draft-model-path",
args.draft_model,
],
)
try:
_send_generate(
dflash_url,
"Hello",
max_new_tokens=8,
stop=[],
timeout_s=min(int(args.timeout_s), 300),
)
for conc in concurrencies:
n = num_questions_by_conc[conc]
_flush_cache(dflash_url)
print(
f"[warmup] run 1 warmup batch (size={conc}) after /flush_cache; excluded from metrics."
)
metrics = _run_bench_requests(
dflash_url,
prompts=prompts[: n + conc],
max_new_tokens=int(args.max_new_tokens),
concurrency=int(conc),
batch_requests=bool(args.batch_requests),
stop=[],
timeout_s=int(args.timeout_s),
expect_dflash=True,
)
dflash_toks[(backend, conc)] = metrics.output_toks_per_s
dflash_accept_len[(backend, conc)] = metrics.spec_accept_length
print(
f"[DFLASH] conc={conc:>2} n={n:<4} "
f"toks/s={metrics.output_toks_per_s:,.2f} "
f"latency={metrics.latency_s:.1f}s "
f"accept_len={metrics.spec_accept_length:.3f} "
f"spec_verify_ct_sum={metrics.spec_verify_ct_sum}"
)
finally:
kill_process_tree(dflash_proc.pid)
try:
dflash_proc.wait(timeout=30)
except Exception:
pass
# Render markdown.
md_lines: list[str] = []
md_lines.append("# DFLASH Bench Report")
md_lines.append("")
md_lines.append("## Settings")
md_lines.append(f"- dataset: `{args.dataset_name}`")
md_lines.append(f"- target_model: `{args.target_model}`")
md_lines.append(f"- draft_model: `{args.draft_model}`")
md_lines.append(f"- max_new_tokens: `{args.max_new_tokens}`")
md_lines.append(f"- attention_backends: `{', '.join(attention_backends)}`")
md_lines.append(f"- tp_size: `{tp}`")
md_lines.append(f"- concurrencies: `{', '.join(str(x) for x in concurrencies)}`")
md_lines.append(f"- questions_per_concurrency: `base={args.questions_per_concurrency_base}`")
md_lines.append(f"- device_sm: `{device_sm}`")
md_lines.append(f"- is_blackwell: `{is_blackwell}`")
md_lines.append(f"- skip_baseline: `{bool(args.skip_baseline)}`")
md_lines.append("- drop_first_batch: `true`")
md_lines.append("")
for backend in attention_backends:
md_lines.append(f"## Backend: `{backend}`")
md_lines.append("")
baseline_values = {
c: baseline_toks.get((backend, c), None) for c in concurrencies
}
dflash_values = {
c: dflash_toks.get((backend, c), None) for c in concurrencies
}
speedup_values: dict[int, Optional[float]] = {}
for c in concurrencies:
b = baseline_values.get(c, None)
d = dflash_values.get(c, None)
speedup_values[c] = None if (b is None or d is None or b <= 0) else (d / b)
md_lines.append("### Baseline output tok/s")
md_lines.append(
_format_table(
concurrencies=concurrencies,
values=baseline_values,
float_fmt=",.2f",
)
)
md_lines.append("")
md_lines.append("### DFLASH output tok/s")
md_lines.append(
_format_table(
concurrencies=concurrencies,
values=dflash_values,
float_fmt=",.2f",
)
)
md_lines.append("")
md_lines.append("### Speedup (DFLASH / baseline)")
md_lines.append(
_format_table(
concurrencies=concurrencies,
values=speedup_values,
float_fmt=".3f",
)
)
md_lines.append("")
md_lines.append("### DFLASH acceptance length")
md_lines.append(
_format_table(
concurrencies=concurrencies,
values={
c: dflash_accept_len.get((backend, c), None)
for c in concurrencies
},
float_fmt=".3f",
)
)
md_lines.append("")
if args.output_md:
with open(args.output_md, "w", encoding="utf-8") as f:
f.write("\n".join(md_lines))
f.write("\n")
print(f"\nWrote markdown report to: {args.output_md}")
else:
print("\nMarkdown report disabled (pass --output-md to write one).")
if __name__ == "__main__":
main()