dinhanhx commited on
Commit
44c2fe5
1 Parent(s): 58f13d7

Add model files

Browse files
Files changed (3) hide show
  1. app.py +91 -0
  2. standalone_velvet.py +305 -0
  3. visual_bloom.torch +3 -0
app.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from standalone_velvet import setup_models
4
+
5
+ models_dict = setup_models("visual_bloom.torch")
6
+ visual_bloom = models_dict["visual_bloom"]
7
+ tokenizer = models_dict["tokenizer"]
8
+ image_feature_collator = models_dict["image_feature_collator"]
9
+
10
+
11
+ def run_inference(text_input, image_input):
12
+ image_features, image_attentions = image_feature_collator([image_input])
13
+ instruction_inputs = tokenizer([text_input], return_tensors="pt")
14
+ language_output = visual_bloom.generate(
15
+ image_features,
16
+ image_attentions,
17
+ instruction_inputs["input_ids"],
18
+ instruction_inputs["attention_mask"],
19
+ )
20
+
21
+ human_output = tokenizer.decode(language_output[0], skip_special_tokens=True)
22
+ return human_output.split(".")[0]
23
+
24
+
25
+ if __name__ == "__main__":
26
+ markdown = """
27
+ # Quick introduction
28
+
29
+ We have proposed a prompting vision language model.
30
+ The model can caption images and answer questions related to images.
31
+ It is trained on CC3M, COCO, VQAv2, OK-VQA, TextCaps, TextVQA.
32
+ As the result of using Google Translate,
33
+ these datasets collectively contain millions of image-text pairs in English and Vietnamese.
34
+
35
+ For further details, please refer to [Velvet](https://github.com/dinhanhx/velvet?tab=readme-ov-file#introduction).
36
+
37
+ # Usage
38
+
39
+ ## Run with pre-defined examples
40
+
41
+ 1. Scroll to bottom of the page to see the examples.
42
+ 2. Click one of them.
43
+ 3. Click the `Run Inference` button.
44
+
45
+ ## Run with user-defined inputs
46
+
47
+ ### 1. Prepare text input
48
+
49
+ Image captioning:
50
+ - `Generate caption in en:`
51
+ - `Generate caption in vi:`
52
+
53
+ Visual question answering:
54
+ - `Generate answer in en: <question>?`
55
+ - `Generate answer in vi: <question>?`
56
+
57
+ Don't forget to replace `<question>` with your own question either in English or Vietnamese.
58
+
59
+ To write the prompt, one can refer to the examples at the bottom of the page.
60
+
61
+ ### 2. Prepare image input
62
+
63
+ You can do as said in Image Input box. Wide range of image types are supported by PIL.
64
+
65
+ ### 3. Click the `Run Inference` button
66
+ """
67
+ examples = [
68
+ ["Generate caption in en:", "examples/cat.png"],
69
+ ["Generate caption in vi:", "examples/cat.png"],
70
+ ["Generate answer in en: what is the color of the cat?", "examples/cat.png"],
71
+ ["Generate answer in vi: màu sắc của con mèo là gì?", "examples/cat.png"],
72
+ ]
73
+
74
+ with gr.Blocks() as demo:
75
+ gr.Markdown(markdown)
76
+
77
+ text_input = gr.Textbox(label="Text Input")
78
+ image_input = gr.Image(label="Image Input", type="pil")
79
+
80
+ text_output = gr.Textbox(label="Text Output")
81
+
82
+ infer_button = gr.Button("Run Inference")
83
+ infer_button.click(
84
+ run_inference, inputs=[text_input, image_input], outputs=text_output
85
+ )
86
+
87
+ examples = gr.Examples(
88
+ examples=examples,
89
+ inputs=[text_input, image_input],
90
+ )
91
+ demo.launch()
standalone_velvet.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from dataclasses import dataclass
3
+ from typing import List
4
+
5
+ import torch
6
+ from einops import rearrange
7
+ from PIL import Image
8
+ from torch import nn
9
+ from transformers.models.bert import BertConfig, BertModel
10
+ from transformers.models.bloom import BloomConfig, BloomForCausalLM, BloomTokenizerFast
11
+ from transformers.models.convnext import ConvNextImageProcessor
12
+ from transformers.models.convnextv2 import ConvNextV2Config
13
+ from transformers.models.convnextv2.modeling_convnextv2 import ConvNextV2Model
14
+
15
+
16
+ # Copied from
17
+ # https://github.com/dinhanhx/velvet/blob/b70730654d26d399920964ed7e606a8f5586c9d1/velvet/collator.py#L13-L32
18
+ @dataclass
19
+ class ImageFeatureCollator:
20
+ image_processor: ConvNextImageProcessor
21
+ image_model: ConvNextV2Model
22
+
23
+ def __call__(self, batch_image: List[Image.Image]):
24
+ return self.tensorize_batch_image(batch_image=batch_image)
25
+
26
+ def tensorize_batch_image(self, batch_image: List[Image.Image]):
27
+ image_inputs = self.image_processor(batch_image, return_tensors="pt")
28
+
29
+ with torch.no_grad():
30
+ image_outputs = self.image_model(**image_inputs)
31
+ image_features = image_outputs["last_hidden_state"]
32
+
33
+ image_features = rearrange(image_features, "b c h w -> b h w c")
34
+ image_features = rearrange(image_features, "b h w c -> b (h w) c")
35
+
36
+ image_attentions = torch.ones(image_features.size()[:-1], dtype=torch.long)
37
+ return image_features, image_attentions
38
+
39
+
40
+ # Copied from
41
+ # https://github.com/dinhanhx/velvet/blob/b70730654d26d399920964ed7e606a8f5586c9d1/velvet/model/cutie.py#L6C1-L78C28
42
+ class IdentityForBertEmbeddings(nn.Module):
43
+ """To skip all BertEmbeddings because another text embeddings provided by another model are used"""
44
+
45
+ def __init__(self, *args, **kwargs) -> None:
46
+ super().__init__(*args, **kwargs)
47
+
48
+ def forward(self, **bert_embeddings_args):
49
+ inputs_embeds = bert_embeddings_args.get("inputs_embeds", None)
50
+ return inputs_embeds
51
+
52
+
53
+ class Cutie(nn.Module):
54
+ """Cutie - Qt - Query Transformer - Q-Former
55
+
56
+ Cutie is motivated by the underlying theoretical foundations of Q-Former presented in BLIP-2 https://arxiv.org/abs/2301.12597
57
+ It should be noted that Cutie differs from the specific approach described in the aforementioned paper
58
+ Both Cutie and Q-former have Query tokens.
59
+ Cutie uses the same unmodified BERT.
60
+ Q-former modifies BERT to behave differently on some tasks.
61
+ """
62
+
63
+ def __init__(
64
+ self,
65
+ bert_config: BertConfig,
66
+ max_query_length: int = 32,
67
+ language_model_ignore_label: int = -100,
68
+ ) -> None:
69
+ assert bert_config.is_decoder, "BERT must be a decoder"
70
+ assert bert_config.add_cross_attention, "BERT must have cross attention layer"
71
+ super().__init__()
72
+ self.bert_model = BertModel(bert_config, add_pooling_layer=False)
73
+ self.bert_model.embeddings = IdentityForBertEmbeddings()
74
+
75
+ self.query_tokens = nn.Parameter(
76
+ torch.zeros(1, max_query_length, bert_config.hidden_size)
77
+ )
78
+ self.query_tokens.data.normal_(mean=0.0, std=bert_config.initializer_range)
79
+ self.query_attentions = torch.ones(
80
+ self.query_tokens.size()[:-1], dtype=torch.long
81
+ )
82
+ self.query_labels = torch.full(
83
+ self.query_tokens.size()[:-1], language_model_ignore_label, dtype=torch.long
84
+ )
85
+
86
+ def forward(
87
+ self,
88
+ image_features: torch.Tensor,
89
+ image_attentions: torch.Tensor,
90
+ instruction_embeds: torch.Tensor,
91
+ instruction_attention_mask: torch.Tensor,
92
+ ):
93
+ batch_size = image_features.size(0)
94
+
95
+ query_tokens = self.query_tokens.expand(batch_size, -1, -1).to(
96
+ self.query_tokens.device
97
+ )
98
+ query_attentions = self.query_attentions.expand(batch_size, -1).to(
99
+ self.query_tokens.device
100
+ )
101
+
102
+ cat_embeds = torch.cat([query_tokens, instruction_embeds], dim=1)
103
+ cat_attentions = torch.cat(
104
+ [query_attentions, instruction_attention_mask], dim=1
105
+ )
106
+
107
+ bert_outputs = self.bert_model(
108
+ inputs_embeds=cat_embeds,
109
+ attention_mask=cat_attentions,
110
+ encoder_hidden_states=image_features,
111
+ encoder_attention_mask=image_attentions,
112
+ )
113
+ cutie_output = bert_outputs.last_hidden_state[:, : query_tokens.size(1), :]
114
+ return cutie_output
115
+
116
+
117
+ # Copied from
118
+ # https://github.com/dinhanhx/velvet/blob/b70730654d26d399920964ed7e606a8f5586c9d1/velvet/model/visual_bloom.py#L12C1-L162C31
119
+ class VisualBloom(nn.Module):
120
+ """A BLOOM-based model that can take image inputs"""
121
+
122
+ def __init__(
123
+ self,
124
+ convnextv2_config: ConvNextV2Config,
125
+ bert_config: BertConfig,
126
+ bloom_config: BloomConfig,
127
+ bloom_name: str,
128
+ use_frozen_bloom: bool = True,
129
+ ) -> None:
130
+ super().__init__()
131
+
132
+ if (
133
+ convnextv2_config.hidden_sizes[-1]
134
+ == bert_config.hidden_size
135
+ == bloom_config.hidden_size
136
+ ):
137
+ self.use_projection = False
138
+ warnings.warn(
139
+ "All embedding dimensions are equal. No linear projection layers are created."
140
+ )
141
+ else:
142
+ self.use_projection = True
143
+ self.text_to_cutie = nn.Linear(
144
+ bloom_config.hidden_size, bert_config.hidden_size
145
+ )
146
+ self.image_to_cutie = nn.Linear(
147
+ convnextv2_config.hidden_sizes[-1], bert_config.hidden_size
148
+ )
149
+ self.cutie_to_text = nn.Linear(
150
+ bert_config.hidden_size, bloom_config.hidden_size
151
+ )
152
+
153
+ self.cutie_model = Cutie(bert_config)
154
+
155
+ # Load and freeze BLOOM model
156
+ if use_frozen_bloom:
157
+ self.bloom_model = BloomForCausalLM.from_pretrained(bloom_name)
158
+ for param in self.bloom_model.parameters():
159
+ param.requires_grad = False
160
+ else:
161
+ self.bloom_model = BloomForCausalLM(bloom_config)
162
+
163
+ def forward(
164
+ self,
165
+ # Image model outputs - Q-former inputs
166
+ image_features: torch.Tensor,
167
+ image_attentions: torch.Tensor,
168
+ # Q-former inputs
169
+ instruction_input_ids: torch.Tensor,
170
+ instruction_attention_mask: torch.Tensor,
171
+ # Frozen language model inputs
172
+ language_model_input_ids: torch.Tensor,
173
+ language_model_attention_mask: torch.Tensor,
174
+ language_model_labels: torch.Tensor,
175
+ ):
176
+ instruction_embeds = self.bloom_model.transformer.word_embeddings(
177
+ instruction_input_ids
178
+ )
179
+ instruction_embeds = self.bloom_model.transformer.word_embeddings_layernorm(
180
+ instruction_embeds
181
+ )
182
+
183
+ if self.use_projection:
184
+ image_features = self.image_to_cutie(image_features)
185
+ instruction_embeds = self.text_to_cutie(instruction_embeds)
186
+
187
+ cutie_output = self.cutie_model(
188
+ image_features=image_features,
189
+ image_attentions=image_attentions,
190
+ instruction_embeds=instruction_embeds,
191
+ instruction_attention_mask=instruction_attention_mask,
192
+ )
193
+
194
+ if self.use_projection:
195
+ cutie_output = self.cutie_to_text(cutie_output)
196
+
197
+ cutie_attentions = self.cutie_model.query_attentions.expand(
198
+ cutie_output.size(0), -1
199
+ ).to(cutie_output.device)
200
+ cutie_labels = self.cutie_model.query_labels.expand(
201
+ cutie_output.size(0), -1
202
+ ).to(cutie_output.device)
203
+
204
+ language_model_embeds = self.bloom_model.transformer.word_embeddings(
205
+ language_model_input_ids
206
+ )
207
+ language_model_embeds = self.bloom_model.transformer.word_embeddings_layernorm(
208
+ language_model_embeds
209
+ )
210
+
211
+ cat_embeds = torch.cat([cutie_output, language_model_embeds], dim=1)
212
+ cat_attentions = torch.cat(
213
+ [cutie_attentions, language_model_attention_mask], dim=1
214
+ )
215
+ cat_labels = torch.cat([cutie_labels, language_model_labels], dim=1)
216
+
217
+ bloom_outputs = self.bloom_model(
218
+ inputs_embeds=cat_embeds, attention_mask=cat_attentions, labels=cat_labels
219
+ )
220
+ return bloom_outputs
221
+
222
+ @torch.no_grad()
223
+ def generate(
224
+ self,
225
+ # Image model outputs - Q-former inputs
226
+ image_features: torch.Tensor,
227
+ image_attentions: torch.Tensor,
228
+ # Q-former inputs
229
+ instruction_input_ids: torch.Tensor,
230
+ instruction_attention_mask: torch.Tensor,
231
+ ):
232
+ instruction_embeds = self.bloom_model.transformer.word_embeddings(
233
+ instruction_input_ids
234
+ )
235
+ instruction_embeds = self.bloom_model.transformer.word_embeddings_layernorm(
236
+ instruction_embeds
237
+ )
238
+
239
+ if self.use_projection:
240
+ image_features = self.image_to_cutie(image_features)
241
+ cutie_instruction_embeds = self.text_to_cutie(instruction_embeds)
242
+
243
+ cutie_output = self.cutie_model(
244
+ image_features=image_features,
245
+ image_attentions=image_attentions,
246
+ instruction_embeds=cutie_instruction_embeds,
247
+ instruction_attention_mask=instruction_attention_mask,
248
+ )
249
+
250
+ if self.use_projection:
251
+ cutie_output = self.cutie_to_text(cutie_output)
252
+
253
+ cutie_attentions = self.cutie_model.query_attentions.expand(
254
+ cutie_output.size(0), -1
255
+ ).to(cutie_output.device)
256
+
257
+ cat_embeds = torch.cat([cutie_output, instruction_embeds], dim=1)
258
+ cat_attentions = torch.cat(
259
+ [cutie_attentions, instruction_attention_mask], dim=1
260
+ )
261
+
262
+ language_output = self.bloom_model.generate(
263
+ inputs_embeds=cat_embeds,
264
+ attention_mask=cat_attentions,
265
+ max_length=96,
266
+ penalty_alpha=0.6,
267
+ top_k=4,
268
+ )
269
+ return language_output
270
+
271
+
272
+ def setup_models(visual_bloom_state_dict_path: str):
273
+ image_model_name = "facebook/convnextv2-large-22k-224"
274
+ image_config = ConvNextV2Config.from_pretrained(image_model_name)
275
+ image_processor = ConvNextImageProcessor.from_pretrained(image_model_name)
276
+ image_model = ConvNextV2Model.from_pretrained(image_model_name)
277
+ image_feature_collator = ImageFeatureCollator(image_processor, image_model)
278
+
279
+ bloom_model_name = "bigscience/bloomz-1b7"
280
+ bloom_config = BloomConfig.from_pretrained(bloom_model_name)
281
+ tokenizer = BloomTokenizerFast.from_pretrained(bloom_model_name)
282
+ tokenizer.padding_side = "right"
283
+
284
+ bert_config = BertConfig(
285
+ hidden_size=1024,
286
+ num_hidden_layers=6,
287
+ num_attention_heads=16,
288
+ is_decoder=True,
289
+ add_cross_attention=True,
290
+ )
291
+
292
+ visual_bloom = VisualBloom(
293
+ image_config,
294
+ bert_config,
295
+ bloom_config,
296
+ bloom_model_name,
297
+ use_frozen_bloom=False,
298
+ )
299
+ visual_bloom.load_state_dict(torch.load(visual_bloom_state_dict_path))
300
+ visual_bloom = visual_bloom.eval()
301
+ return {
302
+ "visual_bloom": visual_bloom,
303
+ "tokenizer": tokenizer,
304
+ "image_feature_collator": image_feature_collator,
305
+ }
visual_bloom.torch ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:18440703d035a942db21b82fe9aaf0d15895e46e97cfb7ae30217fa9c04daf0d
3
+ size 7265806579