| import json |
| import subprocess |
| import sys |
| import tempfile |
| import unittest |
|
|
| from sglang.srt.debug_utils.schedule_simulator import ( |
| AttentionComputeBalancednessRecorder, |
| BatchSizeBalancednessRecorder, |
| FIFOScheduler, |
| GPUState, |
| RandomRouter, |
| RoundRobinRouter, |
| SimRequest, |
| SimulationResult, |
| Simulator, |
| StepRecord, |
| StickyRouter, |
| create_arg_parser, |
| generate_gsp_requests, |
| generate_random_requests, |
| load_from_request_logger, |
| main, |
| ) |
| from sglang.test.ci.ci_register import register_cpu_ci |
| from sglang.test.test_utils import CustomTestCase |
|
|
| register_cpu_ci(est_time=120, suite="default", nightly=True) |
|
|
|
|
| |
|
|
|
|
| class TestSimRequest(CustomTestCase): |
| def test_basic(self): |
| req = SimRequest(request_id="r1", input_len=100, output_len=50) |
| self.assertEqual(req.decoded_tokens, 0) |
| self.assertEqual(req.seq_len(), 100) |
| self.assertFalse(req.is_finished()) |
|
|
| def test_seq_len_with_decoded(self): |
| req = SimRequest( |
| request_id="r1", input_len=100, output_len=50, decoded_tokens=10 |
| ) |
| self.assertEqual(req.seq_len(), 110) |
|
|
| def test_is_finished(self): |
| req = SimRequest( |
| request_id="r1", input_len=100, output_len=50, decoded_tokens=50 |
| ) |
| self.assertTrue(req.is_finished()) |
|
|
|
|
| class TestGPUState(CustomTestCase): |
| def test_batch_size(self): |
| gpu = GPUState(gpu_id=0, max_total_tokens=10000) |
| self.assertEqual(gpu.batch_size(), 0) |
| gpu.running_requests = [ |
| SimRequest(request_id="r1", input_len=100, output_len=50), |
| SimRequest(request_id="r2", input_len=200, output_len=100), |
| ] |
| self.assertEqual(gpu.batch_size(), 2) |
|
|
| def test_total_seq_len(self): |
| gpu = GPUState(gpu_id=0, max_total_tokens=10000) |
| gpu.running_requests = [ |
| SimRequest(request_id="r1", input_len=100, output_len=50), |
| SimRequest( |
| request_id="r2", input_len=200, output_len=100, decoded_tokens=10 |
| ), |
| ] |
| self.assertEqual(gpu.total_seq_len(), 100 + 210) |
|
|
| def test_total_seq_len_shared_prefix(self): |
| gpu = GPUState(gpu_id=0, max_total_tokens=10000) |
| gpu.running_requests = [ |
| SimRequest( |
| request_id="r1", |
| input_len=150, |
| output_len=50, |
| group_id="g0", |
| prefix_len=100, |
| ), |
| SimRequest( |
| request_id="r2", |
| input_len=150, |
| output_len=50, |
| group_id="g0", |
| prefix_len=100, |
| ), |
| ] |
| self.assertEqual(gpu.total_seq_len(), 150 + 50) |
|
|
| def test_total_seq_len_shared_prefix_with_decoded(self): |
| gpu = GPUState(gpu_id=0, max_total_tokens=10000) |
| gpu.running_requests = [ |
| SimRequest( |
| request_id="r1", |
| input_len=150, |
| output_len=50, |
| decoded_tokens=10, |
| group_id="g0", |
| prefix_len=100, |
| ), |
| SimRequest( |
| request_id="r2", |
| input_len=150, |
| output_len=50, |
| decoded_tokens=5, |
| group_id="g0", |
| prefix_len=100, |
| ), |
| ] |
| self.assertEqual(gpu.total_seq_len(), 160 + 55) |
|
|
| def test_total_seq_len_multiple_groups(self): |
| gpu = GPUState(gpu_id=0, max_total_tokens=10000) |
| gpu.running_requests = [ |
| SimRequest( |
| request_id="r1", |
| input_len=150, |
| output_len=50, |
| group_id="g0", |
| prefix_len=100, |
| ), |
| SimRequest( |
| request_id="r2", |
| input_len=150, |
| output_len=50, |
| group_id="g0", |
| prefix_len=100, |
| ), |
| SimRequest( |
| request_id="r3", |
| input_len=200, |
| output_len=50, |
| group_id="g1", |
| prefix_len=150, |
| ), |
| SimRequest(request_id="r4", input_len=80, output_len=20), |
| ] |
| self.assertEqual(gpu.total_seq_len(), 150 + 50 + 200 + 80) |
|
|
|
|
| class TestRouters(CustomTestCase): |
| def test_round_robin(self): |
| router = RoundRobinRouter(num_gpus=4) |
| req = SimRequest(request_id="r1", input_len=100, output_len=50) |
| results = [router.route(req) for _ in range(8)] |
| self.assertEqual(results, [0, 1, 2, 3, 0, 1, 2, 3]) |
|
|
| def test_random_router(self): |
| router = RandomRouter(num_gpus=4) |
| req = SimRequest(request_id="r1", input_len=100, output_len=50) |
| results = [router.route(req) for _ in range(100)] |
| self.assertTrue(all(0 <= r < 4 for r in results)) |
|
|
| def test_sticky_router_same_group_same_gpu(self): |
| router = StickyRouter(num_gpus=4) |
| reqs = [ |
| SimRequest(request_id=f"r{i}", input_len=100, output_len=50, group_id="g0") |
| for i in range(10) |
| ] |
| results = [router.route(req) for req in reqs] |
| self.assertEqual(len(set(results)), 1) |
|
|
| def test_sticky_router_no_group_fallback(self): |
| router = StickyRouter(num_gpus=4) |
| reqs = [ |
| SimRequest(request_id=f"r{i}", input_len=100, output_len=50) |
| for i in range(100) |
| ] |
| results = [router.route(req) for req in reqs] |
| self.assertTrue(all(0 <= r < 4 for r in results)) |
|
|
| def test_sticky_router_multiple_groups(self): |
| router = StickyRouter(num_gpus=4) |
| for group_id in ["g0", "g1", "g2"]: |
| reqs = [ |
| SimRequest( |
| request_id=f"{group_id}_r{i}", |
| input_len=100, |
| output_len=50, |
| group_id=group_id, |
| ) |
| for i in range(5) |
| ] |
| results = [router.route(req) for req in reqs] |
| self.assertEqual(len(set(results)), 1) |
|
|
|
|
| class TestFIFOScheduler(CustomTestCase): |
| def test_runs_pending_requests(self): |
| scheduler = FIFOScheduler() |
| gpu = GPUState(gpu_id=0, max_total_tokens=10000) |
| gpu.pending_requests = [ |
| SimRequest(request_id=f"r{i}", input_len=100, output_len=50) |
| for i in range(3) |
| ] |
| scheduler.schedule(gpu) |
| self.assertEqual(len(gpu.running_requests), 3) |
| self.assertEqual(len(gpu.pending_requests), 0) |
|
|
| def test_respects_token_limit(self): |
| scheduler = FIFOScheduler() |
| gpu = GPUState(gpu_id=0, max_total_tokens=250) |
| gpu.pending_requests = [ |
| SimRequest(request_id=f"r{i}", input_len=100, output_len=50) |
| for i in range(5) |
| ] |
| scheduler.schedule(gpu) |
| self.assertEqual(len(gpu.running_requests), 2) |
| self.assertEqual(len(gpu.pending_requests), 3) |
|
|
| def test_evicts_lifo_when_over_budget(self): |
| scheduler = FIFOScheduler() |
| gpu = GPUState(gpu_id=0, max_total_tokens=250) |
| gpu.running_requests = [ |
| SimRequest(request_id=f"r{i}", input_len=100, output_len=50) |
| for i in range(3) |
| ] |
| scheduler.schedule(gpu) |
| self.assertEqual(len(gpu.running_requests), 2) |
| self.assertEqual(len(gpu.pending_requests), 1) |
| self.assertEqual(gpu.pending_requests[0].request_id, "r2") |
|
|
|
|
| class TestMetrics(CustomTestCase): |
| def test_batch_size_balancedness(self): |
| recorder = BatchSizeBalancednessRecorder() |
| gpu_states = [GPUState(gpu_id=i, max_total_tokens=10000) for i in range(2)] |
| gpu_states[0].running_requests = [ |
| SimRequest(request_id="r1", input_len=100, output_len=50) |
| ] |
| gpu_states[1].running_requests = [ |
| SimRequest(request_id="r2", input_len=100, output_len=50), |
| SimRequest(request_id="r3", input_len=100, output_len=50), |
| ] |
| recorder.on_step_end(0, gpu_states) |
| self.assertAlmostEqual( |
| recorder.get_summary()["batch_size_balancedness_mean"], 0.75 |
| ) |
|
|
| def test_attention_compute_balancedness(self): |
| recorder = AttentionComputeBalancednessRecorder() |
| gpu_states = [GPUState(gpu_id=i, max_total_tokens=10000) for i in range(2)] |
| gpu_states[0].running_requests = [ |
| SimRequest(request_id="r1", input_len=100, output_len=50) |
| ] |
| gpu_states[1].running_requests = [ |
| SimRequest(request_id="r2", input_len=200, output_len=50) |
| ] |
| recorder.on_step_end(0, gpu_states) |
| self.assertAlmostEqual( |
| recorder.get_summary()["attention_compute_balancedness_mean"], 0.75 |
| ) |
|
|
| def test_empty_history(self): |
| recorder = BatchSizeBalancednessRecorder() |
| self.assertEqual(recorder.get_summary()["batch_size_balancedness_mean"], 0.0) |
|
|
| def test_all_zero_batch_size(self): |
| recorder = BatchSizeBalancednessRecorder() |
| gpu_states = [GPUState(gpu_id=i, max_total_tokens=10000) for i in range(2)] |
| recorder.on_step_end(0, gpu_states) |
| self.assertAlmostEqual( |
| recorder.get_summary()["batch_size_balancedness_mean"], 1.0 |
| ) |
|
|
|
|
| class TestDataLoader(CustomTestCase): |
| def test_load_from_request_logger(self): |
| log_data = [ |
| {"event": "request.received", "rid": "r1", "obj": {"text": "hello"}}, |
| { |
| "event": "request.finished", |
| "rid": "r1", |
| "out": {"meta_info": {"prompt_tokens": 100, "completion_tokens": 50}}, |
| }, |
| { |
| "event": "request.finished", |
| "rid": "r2", |
| "out": {"meta_info": {"prompt_tokens": 200, "completion_tokens": 100}}, |
| }, |
| ] |
| with tempfile.NamedTemporaryFile(mode="w", suffix=".log", delete=False) as f: |
| for item in log_data: |
| f.write(json.dumps(item) + "\n") |
| f.flush() |
| requests = load_from_request_logger(f.name) |
|
|
| self.assertEqual(len(requests), 2) |
| self.assertEqual(requests[0].request_id, "r1") |
| self.assertEqual(requests[0].input_len, 100) |
| self.assertEqual(requests[1].input_len, 200) |
|
|
| def test_empty_file(self): |
| with tempfile.NamedTemporaryFile(mode="w", suffix=".log", delete=False) as f: |
| f.write("") |
| f.flush() |
| self.assertEqual(len(load_from_request_logger(f.name)), 0) |
|
|
|
|
| class TestDataSynthesis(CustomTestCase): |
| def test_generate_basic(self): |
| requests = generate_random_requests( |
| num_requests=10, input_len=100, output_len=50 |
| ) |
| self.assertEqual(len(requests), 10) |
| for req in requests: |
| self.assertEqual(req.input_len, 100) |
| self.assertEqual(req.output_len, 50) |
|
|
| def test_generate_with_range_ratio(self): |
| requests = generate_random_requests( |
| num_requests=100, input_len=100, output_len=50, range_ratio=0.5, seed=42 |
| ) |
| for req in requests: |
| self.assertGreaterEqual(req.input_len, 50) |
| self.assertLessEqual(req.input_len, 100) |
|
|
| def test_generate_with_seed(self): |
| r1 = generate_random_requests( |
| num_requests=10, input_len=100, output_len=50, range_ratio=0.5, seed=42 |
| ) |
| r2 = generate_random_requests( |
| num_requests=10, input_len=100, output_len=50, range_ratio=0.5, seed=42 |
| ) |
| for a, b in zip(r1, r2): |
| self.assertEqual(a.input_len, b.input_len) |
|
|
| def test_generate_gsp_basic(self): |
| requests = generate_gsp_requests( |
| num_groups=4, |
| prompts_per_group=3, |
| system_prompt_len=100, |
| question_len=50, |
| output_len=25, |
| seed=42, |
| ) |
| self.assertEqual(len(requests), 12) |
| for req in requests: |
| self.assertIsNotNone(req.group_id) |
| self.assertEqual(req.prefix_len, 100) |
| self.assertEqual(req.input_len, 150) |
| self.assertEqual(req.output_len, 25) |
|
|
| def test_generate_gsp_group_assignment(self): |
| requests = generate_gsp_requests( |
| num_groups=3, |
| prompts_per_group=2, |
| system_prompt_len=100, |
| question_len=50, |
| output_len=25, |
| seed=42, |
| ) |
| group_counts = {} |
| for req in requests: |
| group_counts[req.group_id] = group_counts.get(req.group_id, 0) + 1 |
| self.assertEqual(len(group_counts), 3) |
| for count in group_counts.values(): |
| self.assertEqual(count, 2) |
|
|
| def test_generate_gsp_with_range_ratio(self): |
| requests = generate_gsp_requests( |
| num_groups=4, |
| prompts_per_group=5, |
| system_prompt_len=100, |
| question_len=50, |
| output_len=25, |
| range_ratio=0.5, |
| seed=42, |
| ) |
| for req in requests: |
| self.assertGreaterEqual(req.prefix_len, 50) |
| self.assertLessEqual(req.prefix_len, 100) |
| self.assertGreaterEqual(req.input_len - req.prefix_len, 25) |
| self.assertLessEqual(req.input_len - req.prefix_len, 50) |
|
|
| def test_generate_gsp_shuffled(self): |
| requests = generate_gsp_requests( |
| num_groups=4, |
| prompts_per_group=10, |
| system_prompt_len=100, |
| question_len=50, |
| output_len=25, |
| seed=42, |
| ) |
| group_ids = [req.group_id for req in requests] |
| is_sorted = all( |
| group_ids[i] <= group_ids[i + 1] for i in range(len(group_ids) - 1) |
| ) |
| self.assertFalse(is_sorted) |
|
|
|
|
| class TestSimulator(CustomTestCase): |
| def test_basic_run(self): |
| requests = [ |
| SimRequest(request_id=f"r{i}", input_len=10, output_len=5) |
| for i in range(10) |
| ] |
| sim = Simulator( |
| num_gpus_per_engine=2, |
| router=RoundRobinRouter(num_gpus=2), |
| scheduler=FIFOScheduler(), |
| recorders=[ |
| BatchSizeBalancednessRecorder(), |
| AttentionComputeBalancednessRecorder(), |
| ], |
| max_total_tokens=100, |
| ) |
| result = sim.run(requests) |
| self.assertIsInstance(result, SimulationResult) |
| self.assertIn("batch_size_balancedness_mean", result.summary) |
| self.assertGreater(len(result.step_records), 0) |
|
|
| def test_all_requests_complete(self): |
| requests = [ |
| SimRequest(request_id=f"r{i}", input_len=10, output_len=3) for i in range(4) |
| ] |
| sim = Simulator( |
| num_gpus_per_engine=2, |
| router=RoundRobinRouter(num_gpus=2), |
| scheduler=FIFOScheduler(), |
| max_total_tokens=10000, |
| ) |
| sim.run(requests) |
| for gpu in sim.gpu_states: |
| self.assertEqual(len(gpu.pending_requests), 0) |
| self.assertEqual(len(gpu.running_requests), 0) |
|
|
| def test_empty_requests(self): |
| sim = Simulator( |
| num_gpus_per_engine=2, |
| router=RoundRobinRouter(num_gpus=2), |
| scheduler=FIFOScheduler(), |
| ) |
| result = sim.run([]) |
| self.assertEqual(result.summary, {}) |
| self.assertEqual(len(result.step_records), 0) |
|
|
| def test_step_records(self): |
| requests = [ |
| SimRequest(request_id=f"r{i}", input_len=10, output_len=3) for i in range(4) |
| ] |
| sim = Simulator( |
| num_gpus_per_engine=2, |
| router=RoundRobinRouter(num_gpus=2), |
| scheduler=FIFOScheduler(), |
| max_total_tokens=10000, |
| ) |
| result = sim.run(requests) |
| self.assertGreater(len(result.step_records), 0) |
| for record in result.step_records: |
| self.assertIsInstance(record, StepRecord) |
| self.assertIn(record.gpu_id, [0, 1]) |
| self.assertEqual(len([r for r in result.step_records if r.step == 0]), 2) |
|
|
| def test_preemption_due_to_token_growth(self): |
| requests = [ |
| SimRequest(request_id="r0", input_len=50, output_len=10), |
| SimRequest(request_id="r1", input_len=50, output_len=10), |
| ] |
| sim = Simulator( |
| num_gpus_per_engine=1, |
| router=RoundRobinRouter(num_gpus=1), |
| scheduler=FIFOScheduler(), |
| max_total_tokens=110, |
| ) |
| result = sim.run(requests) |
|
|
| found_preemption = False |
| for record in result.step_records: |
| if record.running_count == 1 and record.pending_count == 1: |
| found_preemption = True |
| break |
| self.assertTrue( |
| found_preemption, "Expected preemption to occur due to token growth" |
| ) |
|
|
|
|
| |
|
|
|
|
| class TestCLI(CustomTestCase): |
| def _run_cli(self, *args): |
| return subprocess.run( |
| [sys.executable, "-m", "sglang.srt.debug_utils.schedule_simulator", *args], |
| capture_output=True, |
| text=True, |
| ) |
|
|
| def _assert_output_contains(self, output: str, expected_lines: str): |
| for line in expected_lines.strip().split("\n"): |
| self.assertIn(line, output) |
|
|
| def test_cli_basic(self): |
| log_data = [ |
| { |
| "event": "request.finished", |
| "rid": "r1", |
| "out": {"meta_info": {"prompt_tokens": 100, "completion_tokens": 50}}, |
| }, |
| { |
| "event": "request.finished", |
| "rid": "r2", |
| "out": {"meta_info": {"prompt_tokens": 200, "completion_tokens": 100}}, |
| }, |
| ] |
| with tempfile.NamedTemporaryFile(mode="w", suffix=".log", delete=False) as f: |
| for item in log_data: |
| f.write(json.dumps(item) + "\n") |
| input_file = f.name |
| with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: |
| output_file = f.name |
|
|
| result = self._run_cli( |
| "--input", input_file, "--num-gpus-per-engine", "2", "--output", output_file |
| ) |
| self.assertEqual(result.returncode, 0, f"CLI failed: {result.stderr}") |
| self.assertIn("Loaded 2 requests", result.stdout) |
| with open(output_file) as f: |
| self.assertIn("batch_size_balancedness_mean", json.load(f)) |
|
|
| def test_cli_random_router(self): |
| log_data = [ |
| { |
| "event": "request.finished", |
| "rid": "r1", |
| "out": {"meta_info": {"prompt_tokens": 100, "completion_tokens": 50}}, |
| } |
| ] |
| with tempfile.NamedTemporaryFile(mode="w", suffix=".log", delete=False) as f: |
| for item in log_data: |
| f.write(json.dumps(item) + "\n") |
| input_file = f.name |
|
|
| result = self._run_cli("--input", input_file, "--router", "random") |
| self.assertEqual(result.returncode, 0, f"CLI failed: {result.stderr}") |
| self.assertIn("router=random", result.stdout) |
|
|
| def test_e2e_sticky_router_group_locality(self): |
| result = self._run_cli( |
| "--synth-gsp", |
| "--synth-gsp-num-groups", |
| "1", |
| "--synth-gsp-prompts-per-group", |
| "4", |
| "--synth-gsp-system-prompt-len", |
| "10", |
| "--synth-gsp-question-len", |
| "10", |
| "--synth-gsp-output-len", |
| "2", |
| "--synth-seed", |
| "42", |
| "--num-gpus-per-engine", |
| "2", |
| "--router", |
| "sticky", |
| "--max-total-tokens", |
| "1000", |
| "--log-level", |
| "2", |
| ) |
| self.assertEqual(result.returncode, 0, f"CLI failed: {result.stderr}") |
| self.assertIn("R=4:", result.stdout) |
| self.assertIn("R=0:-", result.stdout) |
|
|
| def test_cli_synthetic(self): |
| result = self._run_cli( |
| "--synthetic", |
| "--synth-random-num-requests", |
| "100", |
| "--synth-random-input-len", |
| "512", |
| "--synth-random-output-len", |
| "128", |
| "--synth-random-range-ratio", |
| "0.5", |
| "--num-gpus-per-engine", |
| "4", |
| ) |
| self.assertEqual(result.returncode, 0, f"CLI failed: {result.stderr}") |
| self.assertIn("Generated 100 random requests", result.stdout) |
|
|
| def test_cli_log_level(self): |
| result = self._run_cli( |
| "--synthetic", |
| "--synth-random-num-requests", |
| "10", |
| "--synth-random-output-len", |
| "5", |
| "--num-gpus-per-engine", |
| "2", |
| "--log-level", |
| "1", |
| ) |
| self.assertEqual(result.returncode, 0, f"CLI failed: {result.stderr}") |
| self.assertIn("step=", result.stdout) |
|
|
| def test_e2e_simple_no_queuing(self): |
| result = self._run_cli( |
| "--synthetic", |
| "--synth-random-num-requests", |
| "4", |
| "--synth-random-input-len", |
| "10", |
| "--synth-random-output-len", |
| "2", |
| "--synth-random-range-ratio", |
| "1.0", |
| "--synth-seed", |
| "42", |
| "--num-gpus-per-engine", |
| "2", |
| "--max-total-tokens", |
| "10000", |
| "--log-level", |
| "2", |
| ) |
| self.assertEqual(result.returncode, 0, f"CLI failed: {result.stderr}") |
| self.assertIn( |
| "step=0 | GPU0[R=2:syn0,syn2 Q=0:-] | GPU1[R=2:syn1,syn3 Q=0:-]", |
| result.stdout, |
| ) |
| self.assertIn( |
| "step=1 | GPU0[R=0:- Q=0:-] | GPU1[R=0:- Q=0:-]", result.stdout |
| ) |
| self.assertIn("batch_size_balancedness_mean: 1.0000", result.stdout) |
|
|
| def test_e2e_queuing_due_to_token_limit(self): |
| result = self._run_cli( |
| "--synthetic", |
| "--synth-random-num-requests", |
| "4", |
| "--synth-random-input-len", |
| "100", |
| "--synth-random-output-len", |
| "3", |
| "--synth-random-range-ratio", |
| "1.0", |
| "--synth-seed", |
| "42", |
| "--num-gpus-per-engine", |
| "1", |
| "--max-total-tokens", |
| "210", |
| "--log-level", |
| "2", |
| ) |
| self.assertEqual(result.returncode, 0, f"CLI failed: {result.stderr}") |
| self._assert_output_contains( |
| result.stdout, |
| """ |
| step=0 | GPU0[R=2:syn0,syn1 Q=2:syn2,syn3] |
| step=1 | GPU0[R=2:syn0,syn1 Q=2:syn2,syn3] |
| step=2 | GPU0[R=0:- Q=2:syn2,syn3] |
| step=3 | GPU0[R=2:syn2,syn3 Q=0:-] |
| step=4 | GPU0[R=2:syn2,syn3 Q=0:-] |
| step=5 | GPU0[R=0:- Q=0:-]""", |
| ) |
|
|
| def test_e2e_retraction_due_to_token_growth(self): |
| result = self._run_cli( |
| "--synthetic", |
| "--synth-random-num-requests", |
| "2", |
| "--synth-random-input-len", |
| "50", |
| "--synth-random-output-len", |
| "10", |
| "--synth-random-range-ratio", |
| "1.0", |
| "--synth-seed", |
| "42", |
| "--num-gpus-per-engine", |
| "1", |
| "--max-total-tokens", |
| "110", |
| "--log-level", |
| "2", |
| ) |
| self.assertEqual(result.returncode, 0, f"CLI failed: {result.stderr}") |
| self._assert_output_contains( |
| result.stdout, |
| """ |
| step=0 | GPU0[R=2:syn0,syn1 Q=0:-] |
| step=5 | GPU0[R=2:syn0,syn1 Q=0:-] |
| step=6 | GPU0[R=1:syn0 Q=1:syn1] |
| step=9 | GPU0[R=0:- Q=1:syn1] |
| step=10 | GPU0[R=1:syn1 Q=0:-] |
| step=13 | GPU0[R=0:- Q=0:-]""", |
| ) |
|
|
| def test_cli_gsp_basic(self): |
| result = self._run_cli( |
| "--synth-gsp", |
| "--synth-gsp-num-groups", |
| "4", |
| "--synth-gsp-prompts-per-group", |
| "8", |
| "--synth-gsp-system-prompt-len", |
| "100", |
| "--synth-gsp-question-len", |
| "50", |
| "--synth-gsp-output-len", |
| "10", |
| "--synth-seed", |
| "42", |
| "--num-gpus-per-engine", |
| "2", |
| ) |
| self.assertEqual(result.returncode, 0, f"CLI failed: {result.stderr}") |
| self.assertIn("Generated 32 GSP requests", result.stdout) |
| self.assertIn("4 groups x 8 prompts", result.stdout) |
|
|
| def test_e2e_gsp_shared_prefix_enables_batching(self): |
| for has_long_prefix in [True, False]: |
| prefix_len, question_len = (50, 10) if has_long_prefix else (10, 50) |
| result = self._run_cli( |
| "--synth-gsp", |
| "--synth-gsp-num-groups", |
| "1", |
| "--synth-gsp-prompts-per-group", |
| "2", |
| "--synth-gsp-system-prompt-len", |
| str(prefix_len), |
| "--synth-gsp-question-len", |
| str(question_len), |
| "--synth-gsp-output-len", |
| "2", |
| "--synth-seed", |
| "42", |
| "--num-gpus-per-engine", |
| "1", |
| "--max-total-tokens", |
| "80", |
| "--log-level", |
| "2", |
| ) |
| self.assertEqual(result.returncode, 0, f"CLI failed: {result.stderr}") |
| if has_long_prefix: |
| self.assertIn("R=2:", result.stdout) |
| else: |
| self.assertNotIn("R=2:", result.stdout) |
|
|
|
|
| class TestLargerScale(CustomTestCase): |
| def _run_main(self, *cli_args) -> SimulationResult: |
| parser = create_arg_parser() |
| args = parser.parse_args(cli_args) |
| return main(args) |
|
|
| def _assert_in_range(self, value, lo, hi, name): |
| self.assertGreaterEqual(value, lo, f"{name}={value} < {lo}") |
| self.assertLessEqual(value, hi, f"{name}={value} > {hi}") |
|
|
| def test_vanilla_workload_random_policy(self): |
| result = self._run_main( |
| "--synthetic", |
| "--synth-random-num-requests", |
| "500000", |
| "--synth-random-input-len", |
| "32000", |
| "--synth-random-output-len", |
| "2000", |
| "--synth-seed", |
| "42", |
| "--num-gpus-per-engine", |
| "8", |
| "--num-engines", |
| "250", |
| "--router", |
| "random", |
| "--max-total-tokens", |
| "2000000", |
| "--stop-criteria", |
| "exist_no_pending", |
| "--max-steps", |
| "1500", |
| ) |
| self._assert_in_range( |
| result.summary["attention_compute_balancedness_mean"], 0.95, 1.0, "attn" |
| ) |
| self._assert_in_range( |
| result.summary["batch_size_balancedness_mean"], 0.90, 0.98, "bs" |
| ) |
| self._assert_in_range(result.summary["avg_batch_size"], 127, 141, "avg_bs") |
|
|
| def _run_gsp_workload(self, router: str) -> SimulationResult: |
| return self._run_main( |
| "--synth-gsp", |
| "--synth-gsp-num-groups", |
| "50000", |
| "--synth-gsp-prompts-per-group", |
| "100", |
| "--synth-gsp-system-prompt-len", |
| "31000", |
| "--synth-gsp-question-len", |
| "1000", |
| "--synth-gsp-output-len", |
| "8000", |
| "--synth-seed", |
| "42", |
| "--num-gpus-per-engine", |
| "8", |
| "--num-engines", |
| "250", |
| "--router", |
| router, |
| "--max-total-tokens", |
| "500000", |
| "--stop-criteria", |
| "exist_no_pending", |
| "--max-steps", |
| "1500", |
| ) |
|
|
| def test_gsp_workload_random_policy(self): |
| result = self._run_gsp_workload("random") |
| self._assert_in_range( |
| result.summary["attention_compute_balancedness_mean"], 0.90, 0.97, "attn" |
| ) |
| self._assert_in_range( |
| result.summary["batch_size_balancedness_mean"], 0.90, 0.97, "bs" |
| ) |
| self._assert_in_range(result.summary["avg_batch_size"], 14, 17, "avg_bs") |
|
|
| def test_gsp_workload_sticky_policy(self): |
| result = self._run_gsp_workload("sticky") |
| self._assert_in_range( |
| result.summary["attention_compute_balancedness_mean"], 0.64, 0.71, "attn" |
| ) |
| self._assert_in_range( |
| result.summary["batch_size_balancedness_mean"], 0.64, 0.71, "bs" |
| ) |
| self._assert_in_range(result.summary["avg_batch_size"], 31, 36, "avg_bs") |
|
|
|
|
| if __name__ == "__main__": |
| unittest.main() |
|
|