| """ |
| tests/test_patch.py |
| -------------------- |
| Integration tests for the SparseVLM attention hook. |
| Uses a tiny fake transformer so tests run without downloading Qwen2.5-VL. |
| """ |
|
|
| import pytest |
| import torch |
| import torch.nn as nn |
| from sparsevlm.patch import SparseVLMAttentionWrapper, reset_n_vis |
|
|
|
|
| class FakeAttention(nn.Module): |
| """Minimal attention module that mimics Qwen2VLAttention's output signature.""" |
|
|
| def __init__(self, D=64, H=4): |
| super().__init__() |
| self.D = D |
| self.H = H |
| self.proj = nn.Linear(D, D) |
|
|
| def forward(self, hidden_states, attention_mask=None, position_ids=None, |
| past_key_value=None, output_attentions=False, use_cache=False, **kwargs): |
| B, N, D = hidden_states.shape |
| out = self.proj(hidden_states) |
|
|
| if output_attentions: |
| |
| attn = torch.rand(B, self.H, N, N, device=hidden_states.device) |
| attn = attn / attn.sum(dim=-1, keepdim=True) |
| return out, attn |
| return (out,) |
|
|
|
|
| def make_wrapper(D=64, H=4, n_vis=32, is_target=True): |
| shared = {'n_vis': n_vis} |
| attn = FakeAttention(D=D, H=H) |
| wrapper = SparseVLMAttentionWrapper( |
| original_attn=attn, |
| shared_state=shared, |
| layer_idx=0, |
| is_target_layer=is_target, |
| min_keep=8, |
| ) |
| return wrapper, shared |
|
|
|
|
| def test_wrapper_forward_shape(): |
| """Output hidden states have correct shape.""" |
| torch.manual_seed(0) |
| B, N_vis, N_text, D = 2, 32, 8, 64 |
| N_total = N_vis + N_text |
| wrapper, _ = make_wrapper(D=D, n_vis=N_vis) |
|
|
| hidden = torch.randn(B, N_total, D) |
| out = wrapper(hidden) |
|
|
| assert isinstance(out, tuple) |
| assert out[0].dim() == 3 |
| assert out[0].shape[0] == B |
| assert out[0].shape[2] == D |
|
|
|
|
| def test_wrapper_reduces_n_vis(): |
| """n_vis in shared_state decreases after pruning.""" |
| torch.manual_seed(1) |
| B, N_vis, N_text, D = 2, 64, 16, 64 |
| N_total = N_vis + N_text |
| wrapper, shared = make_wrapper(D=D, n_vis=N_vis) |
|
|
| hidden = torch.randn(B, N_total, D) |
| wrapper(hidden) |
|
|
| assert shared['n_vis'] < N_vis, "n_vis should decrease after pruning" |
| assert shared['n_vis'] >= 8, "n_vis should respect min_keep" |
|
|
|
|
| def test_non_target_layer_no_pruning(): |
| """Non-target layers pass through without changing n_vis.""" |
| torch.manual_seed(2) |
| B, N_vis, N_text, D = 2, 64, 16, 64 |
| N_total = N_vis + N_text |
| wrapper, shared = make_wrapper(D=D, n_vis=N_vis, is_target=False) |
|
|
| original_n_vis = shared['n_vis'] |
| hidden = torch.randn(B, N_total, D) |
| wrapper(hidden) |
|
|
| assert shared['n_vis'] == original_n_vis |
|
|
|
|
| def test_reset_n_vis(): |
| """reset_n_vis correctly resets shared state.""" |
| shared = {'n_vis': 64} |
| reset_n_vis(shared, 256) |
| assert shared['n_vis'] == 256 |
|
|
|
|
| def test_no_nan_output(): |
| """No NaN in wrapper output.""" |
| torch.manual_seed(3) |
| B, N_vis, N_text, D = 2, 48, 12, 64 |
| N_total = N_vis + N_text |
| wrapper, _ = make_wrapper(D=D, n_vis=N_vis) |
|
|
| hidden = torch.randn(B, N_total, D) |
| out = wrapper(hidden) |
|
|
| assert not torch.isnan(out[0]).any() |
| assert not torch.isinf(out[0]).any() |
|
|