GunaKoppula commited on
Commit
a8d9c50
1 Parent(s): efe75b3

Upload 19 files

Browse files
Experiments/clip_expt.ipynb ADDED
@@ -0,0 +1,840 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 2,
6
+ "id": "9fe51ce7-4c87-4186-9fd3-0fb18ac43e56",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "from PIL import Image\n",
11
+ "import requests\n",
12
+ "from transformers import AutoProcessor, CLIPVisionModel"
13
+ ]
14
+ },
15
+ {
16
+ "cell_type": "code",
17
+ "execution_count": 3,
18
+ "id": "0f4c21dd-4258-461d-8511-5be089d068a8",
19
+ "metadata": {},
20
+ "outputs": [],
21
+ "source": [
22
+ "model = CLIPVisionModel.from_pretrained(\"openai/clip-vit-base-patch32\", device_map=\"cuda:0\")\n",
23
+ "processor = AutoProcessor.from_pretrained(\"openai/clip-vit-base-patch32\", device_map=\"cuda:0\")"
24
+ ]
25
+ },
26
+ {
27
+ "cell_type": "code",
28
+ "execution_count": 4,
29
+ "id": "98b9f906-ffaa-4be4-8671-4ecf65f12c49",
30
+ "metadata": {},
31
+ "outputs": [],
32
+ "source": [
33
+ "# url = \"http://images.cocodataset.org/val2017/000000039769.jpg\"\n",
34
+ "# image = Image.open(requests.get(url, stream=True).raw)\n",
35
+ "image = Image.open(\"002579.jpg\")"
36
+ ]
37
+ },
38
+ {
39
+ "cell_type": "code",
40
+ "execution_count": 17,
41
+ "id": "54b2e4ce-b77b-4314-87f6-ca2a1970fc79",
42
+ "metadata": {},
43
+ "outputs": [],
44
+ "source": [
45
+ "# image"
46
+ ]
47
+ },
48
+ {
49
+ "cell_type": "code",
50
+ "execution_count": 18,
51
+ "id": "cdd65c58-007f-450b-8deb-f8b4f372a823",
52
+ "metadata": {},
53
+ "outputs": [],
54
+ "source": [
55
+ "# image = None"
56
+ ]
57
+ },
58
+ {
59
+ "cell_type": "code",
60
+ "execution_count": 5,
61
+ "id": "e9066c2e-c78b-49d1-979b-10d0f4f09441",
62
+ "metadata": {},
63
+ "outputs": [],
64
+ "source": [
65
+ "inputs = processor(images=image, return_tensors=\"pt\", device_map=\"cuda:0\")"
66
+ ]
67
+ },
68
+ {
69
+ "cell_type": "code",
70
+ "execution_count": 20,
71
+ "id": "e98b211d-29d9-4662-be0b-e011e89b0101",
72
+ "metadata": {},
73
+ "outputs": [],
74
+ "source": [
75
+ "# inputs"
76
+ ]
77
+ },
78
+ {
79
+ "cell_type": "code",
80
+ "execution_count": 6,
81
+ "id": "b030bd3d-4282-4074-98fe-97e658bd0f50",
82
+ "metadata": {},
83
+ "outputs": [
84
+ {
85
+ "data": {
86
+ "text/plain": [
87
+ "torch.Size([1, 3, 224, 224])"
88
+ ]
89
+ },
90
+ "execution_count": 6,
91
+ "metadata": {},
92
+ "output_type": "execute_result"
93
+ }
94
+ ],
95
+ "source": [
96
+ "inputs[\"pixel_values\"].shape"
97
+ ]
98
+ },
99
+ {
100
+ "cell_type": "code",
101
+ "execution_count": 22,
102
+ "id": "0ce68f11-1c88-4dd7-8b17-0d1de5811fe6",
103
+ "metadata": {},
104
+ "outputs": [],
105
+ "source": [
106
+ "outputs = model(inputs[\"pixel_values\"].to(\"cuda:0\"))\n",
107
+ "last_hidden_state = outputs.last_hidden_state\n",
108
+ "pooled_output = outputs.pooler_output # pooled CLS states"
109
+ ]
110
+ },
111
+ {
112
+ "cell_type": "code",
113
+ "execution_count": 23,
114
+ "id": "30cb0918-a30e-4246-b540-6b8e0d876807",
115
+ "metadata": {},
116
+ "outputs": [
117
+ {
118
+ "data": {
119
+ "text/plain": [
120
+ "torch.Size([1, 768])"
121
+ ]
122
+ },
123
+ "execution_count": 23,
124
+ "metadata": {},
125
+ "output_type": "execute_result"
126
+ }
127
+ ],
128
+ "source": [
129
+ "pooled_output.shape"
130
+ ]
131
+ },
132
+ {
133
+ "cell_type": "code",
134
+ "execution_count": 24,
135
+ "id": "6399543a-f23f-426d-8289-3bb52d293ece",
136
+ "metadata": {},
137
+ "outputs": [
138
+ {
139
+ "data": {
140
+ "text/plain": [
141
+ "torch.Size([1, 50, 768])"
142
+ ]
143
+ },
144
+ "execution_count": 24,
145
+ "metadata": {},
146
+ "output_type": "execute_result"
147
+ }
148
+ ],
149
+ "source": [
150
+ "last_hidden_state.shape"
151
+ ]
152
+ },
153
+ {
154
+ "cell_type": "code",
155
+ "execution_count": 25,
156
+ "id": "19a70443-5942-4937-b3ea-6a52d76e2b08",
157
+ "metadata": {},
158
+ "outputs": [
159
+ {
160
+ "data": {
161
+ "text/plain": [
162
+ "torch.Size([1, 768])"
163
+ ]
164
+ },
165
+ "execution_count": 25,
166
+ "metadata": {},
167
+ "output_type": "execute_result"
168
+ }
169
+ ],
170
+ "source": [
171
+ "outputs[1].shape"
172
+ ]
173
+ },
174
+ {
175
+ "cell_type": "code",
176
+ "execution_count": 8,
177
+ "id": "fa13903f-a94a-4839-ae5a-8df4f55c68b6",
178
+ "metadata": {},
179
+ "outputs": [],
180
+ "source": [
181
+ "import torch\n",
182
+ "from torch import nn\n",
183
+ "from transformers import CLIPVisionConfig,CLIPPreTrainedModel"
184
+ ]
185
+ },
186
+ {
187
+ "cell_type": "code",
188
+ "execution_count": 9,
189
+ "id": "b2bd9198-42f0-40c3-80e1-d167c0b038fb",
190
+ "metadata": {},
191
+ "outputs": [
192
+ {
193
+ "ename": "NameError",
194
+ "evalue": "name 'Optional' is not defined",
195
+ "output_type": "error",
196
+ "traceback": [
197
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
198
+ "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
199
+ "Cell \u001b[0;32mIn[9], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mCLIPVisionModelWithProjection\u001b[39;00m(CLIPPreTrainedModel):\n\u001b[1;32m 2\u001b[0m config_class \u001b[38;5;241m=\u001b[39m CLIPVisionConfig\n\u001b[1;32m 3\u001b[0m main_input_name \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpixel_values\u001b[39m\u001b[38;5;124m\"\u001b[39m\n",
200
+ "Cell \u001b[0;32mIn[9], line 20\u001b[0m, in \u001b[0;36mCLIPVisionModelWithProjection\u001b[0;34m()\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mget_input_embeddings\u001b[39m(\u001b[38;5;28mself\u001b[39m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m nn\u001b[38;5;241m.\u001b[39mModule:\n\u001b[1;32m 16\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mvision_model\u001b[38;5;241m.\u001b[39membeddings\u001b[38;5;241m.\u001b[39mpatch_embedding\n\u001b[1;32m 18\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\n\u001b[1;32m 19\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[0;32m---> 20\u001b[0m pixel_values: \u001b[43mOptional\u001b[49m[torch\u001b[38;5;241m.\u001b[39mFloatTensor] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 21\u001b[0m output_attentions: Optional[\u001b[38;5;28mbool\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 22\u001b[0m output_hidden_states: Optional[\u001b[38;5;28mbool\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 23\u001b[0m return_dict: Optional[\u001b[38;5;28mbool\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 24\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Union[Tuple, CLIPVisionModelOutput]:\n\u001b[1;32m 25\u001b[0m return_dict \u001b[38;5;241m=\u001b[39m return_dict \u001b[38;5;28;01mif\u001b[39;00m return_dict \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39muse_return_dict\n\u001b[1;32m 27\u001b[0m vision_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mvision_model(\n\u001b[1;32m 28\u001b[0m pixel_values\u001b[38;5;241m=\u001b[39mpixel_values,\n\u001b[1;32m 29\u001b[0m output_attentions\u001b[38;5;241m=\u001b[39moutput_attentions,\n\u001b[1;32m 30\u001b[0m output_hidden_states\u001b[38;5;241m=\u001b[39moutput_hidden_states,\n\u001b[1;32m 31\u001b[0m return_dict\u001b[38;5;241m=\u001b[39mreturn_dict,\n\u001b[1;32m 32\u001b[0m )\n",
201
+ "\u001b[0;31mNameError\u001b[0m: name 'Optional' is not defined"
202
+ ]
203
+ }
204
+ ],
205
+ "source": [
206
+ "class CLIPVisionModelWithProjection(CLIPPreTrainedModel):\n",
207
+ " config_class = CLIPVisionConfig\n",
208
+ " main_input_name = \"pixel_values\"\n",
209
+ "\n",
210
+ " def __init__(self, config: CLIPVisionConfig):\n",
211
+ " super().__init__(config)\n",
212
+ "\n",
213
+ " self.vision_model = CLIPVisionTransformer(config)\n",
214
+ "\n",
215
+ " self.visual_projection = nn.Linear(config.hidden_size, config.projection_dim, bias=False)\n",
216
+ "\n",
217
+ " # Initialize weights and apply final processing\n",
218
+ " self.post_init()\n",
219
+ "\n",
220
+ " def get_input_embeddings(self) -> nn.Module:\n",
221
+ " return self.vision_model.embeddings.patch_embedding\n",
222
+ "\n",
223
+ " def forward(\n",
224
+ " self,\n",
225
+ " pixel_values: Optional[torch.FloatTensor] = None,\n",
226
+ " output_attentions: Optional[bool] = None,\n",
227
+ " output_hidden_states: Optional[bool] = None,\n",
228
+ " return_dict: Optional[bool] = None,\n",
229
+ " ) -> Union[Tuple, CLIPVisionModelOutput]:\n",
230
+ " return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n",
231
+ "\n",
232
+ " vision_outputs = self.vision_model(\n",
233
+ " pixel_values=pixel_values,\n",
234
+ " output_attentions=output_attentions,\n",
235
+ " output_hidden_states=output_hidden_states,\n",
236
+ " return_dict=return_dict,\n",
237
+ " )\n",
238
+ "\n",
239
+ " pooled_output = vision_outputs[1] # pooled_output\n",
240
+ "\n",
241
+ " image_embeds = self.visual_projection(pooled_output)\n",
242
+ "\n",
243
+ " if not return_dict:\n",
244
+ " outputs = (image_embeds, vision_outputs[0]) + vision_outputs[2:]\n",
245
+ " return tuple(output for output in outputs if output is not None)\n",
246
+ "\n",
247
+ " return CLIPVisionModelOutput(\n",
248
+ " image_embeds=image_embeds,\n",
249
+ " last_hidden_state=vision_outputs.last_hidden_state,\n",
250
+ " hidden_states=vision_outputs.hidden_states,\n",
251
+ " attentions=vision_outputs.attentions,\n",
252
+ " )"
253
+ ]
254
+ },
255
+ {
256
+ "cell_type": "code",
257
+ "execution_count": 27,
258
+ "id": "68a9ee4a-d977-4725-842d-e64e0dd2f61d",
259
+ "metadata": {
260
+ "collapsed": true,
261
+ "jupyter": {
262
+ "outputs_hidden": true
263
+ }
264
+ },
265
+ "outputs": [
266
+ {
267
+ "name": "stderr",
268
+ "output_type": "stream",
269
+ "text": [
270
+ "loading configuration file config.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/config.json\n",
271
+ "`text_config` is `None`. Initializing the `CLIPTextConfig` with default values.\n",
272
+ "`vision_config` is `None`. initializing the `CLIPVisionConfig` with default values.\n",
273
+ "Model config CLIPConfig {\n",
274
+ " \"_name_or_path\": \"openai/clip-vit-base-patch32\",\n",
275
+ " \"architectures\": [\n",
276
+ " \"CLIPModel\"\n",
277
+ " ],\n",
278
+ " \"initializer_factor\": 1.0,\n",
279
+ " \"logit_scale_init_value\": 2.6592,\n",
280
+ " \"model_type\": \"clip\",\n",
281
+ " \"projection_dim\": 512,\n",
282
+ " \"text_config\": {\n",
283
+ " \"bos_token_id\": 0,\n",
284
+ " \"dropout\": 0.0,\n",
285
+ " \"eos_token_id\": 2,\n",
286
+ " \"model_type\": \"clip_text_model\"\n",
287
+ " },\n",
288
+ " \"transformers_version\": \"4.36.2\",\n",
289
+ " \"vision_config\": {\n",
290
+ " \"dropout\": 0.0,\n",
291
+ " \"model_type\": \"clip_vision_model\"\n",
292
+ " }\n",
293
+ "}\n",
294
+ "\n",
295
+ "loading weights file pytorch_model.bin from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/pytorch_model.bin\n",
296
+ "All model checkpoint weights were used when initializing CLIPModel.\n",
297
+ "\n",
298
+ "All the weights of CLIPModel were initialized from the model checkpoint at openai/clip-vit-base-patch32.\n",
299
+ "If your task is similar to the task the model of the checkpoint was trained on, you can already use CLIPModel for predictions without further training.\n",
300
+ "loading configuration file preprocessor_config.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/preprocessor_config.json\n",
301
+ "loading configuration file preprocessor_config.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/preprocessor_config.json\n",
302
+ "loading configuration file config.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/config.json\n",
303
+ "`text_config` is `None`. Initializing the `CLIPTextConfig` with default values.\n",
304
+ "`vision_config` is `None`. initializing the `CLIPVisionConfig` with default values.\n",
305
+ "Model config CLIPConfig {\n",
306
+ " \"_name_or_path\": \"openai/clip-vit-base-patch32\",\n",
307
+ " \"architectures\": [\n",
308
+ " \"CLIPModel\"\n",
309
+ " ],\n",
310
+ " \"initializer_factor\": 1.0,\n",
311
+ " \"logit_scale_init_value\": 2.6592,\n",
312
+ " \"model_type\": \"clip\",\n",
313
+ " \"projection_dim\": 512,\n",
314
+ " \"text_config\": {\n",
315
+ " \"bos_token_id\": 0,\n",
316
+ " \"dropout\": 0.0,\n",
317
+ " \"eos_token_id\": 2,\n",
318
+ " \"model_type\": \"clip_text_model\"\n",
319
+ " },\n",
320
+ " \"transformers_version\": \"4.36.2\",\n",
321
+ " \"vision_config\": {\n",
322
+ " \"dropout\": 0.0,\n",
323
+ " \"model_type\": \"clip_vision_model\"\n",
324
+ " }\n",
325
+ "}\n",
326
+ "\n",
327
+ "loading configuration file preprocessor_config.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/preprocessor_config.json\n",
328
+ "size should be a dictionary on of the following set of keys: ({'width', 'height'}, {'shortest_edge'}, {'longest_edge', 'shortest_edge'}, {'longest_edge'}), got 224. Converted to {'shortest_edge': 224}.\n",
329
+ "crop_size should be a dictionary on of the following set of keys: ({'width', 'height'}, {'shortest_edge'}, {'longest_edge', 'shortest_edge'}, {'longest_edge'}), got 224. Converted to {'height': 224, 'width': 224}.\n",
330
+ "Image processor CLIPImageProcessor {\n",
331
+ " \"crop_size\": {\n",
332
+ " \"height\": 224,\n",
333
+ " \"width\": 224\n",
334
+ " },\n",
335
+ " \"do_center_crop\": true,\n",
336
+ " \"do_convert_rgb\": true,\n",
337
+ " \"do_normalize\": true,\n",
338
+ " \"do_rescale\": true,\n",
339
+ " \"do_resize\": true,\n",
340
+ " \"feature_extractor_type\": \"CLIPFeatureExtractor\",\n",
341
+ " \"image_mean\": [\n",
342
+ " 0.48145466,\n",
343
+ " 0.4578275,\n",
344
+ " 0.40821073\n",
345
+ " ],\n",
346
+ " \"image_processor_type\": \"CLIPImageProcessor\",\n",
347
+ " \"image_std\": [\n",
348
+ " 0.26862954,\n",
349
+ " 0.26130258,\n",
350
+ " 0.27577711\n",
351
+ " ],\n",
352
+ " \"resample\": 3,\n",
353
+ " \"rescale_factor\": 0.00392156862745098,\n",
354
+ " \"size\": {\n",
355
+ " \"shortest_edge\": 224\n",
356
+ " }\n",
357
+ "}\n",
358
+ "\n",
359
+ "loading file vocab.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/vocab.json\n",
360
+ "loading file merges.txt from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/merges.txt\n",
361
+ "loading file tokenizer.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/tokenizer.json\n",
362
+ "loading file added_tokens.json from cache at None\n",
363
+ "loading file special_tokens_map.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/special_tokens_map.json\n",
364
+ "loading file tokenizer_config.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/tokenizer_config.json\n",
365
+ "loading configuration file config.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/config.json\n",
366
+ "`text_config` is `None`. Initializing the `CLIPTextConfig` with default values.\n",
367
+ "`vision_config` is `None`. initializing the `CLIPVisionConfig` with default values.\n",
368
+ "Model config CLIPConfig {\n",
369
+ " \"_name_or_path\": \"openai/clip-vit-base-patch32\",\n",
370
+ " \"architectures\": [\n",
371
+ " \"CLIPModel\"\n",
372
+ " ],\n",
373
+ " \"initializer_factor\": 1.0,\n",
374
+ " \"logit_scale_init_value\": 2.6592,\n",
375
+ " \"model_type\": \"clip\",\n",
376
+ " \"projection_dim\": 512,\n",
377
+ " \"text_config\": {\n",
378
+ " \"bos_token_id\": 0,\n",
379
+ " \"dropout\": 0.0,\n",
380
+ " \"eos_token_id\": 2,\n",
381
+ " \"model_type\": \"clip_text_model\"\n",
382
+ " },\n",
383
+ " \"transformers_version\": \"4.36.2\",\n",
384
+ " \"vision_config\": {\n",
385
+ " \"dropout\": 0.0,\n",
386
+ " \"model_type\": \"clip_vision_model\"\n",
387
+ " }\n",
388
+ "}\n",
389
+ "\n"
390
+ ]
391
+ }
392
+ ],
393
+ "source": [
394
+ "from PIL import Image\n",
395
+ "import requests\n",
396
+ "from transformers import AutoProcessor, CLIPModel\n",
397
+ "\n",
398
+ "model = CLIPModel.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
399
+ "processor = AutoProcessor.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
400
+ "\n",
401
+ "url = \"http://images.cocodataset.org/val2017/000000039769.jpg\"\n",
402
+ "image = Image.open(requests.get(url, stream=True).raw)\n",
403
+ "\n",
404
+ "inputs = processor(images=image, return_tensors=\"pt\")\n",
405
+ "\n",
406
+ "image_features = model.get_image_features(**inputs)"
407
+ ]
408
+ },
409
+ {
410
+ "cell_type": "code",
411
+ "execution_count": 29,
412
+ "id": "9ff63766-b706-452b-b735-bf9000fb9c20",
413
+ "metadata": {},
414
+ "outputs": [
415
+ {
416
+ "data": {
417
+ "text/plain": [
418
+ "torch.Size([1, 512])"
419
+ ]
420
+ },
421
+ "execution_count": 29,
422
+ "metadata": {},
423
+ "output_type": "execute_result"
424
+ }
425
+ ],
426
+ "source": [
427
+ "image_features.shape"
428
+ ]
429
+ },
430
+ {
431
+ "cell_type": "code",
432
+ "execution_count": 30,
433
+ "id": "82566e7b-3c91-421a-94c5-f1e2b3e91c8c",
434
+ "metadata": {
435
+ "collapsed": true,
436
+ "jupyter": {
437
+ "outputs_hidden": true
438
+ }
439
+ },
440
+ "outputs": [
441
+ {
442
+ "name": "stderr",
443
+ "output_type": "stream",
444
+ "text": [
445
+ "loading configuration file config.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/config.json\n",
446
+ "Model config CLIPVisionConfig {\n",
447
+ " \"attention_dropout\": 0.0,\n",
448
+ " \"dropout\": 0.0,\n",
449
+ " \"hidden_act\": \"quick_gelu\",\n",
450
+ " \"hidden_size\": 768,\n",
451
+ " \"image_size\": 224,\n",
452
+ " \"initializer_factor\": 1.0,\n",
453
+ " \"initializer_range\": 0.02,\n",
454
+ " \"intermediate_size\": 3072,\n",
455
+ " \"layer_norm_eps\": 1e-05,\n",
456
+ " \"model_type\": \"clip_vision_model\",\n",
457
+ " \"num_attention_heads\": 12,\n",
458
+ " \"num_channels\": 3,\n",
459
+ " \"num_hidden_layers\": 12,\n",
460
+ " \"patch_size\": 32,\n",
461
+ " \"projection_dim\": 512,\n",
462
+ " \"transformers_version\": \"4.36.2\"\n",
463
+ "}\n",
464
+ "\n",
465
+ "loading weights file pytorch_model.bin from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/pytorch_model.bin\n",
466
+ "Some weights of the model checkpoint at openai/clip-vit-base-patch32 were not used when initializing CLIPVisionModel: ['text_model.encoder.layers.8.mlp.fc2.weight', 'text_model.encoder.layers.4.self_attn.v_proj.weight', 'text_model.encoder.layers.2.mlp.fc2.bias', 'text_model.encoder.layers.2.self_attn.q_proj.bias', 'text_model.encoder.layers.4.self_attn.v_proj.bias', 'text_model.encoder.layers.6.self_attn.k_proj.bias', 'text_model.encoder.layers.9.self_attn.k_proj.weight', 'text_model.encoder.layers.11.self_attn.q_proj.bias', 'text_model.encoder.layers.3.self_attn.out_proj.bias', 'text_model.encoder.layers.11.self_attn.k_proj.bias', 'text_model.encoder.layers.7.self_attn.k_proj.bias', 'text_model.encoder.layers.1.self_attn.q_proj.weight', 'text_model.encoder.layers.11.layer_norm1.bias', 'text_model.encoder.layers.11.mlp.fc2.bias', 'text_model.encoder.layers.10.layer_norm1.bias', 'text_model.encoder.layers.7.self_attn.q_proj.bias', 'text_model.encoder.layers.11.self_attn.k_proj.weight', 'text_model.encoder.layers.5.self_attn.q_proj.bias', 'text_model.encoder.layers.0.self_attn.v_proj.bias', 'logit_scale', 'text_model.encoder.layers.8.mlp.fc1.bias', 'text_model.encoder.layers.6.layer_norm1.bias', 'text_model.encoder.layers.5.self_attn.out_proj.weight', 'text_model.encoder.layers.7.self_attn.out_proj.bias', 'text_model.embeddings.token_embedding.weight', 'text_model.encoder.layers.8.layer_norm2.bias', 'text_model.encoder.layers.9.self_attn.v_proj.weight', 'text_model.encoder.layers.1.mlp.fc1.weight', 'text_model.encoder.layers.0.layer_norm1.weight', 'text_model.encoder.layers.6.self_attn.k_proj.weight', 'text_model.encoder.layers.3.self_attn.q_proj.weight', 'text_model.encoder.layers.2.layer_norm1.weight', 'text_model.encoder.layers.0.self_attn.v_proj.weight', 'text_model.encoder.layers.7.self_attn.q_proj.weight', 'text_model.encoder.layers.1.layer_norm2.weight', 'text_model.encoder.layers.2.self_attn.out_proj.weight', 'text_model.encoder.layers.3.self_attn.k_proj.weight', 'text_model.encoder.layers.7.mlp.fc2.bias', 'text_model.encoder.layers.10.self_attn.out_proj.weight', 'text_model.encoder.layers.2.self_attn.q_proj.weight', 'text_model.encoder.layers.1.self_attn.k_proj.weight', 'text_model.encoder.layers.4.layer_norm1.weight', 'text_model.encoder.layers.0.self_attn.q_proj.weight', 'text_model.encoder.layers.0.self_attn.out_proj.bias', 'text_model.encoder.layers.4.self_attn.out_proj.bias', 'text_model.encoder.layers.5.self_attn.k_proj.weight', 'visual_projection.weight', 'text_model.encoder.layers.6.layer_norm2.bias', 'text_model.encoder.layers.6.layer_norm1.weight', 'text_model.encoder.layers.4.self_attn.out_proj.weight', 'text_model.encoder.layers.10.mlp.fc2.bias', 'text_model.encoder.layers.10.mlp.fc1.weight', 'text_model.encoder.layers.6.self_attn.out_proj.weight', 'text_model.encoder.layers.9.layer_norm1.weight', 'text_model.encoder.layers.11.layer_norm2.weight', 'text_model.encoder.layers.6.self_attn.q_proj.bias', 'text_model.encoder.layers.5.mlp.fc1.weight', 'text_model.encoder.layers.2.mlp.fc1.weight', 'text_model.encoder.layers.11.self_attn.out_proj.weight', 'text_model.encoder.layers.0.self_attn.out_proj.weight', 'text_model.encoder.layers.11.mlp.fc2.weight', 'text_model.encoder.layers.7.layer_norm2.weight', 'text_model.encoder.layers.10.self_attn.v_proj.bias', 'text_model.encoder.layers.9.mlp.fc1.bias', 'text_model.encoder.layers.8.self_attn.v_proj.weight', 'text_model.encoder.layers.3.layer_norm1.bias', 'text_model.encoder.layers.6.self_attn.v_proj.bias', 'text_model.encoder.layers.1.self_attn.v_proj.bias', 'text_model.encoder.layers.9.self_attn.q_proj.weight', 'text_model.encoder.layers.4.self_attn.k_proj.weight', 'text_model.encoder.layers.7.layer_norm1.weight', 'text_model.encoder.layers.10.self_attn.k_proj.weight', 'text_model.encoder.layers.7.self_attn.v_proj.bias', 'text_model.encoder.layers.7.mlp.fc1.bias', 'text_model.encoder.layers.11.mlp.fc1.weight', 'text_model.encoder.layers.2.mlp.fc1.bias', 'text_model.encoder.layers.3.mlp.fc2.bias', 'text_model.encoder.layers.8.self_attn.q_proj.weight', 'text_model.encoder.layers.0.mlp.fc1.weight', 'text_model.encoder.layers.11.self_attn.out_proj.bias', 'text_model.encoder.layers.1.self_attn.v_proj.weight', 'text_model.encoder.layers.0.self_attn.k_proj.weight', 'text_model.encoder.layers.9.layer_norm1.bias', 'text_model.final_layer_norm.weight', 'text_model.encoder.layers.3.layer_norm1.weight', 'text_model.encoder.layers.4.mlp.fc1.bias', 'text_model.encoder.layers.1.layer_norm1.weight', 'text_model.encoder.layers.10.layer_norm2.bias', 'text_model.encoder.layers.9.self_attn.v_proj.bias', 'text_model.encoder.layers.10.self_attn.k_proj.bias', 'text_model.encoder.layers.8.mlp.fc2.bias', 'text_model.encoder.layers.5.mlp.fc2.bias', 'text_model.encoder.layers.6.self_attn.q_proj.weight', 'text_model.encoder.layers.5.self_attn.out_proj.bias', 'text_model.encoder.layers.9.mlp.fc2.bias', 'text_model.encoder.layers.5.layer_norm2.weight', 'text_model.encoder.layers.2.mlp.fc2.weight', 'text_model.encoder.layers.3.self_attn.out_proj.weight', 'text_model.encoder.layers.6.mlp.fc2.weight', 'text_model.encoder.layers.1.self_attn.out_proj.weight', 'text_model.encoder.layers.1.mlp.fc2.bias', 'text_model.encoder.layers.7.mlp.fc2.weight', 'text_model.encoder.layers.10.self_attn.v_proj.weight', 'text_model.encoder.layers.11.self_attn.v_proj.bias', 'text_model.encoder.layers.4.layer_norm1.bias', 'text_model.encoder.layers.4.layer_norm2.bias', 'text_model.encoder.layers.8.self_attn.q_proj.bias', 'text_model.embeddings.position_ids', 'text_model.encoder.layers.10.layer_norm2.weight', 'text_model.encoder.layers.1.self_attn.out_proj.bias', 'text_model.encoder.layers.2.layer_norm2.weight', 'text_model.encoder.layers.10.self_attn.q_proj.weight', 'text_model.encoder.layers.4.mlp.fc1.weight', 'text_model.encoder.layers.8.layer_norm1.bias', 'text_model.encoder.layers.2.self_attn.k_proj.weight', 'text_model.encoder.layers.5.mlp.fc1.bias', 'text_model.encoder.layers.9.self_attn.out_proj.bias', 'text_model.encoder.layers.7.self_attn.v_proj.weight', 'text_model.encoder.layers.2.self_attn.k_proj.bias', 'text_model.encoder.layers.5.self_attn.k_proj.bias', 'text_model.encoder.layers.8.self_attn.out_proj.bias', 'text_model.encoder.layers.7.self_attn.k_proj.weight', 'text_model.encoder.layers.6.mlp.fc1.weight', 'text_model.encoder.layers.6.mlp.fc1.bias', 'text_model.encoder.layers.3.self_attn.v_proj.weight', 'text_model.encoder.layers.3.self_attn.q_proj.bias', 'text_model.encoder.layers.9.self_attn.out_proj.weight', 'text_model.encoder.layers.3.mlp.fc1.bias', 'text_model.encoder.layers.0.self_attn.q_proj.bias', 'text_model.encoder.layers.1.layer_norm2.bias', 'text_model.encoder.layers.8.layer_norm2.weight', 'text_model.encoder.layers.5.self_attn.q_proj.weight', 'text_model.encoder.layers.4.layer_norm2.weight', 'text_model.encoder.layers.4.mlp.fc2.bias', 'text_model.encoder.layers.9.mlp.fc2.weight', 'text_model.encoder.layers.8.self_attn.k_proj.weight', 'text_model.encoder.layers.10.layer_norm1.weight', 'text_model.encoder.layers.0.self_attn.k_proj.bias', 'text_model.encoder.layers.8.self_attn.k_proj.bias', 'text_model.encoder.layers.9.layer_norm2.weight', 'text_model.encoder.layers.4.self_attn.k_proj.bias', 'text_model.encoder.layers.6.layer_norm2.weight', 'text_model.encoder.layers.0.layer_norm2.weight', 'text_model.encoder.layers.5.self_attn.v_proj.bias', 'text_model.encoder.layers.3.layer_norm2.bias', 'text_model.encoder.layers.8.mlp.fc1.weight', 'text_model.encoder.layers.4.self_attn.q_proj.bias', 'text_model.encoder.layers.8.layer_norm1.weight', 'text_model.encoder.layers.2.self_attn.v_proj.weight', 'text_model.encoder.layers.3.self_attn.v_proj.bias', 'text_model.encoder.layers.11.mlp.fc1.bias', 'text_model.encoder.layers.6.mlp.fc2.bias', 'text_model.encoder.layers.1.mlp.fc1.bias', 'text_model.encoder.layers.2.self_attn.v_proj.bias', 'text_model.encoder.layers.5.mlp.fc2.weight', 'text_model.encoder.layers.8.self_attn.v_proj.bias', 'text_model.encoder.layers.10.self_attn.out_proj.bias', 'text_model.encoder.layers.5.layer_norm1.bias', 'text_model.encoder.layers.5.self_attn.v_proj.weight', 'text_model.encoder.layers.10.self_attn.q_proj.bias', 'text_model.encoder.layers.2.layer_norm2.bias', 'text_model.encoder.layers.7.layer_norm1.bias', 'text_model.encoder.layers.4.mlp.fc2.weight', 'text_model.encoder.layers.10.mlp.fc2.weight', 'text_model.encoder.layers.3.mlp.fc1.weight', 'text_model.encoder.layers.5.layer_norm2.bias', 'text_model.encoder.layers.9.self_attn.q_proj.bias', 'text_model.encoder.layers.1.self_attn.k_proj.bias', 'text_model.encoder.layers.7.self_attn.out_proj.weight', 'text_model.encoder.layers.0.mlp.fc2.weight', 'text_model.encoder.layers.11.self_attn.v_proj.weight', 'text_model.encoder.layers.1.layer_norm1.bias', 'text_model.encoder.layers.1.mlp.fc2.weight', 'text_model.encoder.layers.9.layer_norm2.bias', 'text_model.encoder.layers.9.self_attn.k_proj.bias', 'text_model.encoder.layers.11.layer_norm1.weight', 'text_model.encoder.layers.8.self_attn.out_proj.weight', 'text_model.encoder.layers.0.layer_norm1.bias', 'text_model.encoder.layers.7.mlp.fc1.weight', 'text_model.encoder.layers.0.mlp.fc1.bias', 'text_model.encoder.layers.0.layer_norm2.bias', 'text_model.encoder.layers.3.self_attn.k_proj.bias', 'text_model.encoder.layers.5.layer_norm1.weight', 'text_model.encoder.layers.3.layer_norm2.weight', 'text_model.encoder.layers.1.self_attn.q_proj.bias', 'text_model.encoder.layers.2.self_attn.out_proj.bias', 'text_model.encoder.layers.3.mlp.fc2.weight', 'text_model.encoder.layers.11.self_attn.q_proj.weight', 'text_model.final_layer_norm.bias', 'text_model.encoder.layers.6.self_attn.v_proj.weight', 'text_model.encoder.layers.0.mlp.fc2.bias', 'text_model.encoder.layers.7.layer_norm2.bias', 'text_model.encoder.layers.10.mlp.fc1.bias', 'text_model.embeddings.position_embedding.weight', 'text_model.encoder.layers.6.self_attn.out_proj.bias', 'text_model.encoder.layers.2.layer_norm1.bias', 'text_model.encoder.layers.9.mlp.fc1.weight', 'text_projection.weight', 'text_model.encoder.layers.11.layer_norm2.bias', 'text_model.encoder.layers.4.self_attn.q_proj.weight']\n",
467
+ "- This IS expected if you are initializing CLIPVisionModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
468
+ "- This IS NOT expected if you are initializing CLIPVisionModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
469
+ "All the weights of CLIPVisionModel were initialized from the model checkpoint at openai/clip-vit-base-patch32.\n",
470
+ "If your task is similar to the task the model of the checkpoint was trained on, you can already use CLIPVisionModel for predictions without further training.\n",
471
+ "loading configuration file preprocessor_config.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/preprocessor_config.json\n",
472
+ "loading configuration file preprocessor_config.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/preprocessor_config.json\n",
473
+ "loading configuration file config.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/config.json\n",
474
+ "`text_config` is `None`. Initializing the `CLIPTextConfig` with default values.\n",
475
+ "`vision_config` is `None`. initializing the `CLIPVisionConfig` with default values.\n",
476
+ "Model config CLIPConfig {\n",
477
+ " \"_name_or_path\": \"openai/clip-vit-base-patch32\",\n",
478
+ " \"architectures\": [\n",
479
+ " \"CLIPModel\"\n",
480
+ " ],\n",
481
+ " \"initializer_factor\": 1.0,\n",
482
+ " \"logit_scale_init_value\": 2.6592,\n",
483
+ " \"model_type\": \"clip\",\n",
484
+ " \"projection_dim\": 512,\n",
485
+ " \"text_config\": {\n",
486
+ " \"bos_token_id\": 0,\n",
487
+ " \"dropout\": 0.0,\n",
488
+ " \"eos_token_id\": 2,\n",
489
+ " \"model_type\": \"clip_text_model\"\n",
490
+ " },\n",
491
+ " \"transformers_version\": \"4.36.2\",\n",
492
+ " \"vision_config\": {\n",
493
+ " \"dropout\": 0.0,\n",
494
+ " \"model_type\": \"clip_vision_model\"\n",
495
+ " }\n",
496
+ "}\n",
497
+ "\n",
498
+ "loading configuration file preprocessor_config.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/preprocessor_config.json\n",
499
+ "size should be a dictionary on of the following set of keys: ({'width', 'height'}, {'shortest_edge'}, {'longest_edge', 'shortest_edge'}, {'longest_edge'}), got 224. Converted to {'shortest_edge': 224}.\n",
500
+ "crop_size should be a dictionary on of the following set of keys: ({'width', 'height'}, {'shortest_edge'}, {'longest_edge', 'shortest_edge'}, {'longest_edge'}), got 224. Converted to {'height': 224, 'width': 224}.\n",
501
+ "Image processor CLIPImageProcessor {\n",
502
+ " \"crop_size\": {\n",
503
+ " \"height\": 224,\n",
504
+ " \"width\": 224\n",
505
+ " },\n",
506
+ " \"do_center_crop\": true,\n",
507
+ " \"do_convert_rgb\": true,\n",
508
+ " \"do_normalize\": true,\n",
509
+ " \"do_rescale\": true,\n",
510
+ " \"do_resize\": true,\n",
511
+ " \"feature_extractor_type\": \"CLIPFeatureExtractor\",\n",
512
+ " \"image_mean\": [\n",
513
+ " 0.48145466,\n",
514
+ " 0.4578275,\n",
515
+ " 0.40821073\n",
516
+ " ],\n",
517
+ " \"image_processor_type\": \"CLIPImageProcessor\",\n",
518
+ " \"image_std\": [\n",
519
+ " 0.26862954,\n",
520
+ " 0.26130258,\n",
521
+ " 0.27577711\n",
522
+ " ],\n",
523
+ " \"resample\": 3,\n",
524
+ " \"rescale_factor\": 0.00392156862745098,\n",
525
+ " \"size\": {\n",
526
+ " \"shortest_edge\": 224\n",
527
+ " }\n",
528
+ "}\n",
529
+ "\n",
530
+ "loading file vocab.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/vocab.json\n",
531
+ "loading file merges.txt from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/merges.txt\n",
532
+ "loading file tokenizer.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/tokenizer.json\n",
533
+ "loading file added_tokens.json from cache at None\n",
534
+ "loading file special_tokens_map.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/special_tokens_map.json\n",
535
+ "loading file tokenizer_config.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/tokenizer_config.json\n",
536
+ "loading configuration file config.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/config.json\n",
537
+ "`text_config` is `None`. Initializing the `CLIPTextConfig` with default values.\n",
538
+ "`vision_config` is `None`. initializing the `CLIPVisionConfig` with default values.\n",
539
+ "Model config CLIPConfig {\n",
540
+ " \"_name_or_path\": \"openai/clip-vit-base-patch32\",\n",
541
+ " \"architectures\": [\n",
542
+ " \"CLIPModel\"\n",
543
+ " ],\n",
544
+ " \"initializer_factor\": 1.0,\n",
545
+ " \"logit_scale_init_value\": 2.6592,\n",
546
+ " \"model_type\": \"clip\",\n",
547
+ " \"projection_dim\": 512,\n",
548
+ " \"text_config\": {\n",
549
+ " \"bos_token_id\": 0,\n",
550
+ " \"dropout\": 0.0,\n",
551
+ " \"eos_token_id\": 2,\n",
552
+ " \"model_type\": \"clip_text_model\"\n",
553
+ " },\n",
554
+ " \"transformers_version\": \"4.36.2\",\n",
555
+ " \"vision_config\": {\n",
556
+ " \"dropout\": 0.0,\n",
557
+ " \"model_type\": \"clip_vision_model\"\n",
558
+ " }\n",
559
+ "}\n",
560
+ "\n"
561
+ ]
562
+ }
563
+ ],
564
+ "source": [
565
+ "from PIL import Image\n",
566
+ "import requests\n",
567
+ "from transformers import AutoProcessor, CLIPVisionModel\n",
568
+ "\n",
569
+ "model = CLIPVisionModel.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
570
+ "processor = AutoProcessor.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
571
+ "\n",
572
+ "url = \"http://images.cocodataset.org/val2017/000000039769.jpg\"\n",
573
+ "image = Image.open(requests.get(url, stream=True).raw)\n",
574
+ "\n",
575
+ "inputs = processor(images=image, return_tensors=\"pt\")\n",
576
+ "\n",
577
+ "outputs = model(**inputs)\n",
578
+ "last_hidden_state = outputs.last_hidden_state\n",
579
+ "pooled_output = outputs.pooler_output # pooled CLS states"
580
+ ]
581
+ },
582
+ {
583
+ "cell_type": "code",
584
+ "execution_count": 31,
585
+ "id": "bcf0a7b3-6cbb-492e-bc2c-42e3edbe6a0c",
586
+ "metadata": {},
587
+ "outputs": [
588
+ {
589
+ "data": {
590
+ "text/plain": [
591
+ "torch.Size([1, 768])"
592
+ ]
593
+ },
594
+ "execution_count": 31,
595
+ "metadata": {},
596
+ "output_type": "execute_result"
597
+ }
598
+ ],
599
+ "source": [
600
+ "pooled_output.shape"
601
+ ]
602
+ },
603
+ {
604
+ "cell_type": "code",
605
+ "execution_count": 10,
606
+ "id": "67240294-c7a0-4e94-a8c1-86bfe1b21977",
607
+ "metadata": {},
608
+ "outputs": [],
609
+ "source": [
610
+ "from transformers import CLIPPreTrainedModel\n",
611
+ "from transformers.models.clip.modeling_clip import CLIPVisionModelOutput, CLIPVisionTransformer\n",
612
+ "from typing import Optional, Union, Tuple"
613
+ ]
614
+ },
615
+ {
616
+ "cell_type": "code",
617
+ "execution_count": 54,
618
+ "id": "cc9b20db-7f84-44c3-9c78-e84164ccc192",
619
+ "metadata": {},
620
+ "outputs": [],
621
+ "source": [
622
+ "class VisionLanguageConnector(nn.Module):\n",
623
+ " def __init__(self, hidden_size, projection_dim):\n",
624
+ " super().__init__()\n",
625
+ " self.mlp = nn.Sequential(\n",
626
+ " nn.Linear(hidden_size, hidden_size, bias=False),\n",
627
+ " nn.GELU(),\n",
628
+ " nn.Linear(hidden_size, projection_dim, bias=False)\n",
629
+ " )\n",
630
+ "\n",
631
+ " def forward(self, x):\n",
632
+ " return self.mlp(x)\n",
633
+ " \n",
634
+ "class ClipWithProjection(CLIPPreTrainedModel):\n",
635
+ " config_class = CLIPVisionConfig\n",
636
+ " main_input_name = \"pixel_values\"\n",
637
+ "\n",
638
+ " def __init__(self, config: CLIPVisionConfig):\n",
639
+ " super().__init__(config)\n",
640
+ "\n",
641
+ " self.vision_model = CLIPVisionTransformer(config)\n",
642
+ " self.vision_model.\n",
643
+ " self.vision_language_connector = VisionLanguageConnector(config.hidden_size, config.projection_dim)\n",
644
+ "\n",
645
+ " # Initialize weights and apply final processing\n",
646
+ " self.post_init()\n",
647
+ "\n",
648
+ " def forward(\n",
649
+ " self,\n",
650
+ " pixel_values: Optional[torch.FloatTensor] = None,\n",
651
+ " output_attentions: Optional[bool] = None,\n",
652
+ " output_hidden_states: Optional[bool] = None,\n",
653
+ " return_dict: Optional[bool] = None,\n",
654
+ " ) -> Union[Tuple, CLIPVisionModelOutput]:\n",
655
+ " return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n",
656
+ "\n",
657
+ " vision_outputs = self.vision_model(\n",
658
+ " pixel_values=pixel_values,\n",
659
+ " output_attentions=output_attentions,\n",
660
+ " output_hidden_states=output_hidden_states,\n",
661
+ " return_dict=return_dict,\n",
662
+ " )\n",
663
+ "\n",
664
+ " pooled_output = vision_outputs[1] # pooled_output\n",
665
+ "\n",
666
+ " image_embeds = self.vision_language_connector(pooled_output)\n",
667
+ "\n",
668
+ " if not return_dict:\n",
669
+ " outputs = (image_embeds, vision_outputs[0]) + vision_outputs[2:]\n",
670
+ " return tuple(output for output in outputs if output is not None)\n",
671
+ "\n",
672
+ " return CLIPVisionModelOutput(\n",
673
+ " image_embeds=image_embeds,\n",
674
+ " last_hidden_state=vision_outputs.last_hidden_state,\n",
675
+ " hidden_states=vision_outputs.hidden_states,\n",
676
+ " attentions=vision_outputs.attentions,\n",
677
+ " )"
678
+ ]
679
+ },
680
+ {
681
+ "cell_type": "code",
682
+ "execution_count": 55,
683
+ "id": "a4892ab8-39d2-41c9-ad2a-04711c22b95f",
684
+ "metadata": {
685
+ "collapsed": true,
686
+ "jupyter": {
687
+ "outputs_hidden": true
688
+ }
689
+ },
690
+ "outputs": [
691
+ {
692
+ "name": "stderr",
693
+ "output_type": "stream",
694
+ "text": [
695
+ "loading configuration file config.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/config.json\n",
696
+ "Model config CLIPVisionConfig {\n",
697
+ " \"attention_dropout\": 0.0,\n",
698
+ " \"dropout\": 0.0,\n",
699
+ " \"hidden_act\": \"quick_gelu\",\n",
700
+ " \"hidden_size\": 768,\n",
701
+ " \"image_size\": 224,\n",
702
+ " \"initializer_factor\": 1.0,\n",
703
+ " \"initializer_range\": 0.02,\n",
704
+ " \"intermediate_size\": 3072,\n",
705
+ " \"layer_norm_eps\": 1e-05,\n",
706
+ " \"model_type\": \"clip_vision_model\",\n",
707
+ " \"num_attention_heads\": 12,\n",
708
+ " \"num_channels\": 3,\n",
709
+ " \"num_hidden_layers\": 12,\n",
710
+ " \"patch_size\": 32,\n",
711
+ " \"projection_dim\": 512,\n",
712
+ " \"transformers_version\": \"4.36.2\"\n",
713
+ "}\n",
714
+ "\n",
715
+ "loading weights file pytorch_model.bin from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/pytorch_model.bin\n",
716
+ "Some weights of the model checkpoint at openai/clip-vit-base-patch32 were not used when initializing ClipWithProjection: ['text_model.encoder.layers.8.mlp.fc2.weight', 'text_model.encoder.layers.4.self_attn.v_proj.weight', 'text_model.encoder.layers.2.mlp.fc2.bias', 'text_model.encoder.layers.2.self_attn.q_proj.bias', 'text_model.encoder.layers.4.self_attn.v_proj.bias', 'text_model.encoder.layers.6.self_attn.k_proj.bias', 'text_model.encoder.layers.9.self_attn.k_proj.weight', 'text_model.encoder.layers.11.self_attn.q_proj.bias', 'text_model.encoder.layers.3.self_attn.out_proj.bias', 'text_model.encoder.layers.11.self_attn.k_proj.bias', 'text_model.encoder.layers.7.self_attn.k_proj.bias', 'text_model.encoder.layers.1.self_attn.q_proj.weight', 'text_model.encoder.layers.11.layer_norm1.bias', 'text_model.encoder.layers.11.mlp.fc2.bias', 'text_model.encoder.layers.10.layer_norm1.bias', 'text_model.encoder.layers.7.self_attn.q_proj.bias', 'text_model.encoder.layers.11.self_attn.k_proj.weight', 'text_model.encoder.layers.5.self_attn.q_proj.bias', 'text_model.encoder.layers.0.self_attn.v_proj.bias', 'logit_scale', 'text_model.encoder.layers.8.mlp.fc1.bias', 'text_model.encoder.layers.6.layer_norm1.bias', 'text_model.encoder.layers.5.self_attn.out_proj.weight', 'text_model.encoder.layers.7.self_attn.out_proj.bias', 'text_model.embeddings.token_embedding.weight', 'text_model.encoder.layers.8.layer_norm2.bias', 'text_model.encoder.layers.9.self_attn.v_proj.weight', 'text_model.encoder.layers.1.mlp.fc1.weight', 'text_model.encoder.layers.0.layer_norm1.weight', 'text_model.encoder.layers.6.self_attn.k_proj.weight', 'text_model.encoder.layers.3.self_attn.q_proj.weight', 'text_model.encoder.layers.2.layer_norm1.weight', 'text_model.encoder.layers.0.self_attn.v_proj.weight', 'text_model.encoder.layers.7.self_attn.q_proj.weight', 'text_model.encoder.layers.1.layer_norm2.weight', 'text_model.encoder.layers.2.self_attn.out_proj.weight', 'text_model.encoder.layers.3.self_attn.k_proj.weight', 'text_model.encoder.layers.7.mlp.fc2.bias', 'text_model.encoder.layers.10.self_attn.out_proj.weight', 'text_model.encoder.layers.2.self_attn.q_proj.weight', 'text_model.encoder.layers.1.self_attn.k_proj.weight', 'text_model.encoder.layers.4.layer_norm1.weight', 'text_model.encoder.layers.0.self_attn.q_proj.weight', 'text_model.encoder.layers.0.self_attn.out_proj.bias', 'text_model.encoder.layers.4.self_attn.out_proj.bias', 'text_model.encoder.layers.5.self_attn.k_proj.weight', 'visual_projection.weight', 'text_model.encoder.layers.6.layer_norm2.bias', 'text_model.encoder.layers.6.layer_norm1.weight', 'text_model.encoder.layers.4.self_attn.out_proj.weight', 'text_model.encoder.layers.10.mlp.fc2.bias', 'text_model.encoder.layers.10.mlp.fc1.weight', 'text_model.encoder.layers.6.self_attn.out_proj.weight', 'text_model.encoder.layers.9.layer_norm1.weight', 'text_model.encoder.layers.11.layer_norm2.weight', 'text_model.encoder.layers.6.self_attn.q_proj.bias', 'text_model.encoder.layers.5.mlp.fc1.weight', 'text_model.encoder.layers.2.mlp.fc1.weight', 'text_model.encoder.layers.11.self_attn.out_proj.weight', 'text_model.encoder.layers.0.self_attn.out_proj.weight', 'text_model.encoder.layers.11.mlp.fc2.weight', 'text_model.encoder.layers.7.layer_norm2.weight', 'text_model.encoder.layers.10.self_attn.v_proj.bias', 'text_model.encoder.layers.9.mlp.fc1.bias', 'text_model.encoder.layers.8.self_attn.v_proj.weight', 'text_model.encoder.layers.3.layer_norm1.bias', 'text_model.encoder.layers.6.self_attn.v_proj.bias', 'text_model.encoder.layers.1.self_attn.v_proj.bias', 'text_model.encoder.layers.9.self_attn.q_proj.weight', 'text_model.encoder.layers.4.self_attn.k_proj.weight', 'text_model.encoder.layers.7.layer_norm1.weight', 'text_model.encoder.layers.10.self_attn.k_proj.weight', 'text_model.encoder.layers.7.self_attn.v_proj.bias', 'text_model.encoder.layers.7.mlp.fc1.bias', 'text_model.encoder.layers.11.mlp.fc1.weight', 'text_model.encoder.layers.2.mlp.fc1.bias', 'text_model.encoder.layers.3.mlp.fc2.bias', 'text_model.encoder.layers.8.self_attn.q_proj.weight', 'text_model.encoder.layers.0.mlp.fc1.weight', 'text_model.encoder.layers.11.self_attn.out_proj.bias', 'text_model.encoder.layers.1.self_attn.v_proj.weight', 'text_model.encoder.layers.0.self_attn.k_proj.weight', 'text_model.encoder.layers.9.layer_norm1.bias', 'text_model.final_layer_norm.weight', 'text_model.encoder.layers.3.layer_norm1.weight', 'text_model.encoder.layers.4.mlp.fc1.bias', 'text_model.encoder.layers.1.layer_norm1.weight', 'text_model.encoder.layers.10.layer_norm2.bias', 'text_model.encoder.layers.9.self_attn.v_proj.bias', 'text_model.encoder.layers.10.self_attn.k_proj.bias', 'text_model.encoder.layers.8.mlp.fc2.bias', 'text_model.encoder.layers.5.mlp.fc2.bias', 'text_model.encoder.layers.6.self_attn.q_proj.weight', 'text_model.encoder.layers.5.self_attn.out_proj.bias', 'text_model.encoder.layers.9.mlp.fc2.bias', 'text_model.encoder.layers.5.layer_norm2.weight', 'text_model.encoder.layers.2.mlp.fc2.weight', 'text_model.encoder.layers.3.self_attn.out_proj.weight', 'text_model.encoder.layers.6.mlp.fc2.weight', 'text_model.encoder.layers.1.self_attn.out_proj.weight', 'text_model.encoder.layers.1.mlp.fc2.bias', 'text_model.encoder.layers.7.mlp.fc2.weight', 'text_model.encoder.layers.10.self_attn.v_proj.weight', 'text_model.encoder.layers.11.self_attn.v_proj.bias', 'text_model.encoder.layers.4.layer_norm1.bias', 'text_model.encoder.layers.4.layer_norm2.bias', 'text_model.encoder.layers.8.self_attn.q_proj.bias', 'text_model.embeddings.position_ids', 'text_model.encoder.layers.10.layer_norm2.weight', 'text_model.encoder.layers.1.self_attn.out_proj.bias', 'text_model.encoder.layers.2.layer_norm2.weight', 'text_model.encoder.layers.10.self_attn.q_proj.weight', 'text_model.encoder.layers.4.mlp.fc1.weight', 'text_model.encoder.layers.8.layer_norm1.bias', 'text_model.encoder.layers.2.self_attn.k_proj.weight', 'text_model.encoder.layers.5.mlp.fc1.bias', 'text_model.encoder.layers.9.self_attn.out_proj.bias', 'text_model.encoder.layers.7.self_attn.v_proj.weight', 'text_model.encoder.layers.2.self_attn.k_proj.bias', 'text_model.encoder.layers.5.self_attn.k_proj.bias', 'text_model.encoder.layers.8.self_attn.out_proj.bias', 'text_model.encoder.layers.7.self_attn.k_proj.weight', 'text_model.encoder.layers.6.mlp.fc1.weight', 'text_model.encoder.layers.6.mlp.fc1.bias', 'text_model.encoder.layers.3.self_attn.v_proj.weight', 'text_model.encoder.layers.3.self_attn.q_proj.bias', 'text_model.encoder.layers.9.self_attn.out_proj.weight', 'text_model.encoder.layers.3.mlp.fc1.bias', 'text_model.encoder.layers.0.self_attn.q_proj.bias', 'text_model.encoder.layers.1.layer_norm2.bias', 'text_model.encoder.layers.8.layer_norm2.weight', 'text_model.encoder.layers.5.self_attn.q_proj.weight', 'text_model.encoder.layers.4.layer_norm2.weight', 'text_model.encoder.layers.4.mlp.fc2.bias', 'text_model.encoder.layers.9.mlp.fc2.weight', 'text_model.encoder.layers.8.self_attn.k_proj.weight', 'text_model.encoder.layers.10.layer_norm1.weight', 'text_model.encoder.layers.0.self_attn.k_proj.bias', 'text_model.encoder.layers.8.self_attn.k_proj.bias', 'text_model.encoder.layers.9.layer_norm2.weight', 'text_model.encoder.layers.4.self_attn.k_proj.bias', 'text_model.encoder.layers.6.layer_norm2.weight', 'text_model.encoder.layers.0.layer_norm2.weight', 'text_model.encoder.layers.5.self_attn.v_proj.bias', 'text_model.encoder.layers.3.layer_norm2.bias', 'text_model.encoder.layers.8.mlp.fc1.weight', 'text_model.encoder.layers.4.self_attn.q_proj.bias', 'text_model.encoder.layers.8.layer_norm1.weight', 'text_model.encoder.layers.2.self_attn.v_proj.weight', 'text_model.encoder.layers.3.self_attn.v_proj.bias', 'text_model.encoder.layers.11.mlp.fc1.bias', 'text_model.encoder.layers.6.mlp.fc2.bias', 'text_model.encoder.layers.1.mlp.fc1.bias', 'text_model.encoder.layers.2.self_attn.v_proj.bias', 'text_model.encoder.layers.5.mlp.fc2.weight', 'text_model.encoder.layers.8.self_attn.v_proj.bias', 'text_model.encoder.layers.10.self_attn.out_proj.bias', 'text_model.encoder.layers.5.layer_norm1.bias', 'text_model.encoder.layers.5.self_attn.v_proj.weight', 'text_model.encoder.layers.10.self_attn.q_proj.bias', 'text_model.encoder.layers.2.layer_norm2.bias', 'text_model.encoder.layers.7.layer_norm1.bias', 'text_model.encoder.layers.4.mlp.fc2.weight', 'text_model.encoder.layers.10.mlp.fc2.weight', 'text_model.encoder.layers.3.mlp.fc1.weight', 'text_model.encoder.layers.5.layer_norm2.bias', 'text_model.encoder.layers.9.self_attn.q_proj.bias', 'text_model.encoder.layers.1.self_attn.k_proj.bias', 'text_model.encoder.layers.7.self_attn.out_proj.weight', 'text_model.encoder.layers.0.mlp.fc2.weight', 'text_model.encoder.layers.11.self_attn.v_proj.weight', 'text_model.encoder.layers.1.layer_norm1.bias', 'text_model.encoder.layers.1.mlp.fc2.weight', 'text_model.encoder.layers.9.layer_norm2.bias', 'text_model.encoder.layers.9.self_attn.k_proj.bias', 'text_model.encoder.layers.11.layer_norm1.weight', 'text_model.encoder.layers.8.self_attn.out_proj.weight', 'text_model.encoder.layers.0.layer_norm1.bias', 'text_model.encoder.layers.7.mlp.fc1.weight', 'text_model.encoder.layers.0.mlp.fc1.bias', 'text_model.encoder.layers.0.layer_norm2.bias', 'text_model.encoder.layers.3.self_attn.k_proj.bias', 'text_model.encoder.layers.5.layer_norm1.weight', 'text_model.encoder.layers.3.layer_norm2.weight', 'text_model.encoder.layers.1.self_attn.q_proj.bias', 'text_model.encoder.layers.2.self_attn.out_proj.bias', 'text_model.encoder.layers.3.mlp.fc2.weight', 'text_model.encoder.layers.11.self_attn.q_proj.weight', 'text_model.final_layer_norm.bias', 'text_model.encoder.layers.6.self_attn.v_proj.weight', 'text_model.encoder.layers.0.mlp.fc2.bias', 'text_model.encoder.layers.7.layer_norm2.bias', 'text_model.encoder.layers.10.mlp.fc1.bias', 'text_model.embeddings.position_embedding.weight', 'text_model.encoder.layers.6.self_attn.out_proj.bias', 'text_model.encoder.layers.2.layer_norm1.bias', 'text_model.encoder.layers.9.mlp.fc1.weight', 'text_projection.weight', 'text_model.encoder.layers.11.layer_norm2.bias', 'text_model.encoder.layers.4.self_attn.q_proj.weight']\n",
717
+ "- This IS expected if you are initializing ClipWithProjection from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
718
+ "- This IS NOT expected if you are initializing ClipWithProjection from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
719
+ "Some weights of ClipWithProjection were not initialized from the model checkpoint at openai/clip-vit-base-patch32 and are newly initialized: ['vision_language_connector.mlp.2.weight', 'vision_language_connector.mlp.0.weight']\n",
720
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
721
+ ]
722
+ }
723
+ ],
724
+ "source": [
725
+ "model = ClipWithProjection.from_pretrained(\"openai/clip-vit-base-patch32\")"
726
+ ]
727
+ },
728
+ {
729
+ "cell_type": "code",
730
+ "execution_count": 56,
731
+ "id": "588ef914-5be9-49e1-b68d-b899e0e74edd",
732
+ "metadata": {},
733
+ "outputs": [
734
+ {
735
+ "data": {
736
+ "text/plain": [
737
+ "768"
738
+ ]
739
+ },
740
+ "execution_count": 56,
741
+ "metadata": {},
742
+ "output_type": "execute_result"
743
+ }
744
+ ],
745
+ "source": [
746
+ "model.config.hidden_size"
747
+ ]
748
+ },
749
+ {
750
+ "cell_type": "code",
751
+ "execution_count": 57,
752
+ "id": "05d95b9e-9831-4415-860e-94793e29d210",
753
+ "metadata": {},
754
+ "outputs": [],
755
+ "source": [
756
+ "outputs = model(**inputs)"
757
+ ]
758
+ },
759
+ {
760
+ "cell_type": "code",
761
+ "execution_count": 61,
762
+ "id": "185b1bff-6ffe-4cce-9255-ee7629feba54",
763
+ "metadata": {},
764
+ "outputs": [
765
+ {
766
+ "data": {
767
+ "text/plain": [
768
+ "torch.Size([1, 512])"
769
+ ]
770
+ },
771
+ "execution_count": 61,
772
+ "metadata": {},
773
+ "output_type": "execute_result"
774
+ }
775
+ ],
776
+ "source": [
777
+ "outputs[0].shape"
778
+ ]
779
+ },
780
+ {
781
+ "cell_type": "code",
782
+ "execution_count": null,
783
+ "id": "04414a35-c7b3-4986-a79e-1d363916caa4",
784
+ "metadata": {},
785
+ "outputs": [],
786
+ "source": []
787
+ },
788
+ {
789
+ "cell_type": "code",
790
+ "execution_count": 1,
791
+ "id": "485dbbcb-06df-4926-b257-dfd1a4081d44",
792
+ "metadata": {},
793
+ "outputs": [
794
+ {
795
+ "ename": "NameError",
796
+ "evalue": "name 'outputs' is not defined",
797
+ "output_type": "error",
798
+ "traceback": [
799
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
800
+ "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
801
+ "Cell \u001b[0;32mIn[1], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43moutputs\u001b[49m[\u001b[38;5;241m0\u001b[39m]\n",
802
+ "\u001b[0;31mNameError\u001b[0m: name 'outputs' is not defined"
803
+ ]
804
+ }
805
+ ],
806
+ "source": [
807
+ "outputs[0]"
808
+ ]
809
+ },
810
+ {
811
+ "cell_type": "code",
812
+ "execution_count": null,
813
+ "id": "f983313c-8e0f-4805-af14-25bb69afd04c",
814
+ "metadata": {},
815
+ "outputs": [],
816
+ "source": []
817
+ }
818
+ ],
819
+ "metadata": {
820
+ "kernelspec": {
821
+ "display_name": "Python 3 (ipykernel)",
822
+ "language": "python",
823
+ "name": "python3"
824
+ },
825
+ "language_info": {
826
+ "codemirror_mode": {
827
+ "name": "ipython",
828
+ "version": 3
829
+ },
830
+ "file_extension": ".py",
831
+ "mimetype": "text/x-python",
832
+ "name": "python",
833
+ "nbconvert_exporter": "python",
834
+ "pygments_lexer": "ipython3",
835
+ "version": "3.10.12"
836
+ }
837
+ },
838
+ "nbformat": 4,
839
+ "nbformat_minor": 5
840
+ }
Experiments/eval.ipynb ADDED
@@ -0,0 +1,782 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 4,
6
+ "id": "215cfd2f-62b0-4a86-a407-777a1d32597f",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stdout",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "[2024-01-24 15:18:49,948] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n"
14
+ ]
15
+ }
16
+ ],
17
+ "source": [
18
+ "from PIL import Image\n",
19
+ "import requests\n",
20
+ "\n",
21
+ "import torch\n",
22
+ "from torch import nn\n",
23
+ "from transformers import AutoProcessor, CLIPVisionModel, CLIPVisionConfig, CLIPPreTrainedModel\n",
24
+ "from transformers.models.clip.modeling_clip import CLIPVisionModelOutput, CLIPVisionTransformer\n",
25
+ "from transformers import WhisperProcessor, WhisperForConditionalGeneration\n",
26
+ "from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoTokenizer"
27
+ ]
28
+ },
29
+ {
30
+ "cell_type": "code",
31
+ "execution_count": 5,
32
+ "id": "2244e8f3-fcc7-4309-9d4d-fea557f89f79",
33
+ "metadata": {},
34
+ "outputs": [],
35
+ "source": [
36
+ "from llava_phi import LlavaPhiForCausalLM"
37
+ ]
38
+ },
39
+ {
40
+ "cell_type": "code",
41
+ "execution_count": 3,
42
+ "id": "587883e1-3419-4b14-b16b-38fabbc8bfaa",
43
+ "metadata": {},
44
+ "outputs": [],
45
+ "source": [
46
+ "# model = LlavaPhiForCausalLM.from_pretrained(\"./llava-phi/checkpoints/llavaPhi-v0-3b-finetune/checkpoint-4000\")"
47
+ ]
48
+ },
49
+ {
50
+ "cell_type": "code",
51
+ "execution_count": 4,
52
+ "id": "0e27a7db-e2ab-4d65-b21d-497222e318ad",
53
+ "metadata": {},
54
+ "outputs": [],
55
+ "source": [
56
+ "# processor = AutoProcessor.from_pretrained(\"./llava-phi/checkpoints/llavaPhi-v0-3b-finetune/checkpoint-4000\")"
57
+ ]
58
+ },
59
+ {
60
+ "cell_type": "code",
61
+ "execution_count": 5,
62
+ "id": "663efdd8-ea21-4231-a2ae-bcc0fb47b46a",
63
+ "metadata": {},
64
+ "outputs": [],
65
+ "source": [
66
+ "# prompt = \"<image>\\nUSER: What's the content of the image?\\nASSISTANT:\"\n",
67
+ "# url = \"https://www.ilankelman.org/stopsigns/australia.jpg\"\n",
68
+ "# image = Image.open(requests.get(url, stream=True).raw)"
69
+ ]
70
+ },
71
+ {
72
+ "cell_type": "code",
73
+ "execution_count": 6,
74
+ "id": "f622609f-f6a7-4ec1-ac35-c1d33d9436ca",
75
+ "metadata": {},
76
+ "outputs": [],
77
+ "source": [
78
+ "# # Generate\n",
79
+ "# generate_ids = model.generate(**inputs, max_length=30)\n",
80
+ "# processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]"
81
+ ]
82
+ },
83
+ {
84
+ "cell_type": "code",
85
+ "execution_count": 6,
86
+ "id": "45f5ba72-2e41-4ccc-84c1-97d542ebee63",
87
+ "metadata": {},
88
+ "outputs": [],
89
+ "source": [
90
+ "from llava_phi.model.builder import load_pretrained_model\n",
91
+ "from llava_phi.mm_utils import tokenizer_image_token, get_model_name_from_path\n",
92
+ "from llava_phi.utils import disable_torch_init\n",
93
+ "from llava_phi.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN\n",
94
+ "from llava_phi.conversation import conv_templates, SeparatorStyle"
95
+ ]
96
+ },
97
+ {
98
+ "cell_type": "code",
99
+ "execution_count": 11,
100
+ "id": "b98ac5d3-5503-4430-81d1-19a4f8d6bd75",
101
+ "metadata": {},
102
+ "outputs": [],
103
+ "source": [
104
+ "model_path = \"checkpoints/llavaPhi-v0-3b-finetune/checkpoint-4000\"\n",
105
+ "model_name = get_model_name_from_path(model_path)"
106
+ ]
107
+ },
108
+ {
109
+ "cell_type": "code",
110
+ "execution_count": 12,
111
+ "id": "42fd5721-75a7-475b-bd30-5ee23aeaac64",
112
+ "metadata": {},
113
+ "outputs": [
114
+ {
115
+ "data": {
116
+ "text/plain": [
117
+ "'llavaPhi-v0-3b-finetune_checkpoint-4000'"
118
+ ]
119
+ },
120
+ "execution_count": 12,
121
+ "metadata": {},
122
+ "output_type": "execute_result"
123
+ }
124
+ ],
125
+ "source": [
126
+ "model_name"
127
+ ]
128
+ },
129
+ {
130
+ "cell_type": "code",
131
+ "execution_count": 13,
132
+ "id": "8c2076b5-3bfc-48fd-917b-5dfd06fc532f",
133
+ "metadata": {},
134
+ "outputs": [
135
+ {
136
+ "name": "stderr",
137
+ "output_type": "stream",
138
+ "text": [
139
+ "The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.\n"
140
+ ]
141
+ },
142
+ {
143
+ "name": "stdout",
144
+ "output_type": "stream",
145
+ "text": [
146
+ "load llaVA-Phi MLLM!!!\n"
147
+ ]
148
+ },
149
+ {
150
+ "name": "stderr",
151
+ "output_type": "stream",
152
+ "text": [
153
+ "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n",
154
+ "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
155
+ ]
156
+ },
157
+ {
158
+ "data": {
159
+ "application/vnd.jupyter.widget-view+json": {
160
+ "model_id": "20b86f2c01744081b537620c8780f12e",
161
+ "version_major": 2,
162
+ "version_minor": 0
163
+ },
164
+ "text/plain": [
165
+ "Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
166
+ ]
167
+ },
168
+ "metadata": {},
169
+ "output_type": "display_data"
170
+ },
171
+ {
172
+ "name": "stdout",
173
+ "output_type": "stream",
174
+ "text": [
175
+ "{'device_map': 'cuda'}\n"
176
+ ]
177
+ }
178
+ ],
179
+ "source": [
180
+ "tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name)"
181
+ ]
182
+ },
183
+ {
184
+ "cell_type": "code",
185
+ "execution_count": 14,
186
+ "id": "4e46221e-0907-453e-8126-76199828493e",
187
+ "metadata": {},
188
+ "outputs": [],
189
+ "source": [
190
+ "qs = \"What's the content of the image?\"\n",
191
+ "qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\\n' + qs"
192
+ ]
193
+ },
194
+ {
195
+ "cell_type": "code",
196
+ "execution_count": 15,
197
+ "id": "07355444-0eb8-4d4d-ad50-48b91c969664",
198
+ "metadata": {},
199
+ "outputs": [],
200
+ "source": [
201
+ "conv = conv_templates[\"default\"].copy()\n",
202
+ "conv.append_message(conv.roles[0], qs)\n",
203
+ "conv.append_message(conv.roles[1], None)\n",
204
+ "prompt = conv.get_prompt()"
205
+ ]
206
+ },
207
+ {
208
+ "cell_type": "code",
209
+ "execution_count": 16,
210
+ "id": "ccb5674f-aff8-456e-b61b-1d167864f1a6",
211
+ "metadata": {},
212
+ "outputs": [
213
+ {
214
+ "data": {
215
+ "text/plain": [
216
+ "\"A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: <im_start><image><im_end>\\nWhat's the content of the image? ASSISTANT:\""
217
+ ]
218
+ },
219
+ "execution_count": 16,
220
+ "metadata": {},
221
+ "output_type": "execute_result"
222
+ }
223
+ ],
224
+ "source": [
225
+ "prompt"
226
+ ]
227
+ },
228
+ {
229
+ "cell_type": "code",
230
+ "execution_count": 17,
231
+ "id": "a89cc181-2214-4844-b966-164a41744e54",
232
+ "metadata": {},
233
+ "outputs": [],
234
+ "source": [
235
+ "url = \"https://www.ilankelman.org/stopsigns/australia.jpg\"\n",
236
+ "image = Image.open(requests.get(url, stream=True).raw)\n",
237
+ "image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].cuda()\n",
238
+ "\n",
239
+ "input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()\n",
240
+ "\n",
241
+ "stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2"
242
+ ]
243
+ },
244
+ {
245
+ "cell_type": "code",
246
+ "execution_count": 25,
247
+ "id": "0d519851-64d4-4cf5-b2eb-19474f9aa260",
248
+ "metadata": {},
249
+ "outputs": [
250
+ {
251
+ "data": {
252
+ "text/plain": [
253
+ "torch.Size([1, 55])"
254
+ ]
255
+ },
256
+ "execution_count": 25,
257
+ "metadata": {},
258
+ "output_type": "execute_result"
259
+ }
260
+ ],
261
+ "source": [
262
+ "input_ids.shape"
263
+ ]
264
+ },
265
+ {
266
+ "cell_type": "code",
267
+ "execution_count": 24,
268
+ "id": "1694ff36-f214-4ed3-b2f3-d3dbd0a1a25b",
269
+ "metadata": {},
270
+ "outputs": [
271
+ {
272
+ "name": "stderr",
273
+ "output_type": "stream",
274
+ "text": [
275
+ "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
276
+ ]
277
+ }
278
+ ],
279
+ "source": [
280
+ "from datasets import load_dataset\n",
281
+ "audio_ds = load_dataset(\"hf-internal-testing/librispeech_asr_dummy\", \"clean\", split=\"validation\")\n",
282
+ "audio = audio_ds[0][\"audio\"]\n",
283
+ "\n",
284
+ "whisper_w_proj = WhisperWithProjection(projection_dim=512)\n",
285
+ "audio_embed = whisper_w_proj(audio)[\"input_ids\"]"
286
+ ]
287
+ },
288
+ {
289
+ "cell_type": "code",
290
+ "execution_count": 28,
291
+ "id": "9c4a9fae-d6ed-4fc2-ba02-97df64cddd93",
292
+ "metadata": {},
293
+ "outputs": [
294
+ {
295
+ "data": {
296
+ "text/plain": [
297
+ "(torch.Size([1, 33]), device(type='cpu'))"
298
+ ]
299
+ },
300
+ "execution_count": 28,
301
+ "metadata": {},
302
+ "output_type": "execute_result"
303
+ }
304
+ ],
305
+ "source": [
306
+ "audio_embed.shape, audio_embed.device"
307
+ ]
308
+ },
309
+ {
310
+ "cell_type": "code",
311
+ "execution_count": 29,
312
+ "id": "c3fffe29-98fb-4f4b-ac51-4bdda9e46752",
313
+ "metadata": {},
314
+ "outputs": [],
315
+ "source": [
316
+ "input_ids = torch.concat([input_ids, audio_embed.to(\"cuda:0\")], dim=1)"
317
+ ]
318
+ },
319
+ {
320
+ "cell_type": "code",
321
+ "execution_count": 30,
322
+ "id": "5dee1ec8-2db2-4f65-99e8-d34bd2735c9c",
323
+ "metadata": {},
324
+ "outputs": [
325
+ {
326
+ "data": {
327
+ "text/plain": [
328
+ "torch.Size([1, 88])"
329
+ ]
330
+ },
331
+ "execution_count": 30,
332
+ "metadata": {},
333
+ "output_type": "execute_result"
334
+ }
335
+ ],
336
+ "source": [
337
+ "input_ids.shape"
338
+ ]
339
+ },
340
+ {
341
+ "cell_type": "code",
342
+ "execution_count": 31,
343
+ "id": "96033b43-4f57-4f0c-bcf7-37b57ca02e47",
344
+ "metadata": {},
345
+ "outputs": [],
346
+ "source": [
347
+ "with torch.inference_mode():\n",
348
+ " output_ids = model.generate(\n",
349
+ " input_ids,\n",
350
+ " images=image_tensor,\n",
351
+ " do_sample=True,\n",
352
+ " temperature=0.2,\n",
353
+ " max_new_tokens=1024,\n",
354
+ " eos_token_id=tokenizer.eos_token_id, # End of sequence token\n",
355
+ " pad_token_id=tokenizer.eos_token_id, # Pad token\n",
356
+ " use_cache=True,\n",
357
+ " )"
358
+ ]
359
+ },
360
+ {
361
+ "cell_type": "code",
362
+ "execution_count": 32,
363
+ "id": "741e8da5-0d18-4c11-b559-76054ce4ca3a",
364
+ "metadata": {},
365
+ "outputs": [
366
+ {
367
+ "name": "stdout",
368
+ "output_type": "stream",
369
+ "text": [
370
+ "is a Japanese character from the story of Jesus, who is a Chinese monk who is also known for his teachings. The story is based on the story of the story of Jesus Christ, and it is a representation of the story of Jesus and the story of Jesus Christ.\n"
371
+ ]
372
+ }
373
+ ],
374
+ "source": [
375
+ "input_token_len = input_ids.shape[1]\n",
376
+ "n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()\n",
377
+ "if n_diff_input_output > 0:\n",
378
+ " print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')\n",
379
+ "outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]\n",
380
+ "outputs = outputs.strip()\n",
381
+ "if outputs.endswith(stop_str):\n",
382
+ " outputs = outputs[:-len(stop_str)]\n",
383
+ "outputs = outputs.strip()\n",
384
+ "print(outputs)"
385
+ ]
386
+ },
387
+ {
388
+ "cell_type": "code",
389
+ "execution_count": 20,
390
+ "id": "69d494d4-d768-4645-b4d6-5c455791b50d",
391
+ "metadata": {},
392
+ "outputs": [],
393
+ "source": [
394
+ "# image"
395
+ ]
396
+ },
397
+ {
398
+ "cell_type": "code",
399
+ "execution_count": null,
400
+ "id": "8a340856-a13f-4b18-9911-126a4ba37816",
401
+ "metadata": {},
402
+ "outputs": [],
403
+ "source": []
404
+ },
405
+ {
406
+ "cell_type": "code",
407
+ "execution_count": null,
408
+ "id": "3c56fdea-c7a1-4e67-9832-e2ed077d8704",
409
+ "metadata": {},
410
+ "outputs": [],
411
+ "source": []
412
+ },
413
+ {
414
+ "cell_type": "code",
415
+ "execution_count": 52,
416
+ "id": "89e84d39-8ed8-45db-ae82-27c156ee6dd1",
417
+ "metadata": {},
418
+ "outputs": [],
419
+ "source": [
420
+ "class AudioLanguageConnector:\n",
421
+ " def __init__(self, projection_dim):\n",
422
+ " model_name = \"microsoft/phi-2\"\n",
423
+ " self.phi2_tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)\n",
424
+ " self.phi2_tokenizer.pad_token = self.phi2_tokenizer.eos_token\n",
425
+ " self.phi2_tokenizer.max_length = projection_dim\n",
426
+ "\n",
427
+ " def __call__(self, text):\n",
428
+ " text = f\"<audio_start> {text} <audio_end>\"\n",
429
+ " tokens = self.phi2_tokenizer(text, return_tensors=\"pt\", return_attention_mask=False)\n",
430
+ " return tokens\n",
431
+ " \n",
432
+ "\n",
433
+ "class WhisperWithProjection:\n",
434
+ " def __init__(self, projection_dim, device):\n",
435
+ " self.device = device\n",
436
+ " self.processor = WhisperProcessor.from_pretrained(\"openai/whisper-tiny\", device_map=device)\n",
437
+ " self.model = WhisperForConditionalGeneration.from_pretrained(\"openai/whisper-tiny\", device_map=device)\n",
438
+ " self.model.config.forced_decoder_ids = None\n",
439
+ " self.audio_language_connector = AudioLanguageConnector(projection_dim)\n",
440
+ " \n",
441
+ " def __call__(self, audio):\n",
442
+ " input_features = self.processor(audio[\"array\"],\n",
443
+ " sampling_rate=audio[\"sampling_rate\"],\n",
444
+ " return_tensors=\"pt\").input_features\n",
445
+ " # generate token ids\n",
446
+ " predicted_ids = self.model.generate(input_features.to(self.device))\n",
447
+ " # decode token ids to text \n",
448
+ " transcription = self.processor.batch_decode(predicted_ids, skip_special_tokens=True)\n",
449
+ "\n",
450
+ " audio_embeddings = self.audio_language_connector(transcription)\n",
451
+ " return audio_embeddings.to(self.device)"
452
+ ]
453
+ },
454
+ {
455
+ "cell_type": "code",
456
+ "execution_count": 53,
457
+ "id": "75e24be0-b236-4047-83ef-5c344e262476",
458
+ "metadata": {},
459
+ "outputs": [],
460
+ "source": [
461
+ "class MultiModalPhi2:\n",
462
+ " def __init__(self, model_path=\"checkpoints/llavaPhi-v0-3b-finetune/checkpoint-4000\",\n",
463
+ " temperature=0.2,\n",
464
+ " max_new_tokens=1024,\n",
465
+ " device=\"cuda\"):\n",
466
+ " self.temperature = temperature\n",
467
+ " self.max_new_tokens = max_new_tokens\n",
468
+ " self.device = device\n",
469
+ " model_name = get_model_name_from_path(model_path)\n",
470
+ " self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(model_path, None, model_name, device_map=device)\n",
471
+ " self.whisper_w_proj = WhisperWithProjection(projection_dim=512, device=device)\n",
472
+ " \n",
473
+ " \n",
474
+ " def __call__(self, text, audio, image):\n",
475
+ " qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\\n' + text\n",
476
+ " conv = conv_templates[\"default\"].copy()\n",
477
+ " conv.append_message(conv.roles[0], qs)\n",
478
+ " conv.append_message(conv.roles[1], None)\n",
479
+ " prompt = conv.get_prompt()\n",
480
+ "\n",
481
+ " image_tensor = self.image_processor.preprocess(image, return_tensors='pt')['pixel_values'].cuda()\n",
482
+ " \n",
483
+ " input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()\n",
484
+ "\n",
485
+ " audio_embed = self.whisper_w_proj(audio)[\"input_ids\"]\n",
486
+ " \n",
487
+ " stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2\n",
488
+ "\n",
489
+ " input_ids = torch.concat([input_ids, audio_embed], dim=1)\n",
490
+ "\n",
491
+ " with torch.inference_mode():\n",
492
+ " output_ids = self.model.generate(\n",
493
+ " input_ids,\n",
494
+ " images=image_tensor,\n",
495
+ " do_sample=True,\n",
496
+ " temperature=self.temperature,\n",
497
+ " max_new_tokens=self.max_new_tokens,\n",
498
+ " eos_token_id=tokenizer.eos_token_id, # End of sequence token\n",
499
+ " pad_token_id=tokenizer.eos_token_id, # Pad token\n",
500
+ " use_cache=True,\n",
501
+ " )\n",
502
+ "\n",
503
+ " input_token_len = input_ids.shape[1]\n",
504
+ " n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()\n",
505
+ " if n_diff_input_output > 0:\n",
506
+ " print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')\n",
507
+ " outputs = self.tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]\n",
508
+ " outputs = outputs.strip()\n",
509
+ " if outputs.endswith(stop_str):\n",
510
+ " outputs = outputs[:-len(stop_str)]\n",
511
+ " outputs = outputs.strip()\n",
512
+ " return outputs"
513
+ ]
514
+ },
515
+ {
516
+ "cell_type": "code",
517
+ "execution_count": 54,
518
+ "id": "4efdbad4-d88a-4477-a3a0-f5591cd0b172",
519
+ "metadata": {},
520
+ "outputs": [
521
+ {
522
+ "name": "stderr",
523
+ "output_type": "stream",
524
+ "text": [
525
+ "The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.\n",
526
+ "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n",
527
+ "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
528
+ ]
529
+ },
530
+ {
531
+ "name": "stdout",
532
+ "output_type": "stream",
533
+ "text": [
534
+ "load llaVA-Phi MLLM!!!\n"
535
+ ]
536
+ },
537
+ {
538
+ "data": {
539
+ "application/vnd.jupyter.widget-view+json": {
540
+ "model_id": "492c17cf54f34d4d9e4f288fc9e72e79",
541
+ "version_major": 2,
542
+ "version_minor": 0
543
+ },
544
+ "text/plain": [
545
+ "Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
546
+ ]
547
+ },
548
+ "metadata": {},
549
+ "output_type": "display_data"
550
+ },
551
+ {
552
+ "name": "stdout",
553
+ "output_type": "stream",
554
+ "text": [
555
+ "{'device_map': 'cuda'}\n"
556
+ ]
557
+ },
558
+ {
559
+ "name": "stderr",
560
+ "output_type": "stream",
561
+ "text": [
562
+ "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
563
+ ]
564
+ }
565
+ ],
566
+ "source": [
567
+ "multimodal_phi2 = MultiModalPhi2()"
568
+ ]
569
+ },
570
+ {
571
+ "cell_type": "code",
572
+ "execution_count": 57,
573
+ "id": "9a6de0b0-a231-4d50-88e8-e40c6f7216c3",
574
+ "metadata": {},
575
+ "outputs": [],
576
+ "source": [
577
+ "text = \"tell me about the audio\""
578
+ ]
579
+ },
580
+ {
581
+ "cell_type": "code",
582
+ "execution_count": 58,
583
+ "id": "b4919948-6a75-4d19-ba95-9ba233a7d3d9",
584
+ "metadata": {},
585
+ "outputs": [
586
+ {
587
+ "data": {
588
+ "text/plain": [
589
+ "'is a popular Japanese drama series featuring a man in a red and white costume, who is dressed as Santa Claus, is walking down the street. The scene takes place in a busy city environment, with people walking and standing on the sidewalk, likely enjoying the festive atmosphere and the festive atmosphere.'"
590
+ ]
591
+ },
592
+ "execution_count": 58,
593
+ "metadata": {},
594
+ "output_type": "execute_result"
595
+ }
596
+ ],
597
+ "source": [
598
+ "multimodal_phi2(text, audio, image)"
599
+ ]
600
+ },
601
+ {
602
+ "cell_type": "code",
603
+ "execution_count": null,
604
+ "id": "590f2d64-62ed-4e6f-b7c8-b0cf68aecaab",
605
+ "metadata": {},
606
+ "outputs": [],
607
+ "source": []
608
+ },
609
+ {
610
+ "cell_type": "code",
611
+ "execution_count": 64,
612
+ "id": "c921eb63-feb5-4fa9-993b-2faeb6dfe1db",
613
+ "metadata": {},
614
+ "outputs": [],
615
+ "source": [
616
+ "from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig, CLIPImageProcessor"
617
+ ]
618
+ },
619
+ {
620
+ "cell_type": "code",
621
+ "execution_count": 65,
622
+ "id": "b470a2c4-806a-435d-9fc2-f17448dbe5fc",
623
+ "metadata": {},
624
+ "outputs": [],
625
+ "source": [
626
+ "from llava_phi.model import LlavaPhiConfig"
627
+ ]
628
+ },
629
+ {
630
+ "cell_type": "code",
631
+ "execution_count": 66,
632
+ "id": "4f7bc91a-0a41-45e5-92a4-daa1e3eea0da",
633
+ "metadata": {},
634
+ "outputs": [
635
+ {
636
+ "name": "stderr",
637
+ "output_type": "stream",
638
+ "text": [
639
+ "The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.\n",
640
+ "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n",
641
+ "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
642
+ ]
643
+ },
644
+ {
645
+ "data": {
646
+ "application/vnd.jupyter.widget-view+json": {
647
+ "model_id": "993bc3a38cb84de4a2e3a79a3448c4d6",
648
+ "version_major": 2,
649
+ "version_minor": 0
650
+ },
651
+ "text/plain": [
652
+ "Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
653
+ ]
654
+ },
655
+ "metadata": {},
656
+ "output_type": "display_data"
657
+ }
658
+ ],
659
+ "source": [
660
+ "device_map = \"cuda:0\"\n",
661
+ "load_8bit = False\n",
662
+ "load_4bit = False\n",
663
+ "kwargs = {\"device_map\": device_map}\n",
664
+ "if load_8bit:\n",
665
+ " kwargs['load_in_8bit'] = True\n",
666
+ "elif load_4bit:\n",
667
+ " kwargs['load_in_4bit'] = True\n",
668
+ " kwargs['quantization_config'] = BitsAndBytesConfig(\n",
669
+ " load_in_4bit=True,\n",
670
+ " bnb_4bit_compute_dtype=torch.float16,\n",
671
+ " bnb_4bit_use_double_quant=True,\n",
672
+ " bnb_4bit_quant_type='nf4'\n",
673
+ " )\n",
674
+ "config = LlavaPhiConfig.from_pretrained(model_path, trust_remote_code=True)\n",
675
+ "tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)\n",
676
+ "model = LlavaPhiForCausalLM.from_pretrained(\n",
677
+ " model_path, \n",
678
+ " config=config, \n",
679
+ " use_safetensors=True, \n",
680
+ " **kwargs).to(\"cuda\")\n",
681
+ "image_processor = CLIPImageProcessor.from_pretrained(model_path)\n",
682
+ "mm_use_im_start_end = getattr(model.config, \"mm_use_im_start_end\", False)\n",
683
+ "mm_use_im_patch_token = getattr(model.config, \"mm_use_im_patch_token\", True)\n",
684
+ "\n",
685
+ "# TODO: the tokenizer length of phi-2 is 50295, but the output class of lm_head is 51200\n",
686
+ "if mm_use_im_patch_token:\n",
687
+ " tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)\n",
688
+ "if mm_use_im_start_end:\n",
689
+ " tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)\n",
690
+ " \n",
691
+ "if hasattr(model.config, \"max_sequence_length\"):\n",
692
+ " context_len = model.config.max_sequence_length\n",
693
+ "else:\n",
694
+ " context_len = 2048"
695
+ ]
696
+ },
697
+ {
698
+ "cell_type": "code",
699
+ "execution_count": 70,
700
+ "id": "99355837-a297-4a25-aeb3-1670af7e9251",
701
+ "metadata": {},
702
+ "outputs": [
703
+ {
704
+ "ename": "KeyboardInterrupt",
705
+ "evalue": "",
706
+ "output_type": "error",
707
+ "traceback": [
708
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
709
+ "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
710
+ "Cell \u001b[0;32mIn[70], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msave_pretrained\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mLlava-Phi-Checkpoint\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n",
711
+ "File \u001b[0;32m~/miniconda3/envs/torchenv/lib/python3.10/site-packages/transformers/modeling_utils.py:2376\u001b[0m, in \u001b[0;36mPreTrainedModel.save_pretrained\u001b[0;34m(self, save_directory, is_main_process, state_dict, save_function, push_to_hub, max_shard_size, safe_serialization, variant, token, save_peft_format, **kwargs)\u001b[0m\n\u001b[1;32m 2372\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m shard_file, shard \u001b[38;5;129;01min\u001b[39;00m shards\u001b[38;5;241m.\u001b[39mitems():\n\u001b[1;32m 2373\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m safe_serialization:\n\u001b[1;32m 2374\u001b[0m \u001b[38;5;66;03m# At some point we will need to deal better with save_function (used for TPU and other distributed\u001b[39;00m\n\u001b[1;32m 2375\u001b[0m \u001b[38;5;66;03m# joyfulness), but for now this enough.\u001b[39;00m\n\u001b[0;32m-> 2376\u001b[0m \u001b[43msafe_save_file\u001b[49m\u001b[43m(\u001b[49m\u001b[43mshard\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mos\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpath\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mjoin\u001b[49m\u001b[43m(\u001b[49m\u001b[43msave_directory\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mshard_file\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmetadata\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m{\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mformat\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mpt\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m}\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2377\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 2378\u001b[0m save_function(shard, os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mjoin(save_directory, shard_file))\n",
712
+ "File \u001b[0;32m~/miniconda3/envs/torchenv/lib/python3.10/site-packages/safetensors/torch.py:281\u001b[0m, in \u001b[0;36msave_file\u001b[0;34m(tensors, filename, metadata)\u001b[0m\n\u001b[1;32m 250\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21msave_file\u001b[39m(\n\u001b[1;32m 251\u001b[0m tensors: Dict[\u001b[38;5;28mstr\u001b[39m, torch\u001b[38;5;241m.\u001b[39mTensor],\n\u001b[1;32m 252\u001b[0m filename: Union[\u001b[38;5;28mstr\u001b[39m, os\u001b[38;5;241m.\u001b[39mPathLike],\n\u001b[1;32m 253\u001b[0m metadata: Optional[Dict[\u001b[38;5;28mstr\u001b[39m, \u001b[38;5;28mstr\u001b[39m]] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 254\u001b[0m ):\n\u001b[1;32m 255\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 256\u001b[0m \u001b[38;5;124;03m Saves a dictionary of tensors into raw bytes in safetensors format.\u001b[39;00m\n\u001b[1;32m 257\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 279\u001b[0m \u001b[38;5;124;03m ```\u001b[39;00m\n\u001b[1;32m 280\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 281\u001b[0m \u001b[43mserialize_file\u001b[49m\u001b[43m(\u001b[49m\u001b[43m_flatten\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtensors\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfilename\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmetadata\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmetadata\u001b[49m\u001b[43m)\u001b[49m\n",
713
+ "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
714
+ ]
715
+ }
716
+ ],
717
+ "source": [
718
+ "model.save_pretrained(\"Llava-Phi-Checkpoint\")"
719
+ ]
720
+ },
721
+ {
722
+ "cell_type": "code",
723
+ "execution_count": null,
724
+ "id": "fa0bec34-a148-4340-a30c-6f09dd5e71ca",
725
+ "metadata": {},
726
+ "outputs": [],
727
+ "source": [
728
+ "model.push_to_hub(\"RaviNaik/Llava-Phi2\")"
729
+ ]
730
+ },
731
+ {
732
+ "cell_type": "code",
733
+ "execution_count": 73,
734
+ "id": "382f74b0-2967-408a-badc-a90918810d74",
735
+ "metadata": {},
736
+ "outputs": [
737
+ {
738
+ "data": {
739
+ "text/plain": [
740
+ "CommitInfo(commit_url='https://huggingface.co/RaviNaik/Llava-Phi2/commit/fa8f7240058241243f6bdc3d6ab44bb691f76e39', commit_message='Upload tokenizer', commit_description='', oid='fa8f7240058241243f6bdc3d6ab44bb691f76e39', pr_url=None, pr_revision=None, pr_num=None)"
741
+ ]
742
+ },
743
+ "execution_count": 73,
744
+ "metadata": {},
745
+ "output_type": "execute_result"
746
+ }
747
+ ],
748
+ "source": [
749
+ "tokenizer.push_to_hub(\"RaviNaik/Llava-Phi2\")"
750
+ ]
751
+ },
752
+ {
753
+ "cell_type": "code",
754
+ "execution_count": null,
755
+ "id": "b851459b-d3ac-4fb8-99b6-17a648adc41f",
756
+ "metadata": {},
757
+ "outputs": [],
758
+ "source": []
759
+ }
760
+ ],
761
+ "metadata": {
762
+ "kernelspec": {
763
+ "display_name": "Python 3 (ipykernel)",
764
+ "language": "python",
765
+ "name": "python3"
766
+ },
767
+ "language_info": {
768
+ "codemirror_mode": {
769
+ "name": "ipython",
770
+ "version": 3
771
+ },
772
+ "file_extension": ".py",
773
+ "mimetype": "text/x-python",
774
+ "name": "python",
775
+ "nbconvert_exporter": "python",
776
+ "pygments_lexer": "ipython3",
777
+ "version": "3.10.12"
778
+ }
779
+ },
780
+ "nbformat": 4,
781
+ "nbformat_minor": 5
782
+ }
Experiments/instruct_150k_data.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
Experiments/instruct_data.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import Dataset, IterableDataset
2
+ from PIL import Image
3
+
4
+ # ChatML format
5
+ templates = {
6
+ "assistant": "<|im_start|>assistant\n{msg}<|im_end|>", # message by assistant
7
+ "user": "<|im_start|>user\n{msg}<|im_end|>" # message by user
8
+ }
9
+
10
+ ds = Dataset.from_json("llava_instruct_150k.json", split="train")
11
+ ds_stream = ds.to_iterable_dataset()
12
+
13
+
14
+ def get_image(image_path):
15
+ image_path = f"train2014/COCO_train2014_{image_path}"
16
+ img = Image.open(image_path)
17
+ return img
18
+
19
+ def get_chatml_text(conversations):
20
+ chatml_text = ""
21
+ for conversation in conversations:
22
+ role = conversation["from"]
23
+ role = "user" if role == "human" else "assistant"
24
+ content = conversation["value"]
25
+
26
+ formatted_text = templates[role].format(msg=content)
27
+ chatml_text += formatted_text + "\n"
28
+ return chatml_text
29
+
30
+ def instruct_data_generator():
31
+ for sample in ds_stream:
32
+ image_path = sample["image"]
33
+ conversations = sample["conversations"]
34
+
35
+ image = get_image(image_path)
36
+ text = get_chatml_text(conversations)
37
+ yield {"text": text, "image": image}
38
+
39
+ instruct_ds = IterableDataset.from_generator(generator=instruct_data_generator)
Experiments/llava_exp.ipynb ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "99576983-f881-47c8-8b5e-c6f561a93e71",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import transformers"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "code",
15
+ "execution_count": 2,
16
+ "id": "58ba19f2-4b91-4f90-a33d-4c1ed17e202a",
17
+ "metadata": {},
18
+ "outputs": [],
19
+ "source": [
20
+ "from transformers import LlavaForConditionalGeneration, LlavaConfig, CLIPVisionConfig, PhiConfig\n",
21
+ "\n",
22
+ "# Initializing a CLIP-vision config\n",
23
+ "vision_config = CLIPVisionConfig()\n",
24
+ "\n",
25
+ "# Initializing a Llama config\n",
26
+ "text_config = PhiConfig()\n",
27
+ "\n",
28
+ "# Initializing a Llava llava-1.5-7b style configuration\n",
29
+ "configuration = LlavaConfig(vision_config, text_config)\n",
30
+ "\n",
31
+ "# Initializing a model from the llava-1.5-7b style configuration\n",
32
+ "model = LlavaForConditionalGeneration(configuration)\n",
33
+ "\n",
34
+ "# Accessing the model configuration\n",
35
+ "configuration = model.config"
36
+ ]
37
+ },
38
+ {
39
+ "cell_type": "code",
40
+ "execution_count": 5,
41
+ "id": "a806a07a-fe72-45a3-8ceb-8e942c6c845d",
42
+ "metadata": {},
43
+ "outputs": [
44
+ {
45
+ "data": {
46
+ "text/plain": [
47
+ "LlavaConfig {\n",
48
+ " \"ignore_index\": -100,\n",
49
+ " \"image_token_index\": 32000,\n",
50
+ " \"model_type\": \"llava\",\n",
51
+ " \"projector_hidden_act\": \"gelu\",\n",
52
+ " \"text_config\": {\n",
53
+ " \"embd_pdrop\": 0.0,\n",
54
+ " \"hidden_act\": \"gelu_new\",\n",
55
+ " \"hidden_size\": 2048,\n",
56
+ " \"intermediate_size\": 8192,\n",
57
+ " \"layer_norm_eps\": 1e-05,\n",
58
+ " \"model_type\": \"phi\",\n",
59
+ " \"num_hidden_layers\": 24,\n",
60
+ " \"partial_rotary_factor\": 0.5,\n",
61
+ " \"qk_layernorm\": false,\n",
62
+ " \"resid_pdrop\": 0.0,\n",
63
+ " \"vocab_size\": 51200\n",
64
+ " },\n",
65
+ " \"transformers_version\": \"4.36.2\",\n",
66
+ " \"vision_config\": {\n",
67
+ " \"hidden_size\": 768,\n",
68
+ " \"image_size\": 224,\n",
69
+ " \"intermediate_size\": 3072,\n",
70
+ " \"model_type\": \"clip_vision_model\",\n",
71
+ " \"num_attention_heads\": 12,\n",
72
+ " \"num_hidden_layers\": 12,\n",
73
+ " \"patch_size\": 32,\n",
74
+ " \"projection_dim\": 512\n",
75
+ " },\n",
76
+ " \"vision_feature_layer\": -2,\n",
77
+ " \"vision_feature_select_strategy\": \"default\",\n",
78
+ " \"vocab_size\": 32000\n",
79
+ "}"
80
+ ]
81
+ },
82
+ "execution_count": 5,
83
+ "metadata": {},
84
+ "output_type": "execute_result"
85
+ }
86
+ ],
87
+ "source": [
88
+ "model.config"
89
+ ]
90
+ },
91
+ {
92
+ "cell_type": "code",
93
+ "execution_count": 6,
94
+ "id": "79efbc6b-f005-4a5c-82a1-112fa37f1904",
95
+ "metadata": {},
96
+ "outputs": [
97
+ {
98
+ "name": "stdout",
99
+ "output_type": "stream",
100
+ "text": [
101
+ "Cloning into 'llava-phi'...\n",
102
+ "remote: Enumerating objects: 151, done.\u001b[K\n",
103
+ "remote: Counting objects: 100% (151/151), done.\u001b[K\n",
104
+ "remote: Compressing objects: 100% (116/116), done.\u001b[K\n",
105
+ "remote: Total 151 (delta 36), reused 133 (delta 25), pack-reused 0\u001b[K\n",
106
+ "Receiving objects: 100% (151/151), 333.89 KiB | 112.00 KiB/s, done.\n",
107
+ "Resolving deltas: 100% (36/36), done.\n"
108
+ ]
109
+ }
110
+ ],
111
+ "source": [
112
+ "!git clone https://github.com/zhuyiche/llava-phi.git"
113
+ ]
114
+ },
115
+ {
116
+ "cell_type": "code",
117
+ "execution_count": null,
118
+ "id": "cf827184-f334-4d86-ace1-fe9c92f84d66",
119
+ "metadata": {},
120
+ "outputs": [],
121
+ "source": []
122
+ }
123
+ ],
124
+ "metadata": {
125
+ "kernelspec": {
126
+ "display_name": "Python 3 (ipykernel)",
127
+ "language": "python",
128
+ "name": "python3"
129
+ },
130
+ "language_info": {
131
+ "codemirror_mode": {
132
+ "name": "ipython",
133
+ "version": 3
134
+ },
135
+ "file_extension": ".py",
136
+ "mimetype": "text/x-python",
137
+ "name": "python",
138
+ "nbconvert_exporter": "python",
139
+ "pygments_lexer": "ipython3",
140
+ "version": "3.10.12"
141
+ }
142
+ },
143
+ "nbformat": 4,
144
+ "nbformat_minor": 5
145
+ }
Experiments/multimodal_exp.ipynb ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 23,
6
+ "id": "d4bed9ef-4bff-4d61-a4f9-a585f377f136",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "from PIL import Image\n",
11
+ "import requests\n",
12
+ "\n",
13
+ "import torch\n",
14
+ "from torch import nn\n",
15
+ "from transformers import AutoProcessor, CLIPVisionModel, CLIPVisionConfig, CLIPPreTrainedModel\n",
16
+ "from transformers.models.clip.modeling_clip import CLIPVisionModelOutput, CLIPVisionTransformer\n",
17
+ "from transformers import WhisperProcessor, WhisperForConditionalGeneration\n",
18
+ "from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoTokenizer\n",
19
+ "from typing import Optional, Union, Tuple"
20
+ ]
21
+ },
22
+ {
23
+ "cell_type": "code",
24
+ "execution_count": 43,
25
+ "id": "952314f0-ee9d-45e7-85b8-1e3e44c1a2fd",
26
+ "metadata": {},
27
+ "outputs": [],
28
+ "source": [
29
+ "class VisionLanguageConnector(nn.Module):\n",
30
+ " def __init__(self, hidden_size, projection_dim):\n",
31
+ " super().__init__()\n",
32
+ " self.mlp = nn.Sequential(\n",
33
+ " nn.Linear(hidden_size, hidden_size, bias=False),\n",
34
+ " nn.GELU(),\n",
35
+ " nn.Linear(hidden_size, projection_dim, bias=False)\n",
36
+ " )\n",
37
+ "\n",
38
+ " def forward(self, x):\n",
39
+ " return self.mlp(x)\n",
40
+ " \n",
41
+ "class ClipWithProjection():\n",
42
+ " config_class = CLIPVisionConfig\n",
43
+ " main_input_name = \"pixel_values\"\n",
44
+ "\n",
45
+ " def __init__(self, hidden_size, projection_dim):\n",
46
+ " super().__init__()\n",
47
+ " \n",
48
+ " self.processor = AutoProcessor.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
49
+ " self.vision_model = CLIPVisionModel.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
50
+ " self.vision_language_connector = VisionLanguageConnector(hidden_size, projection_dim)\n",
51
+ "\n",
52
+ " def forward(\n",
53
+ " self,\n",
54
+ " image = None,\n",
55
+ " output_attentions: Optional[bool] = None,\n",
56
+ " output_hidden_states: Optional[bool] = None,\n",
57
+ " return_dict: Optional[bool] = None,\n",
58
+ " ) -> Union[Tuple, CLIPVisionModelOutput]:\n",
59
+ " \n",
60
+ " pixel_values = self.processor(images=image, return_tensors=\"pt\")[\"pixel_values\"]\n",
61
+ " vision_outputs = self.vision_model(\n",
62
+ " pixel_values=pixel_values,\n",
63
+ " output_attentions=output_attentions,\n",
64
+ " output_hidden_states=output_hidden_states,\n",
65
+ " return_dict=return_dict,\n",
66
+ " )\n",
67
+ "\n",
68
+ " pooled_output = vision_outputs[1] # pooled_output\n",
69
+ "\n",
70
+ " image_embeds = self.vision_language_connector(pooled_output)\n",
71
+ "\n",
72
+ " return CLIPVisionModelOutput(\n",
73
+ " image_embeds=image_embeds,\n",
74
+ " last_hidden_state=vision_outputs.last_hidden_state,\n",
75
+ " hidden_states=vision_outputs.hidden_states,\n",
76
+ " attentions=vision_outputs.attentions,\n",
77
+ " )"
78
+ ]
79
+ },
80
+ {
81
+ "cell_type": "code",
82
+ "execution_count": 44,
83
+ "id": "bd2889fe-be85-44a3-afe8-65b47f7a93c3",
84
+ "metadata": {},
85
+ "outputs": [],
86
+ "source": [
87
+ "url = \"http://images.cocodataset.org/val2017/000000039769.jpg\"\n",
88
+ "image = Image.open(requests.get(url, stream=True).raw)"
89
+ ]
90
+ },
91
+ {
92
+ "cell_type": "code",
93
+ "execution_count": 46,
94
+ "id": "17c72699-fe98-4b96-b63c-5c8ab7c1a65f",
95
+ "metadata": {},
96
+ "outputs": [],
97
+ "source": [
98
+ "# model = ClipWithProjection(768, 512)\n",
99
+ "# model.forward(image)"
100
+ ]
101
+ },
102
+ {
103
+ "cell_type": "code",
104
+ "execution_count": 47,
105
+ "id": "70806156-38a9-45a2-bf9f-e72047a0173f",
106
+ "metadata": {},
107
+ "outputs": [],
108
+ "source": [
109
+ "class AudioLanguageConnector:\n",
110
+ " def __init__(self, projection_dim):\n",
111
+ " model_name = \"microsoft/phi-2\"\n",
112
+ " self.phi2_tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)\n",
113
+ " self.phi2_tokenizer.pad_token = self.phi2_tokenizer.eos_token\n",
114
+ " self.phi2_tokenizer.max_length = projection_dim\n",
115
+ "\n",
116
+ " def __call__(self, text):\n",
117
+ " text = f\"<audio_start> {text} <audio_end>\"\n",
118
+ " tokens = self.phi2_tokenizer(text, return_tensors=\"pt\", return_attention_mask=False)\n",
119
+ " return tokens\n",
120
+ " \n",
121
+ "\n",
122
+ "class WhisperWithProjection:\n",
123
+ " def __init__(self, projection_dim):\n",
124
+ " self.processor = WhisperProcessor.from_pretrained(\"openai/whisper-tiny\")\n",
125
+ " self.model = WhisperForConditionalGeneration.from_pretrained(\"openai/whisper-tiny\")\n",
126
+ " self.model.config.forced_decoder_ids = None\n",
127
+ " self.audio_language_connector = AudioLanguageConnector(projection_dim)\n",
128
+ " \n",
129
+ " def forward(self, audio):\n",
130
+ " input_features = self.processor(audio[\"array\"],\n",
131
+ " sampling_rate=audio[\"sampling_rate\"],\n",
132
+ " return_tensors=\"pt\").input_features\n",
133
+ " # generate token ids\n",
134
+ " predicted_ids = self.model.generate(input_features)\n",
135
+ " # decode token ids to text \n",
136
+ " transcription = self.processor.batch_decode(predicted_ids, skip_special_tokens=True)\n",
137
+ "\n",
138
+ " audio_embeddings = self.audio_language_connector(transcription)\n",
139
+ " return audio_embeddings"
140
+ ]
141
+ },
142
+ {
143
+ "cell_type": "code",
144
+ "execution_count": 48,
145
+ "id": "79cc4d98-498b-4042-bd71-143b2477733d",
146
+ "metadata": {},
147
+ "outputs": [],
148
+ "source": [
149
+ "class TextModality:\n",
150
+ " def __init__(self, projection_dim):\n",
151
+ " model_name = \"microsoft/phi-2\"\n",
152
+ " self.phi2_tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)\n",
153
+ " self.phi2_tokenizer.pad_token = self.phi2_tokenizer.eos_token\n",
154
+ " self.phi2_tokenizer.max_length = projection_dim\n",
155
+ "\n",
156
+ "\n",
157
+ " def __call__(self, text):\n",
158
+ " tokens = self.phi2_tokenizer(text, return_tensors=\"pt\", return_attention_mask=False)\n",
159
+ " return tokens"
160
+ ]
161
+ },
162
+ {
163
+ "cell_type": "code",
164
+ "execution_count": 77,
165
+ "id": "ba4c4772-923f-48e8-a4af-b7d9c192dd4b",
166
+ "metadata": {},
167
+ "outputs": [],
168
+ "source": [
169
+ "class MultiModalPhi2:\n",
170
+ " def __init__(self):\n",
171
+ " self.text_modality = TextModality(projection_dim=768)\n",
172
+ " self.whisper_w_proj = WhisperWithProjection(projection_dim=512)\n",
173
+ " self.clip_w_proj = ClipWithProjection(hidden_size=768, projection_dim=768)\n",
174
+ " self.llm = self.load_llm()\n",
175
+ "\n",
176
+ " def load_llm(self):\n",
177
+ " model_name = \"microsoft/phi-2\"\n",
178
+ " \n",
179
+ " bnb_config = BitsAndBytesConfig(\n",
180
+ " load_in_4bit=True,\n",
181
+ " bnb_4bit_quant_type=\"nf4\",\n",
182
+ " bnb_4bit_compute_dtype=torch.float16)\n",
183
+ " \n",
184
+ " model = AutoModelForCausalLM.from_pretrained(\n",
185
+ " model_name,\n",
186
+ " quantization_config=bnb_config,\n",
187
+ " trust_remote_code=True,\n",
188
+ " device_map=\"cuda:0\"\n",
189
+ " )\n",
190
+ " model.config.use_cache = False\n",
191
+ " return model\n",
192
+ "\n",
193
+ " def forward(self, audio, image, text):\n",
194
+ " if text is not None:\n",
195
+ " text_embed = self.text_modality(text)[\"input_ids\"]\n",
196
+ " if audio is not None:\n",
197
+ " audio_embed = self.whisper_w_proj.forward(audio)[\"input_ids\"]\n",
198
+ " if image is not None:\n",
199
+ " image_embed = self.clip_w_proj.forward(image)[0]\n",
200
+ " print(text_embed.shape, text_embed.dtype)\n",
201
+ " print(audio_embed.shape, audio_embed.dtype)\n",
202
+ " print(image_embed.shape, image_embed.dtype)\n",
203
+ " \n",
204
+ " inputs = torch.concat([text_embed, audio_embed, image_embed], dim=1)\n",
205
+ " print(inputs.shape, inputs.dtype)\n",
206
+ " outputs = self.llm(inputs)\n",
207
+ "\n",
208
+ " return outputs \n",
209
+ " \n",
210
+ "\n",
211
+ " def generate(self, audio, text):\n",
212
+ " text_embeddings = self.text_modality(text)\n",
213
+ " audio_embeddings = self.whisper_w_proj.forward(audio)\n",
214
+ " inputs = torch.concat([text_embed[\"input_ids\"], audio_embed[\"input_ids\"]], dim=1)\n",
215
+ " \n",
216
+ " outputs = self.llm.generate(inputs, max_length=200)\n",
217
+ " text = self.text_modality.phi2_tokenizer.batch_decode(outputs)[0]\n",
218
+ " print(text)"
219
+ ]
220
+ },
221
+ {
222
+ "cell_type": "code",
223
+ "execution_count": 74,
224
+ "id": "7ca694eb-8009-4eb9-9a4c-eac406ab9584",
225
+ "metadata": {},
226
+ "outputs": [],
227
+ "source": [
228
+ "from datasets import load_dataset\n",
229
+ "audio_ds = load_dataset(\"hf-internal-testing/librispeech_asr_dummy\", \"clean\", split=\"validation\")\n",
230
+ "audio = audio_ds[0][\"audio\"]"
231
+ ]
232
+ },
233
+ {
234
+ "cell_type": "code",
235
+ "execution_count": 58,
236
+ "id": "37be28c5-4cc3-4471-b394-032c7602accc",
237
+ "metadata": {},
238
+ "outputs": [],
239
+ "source": [
240
+ "text = \"explain about the audio\""
241
+ ]
242
+ },
243
+ {
244
+ "cell_type": "code",
245
+ "execution_count": 59,
246
+ "id": "c0705114-1670-4937-bc3e-3660e5a5d2c5",
247
+ "metadata": {},
248
+ "outputs": [],
249
+ "source": [
250
+ "# image"
251
+ ]
252
+ },
253
+ {
254
+ "cell_type": "code",
255
+ "execution_count": 78,
256
+ "id": "0d7e5b49-b4bd-477c-87b8-91ef70857677",
257
+ "metadata": {},
258
+ "outputs": [
259
+ {
260
+ "name": "stderr",
261
+ "output_type": "stream",
262
+ "text": [
263
+ "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n",
264
+ "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
265
+ ]
266
+ },
267
+ {
268
+ "data": {
269
+ "application/vnd.jupyter.widget-view+json": {
270
+ "model_id": "733dc7b2208b4853a89aea49bff9a55c",
271
+ "version_major": 2,
272
+ "version_minor": 0
273
+ },
274
+ "text/plain": [
275
+ "Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
276
+ ]
277
+ },
278
+ "metadata": {},
279
+ "output_type": "display_data"
280
+ }
281
+ ],
282
+ "source": [
283
+ "model = MultiModalPhi2()"
284
+ ]
285
+ },
286
+ {
287
+ "cell_type": "code",
288
+ "execution_count": 79,
289
+ "id": "0b6471c4-4553-47f3-b38f-46057dcf80f2",
290
+ "metadata": {},
291
+ "outputs": [
292
+ {
293
+ "name": "stdout",
294
+ "output_type": "stream",
295
+ "text": [
296
+ "torch.Size([1, 5]) torch.int64\n",
297
+ "torch.Size([1, 33]) torch.int64\n",
298
+ "torch.Size([1, 768]) torch.float32\n",
299
+ "torch.Size([1, 806]) torch.float32\n"
300
+ ]
301
+ },
302
+ {
303
+ "ename": "RuntimeError",
304
+ "evalue": "Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got torch.cuda.FloatTensor instead (while checking arguments for embedding)",
305
+ "output_type": "error",
306
+ "traceback": [
307
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
308
+ "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
309
+ "Cell \u001b[0;32mIn[79], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mforward\u001b[49m\u001b[43m(\u001b[49m\u001b[43maudio\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mimage\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtext\u001b[49m\u001b[43m)\u001b[49m\n",
310
+ "Cell \u001b[0;32mIn[77], line 38\u001b[0m, in \u001b[0;36mMultiModalPhi2.forward\u001b[0;34m(self, audio, image, text)\u001b[0m\n\u001b[1;32m 36\u001b[0m inputs \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mconcat([text_embed, audio_embed, image_embed], dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m 37\u001b[0m \u001b[38;5;28mprint\u001b[39m(inputs\u001b[38;5;241m.\u001b[39mshape, inputs\u001b[38;5;241m.\u001b[39mdtype)\n\u001b[0;32m---> 38\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mllm\u001b[49m\u001b[43m(\u001b[49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 40\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m outputs\n",
311
+ "File \u001b[0;32m~/miniconda3/envs/torchenv/lib/python3.10/site-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
312
+ "File \u001b[0;32m~/miniconda3/envs/torchenv/lib/python3.10/site-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
313
+ "File \u001b[0;32m~/miniconda3/envs/torchenv/lib/python3.10/site-packages/accelerate/hooks.py:165\u001b[0m, in \u001b[0;36madd_hook_to_module.<locals>.new_forward\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 163\u001b[0m output \u001b[38;5;241m=\u001b[39m old_forward(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 164\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 165\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43mold_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 166\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m module\u001b[38;5;241m.\u001b[39m_hf_hook\u001b[38;5;241m.\u001b[39mpost_forward(module, output)\n",
314
+ "File \u001b[0;32m~/.cache/huggingface/modules/transformers_modules/microsoft/phi-2/85d00b03fee509307549d823fdd095473ba5197c/modeling_phi.py:1049\u001b[0m, in \u001b[0;36mPhiForCausalLM.forward\u001b[0;34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 1046\u001b[0m return_dict \u001b[38;5;241m=\u001b[39m return_dict \u001b[38;5;28;01mif\u001b[39;00m return_dict \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39muse_return_dict\n\u001b[1;32m 1048\u001b[0m \u001b[38;5;66;03m# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\u001b[39;00m\n\u001b[0;32m-> 1049\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1050\u001b[0m \u001b[43m \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1051\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1052\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1053\u001b[0m \u001b[43m \u001b[49m\u001b[43mpast_key_values\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1054\u001b[0m \u001b[43m \u001b[49m\u001b[43minputs_embeds\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs_embeds\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1055\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1056\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1057\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1058\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1059\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1061\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 1062\u001b[0m logits \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlm_head(hidden_states)\n",
315
+ "File \u001b[0;32m~/miniconda3/envs/torchenv/lib/python3.10/site-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
316
+ "File \u001b[0;32m~/miniconda3/envs/torchenv/lib/python3.10/site-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
317
+ "File \u001b[0;32m~/miniconda3/envs/torchenv/lib/python3.10/site-packages/accelerate/hooks.py:165\u001b[0m, in \u001b[0;36madd_hook_to_module.<locals>.new_forward\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 163\u001b[0m output \u001b[38;5;241m=\u001b[39m old_forward(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 164\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 165\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43mold_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 166\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m module\u001b[38;5;241m.\u001b[39m_hf_hook\u001b[38;5;241m.\u001b[39mpost_forward(module, output)\n",
318
+ "File \u001b[0;32m~/.cache/huggingface/modules/transformers_modules/microsoft/phi-2/85d00b03fee509307549d823fdd095473ba5197c/modeling_phi.py:893\u001b[0m, in \u001b[0;36mPhiModel.forward\u001b[0;34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 890\u001b[0m position_ids \u001b[38;5;241m=\u001b[39m position_ids\u001b[38;5;241m.\u001b[39munsqueeze(\u001b[38;5;241m0\u001b[39m)\n\u001b[1;32m 892\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m inputs_embeds \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 893\u001b[0m inputs_embeds \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43membed_tokens\u001b[49m\u001b[43m(\u001b[49m\u001b[43minput_ids\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 895\u001b[0m inputs_embeds \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39membed_dropout(inputs_embeds)\n\u001b[1;32m 897\u001b[0m \u001b[38;5;66;03m# Attention mask.\u001b[39;00m\n",
319
+ "File \u001b[0;32m~/miniconda3/envs/torchenv/lib/python3.10/site-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
320
+ "File \u001b[0;32m~/miniconda3/envs/torchenv/lib/python3.10/site-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
321
+ "File \u001b[0;32m~/miniconda3/envs/torchenv/lib/python3.10/site-packages/accelerate/hooks.py:165\u001b[0m, in \u001b[0;36madd_hook_to_module.<locals>.new_forward\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 163\u001b[0m output \u001b[38;5;241m=\u001b[39m old_forward(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 164\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 165\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43mold_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 166\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m module\u001b[38;5;241m.\u001b[39m_hf_hook\u001b[38;5;241m.\u001b[39mpost_forward(module, output)\n",
322
+ "File \u001b[0;32m~/miniconda3/envs/torchenv/lib/python3.10/site-packages/torch/nn/modules/sparse.py:162\u001b[0m, in \u001b[0;36mEmbedding.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 161\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m: Tensor) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[0;32m--> 162\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43membedding\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 163\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpadding_idx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmax_norm\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 164\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnorm_type\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mscale_grad_by_freq\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msparse\u001b[49m\u001b[43m)\u001b[49m\n",
323
+ "File \u001b[0;32m~/miniconda3/envs/torchenv/lib/python3.10/site-packages/torch/nn/functional.py:2233\u001b[0m, in \u001b[0;36membedding\u001b[0;34m(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)\u001b[0m\n\u001b[1;32m 2227\u001b[0m \u001b[38;5;66;03m# Note [embedding_renorm set_grad_enabled]\u001b[39;00m\n\u001b[1;32m 2228\u001b[0m \u001b[38;5;66;03m# XXX: equivalent to\u001b[39;00m\n\u001b[1;32m 2229\u001b[0m \u001b[38;5;66;03m# with torch.no_grad():\u001b[39;00m\n\u001b[1;32m 2230\u001b[0m \u001b[38;5;66;03m# torch.embedding_renorm_\u001b[39;00m\n\u001b[1;32m 2231\u001b[0m \u001b[38;5;66;03m# remove once script supports set_grad_enabled\u001b[39;00m\n\u001b[1;32m 2232\u001b[0m _no_grad_embedding_renorm_(weight, \u001b[38;5;28minput\u001b[39m, max_norm, norm_type)\n\u001b[0;32m-> 2233\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43membedding\u001b[49m\u001b[43m(\u001b[49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpadding_idx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mscale_grad_by_freq\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msparse\u001b[49m\u001b[43m)\u001b[49m\n",
324
+ "\u001b[0;31mRuntimeError\u001b[0m: Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got torch.cuda.FloatTensor instead (while checking arguments for embedding)"
325
+ ]
326
+ }
327
+ ],
328
+ "source": [
329
+ "model.forward(audio, image, text)"
330
+ ]
331
+ },
332
+ {
333
+ "cell_type": "code",
334
+ "execution_count": null,
335
+ "id": "4ca96caf-82e2-4f07-87b3-8654dfdc89aa",
336
+ "metadata": {},
337
+ "outputs": [],
338
+ "source": []
339
+ }
340
+ ],
341
+ "metadata": {
342
+ "kernelspec": {
343
+ "display_name": "Python 3 (ipykernel)",
344
+ "language": "python",
345
+ "name": "python3"
346
+ },
347
+ "language_info": {
348
+ "codemirror_mode": {
349
+ "name": "ipython",
350
+ "version": 3
351
+ },
352
+ "file_extension": ".py",
353
+ "mimetype": "text/x-python",
354
+ "name": "python",
355
+ "nbconvert_exporter": "python",
356
+ "pygments_lexer": "ipython3",
357
+ "version": "3.10.12"
358
+ }
359
+ },
360
+ "nbformat": 4,
361
+ "nbformat_minor": 5
362
+ }
Experiments/pretrain_data_check.ipynb ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 5,
6
+ "id": "61c272f2-edbe-4b7d-8fec-3ab431400cd3",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import json"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "code",
15
+ "execution_count": 2,
16
+ "id": "e9dfd7d7-1685-4fc7-bbb9-3905c32d8ba1",
17
+ "metadata": {},
18
+ "outputs": [],
19
+ "source": [
20
+ "with open(\"metadata.json\", \"rb\") as f:\n",
21
+ " metadata = json.load(f)"
22
+ ]
23
+ },
24
+ {
25
+ "cell_type": "code",
26
+ "execution_count": 4,
27
+ "id": "70bdba48-db01-42ac-8d89-edc69d7d7672",
28
+ "metadata": {},
29
+ "outputs": [
30
+ {
31
+ "data": {
32
+ "text/plain": [
33
+ "595375"
34
+ ]
35
+ },
36
+ "execution_count": 4,
37
+ "metadata": {},
38
+ "output_type": "execute_result"
39
+ }
40
+ ],
41
+ "source": [
42
+ "len(metadata)"
43
+ ]
44
+ },
45
+ {
46
+ "cell_type": "code",
47
+ "execution_count": 14,
48
+ "id": "59e193cc-0dd8-4f7e-959a-fbad0133d76c",
49
+ "metadata": {},
50
+ "outputs": [],
51
+ "source": [
52
+ "with open(\"blip_laion_cc_sbu_558k.jsonblip_laion_cc_sbu_558k.json\", \"rb\") as f:\n",
53
+ " data = json.load(f)"
54
+ ]
55
+ },
56
+ {
57
+ "cell_type": "code",
58
+ "execution_count": 7,
59
+ "id": "f3157f41-269b-4f7a-b3ba-9be711babe02",
60
+ "metadata": {},
61
+ "outputs": [
62
+ {
63
+ "data": {
64
+ "text/plain": [
65
+ "{'id': '004539375',\n",
66
+ " 'image': '00453/004539375.jpg',\n",
67
+ " 'conversations': [{'from': 'human',\n",
68
+ " 'value': 'Render a clear and concise summary of the photo.\\n<image>'},\n",
69
+ " {'from': 'gpt',\n",
70
+ " 'value': 'select luxury furniture 3 - inch gel memory foam mattress topper'}]}"
71
+ ]
72
+ },
73
+ "execution_count": 7,
74
+ "metadata": {},
75
+ "output_type": "execute_result"
76
+ }
77
+ ],
78
+ "source": [
79
+ "data[0]"
80
+ ]
81
+ },
82
+ {
83
+ "cell_type": "code",
84
+ "execution_count": 8,
85
+ "id": "50d8a051-1526-47dd-ad71-d3c66f7bd34e",
86
+ "metadata": {},
87
+ "outputs": [
88
+ {
89
+ "data": {
90
+ "text/plain": [
91
+ "{'id': '004374662',\n",
92
+ " 'image': '00437/004374662.jpg',\n",
93
+ " 'conversations': [{'from': 'human',\n",
94
+ " 'value': 'Give a brief description of the image.\\n<image>'},\n",
95
+ " {'from': 'gpt', 'value': 'the north face duffel bag camo large'}]}"
96
+ ]
97
+ },
98
+ "execution_count": 8,
99
+ "metadata": {},
100
+ "output_type": "execute_result"
101
+ }
102
+ ],
103
+ "source": [
104
+ "data[234]"
105
+ ]
106
+ },
107
+ {
108
+ "cell_type": "code",
109
+ "execution_count": 17,
110
+ "id": "2e6d5664-4583-49a6-93cc-079ee2d1ff6c",
111
+ "metadata": {},
112
+ "outputs": [
113
+ {
114
+ "data": {
115
+ "text/plain": [
116
+ "558128"
117
+ ]
118
+ },
119
+ "execution_count": 17,
120
+ "metadata": {},
121
+ "output_type": "execute_result"
122
+ }
123
+ ],
124
+ "source": [
125
+ "len(data)"
126
+ ]
127
+ },
128
+ {
129
+ "cell_type": "code",
130
+ "execution_count": 10,
131
+ "id": "11ed106d-6bef-482c-a456-5eaaf2025534",
132
+ "metadata": {},
133
+ "outputs": [
134
+ {
135
+ "data": {
136
+ "text/plain": [
137
+ "{'id': 'GCC_train_001749371',\n",
138
+ " 'image': 'GCC_train_001749371.jpg',\n",
139
+ " 'caption': 'if you are dreaming of simpler or off - the - grid living , a yurt is a fantastic option',\n",
140
+ " 'blip_caption': 'a white and tan yurt sitting on a dirt road',\n",
141
+ " 'url': 'https://i.pinimg.com/736x/14/7b/64/147b64467ee966d9a578097bb70475ad--yurt-kits-small-space-living.jpg'}"
142
+ ]
143
+ },
144
+ "execution_count": 10,
145
+ "metadata": {},
146
+ "output_type": "execute_result"
147
+ }
148
+ ],
149
+ "source": [
150
+ "metadata[67]"
151
+ ]
152
+ },
153
+ {
154
+ "cell_type": "code",
155
+ "execution_count": 15,
156
+ "id": "ce8adcec-2499-4be3-be1d-7313fe54e96a",
157
+ "metadata": {},
158
+ "outputs": [
159
+ {
160
+ "data": {
161
+ "text/plain": [
162
+ "{'id': '000466761',\n",
163
+ " 'image': '00046/000466761.jpg',\n",
164
+ " 'conversations': [{'from': 'human',\n",
165
+ " 'value': '<image>\\nProvide a brief description of the given image.'},\n",
166
+ " {'from': 'gpt',\n",
167
+ " 'value': 'a clipboard and a pen with the words public health emergency next to it on a white table'}]}"
168
+ ]
169
+ },
170
+ "execution_count": 15,
171
+ "metadata": {},
172
+ "output_type": "execute_result"
173
+ }
174
+ ],
175
+ "source": [
176
+ "data[67]"
177
+ ]
178
+ },
179
+ {
180
+ "cell_type": "code",
181
+ "execution_count": 16,
182
+ "id": "068313b6-6379-4ca2-892c-682634d3581e",
183
+ "metadata": {},
184
+ "outputs": [
185
+ {
186
+ "data": {
187
+ "text/plain": [
188
+ "list"
189
+ ]
190
+ },
191
+ "execution_count": 16,
192
+ "metadata": {},
193
+ "output_type": "execute_result"
194
+ }
195
+ ],
196
+ "source": [
197
+ "type(data)"
198
+ ]
199
+ },
200
+ {
201
+ "cell_type": "code",
202
+ "execution_count": 24,
203
+ "id": "9ec33b51-4a0b-4a1e-81f7-2fda7cddb25f",
204
+ "metadata": {},
205
+ "outputs": [],
206
+ "source": [
207
+ "sample_data = data[:200000]"
208
+ ]
209
+ },
210
+ {
211
+ "cell_type": "code",
212
+ "execution_count": 25,
213
+ "id": "095685e5-40f1-4d84-8280-ef74fa56c5a2",
214
+ "metadata": {},
215
+ "outputs": [
216
+ {
217
+ "data": {
218
+ "text/plain": [
219
+ "200000"
220
+ ]
221
+ },
222
+ "execution_count": 25,
223
+ "metadata": {},
224
+ "output_type": "execute_result"
225
+ }
226
+ ],
227
+ "source": [
228
+ "len(sample_data)"
229
+ ]
230
+ },
231
+ {
232
+ "cell_type": "code",
233
+ "execution_count": 26,
234
+ "id": "ffbad552-23fd-475f-8e9a-7118bcc4f51e",
235
+ "metadata": {},
236
+ "outputs": [],
237
+ "source": [
238
+ "with open(\"llava-phi/pretrain_data/blip_sample.json\", \"w\") as f:\n",
239
+ " json.dump(sample_data, f)"
240
+ ]
241
+ },
242
+ {
243
+ "cell_type": "code",
244
+ "execution_count": 27,
245
+ "id": "69a05d25-6f3b-40c0-a3b5-e185ff526471",
246
+ "metadata": {},
247
+ "outputs": [],
248
+ "source": [
249
+ "with open(\"llava-phi/pretrain_data/blip_sample.json\", \"rb\") as f:\n",
250
+ " sample = json.load(f)"
251
+ ]
252
+ },
253
+ {
254
+ "cell_type": "code",
255
+ "execution_count": 28,
256
+ "id": "200eea06-dfd6-4b3a-bb91-82af7d363951",
257
+ "metadata": {},
258
+ "outputs": [
259
+ {
260
+ "data": {
261
+ "text/plain": [
262
+ "200000"
263
+ ]
264
+ },
265
+ "execution_count": 28,
266
+ "metadata": {},
267
+ "output_type": "execute_result"
268
+ }
269
+ ],
270
+ "source": [
271
+ "len(sample)"
272
+ ]
273
+ },
274
+ {
275
+ "cell_type": "code",
276
+ "execution_count": null,
277
+ "id": "f86caa1e-edea-4a9c-934f-5420ede80d0d",
278
+ "metadata": {},
279
+ "outputs": [],
280
+ "source": []
281
+ }
282
+ ],
283
+ "metadata": {
284
+ "kernelspec": {
285
+ "display_name": "Python 3 (ipykernel)",
286
+ "language": "python",
287
+ "name": "python3"
288
+ },
289
+ "language_info": {
290
+ "codemirror_mode": {
291
+ "name": "ipython",
292
+ "version": 3
293
+ },
294
+ "file_extension": ".py",
295
+ "mimetype": "text/x-python",
296
+ "name": "python",
297
+ "nbconvert_exporter": "python",
298
+ "pygments_lexer": "ipython3",
299
+ "version": "3.10.12"
300
+ }
301
+ },
302
+ "nbformat": 4,
303
+ "nbformat_minor": 5
304
+ }
Experiments/whispher_exp.ipynb ADDED
@@ -0,0 +1,500 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 9,
6
+ "id": "bb4dd66b-0c17-48d4-9d34-f48cece2feb5",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "# !pip install soundfile\n",
11
+ "# !pip install librosa"
12
+ ]
13
+ },
14
+ {
15
+ "cell_type": "code",
16
+ "execution_count": 1,
17
+ "id": "6e9386ea-4862-4f5b-a02f-d656e1a5ab9e",
18
+ "metadata": {},
19
+ "outputs": [],
20
+ "source": [
21
+ "from transformers import WhisperProcessor, WhisperForConditionalGeneration\n",
22
+ "from datasets import load_dataset"
23
+ ]
24
+ },
25
+ {
26
+ "cell_type": "code",
27
+ "execution_count": 2,
28
+ "id": "914ab2b4-389d-4c48-8d1d-1250356646ac",
29
+ "metadata": {},
30
+ "outputs": [],
31
+ "source": [
32
+ "# load model and processor\n",
33
+ "processor = WhisperProcessor.from_pretrained(\"openai/whisper-tiny\")\n",
34
+ "model = WhisperForConditionalGeneration.from_pretrained(\"openai/whisper-tiny\")\n",
35
+ "model.config.forced_decoder_ids = None\n",
36
+ "\n",
37
+ "# load dummy dataset and read audio files\n",
38
+ "ds = load_dataset(\"hf-internal-testing/librispeech_asr_dummy\", \"clean\", split=\"validation\")\n",
39
+ "sample = ds[0][\"audio\"]"
40
+ ]
41
+ },
42
+ {
43
+ "cell_type": "code",
44
+ "execution_count": 3,
45
+ "id": "2b299bab-1228-48d9-a8a5-3d5b6c52162d",
46
+ "metadata": {},
47
+ "outputs": [
48
+ {
49
+ "data": {
50
+ "text/plain": [
51
+ "{'path': '/home/ravi.naik/.cache/huggingface/datasets/downloads/extracted/431c2c946d216530b2666a0e7ffa5ac3f5b3da89dd28858a9de6c78fae7caa4a/dev_clean/1272/128104/1272-128104-0000.flac',\n",
52
+ " 'array': array([0.00238037, 0.0020752 , 0.00198364, ..., 0.00042725, 0.00057983,\n",
53
+ " 0.0010376 ]),\n",
54
+ " 'sampling_rate': 16000}"
55
+ ]
56
+ },
57
+ "execution_count": 3,
58
+ "metadata": {},
59
+ "output_type": "execute_result"
60
+ }
61
+ ],
62
+ "source": [
63
+ "sample"
64
+ ]
65
+ },
66
+ {
67
+ "cell_type": "code",
68
+ "execution_count": 4,
69
+ "id": "b7e570a1-cf5c-450c-a7b6-49b45a10d2df",
70
+ "metadata": {},
71
+ "outputs": [],
72
+ "source": [
73
+ "input_features = processor(sample[\"array\"], sampling_rate=sample[\"sampling_rate\"], return_tensors=\"pt\").input_features "
74
+ ]
75
+ },
76
+ {
77
+ "cell_type": "code",
78
+ "execution_count": 5,
79
+ "id": "584e920b-a7fd-402d-95dd-3b9128cd34bb",
80
+ "metadata": {},
81
+ "outputs": [],
82
+ "source": [
83
+ "# generate token ids\n",
84
+ "predicted_ids = model.generate(input_features)\n",
85
+ "# decode token ids to text\n",
86
+ "transcription = processor.batch_decode(predicted_ids, skip_special_tokens=False)\n",
87
+ "\n",
88
+ "transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)"
89
+ ]
90
+ },
91
+ {
92
+ "cell_type": "code",
93
+ "execution_count": 6,
94
+ "id": "b27ab660-861b-49d1-81f9-f51cb7f9d8d8",
95
+ "metadata": {},
96
+ "outputs": [
97
+ {
98
+ "data": {
99
+ "text/plain": [
100
+ "[' Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.']"
101
+ ]
102
+ },
103
+ "execution_count": 6,
104
+ "metadata": {},
105
+ "output_type": "execute_result"
106
+ }
107
+ ],
108
+ "source": [
109
+ "transcription"
110
+ ]
111
+ },
112
+ {
113
+ "cell_type": "code",
114
+ "execution_count": 3,
115
+ "id": "eca553b8-68f6-493d-b567-3d526b49ae1b",
116
+ "metadata": {},
117
+ "outputs": [],
118
+ "source": [
119
+ "import torch\n",
120
+ "from torch import nn"
121
+ ]
122
+ },
123
+ {
124
+ "cell_type": "code",
125
+ "execution_count": 4,
126
+ "id": "c619a4cf-9068-4e4d-8139-e16d15345f4f",
127
+ "metadata": {},
128
+ "outputs": [],
129
+ "source": [
130
+ "from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoTokenizer"
131
+ ]
132
+ },
133
+ {
134
+ "cell_type": "code",
135
+ "execution_count": 5,
136
+ "id": "47d5b1ff-ab0f-4d11-af64-d2fa2be39286",
137
+ "metadata": {},
138
+ "outputs": [
139
+ {
140
+ "name": "stderr",
141
+ "output_type": "stream",
142
+ "text": [
143
+ "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
144
+ ]
145
+ }
146
+ ],
147
+ "source": [
148
+ "model_name = \"microsoft/phi-2\"\n",
149
+ "phi2_tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)\n",
150
+ "phi2_tokenizer.pad_token = phi2_tokenizer.eos_token"
151
+ ]
152
+ },
153
+ {
154
+ "cell_type": "code",
155
+ "execution_count": 6,
156
+ "id": "0b36b3f0-db5b-4029-9072-0a53bcab315a",
157
+ "metadata": {},
158
+ "outputs": [
159
+ {
160
+ "ename": "NameError",
161
+ "evalue": "name 'transcription' is not defined",
162
+ "output_type": "error",
163
+ "traceback": [
164
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
165
+ "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
166
+ "Cell \u001b[0;32mIn[6], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m tokens \u001b[38;5;241m=\u001b[39m phi2_tokenizer(\u001b[38;5;241m*\u001b[39m\u001b[43mtranscription\u001b[49m, return_tensors\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpt\u001b[39m\u001b[38;5;124m\"\u001b[39m, return_attention_mask\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n",
167
+ "\u001b[0;31mNameError\u001b[0m: name 'transcription' is not defined"
168
+ ]
169
+ }
170
+ ],
171
+ "source": [
172
+ "tokens = phi2_tokenizer(*transcription, return_tensors=\"pt\", return_attention_mask=False)"
173
+ ]
174
+ },
175
+ {
176
+ "cell_type": "code",
177
+ "execution_count": 22,
178
+ "id": "91f6d3d3-bb00-434f-a91e-6952375890d0",
179
+ "metadata": {},
180
+ "outputs": [
181
+ {
182
+ "data": {
183
+ "text/plain": [
184
+ "{'input_ids': tensor([[ 1770, 13, 2264, 346, 353, 318, 262, 46329, 286, 262,\n",
185
+ " 3504, 6097, 290, 356, 389, 9675, 284, 7062, 465, 21443,\n",
186
+ " 13]])}"
187
+ ]
188
+ },
189
+ "execution_count": 22,
190
+ "metadata": {},
191
+ "output_type": "execute_result"
192
+ }
193
+ ],
194
+ "source": [
195
+ "tokens"
196
+ ]
197
+ },
198
+ {
199
+ "cell_type": "code",
200
+ "execution_count": 12,
201
+ "id": "533191d9-4b3b-417a-918d-6fe854f24b50",
202
+ "metadata": {},
203
+ "outputs": [
204
+ {
205
+ "name": "stderr",
206
+ "output_type": "stream",
207
+ "text": [
208
+ "A new version of the following files was downloaded from https://huggingface.co/microsoft/phi-2:\n",
209
+ "- configuration_phi.py\n",
210
+ ". Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.\n"
211
+ ]
212
+ },
213
+ {
214
+ "data": {
215
+ "application/vnd.jupyter.widget-view+json": {
216
+ "model_id": "2a65a119388b4cb4b123b532176e786e",
217
+ "version_major": 2,
218
+ "version_minor": 0
219
+ },
220
+ "text/plain": [
221
+ "modeling_phi.py: 0%| | 0.00/62.7k [00:00<?, ?B/s]"
222
+ ]
223
+ },
224
+ "metadata": {},
225
+ "output_type": "display_data"
226
+ },
227
+ {
228
+ "name": "stderr",
229
+ "output_type": "stream",
230
+ "text": [
231
+ "A new version of the following files was downloaded from https://huggingface.co/microsoft/phi-2:\n",
232
+ "- modeling_phi.py\n",
233
+ ". Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.\n"
234
+ ]
235
+ },
236
+ {
237
+ "data": {
238
+ "application/vnd.jupyter.widget-view+json": {
239
+ "model_id": "7183811844304c16b72d53fe11098a74",
240
+ "version_major": 2,
241
+ "version_minor": 0
242
+ },
243
+ "text/plain": [
244
+ "Downloading shards: 0%| | 0/2 [00:00<?, ?it/s]"
245
+ ]
246
+ },
247
+ "metadata": {},
248
+ "output_type": "display_data"
249
+ },
250
+ {
251
+ "data": {
252
+ "application/vnd.jupyter.widget-view+json": {
253
+ "model_id": "3e78fe144e8f42139a4d7a1830dbf192",
254
+ "version_major": 2,
255
+ "version_minor": 0
256
+ },
257
+ "text/plain": [
258
+ "Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
259
+ ]
260
+ },
261
+ "metadata": {},
262
+ "output_type": "display_data"
263
+ }
264
+ ],
265
+ "source": [
266
+ "bnb_config = BitsAndBytesConfig(\n",
267
+ " load_in_4bit=True,\n",
268
+ " bnb_4bit_quant_type=\"nf4\",\n",
269
+ " bnb_4bit_compute_dtype=torch.float16,\n",
270
+ ")\n",
271
+ "\n",
272
+ "model = AutoModelForCausalLM.from_pretrained(\n",
273
+ " model_name,\n",
274
+ " quantization_config=bnb_config,\n",
275
+ " trust_remote_code=True,\n",
276
+ " device_map=\"cuda:0\"\n",
277
+ ")\n",
278
+ "model.config.use_cache = False"
279
+ ]
280
+ },
281
+ {
282
+ "cell_type": "code",
283
+ "execution_count": 19,
284
+ "id": "155c054a-a00f-4ed5-bfff-1ad64889e7f1",
285
+ "metadata": {},
286
+ "outputs": [
287
+ {
288
+ "data": {
289
+ "text/plain": [
290
+ "[' Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.\\n']"
291
+ ]
292
+ },
293
+ "execution_count": 19,
294
+ "metadata": {},
295
+ "output_type": "execute_result"
296
+ }
297
+ ],
298
+ "source": [
299
+ "phi2_tokenizer.batch_decode(model.generate(**tokens))"
300
+ ]
301
+ },
302
+ {
303
+ "cell_type": "code",
304
+ "execution_count": 7,
305
+ "id": "04f940c9-586d-4937-ae31-cc0f96d33e92",
306
+ "metadata": {},
307
+ "outputs": [],
308
+ "source": [
309
+ "class AudioLanguageConnector:\n",
310
+ " def __init__(self):\n",
311
+ " model_name = \"microsoft/phi-2\"\n",
312
+ " self.phi2_tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)\n",
313
+ " self.phi2_tokenizer.pad_token = self.phi2_tokenizer.eos_token\n",
314
+ "\n",
315
+ " def __call__(self, text):\n",
316
+ " text = f\"<audio_start> {text} <audio_end>\"\n",
317
+ " tokens = self.phi2_tokenizer(text, return_tensors=\"pt\", return_attention_mask=False)\n",
318
+ " return tokens\n",
319
+ " \n",
320
+ "\n",
321
+ "class WhisperWithProjection:\n",
322
+ " def __init__(self):\n",
323
+ " self.processor = WhisperProcessor.from_pretrained(\"openai/whisper-tiny\")\n",
324
+ " self.model = WhisperForConditionalGeneration.from_pretrained(\"openai/whisper-tiny\")\n",
325
+ " self.model.config.forced_decoder_ids = None\n",
326
+ " self.audio_language_connector = AudioLanguageConnector()\n",
327
+ " \n",
328
+ " def forward(self, audio):\n",
329
+ " input_features = self.processor(audio[\"array\"],\n",
330
+ " sampling_rate=audio[\"sampling_rate\"],\n",
331
+ " return_tensors=\"pt\").input_features\n",
332
+ " # generate token ids\n",
333
+ " predicted_ids = self.model.generate(input_features)\n",
334
+ " # decode token ids to text \n",
335
+ " transcription = self.processor.batch_decode(predicted_ids, skip_special_tokens=True)\n",
336
+ "\n",
337
+ " audio_embeddings = self.audio_language_connector(transcription)\n",
338
+ " return audio_embeddings"
339
+ ]
340
+ },
341
+ {
342
+ "cell_type": "code",
343
+ "execution_count": 8,
344
+ "id": "2b1f8f44-bfe6-413c-9e32-c38fa5517981",
345
+ "metadata": {},
346
+ "outputs": [],
347
+ "source": [
348
+ "class TextModality:\n",
349
+ " def __init__(self):\n",
350
+ " model_name = \"microsoft/phi-2\"\n",
351
+ " self.phi2_tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)\n",
352
+ " self.phi2_tokenizer.pad_token = self.phi2_tokenizer.eos_token\n",
353
+ "\n",
354
+ " def __call__(self, text):\n",
355
+ " tokens = self.phi2_tokenizer(text, return_tensors=\"pt\", return_attention_mask=False)\n",
356
+ " return tokens"
357
+ ]
358
+ },
359
+ {
360
+ "cell_type": "code",
361
+ "execution_count": 15,
362
+ "id": "21c51648-abb6-4bbd-b4c1-509967a69337",
363
+ "metadata": {},
364
+ "outputs": [],
365
+ "source": [
366
+ "class MultiModalPhi2:\n",
367
+ " def __init__(self):\n",
368
+ " self.text_modality = TextModality()\n",
369
+ " self.whisper_w_proj = WhisperWithProjection()\n",
370
+ " self.llm = self.load_llm()\n",
371
+ "\n",
372
+ " def load_llm(self):\n",
373
+ " bnb_config = BitsAndBytesConfig(\n",
374
+ " load_in_4bit=True,\n",
375
+ " bnb_4bit_quant_type=\"nf4\",\n",
376
+ " bnb_4bit_compute_dtype=torch.float16)\n",
377
+ " \n",
378
+ " model = AutoModelForCausalLM.from_pretrained(\n",
379
+ " model_name,\n",
380
+ " quantization_config=bnb_config,\n",
381
+ " trust_remote_code=True,\n",
382
+ " device_map=\"cuda:0\"\n",
383
+ " )\n",
384
+ " model.config.use_cache = False\n",
385
+ " return model\n",
386
+ "\n",
387
+ " def generate(self, audio, text):\n",
388
+ " text_embeddings = self.text_modality(text)\n",
389
+ " audio_embeddings = self.whisper_w_proj.forward(audio)\n",
390
+ " inputs = torch.concat([text_embeddings[\"input_ids\"], audio_embeddings[\"input_ids\"]], dim=1)\n",
391
+ " \n",
392
+ " # outputs = self.llm.generate(inputs, max_length=200)\n",
393
+ " outputs = self.llm(inputs)\n",
394
+ " return outputs\n",
395
+ " \n",
396
+ " # text = self.text_modality.phi2_tokenizer.batch_decode(outputs)[0]\n",
397
+ " # print(text)"
398
+ ]
399
+ },
400
+ {
401
+ "cell_type": "code",
402
+ "execution_count": 16,
403
+ "id": "472a00cb-bae9-4c09-a0ef-bc57881b5e2c",
404
+ "metadata": {},
405
+ "outputs": [
406
+ {
407
+ "name": "stderr",
408
+ "output_type": "stream",
409
+ "text": [
410
+ "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n",
411
+ "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
412
+ ]
413
+ },
414
+ {
415
+ "data": {
416
+ "application/vnd.jupyter.widget-view+json": {
417
+ "model_id": "2236e6b1e26d444fa3d48181ba1a6cf9",
418
+ "version_major": 2,
419
+ "version_minor": 0
420
+ },
421
+ "text/plain": [
422
+ "Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
423
+ ]
424
+ },
425
+ "metadata": {},
426
+ "output_type": "display_data"
427
+ }
428
+ ],
429
+ "source": [
430
+ "multi_modal_phi = MultiModalPhi2()"
431
+ ]
432
+ },
433
+ {
434
+ "cell_type": "code",
435
+ "execution_count": 17,
436
+ "id": "c350f2d3-0929-4c46-b63d-ff92dea437f3",
437
+ "metadata": {},
438
+ "outputs": [
439
+ {
440
+ "data": {
441
+ "text/plain": [
442
+ "CausalLMOutputWithPast(loss={'logits': tensor([[[ 6.9531, 9.9375, 7.0234, ..., 2.0020, 2.0020, 2.0000],\n",
443
+ " [ 8.9062, 12.1172, 7.5977, ..., -1.2012, -1.2012, -1.2012],\n",
444
+ " [ 7.0273, 5.3477, 3.6328, ..., -4.2070, -4.2070, -4.2070],\n",
445
+ " ...,\n",
446
+ " [ 7.0234, 7.4414, 9.1016, ..., 1.0117, 1.0127, 1.0117],\n",
447
+ " [ 9.4531, 10.0391, 9.7578, ..., 0.0776, 0.0775, 0.0764],\n",
448
+ " [ 8.0703, 6.6445, 5.5156, ..., -1.9268, -1.9268, -1.9277]]],\n",
449
+ " grad_fn=<ToCopyBackward0>)}, logits=tensor([[[ 6.9531, 9.9375, 7.0234, ..., 2.0020, 2.0020, 2.0000],\n",
450
+ " [ 8.9062, 12.1172, 7.5977, ..., -1.2012, -1.2012, -1.2012],\n",
451
+ " [ 7.0273, 5.3477, 3.6328, ..., -4.2070, -4.2070, -4.2070],\n",
452
+ " ...,\n",
453
+ " [ 7.0234, 7.4414, 9.1016, ..., 1.0117, 1.0127, 1.0117],\n",
454
+ " [ 9.4531, 10.0391, 9.7578, ..., 0.0776, 0.0775, 0.0764],\n",
455
+ " [ 8.0703, 6.6445, 5.5156, ..., -1.9268, -1.9268, -1.9277]]],\n",
456
+ " grad_fn=<ToCopyBackward0>), past_key_values=None, hidden_states=None, attentions=None)"
457
+ ]
458
+ },
459
+ "execution_count": 17,
460
+ "metadata": {},
461
+ "output_type": "execute_result"
462
+ }
463
+ ],
464
+ "source": [
465
+ "audio = sample\n",
466
+ "text = \"explain about the audio\"\n",
467
+ "multi_modal_phi.generate(audio, text)"
468
+ ]
469
+ },
470
+ {
471
+ "cell_type": "code",
472
+ "execution_count": null,
473
+ "id": "46aa9c66-a5bb-4760-8895-92673f49345f",
474
+ "metadata": {},
475
+ "outputs": [],
476
+ "source": []
477
+ }
478
+ ],
479
+ "metadata": {
480
+ "kernelspec": {
481
+ "display_name": "Python 3 (ipykernel)",
482
+ "language": "python",
483
+ "name": "python3"
484
+ },
485
+ "language_info": {
486
+ "codemirror_mode": {
487
+ "name": "ipython",
488
+ "version": 3
489
+ },
490
+ "file_extension": ".py",
491
+ "mimetype": "text/x-python",
492
+ "name": "python",
493
+ "nbconvert_exporter": "python",
494
+ "pygments_lexer": "ipython3",
495
+ "version": "3.10.12"
496
+ }
497
+ },
498
+ "nbformat": 4,
499
+ "nbformat_minor": 5
500
+ }
inference/__init__.py ADDED
File without changes
inference/conversation.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ from enum import auto, Enum
3
+ from typing import List, Tuple
4
+
5
+
6
+ class SeparatorStyle(Enum):
7
+ """Different separator style."""
8
+ SINGLE = auto()
9
+ TWO = auto()
10
+ MPT = auto()
11
+ PLAIN = auto()
12
+ LLAMA_2 = auto()
13
+
14
+
15
+ @dataclasses.dataclass
16
+ class Conversation:
17
+ """A class that keeps all conversation history."""
18
+ system: str
19
+ roles: List[str]
20
+ messages: List[List[str]]
21
+ offset: int
22
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
23
+ sep: str = "###"
24
+ sep2: str = None
25
+ version: str = "Unknown"
26
+
27
+ skip_next: bool = False
28
+
29
+ def get_prompt(self):
30
+ messages = self.messages
31
+ if len(messages) > 0 and type(messages[0][1]) is tuple:
32
+ messages = self.messages.copy()
33
+ init_role, init_msg = messages[0].copy()
34
+ init_msg = init_msg[0].replace("<image>", "").strip()
35
+ if 'mmtag' in self.version:
36
+ messages[0] = (init_role, init_msg)
37
+ messages.insert(0, (self.roles[0], "<Image><image></Image>"))
38
+ messages.insert(1, (self.roles[1], "Received."))
39
+ else:
40
+ messages[0] = (init_role, "<image>\n" + init_msg)
41
+
42
+ if self.sep_style == SeparatorStyle.SINGLE:
43
+ ret = self.system + self.sep
44
+ for role, message in messages:
45
+ if message:
46
+ if type(message) is tuple:
47
+ message, _, _ = message
48
+ ret += role + ": " + message + self.sep
49
+ else:
50
+ ret += role + ":"
51
+ elif self.sep_style == SeparatorStyle.TWO:
52
+ seps = [self.sep, self.sep2]
53
+ ret = self.system + seps[0]
54
+ for i, (role, message) in enumerate(messages):
55
+ if message:
56
+ if type(message) is tuple:
57
+ message, _, _ = message
58
+ ret += role + ": " + message + seps[i % 2]
59
+ else:
60
+ ret += role + ":"
61
+ elif self.sep_style == SeparatorStyle.PLAIN:
62
+ seps = [self.sep, self.sep2]
63
+ ret = self.system
64
+ for i, (role, message) in enumerate(messages):
65
+ if message:
66
+ if type(message) is tuple:
67
+ message, _, _ = message
68
+ ret += message + seps[i % 2]
69
+ else:
70
+ ret += ""
71
+ else:
72
+ raise ValueError(f"Invalid style: {self.sep_style}")
73
+
74
+ return ret
75
+
76
+ def append_message(self, role, message):
77
+ self.messages.append([role, message])
78
+
79
+ def get_images(self, return_pil=False):
80
+ images = []
81
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
82
+ if i % 2 == 0:
83
+ if type(msg) is tuple:
84
+ import base64
85
+ from io import BytesIO
86
+ from PIL import Image
87
+ msg, image, image_process_mode = msg
88
+ if image_process_mode == "Pad":
89
+ def expand2square(pil_img, background_color=(122, 116, 104)):
90
+ width, height = pil_img.size
91
+ if width == height:
92
+ return pil_img
93
+ elif width > height:
94
+ result = Image.new(pil_img.mode, (width, width), background_color)
95
+ result.paste(pil_img, (0, (width - height) // 2))
96
+ return result
97
+ else:
98
+ result = Image.new(pil_img.mode, (height, height), background_color)
99
+ result.paste(pil_img, ((height - width) // 2, 0))
100
+ return result
101
+ image = expand2square(image)
102
+ elif image_process_mode in ["Default", "Crop"]:
103
+ pass
104
+ elif image_process_mode == "Resize":
105
+ image = image.resize((336, 336))
106
+ else:
107
+ raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
108
+ max_hw, min_hw = max(image.size), min(image.size)
109
+ aspect_ratio = max_hw / min_hw
110
+ max_len, min_len = 800, 400
111
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
112
+ longest_edge = int(shortest_edge * aspect_ratio)
113
+ W, H = image.size
114
+ if longest_edge != max(image.size):
115
+ if H > W:
116
+ H, W = longest_edge, shortest_edge
117
+ else:
118
+ H, W = shortest_edge, longest_edge
119
+ image = image.resize((W, H))
120
+ if return_pil:
121
+ images.append(image)
122
+ else:
123
+ buffered = BytesIO()
124
+ image.save(buffered, format="PNG")
125
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
126
+ images.append(img_b64_str)
127
+ return images
128
+
129
+ def to_gradio_chatbot(self):
130
+ ret = []
131
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
132
+ if i % 2 == 0:
133
+ if type(msg) is tuple:
134
+ import base64
135
+ from io import BytesIO
136
+ msg, image, image_process_mode = msg
137
+ max_hw, min_hw = max(image.size), min(image.size)
138
+ aspect_ratio = max_hw / min_hw
139
+ max_len, min_len = 800, 400
140
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
141
+ longest_edge = int(shortest_edge * aspect_ratio)
142
+ W, H = image.size
143
+ if H > W:
144
+ H, W = longest_edge, shortest_edge
145
+ else:
146
+ H, W = shortest_edge, longest_edge
147
+ image = image.resize((W, H))
148
+ buffered = BytesIO()
149
+ image.save(buffered, format="JPEG")
150
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
151
+ img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
152
+ msg = img_str + msg.replace('<image>', '').strip()
153
+ ret.append([msg, None])
154
+ else:
155
+ ret.append([msg, None])
156
+ else:
157
+ ret[-1][-1] = msg
158
+ return ret
159
+
160
+ def copy(self):
161
+ return Conversation(
162
+ system=self.system,
163
+ roles=self.roles,
164
+ messages=[[x, y] for x, y in self.messages],
165
+ offset=self.offset,
166
+ sep_style=self.sep_style,
167
+ sep=self.sep,
168
+ sep2=self.sep2,
169
+ version=self.version)
170
+
171
+ def dict(self):
172
+ if len(self.get_images()) > 0:
173
+ return {
174
+ "system": self.system,
175
+ "roles": self.roles,
176
+ "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
177
+ "offset": self.offset,
178
+ "sep": self.sep,
179
+ "sep2": self.sep2,
180
+ }
181
+ return {
182
+ "system": self.system,
183
+ "roles": self.roles,
184
+ "messages": self.messages,
185
+ "offset": self.offset,
186
+ "sep": self.sep,
187
+ "sep2": self.sep2,
188
+ }
189
+
190
+
191
+ conv_phi_v0 = Conversation(
192
+ system="A chat between a curious user and an artificial intelligence assistant. "
193
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
194
+ roles=("USER", "ASSISTANT"),
195
+ version="v0",
196
+ messages=(),
197
+ offset=0,
198
+ sep_style=SeparatorStyle.TWO,
199
+ sep=" ",
200
+ sep2="<|endoftext|>",
201
+ )
202
+
203
+ conv_llava_plain = Conversation(
204
+ system="",
205
+ roles=("", ""),
206
+ messages=(
207
+ ),
208
+ offset=0,
209
+ sep_style=SeparatorStyle.PLAIN,
210
+ sep="\n",
211
+ )
212
+
213
+ default_conversation = conv_phi_v0
214
+ conv_templates = {
215
+ "default": conv_phi_v0,
216
+ "v0": conv_phi_v0,
217
+ "phi-2_v0": conv_phi_v0,
218
+
219
+ "plain": conv_llava_plain,
220
+ }
221
+
222
+
223
+ if __name__ == "__main__":
224
+ print(default_conversation.get_prompt())
inference/inference.ipynb ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "cdad6b21-030a-40d3-9b31-a229e5b6196d",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import torch\n",
11
+ "from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoTokenizer, AutoConfig, CLIPImageProcessor"
12
+ ]
13
+ },
14
+ {
15
+ "cell_type": "code",
16
+ "execution_count": 2,
17
+ "id": "1f832710-0e8c-42ec-b581-1b15fd2a6acc",
18
+ "metadata": {},
19
+ "outputs": [
20
+ {
21
+ "name": "stdout",
22
+ "output_type": "stream",
23
+ "text": [
24
+ "[2024-01-25 14:31:58,511] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n"
25
+ ]
26
+ }
27
+ ],
28
+ "source": [
29
+ "from model import LlavaPhiForCausalLM"
30
+ ]
31
+ },
32
+ {
33
+ "cell_type": "code",
34
+ "execution_count": 3,
35
+ "id": "9e68f1d4-1ae3-4d45-b818-4600218d2215",
36
+ "metadata": {},
37
+ "outputs": [
38
+ {
39
+ "data": {
40
+ "application/vnd.jupyter.widget-view+json": {
41
+ "model_id": "e5e13e666e3a43d4ad26cc70904abee8",
42
+ "version_major": 2,
43
+ "version_minor": 0
44
+ },
45
+ "text/plain": [
46
+ "Loading checkpoint shards: 0%| | 0/3 [00:00<?, ?it/s]"
47
+ ]
48
+ },
49
+ "metadata": {},
50
+ "output_type": "display_data"
51
+ }
52
+ ],
53
+ "source": [
54
+ "model_name = \"RaviNaik/Llava-Phi2\"\n",
55
+ "model = LlavaPhiForCausalLM.from_pretrained(model_name)"
56
+ ]
57
+ },
58
+ {
59
+ "cell_type": "code",
60
+ "execution_count": 4,
61
+ "id": "49edfa0d-e08a-4d3c-a1d6-34068b122419",
62
+ "metadata": {},
63
+ "outputs": [
64
+ {
65
+ "name": "stderr",
66
+ "output_type": "stream",
67
+ "text": [
68
+ "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
69
+ ]
70
+ }
71
+ ],
72
+ "source": [
73
+ "tokenizer = AutoTokenizer.from_pretrained(model_name)"
74
+ ]
75
+ },
76
+ {
77
+ "cell_type": "code",
78
+ "execution_count": 5,
79
+ "id": "dcec20cd-d946-42d7-8e10-c198cd49b486",
80
+ "metadata": {},
81
+ "outputs": [],
82
+ "source": [
83
+ "image_processor = CLIPImageProcessor.from_pretrained(model_name)\n",
84
+ "mm_use_im_start_end = getattr(model.config, \"mm_use_im_start_end\", False)\n",
85
+ "mm_use_im_patch_token = getattr(model.config, \"mm_use_im_patch_token\", True)"
86
+ ]
87
+ },
88
+ {
89
+ "cell_type": "code",
90
+ "execution_count": 6,
91
+ "id": "443c13c4-b7e6-4bc5-b6c7-c577bd4708f6",
92
+ "metadata": {},
93
+ "outputs": [],
94
+ "source": [
95
+ "if mm_use_im_patch_token:\n",
96
+ " tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)\n",
97
+ "if mm_use_im_start_end:\n",
98
+ " tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)\n",
99
+ " \n",
100
+ "if hasattr(model.config, \"max_sequence_length\"):\n",
101
+ " context_len = model.config.max_sequence_length\n",
102
+ "else:\n",
103
+ " context_len = 2048"
104
+ ]
105
+ },
106
+ {
107
+ "cell_type": "code",
108
+ "execution_count": 7,
109
+ "id": "d8caee43-0d2a-46d4-bdbc-2cfc7dec9e52",
110
+ "metadata": {},
111
+ "outputs": [],
112
+ "source": [
113
+ "from transformers import WhisperProcessor, WhisperForConditionalGeneration"
114
+ ]
115
+ },
116
+ {
117
+ "cell_type": "code",
118
+ "execution_count": 8,
119
+ "id": "3acea526-d8ae-4eb6-8dfc-4ea72651b547",
120
+ "metadata": {},
121
+ "outputs": [],
122
+ "source": [
123
+ "class AudioLanguageConnector:\n",
124
+ " def __init__(self, projection_dim):\n",
125
+ " model_name = \"microsoft/phi-2\"\n",
126
+ " self.phi2_tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)\n",
127
+ " self.phi2_tokenizer.pad_token = self.phi2_tokenizer.eos_token\n",
128
+ " self.phi2_tokenizer.max_length = projection_dim\n",
129
+ "\n",
130
+ " def __call__(self, text):\n",
131
+ " text = f\"<audio_start> {text} <audio_end>\"\n",
132
+ " tokens = self.phi2_tokenizer(text, return_tensors=\"pt\", return_attention_mask=False)\n",
133
+ " return tokens\n",
134
+ " \n",
135
+ "\n",
136
+ "class WhisperWithProjection:\n",
137
+ " def __init__(self, projection_dim, device):\n",
138
+ " self.device = device\n",
139
+ " self.processor = WhisperProcessor.from_pretrained(\"openai/whisper-tiny\", device_map=device)\n",
140
+ " self.model = WhisperForConditionalGeneration.from_pretrained(\"openai/whisper-tiny\", device_map=device)\n",
141
+ " self.model.config.forced_decoder_ids = None\n",
142
+ " # self.audio_language_connector = AudioLanguageConnector(projection_dim)\n",
143
+ " \n",
144
+ " def __call__(self, audio):\n",
145
+ " input_features = self.processor(audio[\"array\"],\n",
146
+ " sampling_rate=audio[\"sampling_rate\"],\n",
147
+ " return_tensors=\"pt\").input_features\n",
148
+ " # generate token ids\n",
149
+ " predicted_ids = self.model.generate(input_features.to(self.device))\n",
150
+ " # decode token ids to text \n",
151
+ " transcription = self.processor.batch_decode(predicted_ids, skip_special_tokens=True)\n",
152
+ "\n",
153
+ " # audio_embeddings = self.audio_language_connector(transcription)\n",
154
+ " return transcription"
155
+ ]
156
+ },
157
+ {
158
+ "cell_type": "code",
159
+ "execution_count": 10,
160
+ "id": "a2757c91-2ec1-4fe7-9216-03740bf80061",
161
+ "metadata": {},
162
+ "outputs": [],
163
+ "source": [
164
+ "IGNORE_INDEX = -100\n",
165
+ "IMAGE_TOKEN_INDEX = -200\n",
166
+ "DEFAULT_IMAGE_TOKEN = \"<image>\"\n",
167
+ "DEFAULT_IMAGE_PATCH_TOKEN = \"<im_patch>\"\n",
168
+ "DEFAULT_IM_START_TOKEN = \"<im_start>\"\n",
169
+ "DEFAULT_IM_END_TOKEN = \"<im_end>\"\n",
170
+ "\n",
171
+ "from conversation import conv_templates, SeparatorStyle\n",
172
+ "\n",
173
+ "class MultiModalPhi2:\n",
174
+ " def __init__(self, modelname_or_path=\"RaviNaik/Llava-Phi2\",\n",
175
+ " temperature=0.2,\n",
176
+ " max_new_tokens=1024,\n",
177
+ " device=\"cuda:0\"):\n",
178
+ " self.model_name = modelname_or_path\n",
179
+ " self.temperature = temperature\n",
180
+ " self.max_new_tokens = max_new_tokens\n",
181
+ " self.device = device\n",
182
+ " self.disable_torch_init()\n",
183
+ " self.whisper_w_proj = WhisperWithProjection(projection_dim=512, device=device)\n",
184
+ " self.load_pretrained_model()\n",
185
+ " \n",
186
+ " def disable_torch_init(self):\n",
187
+ " \"\"\"\n",
188
+ " Disable the redundant torch default initialization to accelerate model creation.\n",
189
+ " \"\"\"\n",
190
+ " setattr(torch.nn.Linear, \"reset_parameters\", lambda self: None)\n",
191
+ " setattr(torch.nn.LayerNorm, \"reset_parameters\", lambda self: None)\n",
192
+ " \n",
193
+ " def load_pretrained_model(self):\n",
194
+ " self.model = LlavaPhiForCausalLM.from_pretrained(self.model_name, device_map=self.device)\n",
195
+ " self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)\n",
196
+ " self.image_processor = CLIPImageProcessor.from_pretrained(self.model_name)\n",
197
+ " mm_use_im_start_end = getattr(self.model.config, \"mm_use_im_start_end\", False)\n",
198
+ " mm_use_im_patch_token = getattr(self.model.config, \"mm_use_im_patch_token\", True)\n",
199
+ " if mm_use_im_patch_token:\n",
200
+ " self.tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)\n",
201
+ " if mm_use_im_start_end:\n",
202
+ " self.tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)\n",
203
+ " \n",
204
+ " def tokenizer_image_token(self, prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):\n",
205
+ " prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]\n",
206
+ " \n",
207
+ " def insert_separator(X, sep):\n",
208
+ " return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]\n",
209
+ " \n",
210
+ " input_ids = []\n",
211
+ " offset = 0\n",
212
+ " if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:\n",
213
+ " offset = 1\n",
214
+ " input_ids.append(prompt_chunks[0][0])\n",
215
+ " for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):\n",
216
+ " input_ids.extend(x[offset:])\n",
217
+ " \n",
218
+ " if return_tensors is not None:\n",
219
+ " if return_tensors == 'pt':\n",
220
+ " return torch.tensor(input_ids, dtype=torch.long)\n",
221
+ " raise ValueError(f'Unsupported tensor type: {return_tensors}')\n",
222
+ " return input_ids\n",
223
+ " \n",
224
+ " def __call__(self, text, audio, image):\n",
225
+ " qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\\n' + text\n",
226
+ " conv = conv_templates[\"phi-2_v0\"].copy()\n",
227
+ " conv.append_message(conv.roles[0], qs)\n",
228
+ " conv.append_message(conv.roles[1], None)\n",
229
+ " prompt = conv.get_prompt()\n",
230
+ "\n",
231
+ " image_tensor = self.image_processor.preprocess(image, return_tensors='pt')['pixel_values'].to(self.device)\n",
232
+ " \n",
233
+ " input_ids = self.tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0)\n",
234
+ " if audio is not None:\n",
235
+ " audio_transcript = self.whisper_w_proj(audio)\n",
236
+ " audio_embed = self.tokenizer(audio_transcript, return_tensors='pt')[\"input_ids\"]\n",
237
+ " input_ids = torch.concat([input_ids, audio_embed], dim=1)\n",
238
+ " input_ids = input_ids.to(self.device)\n",
239
+ " \n",
240
+ " stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2\n",
241
+ "\n",
242
+ " with torch.inference_mode():\n",
243
+ " output_ids = self.model.generate(\n",
244
+ " input_ids,\n",
245
+ " images=image_tensor,\n",
246
+ " do_sample=True,\n",
247
+ " temperature=self.temperature,\n",
248
+ " max_new_tokens=self.max_new_tokens,\n",
249
+ " eos_token_id=self.tokenizer.eos_token_id, # End of sequence token\n",
250
+ " pad_token_id=self.tokenizer.eos_token_id, # Pad token\n",
251
+ " use_cache=True,\n",
252
+ " )\n",
253
+ "\n",
254
+ " input_token_len = input_ids.shape[1]\n",
255
+ " n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()\n",
256
+ " if n_diff_input_output > 0:\n",
257
+ " print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')\n",
258
+ " outputs = self.tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]\n",
259
+ " outputs = outputs.strip()\n",
260
+ " if outputs.endswith(stop_str):\n",
261
+ " outputs = outputs[:-len(stop_str)]\n",
262
+ " outputs = outputs.strip()\n",
263
+ " return outputs"
264
+ ]
265
+ },
266
+ {
267
+ "cell_type": "code",
268
+ "execution_count": 11,
269
+ "id": "cc47e6a0-3544-4a60-930f-ccae87ef945a",
270
+ "metadata": {},
271
+ "outputs": [
272
+ {
273
+ "data": {
274
+ "application/vnd.jupyter.widget-view+json": {
275
+ "model_id": "9ef56077307d4cef907e25b092061611",
276
+ "version_major": 2,
277
+ "version_minor": 0
278
+ },
279
+ "text/plain": [
280
+ "Loading checkpoint shards: 0%| | 0/3 [00:00<?, ?it/s]"
281
+ ]
282
+ },
283
+ "metadata": {},
284
+ "output_type": "display_data"
285
+ },
286
+ {
287
+ "name": "stderr",
288
+ "output_type": "stream",
289
+ "text": [
290
+ "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
291
+ ]
292
+ }
293
+ ],
294
+ "source": [
295
+ "multimodal_phi2 = MultiModalPhi2()"
296
+ ]
297
+ },
298
+ {
299
+ "cell_type": "code",
300
+ "execution_count": 12,
301
+ "id": "cb8aca1b-7d75-45e7-b5a4-71d151f792e1",
302
+ "metadata": {},
303
+ "outputs": [],
304
+ "source": [
305
+ "from PIL import Image\n",
306
+ "import requests\n",
307
+ "\n",
308
+ "url = \"https://www.ilankelman.org/stopsigns/australia.jpg\"\n",
309
+ "image = Image.open(requests.get(url, stream=True).raw)\n",
310
+ "\n",
311
+ "from datasets import load_dataset\n",
312
+ "audio_ds = load_dataset(\"hf-internal-testing/librispeech_asr_dummy\", \"clean\", split=\"validation\")\n",
313
+ "audio = audio_ds[0][\"audio\"]\n",
314
+ "\n",
315
+ "text = \"tell me about the image\""
316
+ ]
317
+ },
318
+ {
319
+ "cell_type": "code",
320
+ "execution_count": 14,
321
+ "id": "6767efc6-be4f-44d3-84ff-34db57d9f940",
322
+ "metadata": {},
323
+ "outputs": [
324
+ {
325
+ "data": {
326
+ "text/plain": [
327
+ "'In the image, there is a Chinese writing on a pole in a foreign language. This suggests that the image was taken in a foreign country, possibly in a foreign country. The sign is in a foreign language, which might be in a foreign language. The sign is written in Japanese, which is a common language in Japan. The sign is also written in two different languages, which suggests that it is written in a language that is not in the native language.'"
328
+ ]
329
+ },
330
+ "execution_count": 14,
331
+ "metadata": {},
332
+ "output_type": "execute_result"
333
+ }
334
+ ],
335
+ "source": [
336
+ "multimodal_phi2(text, None, image)"
337
+ ]
338
+ },
339
+ {
340
+ "cell_type": "code",
341
+ "execution_count": null,
342
+ "id": "0bdd0b8a-709b-4c82-ac1d-dc746d3a0748",
343
+ "metadata": {},
344
+ "outputs": [],
345
+ "source": []
346
+ }
347
+ ],
348
+ "metadata": {
349
+ "kernelspec": {
350
+ "display_name": "Python 3 (ipykernel)",
351
+ "language": "python",
352
+ "name": "python3"
353
+ },
354
+ "language_info": {
355
+ "codemirror_mode": {
356
+ "name": "ipython",
357
+ "version": 3
358
+ },
359
+ "file_extension": ".py",
360
+ "mimetype": "text/x-python",
361
+ "name": "python",
362
+ "nbconvert_exporter": "python",
363
+ "pygments_lexer": "ipython3",
364
+ "version": "3.10.12"
365
+ }
366
+ },
367
+ "nbformat": 4,
368
+ "nbformat_minor": 5
369
+ }
inference/main.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import soundfile as sf
2
+ import librosa
3
+ import torch
4
+ from transformers import (
5
+ AutoTokenizer,
6
+ CLIPImageProcessor,
7
+ WhisperProcessor,
8
+ WhisperForConditionalGeneration,
9
+ )
10
+
11
+ from .model import LlavaPhiForCausalLM
12
+ from .conversation import conv_templates, SeparatorStyle
13
+
14
+ IGNORE_INDEX = -100
15
+ IMAGE_TOKEN_INDEX = -200
16
+ DEFAULT_IMAGE_TOKEN = "<image>"
17
+ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
18
+ DEFAULT_IM_START_TOKEN = "<im_start>"
19
+ DEFAULT_IM_END_TOKEN = "<im_end>"
20
+
21
+
22
+ class AudioLanguageConnector:
23
+ def __init__(self, projection_dim):
24
+ model_name = "microsoft/phi-2"
25
+ self.phi2_tokenizer = AutoTokenizer.from_pretrained(
26
+ model_name, trust_remote_code=True
27
+ )
28
+ self.phi2_tokenizer.pad_token = self.phi2_tokenizer.eos_token
29
+ self.phi2_tokenizer.max_length = projection_dim
30
+
31
+ def __call__(self, text):
32
+ text = f"<audio_start> {text} <audio_end>"
33
+ tokens = self.phi2_tokenizer(
34
+ text, return_tensors="pt", return_attention_mask=False
35
+ )
36
+ return tokens
37
+
38
+
39
+ class WhisperWithProjection:
40
+ def __init__(self, projection_dim, device):
41
+ self.device = device
42
+ self.processor = WhisperProcessor.from_pretrained(
43
+ "openai/whisper-tiny", device_map=device
44
+ )
45
+ self.model = WhisperForConditionalGeneration.from_pretrained(
46
+ "openai/whisper-tiny", device_map=device
47
+ )
48
+ self.model.config.forced_decoder_ids = None
49
+ # self.audio_language_connector = AudioLanguageConnector(projection_dim)
50
+
51
+ def __call__(self, audio):
52
+ array, sampling_rate = sf.read(audio)
53
+ resampled_array = librosa.resample(
54
+ array,
55
+ orig_sr=sampling_rate,
56
+ target_sr=16000,
57
+ )
58
+ input_features = self.processor(
59
+ resampled_array, sampling_rate=16000, return_tensors="pt"
60
+ ).input_features
61
+ # generate token ids
62
+ predicted_ids = self.model.generate(input_features.to(self.device))
63
+ # decode token ids to text
64
+ transcription = self.processor.batch_decode(
65
+ predicted_ids, skip_special_tokens=True
66
+ )
67
+
68
+ # audio_embeddings = self.audio_language_connector(transcription)
69
+ return transcription
70
+
71
+
72
+ class MultiModalPhi2:
73
+ def __init__(
74
+ self,
75
+ modelname_or_path="RaviNaik/Llava-Phi2",
76
+ temperature=0.2,
77
+ max_new_tokens=1024,
78
+ device="cuda:0",
79
+ ):
80
+ self.model_name = modelname_or_path
81
+ self.temperature = temperature
82
+ self.max_new_tokens = max_new_tokens
83
+ self.device = device
84
+ self.disable_torch_init()
85
+ self.whisper_w_proj = WhisperWithProjection(projection_dim=512, device=device)
86
+ self.load_pretrained_model()
87
+
88
+ def disable_torch_init(self):
89
+ """
90
+ Disable the redundant torch default initialization to accelerate model creation.
91
+ """
92
+ setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
93
+ setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
94
+
95
+ def load_pretrained_model(self):
96
+ self.model = LlavaPhiForCausalLM.from_pretrained(
97
+ self.model_name, device_map=self.device
98
+ )
99
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
100
+ self.image_processor = CLIPImageProcessor.from_pretrained(self.model_name)
101
+ mm_use_im_start_end = getattr(self.model.config, "mm_use_im_start_end", False)
102
+ mm_use_im_patch_token = getattr(
103
+ self.model.config, "mm_use_im_patch_token", True
104
+ )
105
+ if mm_use_im_patch_token:
106
+ self.tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
107
+ if mm_use_im_start_end:
108
+ self.tokenizer.add_tokens(
109
+ [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
110
+ )
111
+
112
+ def tokenizer_image_token(
113
+ self,
114
+ prompt,
115
+ tokenizer,
116
+ image_token_index=IMAGE_TOKEN_INDEX,
117
+ return_tensors=None,
118
+ ):
119
+ prompt_chunks = [
120
+ tokenizer(chunk).input_ids for chunk in prompt.split("<image>")
121
+ ]
122
+
123
+ def insert_separator(X, sep):
124
+ return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1]
125
+
126
+ input_ids = []
127
+ offset = 0
128
+ if (
129
+ len(prompt_chunks) > 0
130
+ and len(prompt_chunks[0]) > 0
131
+ and prompt_chunks[0][0] == tokenizer.bos_token_id
132
+ ):
133
+ offset = 1
134
+ input_ids.append(prompt_chunks[0][0])
135
+ for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
136
+ input_ids.extend(x[offset:])
137
+
138
+ if return_tensors is not None:
139
+ if return_tensors == "pt":
140
+ return torch.tensor(input_ids, dtype=torch.long)
141
+ raise ValueError(f"Unsupported tensor type: {return_tensors}")
142
+ return input_ids
143
+
144
+ def __call__(self, text, audio, image):
145
+ if text is None:
146
+ text = ""
147
+ if image is not None:
148
+ qs = (
149
+ DEFAULT_IM_START_TOKEN
150
+ + DEFAULT_IMAGE_TOKEN
151
+ + DEFAULT_IM_END_TOKEN
152
+ + "\n"
153
+ + text
154
+ )
155
+ conv = conv_templates["phi-2_v0"].copy()
156
+ conv.append_message(conv.roles[0], qs)
157
+ conv.append_message(conv.roles[1], None)
158
+ prompt = conv.get_prompt()
159
+
160
+ input_ids = self.tokenizer_image_token(
161
+ prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
162
+ ).unsqueeze(0)
163
+
164
+ image_tensor = self.image_processor.preprocess(image, return_tensors="pt")[
165
+ "pixel_values"
166
+ ].to(self.device)
167
+ else:
168
+ qs = text
169
+ conv = conv_templates["phi-2_v0"].copy()
170
+ conv.append_message(conv.roles[0], qs)
171
+ conv.append_message(conv.roles[1], None)
172
+ prompt = conv.get_prompt()
173
+
174
+ input_ids = self.tokenizer(prompt, return_tensors="pt")["input_ids"]
175
+
176
+ image_tensor = None
177
+
178
+ if audio is not None:
179
+ audio_transcript = self.whisper_w_proj(audio)
180
+ audio_embed = self.tokenizer(audio_transcript, return_tensors="pt")[
181
+ "input_ids"
182
+ ]
183
+ input_ids = torch.concat([input_ids, audio_embed], dim=1)
184
+ input_ids = input_ids.to(self.device)
185
+
186
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
187
+
188
+ with torch.inference_mode():
189
+ if image is not None:
190
+ output_ids = self.model.generate(
191
+ input_ids,
192
+ images=image_tensor,
193
+ do_sample=True,
194
+ temperature=self.temperature,
195
+ max_new_tokens=self.max_new_tokens,
196
+ eos_token_id=self.tokenizer.eos_token_id, # End of sequence token
197
+ pad_token_id=self.tokenizer.eos_token_id, # Pad token
198
+ use_cache=True,
199
+ )
200
+ else:
201
+ output_ids = self.model.generate(
202
+ input_ids,
203
+ do_sample=True,
204
+ temperature=self.temperature,
205
+ max_new_tokens=self.max_new_tokens,
206
+ eos_token_id=self.tokenizer.eos_token_id, # End of sequence token
207
+ pad_token_id=self.tokenizer.eos_token_id, # Pad token
208
+ use_cache=True,
209
+ )
210
+
211
+ input_token_len = input_ids.shape[1]
212
+ n_diff_input_output = (
213
+ (input_ids != output_ids[:, :input_token_len]).sum().item()
214
+ )
215
+ if n_diff_input_output > 0:
216
+ print(
217
+ f"[Warning] {n_diff_input_output} output_ids are not the same as the input_ids"
218
+ )
219
+ outputs = self.tokenizer.batch_decode(
220
+ output_ids[:, input_token_len:], skip_special_tokens=True
221
+ )[0]
222
+ outputs = outputs.strip()
223
+ if outputs.endswith(stop_str):
224
+ outputs = outputs[: -len(stop_str)]
225
+ outputs = outputs.strip()
226
+ return outputs
inference/model/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .language_model.llava_phi import LlavaPhiForCausalLM
2
+ from .language_model.configuration_llava_phi import LlavaPhiConfig, LlavaPhiVisionConfig, ProjectorConfig
inference/model/builder.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import warnings
3
+
4
+ from transformers import (
5
+ AutoTokenizer,
6
+ AutoModelForCausalLM,
7
+ AutoConfig,
8
+ BitsAndBytesConfig,
9
+ CLIPImageProcessor,
10
+ )
11
+ import torch
12
+ from .language_model.llava_phi import LlavaPhiForCausalLM
13
+ from .language_model.configuration_llava_phi import LlavaPhiConfig
14
+
15
+ IGNORE_INDEX = -100
16
+ IMAGE_TOKEN_INDEX = -200
17
+ DEFAULT_IMAGE_TOKEN = "<image>"
18
+ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
19
+ DEFAULT_IM_START_TOKEN = "<im_start>"
20
+ DEFAULT_IM_END_TOKEN = "<im_end>"
21
+
22
+
23
+ def load_pretrained_model(
24
+ model_path,
25
+ model_base,
26
+ model_name,
27
+ load_8bit=False,
28
+ load_4bit=False,
29
+ device_map="cuda",
30
+ device="cuda",
31
+ ):
32
+ kwargs = {"device_map": device_map}
33
+ if load_8bit:
34
+ kwargs["load_in_8bit"] = True
35
+ elif load_4bit:
36
+ kwargs["load_in_4bit"] = True
37
+ kwargs["quantization_config"] = BitsAndBytesConfig(
38
+ load_in_4bit=True,
39
+ bnb_4bit_compute_dtype=torch.float16,
40
+ bnb_4bit_use_double_quant=True,
41
+ bnb_4bit_quant_type="nf4",
42
+ )
43
+ # else: # TODO: after fine-tuning LLava-Phi, load the model weights with fp16 will pose nan
44
+ # kwargs['torch_dtype'] = torch.float16
45
+
46
+ if "phi" in model_name.lower():
47
+ # Load LLaVA-Phi model
48
+ if "lora" in model_name.lower() and model_base is None:
49
+ warnings.warn(
50
+ "There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument."
51
+ )
52
+ if "lora" in model_name.lower() and model_base is not None:
53
+ lora_cfg_pretrained = AutoConfig.from_pretrained(model_path)
54
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
55
+ print("Loading LLaVA-Phi from base model...")
56
+ model = LlavaPhiForCausalLM.from_pretrained(
57
+ model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs
58
+ )
59
+ token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
60
+ if model.lm_head.weight.shape[0] != token_num:
61
+ model.lm_head.weight = torch.nn.Parameter(
62
+ torch.empty(
63
+ token_num, tokem_dim, device=model.device, dtype=model.dtype
64
+ )
65
+ )
66
+ model.model.embed_tokens.weight = torch.nn.Parameter(
67
+ torch.empty(
68
+ token_num, tokem_dim, device=model.device, dtype=model.dtype
69
+ )
70
+ )
71
+
72
+ print("Loading additional LLaVA-Phi weights...")
73
+ if os.path.exists(os.path.join(model_path, "non_lora_trainables.bin")):
74
+ non_lora_trainables = torch.load(
75
+ os.path.join(model_path, "non_lora_trainables.bin"),
76
+ map_location="cpu",
77
+ )
78
+ else:
79
+ # this is probably from HF Hub
80
+ from huggingface_hub import hf_hub_download
81
+
82
+ def load_from_hf(repo_id, filename, subfolder=None):
83
+ cache_file = hf_hub_download(
84
+ repo_id=repo_id, filename=filename, subfolder=subfolder
85
+ )
86
+ return torch.load(cache_file, map_location="cpu")
87
+
88
+ non_lora_trainables = load_from_hf(
89
+ model_path, "non_lora_trainables.bin"
90
+ )
91
+ non_lora_trainables = {
92
+ (k[11:] if k.startswith("base_model.") else k): v
93
+ for k, v in non_lora_trainables.items()
94
+ }
95
+ if any(k.startswith("model.model.") for k in non_lora_trainables):
96
+ non_lora_trainables = {
97
+ (k[6:] if k.startswith("model.") else k): v
98
+ for k, v in non_lora_trainables.items()
99
+ }
100
+ model.load_state_dict(non_lora_trainables, strict=False)
101
+
102
+ from peft import PeftModel
103
+
104
+ print("Loading LoRA weights...")
105
+ model = PeftModel.from_pretrained(model, model_path)
106
+ print("Merging LoRA weights...")
107
+ model = model.merge_and_unload()
108
+ print("Model is loaded...")
109
+ elif model_base is not None:
110
+ # this may be mm projector only
111
+ print("Loading LLaVA-Phi from base model...")
112
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
113
+ cfg_pretrained = AutoConfig.from_pretrained(model_path)
114
+ model = LlavaPhiForCausalLM.from_pretrained(
115
+ model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs
116
+ )
117
+
118
+ mm_projector_weights = torch.load(
119
+ os.path.join(model_path, "mm_projector.bin"), map_location="cpu"
120
+ )
121
+ mm_projector_weights = {
122
+ k: v.to(torch.float16) for k, v in mm_projector_weights.items()
123
+ }
124
+ model.load_state_dict(mm_projector_weights, strict=False)
125
+ else:
126
+ print("load llaVA-Phi MLLM!!!")
127
+ config = LlavaPhiConfig.from_pretrained(model_path, trust_remote_code=True)
128
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
129
+ model = LlavaPhiForCausalLM.from_pretrained(
130
+ model_path, config=config, use_safetensors=True, **kwargs
131
+ ).to("cuda")
132
+ else:
133
+ # Load language model
134
+ if model_base is not None:
135
+ # PEFT model
136
+ from peft import PeftModel
137
+
138
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
139
+ model = AutoModelForCausalLM.from_pretrained(
140
+ model_base,
141
+ torch_dtype=torch.float16,
142
+ low_cpu_mem_usage=True,
143
+ device_map="auto",
144
+ )
145
+ print(f"Loading LoRA weights from {model_path}")
146
+ model = PeftModel.from_pretrained(model, model_path)
147
+ print(f"Merging weights")
148
+ model = model.merge_and_unload()
149
+ print("Convert to FP16...")
150
+ model.to(torch.float16)
151
+ else:
152
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
153
+ model = AutoModelForCausalLM.from_pretrained(
154
+ model_path, low_cpu_mem_usage=True, **kwargs
155
+ )
156
+
157
+ image_processor = CLIPImageProcessor.from_pretrained(model_path)
158
+
159
+ if "phi" in model_name.lower():
160
+ mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
161
+ mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
162
+
163
+ # TODO: the tokenizer length of phi-2 is 50295, but the output class of lm_head is 51200
164
+ if mm_use_im_patch_token:
165
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
166
+ if mm_use_im_start_end:
167
+ tokenizer.add_tokens(
168
+ [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
169
+ )
170
+ # model.resize_token_embeddings(len(tokenizer))
171
+ else:
172
+ raise ValueError(f"Unsupported model name: {model_name}")
173
+
174
+ if hasattr(model.config, "max_sequence_length"):
175
+ context_len = model.config.max_sequence_length
176
+ else:
177
+ context_len = 2048
178
+ model.to(device="cuda")
179
+ print(kwargs)
180
+ return tokenizer, model, image_processor, context_len
inference/model/language_model/configuration_llava_phi.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Union
3
+ from transformers import PretrainedConfig, PhiConfig
4
+ from transformers.utils import logging
5
+
6
+ logger = logging.get_logger(__name__)
7
+
8
+
9
+ class LlavaPhiVisionConfig(PretrainedConfig):
10
+ r"""
11
+ This is the configuration class to store the configuration of a [`CLIPVisionModel`]. It is used to instantiate a
12
+ CLIP vision encoder according to the specified arguments, defining the model architecture. Instantiating a
13
+ configuration with the defaults will yield a similar configuration to that of the vision encoder of the CLIP
14
+ [openai/clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32) architecture.
15
+
16
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
17
+ documentation from [`PretrainedConfig`] for more information.
18
+
19
+ Args:
20
+ hidden_size (`int`, *optional*, defaults to 768):
21
+ Dimensionality of the encoder layers and the pooler layer.
22
+ intermediate_size (`int`, *optional*, defaults to 3072):
23
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
24
+ projection_dim (`int`, *optional*, defaults to 512):
25
+ Dimentionality of text and vision projection layers.
26
+ num_hidden_layers (`int`, *optional*, defaults to 12):
27
+ Number of hidden layers in the Transformer encoder.
28
+ num_attention_heads (`int`, *optional*, defaults to 12):
29
+ Number of attention heads for each attention layer in the Transformer encoder.
30
+ num_channels (`int`, *optional*, defaults to 3):
31
+ The number of input channels.
32
+ image_size (`int`, *optional*, defaults to 224):
33
+ The size (resolution) of each image.
34
+ patch_size (`int`, *optional*, defaults to 32):
35
+ The size (resolution) of each patch.
36
+ hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`):
37
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
38
+ `"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported.
39
+ layer_norm_eps (`float`, *optional*, defaults to 1e-05):
40
+ The epsilon used by the layer normalization layers.
41
+ attention_dropout (`float`, *optional*, defaults to 0.0):
42
+ The dropout ratio for the attention probabilities.
43
+ initializer_range (`float`, *optional*, defaults to 0.02):
44
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
45
+ initializer_factor (`float`, *optional*, defaults to 1.0):
46
+ A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
47
+ testing).
48
+ mm_vision_select_feature (`str`, *optional*, defaults to `"patch"`):
49
+ The feature to select from the vision encoder output. Can be one of `"patch"` or `"cls_patch"`.
50
+ mm_vision_select_layer (`int`, *optional*, defaults to `-2`):
51
+ The layer to select from the vision encoder output.
52
+
53
+ Example:
54
+
55
+ ```python
56
+ >>> from transformers import CLIPVisionConfig, CLIPVisionModel
57
+
58
+ >>> # Initializing a CLIPVisionConfig with openai/clip-vit-base-patch32 style configuration
59
+ >>> configuration = CLIPVisionConfig()
60
+
61
+ >>> # Initializing a CLIPVisionModel (with random weights) from the openai/clip-vit-base-patch32 style configuration
62
+ >>> model = CLIPVisionModel(configuration)
63
+
64
+ >>> # Accessing the model configuration
65
+ >>> configuration = model.config
66
+ ```"""
67
+
68
+ model_type = "llava_phi_clip_vision_model"
69
+
70
+ def __init__(
71
+ self,
72
+ hidden_size=768,
73
+ intermediate_size=3072,
74
+ projection_dim=512,
75
+ num_hidden_layers=12,
76
+ num_attention_heads=12,
77
+ num_channels=3,
78
+ image_size=224,
79
+ patch_size=32,
80
+ hidden_act="quick_gelu",
81
+ layer_norm_eps=1e-5,
82
+ attention_dropout=0.0,
83
+ initializer_range=0.02,
84
+ initializer_factor=1.0,
85
+ mm_vision_select_feature="patch",
86
+ mm_vision_select_layer=-2,
87
+ **kwargs,
88
+ ):
89
+ super().__init__(**kwargs)
90
+
91
+ self.hidden_size = hidden_size
92
+ self.intermediate_size = intermediate_size
93
+ self.projection_dim = projection_dim
94
+ self.num_hidden_layers = num_hidden_layers
95
+ self.num_attention_heads = num_attention_heads
96
+ self.num_channels = num_channels
97
+ self.patch_size = patch_size
98
+ self.image_size = image_size
99
+ self.initializer_range = initializer_range
100
+ self.initializer_factor = initializer_factor
101
+ self.attention_dropout = attention_dropout
102
+ self.layer_norm_eps = layer_norm_eps
103
+ self.hidden_act = hidden_act
104
+ self.mm_vision_select_feature = mm_vision_select_feature
105
+ self.mm_vision_select_layer = mm_vision_select_layer
106
+
107
+ @classmethod
108
+ def from_pretrained(
109
+ cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
110
+ ) -> "PretrainedConfig":
111
+ cls._set_token_in_kwargs(kwargs)
112
+
113
+ config_dict, kwargs = cls.get_config_dict(
114
+ pretrained_model_name_or_path, **kwargs
115
+ )
116
+
117
+ # get the vision config dict if we are loading from CLIPConfig
118
+ if config_dict.get("model_type") == "llava_phi-phi":
119
+ config_dict = config_dict["vision_config"]
120
+
121
+ if (
122
+ "model_type" in config_dict
123
+ and hasattr(cls, "model_type")
124
+ and config_dict["model_type"] != cls.model_type
125
+ ):
126
+ logger.warning(
127
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
128
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
129
+ )
130
+
131
+ return cls.from_dict(config_dict, **kwargs)
132
+
133
+
134
+ class ProjectorConfig(PretrainedConfig):
135
+ model_type = "llava_phi_projector"
136
+
137
+ def __init__(
138
+ self, mm_projector_type="linear", mm_hidden_size=768, hidden_size=2560, **kwargs
139
+ ):
140
+ self.mm_projector_type = mm_projector_type
141
+ self.mm_hidden_size = mm_hidden_size
142
+ self.hidden_size = hidden_size
143
+ super().__init__(**kwargs)
144
+
145
+ @classmethod
146
+ def from_pretrained(
147
+ cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
148
+ ) -> "PretrainedConfig":
149
+ cls._set_token_in_kwargs(kwargs)
150
+
151
+ config_dict, kwargs = cls.get_config_dict(
152
+ pretrained_model_name_or_path, **kwargs
153
+ )
154
+
155
+ # get the vision config dict if we are loading from CLIPConfig
156
+ if config_dict.get("model_type") == "llava_phi-phi":
157
+ config_dict = config_dict["projector_config"]
158
+
159
+ if (
160
+ "model_type" in config_dict
161
+ and hasattr(cls, "model_type")
162
+ and config_dict["model_type"] != cls.model_type
163
+ ):
164
+ logger.warning(
165
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
166
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
167
+ )
168
+
169
+ return cls.from_dict(config_dict, **kwargs)
170
+
171
+
172
+ DEFAULT_VISUAL_CONFIG = {
173
+ "vision_tower": LlavaPhiVisionConfig().to_dict(),
174
+ "mm_projector": ProjectorConfig().to_dict(),
175
+ }
176
+
177
+
178
+ class LlavaPhiConfig(PhiConfig):
179
+ model_type = "llava_phi"
180
+
181
+ def __init__(self, vision_config=None, **kwargs):
182
+ if vision_config is None:
183
+ self.vision_config = DEFAULT_VISUAL_CONFIG
184
+ else:
185
+ self.vision_config = vision_config
186
+
187
+ super().__init__(**kwargs)
188
+
189
+
190
+ if __name__ == "__main__":
191
+ print(LlavaPhiVisionConfig())
inference/model/language_model/llava_phi.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List, Optional, Tuple, Union
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch.nn import CrossEntropyLoss
7
+
8
+ from transformers import AutoConfig, AutoModelForCausalLM, \
9
+ PhiModel, PhiPreTrainedModel
10
+
11
+ from transformers.modeling_outputs import CausalLMOutputWithPast
12
+ from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
13
+ from transformers.utils import logging
14
+ from .configuration_llava_phi import LlavaPhiConfig
15
+
16
+ logger = logging.get_logger(__name__)
17
+
18
+
19
+ class LLavaPhiModel(LlavaMetaModel, PhiModel):
20
+ config_class = LlavaPhiConfig
21
+
22
+ def __init__(self, config):
23
+ super(LLavaPhiModel, self).__init__(config)
24
+
25
+
26
+ class LlavaPhiForCausalLM(PhiPreTrainedModel, LlavaMetaForCausalLM):
27
+ config_class = LlavaPhiConfig
28
+
29
+ def __init__(self, config):
30
+ super(PhiPreTrainedModel, self).__init__(config)
31
+ self.model = LLavaPhiModel(config)
32
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
33
+
34
+ # Initialize weights and apply final processing
35
+ self.post_init()
36
+
37
+ def get_model(self):
38
+ return self.model
39
+
40
+ def forward(
41
+ self,
42
+ input_ids: torch.LongTensor = None,
43
+ attention_mask: Optional[torch.Tensor] = None,
44
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
45
+ inputs_embeds: Optional[torch.FloatTensor] = None,
46
+ labels: Optional[torch.LongTensor] = None,
47
+ use_cache: Optional[bool] = None,
48
+ output_attentions: Optional[bool] = None,
49
+ output_hidden_states: Optional[bool] = None,
50
+ images: Optional[torch.FloatTensor] = None,
51
+ return_dict: Optional[bool] = None,
52
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
53
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
54
+ output_hidden_states = (
55
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
56
+ )
57
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
58
+
59
+ input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(
60
+ input_ids, attention_mask, past_key_values, labels, images)
61
+
62
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
63
+ outputs = self.model(
64
+ input_ids=input_ids,
65
+ attention_mask=attention_mask,
66
+ past_key_values=past_key_values,
67
+ inputs_embeds=inputs_embeds,
68
+ use_cache=use_cache,
69
+ output_attentions=output_attentions,
70
+ output_hidden_states=output_hidden_states,
71
+ return_dict=return_dict
72
+ )
73
+
74
+ hidden_states = outputs[0]
75
+ logits = self.lm_head(hidden_states)
76
+
77
+ loss = None
78
+ if labels is not None:
79
+ # Shift so that tokens < n predict n
80
+ shift_logits = logits[..., :-1, :].contiguous()
81
+ shift_labels = labels[..., 1:].contiguous()
82
+ # Flatten the tokens
83
+ loss_fct = CrossEntropyLoss()
84
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
85
+ shift_labels = shift_labels.view(-1)
86
+ # Enable model/pipeline parallelism
87
+ shift_labels = shift_labels.to(shift_logits.device)
88
+ loss = loss_fct(shift_logits, shift_labels)
89
+
90
+ if not return_dict:
91
+ output = (logits,) + outputs[1:]
92
+ return (loss,) + output if loss is not None else output
93
+
94
+ return CausalLMOutputWithPast(
95
+ loss=loss,
96
+ logits=logits,
97
+ past_key_values=outputs.past_key_values,
98
+ hidden_states=outputs.hidden_states,
99
+ attentions=outputs.attentions,
100
+ )
101
+
102
+ def prepare_inputs_for_generation(
103
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
104
+ ):
105
+ if past_key_values:
106
+ input_ids = input_ids[:, -1:]
107
+
108
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
109
+ if inputs_embeds is not None and past_key_values is None:
110
+ model_inputs = {"inputs_embeds": inputs_embeds}
111
+ else:
112
+ model_inputs = {"input_ids": input_ids}
113
+
114
+ model_inputs.update(
115
+ {
116
+ "past_key_values": past_key_values,
117
+ "use_cache": kwargs.get("use_cache"),
118
+ "attention_mask": attention_mask,
119
+ "images": kwargs.get("images", None),
120
+ }
121
+ )
122
+ return model_inputs
123
+
124
+
125
+ AutoConfig.register("llava_phi", LlavaPhiConfig)
126
+ AutoModelForCausalLM.register(LlavaPhiConfig, LlavaPhiForCausalLM)
inference/model/llava_arch.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from abc import ABC, abstractmethod
17
+
18
+ import torch
19
+
20
+ from .multimodal_encoder.clip_encoder import CLIPVisionTower
21
+ from .multimodal_projector.builder import build_vision_projector
22
+ from .language_model.configuration_llava_phi import (
23
+ LlavaPhiConfig,
24
+ LlavaPhiVisionConfig,
25
+ ProjectorConfig,
26
+ )
27
+
28
+ # from llava_phi.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
29
+ IGNORE_INDEX = -100
30
+ IMAGE_TOKEN_INDEX = -200
31
+ DEFAULT_IMAGE_TOKEN = "<image>"
32
+ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
33
+ DEFAULT_IM_START_TOKEN = "<im_start>"
34
+ DEFAULT_IM_END_TOKEN = "<im_end>"
35
+
36
+
37
+ class LlavaMetaModel:
38
+ def __init__(self, config):
39
+ super(LlavaMetaModel, self).__init__(config)
40
+ self.vision_tower = CLIPVisionTower(
41
+ LlavaPhiVisionConfig(**config.vision_config["vision_tower"])
42
+ )
43
+ self.mm_projector = build_vision_projector(
44
+ ProjectorConfig(**config.vision_config["mm_projector"])
45
+ )
46
+
47
+ def get_vision_tower(self):
48
+ vision_tower = getattr(self, "vision_tower", None)
49
+ if type(vision_tower) is list:
50
+ vision_tower = vision_tower[0]
51
+ return vision_tower
52
+
53
+
54
+ class LlavaMetaForCausalLM(ABC):
55
+ @abstractmethod
56
+ def get_model(self):
57
+ pass
58
+
59
+ def get_vision_tower(self):
60
+ return self.get_model().get_vision_tower()
61
+
62
+ def encode_images(self, images):
63
+ image_features = self.get_model().get_vision_tower()(images)
64
+ image_features = self.get_model().mm_projector(image_features)
65
+ return image_features
66
+
67
+ def prepare_inputs_labels_for_multimodal(
68
+ self, input_ids, attention_mask, past_key_values, labels, images
69
+ ):
70
+ vision_tower = self.get_vision_tower()
71
+ if vision_tower is None or images is None or input_ids.shape[1] == 1:
72
+ if (
73
+ past_key_values is not None
74
+ and vision_tower is not None
75
+ and images is not None
76
+ and input_ids.shape[1] == 1
77
+ ):
78
+ attention_mask = torch.ones(
79
+ (attention_mask.shape[0], past_key_values[-1][-1].shape[-2] + 1),
80
+ dtype=attention_mask.dtype,
81
+ device=attention_mask.device,
82
+ )
83
+ return input_ids, attention_mask, past_key_values, None, labels
84
+
85
+ if type(images) is list or images.ndim == 5:
86
+ concat_images = torch.cat([image for image in images], dim=0)
87
+ image_features = self.encode_images(concat_images)
88
+ split_sizes = [image.shape[0] for image in images]
89
+ image_features = torch.split(image_features, split_sizes, dim=0)
90
+ image_features = [x.flatten(0, 1) for x in image_features]
91
+ else:
92
+ image_features = self.encode_images(images)
93
+
94
+ new_input_embeds = []
95
+ new_labels = [] if labels is not None else None
96
+ cur_image_idx = 0
97
+ for batch_idx, cur_input_ids in enumerate(input_ids):
98
+ if (cur_input_ids == IMAGE_TOKEN_INDEX).sum() == 0:
99
+ # multimodal LLM, but the current sample is not multimodal
100
+ # FIXME: this is a hacky fix, for deepspeed zero3 to work
101
+ half_len = cur_input_ids.shape[0] // 2
102
+ cur_image_features = image_features[cur_image_idx]
103
+ cur_input_embeds_1 = self.get_model().embed_tokens(
104
+ cur_input_ids[:half_len]
105
+ )
106
+ cur_input_embeds_2 = self.get_model().embed_tokens(
107
+ cur_input_ids[half_len:]
108
+ )
109
+ cur_input_embeds = torch.cat(
110
+ [cur_input_embeds_1, cur_image_features[0:0], cur_input_embeds_2],
111
+ dim=0,
112
+ )
113
+ new_input_embeds.append(cur_input_embeds)
114
+ if labels is not None:
115
+ new_labels.append(labels[batch_idx])
116
+ cur_image_idx += 1
117
+ continue
118
+ image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0]
119
+ cur_new_input_embeds = []
120
+ if labels is not None:
121
+ cur_labels = labels[batch_idx]
122
+ cur_new_labels = []
123
+ assert cur_labels.shape == cur_input_ids.shape
124
+ while image_token_indices.numel() > 0:
125
+ cur_image_features = image_features[cur_image_idx]
126
+ image_token_start = image_token_indices[0]
127
+ if getattr(self.config, "tune_mm_mlp_adapter", False) and getattr(
128
+ self.config, "mm_use_im_start_end", False
129
+ ):
130
+ cur_new_input_embeds.append(
131
+ self.get_model()
132
+ .embed_tokens(cur_input_ids[: image_token_start - 1])
133
+ .detach()
134
+ )
135
+ cur_new_input_embeds.append(
136
+ self.get_model().embed_tokens(
137
+ cur_input_ids[image_token_start - 1 : image_token_start]
138
+ )
139
+ )
140
+ cur_new_input_embeds.append(cur_image_features)
141
+ cur_new_input_embeds.append(
142
+ self.get_model().embed_tokens(
143
+ cur_input_ids[image_token_start + 1 : image_token_start + 2]
144
+ )
145
+ )
146
+ if labels is not None:
147
+ cur_new_labels.append(cur_labels[:image_token_start])
148
+ cur_new_labels.append(
149
+ torch.full(
150
+ (cur_image_features.shape[0],),
151
+ IGNORE_INDEX,
152
+ device=labels.device,
153
+ dtype=labels.dtype,
154
+ )
155
+ )
156
+ cur_new_labels.append(
157
+ cur_labels[image_token_start : image_token_start + 1]
158
+ )
159
+ cur_labels = cur_labels[image_token_start + 2 :]
160
+ else:
161
+ cur_new_input_embeds.append(
162
+ self.get_model().embed_tokens(cur_input_ids[:image_token_start])
163
+ )
164
+ cur_new_input_embeds.append(cur_image_features)
165
+ if labels is not None:
166
+ cur_new_labels.append(cur_labels[:image_token_start])
167
+ cur_new_labels.append(
168
+ torch.full(
169
+ (cur_image_features.shape[0],),
170
+ IGNORE_INDEX,
171
+ device=labels.device,
172
+ dtype=labels.dtype,
173
+ )
174
+ )
175
+ cur_labels = cur_labels[image_token_start + 1 :]
176
+ cur_image_idx += 1
177
+ if getattr(self.config, "tune_mm_mlp_adapter", False) and getattr(
178
+ self.config, "mm_use_im_start_end", False
179
+ ):
180
+ cur_input_ids = cur_input_ids[image_token_start + 2 :]
181
+ else:
182
+ cur_input_ids = cur_input_ids[image_token_start + 1 :]
183
+ image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0]
184
+ if cur_input_ids.numel() > 0:
185
+ if getattr(self.config, "tune_mm_mlp_adapter", False) and getattr(
186
+ self.config, "mm_use_im_start_end", False
187
+ ):
188
+ cur_new_input_embeds.append(
189
+ self.get_model().embed_tokens(cur_input_ids).detach()
190
+ )
191
+ else:
192
+ cur_new_input_embeds.append(
193
+ self.get_model().embed_tokens(cur_input_ids)
194
+ )
195
+ if labels is not None:
196
+ cur_new_labels.append(cur_labels)
197
+ cur_new_input_embeds = [
198
+ x.to(device=self.device) for x in cur_new_input_embeds
199
+ ]
200
+ cur_new_input_embeds = torch.cat(cur_new_input_embeds, dim=0)
201
+ new_input_embeds.append(cur_new_input_embeds)
202
+ if labels is not None:
203
+ cur_new_labels = torch.cat(cur_new_labels, dim=0)
204
+ new_labels.append(cur_new_labels)
205
+
206
+ if any(x.shape != new_input_embeds[0].shape for x in new_input_embeds):
207
+ max_len = max(x.shape[0] for x in new_input_embeds)
208
+
209
+ new_input_embeds_align = []
210
+ for cur_new_embed in new_input_embeds:
211
+ cur_new_embed = torch.cat(
212
+ (
213
+ cur_new_embed,
214
+ torch.zeros(
215
+ (max_len - cur_new_embed.shape[0], cur_new_embed.shape[1]),
216
+ dtype=cur_new_embed.dtype,
217
+ device=cur_new_embed.device,
218
+ ),
219
+ ),
220
+ dim=0,
221
+ )
222
+ new_input_embeds_align.append(cur_new_embed)
223
+ new_input_embeds = torch.stack(new_input_embeds_align, dim=0)
224
+
225
+ if labels is not None:
226
+ new_labels_align = []
227
+ _new_labels = new_labels
228
+ for cur_new_label in new_labels:
229
+ cur_new_label = torch.cat(
230
+ (
231
+ cur_new_label,
232
+ torch.full(
233
+ (max_len - cur_new_label.shape[0],),
234
+ IGNORE_INDEX,
235
+ dtype=cur_new_label.dtype,
236
+ device=cur_new_label.device,
237
+ ),
238
+ ),
239
+ dim=0,
240
+ )
241
+ new_labels_align.append(cur_new_label)
242
+ new_labels = torch.stack(new_labels_align, dim=0)
243
+
244
+ if attention_mask is not None:
245
+ new_attention_mask = []
246
+ for cur_attention_mask, cur_new_labels, cur_new_labels_align in zip(
247
+ attention_mask, _new_labels, new_labels
248
+ ):
249
+ new_attn_mask_pad_left = torch.full(
250
+ (cur_new_labels.shape[0] - labels.shape[1],),
251
+ True,
252
+ dtype=attention_mask.dtype,
253
+ device=attention_mask.device,
254
+ )
255
+ new_attn_mask_pad_right = torch.full(
256
+ (cur_new_labels_align.shape[0] - cur_new_labels.shape[0],),
257
+ False,
258
+ dtype=attention_mask.dtype,
259
+ device=attention_mask.device,
260
+ )
261
+ cur_new_attention_mask = torch.cat(
262
+ (
263
+ new_attn_mask_pad_left,
264
+ cur_attention_mask,
265
+ new_attn_mask_pad_right,
266
+ ),
267
+ dim=0,
268
+ )
269
+ new_attention_mask.append(cur_new_attention_mask)
270
+ attention_mask = torch.stack(new_attention_mask, dim=0)
271
+ assert attention_mask.shape == new_labels.shape
272
+ else:
273
+ new_input_embeds = torch.stack(new_input_embeds, dim=0)
274
+ if labels is not None:
275
+ new_labels = torch.stack(new_labels, dim=0)
276
+
277
+ if attention_mask is not None:
278
+ new_attn_mask_pad_left = torch.full(
279
+ (
280
+ attention_mask.shape[0],
281
+ new_input_embeds.shape[1] - input_ids.shape[1],
282
+ ),
283
+ True,
284
+ dtype=attention_mask.dtype,
285
+ device=attention_mask.device,
286
+ )
287
+ attention_mask = torch.cat(
288
+ (new_attn_mask_pad_left, attention_mask), dim=1
289
+ )
290
+ assert attention_mask.shape == new_input_embeds.shape[:2]
291
+
292
+ return None, attention_mask, past_key_values, new_input_embeds, new_labels
293
+
294
+ def initialize_vision_tokenizer(self, model_args, tokenizer):
295
+ if model_args.mm_use_im_patch_token:
296
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
297
+ self.resize_token_embeddings(len(tokenizer))
298
+
299
+ if model_args.mm_use_im_start_end:
300
+ num_new_tokens = tokenizer.add_tokens(
301
+ [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
302
+ )
303
+ self.resize_token_embeddings(len(tokenizer))
304
+
305
+ if num_new_tokens > 0:
306
+ input_embeddings = self.get_input_embeddings().weight.data
307
+ output_embeddings = self.get_output_embeddings().weight.data
308
+
309
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
310
+ dim=0, keepdim=True
311
+ )
312
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
313
+ dim=0, keepdim=True
314
+ )
315
+
316
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
317
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
318
+
319
+ if model_args.tune_mm_mlp_adapter:
320
+ for p in self.get_input_embeddings().parameters():
321
+ p.requires_grad = True
322
+ for p in self.get_output_embeddings().parameters():
323
+ p.requires_grad = False
324
+
325
+ elif model_args.mm_use_im_patch_token:
326
+ if model_args.tune_mm_mlp_adapter:
327
+ for p in self.get_input_embeddings().parameters():
328
+ p.requires_grad = False
329
+ for p in self.get_output_embeddings().parameters():
330
+ p.requires_grad = False
inference/model/multimodal_encoder/clip_encoder.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from transformers import CLIPPreTrainedModel, CLIPVisionConfig
7
+ from transformers.models.clip.modeling_clip import CLIPVisionTransformer
8
+ from inference.model.language_model.configuration_llava_phi import LlavaPhiVisionConfig
9
+
10
+
11
+ class CLIPVisionTower(CLIPPreTrainedModel):
12
+ config_class = LlavaPhiVisionConfig
13
+
14
+ def __init__(self, config):
15
+ super().__init__(config)
16
+
17
+ self.vision_model = CLIPVisionTransformer(config)
18
+ # Initialize weights and apply final processing
19
+ self.post_init()
20
+
21
+ def get_input_embeddings(self) -> nn.Module:
22
+ return self.vision_model.embeddings.patch_embedding
23
+
24
+ def feature_select(self, image_forward_outs):
25
+ image_features = image_forward_outs.hidden_states[
26
+ self.config.mm_vision_select_layer
27
+ ]
28
+ if self.config.mm_vision_select_feature == "patch":
29
+ image_features = image_features[:, 1:]
30
+ elif self.config.mm_vision_select_feature == "cls_patch":
31
+ image_features = image_features
32
+ else:
33
+ raise ValueError(
34
+ f"Unexpected select feature: {self.config.mm_vision_select_feature}"
35
+ )
36
+ return image_features
37
+
38
+ def forward(self, images):
39
+ if type(images) is list:
40
+ image_features = []
41
+ for image in images:
42
+ image_forward_out = self.vision_model(
43
+ image.to(device=self.device, dtype=self.dtype).unsqueeze(0),
44
+ output_hidden_states=True,
45
+ )
46
+ image_feature = self.feature_select(image_forward_out).to(image.dtype)
47
+ image_features.append(image_feature)
48
+ else:
49
+ image_forward_outs = self.vision_model(
50
+ images.to(device=self.device, dtype=self.dtype),
51
+ output_hidden_states=True,
52
+ )
53
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
54
+
55
+ return image_features
56
+
57
+ @property
58
+ def dummy_feature(self):
59
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
60
+
61
+ @property
62
+ def dtype(self):
63
+ return list(self.vision_model.parameters())[0].dtype
64
+
65
+ @property
66
+ def device(self):
67
+ return list(self.vision_model.parameters())[0].device
68
+
69
+ @property
70
+ def hidden_size(self):
71
+ return self.config.hidden_size
72
+
73
+ @property
74
+ def num_patches(self):
75
+ return (self.config.image_size // self.config.patch_size) ** 2
76
+
77
+
78
+ if __name__ == "__main__":
79
+ clip_config = CLIPVisionConfig.from_pretrained(
80
+ "/data/private/zhumj/GPTcode/mm-phi/openai/clip-vit-large-patch14-336"
81
+ )
82
+ print("################ clip_config ##############")
83
+ print(clip_config)
84
+ phi_vis_config = LlavaPhiVisionConfig(**clip_config.to_dict())
85
+ print("################ phi_vis_config ##############")
86
+ print(phi_vis_config)
87
+
88
+ model = CLIPVisionTower(clip_config)
89
+ # print(list(model.vision_model.parameters())[0].dtype)
inference/model/multimodal_projector/builder.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import re
4
+
5
+
6
+ class IdentityMap(nn.Module):
7
+ def __init__(self):
8
+ super().__init__()
9
+
10
+ def forward(self, x, *args, **kwargs):
11
+ return x
12
+
13
+ @property
14
+ def config(self):
15
+ return {"mm_projector_type": "identity"}
16
+
17
+
18
+ class SimpleResBlock(nn.Module):
19
+ def __init__(self, channels):
20
+ super().__init__()
21
+ self.pre_norm = nn.LayerNorm(channels)
22
+
23
+ self.proj = nn.Sequential(
24
+ nn.Linear(channels, channels), nn.GELU(), nn.Linear(channels, channels)
25
+ )
26
+
27
+ def forward(self, x):
28
+ x = self.pre_norm(x)
29
+ return x + self.proj(x)
30
+
31
+
32
+ def build_vision_projector(config):
33
+ projector_type = getattr(config, "mm_projector_type", "linear")
34
+
35
+ if projector_type == "linear":
36
+ return nn.Linear(config.mm_hidden_size, config.hidden_size)
37
+
38
+ mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type)
39
+ if mlp_gelu_match:
40
+ mlp_depth = int(mlp_gelu_match.group(1))
41
+ modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
42
+ for _ in range(1, mlp_depth):
43
+ modules.append(nn.GELU())
44
+ modules.append(nn.Linear(config.hidden_size, config.hidden_size))
45
+ return nn.Sequential(*modules)
46
+
47
+ if projector_type == "identity":
48
+ return IdentityMap()
49
+
50
+ raise ValueError(f"Unknown projector type: {projector_type}")