File size: 5,402 Bytes
b443c25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
"""

Mostly copy-paste from LLaVA-HR

https://github.com/luogen1996/LLaVA-HR

"""

import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint

from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig

import math
import torch
import torch.nn.functional as F
from typing import List, Optional


def forward_embeddings(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
    batch_size = pixel_values.shape[0]
    target_dtype = self.patch_embedding.weight.dtype
    patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))  # shape = [*, width, grid, grid]
    patch_embeds = patch_embeds.flatten(2).transpose(1, 2)

    class_embeds = self.class_embedding.expand(batch_size, 1, -1)
    embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
    position_embeddings = self.position_embedding(self.position_ids)

    if position_embeddings.shape[1]!=embeddings.shape[1]:
        position_embeddings=resample_pos_embed(position_embeddings,embeddings.shape[1])

    embeddings = embeddings + position_embeddings
    return embeddings


def resample_pos_embed(

        posemb,

        new_size: int,

        num_prefix_tokens: int = 1,

        interpolation: str = 'bicubic',

        antialias: bool = True,

        verbose: bool = False,

):
    new_size=[int(math.sqrt(new_size-num_prefix_tokens)),int(math.sqrt(new_size-num_prefix_tokens))]
    num_pos_tokens = posemb.shape[1] - num_prefix_tokens
    old_size = int(math.sqrt(num_pos_tokens))
    bs=posemb.shape[0]

    if num_prefix_tokens:
        posemb_prefix, posemb = posemb[:,:num_prefix_tokens], posemb[:,num_prefix_tokens:]
    else:
        posemb_prefix, posemb = None, posemb

    # do the interpolation
    embed_dim = posemb.shape[-1]
    orig_dtype = posemb.dtype
    posemb = posemb.float()  # interpolate needs float32
    posemb = posemb.reshape(bs, old_size, old_size, -1).permute(0, 3, 1, 2)
    posemb = F.interpolate(posemb, size=new_size, mode=interpolation, antialias=antialias)
    posemb = posemb.permute(0, 2, 3, 1).reshape(bs, -1, embed_dim)
    posemb = posemb.to(dtype=orig_dtype)

    # add back extra (class, etc) prefix tokens
    if posemb_prefix is not None:
        posemb = torch.cat([posemb_prefix, posemb],1)

    if not torch.jit.is_scripting() and verbose:
        print(f'Resized position embedding: {old_size} to {new_size}.')

    return posemb

class HRCLIPVisionTower(nn.Module):
    def __init__(self, vision_tower, args, delay_load=False):
        super().__init__()

        self.is_loaded = False
        self.freeze_vision=args.freeze_vision
        self.input_image_size=args.input_image_size
        self.vision_tower_name = vision_tower
        self.select_layer = args.mm_vision_select_layer
        self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')

        if not delay_load:
            self.load_model()
        else:
            self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)


    def load_model(self):
        self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
        self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name)
        # checkpointing for clip
        self.vision_tower.vision_model.encoder.gradient_checkpointing =True

        if self.freeze_vision:
            self.vision_tower.requires_grad_(False)

        cls_=self.vision_tower.vision_model.embeddings
        bound_method = forward_embeddings.__get__(cls_, cls_.__class__)
        setattr(cls_, 'forward', bound_method)

        if self.input_image_size is not None:
            self.image_processor.size=self.input_image_size
            self.image_processor.crop_size={
                'height':self.input_image_size,
                'width': self.input_image_size
            }

        self.is_loaded = True

    def forward(self, x):
        # 448 image input
        blks = self.vision_tower.vision_model.encoder.layers
        x = self.vision_tower.vision_model.embeddings(x)
        x = self.vision_tower.vision_model.pre_layrnorm(x[:, 1:])

        # inference of fast branch
        for blk in blks:
            if self.training:
                x=checkpoint(
                    blk.__call__,
                    x,
                    None,
                    None
                )[0]
            else:
                x = blk(x, None, None)[0]

        return x

    @property
    def dummy_feature(self):
        return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)

    @property
    def dtype(self):
        return self.vision_tower.dtype

    @property
    def device(self):
        return self.vision_tower.device


    @property
    def num_attention_heads(self):
        return self.config.num_attention_heads
    @property
    def num_layers(self):
        return self.config.num_hidden_layers
    @property
    def config(self):
        if self.is_loaded:
            return self.vision_tower.config
        else:
            return self.cfg_only

    @property
    def hidden_size(self):
        return self.config.hidden_size

    @property
    def num_patches(self):
        return (self.config.image_size // self.config.patch_size) ** 2