File size: 5,903 Bytes
b443c25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
from .convnext_encoder import ConvNextVisionTower
from .hr_clip_encoder import HRCLIPVisionTower
from .vision_models.eva_vit import EVAVITVisionTower
from .sam_encoder import SAMVisionTower
from .pix2struct_encoder import Pix2StructLargeVisionTower
import torch.nn.functional as F
from torch.nn.init import trunc_normal_
from copy import deepcopy
import random
import math

class MultiBackboneChannelConcatenationVisionTower(nn.Module):
    def __init__(self,

                 vision_tower,

                 args,

                 grid_size=32):
        
        super().__init__()

        self.is_loaded = False
        self.grid_size = grid_size
        self.num_tokens = self.grid_size ** 2
        
        vision_tower_name_list = vision_tower.split(";")
        self.input_image_size = 1024 # hardcode
        self.load_vision_towers(vision_tower_name_list, args)

      
    def load_vision_towers(self, vision_tower_name_list, args):
        self.vision_towers = nn.ModuleList()
        for name in vision_tower_name_list:
            if name == 'det-1024':
                det_args = deepcopy(args)
                det_args.input_image_size = 1024
                det_args.freeze_vision = False
                det_args.vision_tower_pretrained_from = '/lustre/fsw/portfolios/llmservice/users/fuxiaol/eva02_L_coco_det_sys_o365.pth'
                det_vision_tower = EVAVITVisionTower("eva02-l-16", det_args)     
                det_vision_tower.load_model()
                self.vision_towers.append(det_vision_tower)

            elif name == 'convnext-1024':
                ## ConvNeXt
                convnext_args = deepcopy(args)
                convnext_args.freeze_vision = False
                convnext_args.input_image_size = 1024
                convnext_vision_tower = "convnext_xxlarge.clip_laion2b_soup" # hardcode
                convnext_vision_tower = ConvNextVisionTower(convnext_vision_tower, 
                                                                convnext_args)
                convnext_vision_tower.load_model()      
                self.vision_towers.append(convnext_vision_tower)
            
            elif name == "sam-1024":
                sam_args = deepcopy(args)
                sam_args.freeze_vision = False
                sam_args.input_image_size = 1024
                sam_args.add_pixel_shuffle = True
                sam_vision_tower = SAMVisionTower("SAM-L", sam_args)
                sam_vision_tower.load_model()
                self.vision_towers.append(sam_vision_tower)

            elif name == 'pix2struct-1024':
                pix_args = deepcopy(args)
                #pix_args.freeze_vision = True
                pix_args.input_image_size = 1024
                pix_args.freeze_vision = False
                pix_args.do_resize = True
                pix_args.de_normalize = True
                pix_vision_tower = Pix2StructLargeVisionTower("pix2struct-large", pix_args)     
                pix_vision_tower.load_model()
                self.vision_towers.append(pix_vision_tower)

            elif name == 'clip-448':
                clip_args = deepcopy(args)
                clip_args.input_image_size = 336 # actually 448, will have no effect
                clip_args.freeze_vision = False
                clip_vision_tower = HRCLIPVisionTower("openai/clip-vit-large-patch14-336", clip_args)     
                clip_vision_tower.load_model()
                self.vision_towers.append(clip_vision_tower)
        
        # a hardcode here, so we always use convnext in the vision encoder mixture
        self.image_processor = convnext_vision_tower.image_processor
        self.is_loaded = True

    def load_model(self):
        assert self.is_loaded, "All the vision encoders should be loaded during initialization!"

    def forward(self, x):
        features = []
        for vision_tower in self.vision_towers:
            if vision_tower.input_image_size != self.input_image_size:
                resized_x = F.interpolate(x.float(), 
                                          size=(vision_tower.input_image_size, vision_tower.input_image_size), 
                                          mode='bilinear', 
                                          align_corners=True).to(dtype=x.dtype)
            else:
                resized_x = x
            feature = vision_tower(resized_x)
            if len(feature.shape) == 3: # b, n, c
                b, n, c = feature.shape
                if n == self.num_tokens:
                    features.append(feature)
                    continue

                w = h = int(n**0.5)
                feature = feature.transpose(1,2).reshape(b, c, h, w)
            else:
                b, c, h, w = feature.shape

            if w != self.grid_size:
                feature = F.interpolate(feature.float(), size=(self.grid_size, self.grid_size), mode='bilinear', align_corners=True).to(dtype=x.dtype)
            features.append(feature.flatten(2,3).transpose(1,2))
        
        features = torch.cat(features, dim=-1)

        return features
        
    @property
    def dummy_feature(self):
        return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)

    @property
    def dtype(self):
        return next(self.clip_vision_tower.parameters()).dtype

    @property
    def device(self):
        return next(self.clip_vision_tower.parameters()).device

    @property
    def config(self):
        assert NotImplementedError
        pass

    @property
    def hidden_size(self):
        return sum([_.hidden_size for _ in self.vision_towers])

    @property
    def num_patches(self):
        return self.num_tokens