Spaces:
Runtime error
Runtime error
GunaKoppula
commited on
Commit
•
a8d9c50
1
Parent(s):
efe75b3
Upload 19 files
Browse files- Experiments/clip_expt.ipynb +840 -0
- Experiments/eval.ipynb +782 -0
- Experiments/instruct_150k_data.ipynb +0 -0
- Experiments/instruct_data.py +39 -0
- Experiments/llava_exp.ipynb +145 -0
- Experiments/multimodal_exp.ipynb +362 -0
- Experiments/pretrain_data_check.ipynb +304 -0
- Experiments/whispher_exp.ipynb +500 -0
- inference/__init__.py +0 -0
- inference/conversation.py +224 -0
- inference/inference.ipynb +369 -0
- inference/main.py +226 -0
- inference/model/__init__.py +2 -0
- inference/model/builder.py +180 -0
- inference/model/language_model/configuration_llava_phi.py +191 -0
- inference/model/language_model/llava_phi.py +126 -0
- inference/model/llava_arch.py +330 -0
- inference/model/multimodal_encoder/clip_encoder.py +89 -0
- inference/model/multimodal_projector/builder.py +50 -0
Experiments/clip_expt.ipynb
ADDED
@@ -0,0 +1,840 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 2,
|
6 |
+
"id": "9fe51ce7-4c87-4186-9fd3-0fb18ac43e56",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [],
|
9 |
+
"source": [
|
10 |
+
"from PIL import Image\n",
|
11 |
+
"import requests\n",
|
12 |
+
"from transformers import AutoProcessor, CLIPVisionModel"
|
13 |
+
]
|
14 |
+
},
|
15 |
+
{
|
16 |
+
"cell_type": "code",
|
17 |
+
"execution_count": 3,
|
18 |
+
"id": "0f4c21dd-4258-461d-8511-5be089d068a8",
|
19 |
+
"metadata": {},
|
20 |
+
"outputs": [],
|
21 |
+
"source": [
|
22 |
+
"model = CLIPVisionModel.from_pretrained(\"openai/clip-vit-base-patch32\", device_map=\"cuda:0\")\n",
|
23 |
+
"processor = AutoProcessor.from_pretrained(\"openai/clip-vit-base-patch32\", device_map=\"cuda:0\")"
|
24 |
+
]
|
25 |
+
},
|
26 |
+
{
|
27 |
+
"cell_type": "code",
|
28 |
+
"execution_count": 4,
|
29 |
+
"id": "98b9f906-ffaa-4be4-8671-4ecf65f12c49",
|
30 |
+
"metadata": {},
|
31 |
+
"outputs": [],
|
32 |
+
"source": [
|
33 |
+
"# url = \"http://images.cocodataset.org/val2017/000000039769.jpg\"\n",
|
34 |
+
"# image = Image.open(requests.get(url, stream=True).raw)\n",
|
35 |
+
"image = Image.open(\"002579.jpg\")"
|
36 |
+
]
|
37 |
+
},
|
38 |
+
{
|
39 |
+
"cell_type": "code",
|
40 |
+
"execution_count": 17,
|
41 |
+
"id": "54b2e4ce-b77b-4314-87f6-ca2a1970fc79",
|
42 |
+
"metadata": {},
|
43 |
+
"outputs": [],
|
44 |
+
"source": [
|
45 |
+
"# image"
|
46 |
+
]
|
47 |
+
},
|
48 |
+
{
|
49 |
+
"cell_type": "code",
|
50 |
+
"execution_count": 18,
|
51 |
+
"id": "cdd65c58-007f-450b-8deb-f8b4f372a823",
|
52 |
+
"metadata": {},
|
53 |
+
"outputs": [],
|
54 |
+
"source": [
|
55 |
+
"# image = None"
|
56 |
+
]
|
57 |
+
},
|
58 |
+
{
|
59 |
+
"cell_type": "code",
|
60 |
+
"execution_count": 5,
|
61 |
+
"id": "e9066c2e-c78b-49d1-979b-10d0f4f09441",
|
62 |
+
"metadata": {},
|
63 |
+
"outputs": [],
|
64 |
+
"source": [
|
65 |
+
"inputs = processor(images=image, return_tensors=\"pt\", device_map=\"cuda:0\")"
|
66 |
+
]
|
67 |
+
},
|
68 |
+
{
|
69 |
+
"cell_type": "code",
|
70 |
+
"execution_count": 20,
|
71 |
+
"id": "e98b211d-29d9-4662-be0b-e011e89b0101",
|
72 |
+
"metadata": {},
|
73 |
+
"outputs": [],
|
74 |
+
"source": [
|
75 |
+
"# inputs"
|
76 |
+
]
|
77 |
+
},
|
78 |
+
{
|
79 |
+
"cell_type": "code",
|
80 |
+
"execution_count": 6,
|
81 |
+
"id": "b030bd3d-4282-4074-98fe-97e658bd0f50",
|
82 |
+
"metadata": {},
|
83 |
+
"outputs": [
|
84 |
+
{
|
85 |
+
"data": {
|
86 |
+
"text/plain": [
|
87 |
+
"torch.Size([1, 3, 224, 224])"
|
88 |
+
]
|
89 |
+
},
|
90 |
+
"execution_count": 6,
|
91 |
+
"metadata": {},
|
92 |
+
"output_type": "execute_result"
|
93 |
+
}
|
94 |
+
],
|
95 |
+
"source": [
|
96 |
+
"inputs[\"pixel_values\"].shape"
|
97 |
+
]
|
98 |
+
},
|
99 |
+
{
|
100 |
+
"cell_type": "code",
|
101 |
+
"execution_count": 22,
|
102 |
+
"id": "0ce68f11-1c88-4dd7-8b17-0d1de5811fe6",
|
103 |
+
"metadata": {},
|
104 |
+
"outputs": [],
|
105 |
+
"source": [
|
106 |
+
"outputs = model(inputs[\"pixel_values\"].to(\"cuda:0\"))\n",
|
107 |
+
"last_hidden_state = outputs.last_hidden_state\n",
|
108 |
+
"pooled_output = outputs.pooler_output # pooled CLS states"
|
109 |
+
]
|
110 |
+
},
|
111 |
+
{
|
112 |
+
"cell_type": "code",
|
113 |
+
"execution_count": 23,
|
114 |
+
"id": "30cb0918-a30e-4246-b540-6b8e0d876807",
|
115 |
+
"metadata": {},
|
116 |
+
"outputs": [
|
117 |
+
{
|
118 |
+
"data": {
|
119 |
+
"text/plain": [
|
120 |
+
"torch.Size([1, 768])"
|
121 |
+
]
|
122 |
+
},
|
123 |
+
"execution_count": 23,
|
124 |
+
"metadata": {},
|
125 |
+
"output_type": "execute_result"
|
126 |
+
}
|
127 |
+
],
|
128 |
+
"source": [
|
129 |
+
"pooled_output.shape"
|
130 |
+
]
|
131 |
+
},
|
132 |
+
{
|
133 |
+
"cell_type": "code",
|
134 |
+
"execution_count": 24,
|
135 |
+
"id": "6399543a-f23f-426d-8289-3bb52d293ece",
|
136 |
+
"metadata": {},
|
137 |
+
"outputs": [
|
138 |
+
{
|
139 |
+
"data": {
|
140 |
+
"text/plain": [
|
141 |
+
"torch.Size([1, 50, 768])"
|
142 |
+
]
|
143 |
+
},
|
144 |
+
"execution_count": 24,
|
145 |
+
"metadata": {},
|
146 |
+
"output_type": "execute_result"
|
147 |
+
}
|
148 |
+
],
|
149 |
+
"source": [
|
150 |
+
"last_hidden_state.shape"
|
151 |
+
]
|
152 |
+
},
|
153 |
+
{
|
154 |
+
"cell_type": "code",
|
155 |
+
"execution_count": 25,
|
156 |
+
"id": "19a70443-5942-4937-b3ea-6a52d76e2b08",
|
157 |
+
"metadata": {},
|
158 |
+
"outputs": [
|
159 |
+
{
|
160 |
+
"data": {
|
161 |
+
"text/plain": [
|
162 |
+
"torch.Size([1, 768])"
|
163 |
+
]
|
164 |
+
},
|
165 |
+
"execution_count": 25,
|
166 |
+
"metadata": {},
|
167 |
+
"output_type": "execute_result"
|
168 |
+
}
|
169 |
+
],
|
170 |
+
"source": [
|
171 |
+
"outputs[1].shape"
|
172 |
+
]
|
173 |
+
},
|
174 |
+
{
|
175 |
+
"cell_type": "code",
|
176 |
+
"execution_count": 8,
|
177 |
+
"id": "fa13903f-a94a-4839-ae5a-8df4f55c68b6",
|
178 |
+
"metadata": {},
|
179 |
+
"outputs": [],
|
180 |
+
"source": [
|
181 |
+
"import torch\n",
|
182 |
+
"from torch import nn\n",
|
183 |
+
"from transformers import CLIPVisionConfig,CLIPPreTrainedModel"
|
184 |
+
]
|
185 |
+
},
|
186 |
+
{
|
187 |
+
"cell_type": "code",
|
188 |
+
"execution_count": 9,
|
189 |
+
"id": "b2bd9198-42f0-40c3-80e1-d167c0b038fb",
|
190 |
+
"metadata": {},
|
191 |
+
"outputs": [
|
192 |
+
{
|
193 |
+
"ename": "NameError",
|
194 |
+
"evalue": "name 'Optional' is not defined",
|
195 |
+
"output_type": "error",
|
196 |
+
"traceback": [
|
197 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
198 |
+
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
|
199 |
+
"Cell \u001b[0;32mIn[9], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mCLIPVisionModelWithProjection\u001b[39;00m(CLIPPreTrainedModel):\n\u001b[1;32m 2\u001b[0m config_class \u001b[38;5;241m=\u001b[39m CLIPVisionConfig\n\u001b[1;32m 3\u001b[0m main_input_name \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpixel_values\u001b[39m\u001b[38;5;124m\"\u001b[39m\n",
|
200 |
+
"Cell \u001b[0;32mIn[9], line 20\u001b[0m, in \u001b[0;36mCLIPVisionModelWithProjection\u001b[0;34m()\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mget_input_embeddings\u001b[39m(\u001b[38;5;28mself\u001b[39m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m nn\u001b[38;5;241m.\u001b[39mModule:\n\u001b[1;32m 16\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mvision_model\u001b[38;5;241m.\u001b[39membeddings\u001b[38;5;241m.\u001b[39mpatch_embedding\n\u001b[1;32m 18\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\n\u001b[1;32m 19\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[0;32m---> 20\u001b[0m pixel_values: \u001b[43mOptional\u001b[49m[torch\u001b[38;5;241m.\u001b[39mFloatTensor] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 21\u001b[0m output_attentions: Optional[\u001b[38;5;28mbool\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 22\u001b[0m output_hidden_states: Optional[\u001b[38;5;28mbool\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 23\u001b[0m return_dict: Optional[\u001b[38;5;28mbool\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 24\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Union[Tuple, CLIPVisionModelOutput]:\n\u001b[1;32m 25\u001b[0m return_dict \u001b[38;5;241m=\u001b[39m return_dict \u001b[38;5;28;01mif\u001b[39;00m return_dict \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39muse_return_dict\n\u001b[1;32m 27\u001b[0m vision_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mvision_model(\n\u001b[1;32m 28\u001b[0m pixel_values\u001b[38;5;241m=\u001b[39mpixel_values,\n\u001b[1;32m 29\u001b[0m output_attentions\u001b[38;5;241m=\u001b[39moutput_attentions,\n\u001b[1;32m 30\u001b[0m output_hidden_states\u001b[38;5;241m=\u001b[39moutput_hidden_states,\n\u001b[1;32m 31\u001b[0m return_dict\u001b[38;5;241m=\u001b[39mreturn_dict,\n\u001b[1;32m 32\u001b[0m )\n",
|
201 |
+
"\u001b[0;31mNameError\u001b[0m: name 'Optional' is not defined"
|
202 |
+
]
|
203 |
+
}
|
204 |
+
],
|
205 |
+
"source": [
|
206 |
+
"class CLIPVisionModelWithProjection(CLIPPreTrainedModel):\n",
|
207 |
+
" config_class = CLIPVisionConfig\n",
|
208 |
+
" main_input_name = \"pixel_values\"\n",
|
209 |
+
"\n",
|
210 |
+
" def __init__(self, config: CLIPVisionConfig):\n",
|
211 |
+
" super().__init__(config)\n",
|
212 |
+
"\n",
|
213 |
+
" self.vision_model = CLIPVisionTransformer(config)\n",
|
214 |
+
"\n",
|
215 |
+
" self.visual_projection = nn.Linear(config.hidden_size, config.projection_dim, bias=False)\n",
|
216 |
+
"\n",
|
217 |
+
" # Initialize weights and apply final processing\n",
|
218 |
+
" self.post_init()\n",
|
219 |
+
"\n",
|
220 |
+
" def get_input_embeddings(self) -> nn.Module:\n",
|
221 |
+
" return self.vision_model.embeddings.patch_embedding\n",
|
222 |
+
"\n",
|
223 |
+
" def forward(\n",
|
224 |
+
" self,\n",
|
225 |
+
" pixel_values: Optional[torch.FloatTensor] = None,\n",
|
226 |
+
" output_attentions: Optional[bool] = None,\n",
|
227 |
+
" output_hidden_states: Optional[bool] = None,\n",
|
228 |
+
" return_dict: Optional[bool] = None,\n",
|
229 |
+
" ) -> Union[Tuple, CLIPVisionModelOutput]:\n",
|
230 |
+
" return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n",
|
231 |
+
"\n",
|
232 |
+
" vision_outputs = self.vision_model(\n",
|
233 |
+
" pixel_values=pixel_values,\n",
|
234 |
+
" output_attentions=output_attentions,\n",
|
235 |
+
" output_hidden_states=output_hidden_states,\n",
|
236 |
+
" return_dict=return_dict,\n",
|
237 |
+
" )\n",
|
238 |
+
"\n",
|
239 |
+
" pooled_output = vision_outputs[1] # pooled_output\n",
|
240 |
+
"\n",
|
241 |
+
" image_embeds = self.visual_projection(pooled_output)\n",
|
242 |
+
"\n",
|
243 |
+
" if not return_dict:\n",
|
244 |
+
" outputs = (image_embeds, vision_outputs[0]) + vision_outputs[2:]\n",
|
245 |
+
" return tuple(output for output in outputs if output is not None)\n",
|
246 |
+
"\n",
|
247 |
+
" return CLIPVisionModelOutput(\n",
|
248 |
+
" image_embeds=image_embeds,\n",
|
249 |
+
" last_hidden_state=vision_outputs.last_hidden_state,\n",
|
250 |
+
" hidden_states=vision_outputs.hidden_states,\n",
|
251 |
+
" attentions=vision_outputs.attentions,\n",
|
252 |
+
" )"
|
253 |
+
]
|
254 |
+
},
|
255 |
+
{
|
256 |
+
"cell_type": "code",
|
257 |
+
"execution_count": 27,
|
258 |
+
"id": "68a9ee4a-d977-4725-842d-e64e0dd2f61d",
|
259 |
+
"metadata": {
|
260 |
+
"collapsed": true,
|
261 |
+
"jupyter": {
|
262 |
+
"outputs_hidden": true
|
263 |
+
}
|
264 |
+
},
|
265 |
+
"outputs": [
|
266 |
+
{
|
267 |
+
"name": "stderr",
|
268 |
+
"output_type": "stream",
|
269 |
+
"text": [
|
270 |
+
"loading configuration file config.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/config.json\n",
|
271 |
+
"`text_config` is `None`. Initializing the `CLIPTextConfig` with default values.\n",
|
272 |
+
"`vision_config` is `None`. initializing the `CLIPVisionConfig` with default values.\n",
|
273 |
+
"Model config CLIPConfig {\n",
|
274 |
+
" \"_name_or_path\": \"openai/clip-vit-base-patch32\",\n",
|
275 |
+
" \"architectures\": [\n",
|
276 |
+
" \"CLIPModel\"\n",
|
277 |
+
" ],\n",
|
278 |
+
" \"initializer_factor\": 1.0,\n",
|
279 |
+
" \"logit_scale_init_value\": 2.6592,\n",
|
280 |
+
" \"model_type\": \"clip\",\n",
|
281 |
+
" \"projection_dim\": 512,\n",
|
282 |
+
" \"text_config\": {\n",
|
283 |
+
" \"bos_token_id\": 0,\n",
|
284 |
+
" \"dropout\": 0.0,\n",
|
285 |
+
" \"eos_token_id\": 2,\n",
|
286 |
+
" \"model_type\": \"clip_text_model\"\n",
|
287 |
+
" },\n",
|
288 |
+
" \"transformers_version\": \"4.36.2\",\n",
|
289 |
+
" \"vision_config\": {\n",
|
290 |
+
" \"dropout\": 0.0,\n",
|
291 |
+
" \"model_type\": \"clip_vision_model\"\n",
|
292 |
+
" }\n",
|
293 |
+
"}\n",
|
294 |
+
"\n",
|
295 |
+
"loading weights file pytorch_model.bin from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/pytorch_model.bin\n",
|
296 |
+
"All model checkpoint weights were used when initializing CLIPModel.\n",
|
297 |
+
"\n",
|
298 |
+
"All the weights of CLIPModel were initialized from the model checkpoint at openai/clip-vit-base-patch32.\n",
|
299 |
+
"If your task is similar to the task the model of the checkpoint was trained on, you can already use CLIPModel for predictions without further training.\n",
|
300 |
+
"loading configuration file preprocessor_config.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/preprocessor_config.json\n",
|
301 |
+
"loading configuration file preprocessor_config.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/preprocessor_config.json\n",
|
302 |
+
"loading configuration file config.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/config.json\n",
|
303 |
+
"`text_config` is `None`. Initializing the `CLIPTextConfig` with default values.\n",
|
304 |
+
"`vision_config` is `None`. initializing the `CLIPVisionConfig` with default values.\n",
|
305 |
+
"Model config CLIPConfig {\n",
|
306 |
+
" \"_name_or_path\": \"openai/clip-vit-base-patch32\",\n",
|
307 |
+
" \"architectures\": [\n",
|
308 |
+
" \"CLIPModel\"\n",
|
309 |
+
" ],\n",
|
310 |
+
" \"initializer_factor\": 1.0,\n",
|
311 |
+
" \"logit_scale_init_value\": 2.6592,\n",
|
312 |
+
" \"model_type\": \"clip\",\n",
|
313 |
+
" \"projection_dim\": 512,\n",
|
314 |
+
" \"text_config\": {\n",
|
315 |
+
" \"bos_token_id\": 0,\n",
|
316 |
+
" \"dropout\": 0.0,\n",
|
317 |
+
" \"eos_token_id\": 2,\n",
|
318 |
+
" \"model_type\": \"clip_text_model\"\n",
|
319 |
+
" },\n",
|
320 |
+
" \"transformers_version\": \"4.36.2\",\n",
|
321 |
+
" \"vision_config\": {\n",
|
322 |
+
" \"dropout\": 0.0,\n",
|
323 |
+
" \"model_type\": \"clip_vision_model\"\n",
|
324 |
+
" }\n",
|
325 |
+
"}\n",
|
326 |
+
"\n",
|
327 |
+
"loading configuration file preprocessor_config.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/preprocessor_config.json\n",
|
328 |
+
"size should be a dictionary on of the following set of keys: ({'width', 'height'}, {'shortest_edge'}, {'longest_edge', 'shortest_edge'}, {'longest_edge'}), got 224. Converted to {'shortest_edge': 224}.\n",
|
329 |
+
"crop_size should be a dictionary on of the following set of keys: ({'width', 'height'}, {'shortest_edge'}, {'longest_edge', 'shortest_edge'}, {'longest_edge'}), got 224. Converted to {'height': 224, 'width': 224}.\n",
|
330 |
+
"Image processor CLIPImageProcessor {\n",
|
331 |
+
" \"crop_size\": {\n",
|
332 |
+
" \"height\": 224,\n",
|
333 |
+
" \"width\": 224\n",
|
334 |
+
" },\n",
|
335 |
+
" \"do_center_crop\": true,\n",
|
336 |
+
" \"do_convert_rgb\": true,\n",
|
337 |
+
" \"do_normalize\": true,\n",
|
338 |
+
" \"do_rescale\": true,\n",
|
339 |
+
" \"do_resize\": true,\n",
|
340 |
+
" \"feature_extractor_type\": \"CLIPFeatureExtractor\",\n",
|
341 |
+
" \"image_mean\": [\n",
|
342 |
+
" 0.48145466,\n",
|
343 |
+
" 0.4578275,\n",
|
344 |
+
" 0.40821073\n",
|
345 |
+
" ],\n",
|
346 |
+
" \"image_processor_type\": \"CLIPImageProcessor\",\n",
|
347 |
+
" \"image_std\": [\n",
|
348 |
+
" 0.26862954,\n",
|
349 |
+
" 0.26130258,\n",
|
350 |
+
" 0.27577711\n",
|
351 |
+
" ],\n",
|
352 |
+
" \"resample\": 3,\n",
|
353 |
+
" \"rescale_factor\": 0.00392156862745098,\n",
|
354 |
+
" \"size\": {\n",
|
355 |
+
" \"shortest_edge\": 224\n",
|
356 |
+
" }\n",
|
357 |
+
"}\n",
|
358 |
+
"\n",
|
359 |
+
"loading file vocab.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/vocab.json\n",
|
360 |
+
"loading file merges.txt from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/merges.txt\n",
|
361 |
+
"loading file tokenizer.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/tokenizer.json\n",
|
362 |
+
"loading file added_tokens.json from cache at None\n",
|
363 |
+
"loading file special_tokens_map.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/special_tokens_map.json\n",
|
364 |
+
"loading file tokenizer_config.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/tokenizer_config.json\n",
|
365 |
+
"loading configuration file config.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/config.json\n",
|
366 |
+
"`text_config` is `None`. Initializing the `CLIPTextConfig` with default values.\n",
|
367 |
+
"`vision_config` is `None`. initializing the `CLIPVisionConfig` with default values.\n",
|
368 |
+
"Model config CLIPConfig {\n",
|
369 |
+
" \"_name_or_path\": \"openai/clip-vit-base-patch32\",\n",
|
370 |
+
" \"architectures\": [\n",
|
371 |
+
" \"CLIPModel\"\n",
|
372 |
+
" ],\n",
|
373 |
+
" \"initializer_factor\": 1.0,\n",
|
374 |
+
" \"logit_scale_init_value\": 2.6592,\n",
|
375 |
+
" \"model_type\": \"clip\",\n",
|
376 |
+
" \"projection_dim\": 512,\n",
|
377 |
+
" \"text_config\": {\n",
|
378 |
+
" \"bos_token_id\": 0,\n",
|
379 |
+
" \"dropout\": 0.0,\n",
|
380 |
+
" \"eos_token_id\": 2,\n",
|
381 |
+
" \"model_type\": \"clip_text_model\"\n",
|
382 |
+
" },\n",
|
383 |
+
" \"transformers_version\": \"4.36.2\",\n",
|
384 |
+
" \"vision_config\": {\n",
|
385 |
+
" \"dropout\": 0.0,\n",
|
386 |
+
" \"model_type\": \"clip_vision_model\"\n",
|
387 |
+
" }\n",
|
388 |
+
"}\n",
|
389 |
+
"\n"
|
390 |
+
]
|
391 |
+
}
|
392 |
+
],
|
393 |
+
"source": [
|
394 |
+
"from PIL import Image\n",
|
395 |
+
"import requests\n",
|
396 |
+
"from transformers import AutoProcessor, CLIPModel\n",
|
397 |
+
"\n",
|
398 |
+
"model = CLIPModel.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
|
399 |
+
"processor = AutoProcessor.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
|
400 |
+
"\n",
|
401 |
+
"url = \"http://images.cocodataset.org/val2017/000000039769.jpg\"\n",
|
402 |
+
"image = Image.open(requests.get(url, stream=True).raw)\n",
|
403 |
+
"\n",
|
404 |
+
"inputs = processor(images=image, return_tensors=\"pt\")\n",
|
405 |
+
"\n",
|
406 |
+
"image_features = model.get_image_features(**inputs)"
|
407 |
+
]
|
408 |
+
},
|
409 |
+
{
|
410 |
+
"cell_type": "code",
|
411 |
+
"execution_count": 29,
|
412 |
+
"id": "9ff63766-b706-452b-b735-bf9000fb9c20",
|
413 |
+
"metadata": {},
|
414 |
+
"outputs": [
|
415 |
+
{
|
416 |
+
"data": {
|
417 |
+
"text/plain": [
|
418 |
+
"torch.Size([1, 512])"
|
419 |
+
]
|
420 |
+
},
|
421 |
+
"execution_count": 29,
|
422 |
+
"metadata": {},
|
423 |
+
"output_type": "execute_result"
|
424 |
+
}
|
425 |
+
],
|
426 |
+
"source": [
|
427 |
+
"image_features.shape"
|
428 |
+
]
|
429 |
+
},
|
430 |
+
{
|
431 |
+
"cell_type": "code",
|
432 |
+
"execution_count": 30,
|
433 |
+
"id": "82566e7b-3c91-421a-94c5-f1e2b3e91c8c",
|
434 |
+
"metadata": {
|
435 |
+
"collapsed": true,
|
436 |
+
"jupyter": {
|
437 |
+
"outputs_hidden": true
|
438 |
+
}
|
439 |
+
},
|
440 |
+
"outputs": [
|
441 |
+
{
|
442 |
+
"name": "stderr",
|
443 |
+
"output_type": "stream",
|
444 |
+
"text": [
|
445 |
+
"loading configuration file config.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/config.json\n",
|
446 |
+
"Model config CLIPVisionConfig {\n",
|
447 |
+
" \"attention_dropout\": 0.0,\n",
|
448 |
+
" \"dropout\": 0.0,\n",
|
449 |
+
" \"hidden_act\": \"quick_gelu\",\n",
|
450 |
+
" \"hidden_size\": 768,\n",
|
451 |
+
" \"image_size\": 224,\n",
|
452 |
+
" \"initializer_factor\": 1.0,\n",
|
453 |
+
" \"initializer_range\": 0.02,\n",
|
454 |
+
" \"intermediate_size\": 3072,\n",
|
455 |
+
" \"layer_norm_eps\": 1e-05,\n",
|
456 |
+
" \"model_type\": \"clip_vision_model\",\n",
|
457 |
+
" \"num_attention_heads\": 12,\n",
|
458 |
+
" \"num_channels\": 3,\n",
|
459 |
+
" \"num_hidden_layers\": 12,\n",
|
460 |
+
" \"patch_size\": 32,\n",
|
461 |
+
" \"projection_dim\": 512,\n",
|
462 |
+
" \"transformers_version\": \"4.36.2\"\n",
|
463 |
+
"}\n",
|
464 |
+
"\n",
|
465 |
+
"loading weights file pytorch_model.bin from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/pytorch_model.bin\n",
|
466 |
+
"Some weights of the model checkpoint at openai/clip-vit-base-patch32 were not used when initializing CLIPVisionModel: ['text_model.encoder.layers.8.mlp.fc2.weight', 'text_model.encoder.layers.4.self_attn.v_proj.weight', 'text_model.encoder.layers.2.mlp.fc2.bias', 'text_model.encoder.layers.2.self_attn.q_proj.bias', 'text_model.encoder.layers.4.self_attn.v_proj.bias', 'text_model.encoder.layers.6.self_attn.k_proj.bias', 'text_model.encoder.layers.9.self_attn.k_proj.weight', 'text_model.encoder.layers.11.self_attn.q_proj.bias', 'text_model.encoder.layers.3.self_attn.out_proj.bias', 'text_model.encoder.layers.11.self_attn.k_proj.bias', 'text_model.encoder.layers.7.self_attn.k_proj.bias', 'text_model.encoder.layers.1.self_attn.q_proj.weight', 'text_model.encoder.layers.11.layer_norm1.bias', 'text_model.encoder.layers.11.mlp.fc2.bias', 'text_model.encoder.layers.10.layer_norm1.bias', 'text_model.encoder.layers.7.self_attn.q_proj.bias', 'text_model.encoder.layers.11.self_attn.k_proj.weight', 'text_model.encoder.layers.5.self_attn.q_proj.bias', 'text_model.encoder.layers.0.self_attn.v_proj.bias', 'logit_scale', 'text_model.encoder.layers.8.mlp.fc1.bias', 'text_model.encoder.layers.6.layer_norm1.bias', 'text_model.encoder.layers.5.self_attn.out_proj.weight', 'text_model.encoder.layers.7.self_attn.out_proj.bias', 'text_model.embeddings.token_embedding.weight', 'text_model.encoder.layers.8.layer_norm2.bias', 'text_model.encoder.layers.9.self_attn.v_proj.weight', 'text_model.encoder.layers.1.mlp.fc1.weight', 'text_model.encoder.layers.0.layer_norm1.weight', 'text_model.encoder.layers.6.self_attn.k_proj.weight', 'text_model.encoder.layers.3.self_attn.q_proj.weight', 'text_model.encoder.layers.2.layer_norm1.weight', 'text_model.encoder.layers.0.self_attn.v_proj.weight', 'text_model.encoder.layers.7.self_attn.q_proj.weight', 'text_model.encoder.layers.1.layer_norm2.weight', 'text_model.encoder.layers.2.self_attn.out_proj.weight', 'text_model.encoder.layers.3.self_attn.k_proj.weight', 'text_model.encoder.layers.7.mlp.fc2.bias', 'text_model.encoder.layers.10.self_attn.out_proj.weight', 'text_model.encoder.layers.2.self_attn.q_proj.weight', 'text_model.encoder.layers.1.self_attn.k_proj.weight', 'text_model.encoder.layers.4.layer_norm1.weight', 'text_model.encoder.layers.0.self_attn.q_proj.weight', 'text_model.encoder.layers.0.self_attn.out_proj.bias', 'text_model.encoder.layers.4.self_attn.out_proj.bias', 'text_model.encoder.layers.5.self_attn.k_proj.weight', 'visual_projection.weight', 'text_model.encoder.layers.6.layer_norm2.bias', 'text_model.encoder.layers.6.layer_norm1.weight', 'text_model.encoder.layers.4.self_attn.out_proj.weight', 'text_model.encoder.layers.10.mlp.fc2.bias', 'text_model.encoder.layers.10.mlp.fc1.weight', 'text_model.encoder.layers.6.self_attn.out_proj.weight', 'text_model.encoder.layers.9.layer_norm1.weight', 'text_model.encoder.layers.11.layer_norm2.weight', 'text_model.encoder.layers.6.self_attn.q_proj.bias', 'text_model.encoder.layers.5.mlp.fc1.weight', 'text_model.encoder.layers.2.mlp.fc1.weight', 'text_model.encoder.layers.11.self_attn.out_proj.weight', 'text_model.encoder.layers.0.self_attn.out_proj.weight', 'text_model.encoder.layers.11.mlp.fc2.weight', 'text_model.encoder.layers.7.layer_norm2.weight', 'text_model.encoder.layers.10.self_attn.v_proj.bias', 'text_model.encoder.layers.9.mlp.fc1.bias', 'text_model.encoder.layers.8.self_attn.v_proj.weight', 'text_model.encoder.layers.3.layer_norm1.bias', 'text_model.encoder.layers.6.self_attn.v_proj.bias', 'text_model.encoder.layers.1.self_attn.v_proj.bias', 'text_model.encoder.layers.9.self_attn.q_proj.weight', 'text_model.encoder.layers.4.self_attn.k_proj.weight', 'text_model.encoder.layers.7.layer_norm1.weight', 'text_model.encoder.layers.10.self_attn.k_proj.weight', 'text_model.encoder.layers.7.self_attn.v_proj.bias', 'text_model.encoder.layers.7.mlp.fc1.bias', 'text_model.encoder.layers.11.mlp.fc1.weight', 'text_model.encoder.layers.2.mlp.fc1.bias', 'text_model.encoder.layers.3.mlp.fc2.bias', 'text_model.encoder.layers.8.self_attn.q_proj.weight', 'text_model.encoder.layers.0.mlp.fc1.weight', 'text_model.encoder.layers.11.self_attn.out_proj.bias', 'text_model.encoder.layers.1.self_attn.v_proj.weight', 'text_model.encoder.layers.0.self_attn.k_proj.weight', 'text_model.encoder.layers.9.layer_norm1.bias', 'text_model.final_layer_norm.weight', 'text_model.encoder.layers.3.layer_norm1.weight', 'text_model.encoder.layers.4.mlp.fc1.bias', 'text_model.encoder.layers.1.layer_norm1.weight', 'text_model.encoder.layers.10.layer_norm2.bias', 'text_model.encoder.layers.9.self_attn.v_proj.bias', 'text_model.encoder.layers.10.self_attn.k_proj.bias', 'text_model.encoder.layers.8.mlp.fc2.bias', 'text_model.encoder.layers.5.mlp.fc2.bias', 'text_model.encoder.layers.6.self_attn.q_proj.weight', 'text_model.encoder.layers.5.self_attn.out_proj.bias', 'text_model.encoder.layers.9.mlp.fc2.bias', 'text_model.encoder.layers.5.layer_norm2.weight', 'text_model.encoder.layers.2.mlp.fc2.weight', 'text_model.encoder.layers.3.self_attn.out_proj.weight', 'text_model.encoder.layers.6.mlp.fc2.weight', 'text_model.encoder.layers.1.self_attn.out_proj.weight', 'text_model.encoder.layers.1.mlp.fc2.bias', 'text_model.encoder.layers.7.mlp.fc2.weight', 'text_model.encoder.layers.10.self_attn.v_proj.weight', 'text_model.encoder.layers.11.self_attn.v_proj.bias', 'text_model.encoder.layers.4.layer_norm1.bias', 'text_model.encoder.layers.4.layer_norm2.bias', 'text_model.encoder.layers.8.self_attn.q_proj.bias', 'text_model.embeddings.position_ids', 'text_model.encoder.layers.10.layer_norm2.weight', 'text_model.encoder.layers.1.self_attn.out_proj.bias', 'text_model.encoder.layers.2.layer_norm2.weight', 'text_model.encoder.layers.10.self_attn.q_proj.weight', 'text_model.encoder.layers.4.mlp.fc1.weight', 'text_model.encoder.layers.8.layer_norm1.bias', 'text_model.encoder.layers.2.self_attn.k_proj.weight', 'text_model.encoder.layers.5.mlp.fc1.bias', 'text_model.encoder.layers.9.self_attn.out_proj.bias', 'text_model.encoder.layers.7.self_attn.v_proj.weight', 'text_model.encoder.layers.2.self_attn.k_proj.bias', 'text_model.encoder.layers.5.self_attn.k_proj.bias', 'text_model.encoder.layers.8.self_attn.out_proj.bias', 'text_model.encoder.layers.7.self_attn.k_proj.weight', 'text_model.encoder.layers.6.mlp.fc1.weight', 'text_model.encoder.layers.6.mlp.fc1.bias', 'text_model.encoder.layers.3.self_attn.v_proj.weight', 'text_model.encoder.layers.3.self_attn.q_proj.bias', 'text_model.encoder.layers.9.self_attn.out_proj.weight', 'text_model.encoder.layers.3.mlp.fc1.bias', 'text_model.encoder.layers.0.self_attn.q_proj.bias', 'text_model.encoder.layers.1.layer_norm2.bias', 'text_model.encoder.layers.8.layer_norm2.weight', 'text_model.encoder.layers.5.self_attn.q_proj.weight', 'text_model.encoder.layers.4.layer_norm2.weight', 'text_model.encoder.layers.4.mlp.fc2.bias', 'text_model.encoder.layers.9.mlp.fc2.weight', 'text_model.encoder.layers.8.self_attn.k_proj.weight', 'text_model.encoder.layers.10.layer_norm1.weight', 'text_model.encoder.layers.0.self_attn.k_proj.bias', 'text_model.encoder.layers.8.self_attn.k_proj.bias', 'text_model.encoder.layers.9.layer_norm2.weight', 'text_model.encoder.layers.4.self_attn.k_proj.bias', 'text_model.encoder.layers.6.layer_norm2.weight', 'text_model.encoder.layers.0.layer_norm2.weight', 'text_model.encoder.layers.5.self_attn.v_proj.bias', 'text_model.encoder.layers.3.layer_norm2.bias', 'text_model.encoder.layers.8.mlp.fc1.weight', 'text_model.encoder.layers.4.self_attn.q_proj.bias', 'text_model.encoder.layers.8.layer_norm1.weight', 'text_model.encoder.layers.2.self_attn.v_proj.weight', 'text_model.encoder.layers.3.self_attn.v_proj.bias', 'text_model.encoder.layers.11.mlp.fc1.bias', 'text_model.encoder.layers.6.mlp.fc2.bias', 'text_model.encoder.layers.1.mlp.fc1.bias', 'text_model.encoder.layers.2.self_attn.v_proj.bias', 'text_model.encoder.layers.5.mlp.fc2.weight', 'text_model.encoder.layers.8.self_attn.v_proj.bias', 'text_model.encoder.layers.10.self_attn.out_proj.bias', 'text_model.encoder.layers.5.layer_norm1.bias', 'text_model.encoder.layers.5.self_attn.v_proj.weight', 'text_model.encoder.layers.10.self_attn.q_proj.bias', 'text_model.encoder.layers.2.layer_norm2.bias', 'text_model.encoder.layers.7.layer_norm1.bias', 'text_model.encoder.layers.4.mlp.fc2.weight', 'text_model.encoder.layers.10.mlp.fc2.weight', 'text_model.encoder.layers.3.mlp.fc1.weight', 'text_model.encoder.layers.5.layer_norm2.bias', 'text_model.encoder.layers.9.self_attn.q_proj.bias', 'text_model.encoder.layers.1.self_attn.k_proj.bias', 'text_model.encoder.layers.7.self_attn.out_proj.weight', 'text_model.encoder.layers.0.mlp.fc2.weight', 'text_model.encoder.layers.11.self_attn.v_proj.weight', 'text_model.encoder.layers.1.layer_norm1.bias', 'text_model.encoder.layers.1.mlp.fc2.weight', 'text_model.encoder.layers.9.layer_norm2.bias', 'text_model.encoder.layers.9.self_attn.k_proj.bias', 'text_model.encoder.layers.11.layer_norm1.weight', 'text_model.encoder.layers.8.self_attn.out_proj.weight', 'text_model.encoder.layers.0.layer_norm1.bias', 'text_model.encoder.layers.7.mlp.fc1.weight', 'text_model.encoder.layers.0.mlp.fc1.bias', 'text_model.encoder.layers.0.layer_norm2.bias', 'text_model.encoder.layers.3.self_attn.k_proj.bias', 'text_model.encoder.layers.5.layer_norm1.weight', 'text_model.encoder.layers.3.layer_norm2.weight', 'text_model.encoder.layers.1.self_attn.q_proj.bias', 'text_model.encoder.layers.2.self_attn.out_proj.bias', 'text_model.encoder.layers.3.mlp.fc2.weight', 'text_model.encoder.layers.11.self_attn.q_proj.weight', 'text_model.final_layer_norm.bias', 'text_model.encoder.layers.6.self_attn.v_proj.weight', 'text_model.encoder.layers.0.mlp.fc2.bias', 'text_model.encoder.layers.7.layer_norm2.bias', 'text_model.encoder.layers.10.mlp.fc1.bias', 'text_model.embeddings.position_embedding.weight', 'text_model.encoder.layers.6.self_attn.out_proj.bias', 'text_model.encoder.layers.2.layer_norm1.bias', 'text_model.encoder.layers.9.mlp.fc1.weight', 'text_projection.weight', 'text_model.encoder.layers.11.layer_norm2.bias', 'text_model.encoder.layers.4.self_attn.q_proj.weight']\n",
|
467 |
+
"- This IS expected if you are initializing CLIPVisionModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
|
468 |
+
"- This IS NOT expected if you are initializing CLIPVisionModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
|
469 |
+
"All the weights of CLIPVisionModel were initialized from the model checkpoint at openai/clip-vit-base-patch32.\n",
|
470 |
+
"If your task is similar to the task the model of the checkpoint was trained on, you can already use CLIPVisionModel for predictions without further training.\n",
|
471 |
+
"loading configuration file preprocessor_config.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/preprocessor_config.json\n",
|
472 |
+
"loading configuration file preprocessor_config.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/preprocessor_config.json\n",
|
473 |
+
"loading configuration file config.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/config.json\n",
|
474 |
+
"`text_config` is `None`. Initializing the `CLIPTextConfig` with default values.\n",
|
475 |
+
"`vision_config` is `None`. initializing the `CLIPVisionConfig` with default values.\n",
|
476 |
+
"Model config CLIPConfig {\n",
|
477 |
+
" \"_name_or_path\": \"openai/clip-vit-base-patch32\",\n",
|
478 |
+
" \"architectures\": [\n",
|
479 |
+
" \"CLIPModel\"\n",
|
480 |
+
" ],\n",
|
481 |
+
" \"initializer_factor\": 1.0,\n",
|
482 |
+
" \"logit_scale_init_value\": 2.6592,\n",
|
483 |
+
" \"model_type\": \"clip\",\n",
|
484 |
+
" \"projection_dim\": 512,\n",
|
485 |
+
" \"text_config\": {\n",
|
486 |
+
" \"bos_token_id\": 0,\n",
|
487 |
+
" \"dropout\": 0.0,\n",
|
488 |
+
" \"eos_token_id\": 2,\n",
|
489 |
+
" \"model_type\": \"clip_text_model\"\n",
|
490 |
+
" },\n",
|
491 |
+
" \"transformers_version\": \"4.36.2\",\n",
|
492 |
+
" \"vision_config\": {\n",
|
493 |
+
" \"dropout\": 0.0,\n",
|
494 |
+
" \"model_type\": \"clip_vision_model\"\n",
|
495 |
+
" }\n",
|
496 |
+
"}\n",
|
497 |
+
"\n",
|
498 |
+
"loading configuration file preprocessor_config.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/preprocessor_config.json\n",
|
499 |
+
"size should be a dictionary on of the following set of keys: ({'width', 'height'}, {'shortest_edge'}, {'longest_edge', 'shortest_edge'}, {'longest_edge'}), got 224. Converted to {'shortest_edge': 224}.\n",
|
500 |
+
"crop_size should be a dictionary on of the following set of keys: ({'width', 'height'}, {'shortest_edge'}, {'longest_edge', 'shortest_edge'}, {'longest_edge'}), got 224. Converted to {'height': 224, 'width': 224}.\n",
|
501 |
+
"Image processor CLIPImageProcessor {\n",
|
502 |
+
" \"crop_size\": {\n",
|
503 |
+
" \"height\": 224,\n",
|
504 |
+
" \"width\": 224\n",
|
505 |
+
" },\n",
|
506 |
+
" \"do_center_crop\": true,\n",
|
507 |
+
" \"do_convert_rgb\": true,\n",
|
508 |
+
" \"do_normalize\": true,\n",
|
509 |
+
" \"do_rescale\": true,\n",
|
510 |
+
" \"do_resize\": true,\n",
|
511 |
+
" \"feature_extractor_type\": \"CLIPFeatureExtractor\",\n",
|
512 |
+
" \"image_mean\": [\n",
|
513 |
+
" 0.48145466,\n",
|
514 |
+
" 0.4578275,\n",
|
515 |
+
" 0.40821073\n",
|
516 |
+
" ],\n",
|
517 |
+
" \"image_processor_type\": \"CLIPImageProcessor\",\n",
|
518 |
+
" \"image_std\": [\n",
|
519 |
+
" 0.26862954,\n",
|
520 |
+
" 0.26130258,\n",
|
521 |
+
" 0.27577711\n",
|
522 |
+
" ],\n",
|
523 |
+
" \"resample\": 3,\n",
|
524 |
+
" \"rescale_factor\": 0.00392156862745098,\n",
|
525 |
+
" \"size\": {\n",
|
526 |
+
" \"shortest_edge\": 224\n",
|
527 |
+
" }\n",
|
528 |
+
"}\n",
|
529 |
+
"\n",
|
530 |
+
"loading file vocab.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/vocab.json\n",
|
531 |
+
"loading file merges.txt from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/merges.txt\n",
|
532 |
+
"loading file tokenizer.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/tokenizer.json\n",
|
533 |
+
"loading file added_tokens.json from cache at None\n",
|
534 |
+
"loading file special_tokens_map.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/special_tokens_map.json\n",
|
535 |
+
"loading file tokenizer_config.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/tokenizer_config.json\n",
|
536 |
+
"loading configuration file config.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/config.json\n",
|
537 |
+
"`text_config` is `None`. Initializing the `CLIPTextConfig` with default values.\n",
|
538 |
+
"`vision_config` is `None`. initializing the `CLIPVisionConfig` with default values.\n",
|
539 |
+
"Model config CLIPConfig {\n",
|
540 |
+
" \"_name_or_path\": \"openai/clip-vit-base-patch32\",\n",
|
541 |
+
" \"architectures\": [\n",
|
542 |
+
" \"CLIPModel\"\n",
|
543 |
+
" ],\n",
|
544 |
+
" \"initializer_factor\": 1.0,\n",
|
545 |
+
" \"logit_scale_init_value\": 2.6592,\n",
|
546 |
+
" \"model_type\": \"clip\",\n",
|
547 |
+
" \"projection_dim\": 512,\n",
|
548 |
+
" \"text_config\": {\n",
|
549 |
+
" \"bos_token_id\": 0,\n",
|
550 |
+
" \"dropout\": 0.0,\n",
|
551 |
+
" \"eos_token_id\": 2,\n",
|
552 |
+
" \"model_type\": \"clip_text_model\"\n",
|
553 |
+
" },\n",
|
554 |
+
" \"transformers_version\": \"4.36.2\",\n",
|
555 |
+
" \"vision_config\": {\n",
|
556 |
+
" \"dropout\": 0.0,\n",
|
557 |
+
" \"model_type\": \"clip_vision_model\"\n",
|
558 |
+
" }\n",
|
559 |
+
"}\n",
|
560 |
+
"\n"
|
561 |
+
]
|
562 |
+
}
|
563 |
+
],
|
564 |
+
"source": [
|
565 |
+
"from PIL import Image\n",
|
566 |
+
"import requests\n",
|
567 |
+
"from transformers import AutoProcessor, CLIPVisionModel\n",
|
568 |
+
"\n",
|
569 |
+
"model = CLIPVisionModel.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
|
570 |
+
"processor = AutoProcessor.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
|
571 |
+
"\n",
|
572 |
+
"url = \"http://images.cocodataset.org/val2017/000000039769.jpg\"\n",
|
573 |
+
"image = Image.open(requests.get(url, stream=True).raw)\n",
|
574 |
+
"\n",
|
575 |
+
"inputs = processor(images=image, return_tensors=\"pt\")\n",
|
576 |
+
"\n",
|
577 |
+
"outputs = model(**inputs)\n",
|
578 |
+
"last_hidden_state = outputs.last_hidden_state\n",
|
579 |
+
"pooled_output = outputs.pooler_output # pooled CLS states"
|
580 |
+
]
|
581 |
+
},
|
582 |
+
{
|
583 |
+
"cell_type": "code",
|
584 |
+
"execution_count": 31,
|
585 |
+
"id": "bcf0a7b3-6cbb-492e-bc2c-42e3edbe6a0c",
|
586 |
+
"metadata": {},
|
587 |
+
"outputs": [
|
588 |
+
{
|
589 |
+
"data": {
|
590 |
+
"text/plain": [
|
591 |
+
"torch.Size([1, 768])"
|
592 |
+
]
|
593 |
+
},
|
594 |
+
"execution_count": 31,
|
595 |
+
"metadata": {},
|
596 |
+
"output_type": "execute_result"
|
597 |
+
}
|
598 |
+
],
|
599 |
+
"source": [
|
600 |
+
"pooled_output.shape"
|
601 |
+
]
|
602 |
+
},
|
603 |
+
{
|
604 |
+
"cell_type": "code",
|
605 |
+
"execution_count": 10,
|
606 |
+
"id": "67240294-c7a0-4e94-a8c1-86bfe1b21977",
|
607 |
+
"metadata": {},
|
608 |
+
"outputs": [],
|
609 |
+
"source": [
|
610 |
+
"from transformers import CLIPPreTrainedModel\n",
|
611 |
+
"from transformers.models.clip.modeling_clip import CLIPVisionModelOutput, CLIPVisionTransformer\n",
|
612 |
+
"from typing import Optional, Union, Tuple"
|
613 |
+
]
|
614 |
+
},
|
615 |
+
{
|
616 |
+
"cell_type": "code",
|
617 |
+
"execution_count": 54,
|
618 |
+
"id": "cc9b20db-7f84-44c3-9c78-e84164ccc192",
|
619 |
+
"metadata": {},
|
620 |
+
"outputs": [],
|
621 |
+
"source": [
|
622 |
+
"class VisionLanguageConnector(nn.Module):\n",
|
623 |
+
" def __init__(self, hidden_size, projection_dim):\n",
|
624 |
+
" super().__init__()\n",
|
625 |
+
" self.mlp = nn.Sequential(\n",
|
626 |
+
" nn.Linear(hidden_size, hidden_size, bias=False),\n",
|
627 |
+
" nn.GELU(),\n",
|
628 |
+
" nn.Linear(hidden_size, projection_dim, bias=False)\n",
|
629 |
+
" )\n",
|
630 |
+
"\n",
|
631 |
+
" def forward(self, x):\n",
|
632 |
+
" return self.mlp(x)\n",
|
633 |
+
" \n",
|
634 |
+
"class ClipWithProjection(CLIPPreTrainedModel):\n",
|
635 |
+
" config_class = CLIPVisionConfig\n",
|
636 |
+
" main_input_name = \"pixel_values\"\n",
|
637 |
+
"\n",
|
638 |
+
" def __init__(self, config: CLIPVisionConfig):\n",
|
639 |
+
" super().__init__(config)\n",
|
640 |
+
"\n",
|
641 |
+
" self.vision_model = CLIPVisionTransformer(config)\n",
|
642 |
+
" self.vision_model.\n",
|
643 |
+
" self.vision_language_connector = VisionLanguageConnector(config.hidden_size, config.projection_dim)\n",
|
644 |
+
"\n",
|
645 |
+
" # Initialize weights and apply final processing\n",
|
646 |
+
" self.post_init()\n",
|
647 |
+
"\n",
|
648 |
+
" def forward(\n",
|
649 |
+
" self,\n",
|
650 |
+
" pixel_values: Optional[torch.FloatTensor] = None,\n",
|
651 |
+
" output_attentions: Optional[bool] = None,\n",
|
652 |
+
" output_hidden_states: Optional[bool] = None,\n",
|
653 |
+
" return_dict: Optional[bool] = None,\n",
|
654 |
+
" ) -> Union[Tuple, CLIPVisionModelOutput]:\n",
|
655 |
+
" return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n",
|
656 |
+
"\n",
|
657 |
+
" vision_outputs = self.vision_model(\n",
|
658 |
+
" pixel_values=pixel_values,\n",
|
659 |
+
" output_attentions=output_attentions,\n",
|
660 |
+
" output_hidden_states=output_hidden_states,\n",
|
661 |
+
" return_dict=return_dict,\n",
|
662 |
+
" )\n",
|
663 |
+
"\n",
|
664 |
+
" pooled_output = vision_outputs[1] # pooled_output\n",
|
665 |
+
"\n",
|
666 |
+
" image_embeds = self.vision_language_connector(pooled_output)\n",
|
667 |
+
"\n",
|
668 |
+
" if not return_dict:\n",
|
669 |
+
" outputs = (image_embeds, vision_outputs[0]) + vision_outputs[2:]\n",
|
670 |
+
" return tuple(output for output in outputs if output is not None)\n",
|
671 |
+
"\n",
|
672 |
+
" return CLIPVisionModelOutput(\n",
|
673 |
+
" image_embeds=image_embeds,\n",
|
674 |
+
" last_hidden_state=vision_outputs.last_hidden_state,\n",
|
675 |
+
" hidden_states=vision_outputs.hidden_states,\n",
|
676 |
+
" attentions=vision_outputs.attentions,\n",
|
677 |
+
" )"
|
678 |
+
]
|
679 |
+
},
|
680 |
+
{
|
681 |
+
"cell_type": "code",
|
682 |
+
"execution_count": 55,
|
683 |
+
"id": "a4892ab8-39d2-41c9-ad2a-04711c22b95f",
|
684 |
+
"metadata": {
|
685 |
+
"collapsed": true,
|
686 |
+
"jupyter": {
|
687 |
+
"outputs_hidden": true
|
688 |
+
}
|
689 |
+
},
|
690 |
+
"outputs": [
|
691 |
+
{
|
692 |
+
"name": "stderr",
|
693 |
+
"output_type": "stream",
|
694 |
+
"text": [
|
695 |
+
"loading configuration file config.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/config.json\n",
|
696 |
+
"Model config CLIPVisionConfig {\n",
|
697 |
+
" \"attention_dropout\": 0.0,\n",
|
698 |
+
" \"dropout\": 0.0,\n",
|
699 |
+
" \"hidden_act\": \"quick_gelu\",\n",
|
700 |
+
" \"hidden_size\": 768,\n",
|
701 |
+
" \"image_size\": 224,\n",
|
702 |
+
" \"initializer_factor\": 1.0,\n",
|
703 |
+
" \"initializer_range\": 0.02,\n",
|
704 |
+
" \"intermediate_size\": 3072,\n",
|
705 |
+
" \"layer_norm_eps\": 1e-05,\n",
|
706 |
+
" \"model_type\": \"clip_vision_model\",\n",
|
707 |
+
" \"num_attention_heads\": 12,\n",
|
708 |
+
" \"num_channels\": 3,\n",
|
709 |
+
" \"num_hidden_layers\": 12,\n",
|
710 |
+
" \"patch_size\": 32,\n",
|
711 |
+
" \"projection_dim\": 512,\n",
|
712 |
+
" \"transformers_version\": \"4.36.2\"\n",
|
713 |
+
"}\n",
|
714 |
+
"\n",
|
715 |
+
"loading weights file pytorch_model.bin from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/pytorch_model.bin\n",
|
716 |
+
"Some weights of the model checkpoint at openai/clip-vit-base-patch32 were not used when initializing ClipWithProjection: ['text_model.encoder.layers.8.mlp.fc2.weight', 'text_model.encoder.layers.4.self_attn.v_proj.weight', 'text_model.encoder.layers.2.mlp.fc2.bias', 'text_model.encoder.layers.2.self_attn.q_proj.bias', 'text_model.encoder.layers.4.self_attn.v_proj.bias', 'text_model.encoder.layers.6.self_attn.k_proj.bias', 'text_model.encoder.layers.9.self_attn.k_proj.weight', 'text_model.encoder.layers.11.self_attn.q_proj.bias', 'text_model.encoder.layers.3.self_attn.out_proj.bias', 'text_model.encoder.layers.11.self_attn.k_proj.bias', 'text_model.encoder.layers.7.self_attn.k_proj.bias', 'text_model.encoder.layers.1.self_attn.q_proj.weight', 'text_model.encoder.layers.11.layer_norm1.bias', 'text_model.encoder.layers.11.mlp.fc2.bias', 'text_model.encoder.layers.10.layer_norm1.bias', 'text_model.encoder.layers.7.self_attn.q_proj.bias', 'text_model.encoder.layers.11.self_attn.k_proj.weight', 'text_model.encoder.layers.5.self_attn.q_proj.bias', 'text_model.encoder.layers.0.self_attn.v_proj.bias', 'logit_scale', 'text_model.encoder.layers.8.mlp.fc1.bias', 'text_model.encoder.layers.6.layer_norm1.bias', 'text_model.encoder.layers.5.self_attn.out_proj.weight', 'text_model.encoder.layers.7.self_attn.out_proj.bias', 'text_model.embeddings.token_embedding.weight', 'text_model.encoder.layers.8.layer_norm2.bias', 'text_model.encoder.layers.9.self_attn.v_proj.weight', 'text_model.encoder.layers.1.mlp.fc1.weight', 'text_model.encoder.layers.0.layer_norm1.weight', 'text_model.encoder.layers.6.self_attn.k_proj.weight', 'text_model.encoder.layers.3.self_attn.q_proj.weight', 'text_model.encoder.layers.2.layer_norm1.weight', 'text_model.encoder.layers.0.self_attn.v_proj.weight', 'text_model.encoder.layers.7.self_attn.q_proj.weight', 'text_model.encoder.layers.1.layer_norm2.weight', 'text_model.encoder.layers.2.self_attn.out_proj.weight', 'text_model.encoder.layers.3.self_attn.k_proj.weight', 'text_model.encoder.layers.7.mlp.fc2.bias', 'text_model.encoder.layers.10.self_attn.out_proj.weight', 'text_model.encoder.layers.2.self_attn.q_proj.weight', 'text_model.encoder.layers.1.self_attn.k_proj.weight', 'text_model.encoder.layers.4.layer_norm1.weight', 'text_model.encoder.layers.0.self_attn.q_proj.weight', 'text_model.encoder.layers.0.self_attn.out_proj.bias', 'text_model.encoder.layers.4.self_attn.out_proj.bias', 'text_model.encoder.layers.5.self_attn.k_proj.weight', 'visual_projection.weight', 'text_model.encoder.layers.6.layer_norm2.bias', 'text_model.encoder.layers.6.layer_norm1.weight', 'text_model.encoder.layers.4.self_attn.out_proj.weight', 'text_model.encoder.layers.10.mlp.fc2.bias', 'text_model.encoder.layers.10.mlp.fc1.weight', 'text_model.encoder.layers.6.self_attn.out_proj.weight', 'text_model.encoder.layers.9.layer_norm1.weight', 'text_model.encoder.layers.11.layer_norm2.weight', 'text_model.encoder.layers.6.self_attn.q_proj.bias', 'text_model.encoder.layers.5.mlp.fc1.weight', 'text_model.encoder.layers.2.mlp.fc1.weight', 'text_model.encoder.layers.11.self_attn.out_proj.weight', 'text_model.encoder.layers.0.self_attn.out_proj.weight', 'text_model.encoder.layers.11.mlp.fc2.weight', 'text_model.encoder.layers.7.layer_norm2.weight', 'text_model.encoder.layers.10.self_attn.v_proj.bias', 'text_model.encoder.layers.9.mlp.fc1.bias', 'text_model.encoder.layers.8.self_attn.v_proj.weight', 'text_model.encoder.layers.3.layer_norm1.bias', 'text_model.encoder.layers.6.self_attn.v_proj.bias', 'text_model.encoder.layers.1.self_attn.v_proj.bias', 'text_model.encoder.layers.9.self_attn.q_proj.weight', 'text_model.encoder.layers.4.self_attn.k_proj.weight', 'text_model.encoder.layers.7.layer_norm1.weight', 'text_model.encoder.layers.10.self_attn.k_proj.weight', 'text_model.encoder.layers.7.self_attn.v_proj.bias', 'text_model.encoder.layers.7.mlp.fc1.bias', 'text_model.encoder.layers.11.mlp.fc1.weight', 'text_model.encoder.layers.2.mlp.fc1.bias', 'text_model.encoder.layers.3.mlp.fc2.bias', 'text_model.encoder.layers.8.self_attn.q_proj.weight', 'text_model.encoder.layers.0.mlp.fc1.weight', 'text_model.encoder.layers.11.self_attn.out_proj.bias', 'text_model.encoder.layers.1.self_attn.v_proj.weight', 'text_model.encoder.layers.0.self_attn.k_proj.weight', 'text_model.encoder.layers.9.layer_norm1.bias', 'text_model.final_layer_norm.weight', 'text_model.encoder.layers.3.layer_norm1.weight', 'text_model.encoder.layers.4.mlp.fc1.bias', 'text_model.encoder.layers.1.layer_norm1.weight', 'text_model.encoder.layers.10.layer_norm2.bias', 'text_model.encoder.layers.9.self_attn.v_proj.bias', 'text_model.encoder.layers.10.self_attn.k_proj.bias', 'text_model.encoder.layers.8.mlp.fc2.bias', 'text_model.encoder.layers.5.mlp.fc2.bias', 'text_model.encoder.layers.6.self_attn.q_proj.weight', 'text_model.encoder.layers.5.self_attn.out_proj.bias', 'text_model.encoder.layers.9.mlp.fc2.bias', 'text_model.encoder.layers.5.layer_norm2.weight', 'text_model.encoder.layers.2.mlp.fc2.weight', 'text_model.encoder.layers.3.self_attn.out_proj.weight', 'text_model.encoder.layers.6.mlp.fc2.weight', 'text_model.encoder.layers.1.self_attn.out_proj.weight', 'text_model.encoder.layers.1.mlp.fc2.bias', 'text_model.encoder.layers.7.mlp.fc2.weight', 'text_model.encoder.layers.10.self_attn.v_proj.weight', 'text_model.encoder.layers.11.self_attn.v_proj.bias', 'text_model.encoder.layers.4.layer_norm1.bias', 'text_model.encoder.layers.4.layer_norm2.bias', 'text_model.encoder.layers.8.self_attn.q_proj.bias', 'text_model.embeddings.position_ids', 'text_model.encoder.layers.10.layer_norm2.weight', 'text_model.encoder.layers.1.self_attn.out_proj.bias', 'text_model.encoder.layers.2.layer_norm2.weight', 'text_model.encoder.layers.10.self_attn.q_proj.weight', 'text_model.encoder.layers.4.mlp.fc1.weight', 'text_model.encoder.layers.8.layer_norm1.bias', 'text_model.encoder.layers.2.self_attn.k_proj.weight', 'text_model.encoder.layers.5.mlp.fc1.bias', 'text_model.encoder.layers.9.self_attn.out_proj.bias', 'text_model.encoder.layers.7.self_attn.v_proj.weight', 'text_model.encoder.layers.2.self_attn.k_proj.bias', 'text_model.encoder.layers.5.self_attn.k_proj.bias', 'text_model.encoder.layers.8.self_attn.out_proj.bias', 'text_model.encoder.layers.7.self_attn.k_proj.weight', 'text_model.encoder.layers.6.mlp.fc1.weight', 'text_model.encoder.layers.6.mlp.fc1.bias', 'text_model.encoder.layers.3.self_attn.v_proj.weight', 'text_model.encoder.layers.3.self_attn.q_proj.bias', 'text_model.encoder.layers.9.self_attn.out_proj.weight', 'text_model.encoder.layers.3.mlp.fc1.bias', 'text_model.encoder.layers.0.self_attn.q_proj.bias', 'text_model.encoder.layers.1.layer_norm2.bias', 'text_model.encoder.layers.8.layer_norm2.weight', 'text_model.encoder.layers.5.self_attn.q_proj.weight', 'text_model.encoder.layers.4.layer_norm2.weight', 'text_model.encoder.layers.4.mlp.fc2.bias', 'text_model.encoder.layers.9.mlp.fc2.weight', 'text_model.encoder.layers.8.self_attn.k_proj.weight', 'text_model.encoder.layers.10.layer_norm1.weight', 'text_model.encoder.layers.0.self_attn.k_proj.bias', 'text_model.encoder.layers.8.self_attn.k_proj.bias', 'text_model.encoder.layers.9.layer_norm2.weight', 'text_model.encoder.layers.4.self_attn.k_proj.bias', 'text_model.encoder.layers.6.layer_norm2.weight', 'text_model.encoder.layers.0.layer_norm2.weight', 'text_model.encoder.layers.5.self_attn.v_proj.bias', 'text_model.encoder.layers.3.layer_norm2.bias', 'text_model.encoder.layers.8.mlp.fc1.weight', 'text_model.encoder.layers.4.self_attn.q_proj.bias', 'text_model.encoder.layers.8.layer_norm1.weight', 'text_model.encoder.layers.2.self_attn.v_proj.weight', 'text_model.encoder.layers.3.self_attn.v_proj.bias', 'text_model.encoder.layers.11.mlp.fc1.bias', 'text_model.encoder.layers.6.mlp.fc2.bias', 'text_model.encoder.layers.1.mlp.fc1.bias', 'text_model.encoder.layers.2.self_attn.v_proj.bias', 'text_model.encoder.layers.5.mlp.fc2.weight', 'text_model.encoder.layers.8.self_attn.v_proj.bias', 'text_model.encoder.layers.10.self_attn.out_proj.bias', 'text_model.encoder.layers.5.layer_norm1.bias', 'text_model.encoder.layers.5.self_attn.v_proj.weight', 'text_model.encoder.layers.10.self_attn.q_proj.bias', 'text_model.encoder.layers.2.layer_norm2.bias', 'text_model.encoder.layers.7.layer_norm1.bias', 'text_model.encoder.layers.4.mlp.fc2.weight', 'text_model.encoder.layers.10.mlp.fc2.weight', 'text_model.encoder.layers.3.mlp.fc1.weight', 'text_model.encoder.layers.5.layer_norm2.bias', 'text_model.encoder.layers.9.self_attn.q_proj.bias', 'text_model.encoder.layers.1.self_attn.k_proj.bias', 'text_model.encoder.layers.7.self_attn.out_proj.weight', 'text_model.encoder.layers.0.mlp.fc2.weight', 'text_model.encoder.layers.11.self_attn.v_proj.weight', 'text_model.encoder.layers.1.layer_norm1.bias', 'text_model.encoder.layers.1.mlp.fc2.weight', 'text_model.encoder.layers.9.layer_norm2.bias', 'text_model.encoder.layers.9.self_attn.k_proj.bias', 'text_model.encoder.layers.11.layer_norm1.weight', 'text_model.encoder.layers.8.self_attn.out_proj.weight', 'text_model.encoder.layers.0.layer_norm1.bias', 'text_model.encoder.layers.7.mlp.fc1.weight', 'text_model.encoder.layers.0.mlp.fc1.bias', 'text_model.encoder.layers.0.layer_norm2.bias', 'text_model.encoder.layers.3.self_attn.k_proj.bias', 'text_model.encoder.layers.5.layer_norm1.weight', 'text_model.encoder.layers.3.layer_norm2.weight', 'text_model.encoder.layers.1.self_attn.q_proj.bias', 'text_model.encoder.layers.2.self_attn.out_proj.bias', 'text_model.encoder.layers.3.mlp.fc2.weight', 'text_model.encoder.layers.11.self_attn.q_proj.weight', 'text_model.final_layer_norm.bias', 'text_model.encoder.layers.6.self_attn.v_proj.weight', 'text_model.encoder.layers.0.mlp.fc2.bias', 'text_model.encoder.layers.7.layer_norm2.bias', 'text_model.encoder.layers.10.mlp.fc1.bias', 'text_model.embeddings.position_embedding.weight', 'text_model.encoder.layers.6.self_attn.out_proj.bias', 'text_model.encoder.layers.2.layer_norm1.bias', 'text_model.encoder.layers.9.mlp.fc1.weight', 'text_projection.weight', 'text_model.encoder.layers.11.layer_norm2.bias', 'text_model.encoder.layers.4.self_attn.q_proj.weight']\n",
|
717 |
+
"- This IS expected if you are initializing ClipWithProjection from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
|
718 |
+
"- This IS NOT expected if you are initializing ClipWithProjection from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
|
719 |
+
"Some weights of ClipWithProjection were not initialized from the model checkpoint at openai/clip-vit-base-patch32 and are newly initialized: ['vision_language_connector.mlp.2.weight', 'vision_language_connector.mlp.0.weight']\n",
|
720 |
+
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
721 |
+
]
|
722 |
+
}
|
723 |
+
],
|
724 |
+
"source": [
|
725 |
+
"model = ClipWithProjection.from_pretrained(\"openai/clip-vit-base-patch32\")"
|
726 |
+
]
|
727 |
+
},
|
728 |
+
{
|
729 |
+
"cell_type": "code",
|
730 |
+
"execution_count": 56,
|
731 |
+
"id": "588ef914-5be9-49e1-b68d-b899e0e74edd",
|
732 |
+
"metadata": {},
|
733 |
+
"outputs": [
|
734 |
+
{
|
735 |
+
"data": {
|
736 |
+
"text/plain": [
|
737 |
+
"768"
|
738 |
+
]
|
739 |
+
},
|
740 |
+
"execution_count": 56,
|
741 |
+
"metadata": {},
|
742 |
+
"output_type": "execute_result"
|
743 |
+
}
|
744 |
+
],
|
745 |
+
"source": [
|
746 |
+
"model.config.hidden_size"
|
747 |
+
]
|
748 |
+
},
|
749 |
+
{
|
750 |
+
"cell_type": "code",
|
751 |
+
"execution_count": 57,
|
752 |
+
"id": "05d95b9e-9831-4415-860e-94793e29d210",
|
753 |
+
"metadata": {},
|
754 |
+
"outputs": [],
|
755 |
+
"source": [
|
756 |
+
"outputs = model(**inputs)"
|
757 |
+
]
|
758 |
+
},
|
759 |
+
{
|
760 |
+
"cell_type": "code",
|
761 |
+
"execution_count": 61,
|
762 |
+
"id": "185b1bff-6ffe-4cce-9255-ee7629feba54",
|
763 |
+
"metadata": {},
|
764 |
+
"outputs": [
|
765 |
+
{
|
766 |
+
"data": {
|
767 |
+
"text/plain": [
|
768 |
+
"torch.Size([1, 512])"
|
769 |
+
]
|
770 |
+
},
|
771 |
+
"execution_count": 61,
|
772 |
+
"metadata": {},
|
773 |
+
"output_type": "execute_result"
|
774 |
+
}
|
775 |
+
],
|
776 |
+
"source": [
|
777 |
+
"outputs[0].shape"
|
778 |
+
]
|
779 |
+
},
|
780 |
+
{
|
781 |
+
"cell_type": "code",
|
782 |
+
"execution_count": null,
|
783 |
+
"id": "04414a35-c7b3-4986-a79e-1d363916caa4",
|
784 |
+
"metadata": {},
|
785 |
+
"outputs": [],
|
786 |
+
"source": []
|
787 |
+
},
|
788 |
+
{
|
789 |
+
"cell_type": "code",
|
790 |
+
"execution_count": 1,
|
791 |
+
"id": "485dbbcb-06df-4926-b257-dfd1a4081d44",
|
792 |
+
"metadata": {},
|
793 |
+
"outputs": [
|
794 |
+
{
|
795 |
+
"ename": "NameError",
|
796 |
+
"evalue": "name 'outputs' is not defined",
|
797 |
+
"output_type": "error",
|
798 |
+
"traceback": [
|
799 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
800 |
+
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
|
801 |
+
"Cell \u001b[0;32mIn[1], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43moutputs\u001b[49m[\u001b[38;5;241m0\u001b[39m]\n",
|
802 |
+
"\u001b[0;31mNameError\u001b[0m: name 'outputs' is not defined"
|
803 |
+
]
|
804 |
+
}
|
805 |
+
],
|
806 |
+
"source": [
|
807 |
+
"outputs[0]"
|
808 |
+
]
|
809 |
+
},
|
810 |
+
{
|
811 |
+
"cell_type": "code",
|
812 |
+
"execution_count": null,
|
813 |
+
"id": "f983313c-8e0f-4805-af14-25bb69afd04c",
|
814 |
+
"metadata": {},
|
815 |
+
"outputs": [],
|
816 |
+
"source": []
|
817 |
+
}
|
818 |
+
],
|
819 |
+
"metadata": {
|
820 |
+
"kernelspec": {
|
821 |
+
"display_name": "Python 3 (ipykernel)",
|
822 |
+
"language": "python",
|
823 |
+
"name": "python3"
|
824 |
+
},
|
825 |
+
"language_info": {
|
826 |
+
"codemirror_mode": {
|
827 |
+
"name": "ipython",
|
828 |
+
"version": 3
|
829 |
+
},
|
830 |
+
"file_extension": ".py",
|
831 |
+
"mimetype": "text/x-python",
|
832 |
+
"name": "python",
|
833 |
+
"nbconvert_exporter": "python",
|
834 |
+
"pygments_lexer": "ipython3",
|
835 |
+
"version": "3.10.12"
|
836 |
+
}
|
837 |
+
},
|
838 |
+
"nbformat": 4,
|
839 |
+
"nbformat_minor": 5
|
840 |
+
}
|
Experiments/eval.ipynb
ADDED
@@ -0,0 +1,782 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 4,
|
6 |
+
"id": "215cfd2f-62b0-4a86-a407-777a1d32597f",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [
|
9 |
+
{
|
10 |
+
"name": "stdout",
|
11 |
+
"output_type": "stream",
|
12 |
+
"text": [
|
13 |
+
"[2024-01-24 15:18:49,948] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n"
|
14 |
+
]
|
15 |
+
}
|
16 |
+
],
|
17 |
+
"source": [
|
18 |
+
"from PIL import Image\n",
|
19 |
+
"import requests\n",
|
20 |
+
"\n",
|
21 |
+
"import torch\n",
|
22 |
+
"from torch import nn\n",
|
23 |
+
"from transformers import AutoProcessor, CLIPVisionModel, CLIPVisionConfig, CLIPPreTrainedModel\n",
|
24 |
+
"from transformers.models.clip.modeling_clip import CLIPVisionModelOutput, CLIPVisionTransformer\n",
|
25 |
+
"from transformers import WhisperProcessor, WhisperForConditionalGeneration\n",
|
26 |
+
"from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoTokenizer"
|
27 |
+
]
|
28 |
+
},
|
29 |
+
{
|
30 |
+
"cell_type": "code",
|
31 |
+
"execution_count": 5,
|
32 |
+
"id": "2244e8f3-fcc7-4309-9d4d-fea557f89f79",
|
33 |
+
"metadata": {},
|
34 |
+
"outputs": [],
|
35 |
+
"source": [
|
36 |
+
"from llava_phi import LlavaPhiForCausalLM"
|
37 |
+
]
|
38 |
+
},
|
39 |
+
{
|
40 |
+
"cell_type": "code",
|
41 |
+
"execution_count": 3,
|
42 |
+
"id": "587883e1-3419-4b14-b16b-38fabbc8bfaa",
|
43 |
+
"metadata": {},
|
44 |
+
"outputs": [],
|
45 |
+
"source": [
|
46 |
+
"# model = LlavaPhiForCausalLM.from_pretrained(\"./llava-phi/checkpoints/llavaPhi-v0-3b-finetune/checkpoint-4000\")"
|
47 |
+
]
|
48 |
+
},
|
49 |
+
{
|
50 |
+
"cell_type": "code",
|
51 |
+
"execution_count": 4,
|
52 |
+
"id": "0e27a7db-e2ab-4d65-b21d-497222e318ad",
|
53 |
+
"metadata": {},
|
54 |
+
"outputs": [],
|
55 |
+
"source": [
|
56 |
+
"# processor = AutoProcessor.from_pretrained(\"./llava-phi/checkpoints/llavaPhi-v0-3b-finetune/checkpoint-4000\")"
|
57 |
+
]
|
58 |
+
},
|
59 |
+
{
|
60 |
+
"cell_type": "code",
|
61 |
+
"execution_count": 5,
|
62 |
+
"id": "663efdd8-ea21-4231-a2ae-bcc0fb47b46a",
|
63 |
+
"metadata": {},
|
64 |
+
"outputs": [],
|
65 |
+
"source": [
|
66 |
+
"# prompt = \"<image>\\nUSER: What's the content of the image?\\nASSISTANT:\"\n",
|
67 |
+
"# url = \"https://www.ilankelman.org/stopsigns/australia.jpg\"\n",
|
68 |
+
"# image = Image.open(requests.get(url, stream=True).raw)"
|
69 |
+
]
|
70 |
+
},
|
71 |
+
{
|
72 |
+
"cell_type": "code",
|
73 |
+
"execution_count": 6,
|
74 |
+
"id": "f622609f-f6a7-4ec1-ac35-c1d33d9436ca",
|
75 |
+
"metadata": {},
|
76 |
+
"outputs": [],
|
77 |
+
"source": [
|
78 |
+
"# # Generate\n",
|
79 |
+
"# generate_ids = model.generate(**inputs, max_length=30)\n",
|
80 |
+
"# processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]"
|
81 |
+
]
|
82 |
+
},
|
83 |
+
{
|
84 |
+
"cell_type": "code",
|
85 |
+
"execution_count": 6,
|
86 |
+
"id": "45f5ba72-2e41-4ccc-84c1-97d542ebee63",
|
87 |
+
"metadata": {},
|
88 |
+
"outputs": [],
|
89 |
+
"source": [
|
90 |
+
"from llava_phi.model.builder import load_pretrained_model\n",
|
91 |
+
"from llava_phi.mm_utils import tokenizer_image_token, get_model_name_from_path\n",
|
92 |
+
"from llava_phi.utils import disable_torch_init\n",
|
93 |
+
"from llava_phi.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN\n",
|
94 |
+
"from llava_phi.conversation import conv_templates, SeparatorStyle"
|
95 |
+
]
|
96 |
+
},
|
97 |
+
{
|
98 |
+
"cell_type": "code",
|
99 |
+
"execution_count": 11,
|
100 |
+
"id": "b98ac5d3-5503-4430-81d1-19a4f8d6bd75",
|
101 |
+
"metadata": {},
|
102 |
+
"outputs": [],
|
103 |
+
"source": [
|
104 |
+
"model_path = \"checkpoints/llavaPhi-v0-3b-finetune/checkpoint-4000\"\n",
|
105 |
+
"model_name = get_model_name_from_path(model_path)"
|
106 |
+
]
|
107 |
+
},
|
108 |
+
{
|
109 |
+
"cell_type": "code",
|
110 |
+
"execution_count": 12,
|
111 |
+
"id": "42fd5721-75a7-475b-bd30-5ee23aeaac64",
|
112 |
+
"metadata": {},
|
113 |
+
"outputs": [
|
114 |
+
{
|
115 |
+
"data": {
|
116 |
+
"text/plain": [
|
117 |
+
"'llavaPhi-v0-3b-finetune_checkpoint-4000'"
|
118 |
+
]
|
119 |
+
},
|
120 |
+
"execution_count": 12,
|
121 |
+
"metadata": {},
|
122 |
+
"output_type": "execute_result"
|
123 |
+
}
|
124 |
+
],
|
125 |
+
"source": [
|
126 |
+
"model_name"
|
127 |
+
]
|
128 |
+
},
|
129 |
+
{
|
130 |
+
"cell_type": "code",
|
131 |
+
"execution_count": 13,
|
132 |
+
"id": "8c2076b5-3bfc-48fd-917b-5dfd06fc532f",
|
133 |
+
"metadata": {},
|
134 |
+
"outputs": [
|
135 |
+
{
|
136 |
+
"name": "stderr",
|
137 |
+
"output_type": "stream",
|
138 |
+
"text": [
|
139 |
+
"The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.\n"
|
140 |
+
]
|
141 |
+
},
|
142 |
+
{
|
143 |
+
"name": "stdout",
|
144 |
+
"output_type": "stream",
|
145 |
+
"text": [
|
146 |
+
"load llaVA-Phi MLLM!!!\n"
|
147 |
+
]
|
148 |
+
},
|
149 |
+
{
|
150 |
+
"name": "stderr",
|
151 |
+
"output_type": "stream",
|
152 |
+
"text": [
|
153 |
+
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n",
|
154 |
+
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
|
155 |
+
]
|
156 |
+
},
|
157 |
+
{
|
158 |
+
"data": {
|
159 |
+
"application/vnd.jupyter.widget-view+json": {
|
160 |
+
"model_id": "20b86f2c01744081b537620c8780f12e",
|
161 |
+
"version_major": 2,
|
162 |
+
"version_minor": 0
|
163 |
+
},
|
164 |
+
"text/plain": [
|
165 |
+
"Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
|
166 |
+
]
|
167 |
+
},
|
168 |
+
"metadata": {},
|
169 |
+
"output_type": "display_data"
|
170 |
+
},
|
171 |
+
{
|
172 |
+
"name": "stdout",
|
173 |
+
"output_type": "stream",
|
174 |
+
"text": [
|
175 |
+
"{'device_map': 'cuda'}\n"
|
176 |
+
]
|
177 |
+
}
|
178 |
+
],
|
179 |
+
"source": [
|
180 |
+
"tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name)"
|
181 |
+
]
|
182 |
+
},
|
183 |
+
{
|
184 |
+
"cell_type": "code",
|
185 |
+
"execution_count": 14,
|
186 |
+
"id": "4e46221e-0907-453e-8126-76199828493e",
|
187 |
+
"metadata": {},
|
188 |
+
"outputs": [],
|
189 |
+
"source": [
|
190 |
+
"qs = \"What's the content of the image?\"\n",
|
191 |
+
"qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\\n' + qs"
|
192 |
+
]
|
193 |
+
},
|
194 |
+
{
|
195 |
+
"cell_type": "code",
|
196 |
+
"execution_count": 15,
|
197 |
+
"id": "07355444-0eb8-4d4d-ad50-48b91c969664",
|
198 |
+
"metadata": {},
|
199 |
+
"outputs": [],
|
200 |
+
"source": [
|
201 |
+
"conv = conv_templates[\"default\"].copy()\n",
|
202 |
+
"conv.append_message(conv.roles[0], qs)\n",
|
203 |
+
"conv.append_message(conv.roles[1], None)\n",
|
204 |
+
"prompt = conv.get_prompt()"
|
205 |
+
]
|
206 |
+
},
|
207 |
+
{
|
208 |
+
"cell_type": "code",
|
209 |
+
"execution_count": 16,
|
210 |
+
"id": "ccb5674f-aff8-456e-b61b-1d167864f1a6",
|
211 |
+
"metadata": {},
|
212 |
+
"outputs": [
|
213 |
+
{
|
214 |
+
"data": {
|
215 |
+
"text/plain": [
|
216 |
+
"\"A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: <im_start><image><im_end>\\nWhat's the content of the image? ASSISTANT:\""
|
217 |
+
]
|
218 |
+
},
|
219 |
+
"execution_count": 16,
|
220 |
+
"metadata": {},
|
221 |
+
"output_type": "execute_result"
|
222 |
+
}
|
223 |
+
],
|
224 |
+
"source": [
|
225 |
+
"prompt"
|
226 |
+
]
|
227 |
+
},
|
228 |
+
{
|
229 |
+
"cell_type": "code",
|
230 |
+
"execution_count": 17,
|
231 |
+
"id": "a89cc181-2214-4844-b966-164a41744e54",
|
232 |
+
"metadata": {},
|
233 |
+
"outputs": [],
|
234 |
+
"source": [
|
235 |
+
"url = \"https://www.ilankelman.org/stopsigns/australia.jpg\"\n",
|
236 |
+
"image = Image.open(requests.get(url, stream=True).raw)\n",
|
237 |
+
"image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].cuda()\n",
|
238 |
+
"\n",
|
239 |
+
"input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()\n",
|
240 |
+
"\n",
|
241 |
+
"stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2"
|
242 |
+
]
|
243 |
+
},
|
244 |
+
{
|
245 |
+
"cell_type": "code",
|
246 |
+
"execution_count": 25,
|
247 |
+
"id": "0d519851-64d4-4cf5-b2eb-19474f9aa260",
|
248 |
+
"metadata": {},
|
249 |
+
"outputs": [
|
250 |
+
{
|
251 |
+
"data": {
|
252 |
+
"text/plain": [
|
253 |
+
"torch.Size([1, 55])"
|
254 |
+
]
|
255 |
+
},
|
256 |
+
"execution_count": 25,
|
257 |
+
"metadata": {},
|
258 |
+
"output_type": "execute_result"
|
259 |
+
}
|
260 |
+
],
|
261 |
+
"source": [
|
262 |
+
"input_ids.shape"
|
263 |
+
]
|
264 |
+
},
|
265 |
+
{
|
266 |
+
"cell_type": "code",
|
267 |
+
"execution_count": 24,
|
268 |
+
"id": "1694ff36-f214-4ed3-b2f3-d3dbd0a1a25b",
|
269 |
+
"metadata": {},
|
270 |
+
"outputs": [
|
271 |
+
{
|
272 |
+
"name": "stderr",
|
273 |
+
"output_type": "stream",
|
274 |
+
"text": [
|
275 |
+
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
|
276 |
+
]
|
277 |
+
}
|
278 |
+
],
|
279 |
+
"source": [
|
280 |
+
"from datasets import load_dataset\n",
|
281 |
+
"audio_ds = load_dataset(\"hf-internal-testing/librispeech_asr_dummy\", \"clean\", split=\"validation\")\n",
|
282 |
+
"audio = audio_ds[0][\"audio\"]\n",
|
283 |
+
"\n",
|
284 |
+
"whisper_w_proj = WhisperWithProjection(projection_dim=512)\n",
|
285 |
+
"audio_embed = whisper_w_proj(audio)[\"input_ids\"]"
|
286 |
+
]
|
287 |
+
},
|
288 |
+
{
|
289 |
+
"cell_type": "code",
|
290 |
+
"execution_count": 28,
|
291 |
+
"id": "9c4a9fae-d6ed-4fc2-ba02-97df64cddd93",
|
292 |
+
"metadata": {},
|
293 |
+
"outputs": [
|
294 |
+
{
|
295 |
+
"data": {
|
296 |
+
"text/plain": [
|
297 |
+
"(torch.Size([1, 33]), device(type='cpu'))"
|
298 |
+
]
|
299 |
+
},
|
300 |
+
"execution_count": 28,
|
301 |
+
"metadata": {},
|
302 |
+
"output_type": "execute_result"
|
303 |
+
}
|
304 |
+
],
|
305 |
+
"source": [
|
306 |
+
"audio_embed.shape, audio_embed.device"
|
307 |
+
]
|
308 |
+
},
|
309 |
+
{
|
310 |
+
"cell_type": "code",
|
311 |
+
"execution_count": 29,
|
312 |
+
"id": "c3fffe29-98fb-4f4b-ac51-4bdda9e46752",
|
313 |
+
"metadata": {},
|
314 |
+
"outputs": [],
|
315 |
+
"source": [
|
316 |
+
"input_ids = torch.concat([input_ids, audio_embed.to(\"cuda:0\")], dim=1)"
|
317 |
+
]
|
318 |
+
},
|
319 |
+
{
|
320 |
+
"cell_type": "code",
|
321 |
+
"execution_count": 30,
|
322 |
+
"id": "5dee1ec8-2db2-4f65-99e8-d34bd2735c9c",
|
323 |
+
"metadata": {},
|
324 |
+
"outputs": [
|
325 |
+
{
|
326 |
+
"data": {
|
327 |
+
"text/plain": [
|
328 |
+
"torch.Size([1, 88])"
|
329 |
+
]
|
330 |
+
},
|
331 |
+
"execution_count": 30,
|
332 |
+
"metadata": {},
|
333 |
+
"output_type": "execute_result"
|
334 |
+
}
|
335 |
+
],
|
336 |
+
"source": [
|
337 |
+
"input_ids.shape"
|
338 |
+
]
|
339 |
+
},
|
340 |
+
{
|
341 |
+
"cell_type": "code",
|
342 |
+
"execution_count": 31,
|
343 |
+
"id": "96033b43-4f57-4f0c-bcf7-37b57ca02e47",
|
344 |
+
"metadata": {},
|
345 |
+
"outputs": [],
|
346 |
+
"source": [
|
347 |
+
"with torch.inference_mode():\n",
|
348 |
+
" output_ids = model.generate(\n",
|
349 |
+
" input_ids,\n",
|
350 |
+
" images=image_tensor,\n",
|
351 |
+
" do_sample=True,\n",
|
352 |
+
" temperature=0.2,\n",
|
353 |
+
" max_new_tokens=1024,\n",
|
354 |
+
" eos_token_id=tokenizer.eos_token_id, # End of sequence token\n",
|
355 |
+
" pad_token_id=tokenizer.eos_token_id, # Pad token\n",
|
356 |
+
" use_cache=True,\n",
|
357 |
+
" )"
|
358 |
+
]
|
359 |
+
},
|
360 |
+
{
|
361 |
+
"cell_type": "code",
|
362 |
+
"execution_count": 32,
|
363 |
+
"id": "741e8da5-0d18-4c11-b559-76054ce4ca3a",
|
364 |
+
"metadata": {},
|
365 |
+
"outputs": [
|
366 |
+
{
|
367 |
+
"name": "stdout",
|
368 |
+
"output_type": "stream",
|
369 |
+
"text": [
|
370 |
+
"is a Japanese character from the story of Jesus, who is a Chinese monk who is also known for his teachings. The story is based on the story of the story of Jesus Christ, and it is a representation of the story of Jesus and the story of Jesus Christ.\n"
|
371 |
+
]
|
372 |
+
}
|
373 |
+
],
|
374 |
+
"source": [
|
375 |
+
"input_token_len = input_ids.shape[1]\n",
|
376 |
+
"n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()\n",
|
377 |
+
"if n_diff_input_output > 0:\n",
|
378 |
+
" print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')\n",
|
379 |
+
"outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]\n",
|
380 |
+
"outputs = outputs.strip()\n",
|
381 |
+
"if outputs.endswith(stop_str):\n",
|
382 |
+
" outputs = outputs[:-len(stop_str)]\n",
|
383 |
+
"outputs = outputs.strip()\n",
|
384 |
+
"print(outputs)"
|
385 |
+
]
|
386 |
+
},
|
387 |
+
{
|
388 |
+
"cell_type": "code",
|
389 |
+
"execution_count": 20,
|
390 |
+
"id": "69d494d4-d768-4645-b4d6-5c455791b50d",
|
391 |
+
"metadata": {},
|
392 |
+
"outputs": [],
|
393 |
+
"source": [
|
394 |
+
"# image"
|
395 |
+
]
|
396 |
+
},
|
397 |
+
{
|
398 |
+
"cell_type": "code",
|
399 |
+
"execution_count": null,
|
400 |
+
"id": "8a340856-a13f-4b18-9911-126a4ba37816",
|
401 |
+
"metadata": {},
|
402 |
+
"outputs": [],
|
403 |
+
"source": []
|
404 |
+
},
|
405 |
+
{
|
406 |
+
"cell_type": "code",
|
407 |
+
"execution_count": null,
|
408 |
+
"id": "3c56fdea-c7a1-4e67-9832-e2ed077d8704",
|
409 |
+
"metadata": {},
|
410 |
+
"outputs": [],
|
411 |
+
"source": []
|
412 |
+
},
|
413 |
+
{
|
414 |
+
"cell_type": "code",
|
415 |
+
"execution_count": 52,
|
416 |
+
"id": "89e84d39-8ed8-45db-ae82-27c156ee6dd1",
|
417 |
+
"metadata": {},
|
418 |
+
"outputs": [],
|
419 |
+
"source": [
|
420 |
+
"class AudioLanguageConnector:\n",
|
421 |
+
" def __init__(self, projection_dim):\n",
|
422 |
+
" model_name = \"microsoft/phi-2\"\n",
|
423 |
+
" self.phi2_tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)\n",
|
424 |
+
" self.phi2_tokenizer.pad_token = self.phi2_tokenizer.eos_token\n",
|
425 |
+
" self.phi2_tokenizer.max_length = projection_dim\n",
|
426 |
+
"\n",
|
427 |
+
" def __call__(self, text):\n",
|
428 |
+
" text = f\"<audio_start> {text} <audio_end>\"\n",
|
429 |
+
" tokens = self.phi2_tokenizer(text, return_tensors=\"pt\", return_attention_mask=False)\n",
|
430 |
+
" return tokens\n",
|
431 |
+
" \n",
|
432 |
+
"\n",
|
433 |
+
"class WhisperWithProjection:\n",
|
434 |
+
" def __init__(self, projection_dim, device):\n",
|
435 |
+
" self.device = device\n",
|
436 |
+
" self.processor = WhisperProcessor.from_pretrained(\"openai/whisper-tiny\", device_map=device)\n",
|
437 |
+
" self.model = WhisperForConditionalGeneration.from_pretrained(\"openai/whisper-tiny\", device_map=device)\n",
|
438 |
+
" self.model.config.forced_decoder_ids = None\n",
|
439 |
+
" self.audio_language_connector = AudioLanguageConnector(projection_dim)\n",
|
440 |
+
" \n",
|
441 |
+
" def __call__(self, audio):\n",
|
442 |
+
" input_features = self.processor(audio[\"array\"],\n",
|
443 |
+
" sampling_rate=audio[\"sampling_rate\"],\n",
|
444 |
+
" return_tensors=\"pt\").input_features\n",
|
445 |
+
" # generate token ids\n",
|
446 |
+
" predicted_ids = self.model.generate(input_features.to(self.device))\n",
|
447 |
+
" # decode token ids to text \n",
|
448 |
+
" transcription = self.processor.batch_decode(predicted_ids, skip_special_tokens=True)\n",
|
449 |
+
"\n",
|
450 |
+
" audio_embeddings = self.audio_language_connector(transcription)\n",
|
451 |
+
" return audio_embeddings.to(self.device)"
|
452 |
+
]
|
453 |
+
},
|
454 |
+
{
|
455 |
+
"cell_type": "code",
|
456 |
+
"execution_count": 53,
|
457 |
+
"id": "75e24be0-b236-4047-83ef-5c344e262476",
|
458 |
+
"metadata": {},
|
459 |
+
"outputs": [],
|
460 |
+
"source": [
|
461 |
+
"class MultiModalPhi2:\n",
|
462 |
+
" def __init__(self, model_path=\"checkpoints/llavaPhi-v0-3b-finetune/checkpoint-4000\",\n",
|
463 |
+
" temperature=0.2,\n",
|
464 |
+
" max_new_tokens=1024,\n",
|
465 |
+
" device=\"cuda\"):\n",
|
466 |
+
" self.temperature = temperature\n",
|
467 |
+
" self.max_new_tokens = max_new_tokens\n",
|
468 |
+
" self.device = device\n",
|
469 |
+
" model_name = get_model_name_from_path(model_path)\n",
|
470 |
+
" self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(model_path, None, model_name, device_map=device)\n",
|
471 |
+
" self.whisper_w_proj = WhisperWithProjection(projection_dim=512, device=device)\n",
|
472 |
+
" \n",
|
473 |
+
" \n",
|
474 |
+
" def __call__(self, text, audio, image):\n",
|
475 |
+
" qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\\n' + text\n",
|
476 |
+
" conv = conv_templates[\"default\"].copy()\n",
|
477 |
+
" conv.append_message(conv.roles[0], qs)\n",
|
478 |
+
" conv.append_message(conv.roles[1], None)\n",
|
479 |
+
" prompt = conv.get_prompt()\n",
|
480 |
+
"\n",
|
481 |
+
" image_tensor = self.image_processor.preprocess(image, return_tensors='pt')['pixel_values'].cuda()\n",
|
482 |
+
" \n",
|
483 |
+
" input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()\n",
|
484 |
+
"\n",
|
485 |
+
" audio_embed = self.whisper_w_proj(audio)[\"input_ids\"]\n",
|
486 |
+
" \n",
|
487 |
+
" stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2\n",
|
488 |
+
"\n",
|
489 |
+
" input_ids = torch.concat([input_ids, audio_embed], dim=1)\n",
|
490 |
+
"\n",
|
491 |
+
" with torch.inference_mode():\n",
|
492 |
+
" output_ids = self.model.generate(\n",
|
493 |
+
" input_ids,\n",
|
494 |
+
" images=image_tensor,\n",
|
495 |
+
" do_sample=True,\n",
|
496 |
+
" temperature=self.temperature,\n",
|
497 |
+
" max_new_tokens=self.max_new_tokens,\n",
|
498 |
+
" eos_token_id=tokenizer.eos_token_id, # End of sequence token\n",
|
499 |
+
" pad_token_id=tokenizer.eos_token_id, # Pad token\n",
|
500 |
+
" use_cache=True,\n",
|
501 |
+
" )\n",
|
502 |
+
"\n",
|
503 |
+
" input_token_len = input_ids.shape[1]\n",
|
504 |
+
" n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()\n",
|
505 |
+
" if n_diff_input_output > 0:\n",
|
506 |
+
" print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')\n",
|
507 |
+
" outputs = self.tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]\n",
|
508 |
+
" outputs = outputs.strip()\n",
|
509 |
+
" if outputs.endswith(stop_str):\n",
|
510 |
+
" outputs = outputs[:-len(stop_str)]\n",
|
511 |
+
" outputs = outputs.strip()\n",
|
512 |
+
" return outputs"
|
513 |
+
]
|
514 |
+
},
|
515 |
+
{
|
516 |
+
"cell_type": "code",
|
517 |
+
"execution_count": 54,
|
518 |
+
"id": "4efdbad4-d88a-4477-a3a0-f5591cd0b172",
|
519 |
+
"metadata": {},
|
520 |
+
"outputs": [
|
521 |
+
{
|
522 |
+
"name": "stderr",
|
523 |
+
"output_type": "stream",
|
524 |
+
"text": [
|
525 |
+
"The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.\n",
|
526 |
+
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n",
|
527 |
+
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
|
528 |
+
]
|
529 |
+
},
|
530 |
+
{
|
531 |
+
"name": "stdout",
|
532 |
+
"output_type": "stream",
|
533 |
+
"text": [
|
534 |
+
"load llaVA-Phi MLLM!!!\n"
|
535 |
+
]
|
536 |
+
},
|
537 |
+
{
|
538 |
+
"data": {
|
539 |
+
"application/vnd.jupyter.widget-view+json": {
|
540 |
+
"model_id": "492c17cf54f34d4d9e4f288fc9e72e79",
|
541 |
+
"version_major": 2,
|
542 |
+
"version_minor": 0
|
543 |
+
},
|
544 |
+
"text/plain": [
|
545 |
+
"Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
|
546 |
+
]
|
547 |
+
},
|
548 |
+
"metadata": {},
|
549 |
+
"output_type": "display_data"
|
550 |
+
},
|
551 |
+
{
|
552 |
+
"name": "stdout",
|
553 |
+
"output_type": "stream",
|
554 |
+
"text": [
|
555 |
+
"{'device_map': 'cuda'}\n"
|
556 |
+
]
|
557 |
+
},
|
558 |
+
{
|
559 |
+
"name": "stderr",
|
560 |
+
"output_type": "stream",
|
561 |
+
"text": [
|
562 |
+
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
|
563 |
+
]
|
564 |
+
}
|
565 |
+
],
|
566 |
+
"source": [
|
567 |
+
"multimodal_phi2 = MultiModalPhi2()"
|
568 |
+
]
|
569 |
+
},
|
570 |
+
{
|
571 |
+
"cell_type": "code",
|
572 |
+
"execution_count": 57,
|
573 |
+
"id": "9a6de0b0-a231-4d50-88e8-e40c6f7216c3",
|
574 |
+
"metadata": {},
|
575 |
+
"outputs": [],
|
576 |
+
"source": [
|
577 |
+
"text = \"tell me about the audio\""
|
578 |
+
]
|
579 |
+
},
|
580 |
+
{
|
581 |
+
"cell_type": "code",
|
582 |
+
"execution_count": 58,
|
583 |
+
"id": "b4919948-6a75-4d19-ba95-9ba233a7d3d9",
|
584 |
+
"metadata": {},
|
585 |
+
"outputs": [
|
586 |
+
{
|
587 |
+
"data": {
|
588 |
+
"text/plain": [
|
589 |
+
"'is a popular Japanese drama series featuring a man in a red and white costume, who is dressed as Santa Claus, is walking down the street. The scene takes place in a busy city environment, with people walking and standing on the sidewalk, likely enjoying the festive atmosphere and the festive atmosphere.'"
|
590 |
+
]
|
591 |
+
},
|
592 |
+
"execution_count": 58,
|
593 |
+
"metadata": {},
|
594 |
+
"output_type": "execute_result"
|
595 |
+
}
|
596 |
+
],
|
597 |
+
"source": [
|
598 |
+
"multimodal_phi2(text, audio, image)"
|
599 |
+
]
|
600 |
+
},
|
601 |
+
{
|
602 |
+
"cell_type": "code",
|
603 |
+
"execution_count": null,
|
604 |
+
"id": "590f2d64-62ed-4e6f-b7c8-b0cf68aecaab",
|
605 |
+
"metadata": {},
|
606 |
+
"outputs": [],
|
607 |
+
"source": []
|
608 |
+
},
|
609 |
+
{
|
610 |
+
"cell_type": "code",
|
611 |
+
"execution_count": 64,
|
612 |
+
"id": "c921eb63-feb5-4fa9-993b-2faeb6dfe1db",
|
613 |
+
"metadata": {},
|
614 |
+
"outputs": [],
|
615 |
+
"source": [
|
616 |
+
"from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig, CLIPImageProcessor"
|
617 |
+
]
|
618 |
+
},
|
619 |
+
{
|
620 |
+
"cell_type": "code",
|
621 |
+
"execution_count": 65,
|
622 |
+
"id": "b470a2c4-806a-435d-9fc2-f17448dbe5fc",
|
623 |
+
"metadata": {},
|
624 |
+
"outputs": [],
|
625 |
+
"source": [
|
626 |
+
"from llava_phi.model import LlavaPhiConfig"
|
627 |
+
]
|
628 |
+
},
|
629 |
+
{
|
630 |
+
"cell_type": "code",
|
631 |
+
"execution_count": 66,
|
632 |
+
"id": "4f7bc91a-0a41-45e5-92a4-daa1e3eea0da",
|
633 |
+
"metadata": {},
|
634 |
+
"outputs": [
|
635 |
+
{
|
636 |
+
"name": "stderr",
|
637 |
+
"output_type": "stream",
|
638 |
+
"text": [
|
639 |
+
"The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.\n",
|
640 |
+
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n",
|
641 |
+
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
|
642 |
+
]
|
643 |
+
},
|
644 |
+
{
|
645 |
+
"data": {
|
646 |
+
"application/vnd.jupyter.widget-view+json": {
|
647 |
+
"model_id": "993bc3a38cb84de4a2e3a79a3448c4d6",
|
648 |
+
"version_major": 2,
|
649 |
+
"version_minor": 0
|
650 |
+
},
|
651 |
+
"text/plain": [
|
652 |
+
"Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
|
653 |
+
]
|
654 |
+
},
|
655 |
+
"metadata": {},
|
656 |
+
"output_type": "display_data"
|
657 |
+
}
|
658 |
+
],
|
659 |
+
"source": [
|
660 |
+
"device_map = \"cuda:0\"\n",
|
661 |
+
"load_8bit = False\n",
|
662 |
+
"load_4bit = False\n",
|
663 |
+
"kwargs = {\"device_map\": device_map}\n",
|
664 |
+
"if load_8bit:\n",
|
665 |
+
" kwargs['load_in_8bit'] = True\n",
|
666 |
+
"elif load_4bit:\n",
|
667 |
+
" kwargs['load_in_4bit'] = True\n",
|
668 |
+
" kwargs['quantization_config'] = BitsAndBytesConfig(\n",
|
669 |
+
" load_in_4bit=True,\n",
|
670 |
+
" bnb_4bit_compute_dtype=torch.float16,\n",
|
671 |
+
" bnb_4bit_use_double_quant=True,\n",
|
672 |
+
" bnb_4bit_quant_type='nf4'\n",
|
673 |
+
" )\n",
|
674 |
+
"config = LlavaPhiConfig.from_pretrained(model_path, trust_remote_code=True)\n",
|
675 |
+
"tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)\n",
|
676 |
+
"model = LlavaPhiForCausalLM.from_pretrained(\n",
|
677 |
+
" model_path, \n",
|
678 |
+
" config=config, \n",
|
679 |
+
" use_safetensors=True, \n",
|
680 |
+
" **kwargs).to(\"cuda\")\n",
|
681 |
+
"image_processor = CLIPImageProcessor.from_pretrained(model_path)\n",
|
682 |
+
"mm_use_im_start_end = getattr(model.config, \"mm_use_im_start_end\", False)\n",
|
683 |
+
"mm_use_im_patch_token = getattr(model.config, \"mm_use_im_patch_token\", True)\n",
|
684 |
+
"\n",
|
685 |
+
"# TODO: the tokenizer length of phi-2 is 50295, but the output class of lm_head is 51200\n",
|
686 |
+
"if mm_use_im_patch_token:\n",
|
687 |
+
" tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)\n",
|
688 |
+
"if mm_use_im_start_end:\n",
|
689 |
+
" tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)\n",
|
690 |
+
" \n",
|
691 |
+
"if hasattr(model.config, \"max_sequence_length\"):\n",
|
692 |
+
" context_len = model.config.max_sequence_length\n",
|
693 |
+
"else:\n",
|
694 |
+
" context_len = 2048"
|
695 |
+
]
|
696 |
+
},
|
697 |
+
{
|
698 |
+
"cell_type": "code",
|
699 |
+
"execution_count": 70,
|
700 |
+
"id": "99355837-a297-4a25-aeb3-1670af7e9251",
|
701 |
+
"metadata": {},
|
702 |
+
"outputs": [
|
703 |
+
{
|
704 |
+
"ename": "KeyboardInterrupt",
|
705 |
+
"evalue": "",
|
706 |
+
"output_type": "error",
|
707 |
+
"traceback": [
|
708 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
709 |
+
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
|
710 |
+
"Cell \u001b[0;32mIn[70], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msave_pretrained\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mLlava-Phi-Checkpoint\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n",
|
711 |
+
"File \u001b[0;32m~/miniconda3/envs/torchenv/lib/python3.10/site-packages/transformers/modeling_utils.py:2376\u001b[0m, in \u001b[0;36mPreTrainedModel.save_pretrained\u001b[0;34m(self, save_directory, is_main_process, state_dict, save_function, push_to_hub, max_shard_size, safe_serialization, variant, token, save_peft_format, **kwargs)\u001b[0m\n\u001b[1;32m 2372\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m shard_file, shard \u001b[38;5;129;01min\u001b[39;00m shards\u001b[38;5;241m.\u001b[39mitems():\n\u001b[1;32m 2373\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m safe_serialization:\n\u001b[1;32m 2374\u001b[0m \u001b[38;5;66;03m# At some point we will need to deal better with save_function (used for TPU and other distributed\u001b[39;00m\n\u001b[1;32m 2375\u001b[0m \u001b[38;5;66;03m# joyfulness), but for now this enough.\u001b[39;00m\n\u001b[0;32m-> 2376\u001b[0m \u001b[43msafe_save_file\u001b[49m\u001b[43m(\u001b[49m\u001b[43mshard\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mos\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpath\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mjoin\u001b[49m\u001b[43m(\u001b[49m\u001b[43msave_directory\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mshard_file\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmetadata\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m{\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mformat\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mpt\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m}\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2377\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 2378\u001b[0m save_function(shard, os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mjoin(save_directory, shard_file))\n",
|
712 |
+
"File \u001b[0;32m~/miniconda3/envs/torchenv/lib/python3.10/site-packages/safetensors/torch.py:281\u001b[0m, in \u001b[0;36msave_file\u001b[0;34m(tensors, filename, metadata)\u001b[0m\n\u001b[1;32m 250\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21msave_file\u001b[39m(\n\u001b[1;32m 251\u001b[0m tensors: Dict[\u001b[38;5;28mstr\u001b[39m, torch\u001b[38;5;241m.\u001b[39mTensor],\n\u001b[1;32m 252\u001b[0m filename: Union[\u001b[38;5;28mstr\u001b[39m, os\u001b[38;5;241m.\u001b[39mPathLike],\n\u001b[1;32m 253\u001b[0m metadata: Optional[Dict[\u001b[38;5;28mstr\u001b[39m, \u001b[38;5;28mstr\u001b[39m]] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 254\u001b[0m ):\n\u001b[1;32m 255\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 256\u001b[0m \u001b[38;5;124;03m Saves a dictionary of tensors into raw bytes in safetensors format.\u001b[39;00m\n\u001b[1;32m 257\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 279\u001b[0m \u001b[38;5;124;03m ```\u001b[39;00m\n\u001b[1;32m 280\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 281\u001b[0m \u001b[43mserialize_file\u001b[49m\u001b[43m(\u001b[49m\u001b[43m_flatten\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtensors\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfilename\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmetadata\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmetadata\u001b[49m\u001b[43m)\u001b[49m\n",
|
713 |
+
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
|
714 |
+
]
|
715 |
+
}
|
716 |
+
],
|
717 |
+
"source": [
|
718 |
+
"model.save_pretrained(\"Llava-Phi-Checkpoint\")"
|
719 |
+
]
|
720 |
+
},
|
721 |
+
{
|
722 |
+
"cell_type": "code",
|
723 |
+
"execution_count": null,
|
724 |
+
"id": "fa0bec34-a148-4340-a30c-6f09dd5e71ca",
|
725 |
+
"metadata": {},
|
726 |
+
"outputs": [],
|
727 |
+
"source": [
|
728 |
+
"model.push_to_hub(\"RaviNaik/Llava-Phi2\")"
|
729 |
+
]
|
730 |
+
},
|
731 |
+
{
|
732 |
+
"cell_type": "code",
|
733 |
+
"execution_count": 73,
|
734 |
+
"id": "382f74b0-2967-408a-badc-a90918810d74",
|
735 |
+
"metadata": {},
|
736 |
+
"outputs": [
|
737 |
+
{
|
738 |
+
"data": {
|
739 |
+
"text/plain": [
|
740 |
+
"CommitInfo(commit_url='https://huggingface.co/RaviNaik/Llava-Phi2/commit/fa8f7240058241243f6bdc3d6ab44bb691f76e39', commit_message='Upload tokenizer', commit_description='', oid='fa8f7240058241243f6bdc3d6ab44bb691f76e39', pr_url=None, pr_revision=None, pr_num=None)"
|
741 |
+
]
|
742 |
+
},
|
743 |
+
"execution_count": 73,
|
744 |
+
"metadata": {},
|
745 |
+
"output_type": "execute_result"
|
746 |
+
}
|
747 |
+
],
|
748 |
+
"source": [
|
749 |
+
"tokenizer.push_to_hub(\"RaviNaik/Llava-Phi2\")"
|
750 |
+
]
|
751 |
+
},
|
752 |
+
{
|
753 |
+
"cell_type": "code",
|
754 |
+
"execution_count": null,
|
755 |
+
"id": "b851459b-d3ac-4fb8-99b6-17a648adc41f",
|
756 |
+
"metadata": {},
|
757 |
+
"outputs": [],
|
758 |
+
"source": []
|
759 |
+
}
|
760 |
+
],
|
761 |
+
"metadata": {
|
762 |
+
"kernelspec": {
|
763 |
+
"display_name": "Python 3 (ipykernel)",
|
764 |
+
"language": "python",
|
765 |
+
"name": "python3"
|
766 |
+
},
|
767 |
+
"language_info": {
|
768 |
+
"codemirror_mode": {
|
769 |
+
"name": "ipython",
|
770 |
+
"version": 3
|
771 |
+
},
|
772 |
+
"file_extension": ".py",
|
773 |
+
"mimetype": "text/x-python",
|
774 |
+
"name": "python",
|
775 |
+
"nbconvert_exporter": "python",
|
776 |
+
"pygments_lexer": "ipython3",
|
777 |
+
"version": "3.10.12"
|
778 |
+
}
|
779 |
+
},
|
780 |
+
"nbformat": 4,
|
781 |
+
"nbformat_minor": 5
|
782 |
+
}
|
Experiments/instruct_150k_data.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
Experiments/instruct_data.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datasets import Dataset, IterableDataset
|
2 |
+
from PIL import Image
|
3 |
+
|
4 |
+
# ChatML format
|
5 |
+
templates = {
|
6 |
+
"assistant": "<|im_start|>assistant\n{msg}<|im_end|>", # message by assistant
|
7 |
+
"user": "<|im_start|>user\n{msg}<|im_end|>" # message by user
|
8 |
+
}
|
9 |
+
|
10 |
+
ds = Dataset.from_json("llava_instruct_150k.json", split="train")
|
11 |
+
ds_stream = ds.to_iterable_dataset()
|
12 |
+
|
13 |
+
|
14 |
+
def get_image(image_path):
|
15 |
+
image_path = f"train2014/COCO_train2014_{image_path}"
|
16 |
+
img = Image.open(image_path)
|
17 |
+
return img
|
18 |
+
|
19 |
+
def get_chatml_text(conversations):
|
20 |
+
chatml_text = ""
|
21 |
+
for conversation in conversations:
|
22 |
+
role = conversation["from"]
|
23 |
+
role = "user" if role == "human" else "assistant"
|
24 |
+
content = conversation["value"]
|
25 |
+
|
26 |
+
formatted_text = templates[role].format(msg=content)
|
27 |
+
chatml_text += formatted_text + "\n"
|
28 |
+
return chatml_text
|
29 |
+
|
30 |
+
def instruct_data_generator():
|
31 |
+
for sample in ds_stream:
|
32 |
+
image_path = sample["image"]
|
33 |
+
conversations = sample["conversations"]
|
34 |
+
|
35 |
+
image = get_image(image_path)
|
36 |
+
text = get_chatml_text(conversations)
|
37 |
+
yield {"text": text, "image": image}
|
38 |
+
|
39 |
+
instruct_ds = IterableDataset.from_generator(generator=instruct_data_generator)
|
Experiments/llava_exp.ipynb
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"id": "99576983-f881-47c8-8b5e-c6f561a93e71",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [],
|
9 |
+
"source": [
|
10 |
+
"import transformers"
|
11 |
+
]
|
12 |
+
},
|
13 |
+
{
|
14 |
+
"cell_type": "code",
|
15 |
+
"execution_count": 2,
|
16 |
+
"id": "58ba19f2-4b91-4f90-a33d-4c1ed17e202a",
|
17 |
+
"metadata": {},
|
18 |
+
"outputs": [],
|
19 |
+
"source": [
|
20 |
+
"from transformers import LlavaForConditionalGeneration, LlavaConfig, CLIPVisionConfig, PhiConfig\n",
|
21 |
+
"\n",
|
22 |
+
"# Initializing a CLIP-vision config\n",
|
23 |
+
"vision_config = CLIPVisionConfig()\n",
|
24 |
+
"\n",
|
25 |
+
"# Initializing a Llama config\n",
|
26 |
+
"text_config = PhiConfig()\n",
|
27 |
+
"\n",
|
28 |
+
"# Initializing a Llava llava-1.5-7b style configuration\n",
|
29 |
+
"configuration = LlavaConfig(vision_config, text_config)\n",
|
30 |
+
"\n",
|
31 |
+
"# Initializing a model from the llava-1.5-7b style configuration\n",
|
32 |
+
"model = LlavaForConditionalGeneration(configuration)\n",
|
33 |
+
"\n",
|
34 |
+
"# Accessing the model configuration\n",
|
35 |
+
"configuration = model.config"
|
36 |
+
]
|
37 |
+
},
|
38 |
+
{
|
39 |
+
"cell_type": "code",
|
40 |
+
"execution_count": 5,
|
41 |
+
"id": "a806a07a-fe72-45a3-8ceb-8e942c6c845d",
|
42 |
+
"metadata": {},
|
43 |
+
"outputs": [
|
44 |
+
{
|
45 |
+
"data": {
|
46 |
+
"text/plain": [
|
47 |
+
"LlavaConfig {\n",
|
48 |
+
" \"ignore_index\": -100,\n",
|
49 |
+
" \"image_token_index\": 32000,\n",
|
50 |
+
" \"model_type\": \"llava\",\n",
|
51 |
+
" \"projector_hidden_act\": \"gelu\",\n",
|
52 |
+
" \"text_config\": {\n",
|
53 |
+
" \"embd_pdrop\": 0.0,\n",
|
54 |
+
" \"hidden_act\": \"gelu_new\",\n",
|
55 |
+
" \"hidden_size\": 2048,\n",
|
56 |
+
" \"intermediate_size\": 8192,\n",
|
57 |
+
" \"layer_norm_eps\": 1e-05,\n",
|
58 |
+
" \"model_type\": \"phi\",\n",
|
59 |
+
" \"num_hidden_layers\": 24,\n",
|
60 |
+
" \"partial_rotary_factor\": 0.5,\n",
|
61 |
+
" \"qk_layernorm\": false,\n",
|
62 |
+
" \"resid_pdrop\": 0.0,\n",
|
63 |
+
" \"vocab_size\": 51200\n",
|
64 |
+
" },\n",
|
65 |
+
" \"transformers_version\": \"4.36.2\",\n",
|
66 |
+
" \"vision_config\": {\n",
|
67 |
+
" \"hidden_size\": 768,\n",
|
68 |
+
" \"image_size\": 224,\n",
|
69 |
+
" \"intermediate_size\": 3072,\n",
|
70 |
+
" \"model_type\": \"clip_vision_model\",\n",
|
71 |
+
" \"num_attention_heads\": 12,\n",
|
72 |
+
" \"num_hidden_layers\": 12,\n",
|
73 |
+
" \"patch_size\": 32,\n",
|
74 |
+
" \"projection_dim\": 512\n",
|
75 |
+
" },\n",
|
76 |
+
" \"vision_feature_layer\": -2,\n",
|
77 |
+
" \"vision_feature_select_strategy\": \"default\",\n",
|
78 |
+
" \"vocab_size\": 32000\n",
|
79 |
+
"}"
|
80 |
+
]
|
81 |
+
},
|
82 |
+
"execution_count": 5,
|
83 |
+
"metadata": {},
|
84 |
+
"output_type": "execute_result"
|
85 |
+
}
|
86 |
+
],
|
87 |
+
"source": [
|
88 |
+
"model.config"
|
89 |
+
]
|
90 |
+
},
|
91 |
+
{
|
92 |
+
"cell_type": "code",
|
93 |
+
"execution_count": 6,
|
94 |
+
"id": "79efbc6b-f005-4a5c-82a1-112fa37f1904",
|
95 |
+
"metadata": {},
|
96 |
+
"outputs": [
|
97 |
+
{
|
98 |
+
"name": "stdout",
|
99 |
+
"output_type": "stream",
|
100 |
+
"text": [
|
101 |
+
"Cloning into 'llava-phi'...\n",
|
102 |
+
"remote: Enumerating objects: 151, done.\u001b[K\n",
|
103 |
+
"remote: Counting objects: 100% (151/151), done.\u001b[K\n",
|
104 |
+
"remote: Compressing objects: 100% (116/116), done.\u001b[K\n",
|
105 |
+
"remote: Total 151 (delta 36), reused 133 (delta 25), pack-reused 0\u001b[K\n",
|
106 |
+
"Receiving objects: 100% (151/151), 333.89 KiB | 112.00 KiB/s, done.\n",
|
107 |
+
"Resolving deltas: 100% (36/36), done.\n"
|
108 |
+
]
|
109 |
+
}
|
110 |
+
],
|
111 |
+
"source": [
|
112 |
+
"!git clone https://github.com/zhuyiche/llava-phi.git"
|
113 |
+
]
|
114 |
+
},
|
115 |
+
{
|
116 |
+
"cell_type": "code",
|
117 |
+
"execution_count": null,
|
118 |
+
"id": "cf827184-f334-4d86-ace1-fe9c92f84d66",
|
119 |
+
"metadata": {},
|
120 |
+
"outputs": [],
|
121 |
+
"source": []
|
122 |
+
}
|
123 |
+
],
|
124 |
+
"metadata": {
|
125 |
+
"kernelspec": {
|
126 |
+
"display_name": "Python 3 (ipykernel)",
|
127 |
+
"language": "python",
|
128 |
+
"name": "python3"
|
129 |
+
},
|
130 |
+
"language_info": {
|
131 |
+
"codemirror_mode": {
|
132 |
+
"name": "ipython",
|
133 |
+
"version": 3
|
134 |
+
},
|
135 |
+
"file_extension": ".py",
|
136 |
+
"mimetype": "text/x-python",
|
137 |
+
"name": "python",
|
138 |
+
"nbconvert_exporter": "python",
|
139 |
+
"pygments_lexer": "ipython3",
|
140 |
+
"version": "3.10.12"
|
141 |
+
}
|
142 |
+
},
|
143 |
+
"nbformat": 4,
|
144 |
+
"nbformat_minor": 5
|
145 |
+
}
|
Experiments/multimodal_exp.ipynb
ADDED
@@ -0,0 +1,362 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 23,
|
6 |
+
"id": "d4bed9ef-4bff-4d61-a4f9-a585f377f136",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [],
|
9 |
+
"source": [
|
10 |
+
"from PIL import Image\n",
|
11 |
+
"import requests\n",
|
12 |
+
"\n",
|
13 |
+
"import torch\n",
|
14 |
+
"from torch import nn\n",
|
15 |
+
"from transformers import AutoProcessor, CLIPVisionModel, CLIPVisionConfig, CLIPPreTrainedModel\n",
|
16 |
+
"from transformers.models.clip.modeling_clip import CLIPVisionModelOutput, CLIPVisionTransformer\n",
|
17 |
+
"from transformers import WhisperProcessor, WhisperForConditionalGeneration\n",
|
18 |
+
"from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoTokenizer\n",
|
19 |
+
"from typing import Optional, Union, Tuple"
|
20 |
+
]
|
21 |
+
},
|
22 |
+
{
|
23 |
+
"cell_type": "code",
|
24 |
+
"execution_count": 43,
|
25 |
+
"id": "952314f0-ee9d-45e7-85b8-1e3e44c1a2fd",
|
26 |
+
"metadata": {},
|
27 |
+
"outputs": [],
|
28 |
+
"source": [
|
29 |
+
"class VisionLanguageConnector(nn.Module):\n",
|
30 |
+
" def __init__(self, hidden_size, projection_dim):\n",
|
31 |
+
" super().__init__()\n",
|
32 |
+
" self.mlp = nn.Sequential(\n",
|
33 |
+
" nn.Linear(hidden_size, hidden_size, bias=False),\n",
|
34 |
+
" nn.GELU(),\n",
|
35 |
+
" nn.Linear(hidden_size, projection_dim, bias=False)\n",
|
36 |
+
" )\n",
|
37 |
+
"\n",
|
38 |
+
" def forward(self, x):\n",
|
39 |
+
" return self.mlp(x)\n",
|
40 |
+
" \n",
|
41 |
+
"class ClipWithProjection():\n",
|
42 |
+
" config_class = CLIPVisionConfig\n",
|
43 |
+
" main_input_name = \"pixel_values\"\n",
|
44 |
+
"\n",
|
45 |
+
" def __init__(self, hidden_size, projection_dim):\n",
|
46 |
+
" super().__init__()\n",
|
47 |
+
" \n",
|
48 |
+
" self.processor = AutoProcessor.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
|
49 |
+
" self.vision_model = CLIPVisionModel.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
|
50 |
+
" self.vision_language_connector = VisionLanguageConnector(hidden_size, projection_dim)\n",
|
51 |
+
"\n",
|
52 |
+
" def forward(\n",
|
53 |
+
" self,\n",
|
54 |
+
" image = None,\n",
|
55 |
+
" output_attentions: Optional[bool] = None,\n",
|
56 |
+
" output_hidden_states: Optional[bool] = None,\n",
|
57 |
+
" return_dict: Optional[bool] = None,\n",
|
58 |
+
" ) -> Union[Tuple, CLIPVisionModelOutput]:\n",
|
59 |
+
" \n",
|
60 |
+
" pixel_values = self.processor(images=image, return_tensors=\"pt\")[\"pixel_values\"]\n",
|
61 |
+
" vision_outputs = self.vision_model(\n",
|
62 |
+
" pixel_values=pixel_values,\n",
|
63 |
+
" output_attentions=output_attentions,\n",
|
64 |
+
" output_hidden_states=output_hidden_states,\n",
|
65 |
+
" return_dict=return_dict,\n",
|
66 |
+
" )\n",
|
67 |
+
"\n",
|
68 |
+
" pooled_output = vision_outputs[1] # pooled_output\n",
|
69 |
+
"\n",
|
70 |
+
" image_embeds = self.vision_language_connector(pooled_output)\n",
|
71 |
+
"\n",
|
72 |
+
" return CLIPVisionModelOutput(\n",
|
73 |
+
" image_embeds=image_embeds,\n",
|
74 |
+
" last_hidden_state=vision_outputs.last_hidden_state,\n",
|
75 |
+
" hidden_states=vision_outputs.hidden_states,\n",
|
76 |
+
" attentions=vision_outputs.attentions,\n",
|
77 |
+
" )"
|
78 |
+
]
|
79 |
+
},
|
80 |
+
{
|
81 |
+
"cell_type": "code",
|
82 |
+
"execution_count": 44,
|
83 |
+
"id": "bd2889fe-be85-44a3-afe8-65b47f7a93c3",
|
84 |
+
"metadata": {},
|
85 |
+
"outputs": [],
|
86 |
+
"source": [
|
87 |
+
"url = \"http://images.cocodataset.org/val2017/000000039769.jpg\"\n",
|
88 |
+
"image = Image.open(requests.get(url, stream=True).raw)"
|
89 |
+
]
|
90 |
+
},
|
91 |
+
{
|
92 |
+
"cell_type": "code",
|
93 |
+
"execution_count": 46,
|
94 |
+
"id": "17c72699-fe98-4b96-b63c-5c8ab7c1a65f",
|
95 |
+
"metadata": {},
|
96 |
+
"outputs": [],
|
97 |
+
"source": [
|
98 |
+
"# model = ClipWithProjection(768, 512)\n",
|
99 |
+
"# model.forward(image)"
|
100 |
+
]
|
101 |
+
},
|
102 |
+
{
|
103 |
+
"cell_type": "code",
|
104 |
+
"execution_count": 47,
|
105 |
+
"id": "70806156-38a9-45a2-bf9f-e72047a0173f",
|
106 |
+
"metadata": {},
|
107 |
+
"outputs": [],
|
108 |
+
"source": [
|
109 |
+
"class AudioLanguageConnector:\n",
|
110 |
+
" def __init__(self, projection_dim):\n",
|
111 |
+
" model_name = \"microsoft/phi-2\"\n",
|
112 |
+
" self.phi2_tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)\n",
|
113 |
+
" self.phi2_tokenizer.pad_token = self.phi2_tokenizer.eos_token\n",
|
114 |
+
" self.phi2_tokenizer.max_length = projection_dim\n",
|
115 |
+
"\n",
|
116 |
+
" def __call__(self, text):\n",
|
117 |
+
" text = f\"<audio_start> {text} <audio_end>\"\n",
|
118 |
+
" tokens = self.phi2_tokenizer(text, return_tensors=\"pt\", return_attention_mask=False)\n",
|
119 |
+
" return tokens\n",
|
120 |
+
" \n",
|
121 |
+
"\n",
|
122 |
+
"class WhisperWithProjection:\n",
|
123 |
+
" def __init__(self, projection_dim):\n",
|
124 |
+
" self.processor = WhisperProcessor.from_pretrained(\"openai/whisper-tiny\")\n",
|
125 |
+
" self.model = WhisperForConditionalGeneration.from_pretrained(\"openai/whisper-tiny\")\n",
|
126 |
+
" self.model.config.forced_decoder_ids = None\n",
|
127 |
+
" self.audio_language_connector = AudioLanguageConnector(projection_dim)\n",
|
128 |
+
" \n",
|
129 |
+
" def forward(self, audio):\n",
|
130 |
+
" input_features = self.processor(audio[\"array\"],\n",
|
131 |
+
" sampling_rate=audio[\"sampling_rate\"],\n",
|
132 |
+
" return_tensors=\"pt\").input_features\n",
|
133 |
+
" # generate token ids\n",
|
134 |
+
" predicted_ids = self.model.generate(input_features)\n",
|
135 |
+
" # decode token ids to text \n",
|
136 |
+
" transcription = self.processor.batch_decode(predicted_ids, skip_special_tokens=True)\n",
|
137 |
+
"\n",
|
138 |
+
" audio_embeddings = self.audio_language_connector(transcription)\n",
|
139 |
+
" return audio_embeddings"
|
140 |
+
]
|
141 |
+
},
|
142 |
+
{
|
143 |
+
"cell_type": "code",
|
144 |
+
"execution_count": 48,
|
145 |
+
"id": "79cc4d98-498b-4042-bd71-143b2477733d",
|
146 |
+
"metadata": {},
|
147 |
+
"outputs": [],
|
148 |
+
"source": [
|
149 |
+
"class TextModality:\n",
|
150 |
+
" def __init__(self, projection_dim):\n",
|
151 |
+
" model_name = \"microsoft/phi-2\"\n",
|
152 |
+
" self.phi2_tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)\n",
|
153 |
+
" self.phi2_tokenizer.pad_token = self.phi2_tokenizer.eos_token\n",
|
154 |
+
" self.phi2_tokenizer.max_length = projection_dim\n",
|
155 |
+
"\n",
|
156 |
+
"\n",
|
157 |
+
" def __call__(self, text):\n",
|
158 |
+
" tokens = self.phi2_tokenizer(text, return_tensors=\"pt\", return_attention_mask=False)\n",
|
159 |
+
" return tokens"
|
160 |
+
]
|
161 |
+
},
|
162 |
+
{
|
163 |
+
"cell_type": "code",
|
164 |
+
"execution_count": 77,
|
165 |
+
"id": "ba4c4772-923f-48e8-a4af-b7d9c192dd4b",
|
166 |
+
"metadata": {},
|
167 |
+
"outputs": [],
|
168 |
+
"source": [
|
169 |
+
"class MultiModalPhi2:\n",
|
170 |
+
" def __init__(self):\n",
|
171 |
+
" self.text_modality = TextModality(projection_dim=768)\n",
|
172 |
+
" self.whisper_w_proj = WhisperWithProjection(projection_dim=512)\n",
|
173 |
+
" self.clip_w_proj = ClipWithProjection(hidden_size=768, projection_dim=768)\n",
|
174 |
+
" self.llm = self.load_llm()\n",
|
175 |
+
"\n",
|
176 |
+
" def load_llm(self):\n",
|
177 |
+
" model_name = \"microsoft/phi-2\"\n",
|
178 |
+
" \n",
|
179 |
+
" bnb_config = BitsAndBytesConfig(\n",
|
180 |
+
" load_in_4bit=True,\n",
|
181 |
+
" bnb_4bit_quant_type=\"nf4\",\n",
|
182 |
+
" bnb_4bit_compute_dtype=torch.float16)\n",
|
183 |
+
" \n",
|
184 |
+
" model = AutoModelForCausalLM.from_pretrained(\n",
|
185 |
+
" model_name,\n",
|
186 |
+
" quantization_config=bnb_config,\n",
|
187 |
+
" trust_remote_code=True,\n",
|
188 |
+
" device_map=\"cuda:0\"\n",
|
189 |
+
" )\n",
|
190 |
+
" model.config.use_cache = False\n",
|
191 |
+
" return model\n",
|
192 |
+
"\n",
|
193 |
+
" def forward(self, audio, image, text):\n",
|
194 |
+
" if text is not None:\n",
|
195 |
+
" text_embed = self.text_modality(text)[\"input_ids\"]\n",
|
196 |
+
" if audio is not None:\n",
|
197 |
+
" audio_embed = self.whisper_w_proj.forward(audio)[\"input_ids\"]\n",
|
198 |
+
" if image is not None:\n",
|
199 |
+
" image_embed = self.clip_w_proj.forward(image)[0]\n",
|
200 |
+
" print(text_embed.shape, text_embed.dtype)\n",
|
201 |
+
" print(audio_embed.shape, audio_embed.dtype)\n",
|
202 |
+
" print(image_embed.shape, image_embed.dtype)\n",
|
203 |
+
" \n",
|
204 |
+
" inputs = torch.concat([text_embed, audio_embed, image_embed], dim=1)\n",
|
205 |
+
" print(inputs.shape, inputs.dtype)\n",
|
206 |
+
" outputs = self.llm(inputs)\n",
|
207 |
+
"\n",
|
208 |
+
" return outputs \n",
|
209 |
+
" \n",
|
210 |
+
"\n",
|
211 |
+
" def generate(self, audio, text):\n",
|
212 |
+
" text_embeddings = self.text_modality(text)\n",
|
213 |
+
" audio_embeddings = self.whisper_w_proj.forward(audio)\n",
|
214 |
+
" inputs = torch.concat([text_embed[\"input_ids\"], audio_embed[\"input_ids\"]], dim=1)\n",
|
215 |
+
" \n",
|
216 |
+
" outputs = self.llm.generate(inputs, max_length=200)\n",
|
217 |
+
" text = self.text_modality.phi2_tokenizer.batch_decode(outputs)[0]\n",
|
218 |
+
" print(text)"
|
219 |
+
]
|
220 |
+
},
|
221 |
+
{
|
222 |
+
"cell_type": "code",
|
223 |
+
"execution_count": 74,
|
224 |
+
"id": "7ca694eb-8009-4eb9-9a4c-eac406ab9584",
|
225 |
+
"metadata": {},
|
226 |
+
"outputs": [],
|
227 |
+
"source": [
|
228 |
+
"from datasets import load_dataset\n",
|
229 |
+
"audio_ds = load_dataset(\"hf-internal-testing/librispeech_asr_dummy\", \"clean\", split=\"validation\")\n",
|
230 |
+
"audio = audio_ds[0][\"audio\"]"
|
231 |
+
]
|
232 |
+
},
|
233 |
+
{
|
234 |
+
"cell_type": "code",
|
235 |
+
"execution_count": 58,
|
236 |
+
"id": "37be28c5-4cc3-4471-b394-032c7602accc",
|
237 |
+
"metadata": {},
|
238 |
+
"outputs": [],
|
239 |
+
"source": [
|
240 |
+
"text = \"explain about the audio\""
|
241 |
+
]
|
242 |
+
},
|
243 |
+
{
|
244 |
+
"cell_type": "code",
|
245 |
+
"execution_count": 59,
|
246 |
+
"id": "c0705114-1670-4937-bc3e-3660e5a5d2c5",
|
247 |
+
"metadata": {},
|
248 |
+
"outputs": [],
|
249 |
+
"source": [
|
250 |
+
"# image"
|
251 |
+
]
|
252 |
+
},
|
253 |
+
{
|
254 |
+
"cell_type": "code",
|
255 |
+
"execution_count": 78,
|
256 |
+
"id": "0d7e5b49-b4bd-477c-87b8-91ef70857677",
|
257 |
+
"metadata": {},
|
258 |
+
"outputs": [
|
259 |
+
{
|
260 |
+
"name": "stderr",
|
261 |
+
"output_type": "stream",
|
262 |
+
"text": [
|
263 |
+
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n",
|
264 |
+
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
|
265 |
+
]
|
266 |
+
},
|
267 |
+
{
|
268 |
+
"data": {
|
269 |
+
"application/vnd.jupyter.widget-view+json": {
|
270 |
+
"model_id": "733dc7b2208b4853a89aea49bff9a55c",
|
271 |
+
"version_major": 2,
|
272 |
+
"version_minor": 0
|
273 |
+
},
|
274 |
+
"text/plain": [
|
275 |
+
"Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
|
276 |
+
]
|
277 |
+
},
|
278 |
+
"metadata": {},
|
279 |
+
"output_type": "display_data"
|
280 |
+
}
|
281 |
+
],
|
282 |
+
"source": [
|
283 |
+
"model = MultiModalPhi2()"
|
284 |
+
]
|
285 |
+
},
|
286 |
+
{
|
287 |
+
"cell_type": "code",
|
288 |
+
"execution_count": 79,
|
289 |
+
"id": "0b6471c4-4553-47f3-b38f-46057dcf80f2",
|
290 |
+
"metadata": {},
|
291 |
+
"outputs": [
|
292 |
+
{
|
293 |
+
"name": "stdout",
|
294 |
+
"output_type": "stream",
|
295 |
+
"text": [
|
296 |
+
"torch.Size([1, 5]) torch.int64\n",
|
297 |
+
"torch.Size([1, 33]) torch.int64\n",
|
298 |
+
"torch.Size([1, 768]) torch.float32\n",
|
299 |
+
"torch.Size([1, 806]) torch.float32\n"
|
300 |
+
]
|
301 |
+
},
|
302 |
+
{
|
303 |
+
"ename": "RuntimeError",
|
304 |
+
"evalue": "Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got torch.cuda.FloatTensor instead (while checking arguments for embedding)",
|
305 |
+
"output_type": "error",
|
306 |
+
"traceback": [
|
307 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
308 |
+
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
|
309 |
+
"Cell \u001b[0;32mIn[79], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mforward\u001b[49m\u001b[43m(\u001b[49m\u001b[43maudio\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mimage\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtext\u001b[49m\u001b[43m)\u001b[49m\n",
|
310 |
+
"Cell \u001b[0;32mIn[77], line 38\u001b[0m, in \u001b[0;36mMultiModalPhi2.forward\u001b[0;34m(self, audio, image, text)\u001b[0m\n\u001b[1;32m 36\u001b[0m inputs \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mconcat([text_embed, audio_embed, image_embed], dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m 37\u001b[0m \u001b[38;5;28mprint\u001b[39m(inputs\u001b[38;5;241m.\u001b[39mshape, inputs\u001b[38;5;241m.\u001b[39mdtype)\n\u001b[0;32m---> 38\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mllm\u001b[49m\u001b[43m(\u001b[49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 40\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m outputs\n",
|
311 |
+
"File \u001b[0;32m~/miniconda3/envs/torchenv/lib/python3.10/site-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
312 |
+
"File \u001b[0;32m~/miniconda3/envs/torchenv/lib/python3.10/site-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
|
313 |
+
"File \u001b[0;32m~/miniconda3/envs/torchenv/lib/python3.10/site-packages/accelerate/hooks.py:165\u001b[0m, in \u001b[0;36madd_hook_to_module.<locals>.new_forward\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 163\u001b[0m output \u001b[38;5;241m=\u001b[39m old_forward(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 164\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 165\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43mold_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 166\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m module\u001b[38;5;241m.\u001b[39m_hf_hook\u001b[38;5;241m.\u001b[39mpost_forward(module, output)\n",
|
314 |
+
"File \u001b[0;32m~/.cache/huggingface/modules/transformers_modules/microsoft/phi-2/85d00b03fee509307549d823fdd095473ba5197c/modeling_phi.py:1049\u001b[0m, in \u001b[0;36mPhiForCausalLM.forward\u001b[0;34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 1046\u001b[0m return_dict \u001b[38;5;241m=\u001b[39m return_dict \u001b[38;5;28;01mif\u001b[39;00m return_dict \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39muse_return_dict\n\u001b[1;32m 1048\u001b[0m \u001b[38;5;66;03m# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\u001b[39;00m\n\u001b[0;32m-> 1049\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1050\u001b[0m \u001b[43m \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1051\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1052\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1053\u001b[0m \u001b[43m \u001b[49m\u001b[43mpast_key_values\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1054\u001b[0m \u001b[43m \u001b[49m\u001b[43minputs_embeds\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs_embeds\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1055\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1056\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1057\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1058\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1059\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1061\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 1062\u001b[0m logits \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlm_head(hidden_states)\n",
|
315 |
+
"File \u001b[0;32m~/miniconda3/envs/torchenv/lib/python3.10/site-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
316 |
+
"File \u001b[0;32m~/miniconda3/envs/torchenv/lib/python3.10/site-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
|
317 |
+
"File \u001b[0;32m~/miniconda3/envs/torchenv/lib/python3.10/site-packages/accelerate/hooks.py:165\u001b[0m, in \u001b[0;36madd_hook_to_module.<locals>.new_forward\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 163\u001b[0m output \u001b[38;5;241m=\u001b[39m old_forward(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 164\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 165\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43mold_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 166\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m module\u001b[38;5;241m.\u001b[39m_hf_hook\u001b[38;5;241m.\u001b[39mpost_forward(module, output)\n",
|
318 |
+
"File \u001b[0;32m~/.cache/huggingface/modules/transformers_modules/microsoft/phi-2/85d00b03fee509307549d823fdd095473ba5197c/modeling_phi.py:893\u001b[0m, in \u001b[0;36mPhiModel.forward\u001b[0;34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 890\u001b[0m position_ids \u001b[38;5;241m=\u001b[39m position_ids\u001b[38;5;241m.\u001b[39munsqueeze(\u001b[38;5;241m0\u001b[39m)\n\u001b[1;32m 892\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m inputs_embeds \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 893\u001b[0m inputs_embeds \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43membed_tokens\u001b[49m\u001b[43m(\u001b[49m\u001b[43minput_ids\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 895\u001b[0m inputs_embeds \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39membed_dropout(inputs_embeds)\n\u001b[1;32m 897\u001b[0m \u001b[38;5;66;03m# Attention mask.\u001b[39;00m\n",
|
319 |
+
"File \u001b[0;32m~/miniconda3/envs/torchenv/lib/python3.10/site-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
320 |
+
"File \u001b[0;32m~/miniconda3/envs/torchenv/lib/python3.10/site-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
|
321 |
+
"File \u001b[0;32m~/miniconda3/envs/torchenv/lib/python3.10/site-packages/accelerate/hooks.py:165\u001b[0m, in \u001b[0;36madd_hook_to_module.<locals>.new_forward\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 163\u001b[0m output \u001b[38;5;241m=\u001b[39m old_forward(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 164\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 165\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43mold_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 166\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m module\u001b[38;5;241m.\u001b[39m_hf_hook\u001b[38;5;241m.\u001b[39mpost_forward(module, output)\n",
|
322 |
+
"File \u001b[0;32m~/miniconda3/envs/torchenv/lib/python3.10/site-packages/torch/nn/modules/sparse.py:162\u001b[0m, in \u001b[0;36mEmbedding.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 161\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m: Tensor) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[0;32m--> 162\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43membedding\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 163\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpadding_idx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmax_norm\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 164\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnorm_type\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mscale_grad_by_freq\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msparse\u001b[49m\u001b[43m)\u001b[49m\n",
|
323 |
+
"File \u001b[0;32m~/miniconda3/envs/torchenv/lib/python3.10/site-packages/torch/nn/functional.py:2233\u001b[0m, in \u001b[0;36membedding\u001b[0;34m(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)\u001b[0m\n\u001b[1;32m 2227\u001b[0m \u001b[38;5;66;03m# Note [embedding_renorm set_grad_enabled]\u001b[39;00m\n\u001b[1;32m 2228\u001b[0m \u001b[38;5;66;03m# XXX: equivalent to\u001b[39;00m\n\u001b[1;32m 2229\u001b[0m \u001b[38;5;66;03m# with torch.no_grad():\u001b[39;00m\n\u001b[1;32m 2230\u001b[0m \u001b[38;5;66;03m# torch.embedding_renorm_\u001b[39;00m\n\u001b[1;32m 2231\u001b[0m \u001b[38;5;66;03m# remove once script supports set_grad_enabled\u001b[39;00m\n\u001b[1;32m 2232\u001b[0m _no_grad_embedding_renorm_(weight, \u001b[38;5;28minput\u001b[39m, max_norm, norm_type)\n\u001b[0;32m-> 2233\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43membedding\u001b[49m\u001b[43m(\u001b[49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpadding_idx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mscale_grad_by_freq\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msparse\u001b[49m\u001b[43m)\u001b[49m\n",
|
324 |
+
"\u001b[0;31mRuntimeError\u001b[0m: Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got torch.cuda.FloatTensor instead (while checking arguments for embedding)"
|
325 |
+
]
|
326 |
+
}
|
327 |
+
],
|
328 |
+
"source": [
|
329 |
+
"model.forward(audio, image, text)"
|
330 |
+
]
|
331 |
+
},
|
332 |
+
{
|
333 |
+
"cell_type": "code",
|
334 |
+
"execution_count": null,
|
335 |
+
"id": "4ca96caf-82e2-4f07-87b3-8654dfdc89aa",
|
336 |
+
"metadata": {},
|
337 |
+
"outputs": [],
|
338 |
+
"source": []
|
339 |
+
}
|
340 |
+
],
|
341 |
+
"metadata": {
|
342 |
+
"kernelspec": {
|
343 |
+
"display_name": "Python 3 (ipykernel)",
|
344 |
+
"language": "python",
|
345 |
+
"name": "python3"
|
346 |
+
},
|
347 |
+
"language_info": {
|
348 |
+
"codemirror_mode": {
|
349 |
+
"name": "ipython",
|
350 |
+
"version": 3
|
351 |
+
},
|
352 |
+
"file_extension": ".py",
|
353 |
+
"mimetype": "text/x-python",
|
354 |
+
"name": "python",
|
355 |
+
"nbconvert_exporter": "python",
|
356 |
+
"pygments_lexer": "ipython3",
|
357 |
+
"version": "3.10.12"
|
358 |
+
}
|
359 |
+
},
|
360 |
+
"nbformat": 4,
|
361 |
+
"nbformat_minor": 5
|
362 |
+
}
|
Experiments/pretrain_data_check.ipynb
ADDED
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 5,
|
6 |
+
"id": "61c272f2-edbe-4b7d-8fec-3ab431400cd3",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [],
|
9 |
+
"source": [
|
10 |
+
"import json"
|
11 |
+
]
|
12 |
+
},
|
13 |
+
{
|
14 |
+
"cell_type": "code",
|
15 |
+
"execution_count": 2,
|
16 |
+
"id": "e9dfd7d7-1685-4fc7-bbb9-3905c32d8ba1",
|
17 |
+
"metadata": {},
|
18 |
+
"outputs": [],
|
19 |
+
"source": [
|
20 |
+
"with open(\"metadata.json\", \"rb\") as f:\n",
|
21 |
+
" metadata = json.load(f)"
|
22 |
+
]
|
23 |
+
},
|
24 |
+
{
|
25 |
+
"cell_type": "code",
|
26 |
+
"execution_count": 4,
|
27 |
+
"id": "70bdba48-db01-42ac-8d89-edc69d7d7672",
|
28 |
+
"metadata": {},
|
29 |
+
"outputs": [
|
30 |
+
{
|
31 |
+
"data": {
|
32 |
+
"text/plain": [
|
33 |
+
"595375"
|
34 |
+
]
|
35 |
+
},
|
36 |
+
"execution_count": 4,
|
37 |
+
"metadata": {},
|
38 |
+
"output_type": "execute_result"
|
39 |
+
}
|
40 |
+
],
|
41 |
+
"source": [
|
42 |
+
"len(metadata)"
|
43 |
+
]
|
44 |
+
},
|
45 |
+
{
|
46 |
+
"cell_type": "code",
|
47 |
+
"execution_count": 14,
|
48 |
+
"id": "59e193cc-0dd8-4f7e-959a-fbad0133d76c",
|
49 |
+
"metadata": {},
|
50 |
+
"outputs": [],
|
51 |
+
"source": [
|
52 |
+
"with open(\"blip_laion_cc_sbu_558k.jsonblip_laion_cc_sbu_558k.json\", \"rb\") as f:\n",
|
53 |
+
" data = json.load(f)"
|
54 |
+
]
|
55 |
+
},
|
56 |
+
{
|
57 |
+
"cell_type": "code",
|
58 |
+
"execution_count": 7,
|
59 |
+
"id": "f3157f41-269b-4f7a-b3ba-9be711babe02",
|
60 |
+
"metadata": {},
|
61 |
+
"outputs": [
|
62 |
+
{
|
63 |
+
"data": {
|
64 |
+
"text/plain": [
|
65 |
+
"{'id': '004539375',\n",
|
66 |
+
" 'image': '00453/004539375.jpg',\n",
|
67 |
+
" 'conversations': [{'from': 'human',\n",
|
68 |
+
" 'value': 'Render a clear and concise summary of the photo.\\n<image>'},\n",
|
69 |
+
" {'from': 'gpt',\n",
|
70 |
+
" 'value': 'select luxury furniture 3 - inch gel memory foam mattress topper'}]}"
|
71 |
+
]
|
72 |
+
},
|
73 |
+
"execution_count": 7,
|
74 |
+
"metadata": {},
|
75 |
+
"output_type": "execute_result"
|
76 |
+
}
|
77 |
+
],
|
78 |
+
"source": [
|
79 |
+
"data[0]"
|
80 |
+
]
|
81 |
+
},
|
82 |
+
{
|
83 |
+
"cell_type": "code",
|
84 |
+
"execution_count": 8,
|
85 |
+
"id": "50d8a051-1526-47dd-ad71-d3c66f7bd34e",
|
86 |
+
"metadata": {},
|
87 |
+
"outputs": [
|
88 |
+
{
|
89 |
+
"data": {
|
90 |
+
"text/plain": [
|
91 |
+
"{'id': '004374662',\n",
|
92 |
+
" 'image': '00437/004374662.jpg',\n",
|
93 |
+
" 'conversations': [{'from': 'human',\n",
|
94 |
+
" 'value': 'Give a brief description of the image.\\n<image>'},\n",
|
95 |
+
" {'from': 'gpt', 'value': 'the north face duffel bag camo large'}]}"
|
96 |
+
]
|
97 |
+
},
|
98 |
+
"execution_count": 8,
|
99 |
+
"metadata": {},
|
100 |
+
"output_type": "execute_result"
|
101 |
+
}
|
102 |
+
],
|
103 |
+
"source": [
|
104 |
+
"data[234]"
|
105 |
+
]
|
106 |
+
},
|
107 |
+
{
|
108 |
+
"cell_type": "code",
|
109 |
+
"execution_count": 17,
|
110 |
+
"id": "2e6d5664-4583-49a6-93cc-079ee2d1ff6c",
|
111 |
+
"metadata": {},
|
112 |
+
"outputs": [
|
113 |
+
{
|
114 |
+
"data": {
|
115 |
+
"text/plain": [
|
116 |
+
"558128"
|
117 |
+
]
|
118 |
+
},
|
119 |
+
"execution_count": 17,
|
120 |
+
"metadata": {},
|
121 |
+
"output_type": "execute_result"
|
122 |
+
}
|
123 |
+
],
|
124 |
+
"source": [
|
125 |
+
"len(data)"
|
126 |
+
]
|
127 |
+
},
|
128 |
+
{
|
129 |
+
"cell_type": "code",
|
130 |
+
"execution_count": 10,
|
131 |
+
"id": "11ed106d-6bef-482c-a456-5eaaf2025534",
|
132 |
+
"metadata": {},
|
133 |
+
"outputs": [
|
134 |
+
{
|
135 |
+
"data": {
|
136 |
+
"text/plain": [
|
137 |
+
"{'id': 'GCC_train_001749371',\n",
|
138 |
+
" 'image': 'GCC_train_001749371.jpg',\n",
|
139 |
+
" 'caption': 'if you are dreaming of simpler or off - the - grid living , a yurt is a fantastic option',\n",
|
140 |
+
" 'blip_caption': 'a white and tan yurt sitting on a dirt road',\n",
|
141 |
+
" 'url': 'https://i.pinimg.com/736x/14/7b/64/147b64467ee966d9a578097bb70475ad--yurt-kits-small-space-living.jpg'}"
|
142 |
+
]
|
143 |
+
},
|
144 |
+
"execution_count": 10,
|
145 |
+
"metadata": {},
|
146 |
+
"output_type": "execute_result"
|
147 |
+
}
|
148 |
+
],
|
149 |
+
"source": [
|
150 |
+
"metadata[67]"
|
151 |
+
]
|
152 |
+
},
|
153 |
+
{
|
154 |
+
"cell_type": "code",
|
155 |
+
"execution_count": 15,
|
156 |
+
"id": "ce8adcec-2499-4be3-be1d-7313fe54e96a",
|
157 |
+
"metadata": {},
|
158 |
+
"outputs": [
|
159 |
+
{
|
160 |
+
"data": {
|
161 |
+
"text/plain": [
|
162 |
+
"{'id': '000466761',\n",
|
163 |
+
" 'image': '00046/000466761.jpg',\n",
|
164 |
+
" 'conversations': [{'from': 'human',\n",
|
165 |
+
" 'value': '<image>\\nProvide a brief description of the given image.'},\n",
|
166 |
+
" {'from': 'gpt',\n",
|
167 |
+
" 'value': 'a clipboard and a pen with the words public health emergency next to it on a white table'}]}"
|
168 |
+
]
|
169 |
+
},
|
170 |
+
"execution_count": 15,
|
171 |
+
"metadata": {},
|
172 |
+
"output_type": "execute_result"
|
173 |
+
}
|
174 |
+
],
|
175 |
+
"source": [
|
176 |
+
"data[67]"
|
177 |
+
]
|
178 |
+
},
|
179 |
+
{
|
180 |
+
"cell_type": "code",
|
181 |
+
"execution_count": 16,
|
182 |
+
"id": "068313b6-6379-4ca2-892c-682634d3581e",
|
183 |
+
"metadata": {},
|
184 |
+
"outputs": [
|
185 |
+
{
|
186 |
+
"data": {
|
187 |
+
"text/plain": [
|
188 |
+
"list"
|
189 |
+
]
|
190 |
+
},
|
191 |
+
"execution_count": 16,
|
192 |
+
"metadata": {},
|
193 |
+
"output_type": "execute_result"
|
194 |
+
}
|
195 |
+
],
|
196 |
+
"source": [
|
197 |
+
"type(data)"
|
198 |
+
]
|
199 |
+
},
|
200 |
+
{
|
201 |
+
"cell_type": "code",
|
202 |
+
"execution_count": 24,
|
203 |
+
"id": "9ec33b51-4a0b-4a1e-81f7-2fda7cddb25f",
|
204 |
+
"metadata": {},
|
205 |
+
"outputs": [],
|
206 |
+
"source": [
|
207 |
+
"sample_data = data[:200000]"
|
208 |
+
]
|
209 |
+
},
|
210 |
+
{
|
211 |
+
"cell_type": "code",
|
212 |
+
"execution_count": 25,
|
213 |
+
"id": "095685e5-40f1-4d84-8280-ef74fa56c5a2",
|
214 |
+
"metadata": {},
|
215 |
+
"outputs": [
|
216 |
+
{
|
217 |
+
"data": {
|
218 |
+
"text/plain": [
|
219 |
+
"200000"
|
220 |
+
]
|
221 |
+
},
|
222 |
+
"execution_count": 25,
|
223 |
+
"metadata": {},
|
224 |
+
"output_type": "execute_result"
|
225 |
+
}
|
226 |
+
],
|
227 |
+
"source": [
|
228 |
+
"len(sample_data)"
|
229 |
+
]
|
230 |
+
},
|
231 |
+
{
|
232 |
+
"cell_type": "code",
|
233 |
+
"execution_count": 26,
|
234 |
+
"id": "ffbad552-23fd-475f-8e9a-7118bcc4f51e",
|
235 |
+
"metadata": {},
|
236 |
+
"outputs": [],
|
237 |
+
"source": [
|
238 |
+
"with open(\"llava-phi/pretrain_data/blip_sample.json\", \"w\") as f:\n",
|
239 |
+
" json.dump(sample_data, f)"
|
240 |
+
]
|
241 |
+
},
|
242 |
+
{
|
243 |
+
"cell_type": "code",
|
244 |
+
"execution_count": 27,
|
245 |
+
"id": "69a05d25-6f3b-40c0-a3b5-e185ff526471",
|
246 |
+
"metadata": {},
|
247 |
+
"outputs": [],
|
248 |
+
"source": [
|
249 |
+
"with open(\"llava-phi/pretrain_data/blip_sample.json\", \"rb\") as f:\n",
|
250 |
+
" sample = json.load(f)"
|
251 |
+
]
|
252 |
+
},
|
253 |
+
{
|
254 |
+
"cell_type": "code",
|
255 |
+
"execution_count": 28,
|
256 |
+
"id": "200eea06-dfd6-4b3a-bb91-82af7d363951",
|
257 |
+
"metadata": {},
|
258 |
+
"outputs": [
|
259 |
+
{
|
260 |
+
"data": {
|
261 |
+
"text/plain": [
|
262 |
+
"200000"
|
263 |
+
]
|
264 |
+
},
|
265 |
+
"execution_count": 28,
|
266 |
+
"metadata": {},
|
267 |
+
"output_type": "execute_result"
|
268 |
+
}
|
269 |
+
],
|
270 |
+
"source": [
|
271 |
+
"len(sample)"
|
272 |
+
]
|
273 |
+
},
|
274 |
+
{
|
275 |
+
"cell_type": "code",
|
276 |
+
"execution_count": null,
|
277 |
+
"id": "f86caa1e-edea-4a9c-934f-5420ede80d0d",
|
278 |
+
"metadata": {},
|
279 |
+
"outputs": [],
|
280 |
+
"source": []
|
281 |
+
}
|
282 |
+
],
|
283 |
+
"metadata": {
|
284 |
+
"kernelspec": {
|
285 |
+
"display_name": "Python 3 (ipykernel)",
|
286 |
+
"language": "python",
|
287 |
+
"name": "python3"
|
288 |
+
},
|
289 |
+
"language_info": {
|
290 |
+
"codemirror_mode": {
|
291 |
+
"name": "ipython",
|
292 |
+
"version": 3
|
293 |
+
},
|
294 |
+
"file_extension": ".py",
|
295 |
+
"mimetype": "text/x-python",
|
296 |
+
"name": "python",
|
297 |
+
"nbconvert_exporter": "python",
|
298 |
+
"pygments_lexer": "ipython3",
|
299 |
+
"version": "3.10.12"
|
300 |
+
}
|
301 |
+
},
|
302 |
+
"nbformat": 4,
|
303 |
+
"nbformat_minor": 5
|
304 |
+
}
|
Experiments/whispher_exp.ipynb
ADDED
@@ -0,0 +1,500 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 9,
|
6 |
+
"id": "bb4dd66b-0c17-48d4-9d34-f48cece2feb5",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [],
|
9 |
+
"source": [
|
10 |
+
"# !pip install soundfile\n",
|
11 |
+
"# !pip install librosa"
|
12 |
+
]
|
13 |
+
},
|
14 |
+
{
|
15 |
+
"cell_type": "code",
|
16 |
+
"execution_count": 1,
|
17 |
+
"id": "6e9386ea-4862-4f5b-a02f-d656e1a5ab9e",
|
18 |
+
"metadata": {},
|
19 |
+
"outputs": [],
|
20 |
+
"source": [
|
21 |
+
"from transformers import WhisperProcessor, WhisperForConditionalGeneration\n",
|
22 |
+
"from datasets import load_dataset"
|
23 |
+
]
|
24 |
+
},
|
25 |
+
{
|
26 |
+
"cell_type": "code",
|
27 |
+
"execution_count": 2,
|
28 |
+
"id": "914ab2b4-389d-4c48-8d1d-1250356646ac",
|
29 |
+
"metadata": {},
|
30 |
+
"outputs": [],
|
31 |
+
"source": [
|
32 |
+
"# load model and processor\n",
|
33 |
+
"processor = WhisperProcessor.from_pretrained(\"openai/whisper-tiny\")\n",
|
34 |
+
"model = WhisperForConditionalGeneration.from_pretrained(\"openai/whisper-tiny\")\n",
|
35 |
+
"model.config.forced_decoder_ids = None\n",
|
36 |
+
"\n",
|
37 |
+
"# load dummy dataset and read audio files\n",
|
38 |
+
"ds = load_dataset(\"hf-internal-testing/librispeech_asr_dummy\", \"clean\", split=\"validation\")\n",
|
39 |
+
"sample = ds[0][\"audio\"]"
|
40 |
+
]
|
41 |
+
},
|
42 |
+
{
|
43 |
+
"cell_type": "code",
|
44 |
+
"execution_count": 3,
|
45 |
+
"id": "2b299bab-1228-48d9-a8a5-3d5b6c52162d",
|
46 |
+
"metadata": {},
|
47 |
+
"outputs": [
|
48 |
+
{
|
49 |
+
"data": {
|
50 |
+
"text/plain": [
|
51 |
+
"{'path': '/home/ravi.naik/.cache/huggingface/datasets/downloads/extracted/431c2c946d216530b2666a0e7ffa5ac3f5b3da89dd28858a9de6c78fae7caa4a/dev_clean/1272/128104/1272-128104-0000.flac',\n",
|
52 |
+
" 'array': array([0.00238037, 0.0020752 , 0.00198364, ..., 0.00042725, 0.00057983,\n",
|
53 |
+
" 0.0010376 ]),\n",
|
54 |
+
" 'sampling_rate': 16000}"
|
55 |
+
]
|
56 |
+
},
|
57 |
+
"execution_count": 3,
|
58 |
+
"metadata": {},
|
59 |
+
"output_type": "execute_result"
|
60 |
+
}
|
61 |
+
],
|
62 |
+
"source": [
|
63 |
+
"sample"
|
64 |
+
]
|
65 |
+
},
|
66 |
+
{
|
67 |
+
"cell_type": "code",
|
68 |
+
"execution_count": 4,
|
69 |
+
"id": "b7e570a1-cf5c-450c-a7b6-49b45a10d2df",
|
70 |
+
"metadata": {},
|
71 |
+
"outputs": [],
|
72 |
+
"source": [
|
73 |
+
"input_features = processor(sample[\"array\"], sampling_rate=sample[\"sampling_rate\"], return_tensors=\"pt\").input_features "
|
74 |
+
]
|
75 |
+
},
|
76 |
+
{
|
77 |
+
"cell_type": "code",
|
78 |
+
"execution_count": 5,
|
79 |
+
"id": "584e920b-a7fd-402d-95dd-3b9128cd34bb",
|
80 |
+
"metadata": {},
|
81 |
+
"outputs": [],
|
82 |
+
"source": [
|
83 |
+
"# generate token ids\n",
|
84 |
+
"predicted_ids = model.generate(input_features)\n",
|
85 |
+
"# decode token ids to text\n",
|
86 |
+
"transcription = processor.batch_decode(predicted_ids, skip_special_tokens=False)\n",
|
87 |
+
"\n",
|
88 |
+
"transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)"
|
89 |
+
]
|
90 |
+
},
|
91 |
+
{
|
92 |
+
"cell_type": "code",
|
93 |
+
"execution_count": 6,
|
94 |
+
"id": "b27ab660-861b-49d1-81f9-f51cb7f9d8d8",
|
95 |
+
"metadata": {},
|
96 |
+
"outputs": [
|
97 |
+
{
|
98 |
+
"data": {
|
99 |
+
"text/plain": [
|
100 |
+
"[' Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.']"
|
101 |
+
]
|
102 |
+
},
|
103 |
+
"execution_count": 6,
|
104 |
+
"metadata": {},
|
105 |
+
"output_type": "execute_result"
|
106 |
+
}
|
107 |
+
],
|
108 |
+
"source": [
|
109 |
+
"transcription"
|
110 |
+
]
|
111 |
+
},
|
112 |
+
{
|
113 |
+
"cell_type": "code",
|
114 |
+
"execution_count": 3,
|
115 |
+
"id": "eca553b8-68f6-493d-b567-3d526b49ae1b",
|
116 |
+
"metadata": {},
|
117 |
+
"outputs": [],
|
118 |
+
"source": [
|
119 |
+
"import torch\n",
|
120 |
+
"from torch import nn"
|
121 |
+
]
|
122 |
+
},
|
123 |
+
{
|
124 |
+
"cell_type": "code",
|
125 |
+
"execution_count": 4,
|
126 |
+
"id": "c619a4cf-9068-4e4d-8139-e16d15345f4f",
|
127 |
+
"metadata": {},
|
128 |
+
"outputs": [],
|
129 |
+
"source": [
|
130 |
+
"from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoTokenizer"
|
131 |
+
]
|
132 |
+
},
|
133 |
+
{
|
134 |
+
"cell_type": "code",
|
135 |
+
"execution_count": 5,
|
136 |
+
"id": "47d5b1ff-ab0f-4d11-af64-d2fa2be39286",
|
137 |
+
"metadata": {},
|
138 |
+
"outputs": [
|
139 |
+
{
|
140 |
+
"name": "stderr",
|
141 |
+
"output_type": "stream",
|
142 |
+
"text": [
|
143 |
+
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
|
144 |
+
]
|
145 |
+
}
|
146 |
+
],
|
147 |
+
"source": [
|
148 |
+
"model_name = \"microsoft/phi-2\"\n",
|
149 |
+
"phi2_tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)\n",
|
150 |
+
"phi2_tokenizer.pad_token = phi2_tokenizer.eos_token"
|
151 |
+
]
|
152 |
+
},
|
153 |
+
{
|
154 |
+
"cell_type": "code",
|
155 |
+
"execution_count": 6,
|
156 |
+
"id": "0b36b3f0-db5b-4029-9072-0a53bcab315a",
|
157 |
+
"metadata": {},
|
158 |
+
"outputs": [
|
159 |
+
{
|
160 |
+
"ename": "NameError",
|
161 |
+
"evalue": "name 'transcription' is not defined",
|
162 |
+
"output_type": "error",
|
163 |
+
"traceback": [
|
164 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
165 |
+
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
|
166 |
+
"Cell \u001b[0;32mIn[6], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m tokens \u001b[38;5;241m=\u001b[39m phi2_tokenizer(\u001b[38;5;241m*\u001b[39m\u001b[43mtranscription\u001b[49m, return_tensors\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpt\u001b[39m\u001b[38;5;124m\"\u001b[39m, return_attention_mask\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n",
|
167 |
+
"\u001b[0;31mNameError\u001b[0m: name 'transcription' is not defined"
|
168 |
+
]
|
169 |
+
}
|
170 |
+
],
|
171 |
+
"source": [
|
172 |
+
"tokens = phi2_tokenizer(*transcription, return_tensors=\"pt\", return_attention_mask=False)"
|
173 |
+
]
|
174 |
+
},
|
175 |
+
{
|
176 |
+
"cell_type": "code",
|
177 |
+
"execution_count": 22,
|
178 |
+
"id": "91f6d3d3-bb00-434f-a91e-6952375890d0",
|
179 |
+
"metadata": {},
|
180 |
+
"outputs": [
|
181 |
+
{
|
182 |
+
"data": {
|
183 |
+
"text/plain": [
|
184 |
+
"{'input_ids': tensor([[ 1770, 13, 2264, 346, 353, 318, 262, 46329, 286, 262,\n",
|
185 |
+
" 3504, 6097, 290, 356, 389, 9675, 284, 7062, 465, 21443,\n",
|
186 |
+
" 13]])}"
|
187 |
+
]
|
188 |
+
},
|
189 |
+
"execution_count": 22,
|
190 |
+
"metadata": {},
|
191 |
+
"output_type": "execute_result"
|
192 |
+
}
|
193 |
+
],
|
194 |
+
"source": [
|
195 |
+
"tokens"
|
196 |
+
]
|
197 |
+
},
|
198 |
+
{
|
199 |
+
"cell_type": "code",
|
200 |
+
"execution_count": 12,
|
201 |
+
"id": "533191d9-4b3b-417a-918d-6fe854f24b50",
|
202 |
+
"metadata": {},
|
203 |
+
"outputs": [
|
204 |
+
{
|
205 |
+
"name": "stderr",
|
206 |
+
"output_type": "stream",
|
207 |
+
"text": [
|
208 |
+
"A new version of the following files was downloaded from https://huggingface.co/microsoft/phi-2:\n",
|
209 |
+
"- configuration_phi.py\n",
|
210 |
+
". Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.\n"
|
211 |
+
]
|
212 |
+
},
|
213 |
+
{
|
214 |
+
"data": {
|
215 |
+
"application/vnd.jupyter.widget-view+json": {
|
216 |
+
"model_id": "2a65a119388b4cb4b123b532176e786e",
|
217 |
+
"version_major": 2,
|
218 |
+
"version_minor": 0
|
219 |
+
},
|
220 |
+
"text/plain": [
|
221 |
+
"modeling_phi.py: 0%| | 0.00/62.7k [00:00<?, ?B/s]"
|
222 |
+
]
|
223 |
+
},
|
224 |
+
"metadata": {},
|
225 |
+
"output_type": "display_data"
|
226 |
+
},
|
227 |
+
{
|
228 |
+
"name": "stderr",
|
229 |
+
"output_type": "stream",
|
230 |
+
"text": [
|
231 |
+
"A new version of the following files was downloaded from https://huggingface.co/microsoft/phi-2:\n",
|
232 |
+
"- modeling_phi.py\n",
|
233 |
+
". Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.\n"
|
234 |
+
]
|
235 |
+
},
|
236 |
+
{
|
237 |
+
"data": {
|
238 |
+
"application/vnd.jupyter.widget-view+json": {
|
239 |
+
"model_id": "7183811844304c16b72d53fe11098a74",
|
240 |
+
"version_major": 2,
|
241 |
+
"version_minor": 0
|
242 |
+
},
|
243 |
+
"text/plain": [
|
244 |
+
"Downloading shards: 0%| | 0/2 [00:00<?, ?it/s]"
|
245 |
+
]
|
246 |
+
},
|
247 |
+
"metadata": {},
|
248 |
+
"output_type": "display_data"
|
249 |
+
},
|
250 |
+
{
|
251 |
+
"data": {
|
252 |
+
"application/vnd.jupyter.widget-view+json": {
|
253 |
+
"model_id": "3e78fe144e8f42139a4d7a1830dbf192",
|
254 |
+
"version_major": 2,
|
255 |
+
"version_minor": 0
|
256 |
+
},
|
257 |
+
"text/plain": [
|
258 |
+
"Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
|
259 |
+
]
|
260 |
+
},
|
261 |
+
"metadata": {},
|
262 |
+
"output_type": "display_data"
|
263 |
+
}
|
264 |
+
],
|
265 |
+
"source": [
|
266 |
+
"bnb_config = BitsAndBytesConfig(\n",
|
267 |
+
" load_in_4bit=True,\n",
|
268 |
+
" bnb_4bit_quant_type=\"nf4\",\n",
|
269 |
+
" bnb_4bit_compute_dtype=torch.float16,\n",
|
270 |
+
")\n",
|
271 |
+
"\n",
|
272 |
+
"model = AutoModelForCausalLM.from_pretrained(\n",
|
273 |
+
" model_name,\n",
|
274 |
+
" quantization_config=bnb_config,\n",
|
275 |
+
" trust_remote_code=True,\n",
|
276 |
+
" device_map=\"cuda:0\"\n",
|
277 |
+
")\n",
|
278 |
+
"model.config.use_cache = False"
|
279 |
+
]
|
280 |
+
},
|
281 |
+
{
|
282 |
+
"cell_type": "code",
|
283 |
+
"execution_count": 19,
|
284 |
+
"id": "155c054a-a00f-4ed5-bfff-1ad64889e7f1",
|
285 |
+
"metadata": {},
|
286 |
+
"outputs": [
|
287 |
+
{
|
288 |
+
"data": {
|
289 |
+
"text/plain": [
|
290 |
+
"[' Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.\\n']"
|
291 |
+
]
|
292 |
+
},
|
293 |
+
"execution_count": 19,
|
294 |
+
"metadata": {},
|
295 |
+
"output_type": "execute_result"
|
296 |
+
}
|
297 |
+
],
|
298 |
+
"source": [
|
299 |
+
"phi2_tokenizer.batch_decode(model.generate(**tokens))"
|
300 |
+
]
|
301 |
+
},
|
302 |
+
{
|
303 |
+
"cell_type": "code",
|
304 |
+
"execution_count": 7,
|
305 |
+
"id": "04f940c9-586d-4937-ae31-cc0f96d33e92",
|
306 |
+
"metadata": {},
|
307 |
+
"outputs": [],
|
308 |
+
"source": [
|
309 |
+
"class AudioLanguageConnector:\n",
|
310 |
+
" def __init__(self):\n",
|
311 |
+
" model_name = \"microsoft/phi-2\"\n",
|
312 |
+
" self.phi2_tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)\n",
|
313 |
+
" self.phi2_tokenizer.pad_token = self.phi2_tokenizer.eos_token\n",
|
314 |
+
"\n",
|
315 |
+
" def __call__(self, text):\n",
|
316 |
+
" text = f\"<audio_start> {text} <audio_end>\"\n",
|
317 |
+
" tokens = self.phi2_tokenizer(text, return_tensors=\"pt\", return_attention_mask=False)\n",
|
318 |
+
" return tokens\n",
|
319 |
+
" \n",
|
320 |
+
"\n",
|
321 |
+
"class WhisperWithProjection:\n",
|
322 |
+
" def __init__(self):\n",
|
323 |
+
" self.processor = WhisperProcessor.from_pretrained(\"openai/whisper-tiny\")\n",
|
324 |
+
" self.model = WhisperForConditionalGeneration.from_pretrained(\"openai/whisper-tiny\")\n",
|
325 |
+
" self.model.config.forced_decoder_ids = None\n",
|
326 |
+
" self.audio_language_connector = AudioLanguageConnector()\n",
|
327 |
+
" \n",
|
328 |
+
" def forward(self, audio):\n",
|
329 |
+
" input_features = self.processor(audio[\"array\"],\n",
|
330 |
+
" sampling_rate=audio[\"sampling_rate\"],\n",
|
331 |
+
" return_tensors=\"pt\").input_features\n",
|
332 |
+
" # generate token ids\n",
|
333 |
+
" predicted_ids = self.model.generate(input_features)\n",
|
334 |
+
" # decode token ids to text \n",
|
335 |
+
" transcription = self.processor.batch_decode(predicted_ids, skip_special_tokens=True)\n",
|
336 |
+
"\n",
|
337 |
+
" audio_embeddings = self.audio_language_connector(transcription)\n",
|
338 |
+
" return audio_embeddings"
|
339 |
+
]
|
340 |
+
},
|
341 |
+
{
|
342 |
+
"cell_type": "code",
|
343 |
+
"execution_count": 8,
|
344 |
+
"id": "2b1f8f44-bfe6-413c-9e32-c38fa5517981",
|
345 |
+
"metadata": {},
|
346 |
+
"outputs": [],
|
347 |
+
"source": [
|
348 |
+
"class TextModality:\n",
|
349 |
+
" def __init__(self):\n",
|
350 |
+
" model_name = \"microsoft/phi-2\"\n",
|
351 |
+
" self.phi2_tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)\n",
|
352 |
+
" self.phi2_tokenizer.pad_token = self.phi2_tokenizer.eos_token\n",
|
353 |
+
"\n",
|
354 |
+
" def __call__(self, text):\n",
|
355 |
+
" tokens = self.phi2_tokenizer(text, return_tensors=\"pt\", return_attention_mask=False)\n",
|
356 |
+
" return tokens"
|
357 |
+
]
|
358 |
+
},
|
359 |
+
{
|
360 |
+
"cell_type": "code",
|
361 |
+
"execution_count": 15,
|
362 |
+
"id": "21c51648-abb6-4bbd-b4c1-509967a69337",
|
363 |
+
"metadata": {},
|
364 |
+
"outputs": [],
|
365 |
+
"source": [
|
366 |
+
"class MultiModalPhi2:\n",
|
367 |
+
" def __init__(self):\n",
|
368 |
+
" self.text_modality = TextModality()\n",
|
369 |
+
" self.whisper_w_proj = WhisperWithProjection()\n",
|
370 |
+
" self.llm = self.load_llm()\n",
|
371 |
+
"\n",
|
372 |
+
" def load_llm(self):\n",
|
373 |
+
" bnb_config = BitsAndBytesConfig(\n",
|
374 |
+
" load_in_4bit=True,\n",
|
375 |
+
" bnb_4bit_quant_type=\"nf4\",\n",
|
376 |
+
" bnb_4bit_compute_dtype=torch.float16)\n",
|
377 |
+
" \n",
|
378 |
+
" model = AutoModelForCausalLM.from_pretrained(\n",
|
379 |
+
" model_name,\n",
|
380 |
+
" quantization_config=bnb_config,\n",
|
381 |
+
" trust_remote_code=True,\n",
|
382 |
+
" device_map=\"cuda:0\"\n",
|
383 |
+
" )\n",
|
384 |
+
" model.config.use_cache = False\n",
|
385 |
+
" return model\n",
|
386 |
+
"\n",
|
387 |
+
" def generate(self, audio, text):\n",
|
388 |
+
" text_embeddings = self.text_modality(text)\n",
|
389 |
+
" audio_embeddings = self.whisper_w_proj.forward(audio)\n",
|
390 |
+
" inputs = torch.concat([text_embeddings[\"input_ids\"], audio_embeddings[\"input_ids\"]], dim=1)\n",
|
391 |
+
" \n",
|
392 |
+
" # outputs = self.llm.generate(inputs, max_length=200)\n",
|
393 |
+
" outputs = self.llm(inputs)\n",
|
394 |
+
" return outputs\n",
|
395 |
+
" \n",
|
396 |
+
" # text = self.text_modality.phi2_tokenizer.batch_decode(outputs)[0]\n",
|
397 |
+
" # print(text)"
|
398 |
+
]
|
399 |
+
},
|
400 |
+
{
|
401 |
+
"cell_type": "code",
|
402 |
+
"execution_count": 16,
|
403 |
+
"id": "472a00cb-bae9-4c09-a0ef-bc57881b5e2c",
|
404 |
+
"metadata": {},
|
405 |
+
"outputs": [
|
406 |
+
{
|
407 |
+
"name": "stderr",
|
408 |
+
"output_type": "stream",
|
409 |
+
"text": [
|
410 |
+
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n",
|
411 |
+
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
|
412 |
+
]
|
413 |
+
},
|
414 |
+
{
|
415 |
+
"data": {
|
416 |
+
"application/vnd.jupyter.widget-view+json": {
|
417 |
+
"model_id": "2236e6b1e26d444fa3d48181ba1a6cf9",
|
418 |
+
"version_major": 2,
|
419 |
+
"version_minor": 0
|
420 |
+
},
|
421 |
+
"text/plain": [
|
422 |
+
"Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
|
423 |
+
]
|
424 |
+
},
|
425 |
+
"metadata": {},
|
426 |
+
"output_type": "display_data"
|
427 |
+
}
|
428 |
+
],
|
429 |
+
"source": [
|
430 |
+
"multi_modal_phi = MultiModalPhi2()"
|
431 |
+
]
|
432 |
+
},
|
433 |
+
{
|
434 |
+
"cell_type": "code",
|
435 |
+
"execution_count": 17,
|
436 |
+
"id": "c350f2d3-0929-4c46-b63d-ff92dea437f3",
|
437 |
+
"metadata": {},
|
438 |
+
"outputs": [
|
439 |
+
{
|
440 |
+
"data": {
|
441 |
+
"text/plain": [
|
442 |
+
"CausalLMOutputWithPast(loss={'logits': tensor([[[ 6.9531, 9.9375, 7.0234, ..., 2.0020, 2.0020, 2.0000],\n",
|
443 |
+
" [ 8.9062, 12.1172, 7.5977, ..., -1.2012, -1.2012, -1.2012],\n",
|
444 |
+
" [ 7.0273, 5.3477, 3.6328, ..., -4.2070, -4.2070, -4.2070],\n",
|
445 |
+
" ...,\n",
|
446 |
+
" [ 7.0234, 7.4414, 9.1016, ..., 1.0117, 1.0127, 1.0117],\n",
|
447 |
+
" [ 9.4531, 10.0391, 9.7578, ..., 0.0776, 0.0775, 0.0764],\n",
|
448 |
+
" [ 8.0703, 6.6445, 5.5156, ..., -1.9268, -1.9268, -1.9277]]],\n",
|
449 |
+
" grad_fn=<ToCopyBackward0>)}, logits=tensor([[[ 6.9531, 9.9375, 7.0234, ..., 2.0020, 2.0020, 2.0000],\n",
|
450 |
+
" [ 8.9062, 12.1172, 7.5977, ..., -1.2012, -1.2012, -1.2012],\n",
|
451 |
+
" [ 7.0273, 5.3477, 3.6328, ..., -4.2070, -4.2070, -4.2070],\n",
|
452 |
+
" ...,\n",
|
453 |
+
" [ 7.0234, 7.4414, 9.1016, ..., 1.0117, 1.0127, 1.0117],\n",
|
454 |
+
" [ 9.4531, 10.0391, 9.7578, ..., 0.0776, 0.0775, 0.0764],\n",
|
455 |
+
" [ 8.0703, 6.6445, 5.5156, ..., -1.9268, -1.9268, -1.9277]]],\n",
|
456 |
+
" grad_fn=<ToCopyBackward0>), past_key_values=None, hidden_states=None, attentions=None)"
|
457 |
+
]
|
458 |
+
},
|
459 |
+
"execution_count": 17,
|
460 |
+
"metadata": {},
|
461 |
+
"output_type": "execute_result"
|
462 |
+
}
|
463 |
+
],
|
464 |
+
"source": [
|
465 |
+
"audio = sample\n",
|
466 |
+
"text = \"explain about the audio\"\n",
|
467 |
+
"multi_modal_phi.generate(audio, text)"
|
468 |
+
]
|
469 |
+
},
|
470 |
+
{
|
471 |
+
"cell_type": "code",
|
472 |
+
"execution_count": null,
|
473 |
+
"id": "46aa9c66-a5bb-4760-8895-92673f49345f",
|
474 |
+
"metadata": {},
|
475 |
+
"outputs": [],
|
476 |
+
"source": []
|
477 |
+
}
|
478 |
+
],
|
479 |
+
"metadata": {
|
480 |
+
"kernelspec": {
|
481 |
+
"display_name": "Python 3 (ipykernel)",
|
482 |
+
"language": "python",
|
483 |
+
"name": "python3"
|
484 |
+
},
|
485 |
+
"language_info": {
|
486 |
+
"codemirror_mode": {
|
487 |
+
"name": "ipython",
|
488 |
+
"version": 3
|
489 |
+
},
|
490 |
+
"file_extension": ".py",
|
491 |
+
"mimetype": "text/x-python",
|
492 |
+
"name": "python",
|
493 |
+
"nbconvert_exporter": "python",
|
494 |
+
"pygments_lexer": "ipython3",
|
495 |
+
"version": "3.10.12"
|
496 |
+
}
|
497 |
+
},
|
498 |
+
"nbformat": 4,
|
499 |
+
"nbformat_minor": 5
|
500 |
+
}
|
inference/__init__.py
ADDED
File without changes
|
inference/conversation.py
ADDED
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dataclasses
|
2 |
+
from enum import auto, Enum
|
3 |
+
from typing import List, Tuple
|
4 |
+
|
5 |
+
|
6 |
+
class SeparatorStyle(Enum):
|
7 |
+
"""Different separator style."""
|
8 |
+
SINGLE = auto()
|
9 |
+
TWO = auto()
|
10 |
+
MPT = auto()
|
11 |
+
PLAIN = auto()
|
12 |
+
LLAMA_2 = auto()
|
13 |
+
|
14 |
+
|
15 |
+
@dataclasses.dataclass
|
16 |
+
class Conversation:
|
17 |
+
"""A class that keeps all conversation history."""
|
18 |
+
system: str
|
19 |
+
roles: List[str]
|
20 |
+
messages: List[List[str]]
|
21 |
+
offset: int
|
22 |
+
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
|
23 |
+
sep: str = "###"
|
24 |
+
sep2: str = None
|
25 |
+
version: str = "Unknown"
|
26 |
+
|
27 |
+
skip_next: bool = False
|
28 |
+
|
29 |
+
def get_prompt(self):
|
30 |
+
messages = self.messages
|
31 |
+
if len(messages) > 0 and type(messages[0][1]) is tuple:
|
32 |
+
messages = self.messages.copy()
|
33 |
+
init_role, init_msg = messages[0].copy()
|
34 |
+
init_msg = init_msg[0].replace("<image>", "").strip()
|
35 |
+
if 'mmtag' in self.version:
|
36 |
+
messages[0] = (init_role, init_msg)
|
37 |
+
messages.insert(0, (self.roles[0], "<Image><image></Image>"))
|
38 |
+
messages.insert(1, (self.roles[1], "Received."))
|
39 |
+
else:
|
40 |
+
messages[0] = (init_role, "<image>\n" + init_msg)
|
41 |
+
|
42 |
+
if self.sep_style == SeparatorStyle.SINGLE:
|
43 |
+
ret = self.system + self.sep
|
44 |
+
for role, message in messages:
|
45 |
+
if message:
|
46 |
+
if type(message) is tuple:
|
47 |
+
message, _, _ = message
|
48 |
+
ret += role + ": " + message + self.sep
|
49 |
+
else:
|
50 |
+
ret += role + ":"
|
51 |
+
elif self.sep_style == SeparatorStyle.TWO:
|
52 |
+
seps = [self.sep, self.sep2]
|
53 |
+
ret = self.system + seps[0]
|
54 |
+
for i, (role, message) in enumerate(messages):
|
55 |
+
if message:
|
56 |
+
if type(message) is tuple:
|
57 |
+
message, _, _ = message
|
58 |
+
ret += role + ": " + message + seps[i % 2]
|
59 |
+
else:
|
60 |
+
ret += role + ":"
|
61 |
+
elif self.sep_style == SeparatorStyle.PLAIN:
|
62 |
+
seps = [self.sep, self.sep2]
|
63 |
+
ret = self.system
|
64 |
+
for i, (role, message) in enumerate(messages):
|
65 |
+
if message:
|
66 |
+
if type(message) is tuple:
|
67 |
+
message, _, _ = message
|
68 |
+
ret += message + seps[i % 2]
|
69 |
+
else:
|
70 |
+
ret += ""
|
71 |
+
else:
|
72 |
+
raise ValueError(f"Invalid style: {self.sep_style}")
|
73 |
+
|
74 |
+
return ret
|
75 |
+
|
76 |
+
def append_message(self, role, message):
|
77 |
+
self.messages.append([role, message])
|
78 |
+
|
79 |
+
def get_images(self, return_pil=False):
|
80 |
+
images = []
|
81 |
+
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
82 |
+
if i % 2 == 0:
|
83 |
+
if type(msg) is tuple:
|
84 |
+
import base64
|
85 |
+
from io import BytesIO
|
86 |
+
from PIL import Image
|
87 |
+
msg, image, image_process_mode = msg
|
88 |
+
if image_process_mode == "Pad":
|
89 |
+
def expand2square(pil_img, background_color=(122, 116, 104)):
|
90 |
+
width, height = pil_img.size
|
91 |
+
if width == height:
|
92 |
+
return pil_img
|
93 |
+
elif width > height:
|
94 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
95 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
96 |
+
return result
|
97 |
+
else:
|
98 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
99 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
100 |
+
return result
|
101 |
+
image = expand2square(image)
|
102 |
+
elif image_process_mode in ["Default", "Crop"]:
|
103 |
+
pass
|
104 |
+
elif image_process_mode == "Resize":
|
105 |
+
image = image.resize((336, 336))
|
106 |
+
else:
|
107 |
+
raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
|
108 |
+
max_hw, min_hw = max(image.size), min(image.size)
|
109 |
+
aspect_ratio = max_hw / min_hw
|
110 |
+
max_len, min_len = 800, 400
|
111 |
+
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
|
112 |
+
longest_edge = int(shortest_edge * aspect_ratio)
|
113 |
+
W, H = image.size
|
114 |
+
if longest_edge != max(image.size):
|
115 |
+
if H > W:
|
116 |
+
H, W = longest_edge, shortest_edge
|
117 |
+
else:
|
118 |
+
H, W = shortest_edge, longest_edge
|
119 |
+
image = image.resize((W, H))
|
120 |
+
if return_pil:
|
121 |
+
images.append(image)
|
122 |
+
else:
|
123 |
+
buffered = BytesIO()
|
124 |
+
image.save(buffered, format="PNG")
|
125 |
+
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
126 |
+
images.append(img_b64_str)
|
127 |
+
return images
|
128 |
+
|
129 |
+
def to_gradio_chatbot(self):
|
130 |
+
ret = []
|
131 |
+
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
132 |
+
if i % 2 == 0:
|
133 |
+
if type(msg) is tuple:
|
134 |
+
import base64
|
135 |
+
from io import BytesIO
|
136 |
+
msg, image, image_process_mode = msg
|
137 |
+
max_hw, min_hw = max(image.size), min(image.size)
|
138 |
+
aspect_ratio = max_hw / min_hw
|
139 |
+
max_len, min_len = 800, 400
|
140 |
+
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
|
141 |
+
longest_edge = int(shortest_edge * aspect_ratio)
|
142 |
+
W, H = image.size
|
143 |
+
if H > W:
|
144 |
+
H, W = longest_edge, shortest_edge
|
145 |
+
else:
|
146 |
+
H, W = shortest_edge, longest_edge
|
147 |
+
image = image.resize((W, H))
|
148 |
+
buffered = BytesIO()
|
149 |
+
image.save(buffered, format="JPEG")
|
150 |
+
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
151 |
+
img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
|
152 |
+
msg = img_str + msg.replace('<image>', '').strip()
|
153 |
+
ret.append([msg, None])
|
154 |
+
else:
|
155 |
+
ret.append([msg, None])
|
156 |
+
else:
|
157 |
+
ret[-1][-1] = msg
|
158 |
+
return ret
|
159 |
+
|
160 |
+
def copy(self):
|
161 |
+
return Conversation(
|
162 |
+
system=self.system,
|
163 |
+
roles=self.roles,
|
164 |
+
messages=[[x, y] for x, y in self.messages],
|
165 |
+
offset=self.offset,
|
166 |
+
sep_style=self.sep_style,
|
167 |
+
sep=self.sep,
|
168 |
+
sep2=self.sep2,
|
169 |
+
version=self.version)
|
170 |
+
|
171 |
+
def dict(self):
|
172 |
+
if len(self.get_images()) > 0:
|
173 |
+
return {
|
174 |
+
"system": self.system,
|
175 |
+
"roles": self.roles,
|
176 |
+
"messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
|
177 |
+
"offset": self.offset,
|
178 |
+
"sep": self.sep,
|
179 |
+
"sep2": self.sep2,
|
180 |
+
}
|
181 |
+
return {
|
182 |
+
"system": self.system,
|
183 |
+
"roles": self.roles,
|
184 |
+
"messages": self.messages,
|
185 |
+
"offset": self.offset,
|
186 |
+
"sep": self.sep,
|
187 |
+
"sep2": self.sep2,
|
188 |
+
}
|
189 |
+
|
190 |
+
|
191 |
+
conv_phi_v0 = Conversation(
|
192 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
193 |
+
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
194 |
+
roles=("USER", "ASSISTANT"),
|
195 |
+
version="v0",
|
196 |
+
messages=(),
|
197 |
+
offset=0,
|
198 |
+
sep_style=SeparatorStyle.TWO,
|
199 |
+
sep=" ",
|
200 |
+
sep2="<|endoftext|>",
|
201 |
+
)
|
202 |
+
|
203 |
+
conv_llava_plain = Conversation(
|
204 |
+
system="",
|
205 |
+
roles=("", ""),
|
206 |
+
messages=(
|
207 |
+
),
|
208 |
+
offset=0,
|
209 |
+
sep_style=SeparatorStyle.PLAIN,
|
210 |
+
sep="\n",
|
211 |
+
)
|
212 |
+
|
213 |
+
default_conversation = conv_phi_v0
|
214 |
+
conv_templates = {
|
215 |
+
"default": conv_phi_v0,
|
216 |
+
"v0": conv_phi_v0,
|
217 |
+
"phi-2_v0": conv_phi_v0,
|
218 |
+
|
219 |
+
"plain": conv_llava_plain,
|
220 |
+
}
|
221 |
+
|
222 |
+
|
223 |
+
if __name__ == "__main__":
|
224 |
+
print(default_conversation.get_prompt())
|
inference/inference.ipynb
ADDED
@@ -0,0 +1,369 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"id": "cdad6b21-030a-40d3-9b31-a229e5b6196d",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [],
|
9 |
+
"source": [
|
10 |
+
"import torch\n",
|
11 |
+
"from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoTokenizer, AutoConfig, CLIPImageProcessor"
|
12 |
+
]
|
13 |
+
},
|
14 |
+
{
|
15 |
+
"cell_type": "code",
|
16 |
+
"execution_count": 2,
|
17 |
+
"id": "1f832710-0e8c-42ec-b581-1b15fd2a6acc",
|
18 |
+
"metadata": {},
|
19 |
+
"outputs": [
|
20 |
+
{
|
21 |
+
"name": "stdout",
|
22 |
+
"output_type": "stream",
|
23 |
+
"text": [
|
24 |
+
"[2024-01-25 14:31:58,511] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n"
|
25 |
+
]
|
26 |
+
}
|
27 |
+
],
|
28 |
+
"source": [
|
29 |
+
"from model import LlavaPhiForCausalLM"
|
30 |
+
]
|
31 |
+
},
|
32 |
+
{
|
33 |
+
"cell_type": "code",
|
34 |
+
"execution_count": 3,
|
35 |
+
"id": "9e68f1d4-1ae3-4d45-b818-4600218d2215",
|
36 |
+
"metadata": {},
|
37 |
+
"outputs": [
|
38 |
+
{
|
39 |
+
"data": {
|
40 |
+
"application/vnd.jupyter.widget-view+json": {
|
41 |
+
"model_id": "e5e13e666e3a43d4ad26cc70904abee8",
|
42 |
+
"version_major": 2,
|
43 |
+
"version_minor": 0
|
44 |
+
},
|
45 |
+
"text/plain": [
|
46 |
+
"Loading checkpoint shards: 0%| | 0/3 [00:00<?, ?it/s]"
|
47 |
+
]
|
48 |
+
},
|
49 |
+
"metadata": {},
|
50 |
+
"output_type": "display_data"
|
51 |
+
}
|
52 |
+
],
|
53 |
+
"source": [
|
54 |
+
"model_name = \"RaviNaik/Llava-Phi2\"\n",
|
55 |
+
"model = LlavaPhiForCausalLM.from_pretrained(model_name)"
|
56 |
+
]
|
57 |
+
},
|
58 |
+
{
|
59 |
+
"cell_type": "code",
|
60 |
+
"execution_count": 4,
|
61 |
+
"id": "49edfa0d-e08a-4d3c-a1d6-34068b122419",
|
62 |
+
"metadata": {},
|
63 |
+
"outputs": [
|
64 |
+
{
|
65 |
+
"name": "stderr",
|
66 |
+
"output_type": "stream",
|
67 |
+
"text": [
|
68 |
+
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
|
69 |
+
]
|
70 |
+
}
|
71 |
+
],
|
72 |
+
"source": [
|
73 |
+
"tokenizer = AutoTokenizer.from_pretrained(model_name)"
|
74 |
+
]
|
75 |
+
},
|
76 |
+
{
|
77 |
+
"cell_type": "code",
|
78 |
+
"execution_count": 5,
|
79 |
+
"id": "dcec20cd-d946-42d7-8e10-c198cd49b486",
|
80 |
+
"metadata": {},
|
81 |
+
"outputs": [],
|
82 |
+
"source": [
|
83 |
+
"image_processor = CLIPImageProcessor.from_pretrained(model_name)\n",
|
84 |
+
"mm_use_im_start_end = getattr(model.config, \"mm_use_im_start_end\", False)\n",
|
85 |
+
"mm_use_im_patch_token = getattr(model.config, \"mm_use_im_patch_token\", True)"
|
86 |
+
]
|
87 |
+
},
|
88 |
+
{
|
89 |
+
"cell_type": "code",
|
90 |
+
"execution_count": 6,
|
91 |
+
"id": "443c13c4-b7e6-4bc5-b6c7-c577bd4708f6",
|
92 |
+
"metadata": {},
|
93 |
+
"outputs": [],
|
94 |
+
"source": [
|
95 |
+
"if mm_use_im_patch_token:\n",
|
96 |
+
" tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)\n",
|
97 |
+
"if mm_use_im_start_end:\n",
|
98 |
+
" tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)\n",
|
99 |
+
" \n",
|
100 |
+
"if hasattr(model.config, \"max_sequence_length\"):\n",
|
101 |
+
" context_len = model.config.max_sequence_length\n",
|
102 |
+
"else:\n",
|
103 |
+
" context_len = 2048"
|
104 |
+
]
|
105 |
+
},
|
106 |
+
{
|
107 |
+
"cell_type": "code",
|
108 |
+
"execution_count": 7,
|
109 |
+
"id": "d8caee43-0d2a-46d4-bdbc-2cfc7dec9e52",
|
110 |
+
"metadata": {},
|
111 |
+
"outputs": [],
|
112 |
+
"source": [
|
113 |
+
"from transformers import WhisperProcessor, WhisperForConditionalGeneration"
|
114 |
+
]
|
115 |
+
},
|
116 |
+
{
|
117 |
+
"cell_type": "code",
|
118 |
+
"execution_count": 8,
|
119 |
+
"id": "3acea526-d8ae-4eb6-8dfc-4ea72651b547",
|
120 |
+
"metadata": {},
|
121 |
+
"outputs": [],
|
122 |
+
"source": [
|
123 |
+
"class AudioLanguageConnector:\n",
|
124 |
+
" def __init__(self, projection_dim):\n",
|
125 |
+
" model_name = \"microsoft/phi-2\"\n",
|
126 |
+
" self.phi2_tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)\n",
|
127 |
+
" self.phi2_tokenizer.pad_token = self.phi2_tokenizer.eos_token\n",
|
128 |
+
" self.phi2_tokenizer.max_length = projection_dim\n",
|
129 |
+
"\n",
|
130 |
+
" def __call__(self, text):\n",
|
131 |
+
" text = f\"<audio_start> {text} <audio_end>\"\n",
|
132 |
+
" tokens = self.phi2_tokenizer(text, return_tensors=\"pt\", return_attention_mask=False)\n",
|
133 |
+
" return tokens\n",
|
134 |
+
" \n",
|
135 |
+
"\n",
|
136 |
+
"class WhisperWithProjection:\n",
|
137 |
+
" def __init__(self, projection_dim, device):\n",
|
138 |
+
" self.device = device\n",
|
139 |
+
" self.processor = WhisperProcessor.from_pretrained(\"openai/whisper-tiny\", device_map=device)\n",
|
140 |
+
" self.model = WhisperForConditionalGeneration.from_pretrained(\"openai/whisper-tiny\", device_map=device)\n",
|
141 |
+
" self.model.config.forced_decoder_ids = None\n",
|
142 |
+
" # self.audio_language_connector = AudioLanguageConnector(projection_dim)\n",
|
143 |
+
" \n",
|
144 |
+
" def __call__(self, audio):\n",
|
145 |
+
" input_features = self.processor(audio[\"array\"],\n",
|
146 |
+
" sampling_rate=audio[\"sampling_rate\"],\n",
|
147 |
+
" return_tensors=\"pt\").input_features\n",
|
148 |
+
" # generate token ids\n",
|
149 |
+
" predicted_ids = self.model.generate(input_features.to(self.device))\n",
|
150 |
+
" # decode token ids to text \n",
|
151 |
+
" transcription = self.processor.batch_decode(predicted_ids, skip_special_tokens=True)\n",
|
152 |
+
"\n",
|
153 |
+
" # audio_embeddings = self.audio_language_connector(transcription)\n",
|
154 |
+
" return transcription"
|
155 |
+
]
|
156 |
+
},
|
157 |
+
{
|
158 |
+
"cell_type": "code",
|
159 |
+
"execution_count": 10,
|
160 |
+
"id": "a2757c91-2ec1-4fe7-9216-03740bf80061",
|
161 |
+
"metadata": {},
|
162 |
+
"outputs": [],
|
163 |
+
"source": [
|
164 |
+
"IGNORE_INDEX = -100\n",
|
165 |
+
"IMAGE_TOKEN_INDEX = -200\n",
|
166 |
+
"DEFAULT_IMAGE_TOKEN = \"<image>\"\n",
|
167 |
+
"DEFAULT_IMAGE_PATCH_TOKEN = \"<im_patch>\"\n",
|
168 |
+
"DEFAULT_IM_START_TOKEN = \"<im_start>\"\n",
|
169 |
+
"DEFAULT_IM_END_TOKEN = \"<im_end>\"\n",
|
170 |
+
"\n",
|
171 |
+
"from conversation import conv_templates, SeparatorStyle\n",
|
172 |
+
"\n",
|
173 |
+
"class MultiModalPhi2:\n",
|
174 |
+
" def __init__(self, modelname_or_path=\"RaviNaik/Llava-Phi2\",\n",
|
175 |
+
" temperature=0.2,\n",
|
176 |
+
" max_new_tokens=1024,\n",
|
177 |
+
" device=\"cuda:0\"):\n",
|
178 |
+
" self.model_name = modelname_or_path\n",
|
179 |
+
" self.temperature = temperature\n",
|
180 |
+
" self.max_new_tokens = max_new_tokens\n",
|
181 |
+
" self.device = device\n",
|
182 |
+
" self.disable_torch_init()\n",
|
183 |
+
" self.whisper_w_proj = WhisperWithProjection(projection_dim=512, device=device)\n",
|
184 |
+
" self.load_pretrained_model()\n",
|
185 |
+
" \n",
|
186 |
+
" def disable_torch_init(self):\n",
|
187 |
+
" \"\"\"\n",
|
188 |
+
" Disable the redundant torch default initialization to accelerate model creation.\n",
|
189 |
+
" \"\"\"\n",
|
190 |
+
" setattr(torch.nn.Linear, \"reset_parameters\", lambda self: None)\n",
|
191 |
+
" setattr(torch.nn.LayerNorm, \"reset_parameters\", lambda self: None)\n",
|
192 |
+
" \n",
|
193 |
+
" def load_pretrained_model(self):\n",
|
194 |
+
" self.model = LlavaPhiForCausalLM.from_pretrained(self.model_name, device_map=self.device)\n",
|
195 |
+
" self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)\n",
|
196 |
+
" self.image_processor = CLIPImageProcessor.from_pretrained(self.model_name)\n",
|
197 |
+
" mm_use_im_start_end = getattr(self.model.config, \"mm_use_im_start_end\", False)\n",
|
198 |
+
" mm_use_im_patch_token = getattr(self.model.config, \"mm_use_im_patch_token\", True)\n",
|
199 |
+
" if mm_use_im_patch_token:\n",
|
200 |
+
" self.tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)\n",
|
201 |
+
" if mm_use_im_start_end:\n",
|
202 |
+
" self.tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)\n",
|
203 |
+
" \n",
|
204 |
+
" def tokenizer_image_token(self, prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):\n",
|
205 |
+
" prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]\n",
|
206 |
+
" \n",
|
207 |
+
" def insert_separator(X, sep):\n",
|
208 |
+
" return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]\n",
|
209 |
+
" \n",
|
210 |
+
" input_ids = []\n",
|
211 |
+
" offset = 0\n",
|
212 |
+
" if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:\n",
|
213 |
+
" offset = 1\n",
|
214 |
+
" input_ids.append(prompt_chunks[0][0])\n",
|
215 |
+
" for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):\n",
|
216 |
+
" input_ids.extend(x[offset:])\n",
|
217 |
+
" \n",
|
218 |
+
" if return_tensors is not None:\n",
|
219 |
+
" if return_tensors == 'pt':\n",
|
220 |
+
" return torch.tensor(input_ids, dtype=torch.long)\n",
|
221 |
+
" raise ValueError(f'Unsupported tensor type: {return_tensors}')\n",
|
222 |
+
" return input_ids\n",
|
223 |
+
" \n",
|
224 |
+
" def __call__(self, text, audio, image):\n",
|
225 |
+
" qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\\n' + text\n",
|
226 |
+
" conv = conv_templates[\"phi-2_v0\"].copy()\n",
|
227 |
+
" conv.append_message(conv.roles[0], qs)\n",
|
228 |
+
" conv.append_message(conv.roles[1], None)\n",
|
229 |
+
" prompt = conv.get_prompt()\n",
|
230 |
+
"\n",
|
231 |
+
" image_tensor = self.image_processor.preprocess(image, return_tensors='pt')['pixel_values'].to(self.device)\n",
|
232 |
+
" \n",
|
233 |
+
" input_ids = self.tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0)\n",
|
234 |
+
" if audio is not None:\n",
|
235 |
+
" audio_transcript = self.whisper_w_proj(audio)\n",
|
236 |
+
" audio_embed = self.tokenizer(audio_transcript, return_tensors='pt')[\"input_ids\"]\n",
|
237 |
+
" input_ids = torch.concat([input_ids, audio_embed], dim=1)\n",
|
238 |
+
" input_ids = input_ids.to(self.device)\n",
|
239 |
+
" \n",
|
240 |
+
" stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2\n",
|
241 |
+
"\n",
|
242 |
+
" with torch.inference_mode():\n",
|
243 |
+
" output_ids = self.model.generate(\n",
|
244 |
+
" input_ids,\n",
|
245 |
+
" images=image_tensor,\n",
|
246 |
+
" do_sample=True,\n",
|
247 |
+
" temperature=self.temperature,\n",
|
248 |
+
" max_new_tokens=self.max_new_tokens,\n",
|
249 |
+
" eos_token_id=self.tokenizer.eos_token_id, # End of sequence token\n",
|
250 |
+
" pad_token_id=self.tokenizer.eos_token_id, # Pad token\n",
|
251 |
+
" use_cache=True,\n",
|
252 |
+
" )\n",
|
253 |
+
"\n",
|
254 |
+
" input_token_len = input_ids.shape[1]\n",
|
255 |
+
" n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()\n",
|
256 |
+
" if n_diff_input_output > 0:\n",
|
257 |
+
" print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')\n",
|
258 |
+
" outputs = self.tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]\n",
|
259 |
+
" outputs = outputs.strip()\n",
|
260 |
+
" if outputs.endswith(stop_str):\n",
|
261 |
+
" outputs = outputs[:-len(stop_str)]\n",
|
262 |
+
" outputs = outputs.strip()\n",
|
263 |
+
" return outputs"
|
264 |
+
]
|
265 |
+
},
|
266 |
+
{
|
267 |
+
"cell_type": "code",
|
268 |
+
"execution_count": 11,
|
269 |
+
"id": "cc47e6a0-3544-4a60-930f-ccae87ef945a",
|
270 |
+
"metadata": {},
|
271 |
+
"outputs": [
|
272 |
+
{
|
273 |
+
"data": {
|
274 |
+
"application/vnd.jupyter.widget-view+json": {
|
275 |
+
"model_id": "9ef56077307d4cef907e25b092061611",
|
276 |
+
"version_major": 2,
|
277 |
+
"version_minor": 0
|
278 |
+
},
|
279 |
+
"text/plain": [
|
280 |
+
"Loading checkpoint shards: 0%| | 0/3 [00:00<?, ?it/s]"
|
281 |
+
]
|
282 |
+
},
|
283 |
+
"metadata": {},
|
284 |
+
"output_type": "display_data"
|
285 |
+
},
|
286 |
+
{
|
287 |
+
"name": "stderr",
|
288 |
+
"output_type": "stream",
|
289 |
+
"text": [
|
290 |
+
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
|
291 |
+
]
|
292 |
+
}
|
293 |
+
],
|
294 |
+
"source": [
|
295 |
+
"multimodal_phi2 = MultiModalPhi2()"
|
296 |
+
]
|
297 |
+
},
|
298 |
+
{
|
299 |
+
"cell_type": "code",
|
300 |
+
"execution_count": 12,
|
301 |
+
"id": "cb8aca1b-7d75-45e7-b5a4-71d151f792e1",
|
302 |
+
"metadata": {},
|
303 |
+
"outputs": [],
|
304 |
+
"source": [
|
305 |
+
"from PIL import Image\n",
|
306 |
+
"import requests\n",
|
307 |
+
"\n",
|
308 |
+
"url = \"https://www.ilankelman.org/stopsigns/australia.jpg\"\n",
|
309 |
+
"image = Image.open(requests.get(url, stream=True).raw)\n",
|
310 |
+
"\n",
|
311 |
+
"from datasets import load_dataset\n",
|
312 |
+
"audio_ds = load_dataset(\"hf-internal-testing/librispeech_asr_dummy\", \"clean\", split=\"validation\")\n",
|
313 |
+
"audio = audio_ds[0][\"audio\"]\n",
|
314 |
+
"\n",
|
315 |
+
"text = \"tell me about the image\""
|
316 |
+
]
|
317 |
+
},
|
318 |
+
{
|
319 |
+
"cell_type": "code",
|
320 |
+
"execution_count": 14,
|
321 |
+
"id": "6767efc6-be4f-44d3-84ff-34db57d9f940",
|
322 |
+
"metadata": {},
|
323 |
+
"outputs": [
|
324 |
+
{
|
325 |
+
"data": {
|
326 |
+
"text/plain": [
|
327 |
+
"'In the image, there is a Chinese writing on a pole in a foreign language. This suggests that the image was taken in a foreign country, possibly in a foreign country. The sign is in a foreign language, which might be in a foreign language. The sign is written in Japanese, which is a common language in Japan. The sign is also written in two different languages, which suggests that it is written in a language that is not in the native language.'"
|
328 |
+
]
|
329 |
+
},
|
330 |
+
"execution_count": 14,
|
331 |
+
"metadata": {},
|
332 |
+
"output_type": "execute_result"
|
333 |
+
}
|
334 |
+
],
|
335 |
+
"source": [
|
336 |
+
"multimodal_phi2(text, None, image)"
|
337 |
+
]
|
338 |
+
},
|
339 |
+
{
|
340 |
+
"cell_type": "code",
|
341 |
+
"execution_count": null,
|
342 |
+
"id": "0bdd0b8a-709b-4c82-ac1d-dc746d3a0748",
|
343 |
+
"metadata": {},
|
344 |
+
"outputs": [],
|
345 |
+
"source": []
|
346 |
+
}
|
347 |
+
],
|
348 |
+
"metadata": {
|
349 |
+
"kernelspec": {
|
350 |
+
"display_name": "Python 3 (ipykernel)",
|
351 |
+
"language": "python",
|
352 |
+
"name": "python3"
|
353 |
+
},
|
354 |
+
"language_info": {
|
355 |
+
"codemirror_mode": {
|
356 |
+
"name": "ipython",
|
357 |
+
"version": 3
|
358 |
+
},
|
359 |
+
"file_extension": ".py",
|
360 |
+
"mimetype": "text/x-python",
|
361 |
+
"name": "python",
|
362 |
+
"nbconvert_exporter": "python",
|
363 |
+
"pygments_lexer": "ipython3",
|
364 |
+
"version": "3.10.12"
|
365 |
+
}
|
366 |
+
},
|
367 |
+
"nbformat": 4,
|
368 |
+
"nbformat_minor": 5
|
369 |
+
}
|
inference/main.py
ADDED
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import soundfile as sf
|
2 |
+
import librosa
|
3 |
+
import torch
|
4 |
+
from transformers import (
|
5 |
+
AutoTokenizer,
|
6 |
+
CLIPImageProcessor,
|
7 |
+
WhisperProcessor,
|
8 |
+
WhisperForConditionalGeneration,
|
9 |
+
)
|
10 |
+
|
11 |
+
from .model import LlavaPhiForCausalLM
|
12 |
+
from .conversation import conv_templates, SeparatorStyle
|
13 |
+
|
14 |
+
IGNORE_INDEX = -100
|
15 |
+
IMAGE_TOKEN_INDEX = -200
|
16 |
+
DEFAULT_IMAGE_TOKEN = "<image>"
|
17 |
+
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
|
18 |
+
DEFAULT_IM_START_TOKEN = "<im_start>"
|
19 |
+
DEFAULT_IM_END_TOKEN = "<im_end>"
|
20 |
+
|
21 |
+
|
22 |
+
class AudioLanguageConnector:
|
23 |
+
def __init__(self, projection_dim):
|
24 |
+
model_name = "microsoft/phi-2"
|
25 |
+
self.phi2_tokenizer = AutoTokenizer.from_pretrained(
|
26 |
+
model_name, trust_remote_code=True
|
27 |
+
)
|
28 |
+
self.phi2_tokenizer.pad_token = self.phi2_tokenizer.eos_token
|
29 |
+
self.phi2_tokenizer.max_length = projection_dim
|
30 |
+
|
31 |
+
def __call__(self, text):
|
32 |
+
text = f"<audio_start> {text} <audio_end>"
|
33 |
+
tokens = self.phi2_tokenizer(
|
34 |
+
text, return_tensors="pt", return_attention_mask=False
|
35 |
+
)
|
36 |
+
return tokens
|
37 |
+
|
38 |
+
|
39 |
+
class WhisperWithProjection:
|
40 |
+
def __init__(self, projection_dim, device):
|
41 |
+
self.device = device
|
42 |
+
self.processor = WhisperProcessor.from_pretrained(
|
43 |
+
"openai/whisper-tiny", device_map=device
|
44 |
+
)
|
45 |
+
self.model = WhisperForConditionalGeneration.from_pretrained(
|
46 |
+
"openai/whisper-tiny", device_map=device
|
47 |
+
)
|
48 |
+
self.model.config.forced_decoder_ids = None
|
49 |
+
# self.audio_language_connector = AudioLanguageConnector(projection_dim)
|
50 |
+
|
51 |
+
def __call__(self, audio):
|
52 |
+
array, sampling_rate = sf.read(audio)
|
53 |
+
resampled_array = librosa.resample(
|
54 |
+
array,
|
55 |
+
orig_sr=sampling_rate,
|
56 |
+
target_sr=16000,
|
57 |
+
)
|
58 |
+
input_features = self.processor(
|
59 |
+
resampled_array, sampling_rate=16000, return_tensors="pt"
|
60 |
+
).input_features
|
61 |
+
# generate token ids
|
62 |
+
predicted_ids = self.model.generate(input_features.to(self.device))
|
63 |
+
# decode token ids to text
|
64 |
+
transcription = self.processor.batch_decode(
|
65 |
+
predicted_ids, skip_special_tokens=True
|
66 |
+
)
|
67 |
+
|
68 |
+
# audio_embeddings = self.audio_language_connector(transcription)
|
69 |
+
return transcription
|
70 |
+
|
71 |
+
|
72 |
+
class MultiModalPhi2:
|
73 |
+
def __init__(
|
74 |
+
self,
|
75 |
+
modelname_or_path="RaviNaik/Llava-Phi2",
|
76 |
+
temperature=0.2,
|
77 |
+
max_new_tokens=1024,
|
78 |
+
device="cuda:0",
|
79 |
+
):
|
80 |
+
self.model_name = modelname_or_path
|
81 |
+
self.temperature = temperature
|
82 |
+
self.max_new_tokens = max_new_tokens
|
83 |
+
self.device = device
|
84 |
+
self.disable_torch_init()
|
85 |
+
self.whisper_w_proj = WhisperWithProjection(projection_dim=512, device=device)
|
86 |
+
self.load_pretrained_model()
|
87 |
+
|
88 |
+
def disable_torch_init(self):
|
89 |
+
"""
|
90 |
+
Disable the redundant torch default initialization to accelerate model creation.
|
91 |
+
"""
|
92 |
+
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
|
93 |
+
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
94 |
+
|
95 |
+
def load_pretrained_model(self):
|
96 |
+
self.model = LlavaPhiForCausalLM.from_pretrained(
|
97 |
+
self.model_name, device_map=self.device
|
98 |
+
)
|
99 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
100 |
+
self.image_processor = CLIPImageProcessor.from_pretrained(self.model_name)
|
101 |
+
mm_use_im_start_end = getattr(self.model.config, "mm_use_im_start_end", False)
|
102 |
+
mm_use_im_patch_token = getattr(
|
103 |
+
self.model.config, "mm_use_im_patch_token", True
|
104 |
+
)
|
105 |
+
if mm_use_im_patch_token:
|
106 |
+
self.tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
107 |
+
if mm_use_im_start_end:
|
108 |
+
self.tokenizer.add_tokens(
|
109 |
+
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
|
110 |
+
)
|
111 |
+
|
112 |
+
def tokenizer_image_token(
|
113 |
+
self,
|
114 |
+
prompt,
|
115 |
+
tokenizer,
|
116 |
+
image_token_index=IMAGE_TOKEN_INDEX,
|
117 |
+
return_tensors=None,
|
118 |
+
):
|
119 |
+
prompt_chunks = [
|
120 |
+
tokenizer(chunk).input_ids for chunk in prompt.split("<image>")
|
121 |
+
]
|
122 |
+
|
123 |
+
def insert_separator(X, sep):
|
124 |
+
return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1]
|
125 |
+
|
126 |
+
input_ids = []
|
127 |
+
offset = 0
|
128 |
+
if (
|
129 |
+
len(prompt_chunks) > 0
|
130 |
+
and len(prompt_chunks[0]) > 0
|
131 |
+
and prompt_chunks[0][0] == tokenizer.bos_token_id
|
132 |
+
):
|
133 |
+
offset = 1
|
134 |
+
input_ids.append(prompt_chunks[0][0])
|
135 |
+
for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
|
136 |
+
input_ids.extend(x[offset:])
|
137 |
+
|
138 |
+
if return_tensors is not None:
|
139 |
+
if return_tensors == "pt":
|
140 |
+
return torch.tensor(input_ids, dtype=torch.long)
|
141 |
+
raise ValueError(f"Unsupported tensor type: {return_tensors}")
|
142 |
+
return input_ids
|
143 |
+
|
144 |
+
def __call__(self, text, audio, image):
|
145 |
+
if text is None:
|
146 |
+
text = ""
|
147 |
+
if image is not None:
|
148 |
+
qs = (
|
149 |
+
DEFAULT_IM_START_TOKEN
|
150 |
+
+ DEFAULT_IMAGE_TOKEN
|
151 |
+
+ DEFAULT_IM_END_TOKEN
|
152 |
+
+ "\n"
|
153 |
+
+ text
|
154 |
+
)
|
155 |
+
conv = conv_templates["phi-2_v0"].copy()
|
156 |
+
conv.append_message(conv.roles[0], qs)
|
157 |
+
conv.append_message(conv.roles[1], None)
|
158 |
+
prompt = conv.get_prompt()
|
159 |
+
|
160 |
+
input_ids = self.tokenizer_image_token(
|
161 |
+
prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
|
162 |
+
).unsqueeze(0)
|
163 |
+
|
164 |
+
image_tensor = self.image_processor.preprocess(image, return_tensors="pt")[
|
165 |
+
"pixel_values"
|
166 |
+
].to(self.device)
|
167 |
+
else:
|
168 |
+
qs = text
|
169 |
+
conv = conv_templates["phi-2_v0"].copy()
|
170 |
+
conv.append_message(conv.roles[0], qs)
|
171 |
+
conv.append_message(conv.roles[1], None)
|
172 |
+
prompt = conv.get_prompt()
|
173 |
+
|
174 |
+
input_ids = self.tokenizer(prompt, return_tensors="pt")["input_ids"]
|
175 |
+
|
176 |
+
image_tensor = None
|
177 |
+
|
178 |
+
if audio is not None:
|
179 |
+
audio_transcript = self.whisper_w_proj(audio)
|
180 |
+
audio_embed = self.tokenizer(audio_transcript, return_tensors="pt")[
|
181 |
+
"input_ids"
|
182 |
+
]
|
183 |
+
input_ids = torch.concat([input_ids, audio_embed], dim=1)
|
184 |
+
input_ids = input_ids.to(self.device)
|
185 |
+
|
186 |
+
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
187 |
+
|
188 |
+
with torch.inference_mode():
|
189 |
+
if image is not None:
|
190 |
+
output_ids = self.model.generate(
|
191 |
+
input_ids,
|
192 |
+
images=image_tensor,
|
193 |
+
do_sample=True,
|
194 |
+
temperature=self.temperature,
|
195 |
+
max_new_tokens=self.max_new_tokens,
|
196 |
+
eos_token_id=self.tokenizer.eos_token_id, # End of sequence token
|
197 |
+
pad_token_id=self.tokenizer.eos_token_id, # Pad token
|
198 |
+
use_cache=True,
|
199 |
+
)
|
200 |
+
else:
|
201 |
+
output_ids = self.model.generate(
|
202 |
+
input_ids,
|
203 |
+
do_sample=True,
|
204 |
+
temperature=self.temperature,
|
205 |
+
max_new_tokens=self.max_new_tokens,
|
206 |
+
eos_token_id=self.tokenizer.eos_token_id, # End of sequence token
|
207 |
+
pad_token_id=self.tokenizer.eos_token_id, # Pad token
|
208 |
+
use_cache=True,
|
209 |
+
)
|
210 |
+
|
211 |
+
input_token_len = input_ids.shape[1]
|
212 |
+
n_diff_input_output = (
|
213 |
+
(input_ids != output_ids[:, :input_token_len]).sum().item()
|
214 |
+
)
|
215 |
+
if n_diff_input_output > 0:
|
216 |
+
print(
|
217 |
+
f"[Warning] {n_diff_input_output} output_ids are not the same as the input_ids"
|
218 |
+
)
|
219 |
+
outputs = self.tokenizer.batch_decode(
|
220 |
+
output_ids[:, input_token_len:], skip_special_tokens=True
|
221 |
+
)[0]
|
222 |
+
outputs = outputs.strip()
|
223 |
+
if outputs.endswith(stop_str):
|
224 |
+
outputs = outputs[: -len(stop_str)]
|
225 |
+
outputs = outputs.strip()
|
226 |
+
return outputs
|
inference/model/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .language_model.llava_phi import LlavaPhiForCausalLM
|
2 |
+
from .language_model.configuration_llava_phi import LlavaPhiConfig, LlavaPhiVisionConfig, ProjectorConfig
|
inference/model/builder.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import warnings
|
3 |
+
|
4 |
+
from transformers import (
|
5 |
+
AutoTokenizer,
|
6 |
+
AutoModelForCausalLM,
|
7 |
+
AutoConfig,
|
8 |
+
BitsAndBytesConfig,
|
9 |
+
CLIPImageProcessor,
|
10 |
+
)
|
11 |
+
import torch
|
12 |
+
from .language_model.llava_phi import LlavaPhiForCausalLM
|
13 |
+
from .language_model.configuration_llava_phi import LlavaPhiConfig
|
14 |
+
|
15 |
+
IGNORE_INDEX = -100
|
16 |
+
IMAGE_TOKEN_INDEX = -200
|
17 |
+
DEFAULT_IMAGE_TOKEN = "<image>"
|
18 |
+
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
|
19 |
+
DEFAULT_IM_START_TOKEN = "<im_start>"
|
20 |
+
DEFAULT_IM_END_TOKEN = "<im_end>"
|
21 |
+
|
22 |
+
|
23 |
+
def load_pretrained_model(
|
24 |
+
model_path,
|
25 |
+
model_base,
|
26 |
+
model_name,
|
27 |
+
load_8bit=False,
|
28 |
+
load_4bit=False,
|
29 |
+
device_map="cuda",
|
30 |
+
device="cuda",
|
31 |
+
):
|
32 |
+
kwargs = {"device_map": device_map}
|
33 |
+
if load_8bit:
|
34 |
+
kwargs["load_in_8bit"] = True
|
35 |
+
elif load_4bit:
|
36 |
+
kwargs["load_in_4bit"] = True
|
37 |
+
kwargs["quantization_config"] = BitsAndBytesConfig(
|
38 |
+
load_in_4bit=True,
|
39 |
+
bnb_4bit_compute_dtype=torch.float16,
|
40 |
+
bnb_4bit_use_double_quant=True,
|
41 |
+
bnb_4bit_quant_type="nf4",
|
42 |
+
)
|
43 |
+
# else: # TODO: after fine-tuning LLava-Phi, load the model weights with fp16 will pose nan
|
44 |
+
# kwargs['torch_dtype'] = torch.float16
|
45 |
+
|
46 |
+
if "phi" in model_name.lower():
|
47 |
+
# Load LLaVA-Phi model
|
48 |
+
if "lora" in model_name.lower() and model_base is None:
|
49 |
+
warnings.warn(
|
50 |
+
"There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument."
|
51 |
+
)
|
52 |
+
if "lora" in model_name.lower() and model_base is not None:
|
53 |
+
lora_cfg_pretrained = AutoConfig.from_pretrained(model_path)
|
54 |
+
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
|
55 |
+
print("Loading LLaVA-Phi from base model...")
|
56 |
+
model = LlavaPhiForCausalLM.from_pretrained(
|
57 |
+
model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs
|
58 |
+
)
|
59 |
+
token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
|
60 |
+
if model.lm_head.weight.shape[0] != token_num:
|
61 |
+
model.lm_head.weight = torch.nn.Parameter(
|
62 |
+
torch.empty(
|
63 |
+
token_num, tokem_dim, device=model.device, dtype=model.dtype
|
64 |
+
)
|
65 |
+
)
|
66 |
+
model.model.embed_tokens.weight = torch.nn.Parameter(
|
67 |
+
torch.empty(
|
68 |
+
token_num, tokem_dim, device=model.device, dtype=model.dtype
|
69 |
+
)
|
70 |
+
)
|
71 |
+
|
72 |
+
print("Loading additional LLaVA-Phi weights...")
|
73 |
+
if os.path.exists(os.path.join(model_path, "non_lora_trainables.bin")):
|
74 |
+
non_lora_trainables = torch.load(
|
75 |
+
os.path.join(model_path, "non_lora_trainables.bin"),
|
76 |
+
map_location="cpu",
|
77 |
+
)
|
78 |
+
else:
|
79 |
+
# this is probably from HF Hub
|
80 |
+
from huggingface_hub import hf_hub_download
|
81 |
+
|
82 |
+
def load_from_hf(repo_id, filename, subfolder=None):
|
83 |
+
cache_file = hf_hub_download(
|
84 |
+
repo_id=repo_id, filename=filename, subfolder=subfolder
|
85 |
+
)
|
86 |
+
return torch.load(cache_file, map_location="cpu")
|
87 |
+
|
88 |
+
non_lora_trainables = load_from_hf(
|
89 |
+
model_path, "non_lora_trainables.bin"
|
90 |
+
)
|
91 |
+
non_lora_trainables = {
|
92 |
+
(k[11:] if k.startswith("base_model.") else k): v
|
93 |
+
for k, v in non_lora_trainables.items()
|
94 |
+
}
|
95 |
+
if any(k.startswith("model.model.") for k in non_lora_trainables):
|
96 |
+
non_lora_trainables = {
|
97 |
+
(k[6:] if k.startswith("model.") else k): v
|
98 |
+
for k, v in non_lora_trainables.items()
|
99 |
+
}
|
100 |
+
model.load_state_dict(non_lora_trainables, strict=False)
|
101 |
+
|
102 |
+
from peft import PeftModel
|
103 |
+
|
104 |
+
print("Loading LoRA weights...")
|
105 |
+
model = PeftModel.from_pretrained(model, model_path)
|
106 |
+
print("Merging LoRA weights...")
|
107 |
+
model = model.merge_and_unload()
|
108 |
+
print("Model is loaded...")
|
109 |
+
elif model_base is not None:
|
110 |
+
# this may be mm projector only
|
111 |
+
print("Loading LLaVA-Phi from base model...")
|
112 |
+
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
|
113 |
+
cfg_pretrained = AutoConfig.from_pretrained(model_path)
|
114 |
+
model = LlavaPhiForCausalLM.from_pretrained(
|
115 |
+
model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs
|
116 |
+
)
|
117 |
+
|
118 |
+
mm_projector_weights = torch.load(
|
119 |
+
os.path.join(model_path, "mm_projector.bin"), map_location="cpu"
|
120 |
+
)
|
121 |
+
mm_projector_weights = {
|
122 |
+
k: v.to(torch.float16) for k, v in mm_projector_weights.items()
|
123 |
+
}
|
124 |
+
model.load_state_dict(mm_projector_weights, strict=False)
|
125 |
+
else:
|
126 |
+
print("load llaVA-Phi MLLM!!!")
|
127 |
+
config = LlavaPhiConfig.from_pretrained(model_path, trust_remote_code=True)
|
128 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
|
129 |
+
model = LlavaPhiForCausalLM.from_pretrained(
|
130 |
+
model_path, config=config, use_safetensors=True, **kwargs
|
131 |
+
).to("cuda")
|
132 |
+
else:
|
133 |
+
# Load language model
|
134 |
+
if model_base is not None:
|
135 |
+
# PEFT model
|
136 |
+
from peft import PeftModel
|
137 |
+
|
138 |
+
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
|
139 |
+
model = AutoModelForCausalLM.from_pretrained(
|
140 |
+
model_base,
|
141 |
+
torch_dtype=torch.float16,
|
142 |
+
low_cpu_mem_usage=True,
|
143 |
+
device_map="auto",
|
144 |
+
)
|
145 |
+
print(f"Loading LoRA weights from {model_path}")
|
146 |
+
model = PeftModel.from_pretrained(model, model_path)
|
147 |
+
print(f"Merging weights")
|
148 |
+
model = model.merge_and_unload()
|
149 |
+
print("Convert to FP16...")
|
150 |
+
model.to(torch.float16)
|
151 |
+
else:
|
152 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
153 |
+
model = AutoModelForCausalLM.from_pretrained(
|
154 |
+
model_path, low_cpu_mem_usage=True, **kwargs
|
155 |
+
)
|
156 |
+
|
157 |
+
image_processor = CLIPImageProcessor.from_pretrained(model_path)
|
158 |
+
|
159 |
+
if "phi" in model_name.lower():
|
160 |
+
mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
|
161 |
+
mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
|
162 |
+
|
163 |
+
# TODO: the tokenizer length of phi-2 is 50295, but the output class of lm_head is 51200
|
164 |
+
if mm_use_im_patch_token:
|
165 |
+
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
166 |
+
if mm_use_im_start_end:
|
167 |
+
tokenizer.add_tokens(
|
168 |
+
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
|
169 |
+
)
|
170 |
+
# model.resize_token_embeddings(len(tokenizer))
|
171 |
+
else:
|
172 |
+
raise ValueError(f"Unsupported model name: {model_name}")
|
173 |
+
|
174 |
+
if hasattr(model.config, "max_sequence_length"):
|
175 |
+
context_len = model.config.max_sequence_length
|
176 |
+
else:
|
177 |
+
context_len = 2048
|
178 |
+
model.to(device="cuda")
|
179 |
+
print(kwargs)
|
180 |
+
return tokenizer, model, image_processor, context_len
|
inference/model/language_model/configuration_llava_phi.py
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Union
|
3 |
+
from transformers import PretrainedConfig, PhiConfig
|
4 |
+
from transformers.utils import logging
|
5 |
+
|
6 |
+
logger = logging.get_logger(__name__)
|
7 |
+
|
8 |
+
|
9 |
+
class LlavaPhiVisionConfig(PretrainedConfig):
|
10 |
+
r"""
|
11 |
+
This is the configuration class to store the configuration of a [`CLIPVisionModel`]. It is used to instantiate a
|
12 |
+
CLIP vision encoder according to the specified arguments, defining the model architecture. Instantiating a
|
13 |
+
configuration with the defaults will yield a similar configuration to that of the vision encoder of the CLIP
|
14 |
+
[openai/clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32) architecture.
|
15 |
+
|
16 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
17 |
+
documentation from [`PretrainedConfig`] for more information.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
hidden_size (`int`, *optional*, defaults to 768):
|
21 |
+
Dimensionality of the encoder layers and the pooler layer.
|
22 |
+
intermediate_size (`int`, *optional*, defaults to 3072):
|
23 |
+
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
24 |
+
projection_dim (`int`, *optional*, defaults to 512):
|
25 |
+
Dimentionality of text and vision projection layers.
|
26 |
+
num_hidden_layers (`int`, *optional*, defaults to 12):
|
27 |
+
Number of hidden layers in the Transformer encoder.
|
28 |
+
num_attention_heads (`int`, *optional*, defaults to 12):
|
29 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
30 |
+
num_channels (`int`, *optional*, defaults to 3):
|
31 |
+
The number of input channels.
|
32 |
+
image_size (`int`, *optional*, defaults to 224):
|
33 |
+
The size (resolution) of each image.
|
34 |
+
patch_size (`int`, *optional*, defaults to 32):
|
35 |
+
The size (resolution) of each patch.
|
36 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`):
|
37 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
38 |
+
`"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported.
|
39 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-05):
|
40 |
+
The epsilon used by the layer normalization layers.
|
41 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
42 |
+
The dropout ratio for the attention probabilities.
|
43 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
44 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
45 |
+
initializer_factor (`float`, *optional*, defaults to 1.0):
|
46 |
+
A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
|
47 |
+
testing).
|
48 |
+
mm_vision_select_feature (`str`, *optional*, defaults to `"patch"`):
|
49 |
+
The feature to select from the vision encoder output. Can be one of `"patch"` or `"cls_patch"`.
|
50 |
+
mm_vision_select_layer (`int`, *optional*, defaults to `-2`):
|
51 |
+
The layer to select from the vision encoder output.
|
52 |
+
|
53 |
+
Example:
|
54 |
+
|
55 |
+
```python
|
56 |
+
>>> from transformers import CLIPVisionConfig, CLIPVisionModel
|
57 |
+
|
58 |
+
>>> # Initializing a CLIPVisionConfig with openai/clip-vit-base-patch32 style configuration
|
59 |
+
>>> configuration = CLIPVisionConfig()
|
60 |
+
|
61 |
+
>>> # Initializing a CLIPVisionModel (with random weights) from the openai/clip-vit-base-patch32 style configuration
|
62 |
+
>>> model = CLIPVisionModel(configuration)
|
63 |
+
|
64 |
+
>>> # Accessing the model configuration
|
65 |
+
>>> configuration = model.config
|
66 |
+
```"""
|
67 |
+
|
68 |
+
model_type = "llava_phi_clip_vision_model"
|
69 |
+
|
70 |
+
def __init__(
|
71 |
+
self,
|
72 |
+
hidden_size=768,
|
73 |
+
intermediate_size=3072,
|
74 |
+
projection_dim=512,
|
75 |
+
num_hidden_layers=12,
|
76 |
+
num_attention_heads=12,
|
77 |
+
num_channels=3,
|
78 |
+
image_size=224,
|
79 |
+
patch_size=32,
|
80 |
+
hidden_act="quick_gelu",
|
81 |
+
layer_norm_eps=1e-5,
|
82 |
+
attention_dropout=0.0,
|
83 |
+
initializer_range=0.02,
|
84 |
+
initializer_factor=1.0,
|
85 |
+
mm_vision_select_feature="patch",
|
86 |
+
mm_vision_select_layer=-2,
|
87 |
+
**kwargs,
|
88 |
+
):
|
89 |
+
super().__init__(**kwargs)
|
90 |
+
|
91 |
+
self.hidden_size = hidden_size
|
92 |
+
self.intermediate_size = intermediate_size
|
93 |
+
self.projection_dim = projection_dim
|
94 |
+
self.num_hidden_layers = num_hidden_layers
|
95 |
+
self.num_attention_heads = num_attention_heads
|
96 |
+
self.num_channels = num_channels
|
97 |
+
self.patch_size = patch_size
|
98 |
+
self.image_size = image_size
|
99 |
+
self.initializer_range = initializer_range
|
100 |
+
self.initializer_factor = initializer_factor
|
101 |
+
self.attention_dropout = attention_dropout
|
102 |
+
self.layer_norm_eps = layer_norm_eps
|
103 |
+
self.hidden_act = hidden_act
|
104 |
+
self.mm_vision_select_feature = mm_vision_select_feature
|
105 |
+
self.mm_vision_select_layer = mm_vision_select_layer
|
106 |
+
|
107 |
+
@classmethod
|
108 |
+
def from_pretrained(
|
109 |
+
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
|
110 |
+
) -> "PretrainedConfig":
|
111 |
+
cls._set_token_in_kwargs(kwargs)
|
112 |
+
|
113 |
+
config_dict, kwargs = cls.get_config_dict(
|
114 |
+
pretrained_model_name_or_path, **kwargs
|
115 |
+
)
|
116 |
+
|
117 |
+
# get the vision config dict if we are loading from CLIPConfig
|
118 |
+
if config_dict.get("model_type") == "llava_phi-phi":
|
119 |
+
config_dict = config_dict["vision_config"]
|
120 |
+
|
121 |
+
if (
|
122 |
+
"model_type" in config_dict
|
123 |
+
and hasattr(cls, "model_type")
|
124 |
+
and config_dict["model_type"] != cls.model_type
|
125 |
+
):
|
126 |
+
logger.warning(
|
127 |
+
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
128 |
+
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
129 |
+
)
|
130 |
+
|
131 |
+
return cls.from_dict(config_dict, **kwargs)
|
132 |
+
|
133 |
+
|
134 |
+
class ProjectorConfig(PretrainedConfig):
|
135 |
+
model_type = "llava_phi_projector"
|
136 |
+
|
137 |
+
def __init__(
|
138 |
+
self, mm_projector_type="linear", mm_hidden_size=768, hidden_size=2560, **kwargs
|
139 |
+
):
|
140 |
+
self.mm_projector_type = mm_projector_type
|
141 |
+
self.mm_hidden_size = mm_hidden_size
|
142 |
+
self.hidden_size = hidden_size
|
143 |
+
super().__init__(**kwargs)
|
144 |
+
|
145 |
+
@classmethod
|
146 |
+
def from_pretrained(
|
147 |
+
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
|
148 |
+
) -> "PretrainedConfig":
|
149 |
+
cls._set_token_in_kwargs(kwargs)
|
150 |
+
|
151 |
+
config_dict, kwargs = cls.get_config_dict(
|
152 |
+
pretrained_model_name_or_path, **kwargs
|
153 |
+
)
|
154 |
+
|
155 |
+
# get the vision config dict if we are loading from CLIPConfig
|
156 |
+
if config_dict.get("model_type") == "llava_phi-phi":
|
157 |
+
config_dict = config_dict["projector_config"]
|
158 |
+
|
159 |
+
if (
|
160 |
+
"model_type" in config_dict
|
161 |
+
and hasattr(cls, "model_type")
|
162 |
+
and config_dict["model_type"] != cls.model_type
|
163 |
+
):
|
164 |
+
logger.warning(
|
165 |
+
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
166 |
+
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
167 |
+
)
|
168 |
+
|
169 |
+
return cls.from_dict(config_dict, **kwargs)
|
170 |
+
|
171 |
+
|
172 |
+
DEFAULT_VISUAL_CONFIG = {
|
173 |
+
"vision_tower": LlavaPhiVisionConfig().to_dict(),
|
174 |
+
"mm_projector": ProjectorConfig().to_dict(),
|
175 |
+
}
|
176 |
+
|
177 |
+
|
178 |
+
class LlavaPhiConfig(PhiConfig):
|
179 |
+
model_type = "llava_phi"
|
180 |
+
|
181 |
+
def __init__(self, vision_config=None, **kwargs):
|
182 |
+
if vision_config is None:
|
183 |
+
self.vision_config = DEFAULT_VISUAL_CONFIG
|
184 |
+
else:
|
185 |
+
self.vision_config = vision_config
|
186 |
+
|
187 |
+
super().__init__(**kwargs)
|
188 |
+
|
189 |
+
|
190 |
+
if __name__ == "__main__":
|
191 |
+
print(LlavaPhiVisionConfig())
|
inference/model/language_model/llava_phi.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import List, Optional, Tuple, Union
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from torch.nn import CrossEntropyLoss
|
7 |
+
|
8 |
+
from transformers import AutoConfig, AutoModelForCausalLM, \
|
9 |
+
PhiModel, PhiPreTrainedModel
|
10 |
+
|
11 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
12 |
+
from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
|
13 |
+
from transformers.utils import logging
|
14 |
+
from .configuration_llava_phi import LlavaPhiConfig
|
15 |
+
|
16 |
+
logger = logging.get_logger(__name__)
|
17 |
+
|
18 |
+
|
19 |
+
class LLavaPhiModel(LlavaMetaModel, PhiModel):
|
20 |
+
config_class = LlavaPhiConfig
|
21 |
+
|
22 |
+
def __init__(self, config):
|
23 |
+
super(LLavaPhiModel, self).__init__(config)
|
24 |
+
|
25 |
+
|
26 |
+
class LlavaPhiForCausalLM(PhiPreTrainedModel, LlavaMetaForCausalLM):
|
27 |
+
config_class = LlavaPhiConfig
|
28 |
+
|
29 |
+
def __init__(self, config):
|
30 |
+
super(PhiPreTrainedModel, self).__init__(config)
|
31 |
+
self.model = LLavaPhiModel(config)
|
32 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
|
33 |
+
|
34 |
+
# Initialize weights and apply final processing
|
35 |
+
self.post_init()
|
36 |
+
|
37 |
+
def get_model(self):
|
38 |
+
return self.model
|
39 |
+
|
40 |
+
def forward(
|
41 |
+
self,
|
42 |
+
input_ids: torch.LongTensor = None,
|
43 |
+
attention_mask: Optional[torch.Tensor] = None,
|
44 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
45 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
46 |
+
labels: Optional[torch.LongTensor] = None,
|
47 |
+
use_cache: Optional[bool] = None,
|
48 |
+
output_attentions: Optional[bool] = None,
|
49 |
+
output_hidden_states: Optional[bool] = None,
|
50 |
+
images: Optional[torch.FloatTensor] = None,
|
51 |
+
return_dict: Optional[bool] = None,
|
52 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
53 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
54 |
+
output_hidden_states = (
|
55 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
56 |
+
)
|
57 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
58 |
+
|
59 |
+
input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(
|
60 |
+
input_ids, attention_mask, past_key_values, labels, images)
|
61 |
+
|
62 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
63 |
+
outputs = self.model(
|
64 |
+
input_ids=input_ids,
|
65 |
+
attention_mask=attention_mask,
|
66 |
+
past_key_values=past_key_values,
|
67 |
+
inputs_embeds=inputs_embeds,
|
68 |
+
use_cache=use_cache,
|
69 |
+
output_attentions=output_attentions,
|
70 |
+
output_hidden_states=output_hidden_states,
|
71 |
+
return_dict=return_dict
|
72 |
+
)
|
73 |
+
|
74 |
+
hidden_states = outputs[0]
|
75 |
+
logits = self.lm_head(hidden_states)
|
76 |
+
|
77 |
+
loss = None
|
78 |
+
if labels is not None:
|
79 |
+
# Shift so that tokens < n predict n
|
80 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
81 |
+
shift_labels = labels[..., 1:].contiguous()
|
82 |
+
# Flatten the tokens
|
83 |
+
loss_fct = CrossEntropyLoss()
|
84 |
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
85 |
+
shift_labels = shift_labels.view(-1)
|
86 |
+
# Enable model/pipeline parallelism
|
87 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
88 |
+
loss = loss_fct(shift_logits, shift_labels)
|
89 |
+
|
90 |
+
if not return_dict:
|
91 |
+
output = (logits,) + outputs[1:]
|
92 |
+
return (loss,) + output if loss is not None else output
|
93 |
+
|
94 |
+
return CausalLMOutputWithPast(
|
95 |
+
loss=loss,
|
96 |
+
logits=logits,
|
97 |
+
past_key_values=outputs.past_key_values,
|
98 |
+
hidden_states=outputs.hidden_states,
|
99 |
+
attentions=outputs.attentions,
|
100 |
+
)
|
101 |
+
|
102 |
+
def prepare_inputs_for_generation(
|
103 |
+
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
104 |
+
):
|
105 |
+
if past_key_values:
|
106 |
+
input_ids = input_ids[:, -1:]
|
107 |
+
|
108 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
109 |
+
if inputs_embeds is not None and past_key_values is None:
|
110 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
111 |
+
else:
|
112 |
+
model_inputs = {"input_ids": input_ids}
|
113 |
+
|
114 |
+
model_inputs.update(
|
115 |
+
{
|
116 |
+
"past_key_values": past_key_values,
|
117 |
+
"use_cache": kwargs.get("use_cache"),
|
118 |
+
"attention_mask": attention_mask,
|
119 |
+
"images": kwargs.get("images", None),
|
120 |
+
}
|
121 |
+
)
|
122 |
+
return model_inputs
|
123 |
+
|
124 |
+
|
125 |
+
AutoConfig.register("llava_phi", LlavaPhiConfig)
|
126 |
+
AutoModelForCausalLM.register(LlavaPhiConfig, LlavaPhiForCausalLM)
|
inference/model/llava_arch.py
ADDED
@@ -0,0 +1,330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Haotian Liu
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
from abc import ABC, abstractmethod
|
17 |
+
|
18 |
+
import torch
|
19 |
+
|
20 |
+
from .multimodal_encoder.clip_encoder import CLIPVisionTower
|
21 |
+
from .multimodal_projector.builder import build_vision_projector
|
22 |
+
from .language_model.configuration_llava_phi import (
|
23 |
+
LlavaPhiConfig,
|
24 |
+
LlavaPhiVisionConfig,
|
25 |
+
ProjectorConfig,
|
26 |
+
)
|
27 |
+
|
28 |
+
# from llava_phi.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
29 |
+
IGNORE_INDEX = -100
|
30 |
+
IMAGE_TOKEN_INDEX = -200
|
31 |
+
DEFAULT_IMAGE_TOKEN = "<image>"
|
32 |
+
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
|
33 |
+
DEFAULT_IM_START_TOKEN = "<im_start>"
|
34 |
+
DEFAULT_IM_END_TOKEN = "<im_end>"
|
35 |
+
|
36 |
+
|
37 |
+
class LlavaMetaModel:
|
38 |
+
def __init__(self, config):
|
39 |
+
super(LlavaMetaModel, self).__init__(config)
|
40 |
+
self.vision_tower = CLIPVisionTower(
|
41 |
+
LlavaPhiVisionConfig(**config.vision_config["vision_tower"])
|
42 |
+
)
|
43 |
+
self.mm_projector = build_vision_projector(
|
44 |
+
ProjectorConfig(**config.vision_config["mm_projector"])
|
45 |
+
)
|
46 |
+
|
47 |
+
def get_vision_tower(self):
|
48 |
+
vision_tower = getattr(self, "vision_tower", None)
|
49 |
+
if type(vision_tower) is list:
|
50 |
+
vision_tower = vision_tower[0]
|
51 |
+
return vision_tower
|
52 |
+
|
53 |
+
|
54 |
+
class LlavaMetaForCausalLM(ABC):
|
55 |
+
@abstractmethod
|
56 |
+
def get_model(self):
|
57 |
+
pass
|
58 |
+
|
59 |
+
def get_vision_tower(self):
|
60 |
+
return self.get_model().get_vision_tower()
|
61 |
+
|
62 |
+
def encode_images(self, images):
|
63 |
+
image_features = self.get_model().get_vision_tower()(images)
|
64 |
+
image_features = self.get_model().mm_projector(image_features)
|
65 |
+
return image_features
|
66 |
+
|
67 |
+
def prepare_inputs_labels_for_multimodal(
|
68 |
+
self, input_ids, attention_mask, past_key_values, labels, images
|
69 |
+
):
|
70 |
+
vision_tower = self.get_vision_tower()
|
71 |
+
if vision_tower is None or images is None or input_ids.shape[1] == 1:
|
72 |
+
if (
|
73 |
+
past_key_values is not None
|
74 |
+
and vision_tower is not None
|
75 |
+
and images is not None
|
76 |
+
and input_ids.shape[1] == 1
|
77 |
+
):
|
78 |
+
attention_mask = torch.ones(
|
79 |
+
(attention_mask.shape[0], past_key_values[-1][-1].shape[-2] + 1),
|
80 |
+
dtype=attention_mask.dtype,
|
81 |
+
device=attention_mask.device,
|
82 |
+
)
|
83 |
+
return input_ids, attention_mask, past_key_values, None, labels
|
84 |
+
|
85 |
+
if type(images) is list or images.ndim == 5:
|
86 |
+
concat_images = torch.cat([image for image in images], dim=0)
|
87 |
+
image_features = self.encode_images(concat_images)
|
88 |
+
split_sizes = [image.shape[0] for image in images]
|
89 |
+
image_features = torch.split(image_features, split_sizes, dim=0)
|
90 |
+
image_features = [x.flatten(0, 1) for x in image_features]
|
91 |
+
else:
|
92 |
+
image_features = self.encode_images(images)
|
93 |
+
|
94 |
+
new_input_embeds = []
|
95 |
+
new_labels = [] if labels is not None else None
|
96 |
+
cur_image_idx = 0
|
97 |
+
for batch_idx, cur_input_ids in enumerate(input_ids):
|
98 |
+
if (cur_input_ids == IMAGE_TOKEN_INDEX).sum() == 0:
|
99 |
+
# multimodal LLM, but the current sample is not multimodal
|
100 |
+
# FIXME: this is a hacky fix, for deepspeed zero3 to work
|
101 |
+
half_len = cur_input_ids.shape[0] // 2
|
102 |
+
cur_image_features = image_features[cur_image_idx]
|
103 |
+
cur_input_embeds_1 = self.get_model().embed_tokens(
|
104 |
+
cur_input_ids[:half_len]
|
105 |
+
)
|
106 |
+
cur_input_embeds_2 = self.get_model().embed_tokens(
|
107 |
+
cur_input_ids[half_len:]
|
108 |
+
)
|
109 |
+
cur_input_embeds = torch.cat(
|
110 |
+
[cur_input_embeds_1, cur_image_features[0:0], cur_input_embeds_2],
|
111 |
+
dim=0,
|
112 |
+
)
|
113 |
+
new_input_embeds.append(cur_input_embeds)
|
114 |
+
if labels is not None:
|
115 |
+
new_labels.append(labels[batch_idx])
|
116 |
+
cur_image_idx += 1
|
117 |
+
continue
|
118 |
+
image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0]
|
119 |
+
cur_new_input_embeds = []
|
120 |
+
if labels is not None:
|
121 |
+
cur_labels = labels[batch_idx]
|
122 |
+
cur_new_labels = []
|
123 |
+
assert cur_labels.shape == cur_input_ids.shape
|
124 |
+
while image_token_indices.numel() > 0:
|
125 |
+
cur_image_features = image_features[cur_image_idx]
|
126 |
+
image_token_start = image_token_indices[0]
|
127 |
+
if getattr(self.config, "tune_mm_mlp_adapter", False) and getattr(
|
128 |
+
self.config, "mm_use_im_start_end", False
|
129 |
+
):
|
130 |
+
cur_new_input_embeds.append(
|
131 |
+
self.get_model()
|
132 |
+
.embed_tokens(cur_input_ids[: image_token_start - 1])
|
133 |
+
.detach()
|
134 |
+
)
|
135 |
+
cur_new_input_embeds.append(
|
136 |
+
self.get_model().embed_tokens(
|
137 |
+
cur_input_ids[image_token_start - 1 : image_token_start]
|
138 |
+
)
|
139 |
+
)
|
140 |
+
cur_new_input_embeds.append(cur_image_features)
|
141 |
+
cur_new_input_embeds.append(
|
142 |
+
self.get_model().embed_tokens(
|
143 |
+
cur_input_ids[image_token_start + 1 : image_token_start + 2]
|
144 |
+
)
|
145 |
+
)
|
146 |
+
if labels is not None:
|
147 |
+
cur_new_labels.append(cur_labels[:image_token_start])
|
148 |
+
cur_new_labels.append(
|
149 |
+
torch.full(
|
150 |
+
(cur_image_features.shape[0],),
|
151 |
+
IGNORE_INDEX,
|
152 |
+
device=labels.device,
|
153 |
+
dtype=labels.dtype,
|
154 |
+
)
|
155 |
+
)
|
156 |
+
cur_new_labels.append(
|
157 |
+
cur_labels[image_token_start : image_token_start + 1]
|
158 |
+
)
|
159 |
+
cur_labels = cur_labels[image_token_start + 2 :]
|
160 |
+
else:
|
161 |
+
cur_new_input_embeds.append(
|
162 |
+
self.get_model().embed_tokens(cur_input_ids[:image_token_start])
|
163 |
+
)
|
164 |
+
cur_new_input_embeds.append(cur_image_features)
|
165 |
+
if labels is not None:
|
166 |
+
cur_new_labels.append(cur_labels[:image_token_start])
|
167 |
+
cur_new_labels.append(
|
168 |
+
torch.full(
|
169 |
+
(cur_image_features.shape[0],),
|
170 |
+
IGNORE_INDEX,
|
171 |
+
device=labels.device,
|
172 |
+
dtype=labels.dtype,
|
173 |
+
)
|
174 |
+
)
|
175 |
+
cur_labels = cur_labels[image_token_start + 1 :]
|
176 |
+
cur_image_idx += 1
|
177 |
+
if getattr(self.config, "tune_mm_mlp_adapter", False) and getattr(
|
178 |
+
self.config, "mm_use_im_start_end", False
|
179 |
+
):
|
180 |
+
cur_input_ids = cur_input_ids[image_token_start + 2 :]
|
181 |
+
else:
|
182 |
+
cur_input_ids = cur_input_ids[image_token_start + 1 :]
|
183 |
+
image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0]
|
184 |
+
if cur_input_ids.numel() > 0:
|
185 |
+
if getattr(self.config, "tune_mm_mlp_adapter", False) and getattr(
|
186 |
+
self.config, "mm_use_im_start_end", False
|
187 |
+
):
|
188 |
+
cur_new_input_embeds.append(
|
189 |
+
self.get_model().embed_tokens(cur_input_ids).detach()
|
190 |
+
)
|
191 |
+
else:
|
192 |
+
cur_new_input_embeds.append(
|
193 |
+
self.get_model().embed_tokens(cur_input_ids)
|
194 |
+
)
|
195 |
+
if labels is not None:
|
196 |
+
cur_new_labels.append(cur_labels)
|
197 |
+
cur_new_input_embeds = [
|
198 |
+
x.to(device=self.device) for x in cur_new_input_embeds
|
199 |
+
]
|
200 |
+
cur_new_input_embeds = torch.cat(cur_new_input_embeds, dim=0)
|
201 |
+
new_input_embeds.append(cur_new_input_embeds)
|
202 |
+
if labels is not None:
|
203 |
+
cur_new_labels = torch.cat(cur_new_labels, dim=0)
|
204 |
+
new_labels.append(cur_new_labels)
|
205 |
+
|
206 |
+
if any(x.shape != new_input_embeds[0].shape for x in new_input_embeds):
|
207 |
+
max_len = max(x.shape[0] for x in new_input_embeds)
|
208 |
+
|
209 |
+
new_input_embeds_align = []
|
210 |
+
for cur_new_embed in new_input_embeds:
|
211 |
+
cur_new_embed = torch.cat(
|
212 |
+
(
|
213 |
+
cur_new_embed,
|
214 |
+
torch.zeros(
|
215 |
+
(max_len - cur_new_embed.shape[0], cur_new_embed.shape[1]),
|
216 |
+
dtype=cur_new_embed.dtype,
|
217 |
+
device=cur_new_embed.device,
|
218 |
+
),
|
219 |
+
),
|
220 |
+
dim=0,
|
221 |
+
)
|
222 |
+
new_input_embeds_align.append(cur_new_embed)
|
223 |
+
new_input_embeds = torch.stack(new_input_embeds_align, dim=0)
|
224 |
+
|
225 |
+
if labels is not None:
|
226 |
+
new_labels_align = []
|
227 |
+
_new_labels = new_labels
|
228 |
+
for cur_new_label in new_labels:
|
229 |
+
cur_new_label = torch.cat(
|
230 |
+
(
|
231 |
+
cur_new_label,
|
232 |
+
torch.full(
|
233 |
+
(max_len - cur_new_label.shape[0],),
|
234 |
+
IGNORE_INDEX,
|
235 |
+
dtype=cur_new_label.dtype,
|
236 |
+
device=cur_new_label.device,
|
237 |
+
),
|
238 |
+
),
|
239 |
+
dim=0,
|
240 |
+
)
|
241 |
+
new_labels_align.append(cur_new_label)
|
242 |
+
new_labels = torch.stack(new_labels_align, dim=0)
|
243 |
+
|
244 |
+
if attention_mask is not None:
|
245 |
+
new_attention_mask = []
|
246 |
+
for cur_attention_mask, cur_new_labels, cur_new_labels_align in zip(
|
247 |
+
attention_mask, _new_labels, new_labels
|
248 |
+
):
|
249 |
+
new_attn_mask_pad_left = torch.full(
|
250 |
+
(cur_new_labels.shape[0] - labels.shape[1],),
|
251 |
+
True,
|
252 |
+
dtype=attention_mask.dtype,
|
253 |
+
device=attention_mask.device,
|
254 |
+
)
|
255 |
+
new_attn_mask_pad_right = torch.full(
|
256 |
+
(cur_new_labels_align.shape[0] - cur_new_labels.shape[0],),
|
257 |
+
False,
|
258 |
+
dtype=attention_mask.dtype,
|
259 |
+
device=attention_mask.device,
|
260 |
+
)
|
261 |
+
cur_new_attention_mask = torch.cat(
|
262 |
+
(
|
263 |
+
new_attn_mask_pad_left,
|
264 |
+
cur_attention_mask,
|
265 |
+
new_attn_mask_pad_right,
|
266 |
+
),
|
267 |
+
dim=0,
|
268 |
+
)
|
269 |
+
new_attention_mask.append(cur_new_attention_mask)
|
270 |
+
attention_mask = torch.stack(new_attention_mask, dim=0)
|
271 |
+
assert attention_mask.shape == new_labels.shape
|
272 |
+
else:
|
273 |
+
new_input_embeds = torch.stack(new_input_embeds, dim=0)
|
274 |
+
if labels is not None:
|
275 |
+
new_labels = torch.stack(new_labels, dim=0)
|
276 |
+
|
277 |
+
if attention_mask is not None:
|
278 |
+
new_attn_mask_pad_left = torch.full(
|
279 |
+
(
|
280 |
+
attention_mask.shape[0],
|
281 |
+
new_input_embeds.shape[1] - input_ids.shape[1],
|
282 |
+
),
|
283 |
+
True,
|
284 |
+
dtype=attention_mask.dtype,
|
285 |
+
device=attention_mask.device,
|
286 |
+
)
|
287 |
+
attention_mask = torch.cat(
|
288 |
+
(new_attn_mask_pad_left, attention_mask), dim=1
|
289 |
+
)
|
290 |
+
assert attention_mask.shape == new_input_embeds.shape[:2]
|
291 |
+
|
292 |
+
return None, attention_mask, past_key_values, new_input_embeds, new_labels
|
293 |
+
|
294 |
+
def initialize_vision_tokenizer(self, model_args, tokenizer):
|
295 |
+
if model_args.mm_use_im_patch_token:
|
296 |
+
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
297 |
+
self.resize_token_embeddings(len(tokenizer))
|
298 |
+
|
299 |
+
if model_args.mm_use_im_start_end:
|
300 |
+
num_new_tokens = tokenizer.add_tokens(
|
301 |
+
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
|
302 |
+
)
|
303 |
+
self.resize_token_embeddings(len(tokenizer))
|
304 |
+
|
305 |
+
if num_new_tokens > 0:
|
306 |
+
input_embeddings = self.get_input_embeddings().weight.data
|
307 |
+
output_embeddings = self.get_output_embeddings().weight.data
|
308 |
+
|
309 |
+
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
|
310 |
+
dim=0, keepdim=True
|
311 |
+
)
|
312 |
+
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
|
313 |
+
dim=0, keepdim=True
|
314 |
+
)
|
315 |
+
|
316 |
+
input_embeddings[-num_new_tokens:] = input_embeddings_avg
|
317 |
+
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
318 |
+
|
319 |
+
if model_args.tune_mm_mlp_adapter:
|
320 |
+
for p in self.get_input_embeddings().parameters():
|
321 |
+
p.requires_grad = True
|
322 |
+
for p in self.get_output_embeddings().parameters():
|
323 |
+
p.requires_grad = False
|
324 |
+
|
325 |
+
elif model_args.mm_use_im_patch_token:
|
326 |
+
if model_args.tune_mm_mlp_adapter:
|
327 |
+
for p in self.get_input_embeddings().parameters():
|
328 |
+
p.requires_grad = False
|
329 |
+
for p in self.get_output_embeddings().parameters():
|
330 |
+
p.requires_grad = False
|
inference/model/multimodal_encoder/clip_encoder.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
from transformers import CLIPPreTrainedModel, CLIPVisionConfig
|
7 |
+
from transformers.models.clip.modeling_clip import CLIPVisionTransformer
|
8 |
+
from inference.model.language_model.configuration_llava_phi import LlavaPhiVisionConfig
|
9 |
+
|
10 |
+
|
11 |
+
class CLIPVisionTower(CLIPPreTrainedModel):
|
12 |
+
config_class = LlavaPhiVisionConfig
|
13 |
+
|
14 |
+
def __init__(self, config):
|
15 |
+
super().__init__(config)
|
16 |
+
|
17 |
+
self.vision_model = CLIPVisionTransformer(config)
|
18 |
+
# Initialize weights and apply final processing
|
19 |
+
self.post_init()
|
20 |
+
|
21 |
+
def get_input_embeddings(self) -> nn.Module:
|
22 |
+
return self.vision_model.embeddings.patch_embedding
|
23 |
+
|
24 |
+
def feature_select(self, image_forward_outs):
|
25 |
+
image_features = image_forward_outs.hidden_states[
|
26 |
+
self.config.mm_vision_select_layer
|
27 |
+
]
|
28 |
+
if self.config.mm_vision_select_feature == "patch":
|
29 |
+
image_features = image_features[:, 1:]
|
30 |
+
elif self.config.mm_vision_select_feature == "cls_patch":
|
31 |
+
image_features = image_features
|
32 |
+
else:
|
33 |
+
raise ValueError(
|
34 |
+
f"Unexpected select feature: {self.config.mm_vision_select_feature}"
|
35 |
+
)
|
36 |
+
return image_features
|
37 |
+
|
38 |
+
def forward(self, images):
|
39 |
+
if type(images) is list:
|
40 |
+
image_features = []
|
41 |
+
for image in images:
|
42 |
+
image_forward_out = self.vision_model(
|
43 |
+
image.to(device=self.device, dtype=self.dtype).unsqueeze(0),
|
44 |
+
output_hidden_states=True,
|
45 |
+
)
|
46 |
+
image_feature = self.feature_select(image_forward_out).to(image.dtype)
|
47 |
+
image_features.append(image_feature)
|
48 |
+
else:
|
49 |
+
image_forward_outs = self.vision_model(
|
50 |
+
images.to(device=self.device, dtype=self.dtype),
|
51 |
+
output_hidden_states=True,
|
52 |
+
)
|
53 |
+
image_features = self.feature_select(image_forward_outs).to(images.dtype)
|
54 |
+
|
55 |
+
return image_features
|
56 |
+
|
57 |
+
@property
|
58 |
+
def dummy_feature(self):
|
59 |
+
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
|
60 |
+
|
61 |
+
@property
|
62 |
+
def dtype(self):
|
63 |
+
return list(self.vision_model.parameters())[0].dtype
|
64 |
+
|
65 |
+
@property
|
66 |
+
def device(self):
|
67 |
+
return list(self.vision_model.parameters())[0].device
|
68 |
+
|
69 |
+
@property
|
70 |
+
def hidden_size(self):
|
71 |
+
return self.config.hidden_size
|
72 |
+
|
73 |
+
@property
|
74 |
+
def num_patches(self):
|
75 |
+
return (self.config.image_size // self.config.patch_size) ** 2
|
76 |
+
|
77 |
+
|
78 |
+
if __name__ == "__main__":
|
79 |
+
clip_config = CLIPVisionConfig.from_pretrained(
|
80 |
+
"/data/private/zhumj/GPTcode/mm-phi/openai/clip-vit-large-patch14-336"
|
81 |
+
)
|
82 |
+
print("################ clip_config ##############")
|
83 |
+
print(clip_config)
|
84 |
+
phi_vis_config = LlavaPhiVisionConfig(**clip_config.to_dict())
|
85 |
+
print("################ phi_vis_config ##############")
|
86 |
+
print(phi_vis_config)
|
87 |
+
|
88 |
+
model = CLIPVisionTower(clip_config)
|
89 |
+
# print(list(model.vision_model.parameters())[0].dtype)
|
inference/model/multimodal_projector/builder.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import re
|
4 |
+
|
5 |
+
|
6 |
+
class IdentityMap(nn.Module):
|
7 |
+
def __init__(self):
|
8 |
+
super().__init__()
|
9 |
+
|
10 |
+
def forward(self, x, *args, **kwargs):
|
11 |
+
return x
|
12 |
+
|
13 |
+
@property
|
14 |
+
def config(self):
|
15 |
+
return {"mm_projector_type": "identity"}
|
16 |
+
|
17 |
+
|
18 |
+
class SimpleResBlock(nn.Module):
|
19 |
+
def __init__(self, channels):
|
20 |
+
super().__init__()
|
21 |
+
self.pre_norm = nn.LayerNorm(channels)
|
22 |
+
|
23 |
+
self.proj = nn.Sequential(
|
24 |
+
nn.Linear(channels, channels), nn.GELU(), nn.Linear(channels, channels)
|
25 |
+
)
|
26 |
+
|
27 |
+
def forward(self, x):
|
28 |
+
x = self.pre_norm(x)
|
29 |
+
return x + self.proj(x)
|
30 |
+
|
31 |
+
|
32 |
+
def build_vision_projector(config):
|
33 |
+
projector_type = getattr(config, "mm_projector_type", "linear")
|
34 |
+
|
35 |
+
if projector_type == "linear":
|
36 |
+
return nn.Linear(config.mm_hidden_size, config.hidden_size)
|
37 |
+
|
38 |
+
mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type)
|
39 |
+
if mlp_gelu_match:
|
40 |
+
mlp_depth = int(mlp_gelu_match.group(1))
|
41 |
+
modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
|
42 |
+
for _ in range(1, mlp_depth):
|
43 |
+
modules.append(nn.GELU())
|
44 |
+
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
|
45 |
+
return nn.Sequential(*modules)
|
46 |
+
|
47 |
+
if projector_type == "identity":
|
48 |
+
return IdentityMap()
|
49 |
+
|
50 |
+
raise ValueError(f"Unknown projector type: {projector_type}")
|