Lekr0's picture
Add files using upload-large-folder tool
a402b9b verified
import sys
from pathlib import Path
import pytest
import torch
from sglang.srt.debug_utils.comparator.output_types import SummaryRecord
from sglang.srt.debug_utils.comparator.utils import (
Pair,
argmax_coord,
auto_descend_dir,
calc_per_token_rel_diff,
calc_rel_diff,
compute_exit_code,
compute_smaller_dtype,
try_unify_shape,
)
from sglang.test.ci.ci_register import register_cpu_ci
register_cpu_ci(est_time=10, suite="default", nightly=True)
class TestCalcRelDiff:
def test_identical_tensors(self):
x = torch.randn(10, 10)
assert calc_rel_diff(x, x).item() == pytest.approx(0.0, abs=1e-5)
def test_orthogonal_tensors(self):
result = calc_rel_diff(
torch.tensor([1.0, 0.0]), torch.tensor([0.0, 1.0])
).item()
assert result == pytest.approx(1.0, abs=1e-5)
def test_similar_tensors(self):
x = torch.tensor([1.0, 2.0, 3.0])
y = torch.tensor([1.01, 2.01, 3.01])
result = calc_rel_diff(x, y).item()
assert 0.0 < result < 0.01
def test_negated_tensors(self):
x = torch.tensor([1.0, 2.0])
result = calc_rel_diff(x, -x).item()
assert result == pytest.approx(2.0, abs=1e-5)
class TestCalcPerTokenRelDiff:
def test_identical_tensors(self) -> None:
"""Identical tensors β†’ per-token diff all zero."""
x: torch.Tensor = torch.randn(8, 16)
result: torch.Tensor = calc_per_token_rel_diff(x, x, seq_dim=0)
assert result.shape == (8,)
assert torch.allclose(result, torch.zeros(8), atol=1e-6)
def test_different_tensors(self) -> None:
"""Single token position differs β†’ that position has higher diff."""
torch.manual_seed(42)
x: torch.Tensor = torch.randn(8, 16)
y: torch.Tensor = x.clone()
y[3, :] += 10.0
result: torch.Tensor = calc_per_token_rel_diff(x, y, seq_dim=0)
assert result.shape == (8,)
assert result[3] > result[0]
assert result[3] > result[7]
for i in [0, 1, 2, 4, 5, 6, 7]:
assert result[i] < 1e-6
def test_seq_dim_selection(self) -> None:
"""Different seq_dim values produce correct output shapes."""
x: torch.Tensor = torch.randn(4, 8, 16)
y: torch.Tensor = x + torch.randn_like(x) * 0.01
assert calc_per_token_rel_diff(x, y, seq_dim=0).shape == (4,)
assert calc_per_token_rel_diff(x, y, seq_dim=1).shape == (8,)
assert calc_per_token_rel_diff(x, y, seq_dim=2).shape == (16,)
def test_1d_tensor(self) -> None:
"""1D tensor with seq_dim=0 returns per-element diff."""
x: torch.Tensor = torch.tensor([1.0, 2.0, 3.0])
y: torch.Tensor = torch.tensor([1.0, 2.0, 4.0])
result: torch.Tensor = calc_per_token_rel_diff(x, y, seq_dim=0)
assert result.shape == (3,)
assert result[0] < 1e-6
assert result[1] < 1e-6
assert result[2] > 0.01
class TestArgmaxCoord:
def test_1d_tensor(self):
x = torch.tensor([0.0, 0.0, 5.0, 0.0])
assert argmax_coord(x) == (2,)
def test_2d_tensor(self):
x = torch.zeros(3, 4)
x[1, 2] = 10.0
assert argmax_coord(x) == (1, 2)
def test_3d_tensor(self):
x = torch.zeros(2, 3, 4)
x[1, 2, 3] = 10.0
assert argmax_coord(x) == (1, 2, 3)
class TestTryUnifyShape:
def test_squeeze_leading_ones(self):
target = torch.Size([3, 4])
assert try_unify_shape(torch.randn(1, 1, 3, 4), target).shape == target
def test_no_squeeze_when_leading_dim_not_one(self):
target = torch.Size([3, 4])
assert try_unify_shape(torch.randn(2, 3, 4), target).shape == (2, 3, 4)
def test_same_shape_noop(self):
target = torch.Size([3, 4])
x = torch.randn(3, 4)
result = try_unify_shape(x, target)
assert result.shape == target
assert result.data_ptr() == x.data_ptr()
def test_trailing_dims_mismatch(self):
target = torch.Size([5, 6])
x = torch.randn(1, 3, 4)
result = try_unify_shape(x, target)
assert result.shape == (1, 3, 4)
class TestComputeSmallerDtype:
def test_float32_bfloat16(self):
assert (
compute_smaller_dtype(Pair(x=torch.float32, y=torch.bfloat16))
== torch.bfloat16
)
def test_reverse_order(self):
assert (
compute_smaller_dtype(Pair(x=torch.bfloat16, y=torch.float32))
== torch.bfloat16
)
def test_same_dtype_returns_none(self):
assert compute_smaller_dtype(Pair(x=torch.float32, y=torch.float32)) is None
def test_unknown_pair_returns_none(self):
assert compute_smaller_dtype(Pair(x=torch.int32, y=torch.int64)) is None
class TestPairMap:
def test_map_basic(self):
pair = Pair(x=[1, 2, 3], y=[4, 5, 6])
result = pair.map(lambda lst: sum(lst))
assert result.x == 6
assert result.y == 15
def test_map_type_change(self):
pair = Pair(x=[1, 2, 3], y=[10, 20])
result = pair.map(len)
assert result.x == 3
assert result.y == 2
def test_map_returns_new_pair(self):
pair = Pair(x="hello", y="world")
result = pair.map(str.upper)
assert result.x == "HELLO"
assert result.y == "WORLD"
assert result is not pair
class TestComputeExitCode:
"""Unit tests for compute_exit_code logic."""
def test_all_passed(self):
"""All passed β†’ exit 0."""
summary = SummaryRecord(total=3, passed=3, failed=0, skipped=0)
assert (
compute_exit_code(
summary,
allow_skipped_pattern=".*",
skipped_names=[],
allow_failed_pattern=None,
failed_names=[],
)
== 0
)
def test_has_failed_and_passed(self):
"""Has failed and passed β†’ exit 1."""
summary = SummaryRecord(total=4, passed=2, failed=2, skipped=0)
assert (
compute_exit_code(
summary,
allow_skipped_pattern=".*",
skipped_names=[],
allow_failed_pattern=None,
failed_names=["a", "b"],
)
== 1
)
def test_all_failed(self):
"""All failed (0 passed) β†’ exit 1."""
summary = SummaryRecord(total=3, passed=0, failed=3, skipped=0)
assert (
compute_exit_code(
summary,
allow_skipped_pattern=".*",
skipped_names=[],
allow_failed_pattern=None,
failed_names=["a", "b", "c"],
)
== 1
)
def test_all_skipped_allow_all(self):
"""All skipped + allow_skipped_pattern='.*' β†’ exit 1 (nothing passed)."""
summary = SummaryRecord(total=2, passed=0, failed=0, skipped=2)
assert (
compute_exit_code(
summary,
allow_skipped_pattern=".*",
skipped_names=["a", "b"],
allow_failed_pattern=None,
failed_names=[],
)
== 1
)
def test_all_skipped_forbid_all(self):
"""All skipped + allow_skipped_pattern='^$' β†’ exit 1."""
summary = SummaryRecord(total=2, passed=0, failed=0, skipped=2)
assert (
compute_exit_code(
summary,
allow_skipped_pattern="^$",
skipped_names=["a", "b"],
allow_failed_pattern=None,
failed_names=[],
)
== 1
)
def test_passed_and_skipped_allow_all(self):
"""Passed + skipped, allow all β†’ exit 0."""
summary = SummaryRecord(total=3, passed=2, failed=0, skipped=1)
assert (
compute_exit_code(
summary,
allow_skipped_pattern=".*",
skipped_names=["a"],
allow_failed_pattern=None,
failed_names=[],
)
== 0
)
def test_passed_and_skipped_forbid_all(self):
"""Passed + skipped + forbid all β†’ exit 1."""
summary = SummaryRecord(total=3, passed=2, failed=0, skipped=1)
assert (
compute_exit_code(
summary,
allow_skipped_pattern="^$",
skipped_names=["a"],
allow_failed_pattern=None,
failed_names=[],
)
== 1
)
def test_skip_pattern_matches_specific_name(self):
"""Pattern matching specific name allows that skip, forbids others."""
summary = SummaryRecord(total=4, passed=2, failed=0, skipped=2)
assert (
compute_exit_code(
summary,
allow_skipped_pattern="positions|seq_lens",
skipped_names=["positions", "seq_lens"],
allow_failed_pattern=None,
failed_names=[],
)
== 0
)
def test_skip_pattern_partial_match_forbidden(self):
"""Pattern matches some skips but not all β†’ exit 1."""
summary = SummaryRecord(total=4, passed=1, failed=0, skipped=3)
assert (
compute_exit_code(
summary,
allow_skipped_pattern="positions|seq_lens",
skipped_names=["positions", "seq_lens", "hidden_states"],
allow_failed_pattern=None,
failed_names=[],
)
== 1
)
def test_allow_failed_pattern_matches_all(self):
"""allow_failed_pattern='.*' tolerates all failures β†’ exit 0."""
summary = SummaryRecord(total=3, passed=1, failed=2, skipped=0)
assert (
compute_exit_code(
summary,
allow_skipped_pattern=".*",
skipped_names=[],
allow_failed_pattern=".*",
failed_names=["a", "b"],
)
== 0
)
def test_allow_failed_pattern_matches_specific(self):
"""Pattern matches all failed names β†’ exit 0."""
summary = SummaryRecord(total=3, passed=1, failed=2, skipped=0)
assert (
compute_exit_code(
summary,
allow_skipped_pattern=".*",
skipped_names=[],
allow_failed_pattern="hidden_states|logits",
failed_names=["hidden_states", "logits"],
)
== 0
)
def test_allow_failed_pattern_partial_match(self):
"""Pattern matches some but not all failures β†’ exit 1."""
summary = SummaryRecord(total=3, passed=0, failed=3, skipped=0)
assert (
compute_exit_code(
summary,
allow_skipped_pattern=".*",
skipped_names=[],
allow_failed_pattern="hidden_states",
failed_names=["hidden_states", "logits", "attn"],
)
== 1
)
def test_allow_failed_pattern_no_failures(self):
"""Pattern set but no failures β†’ exit 0."""
summary = SummaryRecord(total=2, passed=2, failed=0, skipped=0)
assert (
compute_exit_code(
summary,
allow_skipped_pattern=".*",
skipped_names=[],
allow_failed_pattern=".*",
failed_names=[],
)
== 0
)
def test_both_failed_and_skipped_patterns(self):
"""Both patterns set, both satisfied β†’ exit 0."""
summary = SummaryRecord(total=4, passed=1, failed=1, skipped=2)
assert (
compute_exit_code(
summary,
allow_skipped_pattern="positions|seq_lens",
skipped_names=["positions", "seq_lens"],
allow_failed_pattern="logits",
failed_names=["logits"],
)
== 0
)
def test_failed_pattern_satisfied_but_skipped_not(self):
"""Failed pattern OK but skipped pattern fails β†’ exit 1."""
summary = SummaryRecord(total=3, passed=1, failed=1, skipped=1)
assert (
compute_exit_code(
summary,
allow_skipped_pattern="^$",
skipped_names=["a"],
allow_failed_pattern=".*",
failed_names=["b"],
)
== 1
)
def test_zero_passed_exits_one(self):
"""No tensors passed β†’ exit 1, even when all failures are allowed."""
summary = SummaryRecord(total=2, passed=0, failed=2, skipped=0)
assert (
compute_exit_code(
summary,
allow_skipped_pattern=".*",
skipped_names=[],
allow_failed_pattern=".*",
failed_names=["a", "b"],
)
== 1
)
def test_zero_passed_all_skipped_exits_one(self):
"""All skipped, nothing passed β†’ exit 1."""
summary = SummaryRecord(total=3, passed=0, failed=0, skipped=3)
assert (
compute_exit_code(
summary,
allow_skipped_pattern=".*",
skipped_names=["a", "b", "c"],
allow_failed_pattern=None,
failed_names=[],
)
== 1
)
def test_errored_with_passed_exits_one(self):
"""Has errored bundle even with passed β†’ exit 1."""
summary = SummaryRecord(total=3, passed=2, failed=0, skipped=0, errored=1)
assert (
compute_exit_code(
summary,
allow_skipped_pattern=".*",
skipped_names=[],
allow_failed_pattern=None,
failed_names=[],
errored_names=["broken_tensor"],
)
== 1
)
def test_errored_only_exits_one(self):
"""All errored β†’ exit 1 (passed==0 already exits 1, but errored also independently triggers)."""
summary = SummaryRecord(total=1, passed=0, failed=0, skipped=0, errored=1)
assert (
compute_exit_code(
summary,
allow_skipped_pattern=".*",
skipped_names=[],
allow_failed_pattern=None,
failed_names=[],
errored_names=["broken_tensor"],
)
== 1
)
def _make_pt(directory: Path) -> None:
directory.mkdir(parents=True, exist_ok=True)
torch.save(torch.tensor([1.0]), directory / "dummy.pt")
class TestAutoDescendDir:
def test_no_descend_when_pt_at_root(self, tmp_path: Path) -> None:
"""Directory with .pt files directly is returned as-is."""
_make_pt(tmp_path)
_make_pt(tmp_path / "child_a")
assert auto_descend_dir(tmp_path, label="test") == tmp_path
def test_descend_into_single_child(self, tmp_path: Path) -> None:
"""Single child with .pt triggers descend."""
child: Path = tmp_path / "engine_0"
_make_pt(child)
assert auto_descend_dir(tmp_path, label="test") == child
def test_descend_single_nonempty_child_among_empty(self, tmp_path: Path) -> None:
"""Two subdirs but only one has .pt β€” descend into that one."""
nonempty: Path = tmp_path / "engine_0"
_make_pt(nonempty)
(tmp_path / "empty_child").mkdir()
assert auto_descend_dir(tmp_path, label="test") == nonempty
def test_error_with_multiple_nonempty_children(self, tmp_path: Path) -> None:
"""Two children with .pt files β€” ambiguous, raises ValueError."""
_make_pt(tmp_path / "engine_0")
_make_pt(tmp_path / "engine_1")
with pytest.raises(ValueError, match="multiple subdirectories contain data"):
auto_descend_dir(tmp_path, label="test")
def test_error_when_no_data_found(self, tmp_path: Path) -> None:
"""No .pt files anywhere β€” raises ValueError."""
(tmp_path / "empty_child").mkdir()
with pytest.raises(ValueError, match="no .pt files found"):
auto_descend_dir(tmp_path, label="test")
if __name__ == "__main__":
sys.exit(pytest.main([__file__]))