SparseVLM / tests /test_patch.py
Aryan3108's picture
Upload folder using huggingface_hub
176b11a verified
"""
tests/test_patch.py
--------------------
Integration tests for the SparseVLM attention hook.
Uses a tiny fake transformer so tests run without downloading Qwen2.5-VL.
"""
import pytest
import torch
import torch.nn as nn
from sparsevlm.patch import SparseVLMAttentionWrapper, reset_n_vis
class FakeAttention(nn.Module):
"""Minimal attention module that mimics Qwen2VLAttention's output signature."""
def __init__(self, D=64, H=4):
super().__init__()
self.D = D
self.H = H
self.proj = nn.Linear(D, D)
def forward(self, hidden_states, attention_mask=None, position_ids=None,
past_key_value=None, output_attentions=False, use_cache=False, **kwargs):
B, N, D = hidden_states.shape
out = self.proj(hidden_states)
if output_attentions:
# Return fake attention weights [B, H, N, N]
attn = torch.rand(B, self.H, N, N, device=hidden_states.device)
attn = attn / attn.sum(dim=-1, keepdim=True)
return out, attn
return (out,)
def make_wrapper(D=64, H=4, n_vis=32, is_target=True):
shared = {'n_vis': n_vis}
attn = FakeAttention(D=D, H=H)
wrapper = SparseVLMAttentionWrapper(
original_attn=attn,
shared_state=shared,
layer_idx=0,
is_target_layer=is_target,
min_keep=8,
)
return wrapper, shared
def test_wrapper_forward_shape():
"""Output hidden states have correct shape."""
torch.manual_seed(0)
B, N_vis, N_text, D = 2, 32, 8, 64
N_total = N_vis + N_text
wrapper, _ = make_wrapper(D=D, n_vis=N_vis)
hidden = torch.randn(B, N_total, D)
out = wrapper(hidden)
assert isinstance(out, tuple)
assert out[0].dim() == 3
assert out[0].shape[0] == B
assert out[0].shape[2] == D
def test_wrapper_reduces_n_vis():
"""n_vis in shared_state decreases after pruning."""
torch.manual_seed(1)
B, N_vis, N_text, D = 2, 64, 16, 64
N_total = N_vis + N_text
wrapper, shared = make_wrapper(D=D, n_vis=N_vis)
hidden = torch.randn(B, N_total, D)
wrapper(hidden)
assert shared['n_vis'] < N_vis, "n_vis should decrease after pruning"
assert shared['n_vis'] >= 8, "n_vis should respect min_keep"
def test_non_target_layer_no_pruning():
"""Non-target layers pass through without changing n_vis."""
torch.manual_seed(2)
B, N_vis, N_text, D = 2, 64, 16, 64
N_total = N_vis + N_text
wrapper, shared = make_wrapper(D=D, n_vis=N_vis, is_target=False)
original_n_vis = shared['n_vis']
hidden = torch.randn(B, N_total, D)
wrapper(hidden)
assert shared['n_vis'] == original_n_vis
def test_reset_n_vis():
"""reset_n_vis correctly resets shared state."""
shared = {'n_vis': 64}
reset_n_vis(shared, 256)
assert shared['n_vis'] == 256
def test_no_nan_output():
"""No NaN in wrapper output."""
torch.manual_seed(3)
B, N_vis, N_text, D = 2, 48, 12, 64
N_total = N_vis + N_text
wrapper, _ = make_wrapper(D=D, n_vis=N_vis)
hidden = torch.randn(B, N_total, D)
out = wrapper(hidden)
assert not torch.isnan(out[0]).any()
assert not torch.isinf(out[0]).any()