""" tests/test_patch.py -------------------- Integration tests for the SparseVLM attention hook. Uses a tiny fake transformer so tests run without downloading Qwen2.5-VL. """ import pytest import torch import torch.nn as nn from sparsevlm.patch import SparseVLMAttentionWrapper, reset_n_vis class FakeAttention(nn.Module): """Minimal attention module that mimics Qwen2VLAttention's output signature.""" def __init__(self, D=64, H=4): super().__init__() self.D = D self.H = H self.proj = nn.Linear(D, D) def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, output_attentions=False, use_cache=False, **kwargs): B, N, D = hidden_states.shape out = self.proj(hidden_states) if output_attentions: # Return fake attention weights [B, H, N, N] attn = torch.rand(B, self.H, N, N, device=hidden_states.device) attn = attn / attn.sum(dim=-1, keepdim=True) return out, attn return (out,) def make_wrapper(D=64, H=4, n_vis=32, is_target=True): shared = {'n_vis': n_vis} attn = FakeAttention(D=D, H=H) wrapper = SparseVLMAttentionWrapper( original_attn=attn, shared_state=shared, layer_idx=0, is_target_layer=is_target, min_keep=8, ) return wrapper, shared def test_wrapper_forward_shape(): """Output hidden states have correct shape.""" torch.manual_seed(0) B, N_vis, N_text, D = 2, 32, 8, 64 N_total = N_vis + N_text wrapper, _ = make_wrapper(D=D, n_vis=N_vis) hidden = torch.randn(B, N_total, D) out = wrapper(hidden) assert isinstance(out, tuple) assert out[0].dim() == 3 assert out[0].shape[0] == B assert out[0].shape[2] == D def test_wrapper_reduces_n_vis(): """n_vis in shared_state decreases after pruning.""" torch.manual_seed(1) B, N_vis, N_text, D = 2, 64, 16, 64 N_total = N_vis + N_text wrapper, shared = make_wrapper(D=D, n_vis=N_vis) hidden = torch.randn(B, N_total, D) wrapper(hidden) assert shared['n_vis'] < N_vis, "n_vis should decrease after pruning" assert shared['n_vis'] >= 8, "n_vis should respect min_keep" def test_non_target_layer_no_pruning(): """Non-target layers pass through without changing n_vis.""" torch.manual_seed(2) B, N_vis, N_text, D = 2, 64, 16, 64 N_total = N_vis + N_text wrapper, shared = make_wrapper(D=D, n_vis=N_vis, is_target=False) original_n_vis = shared['n_vis'] hidden = torch.randn(B, N_total, D) wrapper(hidden) assert shared['n_vis'] == original_n_vis def test_reset_n_vis(): """reset_n_vis correctly resets shared state.""" shared = {'n_vis': 64} reset_n_vis(shared, 256) assert shared['n_vis'] == 256 def test_no_nan_output(): """No NaN in wrapper output.""" torch.manual_seed(3) B, N_vis, N_text, D = 2, 48, 12, 64 N_total = N_vis + N_text wrapper, _ = make_wrapper(D=D, n_vis=N_vis) hidden = torch.randn(B, N_total, D) out = wrapper(hidden) assert not torch.isnan(out[0]).any() assert not torch.isinf(out[0]).any()