File size: 3,523 Bytes
970607e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
import torch.nn as nn

def dense_connector_sti(image_features, image_forward_outs, is_siglip=True):
    avg_pooling_k8 = nn.AvgPool1d(kernel_size=8, stride=8)
    if not is_siglip:
        image_features_1 = image_forward_outs.hidden_states[7][:, 1:].to(image_features.dtype)
        image_features_2 = image_forward_outs.hidden_states[15][:, 1:].to(image_features.dtype)
    else:
        image_features_1 = image_forward_outs.hidden_states[7][:, :].to(image_features.dtype)
        image_features_2 = image_forward_outs.hidden_states[15][:, :].to(image_features.dtype)
    image_features_1 = avg_pooling_k8(image_features_1.permute(0, 2, 1)).permute(0, 2, 1)
    image_features_2 = avg_pooling_k8(image_features_2.permute(0, 2, 1)).permute(0, 2, 1)
    return torch.cat([image_features_1, image_features_2], dim=-2)

def dense_connector_sci(image_features,image_forward_outs, is_siglip=True):
    if not is_siglip:
        image_features_1 = image_forward_outs.hidden_states[7][:, 1:].to(image_features.dtype)
        image_features_2 = image_forward_outs.hidden_states[15][:, 1:].to(image_features.dtype)
    else:
        image_features_1 = image_forward_outs.hidden_states[7][:, :].to(image_features.dtype)
        image_features_2 = image_forward_outs.hidden_states[15][:, :].to(image_features.dtype)
    return torch.cat([image_features_1, image_features_2], dim=-1)

def dense_connector_dci(image_features,image_forward_outs, is_siglip=True):
    image_features_1 = []
    image_features_2 = []
    if not is_siglip:
        for i in range(0, 12):
            image_features_1.append(image_forward_outs.hidden_states[i][:, 1:].to(image_features.dtype))
        image_features_1 = torch.stack(image_features_1, dim=0)
        image_features_1 = torch.sum(image_features_1, dim=0) / 12
        for i in range(12, 24):
            image_features_2.append(image_forward_outs.hidden_states[i][:, 1:].to(image_features.dtype))
        image_features_2 = torch.stack(image_features_2, dim=0)
        image_features_2 = torch.sum(image_features_2, dim=0) / 12
    else:
        for i in range(0, 13):
            image_features_1.append(image_forward_outs.hidden_states[i][:, :].to(image_features.dtype))
        image_features_1 = torch.stack(image_features_1, dim=0)
        image_features_1 = torch.sum(image_features_1, dim=0) / 13
        for i in range(13, 26):
            image_features_2.append(image_forward_outs.hidden_states[i][:, :].to(image_features.dtype))
        image_features_2 = torch.stack(image_features_2, dim=0)
        image_features_2 = torch.sum(image_features_2, dim=0) / 13
    return torch.cat([image_features_1, image_features_2], dim=-1)


def dense_connector(image_features, image_forward_outs, is_siglip=True, mm_dense_connector_type='dci'):
    
    if mm_dense_connector_type == 'sti':
        image_features_dc = dense_connector_sti(image_features, image_forward_outs, is_siglip)
        image_features = torch.cat((image_features, image_features_dc), dim=-2)
    elif mm_dense_connector_type == 'sci':
        image_features_dc = dense_connector_sci(image_features, image_forward_outs, is_siglip)
        image_features = torch.cat((image_features, image_features_dc), dim=-1)
    elif mm_dense_connector_type == 'dci':
        image_features_dc = dense_connector_dci(image_features, image_forward_outs, is_siglip)
        image_features = torch.cat((image_features, image_features_dc), dim=-1)
    else:
        raise NotImplementedError()
    
    return image_features