File size: 3,202 Bytes
176b11a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
"""
tests/test_patch.py
--------------------
Integration tests for the SparseVLM attention hook.
Uses a tiny fake transformer so tests run without downloading Qwen2.5-VL.
"""

import pytest
import torch
import torch.nn as nn
from sparsevlm.patch import SparseVLMAttentionWrapper, reset_n_vis


class FakeAttention(nn.Module):
    """Minimal attention module that mimics Qwen2VLAttention's output signature."""

    def __init__(self, D=64, H=4):
        super().__init__()
        self.D = D
        self.H = H
        self.proj = nn.Linear(D, D)

    def forward(self, hidden_states, attention_mask=None, position_ids=None,
                past_key_value=None, output_attentions=False, use_cache=False, **kwargs):
        B, N, D = hidden_states.shape
        out = self.proj(hidden_states)

        if output_attentions:
            # Return fake attention weights [B, H, N, N]
            attn = torch.rand(B, self.H, N, N, device=hidden_states.device)
            attn = attn / attn.sum(dim=-1, keepdim=True)
            return out, attn
        return (out,)


def make_wrapper(D=64, H=4, n_vis=32, is_target=True):
    shared = {'n_vis': n_vis}
    attn   = FakeAttention(D=D, H=H)
    wrapper = SparseVLMAttentionWrapper(
        original_attn=attn,
        shared_state=shared,
        layer_idx=0,
        is_target_layer=is_target,
        min_keep=8,
    )
    return wrapper, shared


def test_wrapper_forward_shape():
    """Output hidden states have correct shape."""
    torch.manual_seed(0)
    B, N_vis, N_text, D = 2, 32, 8, 64
    N_total = N_vis + N_text
    wrapper, _ = make_wrapper(D=D, n_vis=N_vis)

    hidden = torch.randn(B, N_total, D)
    out = wrapper(hidden)

    assert isinstance(out, tuple)
    assert out[0].dim() == 3
    assert out[0].shape[0] == B
    assert out[0].shape[2] == D


def test_wrapper_reduces_n_vis():
    """n_vis in shared_state decreases after pruning."""
    torch.manual_seed(1)
    B, N_vis, N_text, D = 2, 64, 16, 64
    N_total = N_vis + N_text
    wrapper, shared = make_wrapper(D=D, n_vis=N_vis)

    hidden = torch.randn(B, N_total, D)
    wrapper(hidden)

    assert shared['n_vis'] < N_vis, "n_vis should decrease after pruning"
    assert shared['n_vis'] >= 8,    "n_vis should respect min_keep"


def test_non_target_layer_no_pruning():
    """Non-target layers pass through without changing n_vis."""
    torch.manual_seed(2)
    B, N_vis, N_text, D = 2, 64, 16, 64
    N_total = N_vis + N_text
    wrapper, shared = make_wrapper(D=D, n_vis=N_vis, is_target=False)

    original_n_vis = shared['n_vis']
    hidden = torch.randn(B, N_total, D)
    wrapper(hidden)

    assert shared['n_vis'] == original_n_vis


def test_reset_n_vis():
    """reset_n_vis correctly resets shared state."""
    shared = {'n_vis': 64}
    reset_n_vis(shared, 256)
    assert shared['n_vis'] == 256


def test_no_nan_output():
    """No NaN in wrapper output."""
    torch.manual_seed(3)
    B, N_vis, N_text, D = 2, 48, 12, 64
    N_total = N_vis + N_text
    wrapper, _ = make_wrapper(D=D, n_vis=N_vis)

    hidden = torch.randn(B, N_total, D)
    out = wrapper(hidden)

    assert not torch.isnan(out[0]).any()
    assert not torch.isinf(out[0]).any()