Fabrice-TIERCELIN commited on
Commit
8715afd
·
verified ·
1 Parent(s): 8e58701

Upload clip_encoder.py

Browse files
llava/model/multimodal_encoder/clip_encoder.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
5
+ from CKPT_PTH import LLAVA_CLIP_PATH
6
+
7
+
8
+ class CLIPVisionTower(nn.Module):
9
+ def __init__(self, vision_tower, args, delay_load=False):
10
+ super().__init__()
11
+
12
+ self.is_loaded = False
13
+
14
+ self.vision_tower_name = vision_tower
15
+ print(f'Loading vision tower: {self.vision_tower_name}')
16
+ self.select_layer = args.mm_vision_select_layer
17
+ self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
18
+
19
+ if not delay_load:
20
+ self.load_model()
21
+ else:
22
+ # self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
23
+ self.cfg_only = CLIPVisionConfig.from_pretrained(
24
+ self.vision_tower_name if LLAVA_CLIP_PATH is None else LLAVA_CLIP_PATH)
25
+
26
+ def load_model(self):
27
+ self.image_processor = CLIPImageProcessor.from_pretrained(
28
+ self.vision_tower_name if LLAVA_CLIP_PATH is None else LLAVA_CLIP_PATH)
29
+ self.vision_tower = CLIPVisionModel.from_pretrained(
30
+ self.vision_tower_name if LLAVA_CLIP_PATH is None else LLAVA_CLIP_PATH)
31
+ self.vision_tower.requires_grad_(False)
32
+
33
+ self.is_loaded = True
34
+
35
+ def feature_select(self, image_forward_outs):
36
+ image_features = image_forward_outs.hidden_states[self.select_layer]
37
+ if self.select_feature == 'patch':
38
+ image_features = image_features[:, 1:]
39
+ elif self.select_feature == 'cls_patch':
40
+ image_features = image_features
41
+ else:
42
+ raise ValueError(f'Unexpected select feature: {self.select_feature}')
43
+ return image_features
44
+
45
+ @torch.no_grad()
46
+ def forward(self, images):
47
+ if type(images) is list:
48
+ image_features = []
49
+ for image in images:
50
+ image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
51
+ image_feature = self.feature_select(image_forward_out).to(image.dtype)
52
+ image_features.append(image_feature)
53
+ else:
54
+ image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
55
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
56
+
57
+ return image_features
58
+
59
+ @property
60
+ def dummy_feature(self):
61
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
62
+
63
+ @property
64
+ def dtype(self):
65
+ return self.vision_tower.dtype
66
+
67
+ @property
68
+ def device(self):
69
+ return self.vision_tower.device
70
+
71
+ @property
72
+ def config(self):
73
+ if self.is_loaded:
74
+ return self.vision_tower.config
75
+ else:
76
+ return self.cfg_only
77
+
78
+ @property
79
+ def hidden_size(self):
80
+ return self.config.hidden_size
81
+
82
+ @property
83
+ def num_patches(self):
84
+ return (self.config.image_size // self.config.patch_size) ** 2