File size: 2,315 Bytes
9a43e2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import torch
import torch.nn as nn
from transformers import CLIPVisionModel


class DFN5B_CLIP_ViT_H_14_378(nn.Module):

    def __init__(self, vision_tower):
        super().__init__()

        self.is_loaded = False
        self.is_resize_pos = False

        self.vision_tower_name = vision_tower
        self.select_layer = -1
        self.select_feature = 'patch'
        self.load_model()

    def load_model(self):
        
        # self.vision_tower = CLIPVisionModel.from_pretrained('/root/lwt/tech/mcmd-72b/acc_finetune/DFN5B-bfloat16')#self.vision_tower_name
        self.vision_tower = CLIPVisionModel.from_pretrained('/root/LWT/Models/DFN5B-CLIP-ViT-H-14-378')#self.vision_tower_name

        self.vision_tower.requires_grad_(False)
        
        self.is_loaded = True


    def feature_select(self, image_forward_outs):
        image_features = image_forward_outs.hidden_states[self.select_layer]

        if self.select_feature == 'patch':
            image_features = image_features[:, 1:]
        elif self.select_feature == 'cls_patch':
            image_features = image_features
        else:
            raise ValueError(
                f'Unexpected select feature: {self.select_feature}')
        return image_features

    def forward(self, images):

        if not self.is_loaded:
            self.load_model()
        
        if type(images) is list: # not batch infurence speed!
            image_features = []
            for image in images:
                image_forward_out = self.vision_tower(
                    image.to(device=self.device,
                             dtype=image.dtype).unsqueeze(0),
                    output_hidden_states=True)
                image_feature = self.feature_select(image_forward_out).to(
                    image.dtype)
                image_features.append(image_feature)
        else:
            
            image_forward_outs = self.vision_tower(
                images.to(device=self.device, dtype=images.dtype),
                output_hidden_states=True)
            image_features = self.feature_select(image_forward_outs).to(images.dtype)
        return image_features

    @property
    def device(self):
        return self.vision_tower.device
    
    @property
    def dtype(self):
        return self.vision_tower.dtype