init LAMM
Browse files- app.py +240 -0
- model/CLIP/__init__.py +2 -0
- model/CLIP/__pycache__/__init__.cpython-310.pyc +0 -0
- model/CLIP/__pycache__/clip.cpython-310.pyc +0 -0
- model/CLIP/__pycache__/model.cpython-310.pyc +0 -0
- model/CLIP/__pycache__/simple_tokenizer.cpython-310.pyc +0 -0
- model/CLIP/bpe_simple_vocab_16e6.txt.gz +3 -0
- model/CLIP/clip.py +237 -0
- model/CLIP/model.py +449 -0
- model/CLIP/simple_tokenizer.py +132 -0
- model/PROCESS/__init__.py +1 -0
- model/PROCESS/bpe/bpe_simple_vocab_16e6.txt.gz +3 -0
- model/PROCESS/data.py +406 -0
- model/PROCESS/helpers.py +141 -0
- model/PROCESS/multimodal_preprocessors.py +687 -0
- model/__init__.py +10 -0
- model/agent.py +76 -0
- model/conversations.py +386 -0
- model/modeling_llama.py +755 -0
- model/openlamm.py +550 -0
- model/utils/__init__.py +2 -0
- model/utils/__pycache__/__init__.cpython-310.pyc +0 -0
- model/utils/__pycache__/ceph_utils.cpython-310.pyc +0 -0
- model/utils/__pycache__/pcl_utils.cpython-310.pyc +0 -0
- model/utils/ceph_utils.py +260 -0
- model/utils/pcl_utils.py +131 -0
- pretrained_ckpt/lamm98k/config.json +23 -0
- pretrained_ckpt/lamm98k/pytorch_model.pt +3 -0
- pretrained_ckpt/lamm98k/special_tokens_map.json +24 -0
- pretrained_ckpt/lamm98k/tokenizer.model +3 -0
- pretrained_ckpt/lamm98k/tokenizer_config.json +34 -0
- requirements.txt +19 -0
app.py
ADDED
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoModel, AutoTokenizer
|
2 |
+
from copy import deepcopy
|
3 |
+
import os
|
4 |
+
import ipdb
|
5 |
+
import gradio as gr
|
6 |
+
import mdtex2html
|
7 |
+
from model.openlamm import LAMMPEFTModel
|
8 |
+
import torch
|
9 |
+
import json
|
10 |
+
|
11 |
+
# init the model
|
12 |
+
args = {
|
13 |
+
'model': 'openllama_peft',
|
14 |
+
'imagebind_ckpt_path': '../model_zoo/imagebind_ckpt',
|
15 |
+
'vicuna_ckpt_path': '../model_zoo/vicuna_ckpt/13b_v0',
|
16 |
+
'delta_ckpt_path': './pretrained_ckpt/lamm98k/pytorch_model.pt',
|
17 |
+
'stage': 1,
|
18 |
+
'max_tgt_len': 128,
|
19 |
+
'lora_r': 32,
|
20 |
+
'lora_alpha': 32,
|
21 |
+
'lora_dropout': 0.1,
|
22 |
+
'lora_target_modules': ['q_proj', 'k_proj', 'v_proj', 'o_proj'],
|
23 |
+
'vision_type': 'image',
|
24 |
+
'vision_feature_type': 'local',
|
25 |
+
'num_vision_token': 256,
|
26 |
+
'encoder_pretrain': 'clip',
|
27 |
+
'system_header': True,
|
28 |
+
}
|
29 |
+
|
30 |
+
model = LAMMPEFTModel(**args)
|
31 |
+
delta_ckpt = torch.load(args['delta_ckpt_path'], map_location=torch.device('cpu'))
|
32 |
+
model.load_state_dict(delta_ckpt, strict=False)
|
33 |
+
model = model.eval().half().cuda()
|
34 |
+
print(f'[!] init the 13b model over ...')
|
35 |
+
|
36 |
+
"""Override Chatbot.postprocess"""
|
37 |
+
|
38 |
+
|
39 |
+
def postprocess(self, y):
|
40 |
+
if y is None:
|
41 |
+
return []
|
42 |
+
for i, (message, response) in enumerate(y):
|
43 |
+
y[i] = (
|
44 |
+
None if message is None else mdtex2html.convert((message)),
|
45 |
+
None if response is None else mdtex2html.convert(response),
|
46 |
+
)
|
47 |
+
return y
|
48 |
+
|
49 |
+
|
50 |
+
gr.Chatbot.postprocess = postprocess
|
51 |
+
|
52 |
+
|
53 |
+
def parse_text(text):
|
54 |
+
"""copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/"""
|
55 |
+
lines = text.split("\n")
|
56 |
+
lines = [line for line in lines if line != ""]
|
57 |
+
count = 0
|
58 |
+
for i, line in enumerate(lines):
|
59 |
+
if "```" in line:
|
60 |
+
count += 1
|
61 |
+
items = line.split('`')
|
62 |
+
if count % 2 == 1:
|
63 |
+
lines[i] = f'<pre><code class="language-{items[-1]}">'
|
64 |
+
else:
|
65 |
+
lines[i] = f'<br></code></pre>'
|
66 |
+
else:
|
67 |
+
if i > 0:
|
68 |
+
if count % 2 == 1:
|
69 |
+
line = line.replace("`", "\`")
|
70 |
+
line = line.replace("<", "<")
|
71 |
+
line = line.replace(">", ">")
|
72 |
+
line = line.replace(" ", " ")
|
73 |
+
line = line.replace("*", "*")
|
74 |
+
line = line.replace("_", "_")
|
75 |
+
line = line.replace("-", "-")
|
76 |
+
line = line.replace(".", ".")
|
77 |
+
line = line.replace("!", "!")
|
78 |
+
line = line.replace("(", "(")
|
79 |
+
line = line.replace(")", ")")
|
80 |
+
line = line.replace("$", "$")
|
81 |
+
lines[i] = "<br>"+line
|
82 |
+
text = "".join(lines)
|
83 |
+
return text
|
84 |
+
|
85 |
+
|
86 |
+
def re_predict(
|
87 |
+
input,
|
88 |
+
image_path,
|
89 |
+
chatbot,
|
90 |
+
max_length,
|
91 |
+
top_p,
|
92 |
+
temperature,
|
93 |
+
history,
|
94 |
+
modality_cache,
|
95 |
+
):
|
96 |
+
# drop the latest query and answers and generate again
|
97 |
+
q, a = history.pop()
|
98 |
+
chatbot.pop()
|
99 |
+
return predict(q, image_path, chatbot, max_length, top_p, temperature, history, modality_cache)
|
100 |
+
|
101 |
+
|
102 |
+
def predict(
|
103 |
+
input,
|
104 |
+
image_path,
|
105 |
+
chatbot,
|
106 |
+
max_length,
|
107 |
+
top_p,
|
108 |
+
temperature,
|
109 |
+
history,
|
110 |
+
modality_cache,
|
111 |
+
):
|
112 |
+
if image_path is None: #
|
113 |
+
return [(input, "There is no input data provided! Please upload your data and start the conversation.")]
|
114 |
+
else:
|
115 |
+
print(f'[!] image path: {image_path}\n') # [!] audio path: {audio_path}\n[!] video path: {video_path}\n[!] thermal path: {thermal_path}')
|
116 |
+
|
117 |
+
# prepare the prompt
|
118 |
+
prompt_text = ''
|
119 |
+
for idx, (q, a) in enumerate(history):
|
120 |
+
if idx == 0:
|
121 |
+
prompt_text += f'{q}\n### Assistant: {a}\n###'
|
122 |
+
else:
|
123 |
+
prompt_text += f' Human: {q}\n### Assistant: {a}\n###'
|
124 |
+
if len(history) == 0:
|
125 |
+
prompt_text += f'{input}'
|
126 |
+
else:
|
127 |
+
prompt_text += f' Human: {input}'
|
128 |
+
|
129 |
+
response = model.generate({
|
130 |
+
'prompt': prompt_text,
|
131 |
+
'image_paths': [image_path] if image_path else [],
|
132 |
+
# 'audio_paths': [audio_path] if audio_path else [],
|
133 |
+
# 'video_paths': [video_path] if video_path else [],
|
134 |
+
# 'thermal_paths': [thermal_path] if thermal_path else [],
|
135 |
+
'top_p': top_p,
|
136 |
+
'temperature': temperature,
|
137 |
+
'max_tgt_len': max_length,
|
138 |
+
'modality_embeds': modality_cache
|
139 |
+
})
|
140 |
+
chatbot.append((parse_text(input), parse_text(response)))
|
141 |
+
history.append((input, response))
|
142 |
+
return chatbot, history, modality_cache
|
143 |
+
|
144 |
+
|
145 |
+
def reset_user_input():
|
146 |
+
return gr.update(value='')
|
147 |
+
|
148 |
+
def reset_dialog():
|
149 |
+
return [], []
|
150 |
+
|
151 |
+
def reset_state():
|
152 |
+
return None, None, None, None, [], [], []
|
153 |
+
|
154 |
+
|
155 |
+
with gr.Blocks(scale=4) as demo:
|
156 |
+
gr.HTML("""<h1 align="center">PandaGPT</h1>""")
|
157 |
+
|
158 |
+
with gr.Row(scale=4):
|
159 |
+
with gr.Column(scale=1):
|
160 |
+
image_path = gr.Image(type="filepath", label="Image", value=None)
|
161 |
+
# with gr.Column(scale=1):
|
162 |
+
# audio_path = gr.Audio(type="filepath", label="Audio", value=None)
|
163 |
+
# with gr.Column(scale=1):
|
164 |
+
# video_path = gr.Video(type='file', label="Video")
|
165 |
+
# with gr.Column(scale=1):
|
166 |
+
# thermal_path = gr.Image(type="filepath", label="Thermal Image", value=None)
|
167 |
+
|
168 |
+
chatbot = gr.Chatbot().style(height=300)
|
169 |
+
with gr.Row():
|
170 |
+
with gr.Column(scale=4):
|
171 |
+
with gr.Column(scale=12):
|
172 |
+
user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style(container=False)
|
173 |
+
with gr.Column(min_width=32, scale=1):
|
174 |
+
with gr.Row(scale=1):
|
175 |
+
submitBtn = gr.Button("Submit", variant="primary")
|
176 |
+
with gr.Row(scale=1):
|
177 |
+
resubmitBtn = gr.Button("Resubmit", variant="primary")
|
178 |
+
with gr.Column(scale=1):
|
179 |
+
emptyBtn = gr.Button("Clear History")
|
180 |
+
max_length = gr.Slider(0, 400, value=256, step=1.0, label="Maximum length", interactive=True)
|
181 |
+
top_p = gr.Slider(0, 1, value=0.01, step=0.01, label="Top P", interactive=True)
|
182 |
+
temperature = gr.Slider(0, 1, value=1.0, step=0.01, label="Temperature", interactive=True)
|
183 |
+
|
184 |
+
history = gr.State([])
|
185 |
+
modality_cache = gr.State([])
|
186 |
+
|
187 |
+
submitBtn.click(
|
188 |
+
predict, [
|
189 |
+
user_input,
|
190 |
+
image_path,
|
191 |
+
# audio_path,
|
192 |
+
# video_path,
|
193 |
+
# thermal_path,
|
194 |
+
chatbot,
|
195 |
+
max_length,
|
196 |
+
top_p,
|
197 |
+
temperature,
|
198 |
+
history,
|
199 |
+
modality_cache,
|
200 |
+
], [
|
201 |
+
chatbot,
|
202 |
+
history,
|
203 |
+
modality_cache
|
204 |
+
],
|
205 |
+
show_progress=True
|
206 |
+
)
|
207 |
+
|
208 |
+
resubmitBtn.click(
|
209 |
+
re_predict, [
|
210 |
+
user_input,
|
211 |
+
image_path,
|
212 |
+
# audio_path,
|
213 |
+
# video_path,
|
214 |
+
# thermal_path,
|
215 |
+
chatbot,
|
216 |
+
max_length,
|
217 |
+
top_p,
|
218 |
+
temperature,
|
219 |
+
history,
|
220 |
+
modality_cache,
|
221 |
+
], [
|
222 |
+
chatbot,
|
223 |
+
history,
|
224 |
+
modality_cache
|
225 |
+
],
|
226 |
+
show_progress=True
|
227 |
+
)
|
228 |
+
|
229 |
+
submitBtn.click(reset_user_input, [], [user_input])
|
230 |
+
emptyBtn.click(reset_state, outputs=[
|
231 |
+
image_path,
|
232 |
+
# audio_path,
|
233 |
+
# video_path,
|
234 |
+
# thermal_path,
|
235 |
+
chatbot,
|
236 |
+
history,
|
237 |
+
modality_cache
|
238 |
+
], show_progress=True)
|
239 |
+
|
240 |
+
demo.queue().launch(share=False, inbrowser=True, server_name='0.0.0.0', server_port=10050)
|
model/CLIP/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
# remove fp32 LN & return intermediate features
|
2 |
+
from .clip import *
|
model/CLIP/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (196 Bytes). View file
|
|
model/CLIP/__pycache__/clip.cpython-310.pyc
ADDED
Binary file (8.83 kB). View file
|
|
model/CLIP/__pycache__/model.cpython-310.pyc
ADDED
Binary file (15.6 kB). View file
|
|
model/CLIP/__pycache__/simple_tokenizer.cpython-310.pyc
ADDED
Binary file (5.73 kB). View file
|
|
model/CLIP/bpe_simple_vocab_16e6.txt.gz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
|
3 |
+
size 1356917
|
model/CLIP/clip.py
ADDED
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import hashlib
|
2 |
+
import os
|
3 |
+
import urllib
|
4 |
+
import warnings
|
5 |
+
from typing import Any, Union, List
|
6 |
+
from pkg_resources import packaging
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from PIL import Image
|
10 |
+
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
11 |
+
from tqdm import tqdm
|
12 |
+
|
13 |
+
from .model import build_model
|
14 |
+
from .simple_tokenizer import SimpleTokenizer as _Tokenizer
|
15 |
+
|
16 |
+
try:
|
17 |
+
from torchvision.transforms import InterpolationMode
|
18 |
+
BICUBIC = InterpolationMode.BICUBIC
|
19 |
+
except ImportError:
|
20 |
+
BICUBIC = Image.BICUBIC
|
21 |
+
|
22 |
+
|
23 |
+
if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"):
|
24 |
+
warnings.warn("PyTorch version 1.7.1 or higher is recommended")
|
25 |
+
|
26 |
+
|
27 |
+
__all__ = ["available_models", "load", "tokenize"]
|
28 |
+
_tokenizer = _Tokenizer()
|
29 |
+
|
30 |
+
_MODELS = {
|
31 |
+
"RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
|
32 |
+
"RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
|
33 |
+
"RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
|
34 |
+
"RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
|
35 |
+
"RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
|
36 |
+
"ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
|
37 |
+
"ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
|
38 |
+
"ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
|
39 |
+
"ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
|
40 |
+
}
|
41 |
+
|
42 |
+
|
43 |
+
def _download(url: str, root: str):
|
44 |
+
os.makedirs(root, exist_ok=True)
|
45 |
+
filename = os.path.basename(url)
|
46 |
+
|
47 |
+
expected_sha256 = url.split("/")[-2]
|
48 |
+
download_target = os.path.join(root, filename)
|
49 |
+
|
50 |
+
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
51 |
+
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
52 |
+
|
53 |
+
if os.path.isfile(download_target):
|
54 |
+
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
|
55 |
+
return download_target
|
56 |
+
else:
|
57 |
+
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
|
58 |
+
|
59 |
+
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
60 |
+
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
|
61 |
+
while True:
|
62 |
+
buffer = source.read(8192)
|
63 |
+
if not buffer:
|
64 |
+
break
|
65 |
+
|
66 |
+
output.write(buffer)
|
67 |
+
loop.update(len(buffer))
|
68 |
+
|
69 |
+
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
|
70 |
+
raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match")
|
71 |
+
|
72 |
+
return download_target
|
73 |
+
|
74 |
+
|
75 |
+
def _convert_image_to_rgb(image):
|
76 |
+
return image.convert("RGB")
|
77 |
+
|
78 |
+
|
79 |
+
def _transform(n_px):
|
80 |
+
return Compose([
|
81 |
+
Resize((n_px, n_px), interpolation=BICUBIC),
|
82 |
+
# CenterCrop(n_px),
|
83 |
+
_convert_image_to_rgb,
|
84 |
+
ToTensor(),
|
85 |
+
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
86 |
+
])
|
87 |
+
|
88 |
+
|
89 |
+
def available_models() -> List[str]:
|
90 |
+
"""Returns the names of available CLIP models"""
|
91 |
+
return list(_MODELS.keys())
|
92 |
+
|
93 |
+
|
94 |
+
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None):
|
95 |
+
"""Load a CLIP model
|
96 |
+
|
97 |
+
Parameters
|
98 |
+
----------
|
99 |
+
name : str
|
100 |
+
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
|
101 |
+
|
102 |
+
device : Union[str, torch.device]
|
103 |
+
The device to put the loaded model
|
104 |
+
|
105 |
+
jit : bool
|
106 |
+
Whether to load the optimized JIT model or more hackable non-JIT model (default).
|
107 |
+
|
108 |
+
download_root: str
|
109 |
+
path to download the model files; by default, it uses "~/.cache/clip"
|
110 |
+
|
111 |
+
Returns
|
112 |
+
-------
|
113 |
+
model : torch.nn.Module
|
114 |
+
The CLIP model
|
115 |
+
|
116 |
+
preprocess : Callable[[PIL.Image], torch.Tensor]
|
117 |
+
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
|
118 |
+
"""
|
119 |
+
if name in _MODELS:
|
120 |
+
model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
|
121 |
+
elif os.path.isfile(name):
|
122 |
+
model_path = name
|
123 |
+
else:
|
124 |
+
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
|
125 |
+
|
126 |
+
with open(model_path, 'rb') as opened_file:
|
127 |
+
try:
|
128 |
+
# loading JIT archive
|
129 |
+
model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval()
|
130 |
+
state_dict = None
|
131 |
+
except RuntimeError:
|
132 |
+
# loading saved state dict
|
133 |
+
if jit:
|
134 |
+
warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
|
135 |
+
jit = False
|
136 |
+
state_dict = torch.load(opened_file, map_location="cpu")
|
137 |
+
|
138 |
+
if not jit:
|
139 |
+
model = build_model(state_dict or model.state_dict()).to(device)
|
140 |
+
if str(device) == "cpu":
|
141 |
+
model.float()
|
142 |
+
return model, _transform(model.visual.input_resolution)
|
143 |
+
|
144 |
+
# patch the device names
|
145 |
+
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
|
146 |
+
device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
|
147 |
+
|
148 |
+
def patch_device(module):
|
149 |
+
try:
|
150 |
+
graphs = [module.graph] if hasattr(module, "graph") else []
|
151 |
+
except RuntimeError:
|
152 |
+
graphs = []
|
153 |
+
|
154 |
+
if hasattr(module, "forward1"):
|
155 |
+
graphs.append(module.forward1.graph)
|
156 |
+
|
157 |
+
for graph in graphs:
|
158 |
+
for node in graph.findAllNodes("prim::Constant"):
|
159 |
+
if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
|
160 |
+
node.copyAttributes(device_node)
|
161 |
+
|
162 |
+
model.apply(patch_device)
|
163 |
+
patch_device(model.encode_image)
|
164 |
+
patch_device(model.encode_text)
|
165 |
+
|
166 |
+
# patch dtype to float32 on CPU
|
167 |
+
if str(device) == "cpu":
|
168 |
+
float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
|
169 |
+
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
|
170 |
+
float_node = float_input.node()
|
171 |
+
|
172 |
+
def patch_float(module):
|
173 |
+
try:
|
174 |
+
graphs = [module.graph] if hasattr(module, "graph") else []
|
175 |
+
except RuntimeError:
|
176 |
+
graphs = []
|
177 |
+
|
178 |
+
if hasattr(module, "forward1"):
|
179 |
+
graphs.append(module.forward1.graph)
|
180 |
+
|
181 |
+
for graph in graphs:
|
182 |
+
for node in graph.findAllNodes("aten::to"):
|
183 |
+
inputs = list(node.inputs())
|
184 |
+
for i in [1, 2]: # dtype can be the second or third argument to aten::to()
|
185 |
+
if inputs[i].node()["value"] == 5:
|
186 |
+
inputs[i].node().copyAttributes(float_node)
|
187 |
+
|
188 |
+
model.apply(patch_float)
|
189 |
+
patch_float(model.encode_image)
|
190 |
+
patch_float(model.encode_text)
|
191 |
+
|
192 |
+
model.float()
|
193 |
+
|
194 |
+
return model, _transform(model.input_resolution.item())
|
195 |
+
|
196 |
+
|
197 |
+
def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]:
|
198 |
+
"""
|
199 |
+
Returns the tokenized representation of given input string(s)
|
200 |
+
|
201 |
+
Parameters
|
202 |
+
----------
|
203 |
+
texts : Union[str, List[str]]
|
204 |
+
An input string or a list of input strings to tokenize
|
205 |
+
|
206 |
+
context_length : int
|
207 |
+
The context length to use; all CLIP models use 77 as the context length
|
208 |
+
|
209 |
+
truncate: bool
|
210 |
+
Whether to truncate the text in case its encoding is longer than the context length
|
211 |
+
|
212 |
+
Returns
|
213 |
+
-------
|
214 |
+
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
|
215 |
+
We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
|
216 |
+
"""
|
217 |
+
if isinstance(texts, str):
|
218 |
+
texts = [texts]
|
219 |
+
|
220 |
+
sot_token = _tokenizer.encoder["<|startoftext|>"]
|
221 |
+
eot_token = _tokenizer.encoder["<|endoftext|>"]
|
222 |
+
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
|
223 |
+
if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
|
224 |
+
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
225 |
+
else:
|
226 |
+
result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
|
227 |
+
|
228 |
+
for i, tokens in enumerate(all_tokens):
|
229 |
+
if len(tokens) > context_length:
|
230 |
+
if truncate:
|
231 |
+
tokens = tokens[:context_length]
|
232 |
+
tokens[-1] = eot_token
|
233 |
+
else:
|
234 |
+
raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
|
235 |
+
result[i, :len(tokens)] = torch.tensor(tokens)
|
236 |
+
|
237 |
+
return result
|
model/CLIP/model.py
ADDED
@@ -0,0 +1,449 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import OrderedDict
|
2 |
+
from typing import Tuple, Union
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torch import nn
|
8 |
+
|
9 |
+
|
10 |
+
class Bottleneck(nn.Module):
|
11 |
+
expansion = 4
|
12 |
+
|
13 |
+
def __init__(self, inplanes, planes, stride=1):
|
14 |
+
super().__init__()
|
15 |
+
|
16 |
+
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
|
17 |
+
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
|
18 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
19 |
+
self.relu1 = nn.ReLU(inplace=True)
|
20 |
+
|
21 |
+
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
|
22 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
23 |
+
self.relu2 = nn.ReLU(inplace=True)
|
24 |
+
|
25 |
+
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
|
26 |
+
|
27 |
+
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
|
28 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
29 |
+
self.relu3 = nn.ReLU(inplace=True)
|
30 |
+
|
31 |
+
self.downsample = None
|
32 |
+
self.stride = stride
|
33 |
+
|
34 |
+
if stride > 1 or inplanes != planes * Bottleneck.expansion:
|
35 |
+
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
|
36 |
+
self.downsample = nn.Sequential(OrderedDict([
|
37 |
+
("-1", nn.AvgPool2d(stride)),
|
38 |
+
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
|
39 |
+
("1", nn.BatchNorm2d(planes * self.expansion))
|
40 |
+
]))
|
41 |
+
|
42 |
+
def forward(self, x: torch.Tensor):
|
43 |
+
identity = x
|
44 |
+
|
45 |
+
out = self.relu1(self.bn1(self.conv1(x)))
|
46 |
+
out = self.relu2(self.bn2(self.conv2(out)))
|
47 |
+
out = self.avgpool(out)
|
48 |
+
out = self.bn3(self.conv3(out))
|
49 |
+
|
50 |
+
if self.downsample is not None:
|
51 |
+
identity = self.downsample(x)
|
52 |
+
|
53 |
+
out += identity
|
54 |
+
out = self.relu3(out)
|
55 |
+
return out
|
56 |
+
|
57 |
+
|
58 |
+
class AttentionPool2d(nn.Module):
|
59 |
+
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
|
60 |
+
super().__init__()
|
61 |
+
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
|
62 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
63 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
64 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
65 |
+
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
|
66 |
+
self.num_heads = num_heads
|
67 |
+
|
68 |
+
def forward(self, x):
|
69 |
+
x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
|
70 |
+
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
|
71 |
+
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
|
72 |
+
x, _ = F.multi_head_attention_forward(
|
73 |
+
query=x[:1], key=x, value=x,
|
74 |
+
embed_dim_to_check=x.shape[-1],
|
75 |
+
num_heads=self.num_heads,
|
76 |
+
q_proj_weight=self.q_proj.weight,
|
77 |
+
k_proj_weight=self.k_proj.weight,
|
78 |
+
v_proj_weight=self.v_proj.weight,
|
79 |
+
in_proj_weight=None,
|
80 |
+
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
81 |
+
bias_k=None,
|
82 |
+
bias_v=None,
|
83 |
+
add_zero_attn=False,
|
84 |
+
dropout_p=0,
|
85 |
+
out_proj_weight=self.c_proj.weight,
|
86 |
+
out_proj_bias=self.c_proj.bias,
|
87 |
+
use_separate_proj_weight=True,
|
88 |
+
training=self.training,
|
89 |
+
need_weights=False
|
90 |
+
)
|
91 |
+
return x.squeeze(0)
|
92 |
+
|
93 |
+
|
94 |
+
class ModifiedResNet(nn.Module):
|
95 |
+
"""
|
96 |
+
A ResNet class that is similar to torchvision's but contains the following changes:
|
97 |
+
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
|
98 |
+
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
|
99 |
+
- The final pooling layer is a QKV attention instead of an average pool
|
100 |
+
"""
|
101 |
+
|
102 |
+
def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
|
103 |
+
super().__init__()
|
104 |
+
self.output_dim = output_dim
|
105 |
+
self.input_resolution = input_resolution
|
106 |
+
|
107 |
+
# the 3-layer stem
|
108 |
+
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
|
109 |
+
self.bn1 = nn.BatchNorm2d(width // 2)
|
110 |
+
self.relu1 = nn.ReLU(inplace=True)
|
111 |
+
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
|
112 |
+
self.bn2 = nn.BatchNorm2d(width // 2)
|
113 |
+
self.relu2 = nn.ReLU(inplace=True)
|
114 |
+
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
|
115 |
+
self.bn3 = nn.BatchNorm2d(width)
|
116 |
+
self.relu3 = nn.ReLU(inplace=True)
|
117 |
+
self.avgpool = nn.AvgPool2d(2)
|
118 |
+
|
119 |
+
# residual layers
|
120 |
+
self._inplanes = width # this is a *mutable* variable used during construction
|
121 |
+
self.layer1 = self._make_layer(width, layers[0])
|
122 |
+
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
|
123 |
+
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
|
124 |
+
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
|
125 |
+
|
126 |
+
embed_dim = width * 32 # the ResNet feature dimension
|
127 |
+
self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
|
128 |
+
|
129 |
+
def _make_layer(self, planes, blocks, stride=1):
|
130 |
+
layers = [Bottleneck(self._inplanes, planes, stride)]
|
131 |
+
|
132 |
+
self._inplanes = planes * Bottleneck.expansion
|
133 |
+
for _ in range(1, blocks):
|
134 |
+
layers.append(Bottleneck(self._inplanes, planes))
|
135 |
+
|
136 |
+
return nn.Sequential(*layers)
|
137 |
+
|
138 |
+
def forward(self, x):
|
139 |
+
def stem(x):
|
140 |
+
x = self.relu1(self.bn1(self.conv1(x)))
|
141 |
+
x = self.relu2(self.bn2(self.conv2(x)))
|
142 |
+
x = self.relu3(self.bn3(self.conv3(x)))
|
143 |
+
x = self.avgpool(x)
|
144 |
+
return x
|
145 |
+
|
146 |
+
x = x.type(self.conv1.weight.dtype)
|
147 |
+
x = stem(x)
|
148 |
+
x = self.layer1(x)
|
149 |
+
x = self.layer2(x)
|
150 |
+
x = self.layer3(x)
|
151 |
+
x = self.layer4(x)
|
152 |
+
x = self.attnpool(x)
|
153 |
+
|
154 |
+
return x
|
155 |
+
|
156 |
+
|
157 |
+
class LayerNorm(nn.LayerNorm):
|
158 |
+
"""Subclass torch's LayerNorm to handle fp16."""
|
159 |
+
|
160 |
+
def forward(self, x: torch.Tensor):
|
161 |
+
# orig_type = x.dtype
|
162 |
+
ret = super().forward(x) # .type(torch.float16)) # Warning: Originally avoid fp32 in clip
|
163 |
+
return ret # .type(orig_type)
|
164 |
+
|
165 |
+
|
166 |
+
class QuickGELU(nn.Module):
|
167 |
+
def forward(self, x: torch.Tensor):
|
168 |
+
return x * torch.sigmoid(1.702 * x)
|
169 |
+
|
170 |
+
|
171 |
+
class ResidualAttentionBlock(nn.Module):
|
172 |
+
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
|
173 |
+
super().__init__()
|
174 |
+
|
175 |
+
self.attn = nn.MultiheadAttention(d_model, n_head)
|
176 |
+
self.ln_1 = LayerNorm(d_model)
|
177 |
+
self.mlp = nn.Sequential(OrderedDict([
|
178 |
+
("c_fc", nn.Linear(d_model, d_model * 4)),
|
179 |
+
("gelu", QuickGELU()),
|
180 |
+
("c_proj", nn.Linear(d_model * 4, d_model))
|
181 |
+
]))
|
182 |
+
self.ln_2 = LayerNorm(d_model)
|
183 |
+
self.attn_mask = attn_mask
|
184 |
+
|
185 |
+
def attention(self, x: torch.Tensor):
|
186 |
+
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
|
187 |
+
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
|
188 |
+
|
189 |
+
def forward(self, x: torch.Tensor):
|
190 |
+
x = x + self.attention(self.ln_1(x))
|
191 |
+
x = x + self.mlp(self.ln_2(x))
|
192 |
+
return x
|
193 |
+
|
194 |
+
|
195 |
+
class Transformer(nn.Module):
|
196 |
+
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
|
197 |
+
super().__init__()
|
198 |
+
self.width = width
|
199 |
+
self.layers = layers
|
200 |
+
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
|
201 |
+
|
202 |
+
def forward(self, x: torch.Tensor):
|
203 |
+
return self.resblocks(x)
|
204 |
+
|
205 |
+
|
206 |
+
class VisionTransformer(nn.Module):
|
207 |
+
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
|
208 |
+
super().__init__()
|
209 |
+
self.input_resolution = input_resolution
|
210 |
+
self.output_dim = output_dim
|
211 |
+
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
|
212 |
+
|
213 |
+
scale = width ** -0.5
|
214 |
+
self.class_embedding = nn.Parameter(scale * torch.randn(width))
|
215 |
+
self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
|
216 |
+
self.ln_pre = LayerNorm(width)
|
217 |
+
|
218 |
+
self.transformer = Transformer(width, layers, heads)
|
219 |
+
|
220 |
+
self.ln_post = LayerNorm(width)
|
221 |
+
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
|
222 |
+
|
223 |
+
def forward(self, x: torch.Tensor):
|
224 |
+
x = self.conv1(x) # shape = [*, width, grid, grid]
|
225 |
+
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
226 |
+
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
227 |
+
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
228 |
+
x = x + self.positional_embedding.to(x.dtype)
|
229 |
+
x = self.ln_pre(x)
|
230 |
+
|
231 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
232 |
+
x = self.transformer(x)
|
233 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
234 |
+
|
235 |
+
x = self.ln_post(x[:, 0, :])
|
236 |
+
|
237 |
+
if self.proj is not None:
|
238 |
+
x = x @ self.proj
|
239 |
+
|
240 |
+
return x
|
241 |
+
|
242 |
+
def forward_patch_features(self, x: torch.Tensor):
|
243 |
+
x = self.conv1(x) # shape = [*, width, grid, grid]
|
244 |
+
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
245 |
+
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
246 |
+
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
247 |
+
x = x + self.positional_embedding.to(x.dtype)
|
248 |
+
x = self.ln_pre(x)
|
249 |
+
|
250 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
251 |
+
x = self.transformer(x)
|
252 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
253 |
+
return x[:, 1:, :]
|
254 |
+
|
255 |
+
|
256 |
+
class CLIP(nn.Module):
|
257 |
+
def __init__(self,
|
258 |
+
embed_dim: int,
|
259 |
+
# vision
|
260 |
+
image_resolution: int,
|
261 |
+
vision_layers: Union[Tuple[int, int, int, int], int],
|
262 |
+
vision_width: int,
|
263 |
+
vision_patch_size: int,
|
264 |
+
# text
|
265 |
+
context_length: int,
|
266 |
+
vocab_size: int,
|
267 |
+
transformer_width: int,
|
268 |
+
transformer_heads: int,
|
269 |
+
transformer_layers: int
|
270 |
+
):
|
271 |
+
super().__init__()
|
272 |
+
|
273 |
+
self.context_length = context_length
|
274 |
+
|
275 |
+
if isinstance(vision_layers, (tuple, list)):
|
276 |
+
vision_heads = vision_width * 32 // 64
|
277 |
+
self.visual = ModifiedResNet(
|
278 |
+
layers=vision_layers,
|
279 |
+
output_dim=embed_dim,
|
280 |
+
heads=vision_heads,
|
281 |
+
input_resolution=image_resolution,
|
282 |
+
width=vision_width
|
283 |
+
)
|
284 |
+
else:
|
285 |
+
vision_heads = vision_width // 64
|
286 |
+
self.visual = VisionTransformer(
|
287 |
+
input_resolution=image_resolution,
|
288 |
+
patch_size=vision_patch_size,
|
289 |
+
width=vision_width,
|
290 |
+
layers=vision_layers,
|
291 |
+
heads=vision_heads,
|
292 |
+
output_dim=embed_dim
|
293 |
+
)
|
294 |
+
|
295 |
+
self.transformer = Transformer(
|
296 |
+
width=transformer_width,
|
297 |
+
layers=transformer_layers,
|
298 |
+
heads=transformer_heads,
|
299 |
+
attn_mask=self.build_attention_mask()
|
300 |
+
)
|
301 |
+
|
302 |
+
self.vocab_size = vocab_size
|
303 |
+
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
|
304 |
+
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
|
305 |
+
self.ln_final = LayerNorm(transformer_width)
|
306 |
+
|
307 |
+
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
|
308 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
309 |
+
|
310 |
+
self.initialize_parameters()
|
311 |
+
|
312 |
+
def initialize_parameters(self):
|
313 |
+
nn.init.normal_(self.token_embedding.weight, std=0.02)
|
314 |
+
nn.init.normal_(self.positional_embedding, std=0.01)
|
315 |
+
|
316 |
+
if isinstance(self.visual, ModifiedResNet):
|
317 |
+
if self.visual.attnpool is not None:
|
318 |
+
std = self.visual.attnpool.c_proj.in_features ** -0.5
|
319 |
+
nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
|
320 |
+
nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
|
321 |
+
nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
|
322 |
+
nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
|
323 |
+
|
324 |
+
for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
|
325 |
+
for name, param in resnet_block.named_parameters():
|
326 |
+
if name.endswith("bn3.weight"):
|
327 |
+
nn.init.zeros_(param)
|
328 |
+
|
329 |
+
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
|
330 |
+
attn_std = self.transformer.width ** -0.5
|
331 |
+
fc_std = (2 * self.transformer.width) ** -0.5
|
332 |
+
for block in self.transformer.resblocks:
|
333 |
+
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
334 |
+
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
335 |
+
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
336 |
+
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
337 |
+
|
338 |
+
if self.text_projection is not None:
|
339 |
+
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
|
340 |
+
|
341 |
+
def build_attention_mask(self):
|
342 |
+
# lazily create causal attention mask, with full attention between the vision tokens
|
343 |
+
# pytorch uses additive attention mask; fill with -inf
|
344 |
+
mask = torch.empty(self.context_length, self.context_length)
|
345 |
+
mask.fill_(float("-inf"))
|
346 |
+
mask.triu_(1) # zero out the lower diagonal
|
347 |
+
return mask
|
348 |
+
|
349 |
+
@property
|
350 |
+
def dtype(self):
|
351 |
+
return self.visual.conv1.weight.dtype
|
352 |
+
|
353 |
+
def encode_image(self, image):
|
354 |
+
return self.visual(image.type(self.dtype))
|
355 |
+
|
356 |
+
def encode_text(self, text):
|
357 |
+
x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
|
358 |
+
|
359 |
+
x = x + self.positional_embedding.type(self.dtype)
|
360 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
361 |
+
x = self.transformer(x)
|
362 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
363 |
+
x = self.ln_final(x).type(self.dtype)
|
364 |
+
|
365 |
+
# x.shape = [batch_size, n_ctx, transformer.width]
|
366 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
367 |
+
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
|
368 |
+
|
369 |
+
return x
|
370 |
+
|
371 |
+
def forward(self, image, text):
|
372 |
+
image_features = self.encode_image(image)
|
373 |
+
text_features = self.encode_text(text)
|
374 |
+
|
375 |
+
# normalized features
|
376 |
+
image_features = image_features / image_features.norm(dim=1, keepdim=True)
|
377 |
+
text_features = text_features / text_features.norm(dim=1, keepdim=True)
|
378 |
+
|
379 |
+
# cosine similarity as logits
|
380 |
+
logit_scale = self.logit_scale.exp()
|
381 |
+
logits_per_image = logit_scale * image_features @ text_features.t()
|
382 |
+
logits_per_text = logits_per_image.t()
|
383 |
+
|
384 |
+
# shape = [global_batch_size, global_batch_size]
|
385 |
+
return logits_per_image, logits_per_text
|
386 |
+
|
387 |
+
|
388 |
+
def convert_weights(model: nn.Module):
|
389 |
+
"""Convert applicable model parameters to fp16"""
|
390 |
+
|
391 |
+
def _convert_weights_to_fp16(l):
|
392 |
+
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
393 |
+
l.weight.data = l.weight.data.half()
|
394 |
+
if l.bias is not None:
|
395 |
+
l.bias.data = l.bias.data.half()
|
396 |
+
|
397 |
+
if isinstance(l, nn.MultiheadAttention):
|
398 |
+
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
|
399 |
+
tensor = getattr(l, attr)
|
400 |
+
if tensor is not None:
|
401 |
+
tensor.data = tensor.data.half()
|
402 |
+
|
403 |
+
for name in ["text_projection", "proj"]:
|
404 |
+
if hasattr(l, name):
|
405 |
+
attr = getattr(l, name)
|
406 |
+
if attr is not None:
|
407 |
+
attr.data = attr.data.half()
|
408 |
+
|
409 |
+
model.apply(_convert_weights_to_fp16)
|
410 |
+
|
411 |
+
|
412 |
+
def build_model(state_dict: dict):
|
413 |
+
vit = "visual.proj" in state_dict
|
414 |
+
|
415 |
+
if vit:
|
416 |
+
vision_width = state_dict["visual.conv1.weight"].shape[0]
|
417 |
+
vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
|
418 |
+
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
|
419 |
+
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
|
420 |
+
image_resolution = vision_patch_size * grid_size
|
421 |
+
else:
|
422 |
+
counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
|
423 |
+
vision_layers = tuple(counts)
|
424 |
+
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
|
425 |
+
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
|
426 |
+
vision_patch_size = None
|
427 |
+
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
|
428 |
+
image_resolution = output_width * 32
|
429 |
+
|
430 |
+
embed_dim = state_dict["text_projection"].shape[1]
|
431 |
+
context_length = state_dict["positional_embedding"].shape[0]
|
432 |
+
vocab_size = state_dict["token_embedding.weight"].shape[0]
|
433 |
+
transformer_width = state_dict["ln_final.weight"].shape[0]
|
434 |
+
transformer_heads = transformer_width // 64
|
435 |
+
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks")))
|
436 |
+
|
437 |
+
model = CLIP(
|
438 |
+
embed_dim,
|
439 |
+
image_resolution, vision_layers, vision_width, vision_patch_size,
|
440 |
+
context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
|
441 |
+
)
|
442 |
+
|
443 |
+
for key in ["input_resolution", "context_length", "vocab_size"]:
|
444 |
+
if key in state_dict:
|
445 |
+
del state_dict[key]
|
446 |
+
|
447 |
+
convert_weights(model)
|
448 |
+
model.load_state_dict(state_dict)
|
449 |
+
return model.eval()
|
model/CLIP/simple_tokenizer.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gzip
|
2 |
+
import html
|
3 |
+
import os
|
4 |
+
from functools import lru_cache
|
5 |
+
|
6 |
+
import ftfy
|
7 |
+
import regex as re
|
8 |
+
|
9 |
+
|
10 |
+
@lru_cache()
|
11 |
+
def default_bpe():
|
12 |
+
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
|
13 |
+
|
14 |
+
|
15 |
+
@lru_cache()
|
16 |
+
def bytes_to_unicode():
|
17 |
+
"""
|
18 |
+
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
19 |
+
The reversible bpe codes work on unicode strings.
|
20 |
+
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
21 |
+
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
22 |
+
This is a signficant percentage of your normal, say, 32K bpe vocab.
|
23 |
+
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
24 |
+
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
25 |
+
"""
|
26 |
+
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
|
27 |
+
cs = bs[:]
|
28 |
+
n = 0
|
29 |
+
for b in range(2**8):
|
30 |
+
if b not in bs:
|
31 |
+
bs.append(b)
|
32 |
+
cs.append(2**8+n)
|
33 |
+
n += 1
|
34 |
+
cs = [chr(n) for n in cs]
|
35 |
+
return dict(zip(bs, cs))
|
36 |
+
|
37 |
+
|
38 |
+
def get_pairs(word):
|
39 |
+
"""Return set of symbol pairs in a word.
|
40 |
+
Word is represented as tuple of symbols (symbols being variable-length strings).
|
41 |
+
"""
|
42 |
+
pairs = set()
|
43 |
+
prev_char = word[0]
|
44 |
+
for char in word[1:]:
|
45 |
+
pairs.add((prev_char, char))
|
46 |
+
prev_char = char
|
47 |
+
return pairs
|
48 |
+
|
49 |
+
|
50 |
+
def basic_clean(text):
|
51 |
+
text = ftfy.fix_text(text)
|
52 |
+
text = html.unescape(html.unescape(text))
|
53 |
+
return text.strip()
|
54 |
+
|
55 |
+
|
56 |
+
def whitespace_clean(text):
|
57 |
+
text = re.sub(r'\s+', ' ', text)
|
58 |
+
text = text.strip()
|
59 |
+
return text
|
60 |
+
|
61 |
+
|
62 |
+
class SimpleTokenizer(object):
|
63 |
+
def __init__(self, bpe_path: str = default_bpe()):
|
64 |
+
self.byte_encoder = bytes_to_unicode()
|
65 |
+
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
66 |
+
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
|
67 |
+
merges = merges[1:49152-256-2+1]
|
68 |
+
merges = [tuple(merge.split()) for merge in merges]
|
69 |
+
vocab = list(bytes_to_unicode().values())
|
70 |
+
vocab = vocab + [v+'</w>' for v in vocab]
|
71 |
+
for merge in merges:
|
72 |
+
vocab.append(''.join(merge))
|
73 |
+
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
|
74 |
+
self.encoder = dict(zip(vocab, range(len(vocab))))
|
75 |
+
self.decoder = {v: k for k, v in self.encoder.items()}
|
76 |
+
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
77 |
+
self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
|
78 |
+
self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
|
79 |
+
|
80 |
+
def bpe(self, token):
|
81 |
+
if token in self.cache:
|
82 |
+
return self.cache[token]
|
83 |
+
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
|
84 |
+
pairs = get_pairs(word)
|
85 |
+
|
86 |
+
if not pairs:
|
87 |
+
return token+'</w>'
|
88 |
+
|
89 |
+
while True:
|
90 |
+
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
91 |
+
if bigram not in self.bpe_ranks:
|
92 |
+
break
|
93 |
+
first, second = bigram
|
94 |
+
new_word = []
|
95 |
+
i = 0
|
96 |
+
while i < len(word):
|
97 |
+
try:
|
98 |
+
j = word.index(first, i)
|
99 |
+
new_word.extend(word[i:j])
|
100 |
+
i = j
|
101 |
+
except:
|
102 |
+
new_word.extend(word[i:])
|
103 |
+
break
|
104 |
+
|
105 |
+
if word[i] == first and i < len(word)-1 and word[i+1] == second:
|
106 |
+
new_word.append(first+second)
|
107 |
+
i += 2
|
108 |
+
else:
|
109 |
+
new_word.append(word[i])
|
110 |
+
i += 1
|
111 |
+
new_word = tuple(new_word)
|
112 |
+
word = new_word
|
113 |
+
if len(word) == 1:
|
114 |
+
break
|
115 |
+
else:
|
116 |
+
pairs = get_pairs(word)
|
117 |
+
word = ' '.join(word)
|
118 |
+
self.cache[token] = word
|
119 |
+
return word
|
120 |
+
|
121 |
+
def encode(self, text):
|
122 |
+
bpe_tokens = []
|
123 |
+
text = whitespace_clean(basic_clean(text)).lower()
|
124 |
+
for token in re.findall(self.pat, text):
|
125 |
+
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
126 |
+
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
127 |
+
return bpe_tokens
|
128 |
+
|
129 |
+
def decode(self, tokens):
|
130 |
+
text = ''.join([self.decoder[token] for token in tokens])
|
131 |
+
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
|
132 |
+
return text
|
model/PROCESS/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
import data
|
model/PROCESS/bpe/bpe_simple_vocab_16e6.txt.gz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
|
3 |
+
size 1356917
|
model/PROCESS/data.py
ADDED
@@ -0,0 +1,406 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Portions Copyright (c) Meta Platforms, Inc. and affiliates.
|
3 |
+
# All rights reserved.
|
4 |
+
|
5 |
+
# This source code is licensed under the license found in the
|
6 |
+
# LICENSE file in the root directory of this source tree.
|
7 |
+
|
8 |
+
import io
|
9 |
+
import os
|
10 |
+
import math
|
11 |
+
import requests
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import torch.nn as nn
|
15 |
+
import torchaudio
|
16 |
+
import logging
|
17 |
+
|
18 |
+
# from .ImageBind.models.multimodal_preprocessors import SimpleTokenizer
|
19 |
+
from .multimodal_preprocessors import SimpleTokenizer
|
20 |
+
from PIL import Image
|
21 |
+
# from pytorchvideo import transforms as pv_transforms
|
22 |
+
# from pytorchvideo.data.clip_sampling import ConstantClipsPerVideoSampler
|
23 |
+
# from pytorchvideo.data.encoded_video import EncodedVideo
|
24 |
+
|
25 |
+
from torchvision import transforms
|
26 |
+
from torchvision.transforms._transforms_video import NormalizeVideo
|
27 |
+
|
28 |
+
DEFAULT_AUDIO_FRAME_SHIFT_MS = 10 # in milliseconds
|
29 |
+
|
30 |
+
BPE_PATH = "bpe/bpe_simple_vocab_16e6.txt.gz"
|
31 |
+
|
32 |
+
|
33 |
+
def waveform2melspec(waveform, sample_rate, num_mel_bins, target_length):
|
34 |
+
# Based on https://github.com/YuanGongND/ast/blob/d7d8b4b8e06cdaeb6c843cdb38794c1c7692234c/src/dataloader.py#L102
|
35 |
+
waveform -= waveform.mean()
|
36 |
+
fbank = torchaudio.compliance.kaldi.fbank(
|
37 |
+
waveform,
|
38 |
+
htk_compat=True,
|
39 |
+
sample_frequency=sample_rate,
|
40 |
+
use_energy=False,
|
41 |
+
window_type="hanning",
|
42 |
+
num_mel_bins=num_mel_bins,
|
43 |
+
dither=0.0,
|
44 |
+
frame_length=25,
|
45 |
+
frame_shift=DEFAULT_AUDIO_FRAME_SHIFT_MS,
|
46 |
+
)
|
47 |
+
# Convert to [mel_bins, num_frames] shape
|
48 |
+
fbank = fbank.transpose(0, 1)
|
49 |
+
# Pad to target_length
|
50 |
+
n_frames = fbank.size(1)
|
51 |
+
p = target_length - n_frames
|
52 |
+
# if p is too large (say >20%), flash a warning
|
53 |
+
if abs(p) / n_frames > 0.2:
|
54 |
+
logging.warning(
|
55 |
+
"Large gap between audio n_frames(%d) and "
|
56 |
+
"target_length (%d). Is the audio_target_length "
|
57 |
+
"setting correct?",
|
58 |
+
n_frames,
|
59 |
+
target_length,
|
60 |
+
)
|
61 |
+
# cut and pad
|
62 |
+
if p > 0:
|
63 |
+
fbank = torch.nn.functional.pad(fbank, (0, p), mode="constant", value=0)
|
64 |
+
elif p < 0:
|
65 |
+
fbank = fbank[:, 0:target_length]
|
66 |
+
# Convert to [1, mel_bins, num_frames] shape, essentially like a 1
|
67 |
+
# channel image
|
68 |
+
fbank = fbank.unsqueeze(0)
|
69 |
+
return fbank
|
70 |
+
|
71 |
+
|
72 |
+
def get_clip_timepoints(clip_sampler, duration):
|
73 |
+
# Read out all clips in this video
|
74 |
+
all_clips_timepoints = []
|
75 |
+
is_last_clip = False
|
76 |
+
end = 0.0
|
77 |
+
while not is_last_clip:
|
78 |
+
start, end, _, _, is_last_clip = clip_sampler(end, duration, annotation=None)
|
79 |
+
all_clips_timepoints.append((start, end))
|
80 |
+
return all_clips_timepoints
|
81 |
+
|
82 |
+
|
83 |
+
def load_and_transform_vision_data(image_paths, device, client=None):
|
84 |
+
if image_paths is None:
|
85 |
+
return None
|
86 |
+
|
87 |
+
image_ouputs = []
|
88 |
+
for image_path in image_paths:
|
89 |
+
data_transform = transforms.Compose(
|
90 |
+
[
|
91 |
+
transforms.Resize(
|
92 |
+
224, interpolation=transforms.InterpolationMode.BICUBIC
|
93 |
+
),
|
94 |
+
transforms.CenterCrop(224),
|
95 |
+
transforms.ToTensor(),
|
96 |
+
transforms.Normalize(
|
97 |
+
mean=(0.48145466, 0.4578275, 0.40821073),
|
98 |
+
std=(0.26862954, 0.26130258, 0.27577711),
|
99 |
+
),
|
100 |
+
]
|
101 |
+
)
|
102 |
+
if os.path.exists(image_path):
|
103 |
+
with open(image_path, "rb") as fopen:
|
104 |
+
image = Image.open(fopen).convert("RGB")
|
105 |
+
elif image_path.startswith("s3://") and client is not None:
|
106 |
+
image = Image.open(io.BytesIO(client.get(image_path))).convert("RGB")
|
107 |
+
elif image_path.startswith("http"):
|
108 |
+
image = Image.open(requests.get(image_path, stream=True).raw).convert(
|
109 |
+
"RGB"
|
110 |
+
)
|
111 |
+
else:
|
112 |
+
raise ValueError(f"Invalid image path: {image_path}")
|
113 |
+
|
114 |
+
image = data_transform(image).to(device)
|
115 |
+
image_ouputs.append(image)
|
116 |
+
return torch.stack(image_ouputs, dim=0)
|
117 |
+
|
118 |
+
def transform_vision_data(images, device):
|
119 |
+
image_ouputs = []
|
120 |
+
for img in images:
|
121 |
+
data_transform = transforms.Compose(
|
122 |
+
[
|
123 |
+
transforms.Resize(
|
124 |
+
224, interpolation=transforms.InterpolationMode.BICUBIC
|
125 |
+
),
|
126 |
+
transforms.CenterCrop(224),
|
127 |
+
transforms.ToTensor(),
|
128 |
+
transforms.Normalize(
|
129 |
+
mean=(0.48145466, 0.4578275, 0.40821073),
|
130 |
+
std=(0.26862954, 0.26130258, 0.27577711),
|
131 |
+
),
|
132 |
+
]
|
133 |
+
)
|
134 |
+
image = data_transform(img).to(device)
|
135 |
+
image_ouputs.append(image)
|
136 |
+
return torch.stack(image_ouputs, dim=0)
|
137 |
+
|
138 |
+
|
139 |
+
def load_and_transform_thermal_data(thermal_paths, device):
|
140 |
+
if thermal_paths is None:
|
141 |
+
return None
|
142 |
+
|
143 |
+
thermal_ouputs = []
|
144 |
+
for thermal_path in thermal_paths:
|
145 |
+
data_transform = transforms.Compose(
|
146 |
+
[
|
147 |
+
transforms.Resize(
|
148 |
+
224, interpolation=transforms.InterpolationMode.BICUBIC
|
149 |
+
),
|
150 |
+
transforms.CenterCrop(224),
|
151 |
+
transforms.ToTensor(),
|
152 |
+
]
|
153 |
+
)
|
154 |
+
with open(thermal_path, "rb") as fopen:
|
155 |
+
thermal = Image.open(fopen).convert("L")
|
156 |
+
thermal = data_transform(thermal).to(device)
|
157 |
+
thermal_ouputs.append(thermal)
|
158 |
+
return torch.stack(thermal_ouputs, dim=0)
|
159 |
+
|
160 |
+
|
161 |
+
def load_and_transform_text(text, device):
|
162 |
+
if text is None:
|
163 |
+
return None
|
164 |
+
tokenizer = SimpleTokenizer(bpe_path=BPE_PATH)
|
165 |
+
tokens = [tokenizer(t).unsqueeze(0).to(device) for t in text]
|
166 |
+
tokens = torch.cat(tokens, dim=0)
|
167 |
+
return tokens
|
168 |
+
|
169 |
+
|
170 |
+
# def load_and_transform_audio_data(
|
171 |
+
# audio_paths,
|
172 |
+
# device,
|
173 |
+
# num_mel_bins=128,
|
174 |
+
# target_length=204,
|
175 |
+
# sample_rate=16000,
|
176 |
+
# clip_duration=2,
|
177 |
+
# clips_per_video=3,
|
178 |
+
# mean=-4.268,
|
179 |
+
# std=9.138,
|
180 |
+
# ):
|
181 |
+
# if audio_paths is None:
|
182 |
+
# return None
|
183 |
+
|
184 |
+
# audio_outputs = []
|
185 |
+
# clip_sampler = ConstantClipsPerVideoSampler(
|
186 |
+
# clip_duration=clip_duration, clips_per_video=clips_per_video
|
187 |
+
# )
|
188 |
+
|
189 |
+
# for audio_path in audio_paths:
|
190 |
+
# waveform, sr = torchaudio.load(audio_path)
|
191 |
+
# if sample_rate != sr:
|
192 |
+
# waveform = torchaudio.functional.resample(
|
193 |
+
# waveform, orig_freq=sr, new_freq=sample_rate
|
194 |
+
# )
|
195 |
+
# all_clips_timepoints = get_clip_timepoints(
|
196 |
+
# clip_sampler, waveform.size(1) / sample_rate
|
197 |
+
# )
|
198 |
+
# all_clips = []
|
199 |
+
# for clip_timepoints in all_clips_timepoints:
|
200 |
+
# waveform_clip = waveform[
|
201 |
+
# :,
|
202 |
+
# int(clip_timepoints[0] * sample_rate) : int(
|
203 |
+
# clip_timepoints[1] * sample_rate
|
204 |
+
# ),
|
205 |
+
# ]
|
206 |
+
# waveform_melspec = waveform2melspec(
|
207 |
+
# waveform_clip, sample_rate, num_mel_bins, target_length
|
208 |
+
# )
|
209 |
+
# all_clips.append(waveform_melspec)
|
210 |
+
|
211 |
+
# normalize = transforms.Normalize(mean=mean, std=std)
|
212 |
+
# all_clips = [normalize(ac).to(device) for ac in all_clips]
|
213 |
+
|
214 |
+
# all_clips = torch.stack(all_clips, dim=0)
|
215 |
+
# audio_outputs.append(all_clips)
|
216 |
+
|
217 |
+
# return torch.stack(audio_outputs, dim=0)
|
218 |
+
|
219 |
+
|
220 |
+
def get_clip_timepoints(clip_sampler, duration):
|
221 |
+
# Read out all clips in this video
|
222 |
+
all_clips_timepoints = []
|
223 |
+
is_last_clip = False
|
224 |
+
end = 0.0
|
225 |
+
while not is_last_clip:
|
226 |
+
start, end, _, _, is_last_clip = clip_sampler(end, duration, annotation=None)
|
227 |
+
all_clips_timepoints.append((start, end))
|
228 |
+
return all_clips_timepoints
|
229 |
+
|
230 |
+
|
231 |
+
def crop_boxes(boxes, x_offset, y_offset):
|
232 |
+
"""
|
233 |
+
Peform crop on the bounding boxes given the offsets.
|
234 |
+
Args:
|
235 |
+
boxes (ndarray or None): bounding boxes to peform crop. The dimension
|
236 |
+
is `num boxes` x 4.
|
237 |
+
x_offset (int): cropping offset in the x axis.
|
238 |
+
y_offset (int): cropping offset in the y axis.
|
239 |
+
Returns:
|
240 |
+
cropped_boxes (ndarray or None): the cropped boxes with dimension of
|
241 |
+
`num boxes` x 4.
|
242 |
+
"""
|
243 |
+
cropped_boxes = boxes.copy()
|
244 |
+
cropped_boxes[:, [0, 2]] = boxes[:, [0, 2]] - x_offset
|
245 |
+
cropped_boxes[:, [1, 3]] = boxes[:, [1, 3]] - y_offset
|
246 |
+
|
247 |
+
return cropped_boxes
|
248 |
+
|
249 |
+
|
250 |
+
def uniform_crop(images, size, spatial_idx, boxes=None, scale_size=None):
|
251 |
+
"""
|
252 |
+
Perform uniform spatial sampling on the images and corresponding boxes.
|
253 |
+
Args:
|
254 |
+
images (tensor): images to perform uniform crop. The dimension is
|
255 |
+
`num frames` x `channel` x `height` x `width`.
|
256 |
+
size (int): size of height and weight to crop the images.
|
257 |
+
spatial_idx (int): 0, 1, or 2 for left, center, and right crop if width
|
258 |
+
is larger than height. Or 0, 1, or 2 for top, center, and bottom
|
259 |
+
crop if height is larger than width.
|
260 |
+
boxes (ndarray or None): optional. Corresponding boxes to images.
|
261 |
+
Dimension is `num boxes` x 4.
|
262 |
+
scale_size (int): optinal. If not None, resize the images to scale_size before
|
263 |
+
performing any crop.
|
264 |
+
Returns:
|
265 |
+
cropped (tensor): images with dimension of
|
266 |
+
`num frames` x `channel` x `size` x `size`.
|
267 |
+
cropped_boxes (ndarray or None): the cropped boxes with dimension of
|
268 |
+
`num boxes` x 4.
|
269 |
+
"""
|
270 |
+
assert spatial_idx in [0, 1, 2]
|
271 |
+
ndim = len(images.shape)
|
272 |
+
if ndim == 3:
|
273 |
+
images = images.unsqueeze(0)
|
274 |
+
height = images.shape[2]
|
275 |
+
width = images.shape[3]
|
276 |
+
|
277 |
+
if scale_size is not None:
|
278 |
+
if width <= height:
|
279 |
+
width, height = scale_size, int(height / width * scale_size)
|
280 |
+
else:
|
281 |
+
width, height = int(width / height * scale_size), scale_size
|
282 |
+
images = torch.nn.functional.interpolate(
|
283 |
+
images,
|
284 |
+
size=(height, width),
|
285 |
+
mode="bilinear",
|
286 |
+
align_corners=False,
|
287 |
+
)
|
288 |
+
|
289 |
+
y_offset = int(math.ceil((height - size) / 2))
|
290 |
+
x_offset = int(math.ceil((width - size) / 2))
|
291 |
+
|
292 |
+
if height > width:
|
293 |
+
if spatial_idx == 0:
|
294 |
+
y_offset = 0
|
295 |
+
elif spatial_idx == 2:
|
296 |
+
y_offset = height - size
|
297 |
+
else:
|
298 |
+
if spatial_idx == 0:
|
299 |
+
x_offset = 0
|
300 |
+
elif spatial_idx == 2:
|
301 |
+
x_offset = width - size
|
302 |
+
cropped = images[:, :, y_offset : y_offset + size, x_offset : x_offset + size]
|
303 |
+
cropped_boxes = crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None
|
304 |
+
if ndim == 3:
|
305 |
+
cropped = cropped.squeeze(0)
|
306 |
+
return cropped, cropped_boxes
|
307 |
+
|
308 |
+
|
309 |
+
class SpatialCrop(nn.Module):
|
310 |
+
"""
|
311 |
+
Convert the video into 3 smaller clips spatially. Must be used after the
|
312 |
+
temporal crops to get spatial crops, and should be used with
|
313 |
+
-2 in the spatial crop at the slowfast augmentation stage (so full
|
314 |
+
frames are passed in here). Will return a larger list with the
|
315 |
+
3x spatial crops as well.
|
316 |
+
"""
|
317 |
+
|
318 |
+
def __init__(self, crop_size: int = 224, num_crops: int = 3):
|
319 |
+
super().__init__()
|
320 |
+
self.crop_size = crop_size
|
321 |
+
if num_crops == 3:
|
322 |
+
self.crops_to_ext = [0, 1, 2]
|
323 |
+
self.flipped_crops_to_ext = []
|
324 |
+
elif num_crops == 1:
|
325 |
+
self.crops_to_ext = [1]
|
326 |
+
self.flipped_crops_to_ext = []
|
327 |
+
else:
|
328 |
+
raise NotImplementedError("Nothing else supported yet")
|
329 |
+
|
330 |
+
def forward(self, videos):
|
331 |
+
"""
|
332 |
+
Args:
|
333 |
+
videos: A list of C, T, H, W videos.
|
334 |
+
Returns:
|
335 |
+
videos: A list with 3x the number of elements. Each video converted
|
336 |
+
to C, T, H', W' by spatial cropping.
|
337 |
+
"""
|
338 |
+
assert isinstance(videos, list), "Must be a list of videos after temporal crops"
|
339 |
+
assert all([video.ndim == 4 for video in videos]), "Must be (C,T,H,W)"
|
340 |
+
res = []
|
341 |
+
for video in videos:
|
342 |
+
for spatial_idx in self.crops_to_ext:
|
343 |
+
res.append(uniform_crop(video, self.crop_size, spatial_idx)[0])
|
344 |
+
if not self.flipped_crops_to_ext:
|
345 |
+
continue
|
346 |
+
flipped_video = transforms.functional.hflip(video)
|
347 |
+
for spatial_idx in self.flipped_crops_to_ext:
|
348 |
+
res.append(uniform_crop(flipped_video, self.crop_size, spatial_idx)[0])
|
349 |
+
return res
|
350 |
+
|
351 |
+
"""
|
352 |
+
def load_and_transform_video_data(
|
353 |
+
video_paths,
|
354 |
+
device,
|
355 |
+
clip_duration=2,
|
356 |
+
clips_per_video=5,
|
357 |
+
sample_rate=16000,
|
358 |
+
):
|
359 |
+
if video_paths is None:
|
360 |
+
return None
|
361 |
+
|
362 |
+
video_outputs = []
|
363 |
+
video_transform = transforms.Compose(
|
364 |
+
[
|
365 |
+
pv_transforms.ShortSideScale(224),
|
366 |
+
NormalizeVideo(
|
367 |
+
mean=(0.48145466, 0.4578275, 0.40821073),
|
368 |
+
std=(0.26862954, 0.26130258, 0.27577711),
|
369 |
+
),
|
370 |
+
]
|
371 |
+
)
|
372 |
+
|
373 |
+
clip_sampler = ConstantClipsPerVideoSampler(
|
374 |
+
clip_duration=clip_duration, clips_per_video=clips_per_video
|
375 |
+
)
|
376 |
+
frame_sampler = pv_transforms.UniformTemporalSubsample(num_samples=clip_duration)
|
377 |
+
|
378 |
+
for video_path in video_paths:
|
379 |
+
video = EncodedVideo.from_path(
|
380 |
+
video_path,
|
381 |
+
decoder="decord",
|
382 |
+
decode_audio=False,
|
383 |
+
**{"sample_rate": sample_rate},
|
384 |
+
)
|
385 |
+
|
386 |
+
all_clips_timepoints = get_clip_timepoints(clip_sampler, video.duration)
|
387 |
+
|
388 |
+
all_video = []
|
389 |
+
for clip_timepoints in all_clips_timepoints:
|
390 |
+
# Read the clip, get frames
|
391 |
+
clip = video.get_clip(clip_timepoints[0], clip_timepoints[1])
|
392 |
+
if clip is None:
|
393 |
+
raise ValueError("No clip found")
|
394 |
+
video_clip = frame_sampler(clip["video"])
|
395 |
+
video_clip = video_clip / 255.0 # since this is float, need 0-1
|
396 |
+
|
397 |
+
all_video.append(video_clip)
|
398 |
+
|
399 |
+
all_video = [video_transform(clip) for clip in all_video]
|
400 |
+
all_video = SpatialCrop(224, num_crops=3)(all_video)
|
401 |
+
|
402 |
+
all_video = torch.stack(all_video, dim=0)
|
403 |
+
video_outputs.append(all_video)
|
404 |
+
|
405 |
+
return torch.stack(video_outputs, dim=0).to(device)
|
406 |
+
"""
|
model/PROCESS/helpers.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Portions Copyright (c) Meta Platforms, Inc. and affiliates.
|
3 |
+
# All rights reserved.
|
4 |
+
|
5 |
+
# This source code is licensed under the license found in the
|
6 |
+
# LICENSE file in the root directory of this source tree.
|
7 |
+
|
8 |
+
import math
|
9 |
+
|
10 |
+
import einops
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
|
14 |
+
import torch.nn as nn
|
15 |
+
|
16 |
+
|
17 |
+
class Normalize(nn.Module):
|
18 |
+
def __init__(self, dim: int) -> None:
|
19 |
+
super().__init__()
|
20 |
+
self.dim = dim
|
21 |
+
|
22 |
+
def forward(self, x):
|
23 |
+
return torch.nn.functional.normalize(x, dim=self.dim, p=2)
|
24 |
+
|
25 |
+
|
26 |
+
class LearnableLogitScaling(nn.Module):
|
27 |
+
def __init__(
|
28 |
+
self,
|
29 |
+
logit_scale_init: float = 1 / 0.07,
|
30 |
+
learnable: bool = True,
|
31 |
+
max_logit_scale: float = 100,
|
32 |
+
) -> None:
|
33 |
+
super().__init__()
|
34 |
+
self.max_logit_scale = max_logit_scale
|
35 |
+
self.logit_scale_init = logit_scale_init
|
36 |
+
self.learnable = learnable
|
37 |
+
log_logit_scale = torch.ones([]) * np.log(self.logit_scale_init)
|
38 |
+
if learnable:
|
39 |
+
self.log_logit_scale = nn.Parameter(log_logit_scale)
|
40 |
+
else:
|
41 |
+
self.register_buffer("log_logit_scale", log_logit_scale)
|
42 |
+
|
43 |
+
def forward(self, x):
|
44 |
+
return torch.clip(self.log_logit_scale.exp(), max=self.max_logit_scale) * x
|
45 |
+
|
46 |
+
def extra_repr(self):
|
47 |
+
st = f"logit_scale_init={self.logit_scale_init},learnable={self.learnable}, max_logit_scale={self.max_logit_scale}"
|
48 |
+
return st
|
49 |
+
|
50 |
+
|
51 |
+
class EinOpsRearrange(nn.Module):
|
52 |
+
def __init__(self, rearrange_expr: str, **kwargs) -> None:
|
53 |
+
super().__init__()
|
54 |
+
self.rearrange_expr = rearrange_expr
|
55 |
+
self.kwargs = kwargs
|
56 |
+
|
57 |
+
def forward(self, x):
|
58 |
+
assert isinstance(x, torch.Tensor)
|
59 |
+
return einops.rearrange(x, self.rearrange_expr, **self.kwargs)
|
60 |
+
|
61 |
+
|
62 |
+
class VerboseNNModule(nn.Module):
|
63 |
+
"""
|
64 |
+
Wrapper around nn.Module that prints registered buffers and parameter names.
|
65 |
+
"""
|
66 |
+
|
67 |
+
@staticmethod
|
68 |
+
def get_readable_tensor_repr(name: str, tensor: torch.Tensor) -> str:
|
69 |
+
st = (
|
70 |
+
"("
|
71 |
+
+ name
|
72 |
+
+ "): "
|
73 |
+
+ "tensor("
|
74 |
+
+ str(tuple(tensor[1].shape))
|
75 |
+
+ ", requires_grad="
|
76 |
+
+ str(tensor[1].requires_grad)
|
77 |
+
+ ")\n"
|
78 |
+
)
|
79 |
+
return st
|
80 |
+
|
81 |
+
def extra_repr(self) -> str:
|
82 |
+
named_modules = set()
|
83 |
+
for p in self.named_modules():
|
84 |
+
named_modules.update([p[0]])
|
85 |
+
named_modules = list(named_modules)
|
86 |
+
|
87 |
+
string_repr = ""
|
88 |
+
for p in self.named_parameters():
|
89 |
+
name = p[0].split(".")[0]
|
90 |
+
if name not in named_modules:
|
91 |
+
string_repr += self.get_readable_tensor_repr(name, p)
|
92 |
+
|
93 |
+
for p in self.named_buffers():
|
94 |
+
name = p[0].split(".")[0]
|
95 |
+
string_repr += self.get_readable_tensor_repr(name, p)
|
96 |
+
|
97 |
+
return string_repr
|
98 |
+
|
99 |
+
|
100 |
+
def cast_if_src_dtype(
|
101 |
+
tensor: torch.Tensor, src_dtype: torch.dtype, tgt_dtype: torch.dtype
|
102 |
+
):
|
103 |
+
updated = False
|
104 |
+
if tensor.dtype == src_dtype:
|
105 |
+
tensor = tensor.to(dtype=tgt_dtype)
|
106 |
+
updated = True
|
107 |
+
return tensor, updated
|
108 |
+
|
109 |
+
|
110 |
+
class QuickGELU(nn.Module):
|
111 |
+
# From https://github.com/openai/CLIP/blob/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1/clip/model.py#L166
|
112 |
+
def forward(self, x: torch.Tensor):
|
113 |
+
return x * torch.sigmoid(1.702 * x)
|
114 |
+
|
115 |
+
|
116 |
+
class SelectElement(nn.Module):
|
117 |
+
def __init__(self, index) -> None:
|
118 |
+
super().__init__()
|
119 |
+
self.index = index
|
120 |
+
|
121 |
+
def forward(self, x):
|
122 |
+
assert x.ndim >= 3
|
123 |
+
return x[:, self.index, ...]
|
124 |
+
|
125 |
+
|
126 |
+
class SelectEOSAndProject(nn.Module):
|
127 |
+
"""
|
128 |
+
Text Pooling used in OpenCLIP
|
129 |
+
"""
|
130 |
+
|
131 |
+
def __init__(self, proj: nn.Module) -> None:
|
132 |
+
super().__init__()
|
133 |
+
self.proj = proj
|
134 |
+
|
135 |
+
def forward(self, x, seq_len):
|
136 |
+
assert x.ndim == 3
|
137 |
+
# x is of shape B x L x D
|
138 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
139 |
+
x = x[torch.arange(x.shape[0]), seq_len]
|
140 |
+
x = self.proj(x)
|
141 |
+
return x
|
model/PROCESS/multimodal_preprocessors.py
ADDED
@@ -0,0 +1,687 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Portions Copyright (c) Meta Platforms, Inc. and affiliates.
|
3 |
+
# All rights reserved.
|
4 |
+
|
5 |
+
# This source code is licensed under the license found in the
|
6 |
+
# LICENSE file in the root directory of this source tree.
|
7 |
+
|
8 |
+
import gzip
|
9 |
+
import html
|
10 |
+
import io
|
11 |
+
import math
|
12 |
+
from functools import lru_cache
|
13 |
+
from typing import Callable, List, Optional
|
14 |
+
|
15 |
+
import ftfy
|
16 |
+
|
17 |
+
import numpy as np
|
18 |
+
import regex as re
|
19 |
+
import torch
|
20 |
+
import torch.nn as nn
|
21 |
+
from iopath.common.file_io import g_pathmgr
|
22 |
+
from timm.models.layers import trunc_normal_
|
23 |
+
|
24 |
+
from .helpers import cast_if_src_dtype, VerboseNNModule
|
25 |
+
|
26 |
+
|
27 |
+
def get_sinusoid_encoding_table(n_position, d_hid):
|
28 |
+
"""Sinusoid position encoding table"""
|
29 |
+
|
30 |
+
# TODO: make it with torch instead of numpy
|
31 |
+
def get_position_angle_vec(position):
|
32 |
+
return [
|
33 |
+
position / np.power(10000, 2 * (hid_j // 2) / d_hid)
|
34 |
+
for hid_j in range(d_hid)
|
35 |
+
]
|
36 |
+
|
37 |
+
sinusoid_table = np.array(
|
38 |
+
[get_position_angle_vec(pos_i) for pos_i in range(n_position)]
|
39 |
+
)
|
40 |
+
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
|
41 |
+
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
|
42 |
+
|
43 |
+
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
|
44 |
+
|
45 |
+
|
46 |
+
def interpolate_pos_encoding_2d(target_spatial_size, pos_embed):
|
47 |
+
N = pos_embed.shape[1]
|
48 |
+
if N == target_spatial_size:
|
49 |
+
return pos_embed
|
50 |
+
dim = pos_embed.shape[-1]
|
51 |
+
# nn.functional.interpolate doesn't work with bfloat16 so we cast to float32
|
52 |
+
pos_embed, updated = cast_if_src_dtype(pos_embed, torch.bfloat16, torch.float32)
|
53 |
+
pos_embed = nn.functional.interpolate(
|
54 |
+
pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(
|
55 |
+
0, 3, 1, 2
|
56 |
+
),
|
57 |
+
scale_factor=math.sqrt(target_spatial_size / N),
|
58 |
+
mode="bicubic",
|
59 |
+
)
|
60 |
+
if updated:
|
61 |
+
pos_embed, _ = cast_if_src_dtype(pos_embed, torch.float32, torch.bfloat16)
|
62 |
+
pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
63 |
+
return pos_embed
|
64 |
+
|
65 |
+
|
66 |
+
def interpolate_pos_encoding(
|
67 |
+
npatch_per_img,
|
68 |
+
pos_embed,
|
69 |
+
patches_layout,
|
70 |
+
input_shape=None,
|
71 |
+
first_patch_idx=1,
|
72 |
+
):
|
73 |
+
assert first_patch_idx == 0 or first_patch_idx == 1, "there is 1 CLS token or none"
|
74 |
+
N = pos_embed.shape[1] - first_patch_idx # since it's 1 if cls_token exists
|
75 |
+
if npatch_per_img == N:
|
76 |
+
return pos_embed
|
77 |
+
|
78 |
+
assert (
|
79 |
+
patches_layout[-1] == patches_layout[-2]
|
80 |
+
), "Interpolation of pos embed not supported for non-square layouts"
|
81 |
+
|
82 |
+
class_emb = pos_embed[:, :first_patch_idx]
|
83 |
+
pos_embed = pos_embed[:, first_patch_idx:]
|
84 |
+
|
85 |
+
if input_shape is None or patches_layout[0] == 1:
|
86 |
+
# simple 2D pos embedding, no temporal component
|
87 |
+
pos_embed = interpolate_pos_encoding_2d(npatch_per_img, pos_embed)
|
88 |
+
elif patches_layout[0] > 1:
|
89 |
+
# pos embed has a temporal component
|
90 |
+
assert len(input_shape) == 4, "temporal interpolation not supported"
|
91 |
+
# we only support 2D interpolation in this case
|
92 |
+
num_frames = patches_layout[0]
|
93 |
+
num_spatial_tokens = patches_layout[1] * patches_layout[2]
|
94 |
+
pos_embed = pos_embed.view(1, num_frames, num_spatial_tokens, -1)
|
95 |
+
# interpolate embedding for zeroth frame
|
96 |
+
pos_embed = interpolate_pos_encoding_2d(
|
97 |
+
npatch_per_img, pos_embed[0, 0, ...].unsqueeze(0)
|
98 |
+
)
|
99 |
+
else:
|
100 |
+
raise ValueError("This type of interpolation isn't implemented")
|
101 |
+
|
102 |
+
return torch.cat((class_emb, pos_embed), dim=1)
|
103 |
+
|
104 |
+
|
105 |
+
def _get_pos_embedding(
|
106 |
+
npatch_per_img,
|
107 |
+
pos_embed,
|
108 |
+
patches_layout,
|
109 |
+
input_shape,
|
110 |
+
first_patch_idx=1,
|
111 |
+
):
|
112 |
+
pos_embed = interpolate_pos_encoding(
|
113 |
+
npatch_per_img,
|
114 |
+
pos_embed,
|
115 |
+
patches_layout,
|
116 |
+
input_shape=input_shape,
|
117 |
+
first_patch_idx=first_patch_idx,
|
118 |
+
)
|
119 |
+
return pos_embed
|
120 |
+
|
121 |
+
|
122 |
+
class PatchEmbedGeneric(nn.Module):
|
123 |
+
"""
|
124 |
+
PatchEmbed from Hydra
|
125 |
+
"""
|
126 |
+
|
127 |
+
def __init__(self, proj_stem, norm_layer: Optional[nn.Module] = None):
|
128 |
+
super().__init__()
|
129 |
+
|
130 |
+
if len(proj_stem) > 1:
|
131 |
+
self.proj = nn.Sequential(*proj_stem)
|
132 |
+
else:
|
133 |
+
# Special case to be able to load pre-trained models that were
|
134 |
+
# trained with a standard stem
|
135 |
+
self.proj = proj_stem[0]
|
136 |
+
self.norm_layer = norm_layer
|
137 |
+
|
138 |
+
def get_patch_layout(self, img_size):
|
139 |
+
with torch.no_grad():
|
140 |
+
dummy_img = torch.zeros(
|
141 |
+
[
|
142 |
+
1,
|
143 |
+
]
|
144 |
+
+ img_size
|
145 |
+
)
|
146 |
+
dummy_out = self.proj(dummy_img)
|
147 |
+
embed_dim = dummy_out.shape[1]
|
148 |
+
patches_layout = tuple(dummy_out.shape[2:])
|
149 |
+
num_patches = np.prod(patches_layout)
|
150 |
+
return patches_layout, num_patches, embed_dim
|
151 |
+
|
152 |
+
def forward(self, x):
|
153 |
+
x = self.proj(x)
|
154 |
+
# B C (T) H W -> B (T)HW C
|
155 |
+
x = x.flatten(2).transpose(1, 2)
|
156 |
+
if self.norm_layer is not None:
|
157 |
+
x = self.norm_layer(x)
|
158 |
+
return x
|
159 |
+
|
160 |
+
|
161 |
+
class SpatioTemporalPosEmbeddingHelper(VerboseNNModule):
|
162 |
+
def __init__(
|
163 |
+
self,
|
164 |
+
patches_layout: List,
|
165 |
+
num_patches: int,
|
166 |
+
num_cls_tokens: int,
|
167 |
+
embed_dim: int,
|
168 |
+
learnable: bool,
|
169 |
+
) -> None:
|
170 |
+
super().__init__()
|
171 |
+
self.num_cls_tokens = num_cls_tokens
|
172 |
+
self.patches_layout = patches_layout
|
173 |
+
self.num_patches = num_patches
|
174 |
+
self.num_tokens = num_cls_tokens + num_patches
|
175 |
+
self.learnable = learnable
|
176 |
+
if self.learnable:
|
177 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, self.num_tokens, embed_dim))
|
178 |
+
trunc_normal_(self.pos_embed, std=0.02)
|
179 |
+
else:
|
180 |
+
self.register_buffer(
|
181 |
+
"pos_embed", get_sinusoid_encoding_table(self.num_tokens, embed_dim)
|
182 |
+
)
|
183 |
+
|
184 |
+
def get_pos_embedding(self, vision_input, all_vision_tokens):
|
185 |
+
input_shape = vision_input.shape
|
186 |
+
pos_embed = _get_pos_embedding(
|
187 |
+
all_vision_tokens.size(1) - self.num_cls_tokens,
|
188 |
+
pos_embed=self.pos_embed,
|
189 |
+
patches_layout=self.patches_layout,
|
190 |
+
input_shape=input_shape,
|
191 |
+
first_patch_idx=self.num_cls_tokens,
|
192 |
+
)
|
193 |
+
return pos_embed
|
194 |
+
|
195 |
+
|
196 |
+
class RGBDTPreprocessor(VerboseNNModule):
|
197 |
+
def __init__(
|
198 |
+
self,
|
199 |
+
rgbt_stem: PatchEmbedGeneric,
|
200 |
+
depth_stem: PatchEmbedGeneric,
|
201 |
+
img_size: List = (3, 224, 224),
|
202 |
+
num_cls_tokens: int = 1,
|
203 |
+
pos_embed_fn: Callable = None,
|
204 |
+
use_type_embed: bool = False,
|
205 |
+
init_param_style: str = "openclip",
|
206 |
+
) -> None:
|
207 |
+
super().__init__()
|
208 |
+
stem = rgbt_stem if rgbt_stem is not None else depth_stem
|
209 |
+
(
|
210 |
+
self.patches_layout,
|
211 |
+
self.num_patches,
|
212 |
+
self.embed_dim,
|
213 |
+
) = stem.get_patch_layout(img_size)
|
214 |
+
self.rgbt_stem = rgbt_stem
|
215 |
+
self.depth_stem = depth_stem
|
216 |
+
self.use_pos_embed = pos_embed_fn is not None
|
217 |
+
self.use_type_embed = use_type_embed
|
218 |
+
self.num_cls_tokens = num_cls_tokens
|
219 |
+
|
220 |
+
if self.use_pos_embed:
|
221 |
+
self.pos_embedding_helper = pos_embed_fn(
|
222 |
+
patches_layout=self.patches_layout,
|
223 |
+
num_cls_tokens=num_cls_tokens,
|
224 |
+
num_patches=self.num_patches,
|
225 |
+
embed_dim=self.embed_dim,
|
226 |
+
)
|
227 |
+
if self.num_cls_tokens > 0:
|
228 |
+
self.cls_token = nn.Parameter(
|
229 |
+
torch.zeros(1, self.num_cls_tokens, self.embed_dim)
|
230 |
+
)
|
231 |
+
if self.use_type_embed:
|
232 |
+
self.type_embed = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
|
233 |
+
|
234 |
+
self.init_parameters(init_param_style)
|
235 |
+
|
236 |
+
@torch.no_grad()
|
237 |
+
def init_parameters(self, init_param_style):
|
238 |
+
if init_param_style == "openclip":
|
239 |
+
# OpenCLIP style initialization
|
240 |
+
scale = self.embed_dim**-0.5
|
241 |
+
if self.use_pos_embed:
|
242 |
+
nn.init.normal_(self.pos_embedding_helper.pos_embed)
|
243 |
+
self.pos_embedding_helper.pos_embed *= scale
|
244 |
+
|
245 |
+
if self.num_cls_tokens > 0:
|
246 |
+
nn.init.normal_(self.cls_token)
|
247 |
+
self.cls_token *= scale
|
248 |
+
elif init_param_style == "vit":
|
249 |
+
self.cls_token.data.fill_(0)
|
250 |
+
else:
|
251 |
+
raise ValueError(f"Unknown init {init_param_style}")
|
252 |
+
|
253 |
+
if self.use_type_embed:
|
254 |
+
nn.init.normal_(self.type_embed)
|
255 |
+
|
256 |
+
def tokenize_input_and_cls_pos(self, input, stem, mask):
|
257 |
+
# tokens is of shape B x L x D
|
258 |
+
tokens = stem(input)
|
259 |
+
assert tokens.ndim == 3
|
260 |
+
assert tokens.shape[2] == self.embed_dim
|
261 |
+
B = tokens.shape[0]
|
262 |
+
if self.num_cls_tokens > 0:
|
263 |
+
class_tokens = self.cls_token.expand(
|
264 |
+
B, -1, -1
|
265 |
+
) # stole class_tokens impl from Phil Wang, thanks
|
266 |
+
tokens = torch.cat((class_tokens, tokens), dim=1)
|
267 |
+
if self.use_pos_embed:
|
268 |
+
pos_embed = self.pos_embedding_helper.get_pos_embedding(input, tokens)
|
269 |
+
tokens = tokens + pos_embed
|
270 |
+
if self.use_type_embed:
|
271 |
+
tokens = tokens + self.type_embed.expand(B, -1, -1)
|
272 |
+
return tokens
|
273 |
+
|
274 |
+
def forward(self, vision=None, depth=None, patch_mask=None):
|
275 |
+
if patch_mask is not None:
|
276 |
+
raise NotImplementedError()
|
277 |
+
|
278 |
+
if vision is not None:
|
279 |
+
vision_tokens = self.tokenize_input_and_cls_pos(
|
280 |
+
vision, self.rgbt_stem, patch_mask
|
281 |
+
)
|
282 |
+
|
283 |
+
if depth is not None:
|
284 |
+
depth_tokens = self.tokenize_input_and_cls_pos(
|
285 |
+
depth, self.depth_stem, patch_mask
|
286 |
+
)
|
287 |
+
|
288 |
+
# aggregate tokens
|
289 |
+
if vision is not None and depth is not None:
|
290 |
+
final_tokens = vision_tokens + depth_tokens
|
291 |
+
else:
|
292 |
+
final_tokens = vision_tokens if vision is not None else depth_tokens
|
293 |
+
return_dict = {
|
294 |
+
"trunk": {
|
295 |
+
"tokens": final_tokens,
|
296 |
+
},
|
297 |
+
"head": {},
|
298 |
+
}
|
299 |
+
return return_dict
|
300 |
+
|
301 |
+
|
302 |
+
class AudioPreprocessor(RGBDTPreprocessor):
|
303 |
+
def __init__(self, audio_stem: PatchEmbedGeneric, **kwargs) -> None:
|
304 |
+
super().__init__(rgbt_stem=audio_stem, depth_stem=None, **kwargs)
|
305 |
+
|
306 |
+
def forward(self, audio=None):
|
307 |
+
return super().forward(vision=audio)
|
308 |
+
|
309 |
+
|
310 |
+
class ThermalPreprocessor(RGBDTPreprocessor):
|
311 |
+
def __init__(self, thermal_stem: PatchEmbedGeneric, **kwargs) -> None:
|
312 |
+
super().__init__(rgbt_stem=thermal_stem, depth_stem=None, **kwargs)
|
313 |
+
|
314 |
+
def forward(self, thermal=None):
|
315 |
+
return super().forward(vision=thermal)
|
316 |
+
|
317 |
+
|
318 |
+
def build_causal_attention_mask(context_length):
|
319 |
+
# lazily create causal attention mask, with full attention between the vision tokens
|
320 |
+
# pytorch uses additive attention mask; fill with -inf
|
321 |
+
mask = torch.empty(context_length, context_length, requires_grad=False)
|
322 |
+
mask.fill_(float("-inf"))
|
323 |
+
mask.triu_(1) # zero out the lower diagonal
|
324 |
+
return mask
|
325 |
+
|
326 |
+
|
327 |
+
class TextPreprocessor(VerboseNNModule):
|
328 |
+
def __init__(
|
329 |
+
self,
|
330 |
+
vocab_size: int,
|
331 |
+
context_length: int,
|
332 |
+
embed_dim: int,
|
333 |
+
causal_masking: bool,
|
334 |
+
supply_seq_len_to_head: bool = True,
|
335 |
+
num_cls_tokens: int = 0,
|
336 |
+
init_param_style: str = "openclip",
|
337 |
+
) -> None:
|
338 |
+
super().__init__()
|
339 |
+
self.vocab_size = vocab_size
|
340 |
+
self.context_length = context_length
|
341 |
+
self.token_embedding = nn.Embedding(vocab_size, embed_dim)
|
342 |
+
self.pos_embed = nn.Parameter(
|
343 |
+
torch.empty(1, self.context_length + num_cls_tokens, embed_dim)
|
344 |
+
)
|
345 |
+
self.causal_masking = causal_masking
|
346 |
+
if self.causal_masking:
|
347 |
+
mask = build_causal_attention_mask(self.context_length)
|
348 |
+
# register the mask as a buffer so it can be moved to the right device
|
349 |
+
self.register_buffer("mask", mask)
|
350 |
+
|
351 |
+
self.supply_seq_len_to_head = supply_seq_len_to_head
|
352 |
+
self.num_cls_tokens = num_cls_tokens
|
353 |
+
self.embed_dim = embed_dim
|
354 |
+
if num_cls_tokens > 0:
|
355 |
+
assert self.causal_masking is False, "Masking + CLS token isn't implemented"
|
356 |
+
self.cls_token = nn.Parameter(
|
357 |
+
torch.zeros(1, self.num_cls_tokens, embed_dim)
|
358 |
+
)
|
359 |
+
|
360 |
+
self.init_parameters(init_param_style)
|
361 |
+
|
362 |
+
@torch.no_grad()
|
363 |
+
def init_parameters(self, init_param_style="openclip"):
|
364 |
+
# OpenCLIP style initialization
|
365 |
+
nn.init.normal_(self.token_embedding.weight, std=0.02)
|
366 |
+
nn.init.normal_(self.pos_embed, std=0.01)
|
367 |
+
|
368 |
+
if init_param_style == "openclip":
|
369 |
+
# OpenCLIP style initialization
|
370 |
+
scale = self.embed_dim**-0.5
|
371 |
+
if self.num_cls_tokens > 0:
|
372 |
+
nn.init.normal_(self.cls_token)
|
373 |
+
self.cls_token *= scale
|
374 |
+
elif init_param_style == "vit":
|
375 |
+
self.cls_token.data.fill_(0)
|
376 |
+
else:
|
377 |
+
raise ValueError(f"Unknown init {init_param_style}")
|
378 |
+
|
379 |
+
def forward(self, text):
|
380 |
+
# text tokens are of shape B x L x D
|
381 |
+
text_tokens = self.token_embedding(text)
|
382 |
+
# concat CLS tokens if any
|
383 |
+
if self.num_cls_tokens > 0:
|
384 |
+
B = text_tokens.shape[0]
|
385 |
+
class_tokens = self.cls_token.expand(
|
386 |
+
B, -1, -1
|
387 |
+
) # stole class_tokens impl from Phil Wang, thanks
|
388 |
+
text_tokens = torch.cat((class_tokens, text_tokens), dim=1)
|
389 |
+
text_tokens = text_tokens + self.pos_embed
|
390 |
+
return_dict = {
|
391 |
+
"trunk": {
|
392 |
+
"tokens": text_tokens,
|
393 |
+
},
|
394 |
+
"head": {},
|
395 |
+
}
|
396 |
+
# Compute sequence length after adding CLS tokens
|
397 |
+
if self.supply_seq_len_to_head:
|
398 |
+
text_lengths = text.argmax(dim=-1)
|
399 |
+
return_dict["head"] = {
|
400 |
+
"seq_len": text_lengths,
|
401 |
+
}
|
402 |
+
if self.causal_masking:
|
403 |
+
return_dict["trunk"].update({"attn_mask": self.mask})
|
404 |
+
return return_dict
|
405 |
+
|
406 |
+
|
407 |
+
class Im2Video(nn.Module):
|
408 |
+
"""Convert an image into a trivial video."""
|
409 |
+
|
410 |
+
def __init__(self, time_dim=2):
|
411 |
+
super().__init__()
|
412 |
+
self.time_dim = time_dim
|
413 |
+
|
414 |
+
def forward(self, x):
|
415 |
+
if x.ndim == 4:
|
416 |
+
# B, C, H, W -> B, C, T, H, W
|
417 |
+
return x.unsqueeze(self.time_dim)
|
418 |
+
elif x.ndim == 5:
|
419 |
+
return x
|
420 |
+
else:
|
421 |
+
raise ValueError(f"Dimension incorrect {x.shape}")
|
422 |
+
|
423 |
+
|
424 |
+
class PadIm2Video(Im2Video):
|
425 |
+
def __init__(self, ntimes, pad_type, time_dim=2):
|
426 |
+
super().__init__(time_dim=time_dim)
|
427 |
+
assert ntimes > 0
|
428 |
+
assert pad_type in ["zero", "repeat"]
|
429 |
+
self.ntimes = ntimes
|
430 |
+
self.pad_type = pad_type
|
431 |
+
|
432 |
+
def forward(self, x):
|
433 |
+
x = super().forward(x)
|
434 |
+
if x.shape[self.time_dim] == 1:
|
435 |
+
if self.pad_type == "repeat":
|
436 |
+
new_shape = [1] * len(x.shape)
|
437 |
+
new_shape[self.time_dim] = self.ntimes
|
438 |
+
x = x.repeat(new_shape)
|
439 |
+
elif self.pad_type == "zero":
|
440 |
+
padarg = [0, 0] * len(x.shape)
|
441 |
+
padarg[2 * self.time_dim + 1] = self.ntimes - x.shape[self.time_dim]
|
442 |
+
x = nn.functional.pad(x, padarg)
|
443 |
+
return x
|
444 |
+
|
445 |
+
|
446 |
+
# Modified from github.com/openai/CLIP
|
447 |
+
@lru_cache()
|
448 |
+
def bytes_to_unicode():
|
449 |
+
"""
|
450 |
+
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
451 |
+
The reversible bpe codes work on unicode strings.
|
452 |
+
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
453 |
+
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
454 |
+
This is a signficant percentage of your normal, say, 32K bpe vocab.
|
455 |
+
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
456 |
+
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
457 |
+
"""
|
458 |
+
bs = (
|
459 |
+
list(range(ord("!"), ord("~") + 1))
|
460 |
+
+ list(range(ord("¡"), ord("¬") + 1))
|
461 |
+
+ list(range(ord("®"), ord("ÿ") + 1))
|
462 |
+
)
|
463 |
+
cs = bs[:]
|
464 |
+
n = 0
|
465 |
+
for b in range(2**8):
|
466 |
+
if b not in bs:
|
467 |
+
bs.append(b)
|
468 |
+
cs.append(2**8 + n)
|
469 |
+
n += 1
|
470 |
+
cs = [chr(n) for n in cs]
|
471 |
+
return dict(zip(bs, cs))
|
472 |
+
|
473 |
+
|
474 |
+
def get_pairs(word):
|
475 |
+
"""Return set of symbol pairs in a word.
|
476 |
+
Word is represented as tuple of symbols (symbols being variable-length strings).
|
477 |
+
"""
|
478 |
+
pairs = set()
|
479 |
+
prev_char = word[0]
|
480 |
+
for char in word[1:]:
|
481 |
+
pairs.add((prev_char, char))
|
482 |
+
prev_char = char
|
483 |
+
return pairs
|
484 |
+
|
485 |
+
|
486 |
+
def basic_clean(text):
|
487 |
+
text = ftfy.fix_text(text)
|
488 |
+
text = html.unescape(html.unescape(text))
|
489 |
+
return text.strip()
|
490 |
+
|
491 |
+
|
492 |
+
def whitespace_clean(text):
|
493 |
+
text = re.sub(r"\s+", " ", text)
|
494 |
+
text = text.strip()
|
495 |
+
return text
|
496 |
+
|
497 |
+
|
498 |
+
class SimpleTokenizer(object):
|
499 |
+
def __init__(self, bpe_path: str, context_length=77):
|
500 |
+
self.byte_encoder = bytes_to_unicode()
|
501 |
+
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
502 |
+
|
503 |
+
with g_pathmgr.open(bpe_path, "rb") as fh:
|
504 |
+
bpe_bytes = io.BytesIO(fh.read())
|
505 |
+
merges = gzip.open(bpe_bytes).read().decode("utf-8").split("\n")
|
506 |
+
merges = merges[1 : 49152 - 256 - 2 + 1]
|
507 |
+
merges = [tuple(merge.split()) for merge in merges]
|
508 |
+
vocab = list(bytes_to_unicode().values())
|
509 |
+
vocab = vocab + [v + "</w>" for v in vocab]
|
510 |
+
for merge in merges:
|
511 |
+
vocab.append("".join(merge))
|
512 |
+
vocab.extend(["<|startoftext|>", "<|endoftext|>"])
|
513 |
+
self.encoder = dict(zip(vocab, range(len(vocab))))
|
514 |
+
self.decoder = {v: k for k, v in self.encoder.items()}
|
515 |
+
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
516 |
+
self.cache = {
|
517 |
+
"<|startoftext|>": "<|startoftext|>",
|
518 |
+
"<|endoftext|>": "<|endoftext|>",
|
519 |
+
}
|
520 |
+
self.pat = re.compile(
|
521 |
+
r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
|
522 |
+
re.IGNORECASE,
|
523 |
+
)
|
524 |
+
self.context_length = context_length
|
525 |
+
|
526 |
+
def bpe(self, token):
|
527 |
+
if token in self.cache:
|
528 |
+
return self.cache[token]
|
529 |
+
word = tuple(token[:-1]) + (token[-1] + "</w>",)
|
530 |
+
pairs = get_pairs(word)
|
531 |
+
|
532 |
+
if not pairs:
|
533 |
+
return token + "</w>"
|
534 |
+
|
535 |
+
while True:
|
536 |
+
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
|
537 |
+
if bigram not in self.bpe_ranks:
|
538 |
+
break
|
539 |
+
first, second = bigram
|
540 |
+
new_word = []
|
541 |
+
i = 0
|
542 |
+
while i < len(word):
|
543 |
+
try:
|
544 |
+
j = word.index(first, i)
|
545 |
+
new_word.extend(word[i:j])
|
546 |
+
i = j
|
547 |
+
except:
|
548 |
+
new_word.extend(word[i:])
|
549 |
+
break
|
550 |
+
|
551 |
+
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
|
552 |
+
new_word.append(first + second)
|
553 |
+
i += 2
|
554 |
+
else:
|
555 |
+
new_word.append(word[i])
|
556 |
+
i += 1
|
557 |
+
new_word = tuple(new_word)
|
558 |
+
word = new_word
|
559 |
+
if len(word) == 1:
|
560 |
+
break
|
561 |
+
else:
|
562 |
+
pairs = get_pairs(word)
|
563 |
+
word = " ".join(word)
|
564 |
+
self.cache[token] = word
|
565 |
+
return word
|
566 |
+
|
567 |
+
def encode(self, text):
|
568 |
+
bpe_tokens = []
|
569 |
+
text = whitespace_clean(basic_clean(text)).lower()
|
570 |
+
for token in re.findall(self.pat, text):
|
571 |
+
token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
|
572 |
+
bpe_tokens.extend(
|
573 |
+
self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")
|
574 |
+
)
|
575 |
+
return bpe_tokens
|
576 |
+
|
577 |
+
def decode(self, tokens):
|
578 |
+
text = "".join([self.decoder[token] for token in tokens])
|
579 |
+
text = (
|
580 |
+
bytearray([self.byte_decoder[c] for c in text])
|
581 |
+
.decode("utf-8", errors="replace")
|
582 |
+
.replace("</w>", " ")
|
583 |
+
)
|
584 |
+
return text
|
585 |
+
|
586 |
+
def __call__(self, texts, context_length=None):
|
587 |
+
if not context_length:
|
588 |
+
context_length = self.context_length
|
589 |
+
|
590 |
+
if isinstance(texts, str):
|
591 |
+
texts = [texts]
|
592 |
+
|
593 |
+
sot_token = self.encoder["<|startoftext|>"]
|
594 |
+
eot_token = self.encoder["<|endoftext|>"]
|
595 |
+
all_tokens = [[sot_token] + self.encode(text) + [eot_token] for text in texts]
|
596 |
+
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
597 |
+
|
598 |
+
for i, tokens in enumerate(all_tokens):
|
599 |
+
tokens = tokens[:context_length]
|
600 |
+
result[i, : len(tokens)] = torch.tensor(tokens)
|
601 |
+
|
602 |
+
if len(result) == 1:
|
603 |
+
return result[0]
|
604 |
+
return result
|
605 |
+
|
606 |
+
|
607 |
+
class IMUPreprocessor(VerboseNNModule):
|
608 |
+
def __init__(
|
609 |
+
self,
|
610 |
+
kernel_size: int,
|
611 |
+
imu_stem: PatchEmbedGeneric,
|
612 |
+
embed_dim: int,
|
613 |
+
img_size: List = (6, 2000),
|
614 |
+
num_cls_tokens: int = 1,
|
615 |
+
pos_embed_fn: Callable = None,
|
616 |
+
init_param_style: str = "openclip",
|
617 |
+
) -> None:
|
618 |
+
super().__init__()
|
619 |
+
stem = imu_stem
|
620 |
+
self.imu_stem = imu_stem
|
621 |
+
self.embed_dim = embed_dim
|
622 |
+
self.use_pos_embed = pos_embed_fn is not None
|
623 |
+
self.num_cls_tokens = num_cls_tokens
|
624 |
+
self.kernel_size = kernel_size
|
625 |
+
self.pos_embed = nn.Parameter(
|
626 |
+
torch.empty(1, (img_size[1] // kernel_size) + num_cls_tokens, embed_dim)
|
627 |
+
)
|
628 |
+
|
629 |
+
if self.num_cls_tokens > 0:
|
630 |
+
self.cls_token = nn.Parameter(
|
631 |
+
torch.zeros(1, self.num_cls_tokens, self.embed_dim)
|
632 |
+
)
|
633 |
+
|
634 |
+
self.init_parameters(init_param_style)
|
635 |
+
|
636 |
+
@torch.no_grad()
|
637 |
+
def init_parameters(self, init_param_style):
|
638 |
+
nn.init.normal_(self.pos_embed, std=0.01)
|
639 |
+
|
640 |
+
if init_param_style == "openclip":
|
641 |
+
# OpenCLIP style initialization
|
642 |
+
scale = self.embed_dim**-0.5
|
643 |
+
|
644 |
+
if self.num_cls_tokens > 0:
|
645 |
+
nn.init.normal_(self.cls_token)
|
646 |
+
self.cls_token *= scale
|
647 |
+
elif init_param_style == "vit":
|
648 |
+
self.cls_token.data.fill_(0)
|
649 |
+
else:
|
650 |
+
raise ValueError(f"Unknown init {init_param_style}")
|
651 |
+
|
652 |
+
def tokenize_input_and_cls_pos(self, input, stem):
|
653 |
+
# tokens is of shape B x L x D
|
654 |
+
tokens = stem.norm_layer(stem.proj(input))
|
655 |
+
assert tokens.ndim == 3
|
656 |
+
assert tokens.shape[2] == self.embed_dim
|
657 |
+
B = tokens.shape[0]
|
658 |
+
if self.num_cls_tokens > 0:
|
659 |
+
class_tokens = self.cls_token.expand(
|
660 |
+
B, -1, -1
|
661 |
+
) # stole class_tokens impl from Phil Wang, thanks
|
662 |
+
tokens = torch.cat((class_tokens, tokens), dim=1)
|
663 |
+
if self.use_pos_embed:
|
664 |
+
tokens = tokens + self.pos_embed
|
665 |
+
return tokens
|
666 |
+
|
667 |
+
def forward(self, imu):
|
668 |
+
# Patchify
|
669 |
+
imu = imu.unfold(
|
670 |
+
-1,
|
671 |
+
self.kernel_size,
|
672 |
+
self.kernel_size,
|
673 |
+
).permute(0, 2, 1, 3)
|
674 |
+
imu = imu.reshape(imu.size(0), imu.size(1), -1)
|
675 |
+
|
676 |
+
imu_tokens = self.tokenize_input_and_cls_pos(
|
677 |
+
imu,
|
678 |
+
self.imu_stem,
|
679 |
+
)
|
680 |
+
|
681 |
+
return_dict = {
|
682 |
+
"trunk": {
|
683 |
+
"tokens": imu_tokens,
|
684 |
+
},
|
685 |
+
"head": {},
|
686 |
+
}
|
687 |
+
return return_dict
|
model/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .agent import DeepSpeedAgent
|
2 |
+
from .openlamm import LAMMPEFTModel
|
3 |
+
|
4 |
+
|
5 |
+
def load_model(args):
|
6 |
+
agent_name = args['models'][args['model']]['agent_name']
|
7 |
+
model_name = args['models'][args['model']]['model_name']
|
8 |
+
model = globals()[model_name](**args)
|
9 |
+
agent = globals()[agent_name](model, args)
|
10 |
+
return agent
|
model/agent.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from header import *
|
2 |
+
from torch.utils.tensorboard import SummaryWriter
|
3 |
+
|
4 |
+
|
5 |
+
class DeepSpeedAgent:
|
6 |
+
|
7 |
+
def __init__(self, model, args):
|
8 |
+
super(DeepSpeedAgent, self).__init__()
|
9 |
+
self.args = args
|
10 |
+
self.model = model
|
11 |
+
self.writer = SummaryWriter(args['log_path'])
|
12 |
+
if args['stage'] == 2:
|
13 |
+
self.load_stage_1_parameters(args["delta_ckpt_path"])
|
14 |
+
print(f'[!] load stage 1 checkpoint from {args["delta_ckpt_path"]}')
|
15 |
+
|
16 |
+
# load config parameters of deepspeed
|
17 |
+
ds_params = json.load(open(self.args['ds_config_path']))
|
18 |
+
ds_params['scheduler']['params']['total_num_steps'] = self.args['total_steps']
|
19 |
+
ds_params['scheduler']['params']['warmup_num_steps'] = max(10, int(self.args['total_steps'] * self.args['warmup_rate']))
|
20 |
+
self.ds_engine, self.optimizer, _ , _ = deepspeed.initialize(
|
21 |
+
model=self.model,
|
22 |
+
model_parameters=self.model.parameters(),
|
23 |
+
config_params=ds_params,
|
24 |
+
dist_init_required=True,
|
25 |
+
args=types.SimpleNamespace(**args)
|
26 |
+
)
|
27 |
+
|
28 |
+
@torch.no_grad()
|
29 |
+
def predict(self, batch):
|
30 |
+
self.model.eval()
|
31 |
+
string = self.model.generate_one_sample(batch)
|
32 |
+
return string
|
33 |
+
|
34 |
+
def train_model(self, batch, current_step=0, pbar=None):
|
35 |
+
self.ds_engine.module.train()
|
36 |
+
loss, mle_acc = self.ds_engine(batch)
|
37 |
+
|
38 |
+
self.ds_engine.backward(loss)
|
39 |
+
self.ds_engine.step()
|
40 |
+
pbar.set_description(f'[!] loss: {round(loss.item(), 4)}; token_acc: {round(mle_acc*100, 2)}')
|
41 |
+
pbar.update(1)
|
42 |
+
if self.args['local_rank'] == 0 and self.args['log_path'] and current_step % self.args['logging_step'] == 0:
|
43 |
+
elapsed = pbar.format_dict['elapsed']
|
44 |
+
rate = pbar.format_dict['rate']
|
45 |
+
remaining = (pbar.total - pbar.n) / rate if rate and pbar.total else 0
|
46 |
+
remaining = str(datetime.timedelta(seconds=remaining))
|
47 |
+
self.writer.add_scalar('train/loss', loss.item(), current_step)
|
48 |
+
self.writer.add_scalar('train/token_acc', mle_acc*100, current_step)
|
49 |
+
logging.info(f'[!] progress: {round(pbar.n/pbar.total, 5)}; remaining time: {remaining}; loss: {round(loss.item(), 4)}; token_acc: {round(mle_acc*100, 2)}')
|
50 |
+
|
51 |
+
mle_acc *= 100
|
52 |
+
return mle_acc
|
53 |
+
|
54 |
+
def save_model(self, path, current_step):
|
55 |
+
# only save trainable model parameters
|
56 |
+
param_grad_dic = {
|
57 |
+
k: v.requires_grad for (k, v) in self.ds_engine.module.named_parameters()
|
58 |
+
}
|
59 |
+
state_dict = self.ds_engine.module.state_dict()
|
60 |
+
checkpoint = OrderedDict()
|
61 |
+
for k, v in self.ds_engine.module.named_parameters():
|
62 |
+
if v.requires_grad:
|
63 |
+
checkpoint[k] = v
|
64 |
+
if current_step <= 0:
|
65 |
+
torch.save(checkpoint, f'{path}/pytorch_model.pt')
|
66 |
+
else:
|
67 |
+
torch.save(checkpoint, f'{path}/pytorch_model_ep{current_step}.pt')
|
68 |
+
# save tokenizer
|
69 |
+
self.model.llama_tokenizer.save_pretrained(path)
|
70 |
+
# save configuration
|
71 |
+
self.model.llama_model.config.save_pretrained(path)
|
72 |
+
print(f'[!] save model into {path}')
|
73 |
+
|
74 |
+
def load_stage_1_parameters(self, path):
|
75 |
+
delta_ckpt = torch.load(path, map_location=torch.device('cpu'))
|
76 |
+
self.model.load_state_dict(delta_ckpt, strict=False)
|
model/conversations.py
ADDED
@@ -0,0 +1,386 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dataclasses
|
2 |
+
from enum import auto, Enum
|
3 |
+
from typing import List, Tuple
|
4 |
+
|
5 |
+
|
6 |
+
class SeparatorStyle(Enum):
|
7 |
+
"""Different separator style."""
|
8 |
+
SINGLE = auto()
|
9 |
+
TWO = auto()
|
10 |
+
MPT = auto()
|
11 |
+
|
12 |
+
|
13 |
+
@dataclasses.dataclass
|
14 |
+
class Conversation:
|
15 |
+
"""A class that keeps all conversation history."""
|
16 |
+
system: str
|
17 |
+
roles: List[str]
|
18 |
+
messages: List[List[str]]
|
19 |
+
offset: int
|
20 |
+
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
|
21 |
+
sep: str = "###"
|
22 |
+
sep2: str = None
|
23 |
+
version: str = "Unknown"
|
24 |
+
|
25 |
+
skip_next: bool = False
|
26 |
+
|
27 |
+
def get_prompt(self):
|
28 |
+
if self.sep_style == SeparatorStyle.SINGLE:
|
29 |
+
ret = self.system + self.sep
|
30 |
+
for role, message in self.messages:
|
31 |
+
if message:
|
32 |
+
if type(message) is tuple:
|
33 |
+
message, _, _ = message
|
34 |
+
ret += role + ": " + message + self.sep
|
35 |
+
else:
|
36 |
+
ret += role + ":"
|
37 |
+
return ret
|
38 |
+
elif self.sep_style == SeparatorStyle.TWO:
|
39 |
+
seps = [self.sep, self.sep2]
|
40 |
+
ret = self.system + seps[0]
|
41 |
+
for i, (role, message) in enumerate(self.messages):
|
42 |
+
if message:
|
43 |
+
if type(message) is tuple:
|
44 |
+
message, _, _ = message
|
45 |
+
ret += role + ": " + message + seps[i % 2]
|
46 |
+
else:
|
47 |
+
ret += role + ":"
|
48 |
+
return ret
|
49 |
+
if self.sep_style == SeparatorStyle.MPT:
|
50 |
+
ret = self.system + self.sep
|
51 |
+
for role, message in self.messages:
|
52 |
+
if message:
|
53 |
+
if type(message) is tuple:
|
54 |
+
message, _, _ = message
|
55 |
+
ret += role + message + self.sep
|
56 |
+
else:
|
57 |
+
ret += role
|
58 |
+
return ret
|
59 |
+
else:
|
60 |
+
raise ValueError(f"Invalid style: {self.sep_style}")
|
61 |
+
|
62 |
+
def append_message(self, role, message):
|
63 |
+
self.messages.append([role, message])
|
64 |
+
|
65 |
+
def get_images(self, return_pil=False):
|
66 |
+
images = []
|
67 |
+
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
68 |
+
if i % 2 == 0:
|
69 |
+
if type(msg) is tuple:
|
70 |
+
import base64
|
71 |
+
from io import BytesIO
|
72 |
+
from PIL import Image
|
73 |
+
msg, image, image_process_mode = msg
|
74 |
+
if image_process_mode == "Pad":
|
75 |
+
def expand2square(pil_img, background_color=(122, 116, 104)):
|
76 |
+
width, height = pil_img.size
|
77 |
+
if width == height:
|
78 |
+
return pil_img
|
79 |
+
elif width > height:
|
80 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
81 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
82 |
+
return result
|
83 |
+
else:
|
84 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
85 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
86 |
+
return result
|
87 |
+
image = expand2square(image)
|
88 |
+
elif image_process_mode == "Crop":
|
89 |
+
pass
|
90 |
+
elif image_process_mode == "Resize":
|
91 |
+
image = image.resize((224, 224))
|
92 |
+
else:
|
93 |
+
raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
|
94 |
+
max_hw, min_hw = max(image.size), min(image.size)
|
95 |
+
aspect_ratio = max_hw / min_hw
|
96 |
+
max_len, min_len = 800, 400
|
97 |
+
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
|
98 |
+
longest_edge = int(shortest_edge * aspect_ratio)
|
99 |
+
W, H = image.size
|
100 |
+
if H > W:
|
101 |
+
H, W = longest_edge, shortest_edge
|
102 |
+
else:
|
103 |
+
H, W = shortest_edge, longest_edge
|
104 |
+
image = image.resize((W, H))
|
105 |
+
if return_pil:
|
106 |
+
images.append(image)
|
107 |
+
else:
|
108 |
+
buffered = BytesIO()
|
109 |
+
image.save(buffered, format="JPEG")
|
110 |
+
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
111 |
+
images.append(img_b64_str)
|
112 |
+
return images
|
113 |
+
|
114 |
+
def to_gradio_chatbot(self):
|
115 |
+
ret = []
|
116 |
+
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
117 |
+
if i % 2 == 0:
|
118 |
+
if type(msg) is tuple:
|
119 |
+
import base64
|
120 |
+
from io import BytesIO
|
121 |
+
msg, image, image_process_mode = msg
|
122 |
+
max_hw, min_hw = max(image.size), min(image.size)
|
123 |
+
aspect_ratio = max_hw / min_hw
|
124 |
+
max_len, min_len = 800, 400
|
125 |
+
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
|
126 |
+
longest_edge = int(shortest_edge * aspect_ratio)
|
127 |
+
W, H = image.size
|
128 |
+
if H > W:
|
129 |
+
H, W = longest_edge, shortest_edge
|
130 |
+
else:
|
131 |
+
H, W = shortest_edge, longest_edge
|
132 |
+
image = image.resize((W, H))
|
133 |
+
# image = image.resize((224, 224))
|
134 |
+
buffered = BytesIO()
|
135 |
+
image.save(buffered, format="JPEG")
|
136 |
+
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
137 |
+
img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
|
138 |
+
msg = msg.replace('<image>', img_str)
|
139 |
+
ret.append([msg, None])
|
140 |
+
else:
|
141 |
+
ret[-1][-1] = msg
|
142 |
+
return ret
|
143 |
+
|
144 |
+
def copy(self):
|
145 |
+
return Conversation(
|
146 |
+
system=self.system,
|
147 |
+
roles=self.roles,
|
148 |
+
messages=[[x, y] for x, y in self.messages],
|
149 |
+
offset=self.offset,
|
150 |
+
sep_style=self.sep_style,
|
151 |
+
sep=self.sep,
|
152 |
+
sep2=self.sep2)
|
153 |
+
|
154 |
+
def dict(self):
|
155 |
+
if len(self.get_images()) > 0:
|
156 |
+
return {
|
157 |
+
"system": self.system,
|
158 |
+
"roles": self.roles,
|
159 |
+
"messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
|
160 |
+
"offset": self.offset,
|
161 |
+
"sep": self.sep,
|
162 |
+
"sep2": self.sep2,
|
163 |
+
}
|
164 |
+
return {
|
165 |
+
"system": self.system,
|
166 |
+
"roles": self.roles,
|
167 |
+
"messages": self.messages,
|
168 |
+
"offset": self.offset,
|
169 |
+
"sep": self.sep,
|
170 |
+
"sep2": self.sep2,
|
171 |
+
}
|
172 |
+
|
173 |
+
|
174 |
+
conv_v1 = Conversation(
|
175 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
176 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
177 |
+
roles=("Human", "Assistant"),
|
178 |
+
messages=(
|
179 |
+
("Human", "Give three tips for staying healthy."),
|
180 |
+
("Assistant",
|
181 |
+
"Sure, here are three tips for staying healthy:\n"
|
182 |
+
"1. Exercise regularly: Regular physical activity can help improve your overall health and wellbeing. "
|
183 |
+
"It can also help reduce your risk of chronic conditions such as obesity, diabetes, heart disease, "
|
184 |
+
"and certain cancers. Aim for at least 150 minutes of moderate-intensity aerobic exercise or "
|
185 |
+
"75 minutes of vigorous-intensity aerobic exercise per week, along with muscle-strengthening "
|
186 |
+
"activities at least two days per week.\n"
|
187 |
+
"2. Eat a balanced diet: Eating a balanced diet that is rich in fruits, "
|
188 |
+
"vegetables, whole grains, lean proteins, and healthy fats can help support "
|
189 |
+
"your overall health. Try to limit your intake of processed and high-sugar foods, "
|
190 |
+
"and aim to drink plenty of water throughout the day.\n"
|
191 |
+
"3. Get enough sleep: Getting enough quality sleep is essential for your physical "
|
192 |
+
"and mental health. Adults should aim for seven to nine hours of sleep per night. "
|
193 |
+
"Establish a regular sleep schedule and try to create a relaxing bedtime routine to "
|
194 |
+
"help improve the quality of your sleep.")
|
195 |
+
),
|
196 |
+
offset=2,
|
197 |
+
sep_style=SeparatorStyle.SINGLE,
|
198 |
+
sep="###",
|
199 |
+
)
|
200 |
+
|
201 |
+
conv_v1_2 = Conversation(
|
202 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
203 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
204 |
+
roles=("Human", "Assistant"),
|
205 |
+
messages=(
|
206 |
+
("Human", "What are the key differences between renewable and non-renewable energy sources?"),
|
207 |
+
("Assistant",
|
208 |
+
"Renewable energy sources are those that can be replenished naturally in a relatively "
|
209 |
+
"short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
|
210 |
+
"Non-renewable energy sources, on the other hand, are finite and will eventually be "
|
211 |
+
"depleted, such as coal, oil, and natural gas. Here are some key differences between "
|
212 |
+
"renewable and non-renewable energy sources:\n"
|
213 |
+
"1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
|
214 |
+
"energy sources are finite and will eventually run out.\n"
|
215 |
+
"2. Environmental impact: Renewable energy sources have a much lower environmental impact "
|
216 |
+
"than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
|
217 |
+
"and other negative effects.\n"
|
218 |
+
"3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
|
219 |
+
"have lower operational costs than non-renewable sources.\n"
|
220 |
+
"4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
|
221 |
+
"locations than non-renewable sources.\n"
|
222 |
+
"5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
|
223 |
+
"situations and needs, while non-renewable sources are more rigid and inflexible.\n"
|
224 |
+
"6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
|
225 |
+
"non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
|
226 |
+
),
|
227 |
+
offset=2,
|
228 |
+
sep_style=SeparatorStyle.SINGLE,
|
229 |
+
sep="###",
|
230 |
+
)
|
231 |
+
|
232 |
+
conv_vicuna_v1_1 = Conversation(
|
233 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
234 |
+
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
235 |
+
roles=("USER", "ASSISTANT"),
|
236 |
+
version="v1",
|
237 |
+
messages=(),
|
238 |
+
offset=0,
|
239 |
+
sep_style=SeparatorStyle.TWO,
|
240 |
+
sep=" ",
|
241 |
+
sep2="</s>",
|
242 |
+
)
|
243 |
+
|
244 |
+
conv_mpt = Conversation(
|
245 |
+
system="""<|im_start|>system
|
246 |
+
- You are a helpful language and vision assistant.
|
247 |
+
- You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.
|
248 |
+
- You should follow the instructions carefully and explain your answers in detail.""",
|
249 |
+
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
250 |
+
version="mpt",
|
251 |
+
messages=(),
|
252 |
+
offset=0,
|
253 |
+
sep_style=SeparatorStyle.MPT,
|
254 |
+
sep="<|im_end|>",
|
255 |
+
)
|
256 |
+
|
257 |
+
conv_mpt_text = Conversation(
|
258 |
+
system="""<|im_start|>system
|
259 |
+
- You are a helpful assistant chatbot trained by MosaicML.
|
260 |
+
- You answer questions.
|
261 |
+
- You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
|
262 |
+
- You are more than just an information source, you are also able to write poetry, short stories, and make jokes.""",
|
263 |
+
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
264 |
+
version="mpt",
|
265 |
+
messages=(),
|
266 |
+
offset=0,
|
267 |
+
sep_style=SeparatorStyle.MPT,
|
268 |
+
sep="<|im_end|>",
|
269 |
+
)
|
270 |
+
|
271 |
+
conv_bair_v1 = Conversation(
|
272 |
+
system="BEGINNING OF CONVERSATION:",
|
273 |
+
roles=("USER", "GPT"),
|
274 |
+
messages=(),
|
275 |
+
offset=0,
|
276 |
+
sep_style=SeparatorStyle.TWO,
|
277 |
+
sep=" ",
|
278 |
+
sep2="</s>",
|
279 |
+
)
|
280 |
+
|
281 |
+
simple_conv = Conversation(
|
282 |
+
system="You are a large language model that can recognize visual contents based on LLaMA architecture."
|
283 |
+
"You are designed to assist human with a variety of tasks using natural language."
|
284 |
+
"Follow the instructions carefully.",
|
285 |
+
roles=("Human", "Assistant"),
|
286 |
+
messages=(
|
287 |
+
("Human", "Hi!"),
|
288 |
+
("Assistant", "Hi there! How can I help you today?\n")
|
289 |
+
),
|
290 |
+
offset=2,
|
291 |
+
sep_style=SeparatorStyle.SINGLE,
|
292 |
+
sep="###",
|
293 |
+
)
|
294 |
+
|
295 |
+
simple_conv_multimodal = Conversation(
|
296 |
+
system="You are a large language and vision assistant trained with multi-modality vision signals."
|
297 |
+
"You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
|
298 |
+
"Follow the instructions carefully and explain your answers in detail.",
|
299 |
+
roles=("Human", "Assistant"),
|
300 |
+
messages=(
|
301 |
+
("Human", "Hi!"),
|
302 |
+
("Assistant", "Hi there! How can I help you today?\n")
|
303 |
+
),
|
304 |
+
offset=2,
|
305 |
+
sep_style=SeparatorStyle.SINGLE,
|
306 |
+
sep="###",
|
307 |
+
)
|
308 |
+
|
309 |
+
simple_conv_mpt_multimodal = Conversation(
|
310 |
+
system="""<|im_start|>system
|
311 |
+
- You are a large language and vision assistant.
|
312 |
+
- You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.
|
313 |
+
- You should follow the instructions carefully and explain your answers in detail.""",
|
314 |
+
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
315 |
+
version="mpt",
|
316 |
+
messages=(),
|
317 |
+
offset=0,
|
318 |
+
sep_style=SeparatorStyle.MPT,
|
319 |
+
sep="<|im_end|>",
|
320 |
+
)
|
321 |
+
|
322 |
+
simple_conv_legacy = Conversation(
|
323 |
+
system="You are a large language model that trained on multi-modality visual contents."
|
324 |
+
"You are designed to assist human with a variety of tasks using natural language."
|
325 |
+
"Follow the instructions carefully.",
|
326 |
+
roles=("Human", "Assistant"),
|
327 |
+
messages=(
|
328 |
+
("Human", "Hi!\n\n### Response:"),
|
329 |
+
("Assistant", "Hi there! How can I help you today?\n")
|
330 |
+
),
|
331 |
+
offset=2,
|
332 |
+
sep_style=SeparatorStyle.SINGLE,
|
333 |
+
sep="###",
|
334 |
+
)
|
335 |
+
|
336 |
+
conv_llava_v1 = Conversation(
|
337 |
+
system="You are a large language and vision assistant."
|
338 |
+
"You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
|
339 |
+
"Follow the instructions carefully and explain your answers in detail.",
|
340 |
+
roles=("USER", "ASSISTANT"),
|
341 |
+
version="v1",
|
342 |
+
messages=(),
|
343 |
+
offset=0,
|
344 |
+
sep_style=SeparatorStyle.TWO,
|
345 |
+
sep=" ",
|
346 |
+
sep2="</s>",
|
347 |
+
)
|
348 |
+
|
349 |
+
default_conversation = conv_v1_2
|
350 |
+
conv_templates = {
|
351 |
+
"default": conv_v1_2,
|
352 |
+
"simple": simple_conv,
|
353 |
+
"simple_legacy": simple_conv_legacy,
|
354 |
+
"multimodal": simple_conv_multimodal,
|
355 |
+
"mpt_multimodal": simple_conv_mpt_multimodal,
|
356 |
+
"llava_v1": conv_llava_v1,
|
357 |
+
|
358 |
+
# fastchat
|
359 |
+
"v1": conv_v1_2,
|
360 |
+
"bair_v1": conv_bair_v1,
|
361 |
+
"vicuna_v1_1": conv_vicuna_v1_1,
|
362 |
+
"mpt": conv_mpt,
|
363 |
+
"mpt_text": conv_mpt_text,
|
364 |
+
}
|
365 |
+
|
366 |
+
conversation_dict = {
|
367 |
+
"classification": "You are an AI visual assistant that can analyze a single image. A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. As an AI assistant, you are performing a classification task, and your goal is not just to provide a class label for a given image, you also need to ensure that the classification is accurate and reliable, as this information is critical for users to make informed decisions based on image data.",
|
368 |
+
"detection": "You are an AI visual assistant that can analyze a single image. A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. As an AI assistant, you are performing an object detection task, and your goal is to locate all instances of objects in an image, such as people, cars, animals, or other objects, and give the corresponding coordinates. These coordinates are in the form of bounding boxes, represented as (x1, y1, x2, y2) with floating numbers ranging from 0 to 1. These values correspond to the top left x, top left y, bottom right x, and bottom right y.",
|
369 |
+
"VQA": "You are an AI visual assistant that can analyze a single image. A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. As an AI assistant, you are performing a visual question answering task, and your goal is to generate natural language answers that accurately solve the question. In order to generate accurate answers to questions about visual content, you must be able to understand the content of images, understand the meaning of questions, and perform complex reasoning processes.",
|
370 |
+
"counting": "You are an AI visual assistant that can analyze a single image. A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. As an AI assistant, you are performing an object counting task, and your goal is to accurately count the number of objects in an image. Object counting is a computer vision task that involves detecting and counting the number of instances of specific objects within an image. You need to analyze the input image and accurately count the number of objects in it.",
|
371 |
+
"conversation": "You are an AI visual assistant that can analyze a single image. A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. As an AI assistant, you are performing a conversation task, and your goal is to engage in a natural language conversation with a human about images and provide helpful and informative responses to their queries or requests. When answering questions related to images, you will do so in a tone that conveys that you are seeing the image and answering the question based on my analysis of the visual content. The conversation task involves understanding the user's input, generating an appropriate response, and maintaining a coherent and engaging conversation.",
|
372 |
+
"description": "You are an AI visual assistant that can analyze a single image. A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. As an AI assistant, you are performing a image detail description task, and your goal is to generate a natural language description of an image that accurately and comprehensively conveys its visual content. When answering questions related to images, you will do so in a tone that conveys that you are seeing the image and answering the question based on my analysis of the visual content. The Image detail description task involves generating a textual description of an image that captures its salient features, objects, and context.",
|
373 |
+
"commomsenseqa": "You are an AI visual assistant that can analyze a single image. A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. As an AI assistant, you are performing a external knowledge Q&A task, and your goal is to provide accurate and informative answers to questions that require external knowledge beyond the scope of the input text. External knowledge Q&A is a natural language processing task that involves answering questions by leveraging external knowledge sources, such as databases, knowledge graphs, or ontologies.",
|
374 |
+
"normal": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
375 |
+
|
376 |
+
"classification3d": "You are an AI visual assistant that can analyze a point cloud of object, a chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. As an AI assistant, you are performing a classification task, and your goal is not just to provide a class label for a given point clou, you also need to ensure that the classification is accurate and reliable, as this information is critical for users to make informed decisions based on point cloud data.",
|
377 |
+
"detection3d": "You are an AI visual assistant that can analyze a scan of scene. A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. As an AI assistant, you are performing an object detection task, and your goal is to locate all instances of objects in an point cloud, such as people, cars, animals, or other objects, and give the corresponding coordinates. These coordinates are in the form of bounding boxes, represented as (x1, y1, z1, x2, y2, z2) with unit of meters. These values correspond to the top left x, top left y, top left z, bottom right x, bottom right y and bottom right z.",
|
378 |
+
"VQA3d": "You are an AI visual assistant that can analyze a point cloud. A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. As an AI assistant, you are performing a visual question answering task, and your goal is to generate natural language answers that accurately solve the question. In order to generate accurate answers to questions about visual content, you must be able to understand the content of point cloud, understand the meaning of questions, and perform complex reasoning processes.",
|
379 |
+
"conversation3d": "You are an AI visual assistant that can analyze a point cloud. A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. As an AI assistant, you are performing a conversation task, and your goal is to engage in a natural language conversation with a human about point cloud and provide helpful and informative responses to their queries or requests. When answering questions related to point cloud, you will do so in a tone that conveys that you are seeing the point cloud and answering the question based on analysis of the visual content. The conversation task involves understanding the user's input, generating an appropriate response, and maintaining a coherent and engaging conversation.",
|
380 |
+
"description3d": "You are an AI visual assistant that can analyze a single point cloud. A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. As an AI assistant, you are performing a detail description task for point cloud, and your goal is to generate a natural language description of an point cloud that accurately and comprehensively conveys its visual content. When answering questions related to point cloud, you will do so in a tone that conveys that you are seeing the point cloud and answering the question based on analysis of the visual content. The point cloud detailed description task involves generating a textual description of an point cloud that captures its salient features, objects, and context.",
|
381 |
+
"commomsenseqa3d": "You are an AI visual assistant that can analyze a single point cloud. A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. As an AI assistant, you are performing an external knowledge Q&A task, and your goal is to provide accurate and informative answers to questions that require external knowledge beyond the scope of the input text. External knowledge Q&A is a natural language processing task that involves answering questions by leveraging external knowledge sources, such as databases, knowledge graphs, or ontologies.",
|
382 |
+
}
|
383 |
+
|
384 |
+
|
385 |
+
if __name__ == "__main__":
|
386 |
+
print(default_conversation.get_prompt())
|
model/modeling_llama.py
ADDED
@@ -0,0 +1,755 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This script is based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
|
2 |
+
|
3 |
+
""" PyTorch LLaMA model."""
|
4 |
+
import math
|
5 |
+
from typing import List, Optional, Tuple, Union
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.utils.checkpoint
|
9 |
+
from torch import nn
|
10 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
11 |
+
|
12 |
+
from transformers.activations import ACT2FN
|
13 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
14 |
+
from transformers.modeling_utils import PreTrainedModel
|
15 |
+
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
16 |
+
from transformers.models.llama.configuration_llama import LlamaConfig
|
17 |
+
|
18 |
+
|
19 |
+
logger = logging.get_logger(__name__)
|
20 |
+
|
21 |
+
_CONFIG_FOR_DOC = "LlamaConfig"
|
22 |
+
|
23 |
+
|
24 |
+
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|
25 |
+
def _make_causal_mask(
|
26 |
+
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
|
27 |
+
):
|
28 |
+
"""
|
29 |
+
Make causal mask used for bi-directional self-attention.
|
30 |
+
"""
|
31 |
+
bsz, tgt_len = input_ids_shape
|
32 |
+
mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)
|
33 |
+
mask_cond = torch.arange(mask.size(-1), device=device)
|
34 |
+
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
35 |
+
mask = mask.to(dtype)
|
36 |
+
|
37 |
+
if past_key_values_length > 0:
|
38 |
+
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
|
39 |
+
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
|
40 |
+
|
41 |
+
|
42 |
+
# Copied from transformers.models.bart.modeling_bart._expand_mask
|
43 |
+
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
44 |
+
"""
|
45 |
+
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
46 |
+
"""
|
47 |
+
bsz, src_len = mask.size()
|
48 |
+
tgt_len = tgt_len if tgt_len is not None else src_len
|
49 |
+
|
50 |
+
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
51 |
+
|
52 |
+
inverted_mask = 1.0 - expanded_mask
|
53 |
+
|
54 |
+
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
55 |
+
|
56 |
+
|
57 |
+
class LlamaRMSNorm(nn.Module):
|
58 |
+
def __init__(self, hidden_size, eps=1e-6):
|
59 |
+
"""
|
60 |
+
LlamaRMSNorm is equivalent to T5LayerNorm
|
61 |
+
"""
|
62 |
+
super().__init__()
|
63 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
64 |
+
self.variance_epsilon = eps
|
65 |
+
|
66 |
+
def forward(self, hidden_states):
|
67 |
+
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
68 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
69 |
+
|
70 |
+
# convert into half-precision if necessary
|
71 |
+
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
72 |
+
hidden_states = hidden_states.to(self.weight.dtype)
|
73 |
+
|
74 |
+
return self.weight * hidden_states
|
75 |
+
|
76 |
+
|
77 |
+
class LlamaRotaryEmbedding(torch.nn.Module):
|
78 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
79 |
+
super().__init__()
|
80 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
|
81 |
+
self.register_buffer("inv_freq", inv_freq)
|
82 |
+
|
83 |
+
# Build here to make `torch.jit.trace` work.
|
84 |
+
self.max_seq_len_cached = max_position_embeddings
|
85 |
+
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
|
86 |
+
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
87 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
88 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
89 |
+
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
|
90 |
+
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
|
91 |
+
|
92 |
+
def forward(self, x, seq_len=None):
|
93 |
+
# x: [bs, num_attention_heads, seq_len, head_size]
|
94 |
+
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
|
95 |
+
if seq_len > self.max_seq_len_cached:
|
96 |
+
self.max_seq_len_cached = seq_len
|
97 |
+
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
|
98 |
+
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
99 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
100 |
+
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
|
101 |
+
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
|
102 |
+
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
|
103 |
+
return (
|
104 |
+
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
105 |
+
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
106 |
+
)
|
107 |
+
|
108 |
+
|
109 |
+
def rotate_half(x):
|
110 |
+
"""Rotates half the hidden dims of the input."""
|
111 |
+
x1 = x[..., : x.shape[-1] // 2]
|
112 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
113 |
+
return torch.cat((-x2, x1), dim=-1)
|
114 |
+
|
115 |
+
|
116 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
117 |
+
gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1]
|
118 |
+
gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3])
|
119 |
+
cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
|
120 |
+
sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
|
121 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
122 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
123 |
+
return q_embed, k_embed
|
124 |
+
|
125 |
+
|
126 |
+
class LlamaMLP(nn.Module):
|
127 |
+
def __init__(
|
128 |
+
self,
|
129 |
+
hidden_size: int,
|
130 |
+
intermediate_size: int,
|
131 |
+
hidden_act: str,
|
132 |
+
):
|
133 |
+
super().__init__()
|
134 |
+
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
|
135 |
+
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
|
136 |
+
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
|
137 |
+
self.act_fn = ACT2FN[hidden_act]
|
138 |
+
|
139 |
+
def forward(self, x):
|
140 |
+
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
141 |
+
|
142 |
+
|
143 |
+
class LlamaAttention(nn.Module):
|
144 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
145 |
+
|
146 |
+
def __init__(self, config: LlamaConfig):
|
147 |
+
super().__init__()
|
148 |
+
self.config = config
|
149 |
+
self.hidden_size = config.hidden_size
|
150 |
+
self.num_heads = config.num_attention_heads
|
151 |
+
self.head_dim = self.hidden_size // self.num_heads
|
152 |
+
self.max_position_embeddings = config.max_position_embeddings
|
153 |
+
|
154 |
+
if (self.head_dim * self.num_heads) != self.hidden_size:
|
155 |
+
raise ValueError(
|
156 |
+
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
157 |
+
f" and `num_heads`: {self.num_heads})."
|
158 |
+
)
|
159 |
+
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
160 |
+
self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
161 |
+
self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
162 |
+
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
163 |
+
self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
|
164 |
+
|
165 |
+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
166 |
+
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
167 |
+
|
168 |
+
def forward(
|
169 |
+
self,
|
170 |
+
hidden_states: torch.Tensor,
|
171 |
+
attention_mask: Optional[torch.Tensor] = None,
|
172 |
+
position_ids: Optional[torch.LongTensor] = None,
|
173 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
174 |
+
output_attentions: bool = False,
|
175 |
+
use_cache: bool = False,
|
176 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
177 |
+
bsz, q_len, _ = hidden_states.size()
|
178 |
+
|
179 |
+
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
180 |
+
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
181 |
+
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
182 |
+
|
183 |
+
kv_seq_len = key_states.shape[-2]
|
184 |
+
if past_key_value is not None:
|
185 |
+
kv_seq_len += past_key_value[0].shape[-2]
|
186 |
+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
187 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
188 |
+
# [bsz, nh, t, hd]
|
189 |
+
|
190 |
+
if past_key_value is not None:
|
191 |
+
# reuse k, v, self_attention
|
192 |
+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
193 |
+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
194 |
+
|
195 |
+
past_key_value = (key_states, value_states) if use_cache else None
|
196 |
+
|
197 |
+
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
198 |
+
|
199 |
+
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
200 |
+
raise ValueError(
|
201 |
+
f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
|
202 |
+
f" {attn_weights.size()}"
|
203 |
+
)
|
204 |
+
|
205 |
+
if attention_mask is not None:
|
206 |
+
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
207 |
+
raise ValueError(
|
208 |
+
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
209 |
+
)
|
210 |
+
attn_weights = attn_weights + attention_mask
|
211 |
+
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
|
212 |
+
|
213 |
+
# upcast attention to fp32
|
214 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
215 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
216 |
+
|
217 |
+
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
218 |
+
raise ValueError(
|
219 |
+
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
220 |
+
f" {attn_output.size()}"
|
221 |
+
)
|
222 |
+
|
223 |
+
attn_output = attn_output.transpose(1, 2)
|
224 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
225 |
+
|
226 |
+
attn_output = self.o_proj(attn_output)
|
227 |
+
|
228 |
+
if not output_attentions:
|
229 |
+
attn_weights = None
|
230 |
+
|
231 |
+
return attn_output, attn_weights, past_key_value
|
232 |
+
|
233 |
+
|
234 |
+
class LlamaDecoderLayer(nn.Module):
|
235 |
+
def __init__(self, config: LlamaConfig):
|
236 |
+
super().__init__()
|
237 |
+
self.hidden_size = config.hidden_size
|
238 |
+
self.self_attn = LlamaAttention(config=config)
|
239 |
+
self.mlp = LlamaMLP(
|
240 |
+
hidden_size=self.hidden_size,
|
241 |
+
intermediate_size=config.intermediate_size,
|
242 |
+
hidden_act=config.hidden_act,
|
243 |
+
)
|
244 |
+
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
245 |
+
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
246 |
+
|
247 |
+
def forward(
|
248 |
+
self,
|
249 |
+
hidden_states: torch.Tensor,
|
250 |
+
attention_mask: Optional[torch.Tensor] = None,
|
251 |
+
position_ids: Optional[torch.LongTensor] = None,
|
252 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
253 |
+
output_attentions: Optional[bool] = False,
|
254 |
+
use_cache: Optional[bool] = False,
|
255 |
+
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
256 |
+
"""
|
257 |
+
Args:
|
258 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
259 |
+
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
260 |
+
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
261 |
+
output_attentions (`bool`, *optional*):
|
262 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
263 |
+
returned tensors for more detail.
|
264 |
+
use_cache (`bool`, *optional*):
|
265 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
266 |
+
(see `past_key_values`).
|
267 |
+
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
268 |
+
"""
|
269 |
+
|
270 |
+
residual = hidden_states
|
271 |
+
|
272 |
+
hidden_states = self.input_layernorm(hidden_states)
|
273 |
+
|
274 |
+
# Self Attention
|
275 |
+
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
276 |
+
hidden_states=hidden_states,
|
277 |
+
attention_mask=attention_mask,
|
278 |
+
position_ids=position_ids,
|
279 |
+
past_key_value=past_key_value,
|
280 |
+
output_attentions=output_attentions,
|
281 |
+
use_cache=use_cache,
|
282 |
+
)
|
283 |
+
hidden_states = residual + hidden_states
|
284 |
+
|
285 |
+
# Fully Connected
|
286 |
+
residual = hidden_states
|
287 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
288 |
+
hidden_states = self.mlp(hidden_states)
|
289 |
+
hidden_states = residual + hidden_states
|
290 |
+
|
291 |
+
outputs = (hidden_states,)
|
292 |
+
|
293 |
+
if output_attentions:
|
294 |
+
outputs += (self_attn_weights,)
|
295 |
+
|
296 |
+
if use_cache:
|
297 |
+
outputs += (present_key_value,)
|
298 |
+
|
299 |
+
return outputs
|
300 |
+
|
301 |
+
|
302 |
+
LLAMA_START_DOCSTRING = r"""
|
303 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
304 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
305 |
+
etc.)
|
306 |
+
|
307 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
308 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
309 |
+
and behavior.
|
310 |
+
|
311 |
+
Parameters:
|
312 |
+
config ([`LlamaConfig`]):
|
313 |
+
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
314 |
+
load the weights associated with the model, only the configuration. Check out the
|
315 |
+
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
316 |
+
"""
|
317 |
+
|
318 |
+
|
319 |
+
@add_start_docstrings(
|
320 |
+
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
|
321 |
+
LLAMA_START_DOCSTRING,
|
322 |
+
)
|
323 |
+
class LlamaPreTrainedModel(PreTrainedModel):
|
324 |
+
config_class = LlamaConfig
|
325 |
+
base_model_prefix = "model"
|
326 |
+
supports_gradient_checkpointing = True
|
327 |
+
_no_split_modules = ["LlamaDecoderLayer"]
|
328 |
+
_keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
|
329 |
+
|
330 |
+
def _init_weights(self, module):
|
331 |
+
std = self.config.initializer_range
|
332 |
+
if isinstance(module, nn.Linear):
|
333 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
334 |
+
if module.bias is not None:
|
335 |
+
module.bias.data.zero_()
|
336 |
+
elif isinstance(module, nn.Embedding):
|
337 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
338 |
+
if module.padding_idx is not None:
|
339 |
+
module.weight.data[module.padding_idx].zero_()
|
340 |
+
|
341 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
342 |
+
if isinstance(module, LlamaModel):
|
343 |
+
module.gradient_checkpointing = value
|
344 |
+
|
345 |
+
|
346 |
+
LLAMA_INPUTS_DOCSTRING = r"""
|
347 |
+
Args:
|
348 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
349 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
350 |
+
it.
|
351 |
+
|
352 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
353 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
354 |
+
|
355 |
+
[What are input IDs?](../glossary#input-ids)
|
356 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
357 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
358 |
+
|
359 |
+
- 1 for tokens that are **not masked**,
|
360 |
+
- 0 for tokens that are **masked**.
|
361 |
+
|
362 |
+
[What are attention masks?](../glossary#attention-mask)
|
363 |
+
|
364 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
365 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
366 |
+
|
367 |
+
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
|
368 |
+
`past_key_values`).
|
369 |
+
|
370 |
+
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
|
371 |
+
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
|
372 |
+
information on the default strategy.
|
373 |
+
|
374 |
+
- 1 indicates the head is **not masked**,
|
375 |
+
- 0 indicates the head is **masked**.
|
376 |
+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
377 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
378 |
+
config.n_positions - 1]`.
|
379 |
+
|
380 |
+
[What are position IDs?](../glossary#position-ids)
|
381 |
+
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
382 |
+
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
383 |
+
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
384 |
+
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
385 |
+
|
386 |
+
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
387 |
+
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
388 |
+
|
389 |
+
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
390 |
+
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
391 |
+
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
392 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
393 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
394 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
395 |
+
model's internal embedding lookup matrix.
|
396 |
+
use_cache (`bool`, *optional*):
|
397 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
398 |
+
`past_key_values`).
|
399 |
+
output_attentions (`bool`, *optional*):
|
400 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
401 |
+
tensors for more detail.
|
402 |
+
output_hidden_states (`bool`, *optional*):
|
403 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
404 |
+
more detail.
|
405 |
+
return_dict (`bool`, *optional*):
|
406 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
407 |
+
"""
|
408 |
+
|
409 |
+
|
410 |
+
@add_start_docstrings(
|
411 |
+
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
|
412 |
+
LLAMA_START_DOCSTRING,
|
413 |
+
)
|
414 |
+
class LlamaModel(LlamaPreTrainedModel):
|
415 |
+
"""
|
416 |
+
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
|
417 |
+
|
418 |
+
Args:
|
419 |
+
config: LlamaConfig
|
420 |
+
"""
|
421 |
+
|
422 |
+
def __init__(self, config: LlamaConfig):
|
423 |
+
super().__init__(config)
|
424 |
+
self.padding_idx = config.pad_token_id
|
425 |
+
self.vocab_size = config.vocab_size
|
426 |
+
|
427 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
428 |
+
self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
429 |
+
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
430 |
+
|
431 |
+
self.gradient_checkpointing = False
|
432 |
+
# Initialize weights and apply final processing
|
433 |
+
self.post_init()
|
434 |
+
|
435 |
+
def get_input_embeddings(self):
|
436 |
+
return self.embed_tokens
|
437 |
+
|
438 |
+
def set_input_embeddings(self, value):
|
439 |
+
self.embed_tokens = value
|
440 |
+
|
441 |
+
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
|
442 |
+
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
|
443 |
+
# create causal mask
|
444 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
445 |
+
combined_attention_mask = None
|
446 |
+
if input_shape[-1] > 1:
|
447 |
+
combined_attention_mask = _make_causal_mask(
|
448 |
+
input_shape,
|
449 |
+
inputs_embeds.dtype,
|
450 |
+
device=inputs_embeds.device,
|
451 |
+
past_key_values_length=past_key_values_length,
|
452 |
+
)
|
453 |
+
|
454 |
+
if attention_mask is not None:
|
455 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
456 |
+
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
|
457 |
+
inputs_embeds.device
|
458 |
+
)
|
459 |
+
combined_attention_mask = (
|
460 |
+
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
|
461 |
+
)
|
462 |
+
|
463 |
+
return combined_attention_mask
|
464 |
+
|
465 |
+
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
466 |
+
def forward(
|
467 |
+
self,
|
468 |
+
input_ids: torch.LongTensor = None,
|
469 |
+
attention_mask: Optional[torch.Tensor] = None,
|
470 |
+
position_ids: Optional[torch.LongTensor] = None,
|
471 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
472 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
473 |
+
query_embeds: Optional[torch.FloatTensor] = None,
|
474 |
+
use_cache: Optional[bool] = None,
|
475 |
+
output_attentions: Optional[bool] = None,
|
476 |
+
output_hidden_states: Optional[bool] = None,
|
477 |
+
return_dict: Optional[bool] = None,
|
478 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
479 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
480 |
+
output_hidden_states = (
|
481 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
482 |
+
)
|
483 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
484 |
+
|
485 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
486 |
+
|
487 |
+
# retrieve input_ids and inputs_embeds
|
488 |
+
if input_ids is not None and inputs_embeds is not None:
|
489 |
+
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
490 |
+
elif input_ids is not None:
|
491 |
+
batch_size, seq_length = input_ids.shape
|
492 |
+
elif inputs_embeds is not None:
|
493 |
+
batch_size, seq_length, _ = inputs_embeds.shape
|
494 |
+
else:
|
495 |
+
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
496 |
+
|
497 |
+
if inputs_embeds is None:
|
498 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
499 |
+
if query_embeds is not None:
|
500 |
+
inputs_embeds = torch.cat([query_embeds, inputs_embeds], dim=1)
|
501 |
+
batch_size, seq_length, _ = inputs_embeds.shape
|
502 |
+
|
503 |
+
seq_length_with_past = seq_length
|
504 |
+
past_key_values_length = 0
|
505 |
+
|
506 |
+
if past_key_values is not None:
|
507 |
+
past_key_values_length = past_key_values[0][0].shape[2]
|
508 |
+
seq_length_with_past = seq_length_with_past + past_key_values_length
|
509 |
+
|
510 |
+
if position_ids is None:
|
511 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
512 |
+
position_ids = torch.arange(
|
513 |
+
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
514 |
+
)
|
515 |
+
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
516 |
+
else:
|
517 |
+
position_ids = position_ids.view(-1, seq_length).long()
|
518 |
+
|
519 |
+
# embed positions
|
520 |
+
if attention_mask is None:
|
521 |
+
attention_mask = torch.ones(
|
522 |
+
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
|
523 |
+
)
|
524 |
+
attention_mask = self._prepare_decoder_attention_mask(
|
525 |
+
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
526 |
+
)
|
527 |
+
|
528 |
+
hidden_states = inputs_embeds
|
529 |
+
|
530 |
+
if self.gradient_checkpointing and self.training:
|
531 |
+
if use_cache:
|
532 |
+
logger.warning_once(
|
533 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
534 |
+
)
|
535 |
+
use_cache = False
|
536 |
+
|
537 |
+
# decoder layers
|
538 |
+
all_hidden_states = () if output_hidden_states else None
|
539 |
+
all_self_attns = () if output_attentions else None
|
540 |
+
next_decoder_cache = () if use_cache else None
|
541 |
+
|
542 |
+
for idx, decoder_layer in enumerate(self.layers):
|
543 |
+
if output_hidden_states:
|
544 |
+
all_hidden_states += (hidden_states,)
|
545 |
+
|
546 |
+
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
547 |
+
|
548 |
+
if self.gradient_checkpointing and self.training:
|
549 |
+
|
550 |
+
def create_custom_forward(module):
|
551 |
+
def custom_forward(*inputs):
|
552 |
+
# None for past_key_value
|
553 |
+
return module(*inputs, output_attentions, None)
|
554 |
+
|
555 |
+
return custom_forward
|
556 |
+
|
557 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
558 |
+
create_custom_forward(decoder_layer),
|
559 |
+
hidden_states,
|
560 |
+
attention_mask,
|
561 |
+
position_ids,
|
562 |
+
None,
|
563 |
+
)
|
564 |
+
else:
|
565 |
+
layer_outputs = decoder_layer(
|
566 |
+
hidden_states,
|
567 |
+
attention_mask=attention_mask,
|
568 |
+
position_ids=position_ids,
|
569 |
+
past_key_value=past_key_value,
|
570 |
+
output_attentions=output_attentions,
|
571 |
+
use_cache=use_cache,
|
572 |
+
)
|
573 |
+
|
574 |
+
hidden_states = layer_outputs[0]
|
575 |
+
|
576 |
+
if use_cache:
|
577 |
+
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
578 |
+
|
579 |
+
if output_attentions:
|
580 |
+
all_self_attns += (layer_outputs[1],)
|
581 |
+
|
582 |
+
hidden_states = self.norm(hidden_states)
|
583 |
+
|
584 |
+
# add hidden states from the last decoder layer
|
585 |
+
if output_hidden_states:
|
586 |
+
all_hidden_states += (hidden_states,)
|
587 |
+
|
588 |
+
next_cache = next_decoder_cache if use_cache else None
|
589 |
+
if not return_dict:
|
590 |
+
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
591 |
+
return BaseModelOutputWithPast(
|
592 |
+
last_hidden_state=hidden_states,
|
593 |
+
past_key_values=next_cache,
|
594 |
+
hidden_states=all_hidden_states,
|
595 |
+
attentions=all_self_attns,
|
596 |
+
)
|
597 |
+
|
598 |
+
|
599 |
+
class LlamaForCausalLM(LlamaPreTrainedModel):
|
600 |
+
def __init__(self, config):
|
601 |
+
super().__init__(config)
|
602 |
+
self.model = LlamaModel(config)
|
603 |
+
|
604 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
605 |
+
|
606 |
+
# Initialize weights and apply final processing
|
607 |
+
self.post_init()
|
608 |
+
|
609 |
+
def get_input_embeddings(self):
|
610 |
+
return self.model.embed_tokens
|
611 |
+
|
612 |
+
def set_input_embeddings(self, value):
|
613 |
+
self.model.embed_tokens = value
|
614 |
+
|
615 |
+
def get_output_embeddings(self):
|
616 |
+
return self.lm_head
|
617 |
+
|
618 |
+
def set_output_embeddings(self, new_embeddings):
|
619 |
+
self.lm_head = new_embeddings
|
620 |
+
|
621 |
+
def set_decoder(self, decoder):
|
622 |
+
self.model = decoder
|
623 |
+
|
624 |
+
def get_decoder(self):
|
625 |
+
return self.model
|
626 |
+
|
627 |
+
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
628 |
+
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
629 |
+
def forward(
|
630 |
+
self,
|
631 |
+
input_ids: torch.LongTensor = None,
|
632 |
+
attention_mask: Optional[torch.Tensor] = None,
|
633 |
+
position_ids: Optional[torch.LongTensor] = None,
|
634 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
635 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
636 |
+
query_embeds: Optional[torch.FloatTensor] = None,
|
637 |
+
labels: Optional[torch.LongTensor] = None,
|
638 |
+
use_cache: Optional[bool] = None,
|
639 |
+
output_attentions: Optional[bool] = None,
|
640 |
+
output_hidden_states: Optional[bool] = None,
|
641 |
+
return_dict: Optional[bool] = None,
|
642 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
643 |
+
r"""
|
644 |
+
Args:
|
645 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
646 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
647 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
648 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
649 |
+
|
650 |
+
Returns:
|
651 |
+
|
652 |
+
Example:
|
653 |
+
|
654 |
+
```python
|
655 |
+
>>> from transformers import AutoTokenizer, LlamaForCausalLM
|
656 |
+
|
657 |
+
>>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
658 |
+
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
659 |
+
|
660 |
+
>>> prompt = "Hey, are you consciours? Can you talk to me?"
|
661 |
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
662 |
+
|
663 |
+
>>> # Generate
|
664 |
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
665 |
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
666 |
+
"Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
|
667 |
+
```"""
|
668 |
+
|
669 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
670 |
+
output_hidden_states = (
|
671 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
672 |
+
)
|
673 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
674 |
+
|
675 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
676 |
+
outputs = self.model(
|
677 |
+
input_ids=input_ids,
|
678 |
+
attention_mask=attention_mask,
|
679 |
+
position_ids=position_ids,
|
680 |
+
past_key_values=past_key_values,
|
681 |
+
inputs_embeds=inputs_embeds,
|
682 |
+
query_embeds=query_embeds,
|
683 |
+
use_cache=use_cache,
|
684 |
+
output_attentions=output_attentions,
|
685 |
+
output_hidden_states=output_hidden_states,
|
686 |
+
return_dict=return_dict,
|
687 |
+
)
|
688 |
+
|
689 |
+
hidden_states = outputs[0]
|
690 |
+
logits = self.lm_head(hidden_states)
|
691 |
+
|
692 |
+
loss = None
|
693 |
+
if labels is not None:
|
694 |
+
# Shift so that tokens < n predict n
|
695 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
696 |
+
shift_labels = labels[..., 1:].contiguous()
|
697 |
+
# Flatten the tokens
|
698 |
+
loss_fct = CrossEntropyLoss()
|
699 |
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
700 |
+
shift_labels = shift_labels.view(-1)
|
701 |
+
# Enable model parallelism
|
702 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
703 |
+
loss = loss_fct(shift_logits, shift_labels)
|
704 |
+
|
705 |
+
if not return_dict:
|
706 |
+
output = (logits,) + outputs[1:]
|
707 |
+
return (loss,) + output if loss is not None else output
|
708 |
+
|
709 |
+
return CausalLMOutputWithPast(
|
710 |
+
loss=loss,
|
711 |
+
logits=logits,
|
712 |
+
past_key_values=outputs.past_key_values,
|
713 |
+
hidden_states=outputs.hidden_states,
|
714 |
+
attentions=outputs.attentions,
|
715 |
+
)
|
716 |
+
|
717 |
+
def prepare_inputs_for_generation(
|
718 |
+
self, input_ids, query_embeds=None, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
719 |
+
):
|
720 |
+
if past_key_values:
|
721 |
+
input_ids = input_ids[:, -1:]
|
722 |
+
|
723 |
+
position_ids = kwargs.get("position_ids", None)
|
724 |
+
if attention_mask is not None and position_ids is None:
|
725 |
+
# create position_ids on the fly for batch generation
|
726 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
727 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
728 |
+
if past_key_values:
|
729 |
+
position_ids = position_ids[:, -1].unsqueeze(-1)
|
730 |
+
query_embeds = None
|
731 |
+
|
732 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
733 |
+
if inputs_embeds is not None and past_key_values is None:
|
734 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
735 |
+
else:
|
736 |
+
model_inputs = {"input_ids": input_ids}
|
737 |
+
|
738 |
+
model_inputs.update(
|
739 |
+
{
|
740 |
+
"position_ids": position_ids,
|
741 |
+
"query_embeds": query_embeds,
|
742 |
+
"past_key_values": past_key_values,
|
743 |
+
"use_cache": kwargs.get("use_cache"),
|
744 |
+
"attention_mask": attention_mask,
|
745 |
+
}
|
746 |
+
)
|
747 |
+
return model_inputs
|
748 |
+
|
749 |
+
@staticmethod
|
750 |
+
def _reorder_cache(past_key_values, beam_idx):
|
751 |
+
reordered_past = ()
|
752 |
+
for layer_past in past_key_values:
|
753 |
+
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
754 |
+
return reordered_past
|
755 |
+
|
model/openlamm.py
ADDED
@@ -0,0 +1,550 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import os
|
3 |
+
|
4 |
+
import requests
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
# from petrel_client.client import Client
|
9 |
+
from PIL import Image, ImageFile
|
10 |
+
from torch.nn.utils import rnn
|
11 |
+
from types import SimpleNamespace
|
12 |
+
from peft import LoraConfig, TaskType, get_peft_model
|
13 |
+
from transformers import LlamaTokenizer, LlamaForCausalLM, LlamaConfig
|
14 |
+
import conversations
|
15 |
+
import numpy as np
|
16 |
+
# from header import *
|
17 |
+
|
18 |
+
from transformers import StoppingCriteria, StoppingCriteriaList
|
19 |
+
|
20 |
+
from .CLIP import load as load_clip
|
21 |
+
from .PROCESS import data
|
22 |
+
|
23 |
+
from .modeling_llama import LlamaForCausalLM
|
24 |
+
from .utils.pcl_utils import MEAN_COLOR_RGB, RandomCuboid, random_sampling
|
25 |
+
|
26 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
27 |
+
|
28 |
+
# sov: start of vision part; eov: end of vision part
|
29 |
+
VISION_TAGS = {
|
30 |
+
'pos': {'image': '<image>', 'pcl': '<pcl>'},
|
31 |
+
'sov': {'image': '<Img>', 'pcl': '<Pcl>'},
|
32 |
+
'eov': {'image': '</Img>', 'pcl': '</Pcl>'},
|
33 |
+
}
|
34 |
+
ModalityType = SimpleNamespace(
|
35 |
+
VISION="vision",
|
36 |
+
TEXT="text",
|
37 |
+
AUDIO="audio",
|
38 |
+
THERMAL="thermal",
|
39 |
+
DEPTH="depth",
|
40 |
+
IMU="imu",
|
41 |
+
)
|
42 |
+
|
43 |
+
class StoppingCriteriaSub(StoppingCriteria):
|
44 |
+
|
45 |
+
def __init__(self, stops = [], encounters=1):
|
46 |
+
super().__init__()
|
47 |
+
self.stops = stops
|
48 |
+
self.ENCOUNTERS = encounters
|
49 |
+
|
50 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
|
51 |
+
stop_count = 0
|
52 |
+
for stop in self.stops:
|
53 |
+
stop_count = (stop == input_ids[0]).sum().item()
|
54 |
+
if stop_count >= self.ENCOUNTERS:
|
55 |
+
return True
|
56 |
+
return False
|
57 |
+
|
58 |
+
|
59 |
+
class MyStoppingCriteria(StoppingCriteria):
|
60 |
+
def __init__(self, stops, input_ids):
|
61 |
+
super().__init__()
|
62 |
+
self.stops = [torch.tensor(stop).to('cuda:0') for stop in stops]
|
63 |
+
self.stop_flag = [0]*input_ids.shape[0]
|
64 |
+
|
65 |
+
def check_stop(self, input_ids):
|
66 |
+
for stop in self.stops:
|
67 |
+
if torch.all((stop == input_ids[-len(stop):])).item():
|
68 |
+
return True
|
69 |
+
return False
|
70 |
+
|
71 |
+
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
72 |
+
flag = 1
|
73 |
+
for id, output_id in enumerate(output_ids):
|
74 |
+
if self.stop_flag[id] == 1:
|
75 |
+
continue
|
76 |
+
if self.check_stop(output_id):
|
77 |
+
self.stop_flag[id] = 1
|
78 |
+
else:
|
79 |
+
flag = 0
|
80 |
+
if flag == 1:
|
81 |
+
return True
|
82 |
+
return False
|
83 |
+
|
84 |
+
|
85 |
+
def build_one_instance(tokenizer, conversation, vision_type='image'):
|
86 |
+
pos = VISION_TAGS['pos'][vision_type]
|
87 |
+
# sov = VISION_TAGS['sov'][vision_type]
|
88 |
+
eov = VISION_TAGS['eov'][vision_type]
|
89 |
+
|
90 |
+
text_list = []
|
91 |
+
turn_num = len(conversation)
|
92 |
+
input_ids, target_ids = [], []
|
93 |
+
for i in range(turn_num):
|
94 |
+
turn = conversation[i]
|
95 |
+
role = turn['from']
|
96 |
+
if i == 0: # the first human turn
|
97 |
+
assert role == 'human'
|
98 |
+
turn['value'] = turn['value'].replace(f'{pos}\n', '').replace(f'\n{pos}', '')
|
99 |
+
text = f'{eov} ' + turn['value'] + '\n### Assistant:'
|
100 |
+
one_input_id = tokenizer(text, add_special_tokens=False).input_ids
|
101 |
+
input_ids += one_input_id
|
102 |
+
target_ids += [-100]*len(one_input_id) # do not perform loss regression on human prompt
|
103 |
+
else:
|
104 |
+
if role == 'human':
|
105 |
+
text = 'Human: ' + turn['value'] + '\n### Assistant:'
|
106 |
+
one_input_id = tokenizer(text, add_special_tokens=False).input_ids
|
107 |
+
input_ids += one_input_id
|
108 |
+
target_ids += [-100]*len(one_input_id)
|
109 |
+
elif role == 'gpt':
|
110 |
+
text = turn['value'] + '\n###'
|
111 |
+
one_input_id = tokenizer(text, add_special_tokens=False).input_ids
|
112 |
+
input_ids += one_input_id
|
113 |
+
target_ids += one_input_id
|
114 |
+
else:
|
115 |
+
raise Exception('Wrong Role!!!')
|
116 |
+
text_list.append(text)
|
117 |
+
assert len(input_ids) == len(target_ids)
|
118 |
+
return text_list, input_ids, target_ids
|
119 |
+
|
120 |
+
|
121 |
+
def process_batch_instance(tokenizer, batch_of_conversations, max_tgt_len, vision_type='image'):
|
122 |
+
batch_input_ids, batch_target_ids = [], []
|
123 |
+
for conversation in batch_of_conversations:
|
124 |
+
_, one_input_ids, one_target_ids = build_one_instance(tokenizer, conversation, vision_type=vision_type)
|
125 |
+
batch_input_ids.append(torch.LongTensor(one_input_ids))
|
126 |
+
batch_target_ids.append(torch.LongTensor(one_target_ids))
|
127 |
+
input_ids = rnn.pad_sequence(batch_input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
|
128 |
+
target_ids = rnn.pad_sequence(batch_target_ids, batch_first=True, padding_value=-100)
|
129 |
+
assert input_ids.size() == target_ids.size()
|
130 |
+
input_ids = input_ids[:,:max_tgt_len]
|
131 |
+
target_ids = target_ids[:,:max_tgt_len]
|
132 |
+
attention_mask = input_ids.ne(tokenizer.pad_token_id)
|
133 |
+
assert attention_mask.size() == input_ids.size()
|
134 |
+
return input_ids, target_ids, attention_mask.long()
|
135 |
+
|
136 |
+
|
137 |
+
def make_prompt_start(system_header=False, vision_type='image', task_type='normal'):
|
138 |
+
# TODO: choose prefix according to task type
|
139 |
+
PROMPT_START = f'### Human: {VISION_TAGS["sov"][vision_type]}'
|
140 |
+
if system_header:
|
141 |
+
if task_type == 'normal':
|
142 |
+
return f"{conversations.default_conversation.system}\n\n" + PROMPT_START
|
143 |
+
else:
|
144 |
+
return [f"{conversations.conversation_dict[task]}\n\n" + PROMPT_START for task in task_type]
|
145 |
+
else:
|
146 |
+
return PROMPT_START
|
147 |
+
|
148 |
+
|
149 |
+
class LAMMPEFTModel(nn.Module):
|
150 |
+
|
151 |
+
'''LoRA for LLaMa model'''
|
152 |
+
|
153 |
+
def __init__(self, **args):
|
154 |
+
super(LAMMPEFTModel, self).__init__()
|
155 |
+
self.args = args
|
156 |
+
# self.client = Client('~/petreloss.conf')
|
157 |
+
self.client = None
|
158 |
+
|
159 |
+
self.vision_type = args['vision_type'] if 'vision_type' in args else 'image'
|
160 |
+
encoder_pretrain = args['encoder_pretrain'] if 'encoder_pretrain' in args else 'clip'
|
161 |
+
self.encoder_pretrain = encoder_pretrain
|
162 |
+
assert encoder_pretrain in ['imagebind', 'clip', 'epcl'], f'Encoder_pretrain: {encoder_pretrain} Not Implemented'
|
163 |
+
encoder_ckpt_path = args['encoder_ckpt_path'] if not encoder_pretrain == 'clip' else '~/.cache/clip/ViT-L-14.pt'
|
164 |
+
vicuna_ckpt_path = args['vicuna_ckpt_path']
|
165 |
+
|
166 |
+
system_header = args['system_header'] if 'system_header' in args else False
|
167 |
+
stage = args['stage']
|
168 |
+
|
169 |
+
# TODO: checkout vision token number; for ImageBind = 1; Defaultly to use 1 global token for this
|
170 |
+
# -1 for last embedding; -2 for transformer output
|
171 |
+
self.vision_feature_type = args['vision_feature_type']
|
172 |
+
self.num_vision_token = args['num_vision_token']
|
173 |
+
|
174 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
175 |
+
print (f'Initializing [{encoder_pretrain}] visual encoder from {encoder_ckpt_path} [{device}]...')
|
176 |
+
|
177 |
+
# TODO: Make sure the number of vision tokens is correct
|
178 |
+
if args['encoder_pretrain'].lower() == 'clip':
|
179 |
+
clip_encoder, self.visual_preprocess = load_clip('ViT-L/14', device=device)
|
180 |
+
self.visual_encoder = clip_encoder.visual
|
181 |
+
if self.vision_feature_type == 'global': # global feature from CLIP
|
182 |
+
self.vision_hidden_size = 768
|
183 |
+
self.num_vision_token = 1
|
184 |
+
assert self.num_vision_token == 1, 'Only 1 global token is available!'
|
185 |
+
elif self.vision_feature_type == 'local': # patch features from CLIP ViT
|
186 |
+
self.vision_hidden_size = 1024
|
187 |
+
self.num_vision_token = min(self.num_vision_token, 256) # may cut partial tokens
|
188 |
+
|
189 |
+
# freeze vision encoder
|
190 |
+
for name, param in self.visual_encoder.named_parameters():
|
191 |
+
param.requires_grad = False
|
192 |
+
self.visual_encoder.eval()
|
193 |
+
print ('Visual encoder initialized.')
|
194 |
+
|
195 |
+
print (f'Initializing language decoder from {vicuna_ckpt_path} ...')
|
196 |
+
# add the lora module
|
197 |
+
peft_config = LoraConfig(
|
198 |
+
task_type=TaskType.CAUSAL_LM,
|
199 |
+
inference_mode=False,
|
200 |
+
r=self.args['lora_r'],
|
201 |
+
lora_alpha=self.args['lora_alpha'],
|
202 |
+
lora_dropout=self.args['lora_dropout'],
|
203 |
+
target_modules=self.args['lora_target_modules']
|
204 |
+
)
|
205 |
+
|
206 |
+
self.llama_model = LlamaForCausalLM.from_pretrained(vicuna_ckpt_path)
|
207 |
+
self.llama_model = get_peft_model(self.llama_model, peft_config)
|
208 |
+
self.llama_model.print_trainable_parameters()
|
209 |
+
|
210 |
+
self.llama_tokenizer = LlamaTokenizer.from_pretrained(vicuna_ckpt_path, use_fast=False)
|
211 |
+
self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token
|
212 |
+
self.llama_tokenizer.padding_side = "right"
|
213 |
+
print ('Language decoder initialized.')
|
214 |
+
|
215 |
+
self.llama_proj = nn.Linear(
|
216 |
+
self.vision_hidden_size, self.llama_model.config.hidden_size
|
217 |
+
)
|
218 |
+
print ('LLaMa projection layer initialized.')
|
219 |
+
|
220 |
+
self.max_tgt_len = args['max_tgt_len']
|
221 |
+
self.system_header = system_header
|
222 |
+
self.device = torch.cuda.current_device()
|
223 |
+
|
224 |
+
# def encode_video(self, video_paths):
|
225 |
+
# inputs = {ModalityType.VISION: data.load_and_transform_video_data(video_paths, self.device)}
|
226 |
+
# # convert into visual dtype
|
227 |
+
# inputs = {key: inputs[key].to(self.llama_model.dtype) for key in inputs}
|
228 |
+
# with torch.no_grad():
|
229 |
+
# embeddings = self.visual_encoder(inputs)
|
230 |
+
# video_embeds = embeddings[ModalityType.VISION] # bsz x 1024
|
231 |
+
# inputs_llama = self.llama_proj(video_embeds).unsqueeze(1) # bsz x 1 x llama_size
|
232 |
+
# atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(self.device) # bsz x 1
|
233 |
+
# return inputs_llama, atts_llama
|
234 |
+
|
235 |
+
# def encode_audio(self, audio_paths):
|
236 |
+
# inputs = {ModalityType.AUDIO: data.load_and_transform_audio_data(audio_paths, self.device)}
|
237 |
+
# # convert into visual dtype
|
238 |
+
# inputs = {key: inputs[key].to(self.llama_model.dtype) for key in inputs}
|
239 |
+
# with torch.no_grad():
|
240 |
+
# embeddings = self.visual_encoder(inputs)
|
241 |
+
# audio_embeds = embeddings[ModalityType.AUDIO] # bsz x 1024
|
242 |
+
# inputs_llama = self.llama_proj(audio_embeds).unsqueeze(1) # bsz x 1 x llama_size
|
243 |
+
# atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(self.device) # bsz x 1
|
244 |
+
# return inputs_llama, atts_llama
|
245 |
+
|
246 |
+
# def encode_thermal(self, thermal_paths):
|
247 |
+
# inputs = {ModalityType.THERMAL: data.load_and_transform_thermal_data(thermal_paths, self.device)}
|
248 |
+
# # convert into visual dtype
|
249 |
+
# inputs = {key: inputs[key].to(self.llama_model.dtype) for key in inputs}
|
250 |
+
# with torch.no_grad():
|
251 |
+
# embeddings = self.visual_encoder(inputs)
|
252 |
+
# image_embeds = embeddings['thermal'] # bsz x 1024
|
253 |
+
# inputs_llama = self.llama_proj(image_embeds).unsqueeze(1) # bsz x 1 x llama_size
|
254 |
+
# atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(self.device) # bsz x 1
|
255 |
+
# return inputs_llama, atts_llama
|
256 |
+
|
257 |
+
def encode_image(self, image_paths):
|
258 |
+
"""encode images to llama inputs
|
259 |
+
|
260 |
+
:param tupe image_paths: (bsz, )
|
261 |
+
:return tensor, tensor: input feature to llama, attention mask to llama
|
262 |
+
"""
|
263 |
+
if self.encoder_pretrain == 'imagebind':
|
264 |
+
inputs = {ModalityType.VISION: data.load_and_transform_vision_data(image_paths, self.device)}
|
265 |
+
# convert into visual dtype
|
266 |
+
inputs = {key: inputs[key].to(self.llama_model.dtype) for key in inputs}
|
267 |
+
with torch.no_grad():
|
268 |
+
embeddings = self.visual_encoder(inputs)
|
269 |
+
image_embeds = embeddings['vision'] # bsz x 1024
|
270 |
+
inputs_llama = self.llama_proj(image_embeds).unsqueeze(1) # bsz x 1 x llama_size
|
271 |
+
atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(self.device) # bsz x 1
|
272 |
+
return inputs_llama, atts_llama
|
273 |
+
elif self.encoder_pretrain == 'clip':
|
274 |
+
inputs = self.load_and_transform_vision_data_clip(image_paths, self.device) # bsz x 3 x 224 x 224
|
275 |
+
inputs = inputs.to(self.llama_model.dtype) # clip requires torch.float32
|
276 |
+
inputs_llama = self.clip_encode_image(inputs)
|
277 |
+
atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(self.device) # bsz x 1/256
|
278 |
+
return inputs_llama, atts_llama
|
279 |
+
|
280 |
+
def my_encode_image(self, images):
|
281 |
+
"""encoder loaded image objects"""
|
282 |
+
# if self.encoder_pretrain == 'imagebind':
|
283 |
+
# inputs = {ModalityType.VISION: data.transform_vision_data(images, self.device)}
|
284 |
+
# # convert into visual dtype
|
285 |
+
# inputs = {key: inputs[key].to(self.llama_model.dtype) for key in inputs}
|
286 |
+
# with torch.no_grad():
|
287 |
+
# embeddings = self.visual_encoder(inputs)
|
288 |
+
# image_embeds = embeddings['vision'] # bsz x 1024
|
289 |
+
# inputs_llama = self.llama_proj(image_embeds).unsqueeze(1) # bsz x 1 x llama_size
|
290 |
+
# atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(self.device) # bsz x 1
|
291 |
+
# return inputs_llama, atts_llama
|
292 |
+
if self.encoder_pretrain == 'clip':
|
293 |
+
inputs = data.transform_vision_data(images, self.device) # bsz x 3 x 224 x 224
|
294 |
+
inputs_llama = self.clip_encode_image(inputs) # bsz x 1/256 x llama_size
|
295 |
+
atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(self.device) # bsz x 1/256
|
296 |
+
return inputs_llama, atts_llama
|
297 |
+
else:
|
298 |
+
raise NotImplementedError("Encoder pretrain [{}] not implemented".format(self.encoder_pretrain))
|
299 |
+
|
300 |
+
def encode_pcl(self, pcl_paths):
|
301 |
+
# load pcl data
|
302 |
+
inputs = self.load_and_transform_pcl_data(pcl_paths, self.device) # bsz x 40000 x 3
|
303 |
+
|
304 |
+
inputs = inputs.to(self.llama_model.dtype) # clip requires torch.float32
|
305 |
+
with torch.no_grad():
|
306 |
+
if self.vision_feature_type == 'global':
|
307 |
+
raise NotImplementedError("Global feature not implemented for pcl")
|
308 |
+
elif self.vision_feature_type == 'local':
|
309 |
+
embeddings = self.visual_encoder(inputs)[1][:, :self.num_vision_token] # bsz x 256 x 1024;
|
310 |
+
image_embeds = embeddings.reshape(-1, self.vision_hidden_size).to(self.llama_model.dtype) # bsz*num vision token x 1024
|
311 |
+
inputs_llama = self.llama_proj(image_embeds).reshape(-1, self.num_vision_token, self.llama_model.config.hidden_size) # bsz x num_vision_token x llama_size
|
312 |
+
atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(self.device) # bsz x 1/256
|
313 |
+
return inputs_llama, atts_llama
|
314 |
+
|
315 |
+
def clip_encode_image(self, inputs):
|
316 |
+
inputs = inputs.to(self.llama_model.dtype) # clip requires torch.float32
|
317 |
+
with torch.no_grad():
|
318 |
+
if self.vision_feature_type == 'global':
|
319 |
+
embeddings = self.visual_encoder(inputs) # bsz x 768
|
320 |
+
image_embeds = embeddings.to(self.llama_model.dtype)
|
321 |
+
inputs_llama = self.llama_proj(image_embeds).unsqueeze(1) # bsz x 1 x llama_size
|
322 |
+
elif self.vision_feature_type == 'local':
|
323 |
+
embeddings = self.visual_encoder.forward_patch_features(inputs)[:, :self.num_vision_token] # bsz x self.num_vision_token x 1024
|
324 |
+
image_embeds = embeddings.reshape(-1, self.vision_hidden_size).to(self.llama_model.dtype) # bsz*num vision token x 1024
|
325 |
+
inputs_llama = self.llama_proj(image_embeds).reshape(-1, self.num_vision_token, self.llama_model.config.hidden_size) # bsz x num_vision_token x llama_size
|
326 |
+
else:
|
327 |
+
raise NotImplementedError("{} not Implemented".format(self.vision_feature_type))
|
328 |
+
return inputs_llama
|
329 |
+
|
330 |
+
def load_and_transform_vision_data_clip(self, image_paths, device):
|
331 |
+
if image_paths is None:
|
332 |
+
return None
|
333 |
+
image_ouputs = []
|
334 |
+
for image_path in image_paths:
|
335 |
+
if os.path.exists(image_path):
|
336 |
+
image = Image.open(image_path)
|
337 |
+
elif image_path.startswith('s3://') and self.client is not None:
|
338 |
+
image = Image.open(io.BytesIO(self.client.get(image_path, update_cache=True))).convert("RGB")
|
339 |
+
elif image_path.startswith('http://'):
|
340 |
+
image = Image.open(requests.get(image_path, stream=True).raw)
|
341 |
+
else:
|
342 |
+
print("can not load image: ", image_path)
|
343 |
+
image_outpt = self.visual_preprocess(image).to(device) # 3 x 224 x 224
|
344 |
+
image_ouputs.append(image_outpt)
|
345 |
+
return torch.stack(image_ouputs, dim=0) # B x 3 x 224 x 224
|
346 |
+
|
347 |
+
def load_and_transform_pcl_data(self, pcl_paths, device):
|
348 |
+
if pcl_paths is None:
|
349 |
+
return None
|
350 |
+
pcl_output = []
|
351 |
+
for pcl_path in pcl_paths:
|
352 |
+
mesh_vertices = np.load(pcl_path) # 150000, 3
|
353 |
+
if not self.use_color:
|
354 |
+
point_cloud = mesh_vertices[:, 0:3] # do not use color for now
|
355 |
+
else:
|
356 |
+
point_cloud = mesh_vertices[:, 0:6]
|
357 |
+
point_cloud[:, 3:] = (point_cloud[:, 3:] - MEAN_COLOR_RGB) / 256.0
|
358 |
+
|
359 |
+
if self.use_height:
|
360 |
+
floor_height = np.percentile(point_cloud[:, 2], 0.99)
|
361 |
+
height = point_cloud[:, 2] - floor_height
|
362 |
+
point_cloud = np.concatenate([point_cloud, np.expand_dims(height, 1)], 1)
|
363 |
+
|
364 |
+
point_cloud, _ = random_sampling(
|
365 |
+
point_cloud, self.num_points, return_choices=True
|
366 |
+
)
|
367 |
+
pcl_output.append(torch.from_numpy(point_cloud))
|
368 |
+
return torch.stack(pcl_output, dim=0).to(device) # bsz x num_points x 3
|
369 |
+
|
370 |
+
def prompt_wrap(self, img_embeds, input_ids, target_ids, attention_mask, system_header, task_type):
|
371 |
+
'''
|
372 |
+
input_ids, target_ids, attention_mask: bsz x s2
|
373 |
+
'''
|
374 |
+
input_ids = input_ids.to(self.device) # bsz x s2
|
375 |
+
target_ids = target_ids.to(self.device) # bsz x s2
|
376 |
+
attention_mask = attention_mask.to(self.device) # bsz x s2
|
377 |
+
|
378 |
+
batch_size = img_embeds.shape[0]
|
379 |
+
|
380 |
+
# return list of headers if multiple tasks
|
381 |
+
p_before = make_prompt_start(system_header=system_header, vision_type=self.vision_type, task_type=task_type)
|
382 |
+
if isinstance(p_before, list):
|
383 |
+
p_before_tokens = [self.llama_tokenizer(p,
|
384 |
+
return_tensors="pt", add_special_tokens=False).input_ids[0].to(self.device) for p in p_before]
|
385 |
+
# TODO: fix bug here
|
386 |
+
p_before_token_ids = rnn.pad_sequence(p_before_tokens, batch_first=True, padding_value=self.llama_tokenizer.pad_token_id) # bsz x s1
|
387 |
+
p_before_attn_mask = p_before_token_ids.ne(self.llama_tokenizer.pad_token_id)
|
388 |
+
else:
|
389 |
+
p_before_tokens = self.llama_tokenizer(p_before,
|
390 |
+
return_tensors="pt", add_special_tokens=False).to(self.device) # [s1, s1...] list of batch size
|
391 |
+
p_before_token_ids = p_before_tokens.input_ids.expand(batch_size, -1) # bsz x s1
|
392 |
+
p_before_attn_mask = p_before_tokens.attention_mask.expand(batch_size, -1) # bsz x s1
|
393 |
+
# peft model need deeper call
|
394 |
+
p_before_embeds = self.llama_model.model.model.embed_tokens(p_before_token_ids) #.expand(batch_size, -1, -1) # bsz x s1 x embed_dim
|
395 |
+
p_after_embeds = self.llama_model.model.model.embed_tokens(input_ids).expand(batch_size, -1, -1) # bsz x s2 x embed_dim
|
396 |
+
bos = torch.ones([batch_size, 1],
|
397 |
+
dtype=p_before_token_ids.dtype,
|
398 |
+
device=p_before_token_ids.device) * self.llama_tokenizer.bos_token_id # bsz x 1
|
399 |
+
bos_embeds = self.llama_model.model.model.embed_tokens(bos) # bsz x 1 x embed_dim
|
400 |
+
inputs_embeds = torch.cat([bos_embeds, p_before_embeds, img_embeds, p_after_embeds], dim=1) # bsz x (1+s1+NumToken+s2) x embed_dim
|
401 |
+
|
402 |
+
# make target ids for prefix part
|
403 |
+
empty_targets = (
|
404 |
+
torch.ones([batch_size, 1 + p_before_embeds.size()[1] + self.num_vision_token], # 1 (bos) + s1 + num_image_tokens (image vector)
|
405 |
+
dtype=torch.long).to(self.device).fill_(-100)
|
406 |
+
) # bsz x (1 + s1 + 1)
|
407 |
+
targets = torch.cat([empty_targets, target_ids], dim=1) # bsz x (1 + s1 + num_image_tokens + s2)
|
408 |
+
assert inputs_embeds.size()[1] == targets.size()[1]
|
409 |
+
|
410 |
+
# atts_prefix = torch.ones([batch_size, 1 + p_before_embeds.size()[1] + self.num_vision_token], dtype=torch.long).to(self.device) # bsz x (1[bos] + s1 +num_image_tokens)
|
411 |
+
atts_bos = torch.ones([batch_size, 1], dtype=torch.long).to(self.device) # bsz x 1
|
412 |
+
atts_img = torch.ones([batch_size, self.num_vision_token], dtype=torch.long).to(self.device) # bsz x num_image_tokens
|
413 |
+
attention_mask = torch.cat([atts_bos, p_before_attn_mask, atts_img, attention_mask], dim=1)
|
414 |
+
assert attention_mask.size() == targets.size() # bsz x (1 + s1 + num_image_tokens + s2)
|
415 |
+
return inputs_embeds, targets, attention_mask
|
416 |
+
|
417 |
+
def forward(self, inputs):
|
418 |
+
"""Model Forward in training
|
419 |
+
|
420 |
+
:param class inputs: model itself
|
421 |
+
:raises ValueError: valueerror if not image or pcl
|
422 |
+
:return list: loss & token acc
|
423 |
+
"""
|
424 |
+
# image_paths = inputs['image_paths']
|
425 |
+
assert self.vision_type == inputs['vision_type'] # single modal case
|
426 |
+
task_type = inputs['task_type']
|
427 |
+
vision_paths = inputs['vision_paths']
|
428 |
+
if self.vision_type == 'image':
|
429 |
+
vision_embeds, _ = self.encode_image(vision_paths)
|
430 |
+
elif self.vision_type == 'pcl':
|
431 |
+
vision_embeds, _ = self.encode_pcl(vision_paths) # Bsz x N token x C
|
432 |
+
else:
|
433 |
+
raise ValueError('vision type [{}] not supported'.format(self.vision_type))
|
434 |
+
|
435 |
+
output_texts = inputs['output_texts']
|
436 |
+
input_ids, target_ids, attention_mask = process_batch_instance(self.llama_tokenizer, output_texts, self.max_tgt_len, self.vision_type)
|
437 |
+
inputs_embeds, targets, attention_mask = self.prompt_wrap(vision_embeds, input_ids, target_ids, attention_mask, self.system_header, task_type)
|
438 |
+
|
439 |
+
outputs = self.llama_model(
|
440 |
+
inputs_embeds=inputs_embeds,
|
441 |
+
attention_mask=attention_mask,
|
442 |
+
return_dict=True,
|
443 |
+
labels=targets,
|
444 |
+
)
|
445 |
+
loss = outputs.loss
|
446 |
+
# calculate the token accuarcy
|
447 |
+
chosen_tokens = torch.max(outputs.logits, dim=-1)[1][:, 1: -1] # [B, S-1]
|
448 |
+
labels = targets[:, 2:]
|
449 |
+
gen_acc = (chosen_tokens.reshape(-1) == labels.reshape(-1)).to(torch.long) # [B*S]
|
450 |
+
valid_mask = (labels != -100).reshape(-1)
|
451 |
+
valid_tokens = gen_acc & valid_mask # [B*S]
|
452 |
+
gen_acc = valid_tokens.sum().item() / valid_mask.sum().item()
|
453 |
+
return loss, gen_acc
|
454 |
+
|
455 |
+
def extract_multimodal_feature(self, inputs):
|
456 |
+
"""Extract multimodal features from the input in Generation (Test)
|
457 |
+
|
458 |
+
:param Dict inputs: input dict; modality: path
|
459 |
+
:return _type_: _description_
|
460 |
+
"""
|
461 |
+
features = []
|
462 |
+
if inputs['image_paths']:
|
463 |
+
image_embeds, _ = self.encode_image(inputs['image_paths'])
|
464 |
+
features.append(image_embeds)
|
465 |
+
if 'images' in inputs and inputs['images']: # image objects input in testing
|
466 |
+
image_embeds, _ = self.my_encode_image(inputs['images'])
|
467 |
+
return image_embeds
|
468 |
+
# features.append(image_embeds)
|
469 |
+
if 'pcl_paths' in inputs and inputs['pcl_paths']:
|
470 |
+
pcl_embeds, _ = self.encode_pcl(inputs['pcl_paths'])
|
471 |
+
features.append(pcl_embeds)
|
472 |
+
# TODO: Cautions HERE! Multimodality allowed in test ONLY!
|
473 |
+
feature_embeds = torch.cat(features).sum(dim=0).unsqueeze(0) # sum all modality features together
|
474 |
+
return feature_embeds
|
475 |
+
|
476 |
+
def prepare_generation_embedding(self, inputs):
|
477 |
+
"""prepare for generation
|
478 |
+
|
479 |
+
:param class inputs: model
|
480 |
+
:return Dict: generation input
|
481 |
+
"""
|
482 |
+
eov = VISION_TAGS['eov'][self.vision_type]
|
483 |
+
# TODO: add System header & image token size
|
484 |
+
prompt_list = inputs['prompt'] # questions from user
|
485 |
+
if len(inputs['modality_embeds']) == 1:
|
486 |
+
feature_embeds = inputs['modality_embeds'][0]
|
487 |
+
else:
|
488 |
+
feature_embeds = self.extract_multimodal_feature(inputs)
|
489 |
+
inputs['modality_embeds'].append(feature_embeds)
|
490 |
+
|
491 |
+
batch_size = feature_embeds.shape[0]
|
492 |
+
p_before = make_prompt_start(vision_type=self.vision_type) # no system header in test
|
493 |
+
p_before_tokens = self.llama_tokenizer(p_before,
|
494 |
+
return_tensors="pt", add_special_tokens=False).to(self.device)
|
495 |
+
p_before_embeds = self.llama_model.model.model.embed_tokens(p_before_tokens.input_ids).expand(batch_size, -1, -1) # bsz x s1 x embed_dim
|
496 |
+
p_after_embeds_list = []
|
497 |
+
p_after_tokens_list = []
|
498 |
+
for prompt in prompt_list:
|
499 |
+
# text = '</Img> ' + prompt + '\n### Assistant:'
|
500 |
+
text = f'{eov} ' + prompt + '\n### Assistant:'
|
501 |
+
p_after_tokens = self.llama_tokenizer(text, add_special_tokens=False, return_tensors='pt').to(self.device)
|
502 |
+
|
503 |
+
p_after_tokens_list.append(p_after_tokens.input_ids.squeeze(0))
|
504 |
+
|
505 |
+
p_after_tokens = rnn.pad_sequence(p_after_tokens_list, batch_first=True, padding_value=self.llama_tokenizer.pad_token_id)
|
506 |
+
|
507 |
+
p_after_embeds = self.llama_model.model.model.embed_tokens(p_after_tokens)
|
508 |
+
|
509 |
+
# text = f'{eov} ' + prompt + '\n### Assistant:'
|
510 |
+
# p_after_tokens = self.llama_tokenizer(text, add_special_tokens=False, return_tensors='pt').to(self.device)
|
511 |
+
# p_after_embeds = self.llama_model.model.model.embed_tokens(p_after_tokens.input_ids).expand(batch_size, -1, -1) # bsz x s1 x embed_dim
|
512 |
+
bos = torch.ones([batch_size, 1],
|
513 |
+
dtype=p_before_tokens.input_ids.dtype,
|
514 |
+
device=p_before_tokens.input_ids.device) * self.llama_tokenizer.bos_token_id # bsz x 1
|
515 |
+
bos_embeds = self.llama_model.model.model.embed_tokens(bos) # bsz x 1 x embed_dim
|
516 |
+
# print(bos_embeds.shape, p_before_embeds.shape, feature_embeds.shape, p_after_embeds.shape)
|
517 |
+
inputs_embeds = torch.cat([bos_embeds, p_before_embeds, feature_embeds, p_after_embeds], dim=1) # bsz x (1+s1+NumVisionToken+s2) x embed_dim
|
518 |
+
return inputs_embeds
|
519 |
+
|
520 |
+
def generate(self, inputs):
|
521 |
+
'''
|
522 |
+
inputs = {
|
523 |
+
'image_paths': optional,
|
524 |
+
'audio_paths': optional
|
525 |
+
'video_paths': optional
|
526 |
+
'thermal_paths': optional
|
527 |
+
'mode': generation mode,
|
528 |
+
'prompt': human input prompt,
|
529 |
+
'max_tgt_len': generation length,
|
530 |
+
'top_p': top_p,
|
531 |
+
'temperature': temperature
|
532 |
+
'modality_embeds': None or torch.tensor
|
533 |
+
'modality_cache': save the image cache
|
534 |
+
}
|
535 |
+
'''
|
536 |
+
input_embeds = self.prepare_generation_embedding(inputs)
|
537 |
+
# stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=[2277], encounters=1)])
|
538 |
+
stopping_criteria = StoppingCriteriaList([MyStoppingCriteria([[2277]], input_embeds)])
|
539 |
+
outputs = self.llama_model.generate(
|
540 |
+
inputs_embeds=input_embeds,
|
541 |
+
max_new_tokens=inputs['max_tgt_len'],
|
542 |
+
top_p=inputs['top_p'],
|
543 |
+
temperature=inputs['temperature'],
|
544 |
+
do_sample=True,
|
545 |
+
use_cache=True,
|
546 |
+
stopping_criteria=stopping_criteria,
|
547 |
+
)
|
548 |
+
#output_text = self.llama_tokenizer.decode(outputs[0][:-2], skip_special_tokens=True)
|
549 |
+
output_text = self.llama_tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
550 |
+
return output_text
|
model/utils/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .ceph_utils import *
|
2 |
+
from .pcl_utils import *
|
model/utils/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (224 Bytes). View file
|
|
model/utils/__pycache__/ceph_utils.cpython-310.pyc
ADDED
Binary file (7.87 kB). View file
|
|
model/utils/__pycache__/pcl_utils.cpython-310.pyc
ADDED
Binary file (3.22 kB). View file
|
|
model/utils/ceph_utils.py
ADDED
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Date: 2022-07-18 2:15:47 pm
|
3 |
+
Author: dihuangdh
|
4 |
+
Descriptions:
|
5 |
+
-----
|
6 |
+
LastEditTime: 2022-09-14 3:44:19 pm
|
7 |
+
LastEditors: dihuangdh
|
8 |
+
"""
|
9 |
+
|
10 |
+
import json
|
11 |
+
import pickle
|
12 |
+
import warnings
|
13 |
+
from io import BytesIO, StringIO # TODO:
|
14 |
+
from pathlib import Path
|
15 |
+
from typing import Any, Generator, Iterator, Optional, Tuple, Union
|
16 |
+
|
17 |
+
import cv2
|
18 |
+
import numpy as np
|
19 |
+
|
20 |
+
|
21 |
+
def has_method(obj: object, method: str) -> bool:
|
22 |
+
"""Check whether the object has a method.
|
23 |
+
Args:
|
24 |
+
method (str): The method name to check.
|
25 |
+
obj (object): The object to check.
|
26 |
+
Returns:
|
27 |
+
bool: True if the object has the method else False.
|
28 |
+
"""
|
29 |
+
return hasattr(obj, method) and callable(getattr(obj, method))
|
30 |
+
|
31 |
+
|
32 |
+
class PetrelBackend(object):
|
33 |
+
"""Petrel storage backend - simple version"""
|
34 |
+
|
35 |
+
def __init__(self, enable_mc: bool = False) -> None:
|
36 |
+
try:
|
37 |
+
from petrel_client.client import Client
|
38 |
+
except ImportError:
|
39 |
+
raise ImportError(
|
40 |
+
"Please install petrel_client to enable " "PetrelBackend."
|
41 |
+
)
|
42 |
+
|
43 |
+
self._client = Client('~/petreloss.conf')
|
44 |
+
|
45 |
+
def get(self, filepath) -> memoryview:
|
46 |
+
value = self._client.Get(filepath)
|
47 |
+
value_buf = memoryview(value)
|
48 |
+
return value_buf
|
49 |
+
|
50 |
+
def get_text(self, filepath, warning=False) -> str:
|
51 |
+
try:
|
52 |
+
value = self._client.Get(filepath)
|
53 |
+
except:
|
54 |
+
if warning:
|
55 |
+
warnings.warn("Failed to get text from {}".format(filepath))
|
56 |
+
value = None
|
57 |
+
else:
|
58 |
+
raise Exception("Failed to get text from {}".format(filepath))
|
59 |
+
return str(value, encoding="utf-8")
|
60 |
+
|
61 |
+
def get_uint16_png(self, filepath, warning=False) -> np.ndarray:
|
62 |
+
try:
|
63 |
+
value = np.frombuffer(self._client.get(filepath), np.uint8)
|
64 |
+
value = cv2.imdecode(value, cv2.IMREAD_UNCHANGED)
|
65 |
+
except:
|
66 |
+
if warning:
|
67 |
+
warnings.warn("Failed to get uint16_png from {}".format(filepath))
|
68 |
+
value = None
|
69 |
+
else:
|
70 |
+
raise Exception("Failed to get uint16_png from {}".format(filepath))
|
71 |
+
return value
|
72 |
+
|
73 |
+
def get_uint8_jpg(self, filepath, warning=False) -> np.ndarray:
|
74 |
+
try:
|
75 |
+
value = np.frombuffer(self._client.get(filepath), np.uint8)
|
76 |
+
value = cv2.imdecode(value, cv2.IMREAD_UNCHANGED)
|
77 |
+
except:
|
78 |
+
if warning:
|
79 |
+
warnings.warn("Failed to get uint8_jpg from {}".format(filepath))
|
80 |
+
value = None
|
81 |
+
else:
|
82 |
+
raise Exception("Failed to get uint8_jpg from {}".format(filepath))
|
83 |
+
return value
|
84 |
+
|
85 |
+
def get_npz(self, filepath, warning=False) -> Any:
|
86 |
+
try:
|
87 |
+
value = self._client.get(filepath)
|
88 |
+
value = np.loads(value)
|
89 |
+
except Exception as e:
|
90 |
+
if warning:
|
91 |
+
warnings.warn("Failed to get npz from {}".format(filepath))
|
92 |
+
value = None
|
93 |
+
else:
|
94 |
+
print(e)
|
95 |
+
raise Exception("Failed to get npz from {}".format(filepath))
|
96 |
+
return value
|
97 |
+
|
98 |
+
def get_numpy_txt(self, filepath, warning=False) -> np.ndarray:
|
99 |
+
try:
|
100 |
+
value = np.loadtxt(StringIO(self.get_text(filepath)))
|
101 |
+
except:
|
102 |
+
if warning:
|
103 |
+
warnings.warn("Failed to get numpy_txt from {}".format(filepath))
|
104 |
+
value = None
|
105 |
+
else:
|
106 |
+
raise Exception("Failed to get numpy_txt from {}".format(filepath))
|
107 |
+
return value
|
108 |
+
|
109 |
+
def get_json(self, filepath, warning=False) -> Any:
|
110 |
+
try:
|
111 |
+
value = self._client.get(filepath)
|
112 |
+
value = json.loads(value)
|
113 |
+
except:
|
114 |
+
if warning:
|
115 |
+
warnings.warn("Failed to get json from {}".format(filepath))
|
116 |
+
value = None
|
117 |
+
else:
|
118 |
+
raise Exception("Failed to get json from {}".format(filepath))
|
119 |
+
return value
|
120 |
+
|
121 |
+
def put_uint16_png(self, filepath, value) -> None:
|
122 |
+
success, img_array = cv2.imencode(".png", value, params=[cv2.CV_16U])
|
123 |
+
assert success
|
124 |
+
img_bytes = img_array.tobytes()
|
125 |
+
self._client.put(filepath, img_bytes)
|
126 |
+
# self._client.put(filepath, img_bytes, update_cache=True)
|
127 |
+
|
128 |
+
def put_uint8_jpg(self, filepath, value) -> None:
|
129 |
+
success, img_array = cv2.imencode(".jpg", value)
|
130 |
+
assert success
|
131 |
+
img_bytes = img_array.tobytes()
|
132 |
+
self._client.put(filepath, img_bytes)
|
133 |
+
# self._client.put(filepath, img_bytes, update_cache=True)
|
134 |
+
|
135 |
+
def put_npz(self, filepath, value) -> None:
|
136 |
+
value = pickle.dumps(value)
|
137 |
+
self._client.put(filepath, value)
|
138 |
+
# self._client.put(filepath, value, update_cache=True)
|
139 |
+
|
140 |
+
def put_json(self, filepath, value) -> None:
|
141 |
+
value = json.dumps(value).encode()
|
142 |
+
self._client.put(filepath, value)
|
143 |
+
# self._client.put(filepath, value, update_cache=True)
|
144 |
+
|
145 |
+
def put_text(self, filepath, value) -> None:
|
146 |
+
self._client.put(filepath, bytes(value, encoding="utf-8"))
|
147 |
+
# self._client.put(filepath, bytes(value, encoding='utf-8'), update_cache=True)
|
148 |
+
|
149 |
+
def join_path(
|
150 |
+
self, filepath: Union[str, Path], *filepaths: Union[str, Path]
|
151 |
+
) -> str:
|
152 |
+
"""Concatenate all file paths.
|
153 |
+
Args:
|
154 |
+
filepath (str or Path): Path to be concatenated.
|
155 |
+
Returns:
|
156 |
+
str: The result after concatenation.
|
157 |
+
"""
|
158 |
+
# filepath = self._format_path(self._map_path(filepath))
|
159 |
+
if filepath.endswith("/"):
|
160 |
+
filepath = filepath[:-1]
|
161 |
+
formatted_paths = [filepath]
|
162 |
+
for path in filepaths:
|
163 |
+
formatted_paths.append(path)
|
164 |
+
return "/".join(formatted_paths)
|
165 |
+
|
166 |
+
# from mmcv
|
167 |
+
def list_dir_or_file(
|
168 |
+
self,
|
169 |
+
dir_path: Union[str, Path],
|
170 |
+
list_dir: bool = True,
|
171 |
+
list_file: bool = True,
|
172 |
+
suffix: Optional[Union[str, Tuple[str]]] = None,
|
173 |
+
recursive: bool = False,
|
174 |
+
) -> Iterator[str]:
|
175 |
+
"""Scan a directory to find the interested directories or files in
|
176 |
+
arbitrary order.
|
177 |
+
Note:
|
178 |
+
Petrel has no concept of directories but it simulates the directory
|
179 |
+
hierarchy in the filesystem through public prefixes. In addition,
|
180 |
+
if the returned path ends with '/', it means the path is a public
|
181 |
+
prefix which is a logical directory.
|
182 |
+
Note:
|
183 |
+
:meth:`list_dir_or_file` returns the path relative to ``dir_path``.
|
184 |
+
In addition, the returned path of directory will not contains the
|
185 |
+
suffix '/' which is consistent with other backends.
|
186 |
+
Args:
|
187 |
+
dir_path (str | Path): Path of the directory.
|
188 |
+
list_dir (bool): List the directories. Default: True.
|
189 |
+
list_file (bool): List the path of files. Default: True.
|
190 |
+
suffix (str or tuple[str], optional): File suffix
|
191 |
+
that we are interested in. Default: None.
|
192 |
+
recursive (bool): If set to True, recursively scan the
|
193 |
+
directory. Default: False.
|
194 |
+
Yields:
|
195 |
+
Iterable[str]: A relative path to ``dir_path``.
|
196 |
+
"""
|
197 |
+
# if not has_method(self._client, 'list'):
|
198 |
+
# raise NotImplementedError(
|
199 |
+
# 'Current version of Petrel Python SDK has not supported '
|
200 |
+
# 'the `list` method, please use a higher version or dev'
|
201 |
+
# ' branch instead.')
|
202 |
+
|
203 |
+
# dir_path = self._map_path(dir_path)
|
204 |
+
# dir_path = self._format_path(dir_path)
|
205 |
+
# if list_dir and suffix is not None:
|
206 |
+
# raise TypeError(
|
207 |
+
# '`list_dir` should be False when `suffix` is not None')
|
208 |
+
|
209 |
+
# if (suffix is not None) and not isinstance(suffix, (str, tuple)):
|
210 |
+
# raise TypeError('`suffix` must be a string or tuple of strings')
|
211 |
+
|
212 |
+
# Petrel's simulated directory hierarchy assumes that directory paths
|
213 |
+
# should end with `/`
|
214 |
+
if not dir_path.endswith("/"):
|
215 |
+
dir_path += "/"
|
216 |
+
|
217 |
+
root = dir_path
|
218 |
+
|
219 |
+
def _list_dir_or_file(dir_path, list_dir, list_file, suffix, recursive):
|
220 |
+
for path in self._client.list(dir_path):
|
221 |
+
# the `self.isdir` is not used here to determine whether path
|
222 |
+
# is a directory, because `self.isdir` relies on
|
223 |
+
# `self._client.list`
|
224 |
+
if path.endswith("/"): # a directory path
|
225 |
+
next_dir_path = self.join_path(dir_path, path)
|
226 |
+
if list_dir:
|
227 |
+
# get the relative path and exclude the last
|
228 |
+
# character '/'
|
229 |
+
rel_dir = next_dir_path[len(root) : -1]
|
230 |
+
yield rel_dir
|
231 |
+
if recursive:
|
232 |
+
yield from _list_dir_or_file(
|
233 |
+
next_dir_path, list_dir, list_file, suffix, recursive
|
234 |
+
)
|
235 |
+
else: # a file path
|
236 |
+
absolute_path = self.join_path(dir_path, path)
|
237 |
+
rel_path = absolute_path[len(root) :]
|
238 |
+
if (suffix is None or rel_path.endswith(suffix)) and list_file:
|
239 |
+
yield rel_path
|
240 |
+
|
241 |
+
return _list_dir_or_file(dir_path, list_dir, list_file, suffix, recursive)
|
242 |
+
|
243 |
+
# from mmcv
|
244 |
+
def exists(self, filepath: Union[str, Path]) -> bool:
|
245 |
+
"""Check whether a file path exists.
|
246 |
+
Args:
|
247 |
+
filepath (str or Path): Path to be checked whether exists.
|
248 |
+
Returns:
|
249 |
+
bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise.
|
250 |
+
"""
|
251 |
+
if not (
|
252 |
+
has_method(self._client, "contains") and has_method(self._client, "isdir")
|
253 |
+
):
|
254 |
+
raise NotImplementedError(
|
255 |
+
"Current version of Petrel Python SDK has not supported "
|
256 |
+
"the `contains` and `isdir` methods, please use a higher"
|
257 |
+
"version or dev branch instead."
|
258 |
+
)
|
259 |
+
|
260 |
+
return self._client.contains(filepath) or self._client.isdir(filepath)
|
model/utils/pcl_utils.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
|
3 |
+
""" Utility functions for processing point clouds.
|
4 |
+
|
5 |
+
Author: Charles R. Qi and Or Litany
|
6 |
+
"""
|
7 |
+
|
8 |
+
import os
|
9 |
+
import sys
|
10 |
+
import torch
|
11 |
+
|
12 |
+
# Point cloud IO
|
13 |
+
import numpy as np
|
14 |
+
from plyfile import PlyData, PlyElement
|
15 |
+
|
16 |
+
# Mesh IO
|
17 |
+
import trimesh
|
18 |
+
|
19 |
+
|
20 |
+
MEAN_COLOR_RGB = np.array([109.8, 97.2, 83.8])
|
21 |
+
|
22 |
+
# ----------------------------------------
|
23 |
+
# Point Cloud Sampling
|
24 |
+
# ----------------------------------------
|
25 |
+
|
26 |
+
|
27 |
+
def random_sampling(pc, num_sample, replace=None, return_choices=False):
|
28 |
+
"""Input is NxC, output is num_samplexC"""
|
29 |
+
if replace is None:
|
30 |
+
replace = pc.shape[0] < num_sample
|
31 |
+
choices = np.random.choice(pc.shape[0], num_sample, replace=replace)
|
32 |
+
if return_choices:
|
33 |
+
return pc[choices], choices
|
34 |
+
else:
|
35 |
+
return pc[choices]
|
36 |
+
|
37 |
+
|
38 |
+
def check_aspect(crop_range, aspect_min):
|
39 |
+
xy_aspect = np.min(crop_range[:2]) / np.max(crop_range[:2])
|
40 |
+
xz_aspect = np.min(crop_range[[0, 2]]) / np.max(crop_range[[0, 2]])
|
41 |
+
yz_aspect = np.min(crop_range[1:]) / np.max(crop_range[1:])
|
42 |
+
return (
|
43 |
+
(xy_aspect >= aspect_min)
|
44 |
+
or (xz_aspect >= aspect_min)
|
45 |
+
or (yz_aspect >= aspect_min)
|
46 |
+
)
|
47 |
+
|
48 |
+
|
49 |
+
class RandomCuboid(object):
|
50 |
+
"""
|
51 |
+
RandomCuboid augmentation from DepthContrast [https://arxiv.org/abs/2101.02691]
|
52 |
+
We slightly modify this operation to account for object detection.
|
53 |
+
This augmentation randomly crops a cuboid from the input and
|
54 |
+
ensures that the cropped cuboid contains at least one bounding box
|
55 |
+
"""
|
56 |
+
|
57 |
+
def __init__(
|
58 |
+
self,
|
59 |
+
min_points,
|
60 |
+
aspect=0.8,
|
61 |
+
min_crop=0.5,
|
62 |
+
max_crop=1.0,
|
63 |
+
box_filter_policy="center",
|
64 |
+
):
|
65 |
+
self.aspect = aspect
|
66 |
+
self.min_crop = min_crop
|
67 |
+
self.max_crop = max_crop
|
68 |
+
self.min_points = min_points
|
69 |
+
self.box_filter_policy = box_filter_policy
|
70 |
+
|
71 |
+
def __call__(self, point_cloud, target_boxes, per_point_labels=None):
|
72 |
+
range_xyz = np.max(point_cloud[:, 0:3], axis=0) - np.min(
|
73 |
+
point_cloud[:, 0:3], axis=0
|
74 |
+
)
|
75 |
+
|
76 |
+
for _ in range(100):
|
77 |
+
crop_range = self.min_crop + np.random.rand(3) * (
|
78 |
+
self.max_crop - self.min_crop
|
79 |
+
)
|
80 |
+
if not check_aspect(crop_range, self.aspect):
|
81 |
+
continue
|
82 |
+
|
83 |
+
sample_center = point_cloud[np.random.choice(len(point_cloud)), 0:3]
|
84 |
+
|
85 |
+
new_range = range_xyz * crop_range / 2.0
|
86 |
+
|
87 |
+
max_xyz = sample_center + new_range
|
88 |
+
min_xyz = sample_center - new_range
|
89 |
+
|
90 |
+
upper_idx = (
|
91 |
+
np.sum((point_cloud[:, 0:3] <= max_xyz).astype(np.int32), 1) == 3
|
92 |
+
)
|
93 |
+
lower_idx = (
|
94 |
+
np.sum((point_cloud[:, 0:3] >= min_xyz).astype(np.int32), 1) == 3
|
95 |
+
)
|
96 |
+
|
97 |
+
new_pointidx = (upper_idx) & (lower_idx)
|
98 |
+
|
99 |
+
if np.sum(new_pointidx) < self.min_points:
|
100 |
+
continue
|
101 |
+
|
102 |
+
new_point_cloud = point_cloud[new_pointidx, :]
|
103 |
+
|
104 |
+
# filtering policy is the only modification from DepthContrast
|
105 |
+
if self.box_filter_policy == "center":
|
106 |
+
# remove boxes whose center does not lie within the new_point_cloud
|
107 |
+
new_boxes = target_boxes
|
108 |
+
if (
|
109 |
+
target_boxes.sum() > 0
|
110 |
+
): # ground truth contains no bounding boxes. Common in SUNRGBD.
|
111 |
+
box_centers = target_boxes[:, 0:3]
|
112 |
+
new_pc_min_max = np.min(new_point_cloud[:, 0:3], axis=0), np.max(
|
113 |
+
new_point_cloud[:, 0:3], axis=0
|
114 |
+
)
|
115 |
+
keep_boxes = np.logical_and(
|
116 |
+
np.all(box_centers >= new_pc_min_max[0], axis=1),
|
117 |
+
np.all(box_centers <= new_pc_min_max[1], axis=1),
|
118 |
+
)
|
119 |
+
if keep_boxes.sum() == 0:
|
120 |
+
# current data augmentation removes all boxes in the pointcloud. fail!
|
121 |
+
continue
|
122 |
+
new_boxes = target_boxes[keep_boxes]
|
123 |
+
if per_point_labels is not None:
|
124 |
+
new_per_point_labels = [x[new_pointidx] for x in per_point_labels]
|
125 |
+
else:
|
126 |
+
new_per_point_labels = None
|
127 |
+
# if we are here, all conditions are met. return boxes
|
128 |
+
return new_point_cloud, new_boxes, new_per_point_labels
|
129 |
+
|
130 |
+
# fallback
|
131 |
+
return point_cloud, target_boxes, per_point_labels
|
pretrained_ckpt/lamm98k/config.json
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "../model_zoo/vicuna_ckpt/13b_v0/",
|
3 |
+
"architectures": [
|
4 |
+
"LlamaForCausalLM"
|
5 |
+
],
|
6 |
+
"bos_token_id": 1,
|
7 |
+
"eos_token_id": 2,
|
8 |
+
"hidden_act": "silu",
|
9 |
+
"hidden_size": 5120,
|
10 |
+
"initializer_range": 0.02,
|
11 |
+
"intermediate_size": 13824,
|
12 |
+
"max_position_embeddings": 2048,
|
13 |
+
"model_type": "llama",
|
14 |
+
"num_attention_heads": 40,
|
15 |
+
"num_hidden_layers": 40,
|
16 |
+
"pad_token_id": 0,
|
17 |
+
"rms_norm_eps": 1e-06,
|
18 |
+
"tie_word_embeddings": false,
|
19 |
+
"torch_dtype": "float16",
|
20 |
+
"transformers_version": "4.29.1",
|
21 |
+
"use_cache": true,
|
22 |
+
"vocab_size": 32001
|
23 |
+
}
|
pretrained_ckpt/lamm98k/pytorch_model.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5e43143d64def44689a2c197045795e0fa241a13ea4f33c99575fe0336d8212f
|
3 |
+
size 115418859
|
pretrained_ckpt/lamm98k/special_tokens_map.json
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bos_token": {
|
3 |
+
"content": "<s>",
|
4 |
+
"lstrip": false,
|
5 |
+
"normalized": true,
|
6 |
+
"rstrip": false,
|
7 |
+
"single_word": false
|
8 |
+
},
|
9 |
+
"eos_token": {
|
10 |
+
"content": "</s>",
|
11 |
+
"lstrip": false,
|
12 |
+
"normalized": true,
|
13 |
+
"rstrip": false,
|
14 |
+
"single_word": false
|
15 |
+
},
|
16 |
+
"pad_token": "</s>",
|
17 |
+
"unk_token": {
|
18 |
+
"content": "<unk>",
|
19 |
+
"lstrip": false,
|
20 |
+
"normalized": true,
|
21 |
+
"rstrip": false,
|
22 |
+
"single_word": false
|
23 |
+
}
|
24 |
+
}
|
pretrained_ckpt/lamm98k/tokenizer.model
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347
|
3 |
+
size 499723
|
pretrained_ckpt/lamm98k/tokenizer_config.json
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"add_bos_token": true,
|
3 |
+
"add_eos_token": false,
|
4 |
+
"bos_token": {
|
5 |
+
"__type": "AddedToken",
|
6 |
+
"content": "<s>",
|
7 |
+
"lstrip": false,
|
8 |
+
"normalized": true,
|
9 |
+
"rstrip": false,
|
10 |
+
"single_word": false
|
11 |
+
},
|
12 |
+
"clean_up_tokenization_spaces": false,
|
13 |
+
"eos_token": {
|
14 |
+
"__type": "AddedToken",
|
15 |
+
"content": "</s>",
|
16 |
+
"lstrip": false,
|
17 |
+
"normalized": true,
|
18 |
+
"rstrip": false,
|
19 |
+
"single_word": false
|
20 |
+
},
|
21 |
+
"model_max_length": 1000000000000000019884624838656,
|
22 |
+
"pad_token": null,
|
23 |
+
"sp_model_kwargs": {},
|
24 |
+
"tokenizer_class": "LlamaTokenizer",
|
25 |
+
"unk_token": {
|
26 |
+
"__type": "AddedToken",
|
27 |
+
"content": "<unk>",
|
28 |
+
"lstrip": false,
|
29 |
+
"normalized": true,
|
30 |
+
"rstrip": false,
|
31 |
+
"single_word": false
|
32 |
+
},
|
33 |
+
"use_fast": false
|
34 |
+
}
|
requirements.txt
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
timm==0.6.7
|
2 |
+
data
|
3 |
+
einops==0.6.1
|
4 |
+
ftfy==6.1.1
|
5 |
+
iopath==0.1.10
|
6 |
+
numpy==1.24.3
|
7 |
+
peft==0.3.0
|
8 |
+
Pillow==9.5.0
|
9 |
+
PyYAML==6.0
|
10 |
+
regex==2022.10.31
|
11 |
+
torchvision==0.14.1
|
12 |
+
torchaudio==0.13.1
|
13 |
+
pytorchvideo
|
14 |
+
fvcore
|
15 |
+
decord==0.6.0
|
16 |
+
tqdm==4.64.1
|
17 |
+
transformers==4.29.1
|
18 |
+
bigmodelvis
|
19 |
+
gradio
|