razzant commited on
Commit
3e2b3c7
1 Parent(s): 070912c

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +99 -7
README.md CHANGED
@@ -53,16 +53,103 @@ Model Performance on Visual Dialog Benchmark
53
  ```python
54
  import torch
55
  from PIL import Image
56
- from clip_encoder import CLIPVisionTower
57
  from transformers import AutoTokenizer, AutoModelForCausalLM
58
-
 
 
59
 
60
  DEVICE = "cuda:0"
61
  PROMPT = "This is a dialog with AI assistant.\n"
62
  tokenizer = AutoTokenizer.from_pretrained("OmniMistral-tokenizer", use_fast=False)
63
  model = AutoModelForCausalLM.from_pretrained("OmniMistral-model", torch_dtype=torch.bfloat16, device_map=DEVICE)
64
 
65
- clip = CLIPVisionTower("openai/clip-vit-large-patch14-336")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  clip.load_model()
67
  clip = clip.to(device=DEVICE, dtype=torch.bfloat16)
68
 
@@ -84,7 +171,7 @@ def gen_answer(model, tokenizer, clip, projection, query, special_embs, image=No
84
  "num_return_sequences": 1,
85
  }
86
  with torch.no_grad():
87
- image_features = clip.module.image_processor(image, return_tensors='pt')
88
  image_embedding = clip(image_features['pixel_values']).to(device=DEVICE, dtype=torch.bfloat16)
89
 
90
  projected_vision_embeddings = projection(image_embedding).to(device=DEVICE, dtype=torch.bfloat16)
@@ -111,17 +198,22 @@ def gen_answer(model, tokenizer, clip, projection, query, special_embs, image=No
111
  generated_texts = tokenizer.batch_decode(out)[0]
112
  return generated_texts
113
 
 
 
 
114
 
115
  answer = gen_answer(
116
  model,
117
  tokenizer,
118
  clip,
119
  projection,
120
- query = "who is the author?",
121
- special_embs,
122
- Image.open("https://i.pinimg.com/originals/32/c7/81/32c78115cb47fd4825e6907a83b7afff.jpg")
123
  )
124
 
 
 
125
  print(answer)
126
  ```
127
 
 
53
  ```python
54
  import torch
55
  from PIL import Image
 
56
  from transformers import AutoTokenizer, AutoModelForCausalLM
57
+ from urllib.request import urlopen
58
+ import torch.nn as nn
59
+ from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
60
 
61
  DEVICE = "cuda:0"
62
  PROMPT = "This is a dialog with AI assistant.\n"
63
  tokenizer = AutoTokenizer.from_pretrained("OmniMistral-tokenizer", use_fast=False)
64
  model = AutoModelForCausalLM.from_pretrained("OmniMistral-model", torch_dtype=torch.bfloat16, device_map=DEVICE)
65
 
66
+ projection = torch.load("projection", map_location=DEVICE)
67
+ special_embs = torch.load("special_embeddings.pt", map_location=DEVICE)
68
+
69
+
70
+
71
+
72
+
73
+ class CLIPVisionTower(nn.Module):
74
+ def __init__(self, vision_tower, args, delay_load=False):
75
+ super().__init__()
76
+
77
+ self.is_loaded = False
78
+
79
+ self.vision_tower_name = vision_tower
80
+ self.select_layer = args.mm_vision_select_layer
81
+ self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
82
+
83
+ if not delay_load:
84
+ self.load_model()
85
+ else:
86
+ self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
87
+
88
+ def load_model(self):
89
+ self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
90
+ self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name)
91
+ self.vision_tower.requires_grad_(False)
92
+
93
+ self.is_loaded = True
94
+
95
+ def feature_select(self, image_forward_outs):
96
+ image_features = image_forward_outs.hidden_states[self.select_layer]
97
+ if self.select_feature == 'patch':
98
+ image_features = image_features[:, 1:]
99
+ elif self.select_feature == 'cls_patch':
100
+ image_features = image_features
101
+ else:
102
+ raise ValueError(f'Unexpected select feature: {self.select_feature}')
103
+ return image_features
104
+
105
+ @torch.no_grad()
106
+ def forward(self, images):
107
+ if type(images) is list:
108
+ image_features = []
109
+ for image in images:
110
+ image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
111
+ image_feature = self.feature_select(image_forward_out).to(image.dtype)
112
+ image_features.append(image_feature)
113
+ else:
114
+ image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
115
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
116
+
117
+ return image_features
118
+
119
+ @property
120
+ def dummy_feature(self):
121
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
122
+
123
+ @property
124
+ def dtype(self):
125
+ return self.vision_tower.dtype
126
+
127
+ @property
128
+ def device(self):
129
+ return self.vision_tower.device
130
+
131
+ @property
132
+ def config(self):
133
+ if self.is_loaded:
134
+ return self.vision_tower.config
135
+ else:
136
+ return self.cfg_only
137
+
138
+ @property
139
+ def hidden_size(self):
140
+ return self.config.hidden_size
141
+
142
+ @property
143
+ def num_patches(self):
144
+ return (self.config.image_size // self.config.patch_size) ** 2
145
+
146
+
147
+ class ClipTowerCfg:
148
+ def __init__(self):
149
+ self.mm_vision_select_feature = 'patch'
150
+ self.mm_vision_select_layer = -2
151
+
152
+ clip = CLIPVisionTower("openai/clip-vit-large-patch14-336", ClipTowerCfg())
153
  clip.load_model()
154
  clip = clip.to(device=DEVICE, dtype=torch.bfloat16)
155
 
 
171
  "num_return_sequences": 1,
172
  }
173
  with torch.no_grad():
174
+ image_features = clip.image_processor(image, return_tensors='pt')
175
  image_embedding = clip(image_features['pixel_values']).to(device=DEVICE, dtype=torch.bfloat16)
176
 
177
  projected_vision_embeddings = projection(image_embedding).to(device=DEVICE, dtype=torch.bfloat16)
 
198
  generated_texts = tokenizer.batch_decode(out)[0]
199
  return generated_texts
200
 
201
+ img_url = "https://i.pinimg.com/originals/32/c7/81/32c78115cb47fd4825e6907a83b7afff.jpg"
202
+ question = "who is the author?"
203
+ img = Image.open(urlopen(img_url))
204
 
205
  answer = gen_answer(
206
  model,
207
  tokenizer,
208
  clip,
209
  projection,
210
+ query=question,
211
+ special_embs=special_embs,
212
+ image=img
213
  )
214
 
215
+ img.show()
216
+ print(question)
217
  print(answer)
218
  ```
219