openlamm commited on
Commit
94da716
1 Parent(s): 8bb945f
app.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModel, AutoTokenizer
2
+ from copy import deepcopy
3
+ import os
4
+ import ipdb
5
+ import gradio as gr
6
+ import mdtex2html
7
+ from model.openlamm import LAMMPEFTModel
8
+ import torch
9
+ import json
10
+
11
+ # init the model
12
+ args = {
13
+ 'model': 'openllama_peft',
14
+ 'imagebind_ckpt_path': '../model_zoo/imagebind_ckpt',
15
+ 'vicuna_ckpt_path': '../model_zoo/vicuna_ckpt/13b_v0',
16
+ 'delta_ckpt_path': './pretrained_ckpt/lamm98k/pytorch_model.pt',
17
+ 'stage': 1,
18
+ 'max_tgt_len': 128,
19
+ 'lora_r': 32,
20
+ 'lora_alpha': 32,
21
+ 'lora_dropout': 0.1,
22
+ 'lora_target_modules': ['q_proj', 'k_proj', 'v_proj', 'o_proj'],
23
+ 'vision_type': 'image',
24
+ 'vision_feature_type': 'local',
25
+ 'num_vision_token': 256,
26
+ 'encoder_pretrain': 'clip',
27
+ 'system_header': True,
28
+ }
29
+
30
+ model = LAMMPEFTModel(**args)
31
+ delta_ckpt = torch.load(args['delta_ckpt_path'], map_location=torch.device('cpu'))
32
+ model.load_state_dict(delta_ckpt, strict=False)
33
+ model = model.eval().half().cuda()
34
+ print(f'[!] init the 13b model over ...')
35
+
36
+ """Override Chatbot.postprocess"""
37
+
38
+
39
+ def postprocess(self, y):
40
+ if y is None:
41
+ return []
42
+ for i, (message, response) in enumerate(y):
43
+ y[i] = (
44
+ None if message is None else mdtex2html.convert((message)),
45
+ None if response is None else mdtex2html.convert(response),
46
+ )
47
+ return y
48
+
49
+
50
+ gr.Chatbot.postprocess = postprocess
51
+
52
+
53
+ def parse_text(text):
54
+ """copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/"""
55
+ lines = text.split("\n")
56
+ lines = [line for line in lines if line != ""]
57
+ count = 0
58
+ for i, line in enumerate(lines):
59
+ if "```" in line:
60
+ count += 1
61
+ items = line.split('`')
62
+ if count % 2 == 1:
63
+ lines[i] = f'<pre><code class="language-{items[-1]}">'
64
+ else:
65
+ lines[i] = f'<br></code></pre>'
66
+ else:
67
+ if i > 0:
68
+ if count % 2 == 1:
69
+ line = line.replace("`", "\`")
70
+ line = line.replace("<", "&lt;")
71
+ line = line.replace(">", "&gt;")
72
+ line = line.replace(" ", "&nbsp;")
73
+ line = line.replace("*", "&ast;")
74
+ line = line.replace("_", "&lowbar;")
75
+ line = line.replace("-", "&#45;")
76
+ line = line.replace(".", "&#46;")
77
+ line = line.replace("!", "&#33;")
78
+ line = line.replace("(", "&#40;")
79
+ line = line.replace(")", "&#41;")
80
+ line = line.replace("$", "&#36;")
81
+ lines[i] = "<br>"+line
82
+ text = "".join(lines)
83
+ return text
84
+
85
+
86
+ def re_predict(
87
+ input,
88
+ image_path,
89
+ chatbot,
90
+ max_length,
91
+ top_p,
92
+ temperature,
93
+ history,
94
+ modality_cache,
95
+ ):
96
+ # drop the latest query and answers and generate again
97
+ q, a = history.pop()
98
+ chatbot.pop()
99
+ return predict(q, image_path, chatbot, max_length, top_p, temperature, history, modality_cache)
100
+
101
+
102
+ def predict(
103
+ input,
104
+ image_path,
105
+ chatbot,
106
+ max_length,
107
+ top_p,
108
+ temperature,
109
+ history,
110
+ modality_cache,
111
+ ):
112
+ if image_path is None: #
113
+ return [(input, "There is no input data provided! Please upload your data and start the conversation.")]
114
+ else:
115
+ print(f'[!] image path: {image_path}\n') # [!] audio path: {audio_path}\n[!] video path: {video_path}\n[!] thermal path: {thermal_path}')
116
+
117
+ # prepare the prompt
118
+ prompt_text = ''
119
+ for idx, (q, a) in enumerate(history):
120
+ if idx == 0:
121
+ prompt_text += f'{q}\n### Assistant: {a}\n###'
122
+ else:
123
+ prompt_text += f' Human: {q}\n### Assistant: {a}\n###'
124
+ if len(history) == 0:
125
+ prompt_text += f'{input}'
126
+ else:
127
+ prompt_text += f' Human: {input}'
128
+
129
+ response = model.generate({
130
+ 'prompt': prompt_text,
131
+ 'image_paths': [image_path] if image_path else [],
132
+ # 'audio_paths': [audio_path] if audio_path else [],
133
+ # 'video_paths': [video_path] if video_path else [],
134
+ # 'thermal_paths': [thermal_path] if thermal_path else [],
135
+ 'top_p': top_p,
136
+ 'temperature': temperature,
137
+ 'max_tgt_len': max_length,
138
+ 'modality_embeds': modality_cache
139
+ })
140
+ chatbot.append((parse_text(input), parse_text(response)))
141
+ history.append((input, response))
142
+ return chatbot, history, modality_cache
143
+
144
+
145
+ def reset_user_input():
146
+ return gr.update(value='')
147
+
148
+ def reset_dialog():
149
+ return [], []
150
+
151
+ def reset_state():
152
+ return None, None, None, None, [], [], []
153
+
154
+
155
+ with gr.Blocks(scale=4) as demo:
156
+ gr.HTML("""<h1 align="center">PandaGPT</h1>""")
157
+
158
+ with gr.Row(scale=4):
159
+ with gr.Column(scale=1):
160
+ image_path = gr.Image(type="filepath", label="Image", value=None)
161
+ # with gr.Column(scale=1):
162
+ # audio_path = gr.Audio(type="filepath", label="Audio", value=None)
163
+ # with gr.Column(scale=1):
164
+ # video_path = gr.Video(type='file', label="Video")
165
+ # with gr.Column(scale=1):
166
+ # thermal_path = gr.Image(type="filepath", label="Thermal Image", value=None)
167
+
168
+ chatbot = gr.Chatbot().style(height=300)
169
+ with gr.Row():
170
+ with gr.Column(scale=4):
171
+ with gr.Column(scale=12):
172
+ user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style(container=False)
173
+ with gr.Column(min_width=32, scale=1):
174
+ with gr.Row(scale=1):
175
+ submitBtn = gr.Button("Submit", variant="primary")
176
+ with gr.Row(scale=1):
177
+ resubmitBtn = gr.Button("Resubmit", variant="primary")
178
+ with gr.Column(scale=1):
179
+ emptyBtn = gr.Button("Clear History")
180
+ max_length = gr.Slider(0, 400, value=256, step=1.0, label="Maximum length", interactive=True)
181
+ top_p = gr.Slider(0, 1, value=0.01, step=0.01, label="Top P", interactive=True)
182
+ temperature = gr.Slider(0, 1, value=1.0, step=0.01, label="Temperature", interactive=True)
183
+
184
+ history = gr.State([])
185
+ modality_cache = gr.State([])
186
+
187
+ submitBtn.click(
188
+ predict, [
189
+ user_input,
190
+ image_path,
191
+ # audio_path,
192
+ # video_path,
193
+ # thermal_path,
194
+ chatbot,
195
+ max_length,
196
+ top_p,
197
+ temperature,
198
+ history,
199
+ modality_cache,
200
+ ], [
201
+ chatbot,
202
+ history,
203
+ modality_cache
204
+ ],
205
+ show_progress=True
206
+ )
207
+
208
+ resubmitBtn.click(
209
+ re_predict, [
210
+ user_input,
211
+ image_path,
212
+ # audio_path,
213
+ # video_path,
214
+ # thermal_path,
215
+ chatbot,
216
+ max_length,
217
+ top_p,
218
+ temperature,
219
+ history,
220
+ modality_cache,
221
+ ], [
222
+ chatbot,
223
+ history,
224
+ modality_cache
225
+ ],
226
+ show_progress=True
227
+ )
228
+
229
+ submitBtn.click(reset_user_input, [], [user_input])
230
+ emptyBtn.click(reset_state, outputs=[
231
+ image_path,
232
+ # audio_path,
233
+ # video_path,
234
+ # thermal_path,
235
+ chatbot,
236
+ history,
237
+ modality_cache
238
+ ], show_progress=True)
239
+
240
+ demo.queue().launch(share=False, inbrowser=True, server_name='0.0.0.0', server_port=10050)
model/CLIP/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # remove fp32 LN & return intermediate features
2
+ from .clip import *
model/CLIP/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (196 Bytes). View file
 
model/CLIP/__pycache__/clip.cpython-310.pyc ADDED
Binary file (8.83 kB). View file
 
model/CLIP/__pycache__/model.cpython-310.pyc ADDED
Binary file (15.6 kB). View file
 
model/CLIP/__pycache__/simple_tokenizer.cpython-310.pyc ADDED
Binary file (5.73 kB). View file
 
model/CLIP/bpe_simple_vocab_16e6.txt.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
3
+ size 1356917
model/CLIP/clip.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import os
3
+ import urllib
4
+ import warnings
5
+ from typing import Any, Union, List
6
+ from pkg_resources import packaging
7
+
8
+ import torch
9
+ from PIL import Image
10
+ from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
11
+ from tqdm import tqdm
12
+
13
+ from .model import build_model
14
+ from .simple_tokenizer import SimpleTokenizer as _Tokenizer
15
+
16
+ try:
17
+ from torchvision.transforms import InterpolationMode
18
+ BICUBIC = InterpolationMode.BICUBIC
19
+ except ImportError:
20
+ BICUBIC = Image.BICUBIC
21
+
22
+
23
+ if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"):
24
+ warnings.warn("PyTorch version 1.7.1 or higher is recommended")
25
+
26
+
27
+ __all__ = ["available_models", "load", "tokenize"]
28
+ _tokenizer = _Tokenizer()
29
+
30
+ _MODELS = {
31
+ "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
32
+ "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
33
+ "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
34
+ "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
35
+ "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
36
+ "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
37
+ "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
38
+ "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
39
+ "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
40
+ }
41
+
42
+
43
+ def _download(url: str, root: str):
44
+ os.makedirs(root, exist_ok=True)
45
+ filename = os.path.basename(url)
46
+
47
+ expected_sha256 = url.split("/")[-2]
48
+ download_target = os.path.join(root, filename)
49
+
50
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
51
+ raise RuntimeError(f"{download_target} exists and is not a regular file")
52
+
53
+ if os.path.isfile(download_target):
54
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
55
+ return download_target
56
+ else:
57
+ warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
58
+
59
+ with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
60
+ with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
61
+ while True:
62
+ buffer = source.read(8192)
63
+ if not buffer:
64
+ break
65
+
66
+ output.write(buffer)
67
+ loop.update(len(buffer))
68
+
69
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
70
+ raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match")
71
+
72
+ return download_target
73
+
74
+
75
+ def _convert_image_to_rgb(image):
76
+ return image.convert("RGB")
77
+
78
+
79
+ def _transform(n_px):
80
+ return Compose([
81
+ Resize((n_px, n_px), interpolation=BICUBIC),
82
+ # CenterCrop(n_px),
83
+ _convert_image_to_rgb,
84
+ ToTensor(),
85
+ Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
86
+ ])
87
+
88
+
89
+ def available_models() -> List[str]:
90
+ """Returns the names of available CLIP models"""
91
+ return list(_MODELS.keys())
92
+
93
+
94
+ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None):
95
+ """Load a CLIP model
96
+
97
+ Parameters
98
+ ----------
99
+ name : str
100
+ A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
101
+
102
+ device : Union[str, torch.device]
103
+ The device to put the loaded model
104
+
105
+ jit : bool
106
+ Whether to load the optimized JIT model or more hackable non-JIT model (default).
107
+
108
+ download_root: str
109
+ path to download the model files; by default, it uses "~/.cache/clip"
110
+
111
+ Returns
112
+ -------
113
+ model : torch.nn.Module
114
+ The CLIP model
115
+
116
+ preprocess : Callable[[PIL.Image], torch.Tensor]
117
+ A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
118
+ """
119
+ if name in _MODELS:
120
+ model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
121
+ elif os.path.isfile(name):
122
+ model_path = name
123
+ else:
124
+ raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
125
+
126
+ with open(model_path, 'rb') as opened_file:
127
+ try:
128
+ # loading JIT archive
129
+ model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval()
130
+ state_dict = None
131
+ except RuntimeError:
132
+ # loading saved state dict
133
+ if jit:
134
+ warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
135
+ jit = False
136
+ state_dict = torch.load(opened_file, map_location="cpu")
137
+
138
+ if not jit:
139
+ model = build_model(state_dict or model.state_dict()).to(device)
140
+ if str(device) == "cpu":
141
+ model.float()
142
+ return model, _transform(model.visual.input_resolution)
143
+
144
+ # patch the device names
145
+ device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
146
+ device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
147
+
148
+ def patch_device(module):
149
+ try:
150
+ graphs = [module.graph] if hasattr(module, "graph") else []
151
+ except RuntimeError:
152
+ graphs = []
153
+
154
+ if hasattr(module, "forward1"):
155
+ graphs.append(module.forward1.graph)
156
+
157
+ for graph in graphs:
158
+ for node in graph.findAllNodes("prim::Constant"):
159
+ if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
160
+ node.copyAttributes(device_node)
161
+
162
+ model.apply(patch_device)
163
+ patch_device(model.encode_image)
164
+ patch_device(model.encode_text)
165
+
166
+ # patch dtype to float32 on CPU
167
+ if str(device) == "cpu":
168
+ float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
169
+ float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
170
+ float_node = float_input.node()
171
+
172
+ def patch_float(module):
173
+ try:
174
+ graphs = [module.graph] if hasattr(module, "graph") else []
175
+ except RuntimeError:
176
+ graphs = []
177
+
178
+ if hasattr(module, "forward1"):
179
+ graphs.append(module.forward1.graph)
180
+
181
+ for graph in graphs:
182
+ for node in graph.findAllNodes("aten::to"):
183
+ inputs = list(node.inputs())
184
+ for i in [1, 2]: # dtype can be the second or third argument to aten::to()
185
+ if inputs[i].node()["value"] == 5:
186
+ inputs[i].node().copyAttributes(float_node)
187
+
188
+ model.apply(patch_float)
189
+ patch_float(model.encode_image)
190
+ patch_float(model.encode_text)
191
+
192
+ model.float()
193
+
194
+ return model, _transform(model.input_resolution.item())
195
+
196
+
197
+ def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]:
198
+ """
199
+ Returns the tokenized representation of given input string(s)
200
+
201
+ Parameters
202
+ ----------
203
+ texts : Union[str, List[str]]
204
+ An input string or a list of input strings to tokenize
205
+
206
+ context_length : int
207
+ The context length to use; all CLIP models use 77 as the context length
208
+
209
+ truncate: bool
210
+ Whether to truncate the text in case its encoding is longer than the context length
211
+
212
+ Returns
213
+ -------
214
+ A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
215
+ We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
216
+ """
217
+ if isinstance(texts, str):
218
+ texts = [texts]
219
+
220
+ sot_token = _tokenizer.encoder["<|startoftext|>"]
221
+ eot_token = _tokenizer.encoder["<|endoftext|>"]
222
+ all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
223
+ if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
224
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
225
+ else:
226
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
227
+
228
+ for i, tokens in enumerate(all_tokens):
229
+ if len(tokens) > context_length:
230
+ if truncate:
231
+ tokens = tokens[:context_length]
232
+ tokens[-1] = eot_token
233
+ else:
234
+ raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
235
+ result[i, :len(tokens)] = torch.tensor(tokens)
236
+
237
+ return result
model/CLIP/model.py ADDED
@@ -0,0 +1,449 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ from typing import Tuple, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+
9
+
10
+ class Bottleneck(nn.Module):
11
+ expansion = 4
12
+
13
+ def __init__(self, inplanes, planes, stride=1):
14
+ super().__init__()
15
+
16
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
17
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
18
+ self.bn1 = nn.BatchNorm2d(planes)
19
+ self.relu1 = nn.ReLU(inplace=True)
20
+
21
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
22
+ self.bn2 = nn.BatchNorm2d(planes)
23
+ self.relu2 = nn.ReLU(inplace=True)
24
+
25
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
26
+
27
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
28
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
29
+ self.relu3 = nn.ReLU(inplace=True)
30
+
31
+ self.downsample = None
32
+ self.stride = stride
33
+
34
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
35
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
36
+ self.downsample = nn.Sequential(OrderedDict([
37
+ ("-1", nn.AvgPool2d(stride)),
38
+ ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
39
+ ("1", nn.BatchNorm2d(planes * self.expansion))
40
+ ]))
41
+
42
+ def forward(self, x: torch.Tensor):
43
+ identity = x
44
+
45
+ out = self.relu1(self.bn1(self.conv1(x)))
46
+ out = self.relu2(self.bn2(self.conv2(out)))
47
+ out = self.avgpool(out)
48
+ out = self.bn3(self.conv3(out))
49
+
50
+ if self.downsample is not None:
51
+ identity = self.downsample(x)
52
+
53
+ out += identity
54
+ out = self.relu3(out)
55
+ return out
56
+
57
+
58
+ class AttentionPool2d(nn.Module):
59
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
60
+ super().__init__()
61
+ self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
62
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
63
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
64
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
65
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
66
+ self.num_heads = num_heads
67
+
68
+ def forward(self, x):
69
+ x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
70
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
71
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
72
+ x, _ = F.multi_head_attention_forward(
73
+ query=x[:1], key=x, value=x,
74
+ embed_dim_to_check=x.shape[-1],
75
+ num_heads=self.num_heads,
76
+ q_proj_weight=self.q_proj.weight,
77
+ k_proj_weight=self.k_proj.weight,
78
+ v_proj_weight=self.v_proj.weight,
79
+ in_proj_weight=None,
80
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
81
+ bias_k=None,
82
+ bias_v=None,
83
+ add_zero_attn=False,
84
+ dropout_p=0,
85
+ out_proj_weight=self.c_proj.weight,
86
+ out_proj_bias=self.c_proj.bias,
87
+ use_separate_proj_weight=True,
88
+ training=self.training,
89
+ need_weights=False
90
+ )
91
+ return x.squeeze(0)
92
+
93
+
94
+ class ModifiedResNet(nn.Module):
95
+ """
96
+ A ResNet class that is similar to torchvision's but contains the following changes:
97
+ - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
98
+ - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
99
+ - The final pooling layer is a QKV attention instead of an average pool
100
+ """
101
+
102
+ def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
103
+ super().__init__()
104
+ self.output_dim = output_dim
105
+ self.input_resolution = input_resolution
106
+
107
+ # the 3-layer stem
108
+ self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
109
+ self.bn1 = nn.BatchNorm2d(width // 2)
110
+ self.relu1 = nn.ReLU(inplace=True)
111
+ self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
112
+ self.bn2 = nn.BatchNorm2d(width // 2)
113
+ self.relu2 = nn.ReLU(inplace=True)
114
+ self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
115
+ self.bn3 = nn.BatchNorm2d(width)
116
+ self.relu3 = nn.ReLU(inplace=True)
117
+ self.avgpool = nn.AvgPool2d(2)
118
+
119
+ # residual layers
120
+ self._inplanes = width # this is a *mutable* variable used during construction
121
+ self.layer1 = self._make_layer(width, layers[0])
122
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
123
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
124
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
125
+
126
+ embed_dim = width * 32 # the ResNet feature dimension
127
+ self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
128
+
129
+ def _make_layer(self, planes, blocks, stride=1):
130
+ layers = [Bottleneck(self._inplanes, planes, stride)]
131
+
132
+ self._inplanes = planes * Bottleneck.expansion
133
+ for _ in range(1, blocks):
134
+ layers.append(Bottleneck(self._inplanes, planes))
135
+
136
+ return nn.Sequential(*layers)
137
+
138
+ def forward(self, x):
139
+ def stem(x):
140
+ x = self.relu1(self.bn1(self.conv1(x)))
141
+ x = self.relu2(self.bn2(self.conv2(x)))
142
+ x = self.relu3(self.bn3(self.conv3(x)))
143
+ x = self.avgpool(x)
144
+ return x
145
+
146
+ x = x.type(self.conv1.weight.dtype)
147
+ x = stem(x)
148
+ x = self.layer1(x)
149
+ x = self.layer2(x)
150
+ x = self.layer3(x)
151
+ x = self.layer4(x)
152
+ x = self.attnpool(x)
153
+
154
+ return x
155
+
156
+
157
+ class LayerNorm(nn.LayerNorm):
158
+ """Subclass torch's LayerNorm to handle fp16."""
159
+
160
+ def forward(self, x: torch.Tensor):
161
+ # orig_type = x.dtype
162
+ ret = super().forward(x) # .type(torch.float16)) # Warning: Originally avoid fp32 in clip
163
+ return ret # .type(orig_type)
164
+
165
+
166
+ class QuickGELU(nn.Module):
167
+ def forward(self, x: torch.Tensor):
168
+ return x * torch.sigmoid(1.702 * x)
169
+
170
+
171
+ class ResidualAttentionBlock(nn.Module):
172
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
173
+ super().__init__()
174
+
175
+ self.attn = nn.MultiheadAttention(d_model, n_head)
176
+ self.ln_1 = LayerNorm(d_model)
177
+ self.mlp = nn.Sequential(OrderedDict([
178
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
179
+ ("gelu", QuickGELU()),
180
+ ("c_proj", nn.Linear(d_model * 4, d_model))
181
+ ]))
182
+ self.ln_2 = LayerNorm(d_model)
183
+ self.attn_mask = attn_mask
184
+
185
+ def attention(self, x: torch.Tensor):
186
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
187
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
188
+
189
+ def forward(self, x: torch.Tensor):
190
+ x = x + self.attention(self.ln_1(x))
191
+ x = x + self.mlp(self.ln_2(x))
192
+ return x
193
+
194
+
195
+ class Transformer(nn.Module):
196
+ def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
197
+ super().__init__()
198
+ self.width = width
199
+ self.layers = layers
200
+ self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
201
+
202
+ def forward(self, x: torch.Tensor):
203
+ return self.resblocks(x)
204
+
205
+
206
+ class VisionTransformer(nn.Module):
207
+ def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
208
+ super().__init__()
209
+ self.input_resolution = input_resolution
210
+ self.output_dim = output_dim
211
+ self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
212
+
213
+ scale = width ** -0.5
214
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
215
+ self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
216
+ self.ln_pre = LayerNorm(width)
217
+
218
+ self.transformer = Transformer(width, layers, heads)
219
+
220
+ self.ln_post = LayerNorm(width)
221
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
222
+
223
+ def forward(self, x: torch.Tensor):
224
+ x = self.conv1(x) # shape = [*, width, grid, grid]
225
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
226
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
227
+ x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
228
+ x = x + self.positional_embedding.to(x.dtype)
229
+ x = self.ln_pre(x)
230
+
231
+ x = x.permute(1, 0, 2) # NLD -> LND
232
+ x = self.transformer(x)
233
+ x = x.permute(1, 0, 2) # LND -> NLD
234
+
235
+ x = self.ln_post(x[:, 0, :])
236
+
237
+ if self.proj is not None:
238
+ x = x @ self.proj
239
+
240
+ return x
241
+
242
+ def forward_patch_features(self, x: torch.Tensor):
243
+ x = self.conv1(x) # shape = [*, width, grid, grid]
244
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
245
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
246
+ x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
247
+ x = x + self.positional_embedding.to(x.dtype)
248
+ x = self.ln_pre(x)
249
+
250
+ x = x.permute(1, 0, 2) # NLD -> LND
251
+ x = self.transformer(x)
252
+ x = x.permute(1, 0, 2) # LND -> NLD
253
+ return x[:, 1:, :]
254
+
255
+
256
+ class CLIP(nn.Module):
257
+ def __init__(self,
258
+ embed_dim: int,
259
+ # vision
260
+ image_resolution: int,
261
+ vision_layers: Union[Tuple[int, int, int, int], int],
262
+ vision_width: int,
263
+ vision_patch_size: int,
264
+ # text
265
+ context_length: int,
266
+ vocab_size: int,
267
+ transformer_width: int,
268
+ transformer_heads: int,
269
+ transformer_layers: int
270
+ ):
271
+ super().__init__()
272
+
273
+ self.context_length = context_length
274
+
275
+ if isinstance(vision_layers, (tuple, list)):
276
+ vision_heads = vision_width * 32 // 64
277
+ self.visual = ModifiedResNet(
278
+ layers=vision_layers,
279
+ output_dim=embed_dim,
280
+ heads=vision_heads,
281
+ input_resolution=image_resolution,
282
+ width=vision_width
283
+ )
284
+ else:
285
+ vision_heads = vision_width // 64
286
+ self.visual = VisionTransformer(
287
+ input_resolution=image_resolution,
288
+ patch_size=vision_patch_size,
289
+ width=vision_width,
290
+ layers=vision_layers,
291
+ heads=vision_heads,
292
+ output_dim=embed_dim
293
+ )
294
+
295
+ self.transformer = Transformer(
296
+ width=transformer_width,
297
+ layers=transformer_layers,
298
+ heads=transformer_heads,
299
+ attn_mask=self.build_attention_mask()
300
+ )
301
+
302
+ self.vocab_size = vocab_size
303
+ self.token_embedding = nn.Embedding(vocab_size, transformer_width)
304
+ self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
305
+ self.ln_final = LayerNorm(transformer_width)
306
+
307
+ self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
308
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
309
+
310
+ self.initialize_parameters()
311
+
312
+ def initialize_parameters(self):
313
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
314
+ nn.init.normal_(self.positional_embedding, std=0.01)
315
+
316
+ if isinstance(self.visual, ModifiedResNet):
317
+ if self.visual.attnpool is not None:
318
+ std = self.visual.attnpool.c_proj.in_features ** -0.5
319
+ nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
320
+ nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
321
+ nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
322
+ nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
323
+
324
+ for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
325
+ for name, param in resnet_block.named_parameters():
326
+ if name.endswith("bn3.weight"):
327
+ nn.init.zeros_(param)
328
+
329
+ proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
330
+ attn_std = self.transformer.width ** -0.5
331
+ fc_std = (2 * self.transformer.width) ** -0.5
332
+ for block in self.transformer.resblocks:
333
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
334
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
335
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
336
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
337
+
338
+ if self.text_projection is not None:
339
+ nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
340
+
341
+ def build_attention_mask(self):
342
+ # lazily create causal attention mask, with full attention between the vision tokens
343
+ # pytorch uses additive attention mask; fill with -inf
344
+ mask = torch.empty(self.context_length, self.context_length)
345
+ mask.fill_(float("-inf"))
346
+ mask.triu_(1) # zero out the lower diagonal
347
+ return mask
348
+
349
+ @property
350
+ def dtype(self):
351
+ return self.visual.conv1.weight.dtype
352
+
353
+ def encode_image(self, image):
354
+ return self.visual(image.type(self.dtype))
355
+
356
+ def encode_text(self, text):
357
+ x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
358
+
359
+ x = x + self.positional_embedding.type(self.dtype)
360
+ x = x.permute(1, 0, 2) # NLD -> LND
361
+ x = self.transformer(x)
362
+ x = x.permute(1, 0, 2) # LND -> NLD
363
+ x = self.ln_final(x).type(self.dtype)
364
+
365
+ # x.shape = [batch_size, n_ctx, transformer.width]
366
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
367
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
368
+
369
+ return x
370
+
371
+ def forward(self, image, text):
372
+ image_features = self.encode_image(image)
373
+ text_features = self.encode_text(text)
374
+
375
+ # normalized features
376
+ image_features = image_features / image_features.norm(dim=1, keepdim=True)
377
+ text_features = text_features / text_features.norm(dim=1, keepdim=True)
378
+
379
+ # cosine similarity as logits
380
+ logit_scale = self.logit_scale.exp()
381
+ logits_per_image = logit_scale * image_features @ text_features.t()
382
+ logits_per_text = logits_per_image.t()
383
+
384
+ # shape = [global_batch_size, global_batch_size]
385
+ return logits_per_image, logits_per_text
386
+
387
+
388
+ def convert_weights(model: nn.Module):
389
+ """Convert applicable model parameters to fp16"""
390
+
391
+ def _convert_weights_to_fp16(l):
392
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
393
+ l.weight.data = l.weight.data.half()
394
+ if l.bias is not None:
395
+ l.bias.data = l.bias.data.half()
396
+
397
+ if isinstance(l, nn.MultiheadAttention):
398
+ for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
399
+ tensor = getattr(l, attr)
400
+ if tensor is not None:
401
+ tensor.data = tensor.data.half()
402
+
403
+ for name in ["text_projection", "proj"]:
404
+ if hasattr(l, name):
405
+ attr = getattr(l, name)
406
+ if attr is not None:
407
+ attr.data = attr.data.half()
408
+
409
+ model.apply(_convert_weights_to_fp16)
410
+
411
+
412
+ def build_model(state_dict: dict):
413
+ vit = "visual.proj" in state_dict
414
+
415
+ if vit:
416
+ vision_width = state_dict["visual.conv1.weight"].shape[0]
417
+ vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
418
+ vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
419
+ grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
420
+ image_resolution = vision_patch_size * grid_size
421
+ else:
422
+ counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
423
+ vision_layers = tuple(counts)
424
+ vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
425
+ output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
426
+ vision_patch_size = None
427
+ assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
428
+ image_resolution = output_width * 32
429
+
430
+ embed_dim = state_dict["text_projection"].shape[1]
431
+ context_length = state_dict["positional_embedding"].shape[0]
432
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
433
+ transformer_width = state_dict["ln_final.weight"].shape[0]
434
+ transformer_heads = transformer_width // 64
435
+ transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks")))
436
+
437
+ model = CLIP(
438
+ embed_dim,
439
+ image_resolution, vision_layers, vision_width, vision_patch_size,
440
+ context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
441
+ )
442
+
443
+ for key in ["input_resolution", "context_length", "vocab_size"]:
444
+ if key in state_dict:
445
+ del state_dict[key]
446
+
447
+ convert_weights(model)
448
+ model.load_state_dict(state_dict)
449
+ return model.eval()
model/CLIP/simple_tokenizer.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gzip
2
+ import html
3
+ import os
4
+ from functools import lru_cache
5
+
6
+ import ftfy
7
+ import regex as re
8
+
9
+
10
+ @lru_cache()
11
+ def default_bpe():
12
+ return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
13
+
14
+
15
+ @lru_cache()
16
+ def bytes_to_unicode():
17
+ """
18
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
19
+ The reversible bpe codes work on unicode strings.
20
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
21
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
22
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
23
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
24
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
25
+ """
26
+ bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
27
+ cs = bs[:]
28
+ n = 0
29
+ for b in range(2**8):
30
+ if b not in bs:
31
+ bs.append(b)
32
+ cs.append(2**8+n)
33
+ n += 1
34
+ cs = [chr(n) for n in cs]
35
+ return dict(zip(bs, cs))
36
+
37
+
38
+ def get_pairs(word):
39
+ """Return set of symbol pairs in a word.
40
+ Word is represented as tuple of symbols (symbols being variable-length strings).
41
+ """
42
+ pairs = set()
43
+ prev_char = word[0]
44
+ for char in word[1:]:
45
+ pairs.add((prev_char, char))
46
+ prev_char = char
47
+ return pairs
48
+
49
+
50
+ def basic_clean(text):
51
+ text = ftfy.fix_text(text)
52
+ text = html.unescape(html.unescape(text))
53
+ return text.strip()
54
+
55
+
56
+ def whitespace_clean(text):
57
+ text = re.sub(r'\s+', ' ', text)
58
+ text = text.strip()
59
+ return text
60
+
61
+
62
+ class SimpleTokenizer(object):
63
+ def __init__(self, bpe_path: str = default_bpe()):
64
+ self.byte_encoder = bytes_to_unicode()
65
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
66
+ merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
67
+ merges = merges[1:49152-256-2+1]
68
+ merges = [tuple(merge.split()) for merge in merges]
69
+ vocab = list(bytes_to_unicode().values())
70
+ vocab = vocab + [v+'</w>' for v in vocab]
71
+ for merge in merges:
72
+ vocab.append(''.join(merge))
73
+ vocab.extend(['<|startoftext|>', '<|endoftext|>'])
74
+ self.encoder = dict(zip(vocab, range(len(vocab))))
75
+ self.decoder = {v: k for k, v in self.encoder.items()}
76
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
77
+ self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
78
+ self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
79
+
80
+ def bpe(self, token):
81
+ if token in self.cache:
82
+ return self.cache[token]
83
+ word = tuple(token[:-1]) + ( token[-1] + '</w>',)
84
+ pairs = get_pairs(word)
85
+
86
+ if not pairs:
87
+ return token+'</w>'
88
+
89
+ while True:
90
+ bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
91
+ if bigram not in self.bpe_ranks:
92
+ break
93
+ first, second = bigram
94
+ new_word = []
95
+ i = 0
96
+ while i < len(word):
97
+ try:
98
+ j = word.index(first, i)
99
+ new_word.extend(word[i:j])
100
+ i = j
101
+ except:
102
+ new_word.extend(word[i:])
103
+ break
104
+
105
+ if word[i] == first and i < len(word)-1 and word[i+1] == second:
106
+ new_word.append(first+second)
107
+ i += 2
108
+ else:
109
+ new_word.append(word[i])
110
+ i += 1
111
+ new_word = tuple(new_word)
112
+ word = new_word
113
+ if len(word) == 1:
114
+ break
115
+ else:
116
+ pairs = get_pairs(word)
117
+ word = ' '.join(word)
118
+ self.cache[token] = word
119
+ return word
120
+
121
+ def encode(self, text):
122
+ bpe_tokens = []
123
+ text = whitespace_clean(basic_clean(text)).lower()
124
+ for token in re.findall(self.pat, text):
125
+ token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
126
+ bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
127
+ return bpe_tokens
128
+
129
+ def decode(self, tokens):
130
+ text = ''.join([self.decoder[token] for token in tokens])
131
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
132
+ return text
model/PROCESS/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ import data
model/PROCESS/bpe/bpe_simple_vocab_16e6.txt.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
3
+ size 1356917
model/PROCESS/data.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Portions Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ import io
9
+ import os
10
+ import math
11
+ import requests
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ import torchaudio
16
+ import logging
17
+
18
+ # from .ImageBind.models.multimodal_preprocessors import SimpleTokenizer
19
+ from .multimodal_preprocessors import SimpleTokenizer
20
+ from PIL import Image
21
+ # from pytorchvideo import transforms as pv_transforms
22
+ # from pytorchvideo.data.clip_sampling import ConstantClipsPerVideoSampler
23
+ # from pytorchvideo.data.encoded_video import EncodedVideo
24
+
25
+ from torchvision import transforms
26
+ from torchvision.transforms._transforms_video import NormalizeVideo
27
+
28
+ DEFAULT_AUDIO_FRAME_SHIFT_MS = 10 # in milliseconds
29
+
30
+ BPE_PATH = "bpe/bpe_simple_vocab_16e6.txt.gz"
31
+
32
+
33
+ def waveform2melspec(waveform, sample_rate, num_mel_bins, target_length):
34
+ # Based on https://github.com/YuanGongND/ast/blob/d7d8b4b8e06cdaeb6c843cdb38794c1c7692234c/src/dataloader.py#L102
35
+ waveform -= waveform.mean()
36
+ fbank = torchaudio.compliance.kaldi.fbank(
37
+ waveform,
38
+ htk_compat=True,
39
+ sample_frequency=sample_rate,
40
+ use_energy=False,
41
+ window_type="hanning",
42
+ num_mel_bins=num_mel_bins,
43
+ dither=0.0,
44
+ frame_length=25,
45
+ frame_shift=DEFAULT_AUDIO_FRAME_SHIFT_MS,
46
+ )
47
+ # Convert to [mel_bins, num_frames] shape
48
+ fbank = fbank.transpose(0, 1)
49
+ # Pad to target_length
50
+ n_frames = fbank.size(1)
51
+ p = target_length - n_frames
52
+ # if p is too large (say >20%), flash a warning
53
+ if abs(p) / n_frames > 0.2:
54
+ logging.warning(
55
+ "Large gap between audio n_frames(%d) and "
56
+ "target_length (%d). Is the audio_target_length "
57
+ "setting correct?",
58
+ n_frames,
59
+ target_length,
60
+ )
61
+ # cut and pad
62
+ if p > 0:
63
+ fbank = torch.nn.functional.pad(fbank, (0, p), mode="constant", value=0)
64
+ elif p < 0:
65
+ fbank = fbank[:, 0:target_length]
66
+ # Convert to [1, mel_bins, num_frames] shape, essentially like a 1
67
+ # channel image
68
+ fbank = fbank.unsqueeze(0)
69
+ return fbank
70
+
71
+
72
+ def get_clip_timepoints(clip_sampler, duration):
73
+ # Read out all clips in this video
74
+ all_clips_timepoints = []
75
+ is_last_clip = False
76
+ end = 0.0
77
+ while not is_last_clip:
78
+ start, end, _, _, is_last_clip = clip_sampler(end, duration, annotation=None)
79
+ all_clips_timepoints.append((start, end))
80
+ return all_clips_timepoints
81
+
82
+
83
+ def load_and_transform_vision_data(image_paths, device, client=None):
84
+ if image_paths is None:
85
+ return None
86
+
87
+ image_ouputs = []
88
+ for image_path in image_paths:
89
+ data_transform = transforms.Compose(
90
+ [
91
+ transforms.Resize(
92
+ 224, interpolation=transforms.InterpolationMode.BICUBIC
93
+ ),
94
+ transforms.CenterCrop(224),
95
+ transforms.ToTensor(),
96
+ transforms.Normalize(
97
+ mean=(0.48145466, 0.4578275, 0.40821073),
98
+ std=(0.26862954, 0.26130258, 0.27577711),
99
+ ),
100
+ ]
101
+ )
102
+ if os.path.exists(image_path):
103
+ with open(image_path, "rb") as fopen:
104
+ image = Image.open(fopen).convert("RGB")
105
+ elif image_path.startswith("s3://") and client is not None:
106
+ image = Image.open(io.BytesIO(client.get(image_path))).convert("RGB")
107
+ elif image_path.startswith("http"):
108
+ image = Image.open(requests.get(image_path, stream=True).raw).convert(
109
+ "RGB"
110
+ )
111
+ else:
112
+ raise ValueError(f"Invalid image path: {image_path}")
113
+
114
+ image = data_transform(image).to(device)
115
+ image_ouputs.append(image)
116
+ return torch.stack(image_ouputs, dim=0)
117
+
118
+ def transform_vision_data(images, device):
119
+ image_ouputs = []
120
+ for img in images:
121
+ data_transform = transforms.Compose(
122
+ [
123
+ transforms.Resize(
124
+ 224, interpolation=transforms.InterpolationMode.BICUBIC
125
+ ),
126
+ transforms.CenterCrop(224),
127
+ transforms.ToTensor(),
128
+ transforms.Normalize(
129
+ mean=(0.48145466, 0.4578275, 0.40821073),
130
+ std=(0.26862954, 0.26130258, 0.27577711),
131
+ ),
132
+ ]
133
+ )
134
+ image = data_transform(img).to(device)
135
+ image_ouputs.append(image)
136
+ return torch.stack(image_ouputs, dim=0)
137
+
138
+
139
+ def load_and_transform_thermal_data(thermal_paths, device):
140
+ if thermal_paths is None:
141
+ return None
142
+
143
+ thermal_ouputs = []
144
+ for thermal_path in thermal_paths:
145
+ data_transform = transforms.Compose(
146
+ [
147
+ transforms.Resize(
148
+ 224, interpolation=transforms.InterpolationMode.BICUBIC
149
+ ),
150
+ transforms.CenterCrop(224),
151
+ transforms.ToTensor(),
152
+ ]
153
+ )
154
+ with open(thermal_path, "rb") as fopen:
155
+ thermal = Image.open(fopen).convert("L")
156
+ thermal = data_transform(thermal).to(device)
157
+ thermal_ouputs.append(thermal)
158
+ return torch.stack(thermal_ouputs, dim=0)
159
+
160
+
161
+ def load_and_transform_text(text, device):
162
+ if text is None:
163
+ return None
164
+ tokenizer = SimpleTokenizer(bpe_path=BPE_PATH)
165
+ tokens = [tokenizer(t).unsqueeze(0).to(device) for t in text]
166
+ tokens = torch.cat(tokens, dim=0)
167
+ return tokens
168
+
169
+
170
+ # def load_and_transform_audio_data(
171
+ # audio_paths,
172
+ # device,
173
+ # num_mel_bins=128,
174
+ # target_length=204,
175
+ # sample_rate=16000,
176
+ # clip_duration=2,
177
+ # clips_per_video=3,
178
+ # mean=-4.268,
179
+ # std=9.138,
180
+ # ):
181
+ # if audio_paths is None:
182
+ # return None
183
+
184
+ # audio_outputs = []
185
+ # clip_sampler = ConstantClipsPerVideoSampler(
186
+ # clip_duration=clip_duration, clips_per_video=clips_per_video
187
+ # )
188
+
189
+ # for audio_path in audio_paths:
190
+ # waveform, sr = torchaudio.load(audio_path)
191
+ # if sample_rate != sr:
192
+ # waveform = torchaudio.functional.resample(
193
+ # waveform, orig_freq=sr, new_freq=sample_rate
194
+ # )
195
+ # all_clips_timepoints = get_clip_timepoints(
196
+ # clip_sampler, waveform.size(1) / sample_rate
197
+ # )
198
+ # all_clips = []
199
+ # for clip_timepoints in all_clips_timepoints:
200
+ # waveform_clip = waveform[
201
+ # :,
202
+ # int(clip_timepoints[0] * sample_rate) : int(
203
+ # clip_timepoints[1] * sample_rate
204
+ # ),
205
+ # ]
206
+ # waveform_melspec = waveform2melspec(
207
+ # waveform_clip, sample_rate, num_mel_bins, target_length
208
+ # )
209
+ # all_clips.append(waveform_melspec)
210
+
211
+ # normalize = transforms.Normalize(mean=mean, std=std)
212
+ # all_clips = [normalize(ac).to(device) for ac in all_clips]
213
+
214
+ # all_clips = torch.stack(all_clips, dim=0)
215
+ # audio_outputs.append(all_clips)
216
+
217
+ # return torch.stack(audio_outputs, dim=0)
218
+
219
+
220
+ def get_clip_timepoints(clip_sampler, duration):
221
+ # Read out all clips in this video
222
+ all_clips_timepoints = []
223
+ is_last_clip = False
224
+ end = 0.0
225
+ while not is_last_clip:
226
+ start, end, _, _, is_last_clip = clip_sampler(end, duration, annotation=None)
227
+ all_clips_timepoints.append((start, end))
228
+ return all_clips_timepoints
229
+
230
+
231
+ def crop_boxes(boxes, x_offset, y_offset):
232
+ """
233
+ Peform crop on the bounding boxes given the offsets.
234
+ Args:
235
+ boxes (ndarray or None): bounding boxes to peform crop. The dimension
236
+ is `num boxes` x 4.
237
+ x_offset (int): cropping offset in the x axis.
238
+ y_offset (int): cropping offset in the y axis.
239
+ Returns:
240
+ cropped_boxes (ndarray or None): the cropped boxes with dimension of
241
+ `num boxes` x 4.
242
+ """
243
+ cropped_boxes = boxes.copy()
244
+ cropped_boxes[:, [0, 2]] = boxes[:, [0, 2]] - x_offset
245
+ cropped_boxes[:, [1, 3]] = boxes[:, [1, 3]] - y_offset
246
+
247
+ return cropped_boxes
248
+
249
+
250
+ def uniform_crop(images, size, spatial_idx, boxes=None, scale_size=None):
251
+ """
252
+ Perform uniform spatial sampling on the images and corresponding boxes.
253
+ Args:
254
+ images (tensor): images to perform uniform crop. The dimension is
255
+ `num frames` x `channel` x `height` x `width`.
256
+ size (int): size of height and weight to crop the images.
257
+ spatial_idx (int): 0, 1, or 2 for left, center, and right crop if width
258
+ is larger than height. Or 0, 1, or 2 for top, center, and bottom
259
+ crop if height is larger than width.
260
+ boxes (ndarray or None): optional. Corresponding boxes to images.
261
+ Dimension is `num boxes` x 4.
262
+ scale_size (int): optinal. If not None, resize the images to scale_size before
263
+ performing any crop.
264
+ Returns:
265
+ cropped (tensor): images with dimension of
266
+ `num frames` x `channel` x `size` x `size`.
267
+ cropped_boxes (ndarray or None): the cropped boxes with dimension of
268
+ `num boxes` x 4.
269
+ """
270
+ assert spatial_idx in [0, 1, 2]
271
+ ndim = len(images.shape)
272
+ if ndim == 3:
273
+ images = images.unsqueeze(0)
274
+ height = images.shape[2]
275
+ width = images.shape[3]
276
+
277
+ if scale_size is not None:
278
+ if width <= height:
279
+ width, height = scale_size, int(height / width * scale_size)
280
+ else:
281
+ width, height = int(width / height * scale_size), scale_size
282
+ images = torch.nn.functional.interpolate(
283
+ images,
284
+ size=(height, width),
285
+ mode="bilinear",
286
+ align_corners=False,
287
+ )
288
+
289
+ y_offset = int(math.ceil((height - size) / 2))
290
+ x_offset = int(math.ceil((width - size) / 2))
291
+
292
+ if height > width:
293
+ if spatial_idx == 0:
294
+ y_offset = 0
295
+ elif spatial_idx == 2:
296
+ y_offset = height - size
297
+ else:
298
+ if spatial_idx == 0:
299
+ x_offset = 0
300
+ elif spatial_idx == 2:
301
+ x_offset = width - size
302
+ cropped = images[:, :, y_offset : y_offset + size, x_offset : x_offset + size]
303
+ cropped_boxes = crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None
304
+ if ndim == 3:
305
+ cropped = cropped.squeeze(0)
306
+ return cropped, cropped_boxes
307
+
308
+
309
+ class SpatialCrop(nn.Module):
310
+ """
311
+ Convert the video into 3 smaller clips spatially. Must be used after the
312
+ temporal crops to get spatial crops, and should be used with
313
+ -2 in the spatial crop at the slowfast augmentation stage (so full
314
+ frames are passed in here). Will return a larger list with the
315
+ 3x spatial crops as well.
316
+ """
317
+
318
+ def __init__(self, crop_size: int = 224, num_crops: int = 3):
319
+ super().__init__()
320
+ self.crop_size = crop_size
321
+ if num_crops == 3:
322
+ self.crops_to_ext = [0, 1, 2]
323
+ self.flipped_crops_to_ext = []
324
+ elif num_crops == 1:
325
+ self.crops_to_ext = [1]
326
+ self.flipped_crops_to_ext = []
327
+ else:
328
+ raise NotImplementedError("Nothing else supported yet")
329
+
330
+ def forward(self, videos):
331
+ """
332
+ Args:
333
+ videos: A list of C, T, H, W videos.
334
+ Returns:
335
+ videos: A list with 3x the number of elements. Each video converted
336
+ to C, T, H', W' by spatial cropping.
337
+ """
338
+ assert isinstance(videos, list), "Must be a list of videos after temporal crops"
339
+ assert all([video.ndim == 4 for video in videos]), "Must be (C,T,H,W)"
340
+ res = []
341
+ for video in videos:
342
+ for spatial_idx in self.crops_to_ext:
343
+ res.append(uniform_crop(video, self.crop_size, spatial_idx)[0])
344
+ if not self.flipped_crops_to_ext:
345
+ continue
346
+ flipped_video = transforms.functional.hflip(video)
347
+ for spatial_idx in self.flipped_crops_to_ext:
348
+ res.append(uniform_crop(flipped_video, self.crop_size, spatial_idx)[0])
349
+ return res
350
+
351
+ """
352
+ def load_and_transform_video_data(
353
+ video_paths,
354
+ device,
355
+ clip_duration=2,
356
+ clips_per_video=5,
357
+ sample_rate=16000,
358
+ ):
359
+ if video_paths is None:
360
+ return None
361
+
362
+ video_outputs = []
363
+ video_transform = transforms.Compose(
364
+ [
365
+ pv_transforms.ShortSideScale(224),
366
+ NormalizeVideo(
367
+ mean=(0.48145466, 0.4578275, 0.40821073),
368
+ std=(0.26862954, 0.26130258, 0.27577711),
369
+ ),
370
+ ]
371
+ )
372
+
373
+ clip_sampler = ConstantClipsPerVideoSampler(
374
+ clip_duration=clip_duration, clips_per_video=clips_per_video
375
+ )
376
+ frame_sampler = pv_transforms.UniformTemporalSubsample(num_samples=clip_duration)
377
+
378
+ for video_path in video_paths:
379
+ video = EncodedVideo.from_path(
380
+ video_path,
381
+ decoder="decord",
382
+ decode_audio=False,
383
+ **{"sample_rate": sample_rate},
384
+ )
385
+
386
+ all_clips_timepoints = get_clip_timepoints(clip_sampler, video.duration)
387
+
388
+ all_video = []
389
+ for clip_timepoints in all_clips_timepoints:
390
+ # Read the clip, get frames
391
+ clip = video.get_clip(clip_timepoints[0], clip_timepoints[1])
392
+ if clip is None:
393
+ raise ValueError("No clip found")
394
+ video_clip = frame_sampler(clip["video"])
395
+ video_clip = video_clip / 255.0 # since this is float, need 0-1
396
+
397
+ all_video.append(video_clip)
398
+
399
+ all_video = [video_transform(clip) for clip in all_video]
400
+ all_video = SpatialCrop(224, num_crops=3)(all_video)
401
+
402
+ all_video = torch.stack(all_video, dim=0)
403
+ video_outputs.append(all_video)
404
+
405
+ return torch.stack(video_outputs, dim=0).to(device)
406
+ """
model/PROCESS/helpers.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Portions Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ import math
9
+
10
+ import einops
11
+ import numpy as np
12
+ import torch
13
+
14
+ import torch.nn as nn
15
+
16
+
17
+ class Normalize(nn.Module):
18
+ def __init__(self, dim: int) -> None:
19
+ super().__init__()
20
+ self.dim = dim
21
+
22
+ def forward(self, x):
23
+ return torch.nn.functional.normalize(x, dim=self.dim, p=2)
24
+
25
+
26
+ class LearnableLogitScaling(nn.Module):
27
+ def __init__(
28
+ self,
29
+ logit_scale_init: float = 1 / 0.07,
30
+ learnable: bool = True,
31
+ max_logit_scale: float = 100,
32
+ ) -> None:
33
+ super().__init__()
34
+ self.max_logit_scale = max_logit_scale
35
+ self.logit_scale_init = logit_scale_init
36
+ self.learnable = learnable
37
+ log_logit_scale = torch.ones([]) * np.log(self.logit_scale_init)
38
+ if learnable:
39
+ self.log_logit_scale = nn.Parameter(log_logit_scale)
40
+ else:
41
+ self.register_buffer("log_logit_scale", log_logit_scale)
42
+
43
+ def forward(self, x):
44
+ return torch.clip(self.log_logit_scale.exp(), max=self.max_logit_scale) * x
45
+
46
+ def extra_repr(self):
47
+ st = f"logit_scale_init={self.logit_scale_init},learnable={self.learnable}, max_logit_scale={self.max_logit_scale}"
48
+ return st
49
+
50
+
51
+ class EinOpsRearrange(nn.Module):
52
+ def __init__(self, rearrange_expr: str, **kwargs) -> None:
53
+ super().__init__()
54
+ self.rearrange_expr = rearrange_expr
55
+ self.kwargs = kwargs
56
+
57
+ def forward(self, x):
58
+ assert isinstance(x, torch.Tensor)
59
+ return einops.rearrange(x, self.rearrange_expr, **self.kwargs)
60
+
61
+
62
+ class VerboseNNModule(nn.Module):
63
+ """
64
+ Wrapper around nn.Module that prints registered buffers and parameter names.
65
+ """
66
+
67
+ @staticmethod
68
+ def get_readable_tensor_repr(name: str, tensor: torch.Tensor) -> str:
69
+ st = (
70
+ "("
71
+ + name
72
+ + "): "
73
+ + "tensor("
74
+ + str(tuple(tensor[1].shape))
75
+ + ", requires_grad="
76
+ + str(tensor[1].requires_grad)
77
+ + ")\n"
78
+ )
79
+ return st
80
+
81
+ def extra_repr(self) -> str:
82
+ named_modules = set()
83
+ for p in self.named_modules():
84
+ named_modules.update([p[0]])
85
+ named_modules = list(named_modules)
86
+
87
+ string_repr = ""
88
+ for p in self.named_parameters():
89
+ name = p[0].split(".")[0]
90
+ if name not in named_modules:
91
+ string_repr += self.get_readable_tensor_repr(name, p)
92
+
93
+ for p in self.named_buffers():
94
+ name = p[0].split(".")[0]
95
+ string_repr += self.get_readable_tensor_repr(name, p)
96
+
97
+ return string_repr
98
+
99
+
100
+ def cast_if_src_dtype(
101
+ tensor: torch.Tensor, src_dtype: torch.dtype, tgt_dtype: torch.dtype
102
+ ):
103
+ updated = False
104
+ if tensor.dtype == src_dtype:
105
+ tensor = tensor.to(dtype=tgt_dtype)
106
+ updated = True
107
+ return tensor, updated
108
+
109
+
110
+ class QuickGELU(nn.Module):
111
+ # From https://github.com/openai/CLIP/blob/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1/clip/model.py#L166
112
+ def forward(self, x: torch.Tensor):
113
+ return x * torch.sigmoid(1.702 * x)
114
+
115
+
116
+ class SelectElement(nn.Module):
117
+ def __init__(self, index) -> None:
118
+ super().__init__()
119
+ self.index = index
120
+
121
+ def forward(self, x):
122
+ assert x.ndim >= 3
123
+ return x[:, self.index, ...]
124
+
125
+
126
+ class SelectEOSAndProject(nn.Module):
127
+ """
128
+ Text Pooling used in OpenCLIP
129
+ """
130
+
131
+ def __init__(self, proj: nn.Module) -> None:
132
+ super().__init__()
133
+ self.proj = proj
134
+
135
+ def forward(self, x, seq_len):
136
+ assert x.ndim == 3
137
+ # x is of shape B x L x D
138
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
139
+ x = x[torch.arange(x.shape[0]), seq_len]
140
+ x = self.proj(x)
141
+ return x
model/PROCESS/multimodal_preprocessors.py ADDED
@@ -0,0 +1,687 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Portions Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ import gzip
9
+ import html
10
+ import io
11
+ import math
12
+ from functools import lru_cache
13
+ from typing import Callable, List, Optional
14
+
15
+ import ftfy
16
+
17
+ import numpy as np
18
+ import regex as re
19
+ import torch
20
+ import torch.nn as nn
21
+ from iopath.common.file_io import g_pathmgr
22
+ from timm.models.layers import trunc_normal_
23
+
24
+ from .helpers import cast_if_src_dtype, VerboseNNModule
25
+
26
+
27
+ def get_sinusoid_encoding_table(n_position, d_hid):
28
+ """Sinusoid position encoding table"""
29
+
30
+ # TODO: make it with torch instead of numpy
31
+ def get_position_angle_vec(position):
32
+ return [
33
+ position / np.power(10000, 2 * (hid_j // 2) / d_hid)
34
+ for hid_j in range(d_hid)
35
+ ]
36
+
37
+ sinusoid_table = np.array(
38
+ [get_position_angle_vec(pos_i) for pos_i in range(n_position)]
39
+ )
40
+ sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
41
+ sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
42
+
43
+ return torch.FloatTensor(sinusoid_table).unsqueeze(0)
44
+
45
+
46
+ def interpolate_pos_encoding_2d(target_spatial_size, pos_embed):
47
+ N = pos_embed.shape[1]
48
+ if N == target_spatial_size:
49
+ return pos_embed
50
+ dim = pos_embed.shape[-1]
51
+ # nn.functional.interpolate doesn't work with bfloat16 so we cast to float32
52
+ pos_embed, updated = cast_if_src_dtype(pos_embed, torch.bfloat16, torch.float32)
53
+ pos_embed = nn.functional.interpolate(
54
+ pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(
55
+ 0, 3, 1, 2
56
+ ),
57
+ scale_factor=math.sqrt(target_spatial_size / N),
58
+ mode="bicubic",
59
+ )
60
+ if updated:
61
+ pos_embed, _ = cast_if_src_dtype(pos_embed, torch.float32, torch.bfloat16)
62
+ pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
63
+ return pos_embed
64
+
65
+
66
+ def interpolate_pos_encoding(
67
+ npatch_per_img,
68
+ pos_embed,
69
+ patches_layout,
70
+ input_shape=None,
71
+ first_patch_idx=1,
72
+ ):
73
+ assert first_patch_idx == 0 or first_patch_idx == 1, "there is 1 CLS token or none"
74
+ N = pos_embed.shape[1] - first_patch_idx # since it's 1 if cls_token exists
75
+ if npatch_per_img == N:
76
+ return pos_embed
77
+
78
+ assert (
79
+ patches_layout[-1] == patches_layout[-2]
80
+ ), "Interpolation of pos embed not supported for non-square layouts"
81
+
82
+ class_emb = pos_embed[:, :first_patch_idx]
83
+ pos_embed = pos_embed[:, first_patch_idx:]
84
+
85
+ if input_shape is None or patches_layout[0] == 1:
86
+ # simple 2D pos embedding, no temporal component
87
+ pos_embed = interpolate_pos_encoding_2d(npatch_per_img, pos_embed)
88
+ elif patches_layout[0] > 1:
89
+ # pos embed has a temporal component
90
+ assert len(input_shape) == 4, "temporal interpolation not supported"
91
+ # we only support 2D interpolation in this case
92
+ num_frames = patches_layout[0]
93
+ num_spatial_tokens = patches_layout[1] * patches_layout[2]
94
+ pos_embed = pos_embed.view(1, num_frames, num_spatial_tokens, -1)
95
+ # interpolate embedding for zeroth frame
96
+ pos_embed = interpolate_pos_encoding_2d(
97
+ npatch_per_img, pos_embed[0, 0, ...].unsqueeze(0)
98
+ )
99
+ else:
100
+ raise ValueError("This type of interpolation isn't implemented")
101
+
102
+ return torch.cat((class_emb, pos_embed), dim=1)
103
+
104
+
105
+ def _get_pos_embedding(
106
+ npatch_per_img,
107
+ pos_embed,
108
+ patches_layout,
109
+ input_shape,
110
+ first_patch_idx=1,
111
+ ):
112
+ pos_embed = interpolate_pos_encoding(
113
+ npatch_per_img,
114
+ pos_embed,
115
+ patches_layout,
116
+ input_shape=input_shape,
117
+ first_patch_idx=first_patch_idx,
118
+ )
119
+ return pos_embed
120
+
121
+
122
+ class PatchEmbedGeneric(nn.Module):
123
+ """
124
+ PatchEmbed from Hydra
125
+ """
126
+
127
+ def __init__(self, proj_stem, norm_layer: Optional[nn.Module] = None):
128
+ super().__init__()
129
+
130
+ if len(proj_stem) > 1:
131
+ self.proj = nn.Sequential(*proj_stem)
132
+ else:
133
+ # Special case to be able to load pre-trained models that were
134
+ # trained with a standard stem
135
+ self.proj = proj_stem[0]
136
+ self.norm_layer = norm_layer
137
+
138
+ def get_patch_layout(self, img_size):
139
+ with torch.no_grad():
140
+ dummy_img = torch.zeros(
141
+ [
142
+ 1,
143
+ ]
144
+ + img_size
145
+ )
146
+ dummy_out = self.proj(dummy_img)
147
+ embed_dim = dummy_out.shape[1]
148
+ patches_layout = tuple(dummy_out.shape[2:])
149
+ num_patches = np.prod(patches_layout)
150
+ return patches_layout, num_patches, embed_dim
151
+
152
+ def forward(self, x):
153
+ x = self.proj(x)
154
+ # B C (T) H W -> B (T)HW C
155
+ x = x.flatten(2).transpose(1, 2)
156
+ if self.norm_layer is not None:
157
+ x = self.norm_layer(x)
158
+ return x
159
+
160
+
161
+ class SpatioTemporalPosEmbeddingHelper(VerboseNNModule):
162
+ def __init__(
163
+ self,
164
+ patches_layout: List,
165
+ num_patches: int,
166
+ num_cls_tokens: int,
167
+ embed_dim: int,
168
+ learnable: bool,
169
+ ) -> None:
170
+ super().__init__()
171
+ self.num_cls_tokens = num_cls_tokens
172
+ self.patches_layout = patches_layout
173
+ self.num_patches = num_patches
174
+ self.num_tokens = num_cls_tokens + num_patches
175
+ self.learnable = learnable
176
+ if self.learnable:
177
+ self.pos_embed = nn.Parameter(torch.zeros(1, self.num_tokens, embed_dim))
178
+ trunc_normal_(self.pos_embed, std=0.02)
179
+ else:
180
+ self.register_buffer(
181
+ "pos_embed", get_sinusoid_encoding_table(self.num_tokens, embed_dim)
182
+ )
183
+
184
+ def get_pos_embedding(self, vision_input, all_vision_tokens):
185
+ input_shape = vision_input.shape
186
+ pos_embed = _get_pos_embedding(
187
+ all_vision_tokens.size(1) - self.num_cls_tokens,
188
+ pos_embed=self.pos_embed,
189
+ patches_layout=self.patches_layout,
190
+ input_shape=input_shape,
191
+ first_patch_idx=self.num_cls_tokens,
192
+ )
193
+ return pos_embed
194
+
195
+
196
+ class RGBDTPreprocessor(VerboseNNModule):
197
+ def __init__(
198
+ self,
199
+ rgbt_stem: PatchEmbedGeneric,
200
+ depth_stem: PatchEmbedGeneric,
201
+ img_size: List = (3, 224, 224),
202
+ num_cls_tokens: int = 1,
203
+ pos_embed_fn: Callable = None,
204
+ use_type_embed: bool = False,
205
+ init_param_style: str = "openclip",
206
+ ) -> None:
207
+ super().__init__()
208
+ stem = rgbt_stem if rgbt_stem is not None else depth_stem
209
+ (
210
+ self.patches_layout,
211
+ self.num_patches,
212
+ self.embed_dim,
213
+ ) = stem.get_patch_layout(img_size)
214
+ self.rgbt_stem = rgbt_stem
215
+ self.depth_stem = depth_stem
216
+ self.use_pos_embed = pos_embed_fn is not None
217
+ self.use_type_embed = use_type_embed
218
+ self.num_cls_tokens = num_cls_tokens
219
+
220
+ if self.use_pos_embed:
221
+ self.pos_embedding_helper = pos_embed_fn(
222
+ patches_layout=self.patches_layout,
223
+ num_cls_tokens=num_cls_tokens,
224
+ num_patches=self.num_patches,
225
+ embed_dim=self.embed_dim,
226
+ )
227
+ if self.num_cls_tokens > 0:
228
+ self.cls_token = nn.Parameter(
229
+ torch.zeros(1, self.num_cls_tokens, self.embed_dim)
230
+ )
231
+ if self.use_type_embed:
232
+ self.type_embed = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
233
+
234
+ self.init_parameters(init_param_style)
235
+
236
+ @torch.no_grad()
237
+ def init_parameters(self, init_param_style):
238
+ if init_param_style == "openclip":
239
+ # OpenCLIP style initialization
240
+ scale = self.embed_dim**-0.5
241
+ if self.use_pos_embed:
242
+ nn.init.normal_(self.pos_embedding_helper.pos_embed)
243
+ self.pos_embedding_helper.pos_embed *= scale
244
+
245
+ if self.num_cls_tokens > 0:
246
+ nn.init.normal_(self.cls_token)
247
+ self.cls_token *= scale
248
+ elif init_param_style == "vit":
249
+ self.cls_token.data.fill_(0)
250
+ else:
251
+ raise ValueError(f"Unknown init {init_param_style}")
252
+
253
+ if self.use_type_embed:
254
+ nn.init.normal_(self.type_embed)
255
+
256
+ def tokenize_input_and_cls_pos(self, input, stem, mask):
257
+ # tokens is of shape B x L x D
258
+ tokens = stem(input)
259
+ assert tokens.ndim == 3
260
+ assert tokens.shape[2] == self.embed_dim
261
+ B = tokens.shape[0]
262
+ if self.num_cls_tokens > 0:
263
+ class_tokens = self.cls_token.expand(
264
+ B, -1, -1
265
+ ) # stole class_tokens impl from Phil Wang, thanks
266
+ tokens = torch.cat((class_tokens, tokens), dim=1)
267
+ if self.use_pos_embed:
268
+ pos_embed = self.pos_embedding_helper.get_pos_embedding(input, tokens)
269
+ tokens = tokens + pos_embed
270
+ if self.use_type_embed:
271
+ tokens = tokens + self.type_embed.expand(B, -1, -1)
272
+ return tokens
273
+
274
+ def forward(self, vision=None, depth=None, patch_mask=None):
275
+ if patch_mask is not None:
276
+ raise NotImplementedError()
277
+
278
+ if vision is not None:
279
+ vision_tokens = self.tokenize_input_and_cls_pos(
280
+ vision, self.rgbt_stem, patch_mask
281
+ )
282
+
283
+ if depth is not None:
284
+ depth_tokens = self.tokenize_input_and_cls_pos(
285
+ depth, self.depth_stem, patch_mask
286
+ )
287
+
288
+ # aggregate tokens
289
+ if vision is not None and depth is not None:
290
+ final_tokens = vision_tokens + depth_tokens
291
+ else:
292
+ final_tokens = vision_tokens if vision is not None else depth_tokens
293
+ return_dict = {
294
+ "trunk": {
295
+ "tokens": final_tokens,
296
+ },
297
+ "head": {},
298
+ }
299
+ return return_dict
300
+
301
+
302
+ class AudioPreprocessor(RGBDTPreprocessor):
303
+ def __init__(self, audio_stem: PatchEmbedGeneric, **kwargs) -> None:
304
+ super().__init__(rgbt_stem=audio_stem, depth_stem=None, **kwargs)
305
+
306
+ def forward(self, audio=None):
307
+ return super().forward(vision=audio)
308
+
309
+
310
+ class ThermalPreprocessor(RGBDTPreprocessor):
311
+ def __init__(self, thermal_stem: PatchEmbedGeneric, **kwargs) -> None:
312
+ super().__init__(rgbt_stem=thermal_stem, depth_stem=None, **kwargs)
313
+
314
+ def forward(self, thermal=None):
315
+ return super().forward(vision=thermal)
316
+
317
+
318
+ def build_causal_attention_mask(context_length):
319
+ # lazily create causal attention mask, with full attention between the vision tokens
320
+ # pytorch uses additive attention mask; fill with -inf
321
+ mask = torch.empty(context_length, context_length, requires_grad=False)
322
+ mask.fill_(float("-inf"))
323
+ mask.triu_(1) # zero out the lower diagonal
324
+ return mask
325
+
326
+
327
+ class TextPreprocessor(VerboseNNModule):
328
+ def __init__(
329
+ self,
330
+ vocab_size: int,
331
+ context_length: int,
332
+ embed_dim: int,
333
+ causal_masking: bool,
334
+ supply_seq_len_to_head: bool = True,
335
+ num_cls_tokens: int = 0,
336
+ init_param_style: str = "openclip",
337
+ ) -> None:
338
+ super().__init__()
339
+ self.vocab_size = vocab_size
340
+ self.context_length = context_length
341
+ self.token_embedding = nn.Embedding(vocab_size, embed_dim)
342
+ self.pos_embed = nn.Parameter(
343
+ torch.empty(1, self.context_length + num_cls_tokens, embed_dim)
344
+ )
345
+ self.causal_masking = causal_masking
346
+ if self.causal_masking:
347
+ mask = build_causal_attention_mask(self.context_length)
348
+ # register the mask as a buffer so it can be moved to the right device
349
+ self.register_buffer("mask", mask)
350
+
351
+ self.supply_seq_len_to_head = supply_seq_len_to_head
352
+ self.num_cls_tokens = num_cls_tokens
353
+ self.embed_dim = embed_dim
354
+ if num_cls_tokens > 0:
355
+ assert self.causal_masking is False, "Masking + CLS token isn't implemented"
356
+ self.cls_token = nn.Parameter(
357
+ torch.zeros(1, self.num_cls_tokens, embed_dim)
358
+ )
359
+
360
+ self.init_parameters(init_param_style)
361
+
362
+ @torch.no_grad()
363
+ def init_parameters(self, init_param_style="openclip"):
364
+ # OpenCLIP style initialization
365
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
366
+ nn.init.normal_(self.pos_embed, std=0.01)
367
+
368
+ if init_param_style == "openclip":
369
+ # OpenCLIP style initialization
370
+ scale = self.embed_dim**-0.5
371
+ if self.num_cls_tokens > 0:
372
+ nn.init.normal_(self.cls_token)
373
+ self.cls_token *= scale
374
+ elif init_param_style == "vit":
375
+ self.cls_token.data.fill_(0)
376
+ else:
377
+ raise ValueError(f"Unknown init {init_param_style}")
378
+
379
+ def forward(self, text):
380
+ # text tokens are of shape B x L x D
381
+ text_tokens = self.token_embedding(text)
382
+ # concat CLS tokens if any
383
+ if self.num_cls_tokens > 0:
384
+ B = text_tokens.shape[0]
385
+ class_tokens = self.cls_token.expand(
386
+ B, -1, -1
387
+ ) # stole class_tokens impl from Phil Wang, thanks
388
+ text_tokens = torch.cat((class_tokens, text_tokens), dim=1)
389
+ text_tokens = text_tokens + self.pos_embed
390
+ return_dict = {
391
+ "trunk": {
392
+ "tokens": text_tokens,
393
+ },
394
+ "head": {},
395
+ }
396
+ # Compute sequence length after adding CLS tokens
397
+ if self.supply_seq_len_to_head:
398
+ text_lengths = text.argmax(dim=-1)
399
+ return_dict["head"] = {
400
+ "seq_len": text_lengths,
401
+ }
402
+ if self.causal_masking:
403
+ return_dict["trunk"].update({"attn_mask": self.mask})
404
+ return return_dict
405
+
406
+
407
+ class Im2Video(nn.Module):
408
+ """Convert an image into a trivial video."""
409
+
410
+ def __init__(self, time_dim=2):
411
+ super().__init__()
412
+ self.time_dim = time_dim
413
+
414
+ def forward(self, x):
415
+ if x.ndim == 4:
416
+ # B, C, H, W -> B, C, T, H, W
417
+ return x.unsqueeze(self.time_dim)
418
+ elif x.ndim == 5:
419
+ return x
420
+ else:
421
+ raise ValueError(f"Dimension incorrect {x.shape}")
422
+
423
+
424
+ class PadIm2Video(Im2Video):
425
+ def __init__(self, ntimes, pad_type, time_dim=2):
426
+ super().__init__(time_dim=time_dim)
427
+ assert ntimes > 0
428
+ assert pad_type in ["zero", "repeat"]
429
+ self.ntimes = ntimes
430
+ self.pad_type = pad_type
431
+
432
+ def forward(self, x):
433
+ x = super().forward(x)
434
+ if x.shape[self.time_dim] == 1:
435
+ if self.pad_type == "repeat":
436
+ new_shape = [1] * len(x.shape)
437
+ new_shape[self.time_dim] = self.ntimes
438
+ x = x.repeat(new_shape)
439
+ elif self.pad_type == "zero":
440
+ padarg = [0, 0] * len(x.shape)
441
+ padarg[2 * self.time_dim + 1] = self.ntimes - x.shape[self.time_dim]
442
+ x = nn.functional.pad(x, padarg)
443
+ return x
444
+
445
+
446
+ # Modified from github.com/openai/CLIP
447
+ @lru_cache()
448
+ def bytes_to_unicode():
449
+ """
450
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
451
+ The reversible bpe codes work on unicode strings.
452
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
453
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
454
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
455
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
456
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
457
+ """
458
+ bs = (
459
+ list(range(ord("!"), ord("~") + 1))
460
+ + list(range(ord("¡"), ord("¬") + 1))
461
+ + list(range(ord("®"), ord("ÿ") + 1))
462
+ )
463
+ cs = bs[:]
464
+ n = 0
465
+ for b in range(2**8):
466
+ if b not in bs:
467
+ bs.append(b)
468
+ cs.append(2**8 + n)
469
+ n += 1
470
+ cs = [chr(n) for n in cs]
471
+ return dict(zip(bs, cs))
472
+
473
+
474
+ def get_pairs(word):
475
+ """Return set of symbol pairs in a word.
476
+ Word is represented as tuple of symbols (symbols being variable-length strings).
477
+ """
478
+ pairs = set()
479
+ prev_char = word[0]
480
+ for char in word[1:]:
481
+ pairs.add((prev_char, char))
482
+ prev_char = char
483
+ return pairs
484
+
485
+
486
+ def basic_clean(text):
487
+ text = ftfy.fix_text(text)
488
+ text = html.unescape(html.unescape(text))
489
+ return text.strip()
490
+
491
+
492
+ def whitespace_clean(text):
493
+ text = re.sub(r"\s+", " ", text)
494
+ text = text.strip()
495
+ return text
496
+
497
+
498
+ class SimpleTokenizer(object):
499
+ def __init__(self, bpe_path: str, context_length=77):
500
+ self.byte_encoder = bytes_to_unicode()
501
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
502
+
503
+ with g_pathmgr.open(bpe_path, "rb") as fh:
504
+ bpe_bytes = io.BytesIO(fh.read())
505
+ merges = gzip.open(bpe_bytes).read().decode("utf-8").split("\n")
506
+ merges = merges[1 : 49152 - 256 - 2 + 1]
507
+ merges = [tuple(merge.split()) for merge in merges]
508
+ vocab = list(bytes_to_unicode().values())
509
+ vocab = vocab + [v + "</w>" for v in vocab]
510
+ for merge in merges:
511
+ vocab.append("".join(merge))
512
+ vocab.extend(["<|startoftext|>", "<|endoftext|>"])
513
+ self.encoder = dict(zip(vocab, range(len(vocab))))
514
+ self.decoder = {v: k for k, v in self.encoder.items()}
515
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
516
+ self.cache = {
517
+ "<|startoftext|>": "<|startoftext|>",
518
+ "<|endoftext|>": "<|endoftext|>",
519
+ }
520
+ self.pat = re.compile(
521
+ r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
522
+ re.IGNORECASE,
523
+ )
524
+ self.context_length = context_length
525
+
526
+ def bpe(self, token):
527
+ if token in self.cache:
528
+ return self.cache[token]
529
+ word = tuple(token[:-1]) + (token[-1] + "</w>",)
530
+ pairs = get_pairs(word)
531
+
532
+ if not pairs:
533
+ return token + "</w>"
534
+
535
+ while True:
536
+ bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
537
+ if bigram not in self.bpe_ranks:
538
+ break
539
+ first, second = bigram
540
+ new_word = []
541
+ i = 0
542
+ while i < len(word):
543
+ try:
544
+ j = word.index(first, i)
545
+ new_word.extend(word[i:j])
546
+ i = j
547
+ except:
548
+ new_word.extend(word[i:])
549
+ break
550
+
551
+ if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
552
+ new_word.append(first + second)
553
+ i += 2
554
+ else:
555
+ new_word.append(word[i])
556
+ i += 1
557
+ new_word = tuple(new_word)
558
+ word = new_word
559
+ if len(word) == 1:
560
+ break
561
+ else:
562
+ pairs = get_pairs(word)
563
+ word = " ".join(word)
564
+ self.cache[token] = word
565
+ return word
566
+
567
+ def encode(self, text):
568
+ bpe_tokens = []
569
+ text = whitespace_clean(basic_clean(text)).lower()
570
+ for token in re.findall(self.pat, text):
571
+ token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
572
+ bpe_tokens.extend(
573
+ self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")
574
+ )
575
+ return bpe_tokens
576
+
577
+ def decode(self, tokens):
578
+ text = "".join([self.decoder[token] for token in tokens])
579
+ text = (
580
+ bytearray([self.byte_decoder[c] for c in text])
581
+ .decode("utf-8", errors="replace")
582
+ .replace("</w>", " ")
583
+ )
584
+ return text
585
+
586
+ def __call__(self, texts, context_length=None):
587
+ if not context_length:
588
+ context_length = self.context_length
589
+
590
+ if isinstance(texts, str):
591
+ texts = [texts]
592
+
593
+ sot_token = self.encoder["<|startoftext|>"]
594
+ eot_token = self.encoder["<|endoftext|>"]
595
+ all_tokens = [[sot_token] + self.encode(text) + [eot_token] for text in texts]
596
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
597
+
598
+ for i, tokens in enumerate(all_tokens):
599
+ tokens = tokens[:context_length]
600
+ result[i, : len(tokens)] = torch.tensor(tokens)
601
+
602
+ if len(result) == 1:
603
+ return result[0]
604
+ return result
605
+
606
+
607
+ class IMUPreprocessor(VerboseNNModule):
608
+ def __init__(
609
+ self,
610
+ kernel_size: int,
611
+ imu_stem: PatchEmbedGeneric,
612
+ embed_dim: int,
613
+ img_size: List = (6, 2000),
614
+ num_cls_tokens: int = 1,
615
+ pos_embed_fn: Callable = None,
616
+ init_param_style: str = "openclip",
617
+ ) -> None:
618
+ super().__init__()
619
+ stem = imu_stem
620
+ self.imu_stem = imu_stem
621
+ self.embed_dim = embed_dim
622
+ self.use_pos_embed = pos_embed_fn is not None
623
+ self.num_cls_tokens = num_cls_tokens
624
+ self.kernel_size = kernel_size
625
+ self.pos_embed = nn.Parameter(
626
+ torch.empty(1, (img_size[1] // kernel_size) + num_cls_tokens, embed_dim)
627
+ )
628
+
629
+ if self.num_cls_tokens > 0:
630
+ self.cls_token = nn.Parameter(
631
+ torch.zeros(1, self.num_cls_tokens, self.embed_dim)
632
+ )
633
+
634
+ self.init_parameters(init_param_style)
635
+
636
+ @torch.no_grad()
637
+ def init_parameters(self, init_param_style):
638
+ nn.init.normal_(self.pos_embed, std=0.01)
639
+
640
+ if init_param_style == "openclip":
641
+ # OpenCLIP style initialization
642
+ scale = self.embed_dim**-0.5
643
+
644
+ if self.num_cls_tokens > 0:
645
+ nn.init.normal_(self.cls_token)
646
+ self.cls_token *= scale
647
+ elif init_param_style == "vit":
648
+ self.cls_token.data.fill_(0)
649
+ else:
650
+ raise ValueError(f"Unknown init {init_param_style}")
651
+
652
+ def tokenize_input_and_cls_pos(self, input, stem):
653
+ # tokens is of shape B x L x D
654
+ tokens = stem.norm_layer(stem.proj(input))
655
+ assert tokens.ndim == 3
656
+ assert tokens.shape[2] == self.embed_dim
657
+ B = tokens.shape[0]
658
+ if self.num_cls_tokens > 0:
659
+ class_tokens = self.cls_token.expand(
660
+ B, -1, -1
661
+ ) # stole class_tokens impl from Phil Wang, thanks
662
+ tokens = torch.cat((class_tokens, tokens), dim=1)
663
+ if self.use_pos_embed:
664
+ tokens = tokens + self.pos_embed
665
+ return tokens
666
+
667
+ def forward(self, imu):
668
+ # Patchify
669
+ imu = imu.unfold(
670
+ -1,
671
+ self.kernel_size,
672
+ self.kernel_size,
673
+ ).permute(0, 2, 1, 3)
674
+ imu = imu.reshape(imu.size(0), imu.size(1), -1)
675
+
676
+ imu_tokens = self.tokenize_input_and_cls_pos(
677
+ imu,
678
+ self.imu_stem,
679
+ )
680
+
681
+ return_dict = {
682
+ "trunk": {
683
+ "tokens": imu_tokens,
684
+ },
685
+ "head": {},
686
+ }
687
+ return return_dict
model/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from .agent import DeepSpeedAgent
2
+ from .openlamm import LAMMPEFTModel
3
+
4
+
5
+ def load_model(args):
6
+ agent_name = args['models'][args['model']]['agent_name']
7
+ model_name = args['models'][args['model']]['model_name']
8
+ model = globals()[model_name](**args)
9
+ agent = globals()[agent_name](model, args)
10
+ return agent
model/agent.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from header import *
2
+ from torch.utils.tensorboard import SummaryWriter
3
+
4
+
5
+ class DeepSpeedAgent:
6
+
7
+ def __init__(self, model, args):
8
+ super(DeepSpeedAgent, self).__init__()
9
+ self.args = args
10
+ self.model = model
11
+ self.writer = SummaryWriter(args['log_path'])
12
+ if args['stage'] == 2:
13
+ self.load_stage_1_parameters(args["delta_ckpt_path"])
14
+ print(f'[!] load stage 1 checkpoint from {args["delta_ckpt_path"]}')
15
+
16
+ # load config parameters of deepspeed
17
+ ds_params = json.load(open(self.args['ds_config_path']))
18
+ ds_params['scheduler']['params']['total_num_steps'] = self.args['total_steps']
19
+ ds_params['scheduler']['params']['warmup_num_steps'] = max(10, int(self.args['total_steps'] * self.args['warmup_rate']))
20
+ self.ds_engine, self.optimizer, _ , _ = deepspeed.initialize(
21
+ model=self.model,
22
+ model_parameters=self.model.parameters(),
23
+ config_params=ds_params,
24
+ dist_init_required=True,
25
+ args=types.SimpleNamespace(**args)
26
+ )
27
+
28
+ @torch.no_grad()
29
+ def predict(self, batch):
30
+ self.model.eval()
31
+ string = self.model.generate_one_sample(batch)
32
+ return string
33
+
34
+ def train_model(self, batch, current_step=0, pbar=None):
35
+ self.ds_engine.module.train()
36
+ loss, mle_acc = self.ds_engine(batch)
37
+
38
+ self.ds_engine.backward(loss)
39
+ self.ds_engine.step()
40
+ pbar.set_description(f'[!] loss: {round(loss.item(), 4)}; token_acc: {round(mle_acc*100, 2)}')
41
+ pbar.update(1)
42
+ if self.args['local_rank'] == 0 and self.args['log_path'] and current_step % self.args['logging_step'] == 0:
43
+ elapsed = pbar.format_dict['elapsed']
44
+ rate = pbar.format_dict['rate']
45
+ remaining = (pbar.total - pbar.n) / rate if rate and pbar.total else 0
46
+ remaining = str(datetime.timedelta(seconds=remaining))
47
+ self.writer.add_scalar('train/loss', loss.item(), current_step)
48
+ self.writer.add_scalar('train/token_acc', mle_acc*100, current_step)
49
+ logging.info(f'[!] progress: {round(pbar.n/pbar.total, 5)}; remaining time: {remaining}; loss: {round(loss.item(), 4)}; token_acc: {round(mle_acc*100, 2)}')
50
+
51
+ mle_acc *= 100
52
+ return mle_acc
53
+
54
+ def save_model(self, path, current_step):
55
+ # only save trainable model parameters
56
+ param_grad_dic = {
57
+ k: v.requires_grad for (k, v) in self.ds_engine.module.named_parameters()
58
+ }
59
+ state_dict = self.ds_engine.module.state_dict()
60
+ checkpoint = OrderedDict()
61
+ for k, v in self.ds_engine.module.named_parameters():
62
+ if v.requires_grad:
63
+ checkpoint[k] = v
64
+ if current_step <= 0:
65
+ torch.save(checkpoint, f'{path}/pytorch_model.pt')
66
+ else:
67
+ torch.save(checkpoint, f'{path}/pytorch_model_ep{current_step}.pt')
68
+ # save tokenizer
69
+ self.model.llama_tokenizer.save_pretrained(path)
70
+ # save configuration
71
+ self.model.llama_model.config.save_pretrained(path)
72
+ print(f'[!] save model into {path}')
73
+
74
+ def load_stage_1_parameters(self, path):
75
+ delta_ckpt = torch.load(path, map_location=torch.device('cpu'))
76
+ self.model.load_state_dict(delta_ckpt, strict=False)
model/conversations.py ADDED
@@ -0,0 +1,386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ from enum import auto, Enum
3
+ from typing import List, Tuple
4
+
5
+
6
+ class SeparatorStyle(Enum):
7
+ """Different separator style."""
8
+ SINGLE = auto()
9
+ TWO = auto()
10
+ MPT = auto()
11
+
12
+
13
+ @dataclasses.dataclass
14
+ class Conversation:
15
+ """A class that keeps all conversation history."""
16
+ system: str
17
+ roles: List[str]
18
+ messages: List[List[str]]
19
+ offset: int
20
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
21
+ sep: str = "###"
22
+ sep2: str = None
23
+ version: str = "Unknown"
24
+
25
+ skip_next: bool = False
26
+
27
+ def get_prompt(self):
28
+ if self.sep_style == SeparatorStyle.SINGLE:
29
+ ret = self.system + self.sep
30
+ for role, message in self.messages:
31
+ if message:
32
+ if type(message) is tuple:
33
+ message, _, _ = message
34
+ ret += role + ": " + message + self.sep
35
+ else:
36
+ ret += role + ":"
37
+ return ret
38
+ elif self.sep_style == SeparatorStyle.TWO:
39
+ seps = [self.sep, self.sep2]
40
+ ret = self.system + seps[0]
41
+ for i, (role, message) in enumerate(self.messages):
42
+ if message:
43
+ if type(message) is tuple:
44
+ message, _, _ = message
45
+ ret += role + ": " + message + seps[i % 2]
46
+ else:
47
+ ret += role + ":"
48
+ return ret
49
+ if self.sep_style == SeparatorStyle.MPT:
50
+ ret = self.system + self.sep
51
+ for role, message in self.messages:
52
+ if message:
53
+ if type(message) is tuple:
54
+ message, _, _ = message
55
+ ret += role + message + self.sep
56
+ else:
57
+ ret += role
58
+ return ret
59
+ else:
60
+ raise ValueError(f"Invalid style: {self.sep_style}")
61
+
62
+ def append_message(self, role, message):
63
+ self.messages.append([role, message])
64
+
65
+ def get_images(self, return_pil=False):
66
+ images = []
67
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
68
+ if i % 2 == 0:
69
+ if type(msg) is tuple:
70
+ import base64
71
+ from io import BytesIO
72
+ from PIL import Image
73
+ msg, image, image_process_mode = msg
74
+ if image_process_mode == "Pad":
75
+ def expand2square(pil_img, background_color=(122, 116, 104)):
76
+ width, height = pil_img.size
77
+ if width == height:
78
+ return pil_img
79
+ elif width > height:
80
+ result = Image.new(pil_img.mode, (width, width), background_color)
81
+ result.paste(pil_img, (0, (width - height) // 2))
82
+ return result
83
+ else:
84
+ result = Image.new(pil_img.mode, (height, height), background_color)
85
+ result.paste(pil_img, ((height - width) // 2, 0))
86
+ return result
87
+ image = expand2square(image)
88
+ elif image_process_mode == "Crop":
89
+ pass
90
+ elif image_process_mode == "Resize":
91
+ image = image.resize((224, 224))
92
+ else:
93
+ raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
94
+ max_hw, min_hw = max(image.size), min(image.size)
95
+ aspect_ratio = max_hw / min_hw
96
+ max_len, min_len = 800, 400
97
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
98
+ longest_edge = int(shortest_edge * aspect_ratio)
99
+ W, H = image.size
100
+ if H > W:
101
+ H, W = longest_edge, shortest_edge
102
+ else:
103
+ H, W = shortest_edge, longest_edge
104
+ image = image.resize((W, H))
105
+ if return_pil:
106
+ images.append(image)
107
+ else:
108
+ buffered = BytesIO()
109
+ image.save(buffered, format="JPEG")
110
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
111
+ images.append(img_b64_str)
112
+ return images
113
+
114
+ def to_gradio_chatbot(self):
115
+ ret = []
116
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
117
+ if i % 2 == 0:
118
+ if type(msg) is tuple:
119
+ import base64
120
+ from io import BytesIO
121
+ msg, image, image_process_mode = msg
122
+ max_hw, min_hw = max(image.size), min(image.size)
123
+ aspect_ratio = max_hw / min_hw
124
+ max_len, min_len = 800, 400
125
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
126
+ longest_edge = int(shortest_edge * aspect_ratio)
127
+ W, H = image.size
128
+ if H > W:
129
+ H, W = longest_edge, shortest_edge
130
+ else:
131
+ H, W = shortest_edge, longest_edge
132
+ image = image.resize((W, H))
133
+ # image = image.resize((224, 224))
134
+ buffered = BytesIO()
135
+ image.save(buffered, format="JPEG")
136
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
137
+ img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
138
+ msg = msg.replace('<image>', img_str)
139
+ ret.append([msg, None])
140
+ else:
141
+ ret[-1][-1] = msg
142
+ return ret
143
+
144
+ def copy(self):
145
+ return Conversation(
146
+ system=self.system,
147
+ roles=self.roles,
148
+ messages=[[x, y] for x, y in self.messages],
149
+ offset=self.offset,
150
+ sep_style=self.sep_style,
151
+ sep=self.sep,
152
+ sep2=self.sep2)
153
+
154
+ def dict(self):
155
+ if len(self.get_images()) > 0:
156
+ return {
157
+ "system": self.system,
158
+ "roles": self.roles,
159
+ "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
160
+ "offset": self.offset,
161
+ "sep": self.sep,
162
+ "sep2": self.sep2,
163
+ }
164
+ return {
165
+ "system": self.system,
166
+ "roles": self.roles,
167
+ "messages": self.messages,
168
+ "offset": self.offset,
169
+ "sep": self.sep,
170
+ "sep2": self.sep2,
171
+ }
172
+
173
+
174
+ conv_v1 = Conversation(
175
+ system="A chat between a curious human and an artificial intelligence assistant. "
176
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
177
+ roles=("Human", "Assistant"),
178
+ messages=(
179
+ ("Human", "Give three tips for staying healthy."),
180
+ ("Assistant",
181
+ "Sure, here are three tips for staying healthy:\n"
182
+ "1. Exercise regularly: Regular physical activity can help improve your overall health and wellbeing. "
183
+ "It can also help reduce your risk of chronic conditions such as obesity, diabetes, heart disease, "
184
+ "and certain cancers. Aim for at least 150 minutes of moderate-intensity aerobic exercise or "
185
+ "75 minutes of vigorous-intensity aerobic exercise per week, along with muscle-strengthening "
186
+ "activities at least two days per week.\n"
187
+ "2. Eat a balanced diet: Eating a balanced diet that is rich in fruits, "
188
+ "vegetables, whole grains, lean proteins, and healthy fats can help support "
189
+ "your overall health. Try to limit your intake of processed and high-sugar foods, "
190
+ "and aim to drink plenty of water throughout the day.\n"
191
+ "3. Get enough sleep: Getting enough quality sleep is essential for your physical "
192
+ "and mental health. Adults should aim for seven to nine hours of sleep per night. "
193
+ "Establish a regular sleep schedule and try to create a relaxing bedtime routine to "
194
+ "help improve the quality of your sleep.")
195
+ ),
196
+ offset=2,
197
+ sep_style=SeparatorStyle.SINGLE,
198
+ sep="###",
199
+ )
200
+
201
+ conv_v1_2 = Conversation(
202
+ system="A chat between a curious human and an artificial intelligence assistant. "
203
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
204
+ roles=("Human", "Assistant"),
205
+ messages=(
206
+ ("Human", "What are the key differences between renewable and non-renewable energy sources?"),
207
+ ("Assistant",
208
+ "Renewable energy sources are those that can be replenished naturally in a relatively "
209
+ "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
210
+ "Non-renewable energy sources, on the other hand, are finite and will eventually be "
211
+ "depleted, such as coal, oil, and natural gas. Here are some key differences between "
212
+ "renewable and non-renewable energy sources:\n"
213
+ "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
214
+ "energy sources are finite and will eventually run out.\n"
215
+ "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
216
+ "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
217
+ "and other negative effects.\n"
218
+ "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
219
+ "have lower operational costs than non-renewable sources.\n"
220
+ "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
221
+ "locations than non-renewable sources.\n"
222
+ "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
223
+ "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
224
+ "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
225
+ "non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
226
+ ),
227
+ offset=2,
228
+ sep_style=SeparatorStyle.SINGLE,
229
+ sep="###",
230
+ )
231
+
232
+ conv_vicuna_v1_1 = Conversation(
233
+ system="A chat between a curious user and an artificial intelligence assistant. "
234
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
235
+ roles=("USER", "ASSISTANT"),
236
+ version="v1",
237
+ messages=(),
238
+ offset=0,
239
+ sep_style=SeparatorStyle.TWO,
240
+ sep=" ",
241
+ sep2="</s>",
242
+ )
243
+
244
+ conv_mpt = Conversation(
245
+ system="""<|im_start|>system
246
+ - You are a helpful language and vision assistant.
247
+ - You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.
248
+ - You should follow the instructions carefully and explain your answers in detail.""",
249
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
250
+ version="mpt",
251
+ messages=(),
252
+ offset=0,
253
+ sep_style=SeparatorStyle.MPT,
254
+ sep="<|im_end|>",
255
+ )
256
+
257
+ conv_mpt_text = Conversation(
258
+ system="""<|im_start|>system
259
+ - You are a helpful assistant chatbot trained by MosaicML.
260
+ - You answer questions.
261
+ - You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
262
+ - You are more than just an information source, you are also able to write poetry, short stories, and make jokes.""",
263
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
264
+ version="mpt",
265
+ messages=(),
266
+ offset=0,
267
+ sep_style=SeparatorStyle.MPT,
268
+ sep="<|im_end|>",
269
+ )
270
+
271
+ conv_bair_v1 = Conversation(
272
+ system="BEGINNING OF CONVERSATION:",
273
+ roles=("USER", "GPT"),
274
+ messages=(),
275
+ offset=0,
276
+ sep_style=SeparatorStyle.TWO,
277
+ sep=" ",
278
+ sep2="</s>",
279
+ )
280
+
281
+ simple_conv = Conversation(
282
+ system="You are a large language model that can recognize visual contents based on LLaMA architecture."
283
+ "You are designed to assist human with a variety of tasks using natural language."
284
+ "Follow the instructions carefully.",
285
+ roles=("Human", "Assistant"),
286
+ messages=(
287
+ ("Human", "Hi!"),
288
+ ("Assistant", "Hi there! How can I help you today?\n")
289
+ ),
290
+ offset=2,
291
+ sep_style=SeparatorStyle.SINGLE,
292
+ sep="###",
293
+ )
294
+
295
+ simple_conv_multimodal = Conversation(
296
+ system="You are a large language and vision assistant trained with multi-modality vision signals."
297
+ "You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
298
+ "Follow the instructions carefully and explain your answers in detail.",
299
+ roles=("Human", "Assistant"),
300
+ messages=(
301
+ ("Human", "Hi!"),
302
+ ("Assistant", "Hi there! How can I help you today?\n")
303
+ ),
304
+ offset=2,
305
+ sep_style=SeparatorStyle.SINGLE,
306
+ sep="###",
307
+ )
308
+
309
+ simple_conv_mpt_multimodal = Conversation(
310
+ system="""<|im_start|>system
311
+ - You are a large language and vision assistant.
312
+ - You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.
313
+ - You should follow the instructions carefully and explain your answers in detail.""",
314
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
315
+ version="mpt",
316
+ messages=(),
317
+ offset=0,
318
+ sep_style=SeparatorStyle.MPT,
319
+ sep="<|im_end|>",
320
+ )
321
+
322
+ simple_conv_legacy = Conversation(
323
+ system="You are a large language model that trained on multi-modality visual contents."
324
+ "You are designed to assist human with a variety of tasks using natural language."
325
+ "Follow the instructions carefully.",
326
+ roles=("Human", "Assistant"),
327
+ messages=(
328
+ ("Human", "Hi!\n\n### Response:"),
329
+ ("Assistant", "Hi there! How can I help you today?\n")
330
+ ),
331
+ offset=2,
332
+ sep_style=SeparatorStyle.SINGLE,
333
+ sep="###",
334
+ )
335
+
336
+ conv_llava_v1 = Conversation(
337
+ system="You are a large language and vision assistant."
338
+ "You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
339
+ "Follow the instructions carefully and explain your answers in detail.",
340
+ roles=("USER", "ASSISTANT"),
341
+ version="v1",
342
+ messages=(),
343
+ offset=0,
344
+ sep_style=SeparatorStyle.TWO,
345
+ sep=" ",
346
+ sep2="</s>",
347
+ )
348
+
349
+ default_conversation = conv_v1_2
350
+ conv_templates = {
351
+ "default": conv_v1_2,
352
+ "simple": simple_conv,
353
+ "simple_legacy": simple_conv_legacy,
354
+ "multimodal": simple_conv_multimodal,
355
+ "mpt_multimodal": simple_conv_mpt_multimodal,
356
+ "llava_v1": conv_llava_v1,
357
+
358
+ # fastchat
359
+ "v1": conv_v1_2,
360
+ "bair_v1": conv_bair_v1,
361
+ "vicuna_v1_1": conv_vicuna_v1_1,
362
+ "mpt": conv_mpt,
363
+ "mpt_text": conv_mpt_text,
364
+ }
365
+
366
+ conversation_dict = {
367
+ "classification": "You are an AI visual assistant that can analyze a single image. A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. As an AI assistant, you are performing a classification task, and your goal is not just to provide a class label for a given image, you also need to ensure that the classification is accurate and reliable, as this information is critical for users to make informed decisions based on image data.",
368
+ "detection": "You are an AI visual assistant that can analyze a single image. A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. As an AI assistant, you are performing an object detection task, and your goal is to locate all instances of objects in an image, such as people, cars, animals, or other objects, and give the corresponding coordinates. These coordinates are in the form of bounding boxes, represented as (x1, y1, x2, y2) with floating numbers ranging from 0 to 1. These values correspond to the top left x, top left y, bottom right x, and bottom right y.",
369
+ "VQA": "You are an AI visual assistant that can analyze a single image. A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. As an AI assistant, you are performing a visual question answering task, and your goal is to generate natural language answers that accurately solve the question. In order to generate accurate answers to questions about visual content, you must be able to understand the content of images, understand the meaning of questions, and perform complex reasoning processes.",
370
+ "counting": "You are an AI visual assistant that can analyze a single image. A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. As an AI assistant, you are performing an object counting task, and your goal is to accurately count the number of objects in an image. Object counting is a computer vision task that involves detecting and counting the number of instances of specific objects within an image. You need to analyze the input image and accurately count the number of objects in it.",
371
+ "conversation": "You are an AI visual assistant that can analyze a single image. A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. As an AI assistant, you are performing a conversation task, and your goal is to engage in a natural language conversation with a human about images and provide helpful and informative responses to their queries or requests. When answering questions related to images, you will do so in a tone that conveys that you are seeing the image and answering the question based on my analysis of the visual content. The conversation task involves understanding the user's input, generating an appropriate response, and maintaining a coherent and engaging conversation.",
372
+ "description": "You are an AI visual assistant that can analyze a single image. A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. As an AI assistant, you are performing a image detail description task, and your goal is to generate a natural language description of an image that accurately and comprehensively conveys its visual content. When answering questions related to images, you will do so in a tone that conveys that you are seeing the image and answering the question based on my analysis of the visual content. The Image detail description task involves generating a textual description of an image that captures its salient features, objects, and context.",
373
+ "commomsenseqa": "You are an AI visual assistant that can analyze a single image. A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. As an AI assistant, you are performing a external knowledge Q&A task, and your goal is to provide accurate and informative answers to questions that require external knowledge beyond the scope of the input text. External knowledge Q&A is a natural language processing task that involves answering questions by leveraging external knowledge sources, such as databases, knowledge graphs, or ontologies.",
374
+ "normal": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.",
375
+
376
+ "classification3d": "You are an AI visual assistant that can analyze a point cloud of object, a chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. As an AI assistant, you are performing a classification task, and your goal is not just to provide a class label for a given point clou, you also need to ensure that the classification is accurate and reliable, as this information is critical for users to make informed decisions based on point cloud data.",
377
+ "detection3d": "You are an AI visual assistant that can analyze a scan of scene. A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. As an AI assistant, you are performing an object detection task, and your goal is to locate all instances of objects in an point cloud, such as people, cars, animals, or other objects, and give the corresponding coordinates. These coordinates are in the form of bounding boxes, represented as (x1, y1, z1, x2, y2, z2) with unit of meters. These values correspond to the top left x, top left y, top left z, bottom right x, bottom right y and bottom right z.",
378
+ "VQA3d": "You are an AI visual assistant that can analyze a point cloud. A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. As an AI assistant, you are performing a visual question answering task, and your goal is to generate natural language answers that accurately solve the question. In order to generate accurate answers to questions about visual content, you must be able to understand the content of point cloud, understand the meaning of questions, and perform complex reasoning processes.",
379
+ "conversation3d": "You are an AI visual assistant that can analyze a point cloud. A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. As an AI assistant, you are performing a conversation task, and your goal is to engage in a natural language conversation with a human about point cloud and provide helpful and informative responses to their queries or requests. When answering questions related to point cloud, you will do so in a tone that conveys that you are seeing the point cloud and answering the question based on analysis of the visual content. The conversation task involves understanding the user's input, generating an appropriate response, and maintaining a coherent and engaging conversation.",
380
+ "description3d": "You are an AI visual assistant that can analyze a single point cloud. A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. As an AI assistant, you are performing a detail description task for point cloud, and your goal is to generate a natural language description of an point cloud that accurately and comprehensively conveys its visual content. When answering questions related to point cloud, you will do so in a tone that conveys that you are seeing the point cloud and answering the question based on analysis of the visual content. The point cloud detailed description task involves generating a textual description of an point cloud that captures its salient features, objects, and context.",
381
+ "commomsenseqa3d": "You are an AI visual assistant that can analyze a single point cloud. A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. As an AI assistant, you are performing an external knowledge Q&A task, and your goal is to provide accurate and informative answers to questions that require external knowledge beyond the scope of the input text. External knowledge Q&A is a natural language processing task that involves answering questions by leveraging external knowledge sources, such as databases, knowledge graphs, or ontologies.",
382
+ }
383
+
384
+
385
+ if __name__ == "__main__":
386
+ print(default_conversation.get_prompt())
model/modeling_llama.py ADDED
@@ -0,0 +1,755 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This script is based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
2
+
3
+ """ PyTorch LLaMA model."""
4
+ import math
5
+ from typing import List, Optional, Tuple, Union
6
+
7
+ import torch
8
+ import torch.utils.checkpoint
9
+ from torch import nn
10
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
11
+
12
+ from transformers.activations import ACT2FN
13
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
14
+ from transformers.modeling_utils import PreTrainedModel
15
+ from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
16
+ from transformers.models.llama.configuration_llama import LlamaConfig
17
+
18
+
19
+ logger = logging.get_logger(__name__)
20
+
21
+ _CONFIG_FOR_DOC = "LlamaConfig"
22
+
23
+
24
+ # Copied from transformers.models.bart.modeling_bart._make_causal_mask
25
+ def _make_causal_mask(
26
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
27
+ ):
28
+ """
29
+ Make causal mask used for bi-directional self-attention.
30
+ """
31
+ bsz, tgt_len = input_ids_shape
32
+ mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)
33
+ mask_cond = torch.arange(mask.size(-1), device=device)
34
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
35
+ mask = mask.to(dtype)
36
+
37
+ if past_key_values_length > 0:
38
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
39
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
40
+
41
+
42
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
43
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
44
+ """
45
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
46
+ """
47
+ bsz, src_len = mask.size()
48
+ tgt_len = tgt_len if tgt_len is not None else src_len
49
+
50
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
51
+
52
+ inverted_mask = 1.0 - expanded_mask
53
+
54
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
55
+
56
+
57
+ class LlamaRMSNorm(nn.Module):
58
+ def __init__(self, hidden_size, eps=1e-6):
59
+ """
60
+ LlamaRMSNorm is equivalent to T5LayerNorm
61
+ """
62
+ super().__init__()
63
+ self.weight = nn.Parameter(torch.ones(hidden_size))
64
+ self.variance_epsilon = eps
65
+
66
+ def forward(self, hidden_states):
67
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
68
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
69
+
70
+ # convert into half-precision if necessary
71
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
72
+ hidden_states = hidden_states.to(self.weight.dtype)
73
+
74
+ return self.weight * hidden_states
75
+
76
+
77
+ class LlamaRotaryEmbedding(torch.nn.Module):
78
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
79
+ super().__init__()
80
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
81
+ self.register_buffer("inv_freq", inv_freq)
82
+
83
+ # Build here to make `torch.jit.trace` work.
84
+ self.max_seq_len_cached = max_position_embeddings
85
+ t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
86
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
87
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
88
+ emb = torch.cat((freqs, freqs), dim=-1)
89
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
90
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
91
+
92
+ def forward(self, x, seq_len=None):
93
+ # x: [bs, num_attention_heads, seq_len, head_size]
94
+ # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
95
+ if seq_len > self.max_seq_len_cached:
96
+ self.max_seq_len_cached = seq_len
97
+ t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
98
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
99
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
100
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
101
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
102
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
103
+ return (
104
+ self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
105
+ self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
106
+ )
107
+
108
+
109
+ def rotate_half(x):
110
+ """Rotates half the hidden dims of the input."""
111
+ x1 = x[..., : x.shape[-1] // 2]
112
+ x2 = x[..., x.shape[-1] // 2 :]
113
+ return torch.cat((-x2, x1), dim=-1)
114
+
115
+
116
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
117
+ gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1]
118
+ gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3])
119
+ cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
120
+ sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
121
+ q_embed = (q * cos) + (rotate_half(q) * sin)
122
+ k_embed = (k * cos) + (rotate_half(k) * sin)
123
+ return q_embed, k_embed
124
+
125
+
126
+ class LlamaMLP(nn.Module):
127
+ def __init__(
128
+ self,
129
+ hidden_size: int,
130
+ intermediate_size: int,
131
+ hidden_act: str,
132
+ ):
133
+ super().__init__()
134
+ self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
135
+ self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
136
+ self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
137
+ self.act_fn = ACT2FN[hidden_act]
138
+
139
+ def forward(self, x):
140
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
141
+
142
+
143
+ class LlamaAttention(nn.Module):
144
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
145
+
146
+ def __init__(self, config: LlamaConfig):
147
+ super().__init__()
148
+ self.config = config
149
+ self.hidden_size = config.hidden_size
150
+ self.num_heads = config.num_attention_heads
151
+ self.head_dim = self.hidden_size // self.num_heads
152
+ self.max_position_embeddings = config.max_position_embeddings
153
+
154
+ if (self.head_dim * self.num_heads) != self.hidden_size:
155
+ raise ValueError(
156
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
157
+ f" and `num_heads`: {self.num_heads})."
158
+ )
159
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
160
+ self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
161
+ self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
162
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
163
+ self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
164
+
165
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
166
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
167
+
168
+ def forward(
169
+ self,
170
+ hidden_states: torch.Tensor,
171
+ attention_mask: Optional[torch.Tensor] = None,
172
+ position_ids: Optional[torch.LongTensor] = None,
173
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
174
+ output_attentions: bool = False,
175
+ use_cache: bool = False,
176
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
177
+ bsz, q_len, _ = hidden_states.size()
178
+
179
+ query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
180
+ key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
181
+ value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
182
+
183
+ kv_seq_len = key_states.shape[-2]
184
+ if past_key_value is not None:
185
+ kv_seq_len += past_key_value[0].shape[-2]
186
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
187
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
188
+ # [bsz, nh, t, hd]
189
+
190
+ if past_key_value is not None:
191
+ # reuse k, v, self_attention
192
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
193
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
194
+
195
+ past_key_value = (key_states, value_states) if use_cache else None
196
+
197
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
198
+
199
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
200
+ raise ValueError(
201
+ f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
202
+ f" {attn_weights.size()}"
203
+ )
204
+
205
+ if attention_mask is not None:
206
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
207
+ raise ValueError(
208
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
209
+ )
210
+ attn_weights = attn_weights + attention_mask
211
+ attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
212
+
213
+ # upcast attention to fp32
214
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
215
+ attn_output = torch.matmul(attn_weights, value_states)
216
+
217
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
218
+ raise ValueError(
219
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
220
+ f" {attn_output.size()}"
221
+ )
222
+
223
+ attn_output = attn_output.transpose(1, 2)
224
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
225
+
226
+ attn_output = self.o_proj(attn_output)
227
+
228
+ if not output_attentions:
229
+ attn_weights = None
230
+
231
+ return attn_output, attn_weights, past_key_value
232
+
233
+
234
+ class LlamaDecoderLayer(nn.Module):
235
+ def __init__(self, config: LlamaConfig):
236
+ super().__init__()
237
+ self.hidden_size = config.hidden_size
238
+ self.self_attn = LlamaAttention(config=config)
239
+ self.mlp = LlamaMLP(
240
+ hidden_size=self.hidden_size,
241
+ intermediate_size=config.intermediate_size,
242
+ hidden_act=config.hidden_act,
243
+ )
244
+ self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
245
+ self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
246
+
247
+ def forward(
248
+ self,
249
+ hidden_states: torch.Tensor,
250
+ attention_mask: Optional[torch.Tensor] = None,
251
+ position_ids: Optional[torch.LongTensor] = None,
252
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
253
+ output_attentions: Optional[bool] = False,
254
+ use_cache: Optional[bool] = False,
255
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
256
+ """
257
+ Args:
258
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
259
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
260
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
261
+ output_attentions (`bool`, *optional*):
262
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
263
+ returned tensors for more detail.
264
+ use_cache (`bool`, *optional*):
265
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
266
+ (see `past_key_values`).
267
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
268
+ """
269
+
270
+ residual = hidden_states
271
+
272
+ hidden_states = self.input_layernorm(hidden_states)
273
+
274
+ # Self Attention
275
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
276
+ hidden_states=hidden_states,
277
+ attention_mask=attention_mask,
278
+ position_ids=position_ids,
279
+ past_key_value=past_key_value,
280
+ output_attentions=output_attentions,
281
+ use_cache=use_cache,
282
+ )
283
+ hidden_states = residual + hidden_states
284
+
285
+ # Fully Connected
286
+ residual = hidden_states
287
+ hidden_states = self.post_attention_layernorm(hidden_states)
288
+ hidden_states = self.mlp(hidden_states)
289
+ hidden_states = residual + hidden_states
290
+
291
+ outputs = (hidden_states,)
292
+
293
+ if output_attentions:
294
+ outputs += (self_attn_weights,)
295
+
296
+ if use_cache:
297
+ outputs += (present_key_value,)
298
+
299
+ return outputs
300
+
301
+
302
+ LLAMA_START_DOCSTRING = r"""
303
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
304
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
305
+ etc.)
306
+
307
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
308
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
309
+ and behavior.
310
+
311
+ Parameters:
312
+ config ([`LlamaConfig`]):
313
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
314
+ load the weights associated with the model, only the configuration. Check out the
315
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
316
+ """
317
+
318
+
319
+ @add_start_docstrings(
320
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
321
+ LLAMA_START_DOCSTRING,
322
+ )
323
+ class LlamaPreTrainedModel(PreTrainedModel):
324
+ config_class = LlamaConfig
325
+ base_model_prefix = "model"
326
+ supports_gradient_checkpointing = True
327
+ _no_split_modules = ["LlamaDecoderLayer"]
328
+ _keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
329
+
330
+ def _init_weights(self, module):
331
+ std = self.config.initializer_range
332
+ if isinstance(module, nn.Linear):
333
+ module.weight.data.normal_(mean=0.0, std=std)
334
+ if module.bias is not None:
335
+ module.bias.data.zero_()
336
+ elif isinstance(module, nn.Embedding):
337
+ module.weight.data.normal_(mean=0.0, std=std)
338
+ if module.padding_idx is not None:
339
+ module.weight.data[module.padding_idx].zero_()
340
+
341
+ def _set_gradient_checkpointing(self, module, value=False):
342
+ if isinstance(module, LlamaModel):
343
+ module.gradient_checkpointing = value
344
+
345
+
346
+ LLAMA_INPUTS_DOCSTRING = r"""
347
+ Args:
348
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
349
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
350
+ it.
351
+
352
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
353
+ [`PreTrainedTokenizer.__call__`] for details.
354
+
355
+ [What are input IDs?](../glossary#input-ids)
356
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
357
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
358
+
359
+ - 1 for tokens that are **not masked**,
360
+ - 0 for tokens that are **masked**.
361
+
362
+ [What are attention masks?](../glossary#attention-mask)
363
+
364
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
365
+ [`PreTrainedTokenizer.__call__`] for details.
366
+
367
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
368
+ `past_key_values`).
369
+
370
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
371
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
372
+ information on the default strategy.
373
+
374
+ - 1 indicates the head is **not masked**,
375
+ - 0 indicates the head is **masked**.
376
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
377
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
378
+ config.n_positions - 1]`.
379
+
380
+ [What are position IDs?](../glossary#position-ids)
381
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
382
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
383
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
384
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
385
+
386
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
387
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
388
+
389
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
390
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
391
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
392
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
393
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
394
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
395
+ model's internal embedding lookup matrix.
396
+ use_cache (`bool`, *optional*):
397
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
398
+ `past_key_values`).
399
+ output_attentions (`bool`, *optional*):
400
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
401
+ tensors for more detail.
402
+ output_hidden_states (`bool`, *optional*):
403
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
404
+ more detail.
405
+ return_dict (`bool`, *optional*):
406
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
407
+ """
408
+
409
+
410
+ @add_start_docstrings(
411
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
412
+ LLAMA_START_DOCSTRING,
413
+ )
414
+ class LlamaModel(LlamaPreTrainedModel):
415
+ """
416
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
417
+
418
+ Args:
419
+ config: LlamaConfig
420
+ """
421
+
422
+ def __init__(self, config: LlamaConfig):
423
+ super().__init__(config)
424
+ self.padding_idx = config.pad_token_id
425
+ self.vocab_size = config.vocab_size
426
+
427
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
428
+ self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
429
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
430
+
431
+ self.gradient_checkpointing = False
432
+ # Initialize weights and apply final processing
433
+ self.post_init()
434
+
435
+ def get_input_embeddings(self):
436
+ return self.embed_tokens
437
+
438
+ def set_input_embeddings(self, value):
439
+ self.embed_tokens = value
440
+
441
+ # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
442
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
443
+ # create causal mask
444
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
445
+ combined_attention_mask = None
446
+ if input_shape[-1] > 1:
447
+ combined_attention_mask = _make_causal_mask(
448
+ input_shape,
449
+ inputs_embeds.dtype,
450
+ device=inputs_embeds.device,
451
+ past_key_values_length=past_key_values_length,
452
+ )
453
+
454
+ if attention_mask is not None:
455
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
456
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
457
+ inputs_embeds.device
458
+ )
459
+ combined_attention_mask = (
460
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
461
+ )
462
+
463
+ return combined_attention_mask
464
+
465
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
466
+ def forward(
467
+ self,
468
+ input_ids: torch.LongTensor = None,
469
+ attention_mask: Optional[torch.Tensor] = None,
470
+ position_ids: Optional[torch.LongTensor] = None,
471
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
472
+ inputs_embeds: Optional[torch.FloatTensor] = None,
473
+ query_embeds: Optional[torch.FloatTensor] = None,
474
+ use_cache: Optional[bool] = None,
475
+ output_attentions: Optional[bool] = None,
476
+ output_hidden_states: Optional[bool] = None,
477
+ return_dict: Optional[bool] = None,
478
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
479
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
480
+ output_hidden_states = (
481
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
482
+ )
483
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
484
+
485
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
486
+
487
+ # retrieve input_ids and inputs_embeds
488
+ if input_ids is not None and inputs_embeds is not None:
489
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
490
+ elif input_ids is not None:
491
+ batch_size, seq_length = input_ids.shape
492
+ elif inputs_embeds is not None:
493
+ batch_size, seq_length, _ = inputs_embeds.shape
494
+ else:
495
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
496
+
497
+ if inputs_embeds is None:
498
+ inputs_embeds = self.embed_tokens(input_ids)
499
+ if query_embeds is not None:
500
+ inputs_embeds = torch.cat([query_embeds, inputs_embeds], dim=1)
501
+ batch_size, seq_length, _ = inputs_embeds.shape
502
+
503
+ seq_length_with_past = seq_length
504
+ past_key_values_length = 0
505
+
506
+ if past_key_values is not None:
507
+ past_key_values_length = past_key_values[0][0].shape[2]
508
+ seq_length_with_past = seq_length_with_past + past_key_values_length
509
+
510
+ if position_ids is None:
511
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
512
+ position_ids = torch.arange(
513
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
514
+ )
515
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
516
+ else:
517
+ position_ids = position_ids.view(-1, seq_length).long()
518
+
519
+ # embed positions
520
+ if attention_mask is None:
521
+ attention_mask = torch.ones(
522
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
523
+ )
524
+ attention_mask = self._prepare_decoder_attention_mask(
525
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
526
+ )
527
+
528
+ hidden_states = inputs_embeds
529
+
530
+ if self.gradient_checkpointing and self.training:
531
+ if use_cache:
532
+ logger.warning_once(
533
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
534
+ )
535
+ use_cache = False
536
+
537
+ # decoder layers
538
+ all_hidden_states = () if output_hidden_states else None
539
+ all_self_attns = () if output_attentions else None
540
+ next_decoder_cache = () if use_cache else None
541
+
542
+ for idx, decoder_layer in enumerate(self.layers):
543
+ if output_hidden_states:
544
+ all_hidden_states += (hidden_states,)
545
+
546
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
547
+
548
+ if self.gradient_checkpointing and self.training:
549
+
550
+ def create_custom_forward(module):
551
+ def custom_forward(*inputs):
552
+ # None for past_key_value
553
+ return module(*inputs, output_attentions, None)
554
+
555
+ return custom_forward
556
+
557
+ layer_outputs = torch.utils.checkpoint.checkpoint(
558
+ create_custom_forward(decoder_layer),
559
+ hidden_states,
560
+ attention_mask,
561
+ position_ids,
562
+ None,
563
+ )
564
+ else:
565
+ layer_outputs = decoder_layer(
566
+ hidden_states,
567
+ attention_mask=attention_mask,
568
+ position_ids=position_ids,
569
+ past_key_value=past_key_value,
570
+ output_attentions=output_attentions,
571
+ use_cache=use_cache,
572
+ )
573
+
574
+ hidden_states = layer_outputs[0]
575
+
576
+ if use_cache:
577
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
578
+
579
+ if output_attentions:
580
+ all_self_attns += (layer_outputs[1],)
581
+
582
+ hidden_states = self.norm(hidden_states)
583
+
584
+ # add hidden states from the last decoder layer
585
+ if output_hidden_states:
586
+ all_hidden_states += (hidden_states,)
587
+
588
+ next_cache = next_decoder_cache if use_cache else None
589
+ if not return_dict:
590
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
591
+ return BaseModelOutputWithPast(
592
+ last_hidden_state=hidden_states,
593
+ past_key_values=next_cache,
594
+ hidden_states=all_hidden_states,
595
+ attentions=all_self_attns,
596
+ )
597
+
598
+
599
+ class LlamaForCausalLM(LlamaPreTrainedModel):
600
+ def __init__(self, config):
601
+ super().__init__(config)
602
+ self.model = LlamaModel(config)
603
+
604
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
605
+
606
+ # Initialize weights and apply final processing
607
+ self.post_init()
608
+
609
+ def get_input_embeddings(self):
610
+ return self.model.embed_tokens
611
+
612
+ def set_input_embeddings(self, value):
613
+ self.model.embed_tokens = value
614
+
615
+ def get_output_embeddings(self):
616
+ return self.lm_head
617
+
618
+ def set_output_embeddings(self, new_embeddings):
619
+ self.lm_head = new_embeddings
620
+
621
+ def set_decoder(self, decoder):
622
+ self.model = decoder
623
+
624
+ def get_decoder(self):
625
+ return self.model
626
+
627
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
628
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
629
+ def forward(
630
+ self,
631
+ input_ids: torch.LongTensor = None,
632
+ attention_mask: Optional[torch.Tensor] = None,
633
+ position_ids: Optional[torch.LongTensor] = None,
634
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
635
+ inputs_embeds: Optional[torch.FloatTensor] = None,
636
+ query_embeds: Optional[torch.FloatTensor] = None,
637
+ labels: Optional[torch.LongTensor] = None,
638
+ use_cache: Optional[bool] = None,
639
+ output_attentions: Optional[bool] = None,
640
+ output_hidden_states: Optional[bool] = None,
641
+ return_dict: Optional[bool] = None,
642
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
643
+ r"""
644
+ Args:
645
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
646
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
647
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
648
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
649
+
650
+ Returns:
651
+
652
+ Example:
653
+
654
+ ```python
655
+ >>> from transformers import AutoTokenizer, LlamaForCausalLM
656
+
657
+ >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
658
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
659
+
660
+ >>> prompt = "Hey, are you consciours? Can you talk to me?"
661
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
662
+
663
+ >>> # Generate
664
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
665
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
666
+ "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
667
+ ```"""
668
+
669
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
670
+ output_hidden_states = (
671
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
672
+ )
673
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
674
+
675
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
676
+ outputs = self.model(
677
+ input_ids=input_ids,
678
+ attention_mask=attention_mask,
679
+ position_ids=position_ids,
680
+ past_key_values=past_key_values,
681
+ inputs_embeds=inputs_embeds,
682
+ query_embeds=query_embeds,
683
+ use_cache=use_cache,
684
+ output_attentions=output_attentions,
685
+ output_hidden_states=output_hidden_states,
686
+ return_dict=return_dict,
687
+ )
688
+
689
+ hidden_states = outputs[0]
690
+ logits = self.lm_head(hidden_states)
691
+
692
+ loss = None
693
+ if labels is not None:
694
+ # Shift so that tokens < n predict n
695
+ shift_logits = logits[..., :-1, :].contiguous()
696
+ shift_labels = labels[..., 1:].contiguous()
697
+ # Flatten the tokens
698
+ loss_fct = CrossEntropyLoss()
699
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
700
+ shift_labels = shift_labels.view(-1)
701
+ # Enable model parallelism
702
+ shift_labels = shift_labels.to(shift_logits.device)
703
+ loss = loss_fct(shift_logits, shift_labels)
704
+
705
+ if not return_dict:
706
+ output = (logits,) + outputs[1:]
707
+ return (loss,) + output if loss is not None else output
708
+
709
+ return CausalLMOutputWithPast(
710
+ loss=loss,
711
+ logits=logits,
712
+ past_key_values=outputs.past_key_values,
713
+ hidden_states=outputs.hidden_states,
714
+ attentions=outputs.attentions,
715
+ )
716
+
717
+ def prepare_inputs_for_generation(
718
+ self, input_ids, query_embeds=None, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
719
+ ):
720
+ if past_key_values:
721
+ input_ids = input_ids[:, -1:]
722
+
723
+ position_ids = kwargs.get("position_ids", None)
724
+ if attention_mask is not None and position_ids is None:
725
+ # create position_ids on the fly for batch generation
726
+ position_ids = attention_mask.long().cumsum(-1) - 1
727
+ position_ids.masked_fill_(attention_mask == 0, 1)
728
+ if past_key_values:
729
+ position_ids = position_ids[:, -1].unsqueeze(-1)
730
+ query_embeds = None
731
+
732
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
733
+ if inputs_embeds is not None and past_key_values is None:
734
+ model_inputs = {"inputs_embeds": inputs_embeds}
735
+ else:
736
+ model_inputs = {"input_ids": input_ids}
737
+
738
+ model_inputs.update(
739
+ {
740
+ "position_ids": position_ids,
741
+ "query_embeds": query_embeds,
742
+ "past_key_values": past_key_values,
743
+ "use_cache": kwargs.get("use_cache"),
744
+ "attention_mask": attention_mask,
745
+ }
746
+ )
747
+ return model_inputs
748
+
749
+ @staticmethod
750
+ def _reorder_cache(past_key_values, beam_idx):
751
+ reordered_past = ()
752
+ for layer_past in past_key_values:
753
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
754
+ return reordered_past
755
+
model/openlamm.py ADDED
@@ -0,0 +1,550 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+
4
+ import requests
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ # from petrel_client.client import Client
9
+ from PIL import Image, ImageFile
10
+ from torch.nn.utils import rnn
11
+ from types import SimpleNamespace
12
+ from peft import LoraConfig, TaskType, get_peft_model
13
+ from transformers import LlamaTokenizer, LlamaForCausalLM, LlamaConfig
14
+ import conversations
15
+ import numpy as np
16
+ # from header import *
17
+
18
+ from transformers import StoppingCriteria, StoppingCriteriaList
19
+
20
+ from .CLIP import load as load_clip
21
+ from .PROCESS import data
22
+
23
+ from .modeling_llama import LlamaForCausalLM
24
+ from .utils.pcl_utils import MEAN_COLOR_RGB, RandomCuboid, random_sampling
25
+
26
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
27
+
28
+ # sov: start of vision part; eov: end of vision part
29
+ VISION_TAGS = {
30
+ 'pos': {'image': '<image>', 'pcl': '<pcl>'},
31
+ 'sov': {'image': '<Img>', 'pcl': '<Pcl>'},
32
+ 'eov': {'image': '</Img>', 'pcl': '</Pcl>'},
33
+ }
34
+ ModalityType = SimpleNamespace(
35
+ VISION="vision",
36
+ TEXT="text",
37
+ AUDIO="audio",
38
+ THERMAL="thermal",
39
+ DEPTH="depth",
40
+ IMU="imu",
41
+ )
42
+
43
+ class StoppingCriteriaSub(StoppingCriteria):
44
+
45
+ def __init__(self, stops = [], encounters=1):
46
+ super().__init__()
47
+ self.stops = stops
48
+ self.ENCOUNTERS = encounters
49
+
50
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
51
+ stop_count = 0
52
+ for stop in self.stops:
53
+ stop_count = (stop == input_ids[0]).sum().item()
54
+ if stop_count >= self.ENCOUNTERS:
55
+ return True
56
+ return False
57
+
58
+
59
+ class MyStoppingCriteria(StoppingCriteria):
60
+ def __init__(self, stops, input_ids):
61
+ super().__init__()
62
+ self.stops = [torch.tensor(stop).to('cuda:0') for stop in stops]
63
+ self.stop_flag = [0]*input_ids.shape[0]
64
+
65
+ def check_stop(self, input_ids):
66
+ for stop in self.stops:
67
+ if torch.all((stop == input_ids[-len(stop):])).item():
68
+ return True
69
+ return False
70
+
71
+ def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
72
+ flag = 1
73
+ for id, output_id in enumerate(output_ids):
74
+ if self.stop_flag[id] == 1:
75
+ continue
76
+ if self.check_stop(output_id):
77
+ self.stop_flag[id] = 1
78
+ else:
79
+ flag = 0
80
+ if flag == 1:
81
+ return True
82
+ return False
83
+
84
+
85
+ def build_one_instance(tokenizer, conversation, vision_type='image'):
86
+ pos = VISION_TAGS['pos'][vision_type]
87
+ # sov = VISION_TAGS['sov'][vision_type]
88
+ eov = VISION_TAGS['eov'][vision_type]
89
+
90
+ text_list = []
91
+ turn_num = len(conversation)
92
+ input_ids, target_ids = [], []
93
+ for i in range(turn_num):
94
+ turn = conversation[i]
95
+ role = turn['from']
96
+ if i == 0: # the first human turn
97
+ assert role == 'human'
98
+ turn['value'] = turn['value'].replace(f'{pos}\n', '').replace(f'\n{pos}', '')
99
+ text = f'{eov} ' + turn['value'] + '\n### Assistant:'
100
+ one_input_id = tokenizer(text, add_special_tokens=False).input_ids
101
+ input_ids += one_input_id
102
+ target_ids += [-100]*len(one_input_id) # do not perform loss regression on human prompt
103
+ else:
104
+ if role == 'human':
105
+ text = 'Human: ' + turn['value'] + '\n### Assistant:'
106
+ one_input_id = tokenizer(text, add_special_tokens=False).input_ids
107
+ input_ids += one_input_id
108
+ target_ids += [-100]*len(one_input_id)
109
+ elif role == 'gpt':
110
+ text = turn['value'] + '\n###'
111
+ one_input_id = tokenizer(text, add_special_tokens=False).input_ids
112
+ input_ids += one_input_id
113
+ target_ids += one_input_id
114
+ else:
115
+ raise Exception('Wrong Role!!!')
116
+ text_list.append(text)
117
+ assert len(input_ids) == len(target_ids)
118
+ return text_list, input_ids, target_ids
119
+
120
+
121
+ def process_batch_instance(tokenizer, batch_of_conversations, max_tgt_len, vision_type='image'):
122
+ batch_input_ids, batch_target_ids = [], []
123
+ for conversation in batch_of_conversations:
124
+ _, one_input_ids, one_target_ids = build_one_instance(tokenizer, conversation, vision_type=vision_type)
125
+ batch_input_ids.append(torch.LongTensor(one_input_ids))
126
+ batch_target_ids.append(torch.LongTensor(one_target_ids))
127
+ input_ids = rnn.pad_sequence(batch_input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
128
+ target_ids = rnn.pad_sequence(batch_target_ids, batch_first=True, padding_value=-100)
129
+ assert input_ids.size() == target_ids.size()
130
+ input_ids = input_ids[:,:max_tgt_len]
131
+ target_ids = target_ids[:,:max_tgt_len]
132
+ attention_mask = input_ids.ne(tokenizer.pad_token_id)
133
+ assert attention_mask.size() == input_ids.size()
134
+ return input_ids, target_ids, attention_mask.long()
135
+
136
+
137
+ def make_prompt_start(system_header=False, vision_type='image', task_type='normal'):
138
+ # TODO: choose prefix according to task type
139
+ PROMPT_START = f'### Human: {VISION_TAGS["sov"][vision_type]}'
140
+ if system_header:
141
+ if task_type == 'normal':
142
+ return f"{conversations.default_conversation.system}\n\n" + PROMPT_START
143
+ else:
144
+ return [f"{conversations.conversation_dict[task]}\n\n" + PROMPT_START for task in task_type]
145
+ else:
146
+ return PROMPT_START
147
+
148
+
149
+ class LAMMPEFTModel(nn.Module):
150
+
151
+ '''LoRA for LLaMa model'''
152
+
153
+ def __init__(self, **args):
154
+ super(LAMMPEFTModel, self).__init__()
155
+ self.args = args
156
+ # self.client = Client('~/petreloss.conf')
157
+ self.client = None
158
+
159
+ self.vision_type = args['vision_type'] if 'vision_type' in args else 'image'
160
+ encoder_pretrain = args['encoder_pretrain'] if 'encoder_pretrain' in args else 'clip'
161
+ self.encoder_pretrain = encoder_pretrain
162
+ assert encoder_pretrain in ['imagebind', 'clip', 'epcl'], f'Encoder_pretrain: {encoder_pretrain} Not Implemented'
163
+ encoder_ckpt_path = args['encoder_ckpt_path'] if not encoder_pretrain == 'clip' else '~/.cache/clip/ViT-L-14.pt'
164
+ vicuna_ckpt_path = args['vicuna_ckpt_path']
165
+
166
+ system_header = args['system_header'] if 'system_header' in args else False
167
+ stage = args['stage']
168
+
169
+ # TODO: checkout vision token number; for ImageBind = 1; Defaultly to use 1 global token for this
170
+ # -1 for last embedding; -2 for transformer output
171
+ self.vision_feature_type = args['vision_feature_type']
172
+ self.num_vision_token = args['num_vision_token']
173
+
174
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
175
+ print (f'Initializing [{encoder_pretrain}] visual encoder from {encoder_ckpt_path} [{device}]...')
176
+
177
+ # TODO: Make sure the number of vision tokens is correct
178
+ if args['encoder_pretrain'].lower() == 'clip':
179
+ clip_encoder, self.visual_preprocess = load_clip('ViT-L/14', device=device)
180
+ self.visual_encoder = clip_encoder.visual
181
+ if self.vision_feature_type == 'global': # global feature from CLIP
182
+ self.vision_hidden_size = 768
183
+ self.num_vision_token = 1
184
+ assert self.num_vision_token == 1, 'Only 1 global token is available!'
185
+ elif self.vision_feature_type == 'local': # patch features from CLIP ViT
186
+ self.vision_hidden_size = 1024
187
+ self.num_vision_token = min(self.num_vision_token, 256) # may cut partial tokens
188
+
189
+ # freeze vision encoder
190
+ for name, param in self.visual_encoder.named_parameters():
191
+ param.requires_grad = False
192
+ self.visual_encoder.eval()
193
+ print ('Visual encoder initialized.')
194
+
195
+ print (f'Initializing language decoder from {vicuna_ckpt_path} ...')
196
+ # add the lora module
197
+ peft_config = LoraConfig(
198
+ task_type=TaskType.CAUSAL_LM,
199
+ inference_mode=False,
200
+ r=self.args['lora_r'],
201
+ lora_alpha=self.args['lora_alpha'],
202
+ lora_dropout=self.args['lora_dropout'],
203
+ target_modules=self.args['lora_target_modules']
204
+ )
205
+
206
+ self.llama_model = LlamaForCausalLM.from_pretrained(vicuna_ckpt_path)
207
+ self.llama_model = get_peft_model(self.llama_model, peft_config)
208
+ self.llama_model.print_trainable_parameters()
209
+
210
+ self.llama_tokenizer = LlamaTokenizer.from_pretrained(vicuna_ckpt_path, use_fast=False)
211
+ self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token
212
+ self.llama_tokenizer.padding_side = "right"
213
+ print ('Language decoder initialized.')
214
+
215
+ self.llama_proj = nn.Linear(
216
+ self.vision_hidden_size, self.llama_model.config.hidden_size
217
+ )
218
+ print ('LLaMa projection layer initialized.')
219
+
220
+ self.max_tgt_len = args['max_tgt_len']
221
+ self.system_header = system_header
222
+ self.device = torch.cuda.current_device()
223
+
224
+ # def encode_video(self, video_paths):
225
+ # inputs = {ModalityType.VISION: data.load_and_transform_video_data(video_paths, self.device)}
226
+ # # convert into visual dtype
227
+ # inputs = {key: inputs[key].to(self.llama_model.dtype) for key in inputs}
228
+ # with torch.no_grad():
229
+ # embeddings = self.visual_encoder(inputs)
230
+ # video_embeds = embeddings[ModalityType.VISION] # bsz x 1024
231
+ # inputs_llama = self.llama_proj(video_embeds).unsqueeze(1) # bsz x 1 x llama_size
232
+ # atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(self.device) # bsz x 1
233
+ # return inputs_llama, atts_llama
234
+
235
+ # def encode_audio(self, audio_paths):
236
+ # inputs = {ModalityType.AUDIO: data.load_and_transform_audio_data(audio_paths, self.device)}
237
+ # # convert into visual dtype
238
+ # inputs = {key: inputs[key].to(self.llama_model.dtype) for key in inputs}
239
+ # with torch.no_grad():
240
+ # embeddings = self.visual_encoder(inputs)
241
+ # audio_embeds = embeddings[ModalityType.AUDIO] # bsz x 1024
242
+ # inputs_llama = self.llama_proj(audio_embeds).unsqueeze(1) # bsz x 1 x llama_size
243
+ # atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(self.device) # bsz x 1
244
+ # return inputs_llama, atts_llama
245
+
246
+ # def encode_thermal(self, thermal_paths):
247
+ # inputs = {ModalityType.THERMAL: data.load_and_transform_thermal_data(thermal_paths, self.device)}
248
+ # # convert into visual dtype
249
+ # inputs = {key: inputs[key].to(self.llama_model.dtype) for key in inputs}
250
+ # with torch.no_grad():
251
+ # embeddings = self.visual_encoder(inputs)
252
+ # image_embeds = embeddings['thermal'] # bsz x 1024
253
+ # inputs_llama = self.llama_proj(image_embeds).unsqueeze(1) # bsz x 1 x llama_size
254
+ # atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(self.device) # bsz x 1
255
+ # return inputs_llama, atts_llama
256
+
257
+ def encode_image(self, image_paths):
258
+ """encode images to llama inputs
259
+
260
+ :param tupe image_paths: (bsz, )
261
+ :return tensor, tensor: input feature to llama, attention mask to llama
262
+ """
263
+ if self.encoder_pretrain == 'imagebind':
264
+ inputs = {ModalityType.VISION: data.load_and_transform_vision_data(image_paths, self.device)}
265
+ # convert into visual dtype
266
+ inputs = {key: inputs[key].to(self.llama_model.dtype) for key in inputs}
267
+ with torch.no_grad():
268
+ embeddings = self.visual_encoder(inputs)
269
+ image_embeds = embeddings['vision'] # bsz x 1024
270
+ inputs_llama = self.llama_proj(image_embeds).unsqueeze(1) # bsz x 1 x llama_size
271
+ atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(self.device) # bsz x 1
272
+ return inputs_llama, atts_llama
273
+ elif self.encoder_pretrain == 'clip':
274
+ inputs = self.load_and_transform_vision_data_clip(image_paths, self.device) # bsz x 3 x 224 x 224
275
+ inputs = inputs.to(self.llama_model.dtype) # clip requires torch.float32
276
+ inputs_llama = self.clip_encode_image(inputs)
277
+ atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(self.device) # bsz x 1/256
278
+ return inputs_llama, atts_llama
279
+
280
+ def my_encode_image(self, images):
281
+ """encoder loaded image objects"""
282
+ # if self.encoder_pretrain == 'imagebind':
283
+ # inputs = {ModalityType.VISION: data.transform_vision_data(images, self.device)}
284
+ # # convert into visual dtype
285
+ # inputs = {key: inputs[key].to(self.llama_model.dtype) for key in inputs}
286
+ # with torch.no_grad():
287
+ # embeddings = self.visual_encoder(inputs)
288
+ # image_embeds = embeddings['vision'] # bsz x 1024
289
+ # inputs_llama = self.llama_proj(image_embeds).unsqueeze(1) # bsz x 1 x llama_size
290
+ # atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(self.device) # bsz x 1
291
+ # return inputs_llama, atts_llama
292
+ if self.encoder_pretrain == 'clip':
293
+ inputs = data.transform_vision_data(images, self.device) # bsz x 3 x 224 x 224
294
+ inputs_llama = self.clip_encode_image(inputs) # bsz x 1/256 x llama_size
295
+ atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(self.device) # bsz x 1/256
296
+ return inputs_llama, atts_llama
297
+ else:
298
+ raise NotImplementedError("Encoder pretrain [{}] not implemented".format(self.encoder_pretrain))
299
+
300
+ def encode_pcl(self, pcl_paths):
301
+ # load pcl data
302
+ inputs = self.load_and_transform_pcl_data(pcl_paths, self.device) # bsz x 40000 x 3
303
+
304
+ inputs = inputs.to(self.llama_model.dtype) # clip requires torch.float32
305
+ with torch.no_grad():
306
+ if self.vision_feature_type == 'global':
307
+ raise NotImplementedError("Global feature not implemented for pcl")
308
+ elif self.vision_feature_type == 'local':
309
+ embeddings = self.visual_encoder(inputs)[1][:, :self.num_vision_token] # bsz x 256 x 1024;
310
+ image_embeds = embeddings.reshape(-1, self.vision_hidden_size).to(self.llama_model.dtype) # bsz*num vision token x 1024
311
+ inputs_llama = self.llama_proj(image_embeds).reshape(-1, self.num_vision_token, self.llama_model.config.hidden_size) # bsz x num_vision_token x llama_size
312
+ atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(self.device) # bsz x 1/256
313
+ return inputs_llama, atts_llama
314
+
315
+ def clip_encode_image(self, inputs):
316
+ inputs = inputs.to(self.llama_model.dtype) # clip requires torch.float32
317
+ with torch.no_grad():
318
+ if self.vision_feature_type == 'global':
319
+ embeddings = self.visual_encoder(inputs) # bsz x 768
320
+ image_embeds = embeddings.to(self.llama_model.dtype)
321
+ inputs_llama = self.llama_proj(image_embeds).unsqueeze(1) # bsz x 1 x llama_size
322
+ elif self.vision_feature_type == 'local':
323
+ embeddings = self.visual_encoder.forward_patch_features(inputs)[:, :self.num_vision_token] # bsz x self.num_vision_token x 1024
324
+ image_embeds = embeddings.reshape(-1, self.vision_hidden_size).to(self.llama_model.dtype) # bsz*num vision token x 1024
325
+ inputs_llama = self.llama_proj(image_embeds).reshape(-1, self.num_vision_token, self.llama_model.config.hidden_size) # bsz x num_vision_token x llama_size
326
+ else:
327
+ raise NotImplementedError("{} not Implemented".format(self.vision_feature_type))
328
+ return inputs_llama
329
+
330
+ def load_and_transform_vision_data_clip(self, image_paths, device):
331
+ if image_paths is None:
332
+ return None
333
+ image_ouputs = []
334
+ for image_path in image_paths:
335
+ if os.path.exists(image_path):
336
+ image = Image.open(image_path)
337
+ elif image_path.startswith('s3://') and self.client is not None:
338
+ image = Image.open(io.BytesIO(self.client.get(image_path, update_cache=True))).convert("RGB")
339
+ elif image_path.startswith('http://'):
340
+ image = Image.open(requests.get(image_path, stream=True).raw)
341
+ else:
342
+ print("can not load image: ", image_path)
343
+ image_outpt = self.visual_preprocess(image).to(device) # 3 x 224 x 224
344
+ image_ouputs.append(image_outpt)
345
+ return torch.stack(image_ouputs, dim=0) # B x 3 x 224 x 224
346
+
347
+ def load_and_transform_pcl_data(self, pcl_paths, device):
348
+ if pcl_paths is None:
349
+ return None
350
+ pcl_output = []
351
+ for pcl_path in pcl_paths:
352
+ mesh_vertices = np.load(pcl_path) # 150000, 3
353
+ if not self.use_color:
354
+ point_cloud = mesh_vertices[:, 0:3] # do not use color for now
355
+ else:
356
+ point_cloud = mesh_vertices[:, 0:6]
357
+ point_cloud[:, 3:] = (point_cloud[:, 3:] - MEAN_COLOR_RGB) / 256.0
358
+
359
+ if self.use_height:
360
+ floor_height = np.percentile(point_cloud[:, 2], 0.99)
361
+ height = point_cloud[:, 2] - floor_height
362
+ point_cloud = np.concatenate([point_cloud, np.expand_dims(height, 1)], 1)
363
+
364
+ point_cloud, _ = random_sampling(
365
+ point_cloud, self.num_points, return_choices=True
366
+ )
367
+ pcl_output.append(torch.from_numpy(point_cloud))
368
+ return torch.stack(pcl_output, dim=0).to(device) # bsz x num_points x 3
369
+
370
+ def prompt_wrap(self, img_embeds, input_ids, target_ids, attention_mask, system_header, task_type):
371
+ '''
372
+ input_ids, target_ids, attention_mask: bsz x s2
373
+ '''
374
+ input_ids = input_ids.to(self.device) # bsz x s2
375
+ target_ids = target_ids.to(self.device) # bsz x s2
376
+ attention_mask = attention_mask.to(self.device) # bsz x s2
377
+
378
+ batch_size = img_embeds.shape[0]
379
+
380
+ # return list of headers if multiple tasks
381
+ p_before = make_prompt_start(system_header=system_header, vision_type=self.vision_type, task_type=task_type)
382
+ if isinstance(p_before, list):
383
+ p_before_tokens = [self.llama_tokenizer(p,
384
+ return_tensors="pt", add_special_tokens=False).input_ids[0].to(self.device) for p in p_before]
385
+ # TODO: fix bug here
386
+ p_before_token_ids = rnn.pad_sequence(p_before_tokens, batch_first=True, padding_value=self.llama_tokenizer.pad_token_id) # bsz x s1
387
+ p_before_attn_mask = p_before_token_ids.ne(self.llama_tokenizer.pad_token_id)
388
+ else:
389
+ p_before_tokens = self.llama_tokenizer(p_before,
390
+ return_tensors="pt", add_special_tokens=False).to(self.device) # [s1, s1...] list of batch size
391
+ p_before_token_ids = p_before_tokens.input_ids.expand(batch_size, -1) # bsz x s1
392
+ p_before_attn_mask = p_before_tokens.attention_mask.expand(batch_size, -1) # bsz x s1
393
+ # peft model need deeper call
394
+ p_before_embeds = self.llama_model.model.model.embed_tokens(p_before_token_ids) #.expand(batch_size, -1, -1) # bsz x s1 x embed_dim
395
+ p_after_embeds = self.llama_model.model.model.embed_tokens(input_ids).expand(batch_size, -1, -1) # bsz x s2 x embed_dim
396
+ bos = torch.ones([batch_size, 1],
397
+ dtype=p_before_token_ids.dtype,
398
+ device=p_before_token_ids.device) * self.llama_tokenizer.bos_token_id # bsz x 1
399
+ bos_embeds = self.llama_model.model.model.embed_tokens(bos) # bsz x 1 x embed_dim
400
+ inputs_embeds = torch.cat([bos_embeds, p_before_embeds, img_embeds, p_after_embeds], dim=1) # bsz x (1+s1+NumToken+s2) x embed_dim
401
+
402
+ # make target ids for prefix part
403
+ empty_targets = (
404
+ torch.ones([batch_size, 1 + p_before_embeds.size()[1] + self.num_vision_token], # 1 (bos) + s1 + num_image_tokens (image vector)
405
+ dtype=torch.long).to(self.device).fill_(-100)
406
+ ) # bsz x (1 + s1 + 1)
407
+ targets = torch.cat([empty_targets, target_ids], dim=1) # bsz x (1 + s1 + num_image_tokens + s2)
408
+ assert inputs_embeds.size()[1] == targets.size()[1]
409
+
410
+ # atts_prefix = torch.ones([batch_size, 1 + p_before_embeds.size()[1] + self.num_vision_token], dtype=torch.long).to(self.device) # bsz x (1[bos] + s1 +num_image_tokens)
411
+ atts_bos = torch.ones([batch_size, 1], dtype=torch.long).to(self.device) # bsz x 1
412
+ atts_img = torch.ones([batch_size, self.num_vision_token], dtype=torch.long).to(self.device) # bsz x num_image_tokens
413
+ attention_mask = torch.cat([atts_bos, p_before_attn_mask, atts_img, attention_mask], dim=1)
414
+ assert attention_mask.size() == targets.size() # bsz x (1 + s1 + num_image_tokens + s2)
415
+ return inputs_embeds, targets, attention_mask
416
+
417
+ def forward(self, inputs):
418
+ """Model Forward in training
419
+
420
+ :param class inputs: model itself
421
+ :raises ValueError: valueerror if not image or pcl
422
+ :return list: loss & token acc
423
+ """
424
+ # image_paths = inputs['image_paths']
425
+ assert self.vision_type == inputs['vision_type'] # single modal case
426
+ task_type = inputs['task_type']
427
+ vision_paths = inputs['vision_paths']
428
+ if self.vision_type == 'image':
429
+ vision_embeds, _ = self.encode_image(vision_paths)
430
+ elif self.vision_type == 'pcl':
431
+ vision_embeds, _ = self.encode_pcl(vision_paths) # Bsz x N token x C
432
+ else:
433
+ raise ValueError('vision type [{}] not supported'.format(self.vision_type))
434
+
435
+ output_texts = inputs['output_texts']
436
+ input_ids, target_ids, attention_mask = process_batch_instance(self.llama_tokenizer, output_texts, self.max_tgt_len, self.vision_type)
437
+ inputs_embeds, targets, attention_mask = self.prompt_wrap(vision_embeds, input_ids, target_ids, attention_mask, self.system_header, task_type)
438
+
439
+ outputs = self.llama_model(
440
+ inputs_embeds=inputs_embeds,
441
+ attention_mask=attention_mask,
442
+ return_dict=True,
443
+ labels=targets,
444
+ )
445
+ loss = outputs.loss
446
+ # calculate the token accuarcy
447
+ chosen_tokens = torch.max(outputs.logits, dim=-1)[1][:, 1: -1] # [B, S-1]
448
+ labels = targets[:, 2:]
449
+ gen_acc = (chosen_tokens.reshape(-1) == labels.reshape(-1)).to(torch.long) # [B*S]
450
+ valid_mask = (labels != -100).reshape(-1)
451
+ valid_tokens = gen_acc & valid_mask # [B*S]
452
+ gen_acc = valid_tokens.sum().item() / valid_mask.sum().item()
453
+ return loss, gen_acc
454
+
455
+ def extract_multimodal_feature(self, inputs):
456
+ """Extract multimodal features from the input in Generation (Test)
457
+
458
+ :param Dict inputs: input dict; modality: path
459
+ :return _type_: _description_
460
+ """
461
+ features = []
462
+ if inputs['image_paths']:
463
+ image_embeds, _ = self.encode_image(inputs['image_paths'])
464
+ features.append(image_embeds)
465
+ if 'images' in inputs and inputs['images']: # image objects input in testing
466
+ image_embeds, _ = self.my_encode_image(inputs['images'])
467
+ return image_embeds
468
+ # features.append(image_embeds)
469
+ if 'pcl_paths' in inputs and inputs['pcl_paths']:
470
+ pcl_embeds, _ = self.encode_pcl(inputs['pcl_paths'])
471
+ features.append(pcl_embeds)
472
+ # TODO: Cautions HERE! Multimodality allowed in test ONLY!
473
+ feature_embeds = torch.cat(features).sum(dim=0).unsqueeze(0) # sum all modality features together
474
+ return feature_embeds
475
+
476
+ def prepare_generation_embedding(self, inputs):
477
+ """prepare for generation
478
+
479
+ :param class inputs: model
480
+ :return Dict: generation input
481
+ """
482
+ eov = VISION_TAGS['eov'][self.vision_type]
483
+ # TODO: add System header & image token size
484
+ prompt_list = inputs['prompt'] # questions from user
485
+ if len(inputs['modality_embeds']) == 1:
486
+ feature_embeds = inputs['modality_embeds'][0]
487
+ else:
488
+ feature_embeds = self.extract_multimodal_feature(inputs)
489
+ inputs['modality_embeds'].append(feature_embeds)
490
+
491
+ batch_size = feature_embeds.shape[0]
492
+ p_before = make_prompt_start(vision_type=self.vision_type) # no system header in test
493
+ p_before_tokens = self.llama_tokenizer(p_before,
494
+ return_tensors="pt", add_special_tokens=False).to(self.device)
495
+ p_before_embeds = self.llama_model.model.model.embed_tokens(p_before_tokens.input_ids).expand(batch_size, -1, -1) # bsz x s1 x embed_dim
496
+ p_after_embeds_list = []
497
+ p_after_tokens_list = []
498
+ for prompt in prompt_list:
499
+ # text = '</Img> ' + prompt + '\n### Assistant:'
500
+ text = f'{eov} ' + prompt + '\n### Assistant:'
501
+ p_after_tokens = self.llama_tokenizer(text, add_special_tokens=False, return_tensors='pt').to(self.device)
502
+
503
+ p_after_tokens_list.append(p_after_tokens.input_ids.squeeze(0))
504
+
505
+ p_after_tokens = rnn.pad_sequence(p_after_tokens_list, batch_first=True, padding_value=self.llama_tokenizer.pad_token_id)
506
+
507
+ p_after_embeds = self.llama_model.model.model.embed_tokens(p_after_tokens)
508
+
509
+ # text = f'{eov} ' + prompt + '\n### Assistant:'
510
+ # p_after_tokens = self.llama_tokenizer(text, add_special_tokens=False, return_tensors='pt').to(self.device)
511
+ # p_after_embeds = self.llama_model.model.model.embed_tokens(p_after_tokens.input_ids).expand(batch_size, -1, -1) # bsz x s1 x embed_dim
512
+ bos = torch.ones([batch_size, 1],
513
+ dtype=p_before_tokens.input_ids.dtype,
514
+ device=p_before_tokens.input_ids.device) * self.llama_tokenizer.bos_token_id # bsz x 1
515
+ bos_embeds = self.llama_model.model.model.embed_tokens(bos) # bsz x 1 x embed_dim
516
+ # print(bos_embeds.shape, p_before_embeds.shape, feature_embeds.shape, p_after_embeds.shape)
517
+ inputs_embeds = torch.cat([bos_embeds, p_before_embeds, feature_embeds, p_after_embeds], dim=1) # bsz x (1+s1+NumVisionToken+s2) x embed_dim
518
+ return inputs_embeds
519
+
520
+ def generate(self, inputs):
521
+ '''
522
+ inputs = {
523
+ 'image_paths': optional,
524
+ 'audio_paths': optional
525
+ 'video_paths': optional
526
+ 'thermal_paths': optional
527
+ 'mode': generation mode,
528
+ 'prompt': human input prompt,
529
+ 'max_tgt_len': generation length,
530
+ 'top_p': top_p,
531
+ 'temperature': temperature
532
+ 'modality_embeds': None or torch.tensor
533
+ 'modality_cache': save the image cache
534
+ }
535
+ '''
536
+ input_embeds = self.prepare_generation_embedding(inputs)
537
+ # stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=[2277], encounters=1)])
538
+ stopping_criteria = StoppingCriteriaList([MyStoppingCriteria([[2277]], input_embeds)])
539
+ outputs = self.llama_model.generate(
540
+ inputs_embeds=input_embeds,
541
+ max_new_tokens=inputs['max_tgt_len'],
542
+ top_p=inputs['top_p'],
543
+ temperature=inputs['temperature'],
544
+ do_sample=True,
545
+ use_cache=True,
546
+ stopping_criteria=stopping_criteria,
547
+ )
548
+ #output_text = self.llama_tokenizer.decode(outputs[0][:-2], skip_special_tokens=True)
549
+ output_text = self.llama_tokenizer.batch_decode(outputs, skip_special_tokens=True)
550
+ return output_text
model/utils/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .ceph_utils import *
2
+ from .pcl_utils import *
model/utils/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (224 Bytes). View file
 
model/utils/__pycache__/ceph_utils.cpython-310.pyc ADDED
Binary file (7.87 kB). View file
 
model/utils/__pycache__/pcl_utils.cpython-310.pyc ADDED
Binary file (3.22 kB). View file
 
model/utils/ceph_utils.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Date: 2022-07-18 2:15:47 pm
3
+ Author: dihuangdh
4
+ Descriptions:
5
+ -----
6
+ LastEditTime: 2022-09-14 3:44:19 pm
7
+ LastEditors: dihuangdh
8
+ """
9
+
10
+ import json
11
+ import pickle
12
+ import warnings
13
+ from io import BytesIO, StringIO # TODO:
14
+ from pathlib import Path
15
+ from typing import Any, Generator, Iterator, Optional, Tuple, Union
16
+
17
+ import cv2
18
+ import numpy as np
19
+
20
+
21
+ def has_method(obj: object, method: str) -> bool:
22
+ """Check whether the object has a method.
23
+ Args:
24
+ method (str): The method name to check.
25
+ obj (object): The object to check.
26
+ Returns:
27
+ bool: True if the object has the method else False.
28
+ """
29
+ return hasattr(obj, method) and callable(getattr(obj, method))
30
+
31
+
32
+ class PetrelBackend(object):
33
+ """Petrel storage backend - simple version"""
34
+
35
+ def __init__(self, enable_mc: bool = False) -> None:
36
+ try:
37
+ from petrel_client.client import Client
38
+ except ImportError:
39
+ raise ImportError(
40
+ "Please install petrel_client to enable " "PetrelBackend."
41
+ )
42
+
43
+ self._client = Client('~/petreloss.conf')
44
+
45
+ def get(self, filepath) -> memoryview:
46
+ value = self._client.Get(filepath)
47
+ value_buf = memoryview(value)
48
+ return value_buf
49
+
50
+ def get_text(self, filepath, warning=False) -> str:
51
+ try:
52
+ value = self._client.Get(filepath)
53
+ except:
54
+ if warning:
55
+ warnings.warn("Failed to get text from {}".format(filepath))
56
+ value = None
57
+ else:
58
+ raise Exception("Failed to get text from {}".format(filepath))
59
+ return str(value, encoding="utf-8")
60
+
61
+ def get_uint16_png(self, filepath, warning=False) -> np.ndarray:
62
+ try:
63
+ value = np.frombuffer(self._client.get(filepath), np.uint8)
64
+ value = cv2.imdecode(value, cv2.IMREAD_UNCHANGED)
65
+ except:
66
+ if warning:
67
+ warnings.warn("Failed to get uint16_png from {}".format(filepath))
68
+ value = None
69
+ else:
70
+ raise Exception("Failed to get uint16_png from {}".format(filepath))
71
+ return value
72
+
73
+ def get_uint8_jpg(self, filepath, warning=False) -> np.ndarray:
74
+ try:
75
+ value = np.frombuffer(self._client.get(filepath), np.uint8)
76
+ value = cv2.imdecode(value, cv2.IMREAD_UNCHANGED)
77
+ except:
78
+ if warning:
79
+ warnings.warn("Failed to get uint8_jpg from {}".format(filepath))
80
+ value = None
81
+ else:
82
+ raise Exception("Failed to get uint8_jpg from {}".format(filepath))
83
+ return value
84
+
85
+ def get_npz(self, filepath, warning=False) -> Any:
86
+ try:
87
+ value = self._client.get(filepath)
88
+ value = np.loads(value)
89
+ except Exception as e:
90
+ if warning:
91
+ warnings.warn("Failed to get npz from {}".format(filepath))
92
+ value = None
93
+ else:
94
+ print(e)
95
+ raise Exception("Failed to get npz from {}".format(filepath))
96
+ return value
97
+
98
+ def get_numpy_txt(self, filepath, warning=False) -> np.ndarray:
99
+ try:
100
+ value = np.loadtxt(StringIO(self.get_text(filepath)))
101
+ except:
102
+ if warning:
103
+ warnings.warn("Failed to get numpy_txt from {}".format(filepath))
104
+ value = None
105
+ else:
106
+ raise Exception("Failed to get numpy_txt from {}".format(filepath))
107
+ return value
108
+
109
+ def get_json(self, filepath, warning=False) -> Any:
110
+ try:
111
+ value = self._client.get(filepath)
112
+ value = json.loads(value)
113
+ except:
114
+ if warning:
115
+ warnings.warn("Failed to get json from {}".format(filepath))
116
+ value = None
117
+ else:
118
+ raise Exception("Failed to get json from {}".format(filepath))
119
+ return value
120
+
121
+ def put_uint16_png(self, filepath, value) -> None:
122
+ success, img_array = cv2.imencode(".png", value, params=[cv2.CV_16U])
123
+ assert success
124
+ img_bytes = img_array.tobytes()
125
+ self._client.put(filepath, img_bytes)
126
+ # self._client.put(filepath, img_bytes, update_cache=True)
127
+
128
+ def put_uint8_jpg(self, filepath, value) -> None:
129
+ success, img_array = cv2.imencode(".jpg", value)
130
+ assert success
131
+ img_bytes = img_array.tobytes()
132
+ self._client.put(filepath, img_bytes)
133
+ # self._client.put(filepath, img_bytes, update_cache=True)
134
+
135
+ def put_npz(self, filepath, value) -> None:
136
+ value = pickle.dumps(value)
137
+ self._client.put(filepath, value)
138
+ # self._client.put(filepath, value, update_cache=True)
139
+
140
+ def put_json(self, filepath, value) -> None:
141
+ value = json.dumps(value).encode()
142
+ self._client.put(filepath, value)
143
+ # self._client.put(filepath, value, update_cache=True)
144
+
145
+ def put_text(self, filepath, value) -> None:
146
+ self._client.put(filepath, bytes(value, encoding="utf-8"))
147
+ # self._client.put(filepath, bytes(value, encoding='utf-8'), update_cache=True)
148
+
149
+ def join_path(
150
+ self, filepath: Union[str, Path], *filepaths: Union[str, Path]
151
+ ) -> str:
152
+ """Concatenate all file paths.
153
+ Args:
154
+ filepath (str or Path): Path to be concatenated.
155
+ Returns:
156
+ str: The result after concatenation.
157
+ """
158
+ # filepath = self._format_path(self._map_path(filepath))
159
+ if filepath.endswith("/"):
160
+ filepath = filepath[:-1]
161
+ formatted_paths = [filepath]
162
+ for path in filepaths:
163
+ formatted_paths.append(path)
164
+ return "/".join(formatted_paths)
165
+
166
+ # from mmcv
167
+ def list_dir_or_file(
168
+ self,
169
+ dir_path: Union[str, Path],
170
+ list_dir: bool = True,
171
+ list_file: bool = True,
172
+ suffix: Optional[Union[str, Tuple[str]]] = None,
173
+ recursive: bool = False,
174
+ ) -> Iterator[str]:
175
+ """Scan a directory to find the interested directories or files in
176
+ arbitrary order.
177
+ Note:
178
+ Petrel has no concept of directories but it simulates the directory
179
+ hierarchy in the filesystem through public prefixes. In addition,
180
+ if the returned path ends with '/', it means the path is a public
181
+ prefix which is a logical directory.
182
+ Note:
183
+ :meth:`list_dir_or_file` returns the path relative to ``dir_path``.
184
+ In addition, the returned path of directory will not contains the
185
+ suffix '/' which is consistent with other backends.
186
+ Args:
187
+ dir_path (str | Path): Path of the directory.
188
+ list_dir (bool): List the directories. Default: True.
189
+ list_file (bool): List the path of files. Default: True.
190
+ suffix (str or tuple[str], optional): File suffix
191
+ that we are interested in. Default: None.
192
+ recursive (bool): If set to True, recursively scan the
193
+ directory. Default: False.
194
+ Yields:
195
+ Iterable[str]: A relative path to ``dir_path``.
196
+ """
197
+ # if not has_method(self._client, 'list'):
198
+ # raise NotImplementedError(
199
+ # 'Current version of Petrel Python SDK has not supported '
200
+ # 'the `list` method, please use a higher version or dev'
201
+ # ' branch instead.')
202
+
203
+ # dir_path = self._map_path(dir_path)
204
+ # dir_path = self._format_path(dir_path)
205
+ # if list_dir and suffix is not None:
206
+ # raise TypeError(
207
+ # '`list_dir` should be False when `suffix` is not None')
208
+
209
+ # if (suffix is not None) and not isinstance(suffix, (str, tuple)):
210
+ # raise TypeError('`suffix` must be a string or tuple of strings')
211
+
212
+ # Petrel's simulated directory hierarchy assumes that directory paths
213
+ # should end with `/`
214
+ if not dir_path.endswith("/"):
215
+ dir_path += "/"
216
+
217
+ root = dir_path
218
+
219
+ def _list_dir_or_file(dir_path, list_dir, list_file, suffix, recursive):
220
+ for path in self._client.list(dir_path):
221
+ # the `self.isdir` is not used here to determine whether path
222
+ # is a directory, because `self.isdir` relies on
223
+ # `self._client.list`
224
+ if path.endswith("/"): # a directory path
225
+ next_dir_path = self.join_path(dir_path, path)
226
+ if list_dir:
227
+ # get the relative path and exclude the last
228
+ # character '/'
229
+ rel_dir = next_dir_path[len(root) : -1]
230
+ yield rel_dir
231
+ if recursive:
232
+ yield from _list_dir_or_file(
233
+ next_dir_path, list_dir, list_file, suffix, recursive
234
+ )
235
+ else: # a file path
236
+ absolute_path = self.join_path(dir_path, path)
237
+ rel_path = absolute_path[len(root) :]
238
+ if (suffix is None or rel_path.endswith(suffix)) and list_file:
239
+ yield rel_path
240
+
241
+ return _list_dir_or_file(dir_path, list_dir, list_file, suffix, recursive)
242
+
243
+ # from mmcv
244
+ def exists(self, filepath: Union[str, Path]) -> bool:
245
+ """Check whether a file path exists.
246
+ Args:
247
+ filepath (str or Path): Path to be checked whether exists.
248
+ Returns:
249
+ bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise.
250
+ """
251
+ if not (
252
+ has_method(self._client, "contains") and has_method(self._client, "isdir")
253
+ ):
254
+ raise NotImplementedError(
255
+ "Current version of Petrel Python SDK has not supported "
256
+ "the `contains` and `isdir` methods, please use a higher"
257
+ "version or dev branch instead."
258
+ )
259
+
260
+ return self._client.contains(filepath) or self._client.isdir(filepath)
model/utils/pcl_utils.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ """ Utility functions for processing point clouds.
4
+
5
+ Author: Charles R. Qi and Or Litany
6
+ """
7
+
8
+ import os
9
+ import sys
10
+ import torch
11
+
12
+ # Point cloud IO
13
+ import numpy as np
14
+ from plyfile import PlyData, PlyElement
15
+
16
+ # Mesh IO
17
+ import trimesh
18
+
19
+
20
+ MEAN_COLOR_RGB = np.array([109.8, 97.2, 83.8])
21
+
22
+ # ----------------------------------------
23
+ # Point Cloud Sampling
24
+ # ----------------------------------------
25
+
26
+
27
+ def random_sampling(pc, num_sample, replace=None, return_choices=False):
28
+ """Input is NxC, output is num_samplexC"""
29
+ if replace is None:
30
+ replace = pc.shape[0] < num_sample
31
+ choices = np.random.choice(pc.shape[0], num_sample, replace=replace)
32
+ if return_choices:
33
+ return pc[choices], choices
34
+ else:
35
+ return pc[choices]
36
+
37
+
38
+ def check_aspect(crop_range, aspect_min):
39
+ xy_aspect = np.min(crop_range[:2]) / np.max(crop_range[:2])
40
+ xz_aspect = np.min(crop_range[[0, 2]]) / np.max(crop_range[[0, 2]])
41
+ yz_aspect = np.min(crop_range[1:]) / np.max(crop_range[1:])
42
+ return (
43
+ (xy_aspect >= aspect_min)
44
+ or (xz_aspect >= aspect_min)
45
+ or (yz_aspect >= aspect_min)
46
+ )
47
+
48
+
49
+ class RandomCuboid(object):
50
+ """
51
+ RandomCuboid augmentation from DepthContrast [https://arxiv.org/abs/2101.02691]
52
+ We slightly modify this operation to account for object detection.
53
+ This augmentation randomly crops a cuboid from the input and
54
+ ensures that the cropped cuboid contains at least one bounding box
55
+ """
56
+
57
+ def __init__(
58
+ self,
59
+ min_points,
60
+ aspect=0.8,
61
+ min_crop=0.5,
62
+ max_crop=1.0,
63
+ box_filter_policy="center",
64
+ ):
65
+ self.aspect = aspect
66
+ self.min_crop = min_crop
67
+ self.max_crop = max_crop
68
+ self.min_points = min_points
69
+ self.box_filter_policy = box_filter_policy
70
+
71
+ def __call__(self, point_cloud, target_boxes, per_point_labels=None):
72
+ range_xyz = np.max(point_cloud[:, 0:3], axis=0) - np.min(
73
+ point_cloud[:, 0:3], axis=0
74
+ )
75
+
76
+ for _ in range(100):
77
+ crop_range = self.min_crop + np.random.rand(3) * (
78
+ self.max_crop - self.min_crop
79
+ )
80
+ if not check_aspect(crop_range, self.aspect):
81
+ continue
82
+
83
+ sample_center = point_cloud[np.random.choice(len(point_cloud)), 0:3]
84
+
85
+ new_range = range_xyz * crop_range / 2.0
86
+
87
+ max_xyz = sample_center + new_range
88
+ min_xyz = sample_center - new_range
89
+
90
+ upper_idx = (
91
+ np.sum((point_cloud[:, 0:3] <= max_xyz).astype(np.int32), 1) == 3
92
+ )
93
+ lower_idx = (
94
+ np.sum((point_cloud[:, 0:3] >= min_xyz).astype(np.int32), 1) == 3
95
+ )
96
+
97
+ new_pointidx = (upper_idx) & (lower_idx)
98
+
99
+ if np.sum(new_pointidx) < self.min_points:
100
+ continue
101
+
102
+ new_point_cloud = point_cloud[new_pointidx, :]
103
+
104
+ # filtering policy is the only modification from DepthContrast
105
+ if self.box_filter_policy == "center":
106
+ # remove boxes whose center does not lie within the new_point_cloud
107
+ new_boxes = target_boxes
108
+ if (
109
+ target_boxes.sum() > 0
110
+ ): # ground truth contains no bounding boxes. Common in SUNRGBD.
111
+ box_centers = target_boxes[:, 0:3]
112
+ new_pc_min_max = np.min(new_point_cloud[:, 0:3], axis=0), np.max(
113
+ new_point_cloud[:, 0:3], axis=0
114
+ )
115
+ keep_boxes = np.logical_and(
116
+ np.all(box_centers >= new_pc_min_max[0], axis=1),
117
+ np.all(box_centers <= new_pc_min_max[1], axis=1),
118
+ )
119
+ if keep_boxes.sum() == 0:
120
+ # current data augmentation removes all boxes in the pointcloud. fail!
121
+ continue
122
+ new_boxes = target_boxes[keep_boxes]
123
+ if per_point_labels is not None:
124
+ new_per_point_labels = [x[new_pointidx] for x in per_point_labels]
125
+ else:
126
+ new_per_point_labels = None
127
+ # if we are here, all conditions are met. return boxes
128
+ return new_point_cloud, new_boxes, new_per_point_labels
129
+
130
+ # fallback
131
+ return point_cloud, target_boxes, per_point_labels
pretrained_ckpt/lamm98k/config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "../model_zoo/vicuna_ckpt/13b_v0/",
3
+ "architectures": [
4
+ "LlamaForCausalLM"
5
+ ],
6
+ "bos_token_id": 1,
7
+ "eos_token_id": 2,
8
+ "hidden_act": "silu",
9
+ "hidden_size": 5120,
10
+ "initializer_range": 0.02,
11
+ "intermediate_size": 13824,
12
+ "max_position_embeddings": 2048,
13
+ "model_type": "llama",
14
+ "num_attention_heads": 40,
15
+ "num_hidden_layers": 40,
16
+ "pad_token_id": 0,
17
+ "rms_norm_eps": 1e-06,
18
+ "tie_word_embeddings": false,
19
+ "torch_dtype": "float16",
20
+ "transformers_version": "4.29.1",
21
+ "use_cache": true,
22
+ "vocab_size": 32001
23
+ }
pretrained_ckpt/lamm98k/pytorch_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5e43143d64def44689a2c197045795e0fa241a13ea4f33c99575fe0336d8212f
3
+ size 115418859
pretrained_ckpt/lamm98k/special_tokens_map.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": true,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "</s>",
11
+ "lstrip": false,
12
+ "normalized": true,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": "</s>",
17
+ "unk_token": {
18
+ "content": "<unk>",
19
+ "lstrip": false,
20
+ "normalized": true,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ }
24
+ }
pretrained_ckpt/lamm98k/tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347
3
+ size 499723
pretrained_ckpt/lamm98k/tokenizer_config.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "bos_token": {
5
+ "__type": "AddedToken",
6
+ "content": "<s>",
7
+ "lstrip": false,
8
+ "normalized": true,
9
+ "rstrip": false,
10
+ "single_word": false
11
+ },
12
+ "clean_up_tokenization_spaces": false,
13
+ "eos_token": {
14
+ "__type": "AddedToken",
15
+ "content": "</s>",
16
+ "lstrip": false,
17
+ "normalized": true,
18
+ "rstrip": false,
19
+ "single_word": false
20
+ },
21
+ "model_max_length": 1000000000000000019884624838656,
22
+ "pad_token": null,
23
+ "sp_model_kwargs": {},
24
+ "tokenizer_class": "LlamaTokenizer",
25
+ "unk_token": {
26
+ "__type": "AddedToken",
27
+ "content": "<unk>",
28
+ "lstrip": false,
29
+ "normalized": true,
30
+ "rstrip": false,
31
+ "single_word": false
32
+ },
33
+ "use_fast": false
34
+ }
requirements.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ timm==0.6.7
2
+ data
3
+ einops==0.6.1
4
+ ftfy==6.1.1
5
+ iopath==0.1.10
6
+ numpy==1.24.3
7
+ peft==0.3.0
8
+ Pillow==9.5.0
9
+ PyYAML==6.0
10
+ regex==2022.10.31
11
+ torchvision==0.14.1
12
+ torchaudio==0.13.1
13
+ pytorchvideo
14
+ fvcore
15
+ decord==0.6.0
16
+ tqdm==4.64.1
17
+ transformers==4.29.1
18
+ bigmodelvis
19
+ gradio