praeclarumjj3 commited on
Commit
9fa3d89
1 Parent(s): d62bbd6

:zap: add code

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +4 -4
  2. app.py +471 -48
  3. demo.py +486 -0
  4. ola_vlm/.DS_Store +0 -0
  5. ola_vlm/__init__.py +2 -0
  6. ola_vlm/constants.py +13 -0
  7. ola_vlm/conversation.py +255 -0
  8. ola_vlm/eval/.DS_Store +0 -0
  9. ola_vlm/eval/eval_cv_bench.py +78 -0
  10. ola_vlm/eval/eval_mmstar.py +17 -0
  11. ola_vlm/eval/eval_probe_task.py +223 -0
  12. ola_vlm/eval/eval_sherlock_dsg.py +282 -0
  13. ola_vlm/eval/get_all_stats.py +132 -0
  14. ola_vlm/eval/get_probe_task_scores.py +197 -0
  15. ola_vlm/eval/get_sherlock_dsg_scores.py +49 -0
  16. ola_vlm/eval/merge_json.py +30 -0
  17. ola_vlm/eval/mmstar/evaluate/__init__.py +1 -0
  18. ola_vlm/eval/mmstar/evaluate/__pycache__/__init__.cpython-310.pyc +0 -0
  19. ola_vlm/eval/mmstar/evaluate/__pycache__/mmstar.cpython-310.pyc +0 -0
  20. ola_vlm/eval/mmstar/evaluate/mmstar.py +87 -0
  21. ola_vlm/eval/mmstar/smp/__init__.py +3 -0
  22. ola_vlm/eval/mmstar/smp/__pycache__/__init__.cpython-310.pyc +0 -0
  23. ola_vlm/eval/mmstar/smp/__pycache__/file.cpython-310.pyc +0 -0
  24. ola_vlm/eval/mmstar/smp/__pycache__/log.cpython-310.pyc +0 -0
  25. ola_vlm/eval/mmstar/smp/__pycache__/misc.cpython-310.pyc +0 -0
  26. ola_vlm/eval/mmstar/smp/__pycache__/vlm.cpython-310.pyc +0 -0
  27. ola_vlm/eval/mmstar/smp/file.py +147 -0
  28. ola_vlm/eval/mmstar/smp/log.py +43 -0
  29. ola_vlm/eval/mmstar/smp/misc.py +174 -0
  30. ola_vlm/eval/model_cvbench_loader.py +166 -0
  31. ola_vlm/eval/model_mmstar_loader.py +164 -0
  32. ola_vlm/mm_utils.py +398 -0
  33. ola_vlm/model/.DS_Store +0 -0
  34. ola_vlm/model/__init__.py +5 -0
  35. ola_vlm/model/apply_delta.py +48 -0
  36. ola_vlm/model/aux_heads/.DS_Store +0 -0
  37. ola_vlm/model/aux_heads/__init__.py +3 -0
  38. ola_vlm/model/aux_heads/da_v2_head.py +457 -0
  39. ola_vlm/model/aux_heads/depth_anything_v2/dinov2.py +415 -0
  40. ola_vlm/model/aux_heads/depth_anything_v2/dinov2_layers/__init__.py +11 -0
  41. ola_vlm/model/aux_heads/depth_anything_v2/dinov2_layers/attention.py +83 -0
  42. ola_vlm/model/aux_heads/depth_anything_v2/dinov2_layers/block.py +252 -0
  43. ola_vlm/model/aux_heads/depth_anything_v2/dinov2_layers/drop_path.py +35 -0
  44. ola_vlm/model/aux_heads/depth_anything_v2/dinov2_layers/layer_scale.py +28 -0
  45. ola_vlm/model/aux_heads/depth_anything_v2/dinov2_layers/mlp.py +41 -0
  46. ola_vlm/model/aux_heads/depth_anything_v2/dinov2_layers/patch_embed.py +90 -0
  47. ola_vlm/model/aux_heads/depth_anything_v2/dinov2_layers/swiglu_ffn.py +63 -0
  48. ola_vlm/model/aux_heads/depth_anything_v2/dpt.py +219 -0
  49. ola_vlm/model/aux_heads/depth_anything_v2/util/blocks.py +148 -0
  50. ola_vlm/model/aux_heads/depth_anything_v2/util/transform.py +158 -0
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
- title: OLA VLM
3
- emoji: 💬
4
- colorFrom: yellow
5
  colorTo: purple
6
  sdk: gradio
7
- sdk_version: 5.0.1
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
 
1
  ---
2
+ title: OLA-VLM
3
+ emoji: 🔍
4
+ colorFrom: blue
5
  colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 4.16.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
app.py CHANGED
@@ -1,64 +1,487 @@
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
 
 
3
 
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
 
 
 
9
 
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
 
 
19
 
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
- messages.append({"role": "user", "content": message})
 
 
 
 
 
27
 
28
- response = ""
 
 
 
 
 
 
29
 
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  temperature=temperature,
35
  top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
- response += token
40
- yield response
41
 
 
 
 
 
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- demo = gr.ChatInterface(
47
- respond,
48
- additional_inputs=[
49
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
50
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
51
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
52
- gr.Slider(
53
- minimum=0.1,
54
- maximum=1.0,
55
- value=0.95,
56
- step=0.05,
57
- label="Top-p (nucleus sampling)",
58
- ),
59
- ],
60
- )
61
 
62
 
63
- if __name__ == "__main__":
64
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import spaces
3
+ import torch
4
+ import numpy as np
5
 
6
+ from ola_vlm.constants import DEFAULT_IMAGE_TOKEN
7
+
8
+ from ola_vlm.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
9
+ from ola_vlm.conversation import conv_templates, SeparatorStyle
10
+ from ola_vlm.model.builder import load_pretrained_model
11
+ from ola_vlm.mm_utils import tokenizer_image_token, get_model_name_from_path, process_images
12
+
13
+ from diffusers import StableUnCLIPImg2ImgPipeline
14
+ from diffusers import DPMSolverMultistepScheduler
15
+ from transformers import OneFormerProcessor
16
+ from ola_vlm.model.aux_heads.oneformer_head import OneFormerHead
17
+ from ola_vlm.ola_utils import visualize_oneformer_masks_on_image, oneformer_prepare_panoptic_instance_prediction
18
+ import matplotlib
19
+ from PIL import Image, ImageDraw, ImageFont
20
+ import argparse
21
+ import math
22
+
23
+ from transformers import TextIteratorStreamer
24
+ from threading import Thread
25
+
26
+ def make_grid(pil_images, layer_indices=None):
27
+ new_images = []
28
+ new_captions = []
29
+
30
+ # Resize images and prepare captions
31
+ for i, pil_image in enumerate(pil_images):
32
+ pil_image = pil_image.resize((256, 256))
33
+ new_images.append(pil_image)
34
+ if layer_indices is not None:
35
+ new_captions.append(f"Layer: {layer_indices[i]}")
36
+ else:
37
+ new_captions.append(f"Layer: {i+1}")
38
+
39
+ images = new_images
40
+ captions = new_captions
41
+
42
+ width, height = images[0].size
43
+ font_size = 18
44
+
45
+ # Calculate the number of rows and columns for the grid
46
+ images_per_row = min(len(images), 4) # Max 4 images per row
47
+ row_count = math.ceil(len(images) / images_per_row)
48
+ total_width = width * images_per_row
49
+ total_height = height * row_count
50
+
51
+ # Create a new blank image
52
+ new_image = Image.new("RGB", (total_width, total_height), "white")
53
+ draw = ImageDraw.Draw(new_image)
54
+
55
+ # Load a default font
56
+ try:
57
+ font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", font_size)
58
+ except:
59
+ font = ImageFont.load_default()
60
+
61
+ # Place images and captions in the grid
62
+ for i, (image, caption) in enumerate(zip(images, captions)):
63
+ row = i // images_per_row
64
+ col = i % images_per_row
65
+ x_offset = col * width
66
+ y_offset = row * height
67
+
68
+ # Paste the image
69
+ new_image.paste(image, (x_offset, y_offset))
70
+
71
+ # Calculate text and background positions
72
+ text_width, text_height = draw.textsize(caption, font=font)
73
+ text_position = (x_offset + 10, y_offset + height - text_height - 10)
74
+ background_position = (
75
+ text_position[0] - 5,
76
+ text_position[1] - 5,
77
+ text_position[0] + text_width + 5,
78
+ text_position[1] + text_height + 5,
79
+ )
80
+
81
+ # Draw background rectangle and text
82
+ draw.rectangle(background_position, fill="white", outline="black")
83
+ draw.text(text_position, caption, fill="black", font=font)
84
+
85
+ return new_image
86
+
87
+ def reload_from_ckpt(model_path, model, cache_dir=None):
88
+ import os
89
+ from safetensors import safe_open
90
+ from huggingface_hub import hf_hub_download, list_repo_files
91
+
92
+ state_dict = {}
93
+
94
+ # Check if the path is a local directory or HF Hub model
95
+ if os.path.isdir(model_path):
96
+ # Local directory: Load safetensors files
97
+ safetensors_paths = [os.path.join(model_path, f) for f in os.listdir(model_path) if f.endswith('.safetensors')]
98
+ else:
99
+ # HF Hub: Get list of safetensors files and download them
100
+ repo_files = list_repo_files(model_path)
101
+ safetensors_paths = [
102
+ hf_hub_download(model_path, file_name, cache_dir=cache_dir)
103
+ for file_name in repo_files if file_name.endswith('.safetensors')
104
+ ]
105
+
106
+ # Load safetensors files into the state_dict
107
+ for path in safetensors_paths:
108
+ with safe_open(path, framework="pt", device="cpu") as f:
109
+ for key in f.keys():
110
+ state_dict[key] = f.get_tensor(key)
111
+
112
+ # Load the state dict into the model
113
+ model.load_state_dict(state_dict, strict=False)
114
+ return model
115
+
116
+ # os.environ['GRADIO_TEMP_DIR'] = './gradio_tmp'
117
+ no_change_btn = gr.Button()
118
+ enable_btn = gr.Button(interactive=True)
119
+ disable_btn = gr.Button(interactive=False)
120
+
121
+ argparser = argparse.ArgumentParser()
122
+ argparser.add_argument("--server_name", default="0.0.0.0", type=str)
123
+ argparser.add_argument("--port", default="6324", type=str)
124
+ argparser.add_argument("--model-path", default="shi-labs/pretrain_dsg_OLA-VLM-CLIP-ViT-Llama3-8b", type=str)
125
+ argparser.add_argument("--model-base", type=str, default=None)
126
+ argparser.add_argument("--num-gpus", type=int, default=1)
127
+ argparser.add_argument("--conv-mode", type=str, default="llava_llama_3")
128
+ argparser.add_argument("--temperature", type=float, default=0.2)
129
+ argparser.add_argument("--max-new-tokens", type=int, default=512)
130
+ argparser.add_argument("--num_frames", type=int, default=16)
131
+ argparser.add_argument("--load-8bit", action="store_true")
132
+ argparser.add_argument("--load-4bit", action="store_true")
133
+ argparser.add_argument("--debug", action="store_true")
134
+
135
+ args = argparser.parse_args()
136
+ model_path = args.model_path
137
+ conv_mode = args.conv_mode
138
+ filt_invalid="cut"
139
+ model_name = get_model_name_from_path(args.model_path)
140
+ tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit)
141
+ model = reload_from_ckpt("shi-labs/OLA-VLM-CLIP-ViT-Llama3-8b", model)
142
+ our_chatbot = None
143
+
144
+ pipe = StableUnCLIPImg2ImgPipeline.from_pretrained(f"stabilityai/stable-diffusion-2-1-unclip", torch_dtype=torch.float16, variant="fp16")
145
+ pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
146
+ pipe = pipe.to("cuda")
147
+
148
+ oneformer_processor = OneFormerProcessor.from_pretrained("shi-labs/oneformer_coco_swin_large")
149
+ oneformer = OneFormerHead.from_pretrained("shi-labs/oneformer_coco_swin_large").to("cuda")
150
+
151
+ gen_layer_indices = model.config.image_gen["img_layer_indices"].split("-")
152
+ seg_layer_indices = model.config.image_seg["seg_layer_indices"].split("-")
153
+ depth_layer_indices = model.config.image_depth["depth_layer_indices"].split("-")
154
+
155
+
156
+ def clear_history():
157
+ state =conv_templates[conv_mode].copy()
158
+ return (state, state.to_gradio_chatbot(), "", None, None, None, None) + (disable_btn,) * 5
159
+
160
+ def add_text(state, imagebox, textbox, image_process_mode):
161
+ if state is None:
162
+ state = conv_templates[conv_mode].copy()
163
+
164
+ if imagebox is not None:
165
+ textbox = DEFAULT_IMAGE_TOKEN + '\n' + textbox
166
+ image = Image.open(imagebox).convert('RGB')
167
+
168
+ if imagebox is not None:
169
+ textbox = (textbox, image, image_process_mode)
170
+
171
+ state.append_message(state.roles[0], textbox)
172
+ state.append_message(state.roles[1], None)
173
+
174
+ yield (state, state.to_gradio_chatbot(), "", None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
175
+
176
+ def get_gen_images(out):
177
+ img_embeds = out.image_embs
178
+ if len(img_embeds) == 0:
179
+ return None
180
+ images = []
181
+ for img_embed in img_embeds:
182
+ gen_image = pipe(image_embeds=img_embed.squeeze(1),
183
+ num_inference_steps=25,
184
+ ).images[0]
185
+ images.append(gen_image)
186
+ grid_image = make_grid(images, gen_layer_indices)
187
+ return grid_image
188
+
189
+ def get_depth_images(out, org_size):
190
+ depth_preds = out.depth_preds
191
 
192
+ if len(depth_preds) == 0:
193
+ return None
194
+ depths = []
195
 
196
+ for i, depth_pred in enumerate(depth_preds):
197
+ depth = (depth_pred - depth_pred.min()) / (depth_pred.max() - depth_pred.min()) * 255.0
198
+ depth = depth.squeeze(0).cpu().numpy()
199
+ depth = depth.astype(np.uint8)
200
+ cmap = matplotlib.colormaps.get_cmap('Spectral_r')
201
+ depth = (cmap(depth)[:, :, :3] * 255).astype(np.uint8)
202
+ depth = Image.fromarray(depth)
203
+ depth = depth.resize(org_size)
204
+ depths.append(depth)
205
+ grid_image = make_grid(depths, depth_layer_indices)
206
+ return grid_image
207
 
208
+ def get_seg_images(out, image):
209
+ seg_embs = out.seg_embs
210
+
211
+ if len(seg_embs) == 0:
212
+ return None
213
+
214
+ seg_preds = []
215
+ inputs = oneformer_processor(image, ["semantic"], return_tensors="pt")
216
+ inputs["pixel_values"] = inputs["pixel_values"].to(out.logits.device, out.logits.dtype)
217
+ inputs["task_inputs"] = inputs["task_inputs"].to(out.logits.device, out.logits.dtype)
218
+ backbone_features = oneformer.get_backbone_feats(**inputs)
219
+ for i, seg_emb in enumerate(seg_embs):
220
+ pred = oneformer.get_masks(**inputs, backbone_last_feature=seg_emb.float(), all_backbone_features=backbone_features)
221
+ pred = oneformer_processor.post_process_panoptic_segmentation(
222
+ pred, target_sizes=[image.size[::-1]]
223
+ )[0]
224
+ pred_msk, pred_cls = oneformer_prepare_panoptic_instance_prediction(**pred, oneformer=oneformer)
225
+ pred = visualize_oneformer_masks_on_image(image, pred_msk, pred_cls)
226
+ seg_preds.append(pred)
227
+ grid_image = make_grid(seg_preds, seg_layer_indices)
228
+ return grid_image
229
 
230
+ def delete_text(state, image_process_mode):
231
+ state.messages[-1][-1] = None
232
+ prev_human_msg = state.messages[-2]
233
+ if type(prev_human_msg[1]) in (tuple, list):
234
+ prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
235
+ yield (state, state.to_gradio_chatbot(), "", None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
236
 
237
+ def regenerate(state, image_process_mode):
238
+ state.messages[-1][-1] = None
239
+ prev_human_msg = state.messages[-2]
240
+ if type(prev_human_msg[1]) in (tuple, list):
241
+ prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
242
+ state.skip_next = False
243
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
244
 
245
+ @spaces.GPU
246
+ def get_interm_outs(state):
247
+ prompt = state.get_prompt()
248
+ images = state.get_images(return_pil=True)
249
+ #prompt, image_args = process_image(prompt, images)
250
+
251
+ if images is not None and len(images) > 0:
252
+ if len(images) > 0:
253
+ if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
254
+ raise ValueError("Number of images does not match number of <image> tokens in prompt")
255
+
256
+ #images = [load_image_from_base64(image) for image in images]
257
+ image_sizes = [image.size for image in images]
258
+ inp_images = process_images(images, image_processor, model.config)
259
+
260
+ if type(inp_images) is list:
261
+ inp_images = [image.to(model.device, dtype=torch.float16) for image in images]
262
+ else:
263
+ inp_images = inp_images.to(model.device, dtype=torch.float16)
264
+ else:
265
+ inp_images = None
266
+ image_sizes = None
267
+ image_args = {"images": inp_images, "image_sizes": image_sizes}
268
+ else:
269
+ inp_images = None
270
+ image_args = {}
271
+
272
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
273
+
274
+ interm_outs = model.get_visual_interpretations(
275
+ input_ids,
276
+ **image_args
277
+ )
278
+
279
+ depth_outs = get_depth_images(interm_outs, image_sizes[0])
280
+ seg_outs = get_seg_images(interm_outs, images[0])
281
+ gen_outs = get_gen_images(interm_outs)
282
+
283
+ return depth_outs, seg_outs, gen_outs
284
+
285
+ @spaces.GPU
286
+ def generate(state, temperature, top_p, max_output_tokens):
287
+ prompt = state.get_prompt()
288
+ images = state.get_images(return_pil=True)
289
+ #prompt, image_args = process_image(prompt, images)
290
+
291
+ ori_prompt = prompt
292
+ num_image_tokens = 0
293
+
294
+ if images is not None and len(images) > 0:
295
+ if len(images) > 0:
296
+ if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
297
+ raise ValueError("Number of images does not match number of <image> tokens in prompt")
298
+
299
+ #images = [load_image_from_base64(image) for image in images]
300
+ image_sizes = [image.size for image in images]
301
+ images = process_images(images, image_processor, model.config)
302
+
303
+ if type(images) is list:
304
+ images = [image.to(model.device, dtype=torch.float16) for image in images]
305
+ else:
306
+ images = images.to(model.device, dtype=torch.float16)
307
+ else:
308
+ images = None
309
+ image_sizes = None
310
+ image_args = {"images": images, "image_sizes": image_sizes}
311
+ else:
312
+ images = None
313
+ image_args = {}
314
+
315
+ max_context_length = getattr(model.config, 'max_position_embeddings', 2048)
316
+ max_new_tokens = max_output_tokens
317
+ do_sample = True if temperature > 0.001 else False
318
+ stop_str = state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2
319
+
320
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
321
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
322
+
323
+ max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens)
324
+
325
+ if max_new_tokens < 1:
326
+ return
327
+
328
+ thread = Thread(target=model.generate, kwargs=dict(
329
+ inputs=input_ids,
330
+ do_sample=do_sample,
331
  temperature=temperature,
332
  top_p=top_p,
333
+ max_new_tokens=max_new_tokens,
334
+ streamer=streamer,
335
+ use_cache=True,
336
+ pad_token_id=tokenizer.eos_token_id,
337
+ **image_args
338
+ ))
339
+ thread.start()
340
+ generated_text = ''
341
+ for new_text in streamer:
342
+ generated_text += new_text
343
+ if generated_text.endswith(stop_str):
344
+ generated_text = generated_text[:-len(stop_str)]
345
+ state.messages[-1][-1] = generated_text
346
+ yield (state, state.to_gradio_chatbot(), "", None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
347
+
348
+ yield (state, state.to_gradio_chatbot(), "", None) + (enable_btn,) * 5
349
+
350
+ torch.cuda.empty_cache()
351
+
352
+ txt = gr.Textbox(
353
+ scale=4,
354
+ show_label=False,
355
+ placeholder="Enter text and press enter.",
356
+ container=False,
357
+ )
358
 
 
 
359
 
360
+ title = "<h1 style='margin-bottom: -10px; text-align: center'>OLA-VLM: Optimizing Language Model Representations for Enhanced Visual Quality and Alignment</h1>"
361
+ description = "<p style='font-size: 16px; margin: 5px; font-weight: w300; text-align: center'> <a href='https://praeclarumjj3.github.io/' style='text-decoration:none' target='_blank'>Jitesh Jain</a> &nbsp;&nbsp <a href='https://zyang-ur.github.io/' style='text-decoration:none' target='_blank'>Zhengyuan Yang</a> &nbsp;&nbsp <a href='https://www.humphreyshi.com/home' style='text-decoration:none' target='_blank'>Humphrey Shi<sup>*</sup></a> &nbsp;&nbsp <a href='https://www.humphreyshi.com/home' style='text-decoration:none' target='_blank'>Jianfeng Gao<sup>*</sup></a> &nbsp;&nbsp <a href='https://jwyang.github.io/' style='text-decoration:none' target='_blank'>Jianwei Yang<sup>*</sup></a></p>" \
362
+ + "<p style='font-size: 12px; margin: 5px; font-weight: w300; text-align: center'><sup>*</sup>Equal Advising</p>" \
363
+ + "<p style='font-size: 16px; margin: 5px; font-weight: w600; text-align: center'> <a href='https://praeclarumjj3.github.io/ola_vlm/' target='_blank'>Project Page</a> | <a href='https://youtu.be/' target='_blank'>Video</a> | <a href='https://arxiv.org/abs/' target='_blank'>ArXiv</a> | <a href='https://github.com/SHI-Labs/OLA-VLM' target='_blank'>Github</a></p>"
364
 
365
+ tos_markdown = ("""
366
+ ### Terms of use
367
+ By using this service, users are required to agree to the following terms:
368
+ The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes.
369
+ """)
370
+
371
+
372
+ learn_more_markdown = ("""
373
+ ### License
374
+ The service is a research preview intended for non-commercial use only, subject to the [License](https://huggingface.co/lmsys/vicuna-7b-v1.5) of Vicuna-v1.5, [License](https://github.com/haotian-liu/LLaVA/blob/main/LICENSE) of LLaVA, [Terms of Use](https://cocodataset.org/#termsofuse) of the COCO dataset, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
375
+ """)
376
+
377
+ block_css = """
378
+ #buttons button {
379
+ min-width: min(120px,100%);
380
+ }
381
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
382
 
383
 
384
+ textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
385
+ with gr.Blocks(title="OLA-VLM", theme=gr.themes.Default(), css=block_css) as demo:
386
+ state = gr.State()
387
+
388
+ gr.Markdown(title)
389
+ gr.Markdown(description)
390
+
391
+ with gr.Row():
392
+ with gr.Column(scale=4):
393
+ imagebox = gr.Image(label="Input Image", type="filepath")
394
+ image_process_mode = gr.Radio(
395
+ ["Crop", "Resize", "Pad", "Default"],
396
+ value="Default",
397
+ label="Preprocess for non-square image", visible=False)
398
+
399
+ # with gr.Accordion("Parameters", open=False) as parameter_row:
400
+ with gr.Row():
401
+ temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, label="Temperature",)
402
+ top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",)
403
+ max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
404
+
405
+ with gr.Column(scale=8):
406
+ chatbot = gr.Chatbot(
407
+ elem_id="chatbot",
408
+ label="OLA-VLM",
409
+ height=300,
410
+ layout="panel",
411
+ )
412
+ textbox.render()
413
+ with gr.Row(elem_id="buttons") as button_row:
414
+ upvote_btn = gr.Button(value="👍 Upvote", interactive=False, visible=False)
415
+ downvote_btn = gr.Button(value="👎 Downvote", interactive=False, visible=False)
416
+ flag_btn = gr.Button(value="⚠️ Flag", interactive=False, visible=False)
417
+ #stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
418
+ regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
419
+ clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
420
+ submit_btn = gr.Button(value="Send", variant="primary")
421
+
422
+ with gr.Accordion("Representations from selected layers of the LLM (expects only a single image input)", open=False) as interm_out:
423
+ inter_vis_btn = gr.Button(value="✨ Visualize")
424
+ with gr.Row():
425
+ depth_box = gr.Image(label="depth", type="pil", visible=True)
426
+ seg_box = gr.Image(label="seg", type="pil", visible=True)
427
+ gen_box = gr.Image(label="gen", type="pil", visible=True)
428
+
429
+ gr.Examples(examples=[
430
+ [f"assets/cars.jpg", "Which car is in front: the blue or the brown one?"],
431
+ [f"assets/pb.jpg", "Where is the bulding located with respect to the man?"],
432
+ ], inputs=[imagebox, textbox], cache_examples=False)
433
+
434
+ # gr.Markdown(tos_markdown)
435
+ # gr.Markdown(learn_more_markdown)
436
+ # url_params = gr.JSON(visible=False)
437
+
438
+ # Register listeners
439
+ btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
440
+
441
+ inter_vis_btn.click(
442
+ get_interm_outs,
443
+ [state],
444
+ [depth_box, seg_box, gen_box],
445
+ )
446
+
447
+ clear_btn.click(
448
+ clear_history,
449
+ None,
450
+ [state, chatbot, textbox, imagebox, depth_box, gen_box, seg_box] + btn_list,
451
+ queue=False
452
+ )
453
+
454
+ regenerate_btn.click(
455
+ delete_text,
456
+ [state, image_process_mode],
457
+ [state, chatbot, textbox, imagebox] + btn_list,
458
+ ).then(
459
+ generate,
460
+ [state, temperature, top_p, max_output_tokens],
461
+ [state, chatbot, textbox, imagebox] + btn_list,
462
+ )
463
+ textbox.submit(
464
+ add_text,
465
+ [state, imagebox, textbox, image_process_mode],
466
+ [state, chatbot, textbox, imagebox] + btn_list,
467
+ ).then(
468
+ generate,
469
+ [state, temperature, top_p, max_output_tokens],
470
+ [state, chatbot, textbox, imagebox] + btn_list,
471
+ )
472
+
473
+ submit_btn.click(
474
+ add_text,
475
+ [state, imagebox, textbox, image_process_mode],
476
+ [state, chatbot, textbox, imagebox] + btn_list,
477
+ ).then(
478
+ generate,
479
+ [state, temperature, top_p, max_output_tokens],
480
+ [state, chatbot, textbox, imagebox] + btn_list,
481
+ )
482
+
483
+ demo.queue(
484
+ status_update_rate=10,
485
+ api_open=False
486
+ ).launch(share=False)
487
+ demo.queue()
demo.py ADDED
@@ -0,0 +1,486 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import torch
4
+ import numpy as np
5
+
6
+ from ola_vlm.constants import DEFAULT_IMAGE_TOKEN
7
+
8
+ from ola_vlm.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
9
+ from ola_vlm.conversation import conv_templates, SeparatorStyle
10
+ from ola_vlm.model.builder import load_pretrained_model
11
+ from ola_vlm.mm_utils import tokenizer_image_token, get_model_name_from_path, process_images
12
+
13
+ from diffusers import StableUnCLIPImg2ImgPipeline
14
+ from diffusers import DPMSolverMultistepScheduler
15
+ from transformers import OneFormerProcessor
16
+ from ola_vlm.model.aux_heads.oneformer_head import OneFormerHead
17
+ from ola_vlm.ola_utils import visualize_oneformer_masks_on_image, oneformer_prepare_panoptic_instance_prediction
18
+ import matplotlib
19
+ from PIL import Image, ImageDraw, ImageFont
20
+ import argparse
21
+ import math
22
+
23
+ from transformers import TextIteratorStreamer
24
+ from threading import Thread
25
+
26
+ def make_grid(pil_images, layer_indices=None):
27
+ new_images = []
28
+ new_captions = []
29
+
30
+ # Resize images and prepare captions
31
+ for i, pil_image in enumerate(pil_images):
32
+ pil_image = pil_image.resize((256, 256))
33
+ new_images.append(pil_image)
34
+ if layer_indices is not None:
35
+ new_captions.append(f"Layer: {layer_indices[i]}")
36
+ else:
37
+ new_captions.append(f"Layer: {i+1}")
38
+
39
+ images = new_images
40
+ captions = new_captions
41
+
42
+ width, height = images[0].size
43
+ font_size = 18
44
+
45
+ # Calculate the number of rows and columns for the grid
46
+ images_per_row = min(len(images), 4) # Max 4 images per row
47
+ row_count = math.ceil(len(images) / images_per_row)
48
+ total_width = width * images_per_row
49
+ total_height = height * row_count
50
+
51
+ # Create a new blank image
52
+ new_image = Image.new("RGB", (total_width, total_height), "white")
53
+ draw = ImageDraw.Draw(new_image)
54
+
55
+ # Load a default font
56
+ try:
57
+ font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", font_size)
58
+ except:
59
+ font = ImageFont.load_default()
60
+
61
+ # Place images and captions in the grid
62
+ for i, (image, caption) in enumerate(zip(images, captions)):
63
+ row = i // images_per_row
64
+ col = i % images_per_row
65
+ x_offset = col * width
66
+ y_offset = row * height
67
+
68
+ # Paste the image
69
+ new_image.paste(image, (x_offset, y_offset))
70
+
71
+ # Calculate text and background positions
72
+ text_width, text_height = draw.textsize(caption, font=font)
73
+ text_position = (x_offset + 10, y_offset + height - text_height - 10)
74
+ background_position = (
75
+ text_position[0] - 5,
76
+ text_position[1] - 5,
77
+ text_position[0] + text_width + 5,
78
+ text_position[1] + text_height + 5,
79
+ )
80
+
81
+ # Draw background rectangle and text
82
+ draw.rectangle(background_position, fill="white", outline="black")
83
+ draw.text(text_position, caption, fill="black", font=font)
84
+
85
+ return new_image
86
+
87
+ def reload_from_ckpt(model_path, model, cache_dir=None):
88
+ import os
89
+ from safetensors import safe_open
90
+ from huggingface_hub import hf_hub_download, list_repo_files
91
+
92
+ state_dict = {}
93
+
94
+ # Check if the path is a local directory or HF Hub model
95
+ if os.path.isdir(model_path):
96
+ # Local directory: Load safetensors files
97
+ safetensors_paths = [os.path.join(model_path, f) for f in os.listdir(model_path) if f.endswith('.safetensors')]
98
+ else:
99
+ # HF Hub: Get list of safetensors files and download them
100
+ repo_files = list_repo_files(model_path)
101
+ safetensors_paths = [
102
+ hf_hub_download(model_path, file_name, cache_dir=cache_dir)
103
+ for file_name in repo_files if file_name.endswith('.safetensors')
104
+ ]
105
+
106
+ # Load safetensors files into the state_dict
107
+ for path in safetensors_paths:
108
+ with safe_open(path, framework="pt", device="cpu") as f:
109
+ for key in f.keys():
110
+ state_dict[key] = f.get_tensor(key)
111
+
112
+ # Load the state dict into the model
113
+ model.load_state_dict(state_dict, strict=False)
114
+ return model
115
+
116
+ # os.environ['GRADIO_TEMP_DIR'] = './gradio_tmp'
117
+ no_change_btn = gr.Button()
118
+ enable_btn = gr.Button(interactive=True)
119
+ disable_btn = gr.Button(interactive=False)
120
+
121
+ argparser = argparse.ArgumentParser()
122
+ argparser.add_argument("--server_name", default="0.0.0.0", type=str)
123
+ argparser.add_argument("--port", default="6324", type=str)
124
+ argparser.add_argument("--model-path", default="shi-labs/pretrain_dsg_OLA-VLM-CLIP-ViT-Llama3-8b", type=str)
125
+ argparser.add_argument("--model-base", type=str, default=None)
126
+ argparser.add_argument("--num-gpus", type=int, default=1)
127
+ argparser.add_argument("--conv-mode", type=str, default="llava_llama_3")
128
+ argparser.add_argument("--temperature", type=float, default=0.2)
129
+ argparser.add_argument("--max-new-tokens", type=int, default=512)
130
+ argparser.add_argument("--num_frames", type=int, default=16)
131
+ argparser.add_argument("--load-8bit", action="store_true")
132
+ argparser.add_argument("--load-4bit", action="store_true")
133
+ argparser.add_argument("--debug", action="store_true")
134
+
135
+ args = argparser.parse_args()
136
+ model_path = args.model_path
137
+ conv_mode = args.conv_mode
138
+ filt_invalid="cut"
139
+ model_name = get_model_name_from_path(args.model_path)
140
+ tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit)
141
+ model = reload_from_ckpt("shi-labs/OLA-VLM-CLIP-ViT-Llama3-8b", model)
142
+ our_chatbot = None
143
+
144
+ pipe = StableUnCLIPImg2ImgPipeline.from_pretrained(f"stabilityai/stable-diffusion-2-1-unclip", torch_dtype=torch.float16, variant="fp16")
145
+ pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
146
+ pipe = pipe.to("cuda")
147
+
148
+ oneformer_processor = OneFormerProcessor.from_pretrained("shi-labs/oneformer_coco_swin_large")
149
+ oneformer = OneFormerHead.from_pretrained("shi-labs/oneformer_coco_swin_large").to("cuda")
150
+
151
+ gen_layer_indices = model.config.image_gen["img_layer_indices"].split("-")
152
+ seg_layer_indices = model.config.image_seg["seg_layer_indices"].split("-")
153
+ depth_layer_indices = model.config.image_depth["depth_layer_indices"].split("-")
154
+
155
+
156
+ def clear_history():
157
+ state =conv_templates[conv_mode].copy()
158
+ return (state, state.to_gradio_chatbot(), "", None, None, None, None) + (disable_btn,) * 5
159
+
160
+ def add_text(state, imagebox, textbox, image_process_mode):
161
+ if state is None:
162
+ state = conv_templates[conv_mode].copy()
163
+
164
+ if imagebox is not None:
165
+ textbox = DEFAULT_IMAGE_TOKEN + '\n' + textbox
166
+ image = Image.open(imagebox).convert('RGB')
167
+
168
+ if imagebox is not None:
169
+ textbox = (textbox, image, image_process_mode)
170
+
171
+ state.append_message(state.roles[0], textbox)
172
+ state.append_message(state.roles[1], None)
173
+
174
+ yield (state, state.to_gradio_chatbot(), "", None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
175
+
176
+ def get_gen_images(out):
177
+ img_embeds = out.image_embs
178
+ if len(img_embeds) == 0:
179
+ return None
180
+ images = []
181
+ for img_embed in img_embeds:
182
+ gen_image = pipe(image_embeds=img_embed.squeeze(1),
183
+ num_inference_steps=25,
184
+ ).images[0]
185
+ images.append(gen_image)
186
+ grid_image = make_grid(images, gen_layer_indices)
187
+ return grid_image
188
+
189
+ def get_depth_images(out, org_size):
190
+ depth_preds = out.depth_preds
191
+
192
+ if len(depth_preds) == 0:
193
+ return None
194
+ depths = []
195
+
196
+ for i, depth_pred in enumerate(depth_preds):
197
+ depth = (depth_pred - depth_pred.min()) / (depth_pred.max() - depth_pred.min()) * 255.0
198
+ depth = depth.squeeze(0).cpu().numpy()
199
+ depth = depth.astype(np.uint8)
200
+ cmap = matplotlib.colormaps.get_cmap('Spectral_r')
201
+ depth = (cmap(depth)[:, :, :3] * 255).astype(np.uint8)
202
+ depth = Image.fromarray(depth)
203
+ depth = depth.resize(org_size)
204
+ depths.append(depth)
205
+ grid_image = make_grid(depths, depth_layer_indices)
206
+ return grid_image
207
+
208
+ def get_seg_images(out, image):
209
+ seg_embs = out.seg_embs
210
+
211
+ if len(seg_embs) == 0:
212
+ return None
213
+
214
+ seg_preds = []
215
+ inputs = oneformer_processor(image, ["semantic"], return_tensors="pt")
216
+ inputs["pixel_values"] = inputs["pixel_values"].to(out.logits.device, out.logits.dtype)
217
+ inputs["task_inputs"] = inputs["task_inputs"].to(out.logits.device, out.logits.dtype)
218
+ backbone_features = oneformer.get_backbone_feats(**inputs)
219
+ for i, seg_emb in enumerate(seg_embs):
220
+ pred = oneformer.get_masks(**inputs, backbone_last_feature=seg_emb.float(), all_backbone_features=backbone_features)
221
+ pred = oneformer_processor.post_process_panoptic_segmentation(
222
+ pred, target_sizes=[image.size[::-1]]
223
+ )[0]
224
+ pred_msk, pred_cls = oneformer_prepare_panoptic_instance_prediction(**pred, oneformer=oneformer)
225
+ pred = visualize_oneformer_masks_on_image(image, pred_msk, pred_cls)
226
+ seg_preds.append(pred)
227
+ grid_image = make_grid(seg_preds, seg_layer_indices)
228
+ return grid_image
229
+
230
+ def delete_text(state, image_process_mode):
231
+ state.messages[-1][-1] = None
232
+ prev_human_msg = state.messages[-2]
233
+ if type(prev_human_msg[1]) in (tuple, list):
234
+ prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
235
+ yield (state, state.to_gradio_chatbot(), "", None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
236
+
237
+ def regenerate(state, image_process_mode):
238
+ state.messages[-1][-1] = None
239
+ prev_human_msg = state.messages[-2]
240
+ if type(prev_human_msg[1]) in (tuple, list):
241
+ prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
242
+ state.skip_next = False
243
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
244
+
245
+ def get_interm_outs(state):
246
+ prompt = state.get_prompt()
247
+ images = state.get_images(return_pil=True)
248
+ #prompt, image_args = process_image(prompt, images)
249
+
250
+ if images is not None and len(images) > 0:
251
+ if len(images) > 0:
252
+ if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
253
+ raise ValueError("Number of images does not match number of <image> tokens in prompt")
254
+
255
+ #images = [load_image_from_base64(image) for image in images]
256
+ image_sizes = [image.size for image in images]
257
+ inp_images = process_images(images, image_processor, model.config)
258
+
259
+ if type(inp_images) is list:
260
+ inp_images = [image.to(model.device, dtype=torch.float16) for image in images]
261
+ else:
262
+ inp_images = inp_images.to(model.device, dtype=torch.float16)
263
+ else:
264
+ inp_images = None
265
+ image_sizes = None
266
+ image_args = {"images": inp_images, "image_sizes": image_sizes}
267
+ else:
268
+ inp_images = None
269
+ image_args = {}
270
+
271
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
272
+
273
+ interm_outs = model.get_visual_interpretations(
274
+ input_ids,
275
+ **image_args
276
+ )
277
+
278
+ depth_outs = get_depth_images(interm_outs, image_sizes[0])
279
+ seg_outs = get_seg_images(interm_outs, images[0])
280
+ gen_outs = get_gen_images(interm_outs)
281
+
282
+ return depth_outs, seg_outs, gen_outs
283
+
284
+ # @spaces.GPU
285
+ def generate(state, temperature, top_p, max_output_tokens):
286
+ prompt = state.get_prompt()
287
+ images = state.get_images(return_pil=True)
288
+ #prompt, image_args = process_image(prompt, images)
289
+
290
+ ori_prompt = prompt
291
+ num_image_tokens = 0
292
+
293
+ if images is not None and len(images) > 0:
294
+ if len(images) > 0:
295
+ if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
296
+ raise ValueError("Number of images does not match number of <image> tokens in prompt")
297
+
298
+ #images = [load_image_from_base64(image) for image in images]
299
+ image_sizes = [image.size for image in images]
300
+ images = process_images(images, image_processor, model.config)
301
+
302
+ if type(images) is list:
303
+ images = [image.to(model.device, dtype=torch.float16) for image in images]
304
+ else:
305
+ images = images.to(model.device, dtype=torch.float16)
306
+ else:
307
+ images = None
308
+ image_sizes = None
309
+ image_args = {"images": images, "image_sizes": image_sizes}
310
+ else:
311
+ images = None
312
+ image_args = {}
313
+
314
+ max_context_length = getattr(model.config, 'max_position_embeddings', 2048)
315
+ max_new_tokens = max_output_tokens
316
+ do_sample = True if temperature > 0.001 else False
317
+ stop_str = state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2
318
+
319
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
320
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
321
+
322
+ max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens)
323
+
324
+ if max_new_tokens < 1:
325
+ return
326
+
327
+ thread = Thread(target=model.generate, kwargs=dict(
328
+ inputs=input_ids,
329
+ do_sample=do_sample,
330
+ temperature=temperature,
331
+ top_p=top_p,
332
+ max_new_tokens=max_new_tokens,
333
+ streamer=streamer,
334
+ use_cache=True,
335
+ pad_token_id=tokenizer.eos_token_id,
336
+ **image_args
337
+ ))
338
+ thread.start()
339
+ generated_text = ''
340
+ for new_text in streamer:
341
+ generated_text += new_text
342
+ if generated_text.endswith(stop_str):
343
+ generated_text = generated_text[:-len(stop_str)]
344
+ state.messages[-1][-1] = generated_text
345
+ yield (state, state.to_gradio_chatbot(), "", None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
346
+
347
+ yield (state, state.to_gradio_chatbot(), "", None) + (enable_btn,) * 5
348
+
349
+ torch.cuda.empty_cache()
350
+
351
+ txt = gr.Textbox(
352
+ scale=4,
353
+ show_label=False,
354
+ placeholder="Enter text and press enter.",
355
+ container=False,
356
+ )
357
+
358
+
359
+ title = "<h1 style='margin-bottom: -10px; text-align: center'>OLA-VLM: Optimizing Language Model Representations for Enhanced Visual Quality and Alignment</h1>"
360
+ description = "<p style='font-size: 16px; margin: 5px; font-weight: w300; text-align: center'> <a href='https://praeclarumjj3.github.io/' style='text-decoration:none' target='_blank'>Jitesh Jain</a> &nbsp;&nbsp <a href='https://zyang-ur.github.io/' style='text-decoration:none' target='_blank'>Zhengyuan Yang</a> &nbsp;&nbsp <a href='https://www.humphreyshi.com/home' style='text-decoration:none' target='_blank'>Humphrey Shi<sup>*</sup></a> &nbsp;&nbsp <a href='https://www.humphreyshi.com/home' style='text-decoration:none' target='_blank'>Jianfeng Gao<sup>*</sup></a> &nbsp;&nbsp <a href='https://jwyang.github.io/' style='text-decoration:none' target='_blank'>Jianwei Yang<sup>*</sup></a></p>" \
361
+ + "<p style='font-size: 12px; margin: 5px; font-weight: w300; text-align: center'><sup>*</sup>Equal Advising</p>" \
362
+ + "<p style='font-size: 16px; margin: 5px; font-weight: w600; text-align: center'> <a href='https://praeclarumjj3.github.io/ola_vlm/' target='_blank'>Project Page</a> | <a href='https://youtu.be/' target='_blank'>Video</a> | <a href='https://arxiv.org/abs/' target='_blank'>ArXiv</a> | <a href='https://github.com/SHI-Labs/OLA-VLM' target='_blank'>Github</a></p>"
363
+
364
+ tos_markdown = ("""
365
+ ### Terms of use
366
+ By using this service, users are required to agree to the following terms:
367
+ The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes.
368
+ """)
369
+
370
+
371
+ learn_more_markdown = ("""
372
+ ### License
373
+ The service is a research preview intended for non-commercial use only, subject to the [License](https://huggingface.co/lmsys/vicuna-7b-v1.5) of Vicuna-v1.5, [License](https://github.com/haotian-liu/LLaVA/blob/main/LICENSE) of LLaVA, [Terms of Use](https://cocodataset.org/#termsofuse) of the COCO dataset, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
374
+ """)
375
+
376
+ block_css = """
377
+ #buttons button {
378
+ min-width: min(120px,100%);
379
+ }
380
+ """
381
+
382
+
383
+ textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
384
+ with gr.Blocks(title="OLA-VLM", theme=gr.themes.Default(), css=block_css) as demo:
385
+ state = gr.State()
386
+
387
+ gr.Markdown(title)
388
+ gr.Markdown(description)
389
+
390
+ with gr.Row():
391
+ with gr.Column(scale=4):
392
+ imagebox = gr.Image(label="Input Image", type="filepath")
393
+ image_process_mode = gr.Radio(
394
+ ["Crop", "Resize", "Pad", "Default"],
395
+ value="Default",
396
+ label="Preprocess for non-square image", visible=False)
397
+
398
+ # with gr.Accordion("Parameters", open=False) as parameter_row:
399
+ with gr.Row():
400
+ temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, label="Temperature",)
401
+ top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",)
402
+ max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
403
+
404
+ with gr.Column(scale=8):
405
+ chatbot = gr.Chatbot(
406
+ elem_id="chatbot",
407
+ label="OLA-VLM",
408
+ height=300,
409
+ layout="panel",
410
+ )
411
+ textbox.render()
412
+ with gr.Row(elem_id="buttons") as button_row:
413
+ upvote_btn = gr.Button(value="👍 Upvote", interactive=False, visible=False)
414
+ downvote_btn = gr.Button(value="👎 Downvote", interactive=False, visible=False)
415
+ flag_btn = gr.Button(value="⚠️ Flag", interactive=False, visible=False)
416
+ #stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
417
+ regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
418
+ clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
419
+ submit_btn = gr.Button(value="Send", variant="primary")
420
+
421
+ with gr.Accordion("Representations from selected layers of the LLM (expects only a single image input)", open=False) as interm_out:
422
+ inter_vis_btn = gr.Button(value="✨ Visualize")
423
+ with gr.Row():
424
+ depth_box = gr.Image(label="depth", type="pil", visible=True)
425
+ seg_box = gr.Image(label="seg", type="pil", visible=True)
426
+ gen_box = gr.Image(label="gen", type="pil", visible=True)
427
+
428
+ gr.Examples(examples=[
429
+ [f"assets/cars.jpg", "Which car is in front: the blue or the brown one?"],
430
+ [f"assets/pb.jpg", "Where is the bulding located with respect to the man?"],
431
+ ], inputs=[imagebox, textbox], cache_examples=False)
432
+
433
+ # gr.Markdown(tos_markdown)
434
+ # gr.Markdown(learn_more_markdown)
435
+ # url_params = gr.JSON(visible=False)
436
+
437
+ # Register listeners
438
+ btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
439
+
440
+ inter_vis_btn.click(
441
+ get_interm_outs,
442
+ [state],
443
+ [depth_box, seg_box, gen_box],
444
+ )
445
+
446
+ clear_btn.click(
447
+ clear_history,
448
+ None,
449
+ [state, chatbot, textbox, imagebox, depth_box, gen_box, seg_box] + btn_list,
450
+ queue=False
451
+ )
452
+
453
+ regenerate_btn.click(
454
+ delete_text,
455
+ [state, image_process_mode],
456
+ [state, chatbot, textbox, imagebox] + btn_list,
457
+ ).then(
458
+ generate,
459
+ [state, temperature, top_p, max_output_tokens],
460
+ [state, chatbot, textbox, imagebox] + btn_list,
461
+ )
462
+ textbox.submit(
463
+ add_text,
464
+ [state, imagebox, textbox, image_process_mode],
465
+ [state, chatbot, textbox, imagebox] + btn_list,
466
+ ).then(
467
+ generate,
468
+ [state, temperature, top_p, max_output_tokens],
469
+ [state, chatbot, textbox, imagebox] + btn_list,
470
+ )
471
+
472
+ submit_btn.click(
473
+ add_text,
474
+ [state, imagebox, textbox, image_process_mode],
475
+ [state, chatbot, textbox, imagebox] + btn_list,
476
+ ).then(
477
+ generate,
478
+ [state, temperature, top_p, max_output_tokens],
479
+ [state, chatbot, textbox, imagebox] + btn_list,
480
+ )
481
+
482
+ demo.queue(
483
+ status_update_rate=10,
484
+ api_open=False
485
+ ).launch(share=True)
486
+ demo.queue()
ola_vlm/.DS_Store ADDED
Binary file (6.15 kB). View file
 
ola_vlm/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .model import LlavaLlamaForCausalLM
2
+ from .model import LlavaPhi3ForCausalLM
ola_vlm/constants.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CONTROLLER_HEART_BEAT_EXPIRATION = 30
2
+ WORKER_HEART_BEAT_INTERVAL = 15
3
+
4
+ LOGDIR = "."
5
+
6
+ # Model Constants
7
+ IGNORE_INDEX = -100
8
+ IMAGE_TOKEN_INDEX = -200
9
+ DEFAULT_IMAGE_TOKEN = "<image>"
10
+ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
11
+ DEFAULT_IM_START_TOKEN = "<im_start>"
12
+ DEFAULT_IM_END_TOKEN = "<im_end>"
13
+ IMAGE_PLACEHOLDER = "<image-placeholder>"
ola_vlm/conversation.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ from enum import auto, Enum
3
+ from typing import List, Tuple
4
+ import base64
5
+ from io import BytesIO
6
+ from PIL import Image
7
+
8
+
9
+ class SeparatorStyle(Enum):
10
+ """Different separator style."""
11
+ SINGLE = auto()
12
+ TWO = auto()
13
+ MPT = auto()
14
+ PLAIN = auto()
15
+ LLAMA_3 = auto()
16
+
17
+
18
+ @dataclasses.dataclass
19
+ class Conversation:
20
+ """A class that keeps all conversation history."""
21
+ system: str
22
+ roles: List[str]
23
+ messages: List[List[str]]
24
+ offset: int
25
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
26
+ sep: str = "###"
27
+ sep2: str = None
28
+ version: str = "Unknown"
29
+
30
+ skip_next: bool = False
31
+
32
+ def get_prompt(self):
33
+ messages = self.messages
34
+ if len(messages) > 0 and type(messages[0][1]) is tuple:
35
+ messages = self.messages.copy()
36
+ init_role, init_msg = messages[0].copy()
37
+ init_msg = init_msg[0].replace("<image>", "").strip()
38
+ if 'mmtag' in self.version:
39
+ messages[0] = (init_role, init_msg)
40
+ messages.insert(0, (self.roles[0], "<Image><image></Image>"))
41
+ messages.insert(1, (self.roles[1], "Received."))
42
+ else:
43
+ messages[0] = (init_role, "<image>\n" + init_msg)
44
+
45
+ if self.sep_style == SeparatorStyle.SINGLE:
46
+ ret = self.system + self.sep
47
+ for role, message in messages:
48
+ if message:
49
+ if type(message) is tuple:
50
+ message, _, _ = message
51
+ ret += role + ": " + message + self.sep
52
+ else:
53
+ ret += role + ":"
54
+ elif self.sep_style == SeparatorStyle.TWO:
55
+ seps = [self.sep, self.sep2]
56
+ ret = self.system + seps[0]
57
+ for i, (role, message) in enumerate(messages):
58
+ if message:
59
+ if type(message) is tuple:
60
+ message, _, _ = message
61
+ ret += role + ": " + message + seps[i % 2]
62
+ else:
63
+ ret += role + ":"
64
+ elif self.sep_style == SeparatorStyle.MPT:
65
+ ret = self.system + self.sep
66
+ for role, message in messages:
67
+ if message:
68
+ if type(message) is tuple:
69
+ message, _, _ = message
70
+ ret += role + message + self.sep
71
+ else:
72
+ ret += role
73
+ elif self.sep_style == SeparatorStyle.LLAMA_2:
74
+ wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n" if len(msg) > 0 else msg
75
+ wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
76
+ ret = ""
77
+
78
+ for i, (role, message) in enumerate(messages):
79
+ if i == 0:
80
+ assert message, "first message should not be none"
81
+ assert role == self.roles[0], "first message should come from user"
82
+ if message:
83
+ if type(message) is tuple:
84
+ message, _, _ = message
85
+ if i == 0: message = wrap_sys(self.system) + message
86
+ if i % 2 == 0:
87
+ message = wrap_inst(message)
88
+ ret += self.sep + message
89
+ else:
90
+ ret += " " + message + " " + self.sep2
91
+ else:
92
+ ret += ""
93
+ ret = ret.lstrip(self.sep)
94
+ elif self.sep_style == SeparatorStyle.CHATML:
95
+ ret = "" if self.system == "" else self.system + self.sep + "\n"
96
+ for role, message in messages:
97
+ if message:
98
+ if type(message) is tuple:
99
+ message, images, _ = message
100
+ message = "<image>" * len(images) + message
101
+ ret += role + "\n" + message + self.sep + "\n"
102
+ else:
103
+ ret += role + "\n"
104
+ return ret
105
+ else:
106
+ raise ValueError(f"Invalid style: {self.sep_style}")
107
+
108
+ return ret
109
+
110
+ def append_message(self, role, message):
111
+ if isinstance(self.messages, tuple):
112
+ self.messages = list(self.messages)
113
+ self.messages.append([role, message])
114
+
115
+ def process_image(self, image, image_process_mode, return_pil=False, image_format='PNG', max_len=1344, min_len=672):
116
+ if image_process_mode == "Pad":
117
+ def expand2square(pil_img, background_color=(122, 116, 104)):
118
+ width, height = pil_img.size
119
+ if width == height:
120
+ return pil_img
121
+ elif width > height:
122
+ result = Image.new(pil_img.mode, (width, width), background_color)
123
+ result.paste(pil_img, (0, (width - height) // 2))
124
+ return result
125
+ else:
126
+ result = Image.new(pil_img.mode, (height, height), background_color)
127
+ result.paste(pil_img, ((height - width) // 2, 0))
128
+ return result
129
+ image = expand2square(image)
130
+ elif image_process_mode in ["Default", "Crop"]:
131
+ pass
132
+ elif image_process_mode == "Resize":
133
+ image = image.resize((336, 336))
134
+ else:
135
+ raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
136
+ if max(image.size) > max_len:
137
+ max_hw, min_hw = max(image.size), min(image.size)
138
+ aspect_ratio = max_hw / min_hw
139
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
140
+ longest_edge = int(shortest_edge * aspect_ratio)
141
+ W, H = image.size
142
+ if H > W:
143
+ H, W = longest_edge, shortest_edge
144
+ else:
145
+ H, W = shortest_edge, longest_edge
146
+ image = image.resize((W, H))
147
+ if return_pil:
148
+ return image
149
+ else:
150
+ buffered = BytesIO()
151
+ image.save(buffered, format=image_format)
152
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
153
+ return img_b64_str
154
+
155
+ def get_images(self, return_pil=False):
156
+ images = []
157
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
158
+ if i % 2 == 0:
159
+ if type(msg) is tuple:
160
+ msg, image, image_process_mode = msg
161
+ image = self.process_image(image, image_process_mode, return_pil=return_pil)
162
+ images.append(image)
163
+ return images
164
+
165
+ def to_gradio_chatbot(self):
166
+ ret = []
167
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
168
+ if i % 2 == 0:
169
+ if type(msg) is tuple:
170
+ msg, image, image_process_mode = msg
171
+ img_b64_str = self.process_image(
172
+ image, "Default", return_pil=False,
173
+ image_format='JPEG')
174
+ img_str = f'<img src="data:image/jpeg;base64,{img_b64_str}" alt="user upload image" />'
175
+ msg = img_str + msg.replace('<image>', '').strip()
176
+ ret.append([msg, None])
177
+ else:
178
+ ret.append([msg, None])
179
+ else:
180
+ ret[-1][-1] = msg
181
+ return ret
182
+
183
+ def copy(self):
184
+ return Conversation(
185
+ system=self.system,
186
+ roles=self.roles,
187
+ messages=[[x, y] for x, y in self.messages],
188
+ offset=self.offset,
189
+ sep_style=self.sep_style,
190
+ sep=self.sep,
191
+ sep2=self.sep2,
192
+ version=self.version)
193
+
194
+ def dict(self):
195
+ if len(self.get_images()) > 0:
196
+ return {
197
+ "system": self.system,
198
+ "roles": self.roles,
199
+ "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
200
+ "offset": self.offset,
201
+ "sep": self.sep,
202
+ "sep2": self.sep2,
203
+ }
204
+ return {
205
+ "system": self.system,
206
+ "roles": self.roles,
207
+ "messages": self.messages,
208
+ "offset": self.offset,
209
+ "sep": self.sep,
210
+ "sep2": self.sep2,
211
+ }
212
+
213
+ conv_vicuna_v1 = Conversation(
214
+ system="A chat between a curious user and an artificial intelligence assistant. "
215
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
216
+ roles=("USER", "ASSISTANT"),
217
+ version="v1",
218
+ messages=(),
219
+ offset=0,
220
+ sep_style=SeparatorStyle.TWO,
221
+ sep=" ",
222
+ sep2="</s>",
223
+ )
224
+
225
+ conv_llava_llama_3 = Conversation(
226
+ system="""<|start_header_id|>system<|end_header_id|>\n\nA chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.""",
227
+ roles=("<|start_header_id|>user<|end_header_id|>\n\n", "<|start_header_id|>assistant<|end_header_id|>\n\n"),
228
+ version="llama3",
229
+ messages=(),
230
+ offset=0,
231
+ sep_style=SeparatorStyle.MPT,
232
+ sep="<|eot_id|>",
233
+ )
234
+
235
+ conv_llava_phi_3 = Conversation(
236
+ system="""<|system|>\nYou are a helpful AI assistant.""",
237
+ roles=("\n<|user|>\n", "\n<|assistant|>\n"),
238
+ version="phi3",
239
+ messages=(),
240
+ offset=0,
241
+ sep_style=SeparatorStyle.MPT,
242
+ sep="<|end|>",
243
+ )
244
+
245
+ default_conversation = conv_llava_phi_3
246
+ conv_templates = {
247
+ "v1": conv_vicuna_v1,
248
+ "vicuna_v1": conv_vicuna_v1,
249
+ "llava_phi_3": conv_llava_phi_3,
250
+ "llava_llama_3": conv_llava_llama_3,
251
+ }
252
+
253
+
254
+ if __name__ == "__main__":
255
+ print(default_conversation.get_prompt())
ola_vlm/eval/.DS_Store ADDED
Binary file (6.15 kB). View file
 
ola_vlm/eval/eval_cv_bench.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import json
3
+ import argparse
4
+
5
+ def load_jsonl(f):
6
+ lines = open(f, encoding='utf-8').readlines()
7
+ lines = [x.strip() for x in lines]
8
+ if lines[-1] == '':
9
+ lines = lines[:-1]
10
+ data = [json.loads(x) for x in lines]
11
+ return data
12
+
13
+ if __name__ == '__main__':
14
+
15
+ parser = argparse.ArgumentParser()
16
+ parser.add_argument("--results_file", type=str, default="cv-bench_answer.jsonl")
17
+ args = parser.parse_args()
18
+
19
+ answers = load_jsonl(args.results_file)
20
+
21
+ data = {
22
+ "source": [],
23
+ "result": [],
24
+ "task": [],
25
+ }
26
+ import re
27
+ for a in answers:
28
+ data["source"].append(a["source"][0])
29
+ if "(" in a["prediction"]:
30
+ match = re.search(r'\(([A-Z])\)', a["prediction"])
31
+ if match:
32
+ pred = "(" + match.group(1) + ")"
33
+ else:
34
+ pred = "(" + a["prediction"][0] + ")"
35
+ data["result"].append(pred == a["answer"][0])
36
+ data["task"].append(a["task"][0])
37
+
38
+ df = pd.DataFrame(data)
39
+
40
+ def calculate_accuracy(df, source):
41
+ source_df = df[df['source'] == source]
42
+ accuracy = (source_df['result']).mean()
43
+ return accuracy
44
+
45
+ def calculate_task_accuracy(df, task):
46
+ source_df = df[df['task'] == task]
47
+ accuracy = (source_df['result']).mean()
48
+ return accuracy
49
+
50
+ accuracy_2d_ade = calculate_accuracy(df, 'ADE20K')
51
+ accuracy_2d_coco = calculate_accuracy(df, 'COCO')
52
+ accuracy_3d_omni = calculate_accuracy(df, 'Omni3D')
53
+
54
+ tasks = ["Count", "Depth", "Relation", "Distance"]
55
+
56
+ scores = {}
57
+
58
+ accuracy_2d = (accuracy_2d_ade + accuracy_2d_coco) / 2
59
+ accuracy_3d = accuracy_3d_omni
60
+
61
+ combined_accuracy = (accuracy_2d + accuracy_3d) / 2
62
+
63
+ scores["Overall"] = combined_accuracy
64
+
65
+ scores["3D"] = accuracy_3d
66
+ scores["2D"] = accuracy_2d
67
+
68
+ for t in tasks:
69
+ accuracy = calculate_task_accuracy(df, t)
70
+ scores[t] = accuracy
71
+
72
+ print("\n=========================CV-Bench Scores===============================")
73
+ for key, value in scores.items():
74
+ print(f"{key} -> {value}")
75
+ print("================================================================")
76
+
77
+ with open(args.results_file.replace('.jsonl', '_score.json'), "w") as f:
78
+ json.dump(scores, f, indent=2)
ola_vlm/eval/eval_mmstar.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import json
4
+
5
+ from ola_vlm.eval.mmstar.evaluate import MMStar_eval
6
+
7
+
8
+ def parse_args():
9
+ parser = argparse.ArgumentParser()
10
+ parser.add_argument('--results_file', type=str, default="./playground/data/eval/mmstar_results.jsonl")
11
+ return parser.parse_args()
12
+
13
+
14
+ if __name__ == '__main__':
15
+
16
+ args = parse_args()
17
+ MMStar_eval(args.results_file)
ola_vlm/eval/eval_probe_task.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+
4
+ from ola_vlm.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
5
+ from ola_vlm.conversation import conv_templates
6
+ from ola_vlm.model.builder import load_pretrained_model
7
+ from ola_vlm.utils import disable_torch_init
8
+ from ola_vlm.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path
9
+ from ola_vlm.model.aux_heads.oneformer_head import OneFormerHead
10
+ from transformers import OneFormerProcessor
11
+
12
+ from PIL import Image
13
+ import json
14
+ import os
15
+ from tqdm import tqdm
16
+ from icecream import ic
17
+ import warnings
18
+ warnings.filterwarnings("ignore")
19
+ import random
20
+ import numpy as np
21
+ from analyze.analyze_utils import prepare_coco, prepare_da2k
22
+ import math
23
+ from diffusers import StableUnCLIPImg2ImgPipeline
24
+ from diffusers import DPMSolverMultistepScheduler
25
+
26
+
27
+ def split_list(lst, n):
28
+ """Split a list into n (roughly) equal-sized chunks"""
29
+ chunk_size = math.ceil(len(lst) / n) # integer division
30
+ return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
31
+
32
+
33
+ def get_chunk(lst, n, k):
34
+ chunks = split_list(lst, n)
35
+ return chunks[k]
36
+
37
+ def set_seed(seed):
38
+ random.seed(seed)
39
+ np.random.seed(seed)
40
+ torch.manual_seed(seed)
41
+ torch.cuda.manual_seed_all(seed)
42
+
43
+ def load_image(image_file):
44
+ image = Image.open(image_file).convert('RGB')
45
+ return image
46
+
47
+ import glob
48
+
49
+ def list_image_files(directory):
50
+ image_extensions = ['*.png', '*.jpg', '*.jpeg', '*.gif', '*.bmp', '*.tiff']
51
+ image_files = []
52
+ for extension in image_extensions:
53
+ image_files.extend(glob.glob(os.path.join(directory, extension)))
54
+ return image_files
55
+
56
+ def prep_seginw(dir):
57
+ image_files = list_image_files(dir)
58
+ prompts = []
59
+ for image_file in image_files:
60
+ prompts.append("Describe the image")
61
+ return image_files, prompts, prompts
62
+
63
+ def predict(args):
64
+
65
+ mode = args.mode
66
+
67
+ name = args.model_path.split("/")[-1]
68
+ os.makedirs(f"plots/probes_task/{name}/", exist_ok=True)
69
+
70
+ # Model
71
+ disable_torch_init()
72
+
73
+ if mode == 'gen' or mode == 'seg':
74
+ images, prompts, answers = prepare_coco(args.json_file)
75
+ elif mode == 'depth':
76
+ images, prompts, answers = prepare_da2k("/mnt/vlpdatasets/sherlock/eval/DA-2K/DA-2K/images", is_eval=True)
77
+
78
+ images = get_chunk(images, args.num_chunks, args.chunk_idx)
79
+ prompts = get_chunk(prompts, args.num_chunks, args.chunk_idx)
80
+ answers = get_chunk(answers, args.num_chunks, args.chunk_idx)
81
+
82
+ model_name = get_model_name_from_path(args.model_path)
83
+ tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
84
+
85
+ if mode == "gen":
86
+ pipe = StableUnCLIPImg2ImgPipeline.from_pretrained(f"playground/jiteshjain_sherlock/stable-diffusion-2-1-unclip", torch_dtype=torch.float16, variant="fp16")
87
+ pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
88
+ pipe = pipe.to("cuda")
89
+
90
+ elif mode == "seg":
91
+ oneformer_processor = OneFormerProcessor.from_pretrained("/mnt/projects4jw/jiteshjain_sherlock/oneformer_coco_swin_large")
92
+ oneformer = OneFormerHead.from_pretrained("/mnt/projects4jw/jiteshjain_sherlock/oneformer_coco_swin_large")
93
+ oneformer = oneformer.to("cuda")
94
+
95
+ if "mistral" in model_name.lower():
96
+ conv_mode = "mistral_instruct"
97
+ elif "v1.6-34b" in model_name.lower():
98
+ conv_mode = "chatml_direct"
99
+ elif "llama3" in model_name.lower():
100
+ conv_mode = "llava_llama_3"
101
+ elif "qwen" in model_name.lower():
102
+ conv_mode = "qwen_1_5"
103
+ elif "v1" in model_name.lower():
104
+ conv_mode = "llava_v1"
105
+ elif "phi" in model_name.lower():
106
+ conv_mode = "llava_phi_3"
107
+
108
+ set_seed(42)
109
+
110
+ if mode == "gen":
111
+ try:
112
+ layers = model.config.image_gen["layer_indices"]
113
+ except:
114
+ layers = [i+1 for i in range(32)]
115
+ elif mode == "depth":
116
+ try:
117
+ layers = model.config.image_depth["layer_indices"]
118
+ except:
119
+ layers = [i+1 for i in range(32)]
120
+ elif mode == "seg":
121
+ try:
122
+ layers = model.config.image_seg["layer_indices"]
123
+ except:
124
+ layers = [i+1 for i in range(32)]
125
+
126
+ from tqdm import tqdm
127
+ for fname, prompt, answer in tqdm(zip(images, prompts, answers), total=len(prompts)):
128
+
129
+ conv = conv_templates[conv_mode].copy()
130
+ im = fname.split("/")[-1].split(".")[0]
131
+
132
+ image = load_image(fname)
133
+
134
+ image_size = image.size
135
+ image_tensor = process_images([image], image_processor, model.config)
136
+ if type(image_tensor) is list:
137
+ image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor]
138
+ else:
139
+ image_tensor = image_tensor.to(model.device, dtype=torch.float16)
140
+
141
+ inp = prompt
142
+ if image is not None:
143
+ if model.config.mm_use_im_start_end:
144
+ inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp
145
+ else:
146
+ inp = DEFAULT_IMAGE_TOKEN + '\n' + inp
147
+
148
+ conv.append_message(conv.roles[0], inp)
149
+ conv.append_message(conv.roles[1], None)
150
+ prompt = conv.get_prompt()
151
+
152
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
153
+
154
+ with torch.inference_mode():
155
+ out = model.get_visual_interpretations(
156
+ input_ids,
157
+ images=image_tensor,
158
+ image_sizes=image_size,
159
+ )
160
+
161
+ if mode == "seg":
162
+ seg_embs = out.seg_embs
163
+ inputs = oneformer_processor(image, ["semantic"], return_tensors="pt")
164
+ inputs["pixel_values"] = inputs["pixel_values"].to(out.logits.device, out.logits.dtype)
165
+ inputs["task_inputs"] = inputs["task_inputs"].to(out.logits.device, out.logits.dtype)
166
+ backbone_features = oneformer.get_backbone_feats(**inputs)
167
+ for i, seg_emb in enumerate(seg_embs):
168
+ pred = oneformer.get_masks(**inputs, backbone_last_feature=seg_emb.float(), all_backbone_features=backbone_features)
169
+ pred = oneformer_processor.post_process_semantic_segmentation(
170
+ pred, target_sizes=[image.size[::-1]]
171
+ )[0]
172
+ pred = pred.squeeze().cpu().numpy().astype(np.uint8)
173
+ pred = Image.fromarray(pred)
174
+ if not os.path.exists(f"plots/probes_task/{name}/seg/layer_{layers[i]}"):
175
+ os.makedirs(f"plots/probes_task/{name}/seg/layer_{layers[i]}", exist_ok=True)
176
+ save_path = os.path.join(f"plots/probes_task/{name}/seg/layer_{layers[i]}", fname.split("/")[-1].replace("jpg", "png"))
177
+ pred.save(save_path)
178
+
179
+
180
+ elif mode == "gen":
181
+ img_embeds = out.image_embs
182
+ images = []
183
+
184
+ for img_emb in img_embeds:
185
+ gen_image = pipe(image_embeds=img_emb.squeeze(1),
186
+ num_inference_steps=25,
187
+ ).images[0]
188
+ images.append(gen_image)
189
+
190
+ for i, image in enumerate(images):
191
+ image = image.resize((256, 256), Image.LANCZOS)
192
+ if not os.path.exists(f"plots/probes_task/{name}/gen/layer_{layers[i]}"):
193
+ os.makedirs(f"plots/probes_task/{name}/gen/layer_{layers[i]}", exist_ok=True)
194
+ save_path = os.path.join(f"plots/probes_task/{name}/gen/layer_{layers[i]}", fname.split("/")[-1])
195
+ image.save(save_path)
196
+
197
+ elif mode == "depth":
198
+ depth_preds = out.depth_preds
199
+
200
+ for i, depth_pred in enumerate(depth_preds):
201
+ if not os.path.exists(f"plots/probes_task/{name}/depth/layer_{layers[i]}"):
202
+ os.makedirs(f"plots/probes_task/{name}/depth/layer_{layers[i]}", exist_ok=True)
203
+ depth = depth_pred.squeeze(0).cpu().numpy() * 255.0
204
+ depth = depth.astype(np.uint8)
205
+ depth = Image.fromarray(depth)
206
+ save_path = os.path.join(f"plots/probes_task/{name}/depth/layer_{layers[i]}", fname.split("/")[-1])
207
+ depth.save(save_path)
208
+
209
+ if __name__ == "__main__":
210
+ parser = argparse.ArgumentParser()
211
+ parser.add_argument("--model-path", type=str, default="/mnt/projects4jw/jiteshjain_sherlock/llava-v1.5-7b")
212
+ parser.add_argument("--model-base", type=str, default=None)
213
+ parser.add_argument("--json-file", type=str, default="/mnt/projects4jw/jiteshjain_sherlock/datasets/coco/annotations/captions_val2017.json")
214
+ parser.add_argument("--device", type=str, default="cuda")
215
+ parser.add_argument("--temperature", type=float, default=0.2)
216
+ parser.add_argument("--max-new-tokens", type=int, default=10)
217
+ parser.add_argument("--load-8bit", action="store_true")
218
+ parser.add_argument("--load-4bit", action="store_true")
219
+ parser.add_argument("--mode", type=str, default="gen")
220
+ parser.add_argument("--num-chunks", type=int, default=1)
221
+ parser.add_argument("--chunk-idx", type=int, default=0)
222
+ args = parser.parse_args()
223
+ predict(args)
ola_vlm/eval/eval_sherlock_dsg.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+
4
+ from ola_vlm.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
5
+ from ola_vlm.conversation import conv_templates
6
+ from ola_vlm.model.builder import load_pretrained_model
7
+ from ola_vlm.utils import disable_torch_init
8
+ from ola_vlm.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path
9
+ from ola_vlm.model.aux_heads.sam_utils.build_sam import sam_model_registry
10
+ from ola_vlm.model.aux_heads.sam_utils.automatic_mask_generator import SamAutomaticMaskGenerator
11
+ from ola_vlm.model.aux_heads.oneformer_head import OneFormerHead, OneFormerSegHead, OneFormerTaskTokenSegHead
12
+ from ola_vlm.model.aux_heads.depth_anything_v2.dpt import DepthAnythingV2
13
+ from transformers import OneFormerProcessor
14
+
15
+ from diffusers import (
16
+ DPMSolverMultistepScheduler,
17
+ StableUnCLIPImg2ImgPipeline,
18
+ )
19
+
20
+ from PIL import Image
21
+ import json
22
+ import os
23
+ from tqdm import tqdm
24
+ from icecream import ic
25
+ import warnings
26
+ warnings.filterwarnings("ignore")
27
+ import random
28
+ import numpy as np
29
+ from analyze.analyze_utils import prepare_coco
30
+ import math
31
+
32
+ def split_list(lst, n):
33
+ """Split a list into n (roughly) equal-sized chunks"""
34
+ chunk_size = math.ceil(len(lst) / n) # integer division
35
+ return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
36
+
37
+
38
+ def get_chunk(lst, n, k):
39
+ chunks = split_list(lst, n)
40
+ return chunks[k]
41
+
42
+ def set_seed(seed):
43
+ random.seed(seed)
44
+ np.random.seed(seed)
45
+ torch.manual_seed(seed)
46
+ torch.cuda.manual_seed_all(seed)
47
+
48
+ def load_image(image_file):
49
+ image = Image.open(image_file).convert('RGB')
50
+ return image
51
+
52
+ import glob
53
+
54
+ def list_image_files(directory):
55
+ image_extensions = ['*.png', '*.jpg', '*.jpeg', '*.gif', '*.bmp', '*.tiff']
56
+ image_files = []
57
+ for extension in image_extensions:
58
+ image_files.extend(glob.glob(os.path.join(directory, extension)))
59
+ return image_files
60
+
61
+ def get_gen_feats(pipe, image):
62
+ with torch.no_grad():
63
+ clip_ims = pipe.feature_extractor(images=image, return_tensors="pt").pixel_values.to("cuda")
64
+ feat = pipe.image_encoder(clip_ims).image_embeds
65
+ return feat
66
+
67
+ def get_dav2_feats(dav2, image):
68
+ image = image.resize((336, 336))
69
+ image = np.array(image)
70
+ with torch.no_grad():
71
+ feat = dav2.infer_image(image, is_dsg=True)
72
+ return feat[-1][0]
73
+
74
+ def get_seg_feats(mask_generator, oneformer, oneformer_processor, seg_teacher, image):
75
+ if seg_teacher == "oneformer":
76
+ img = image.resize((768, 768))
77
+ inputs = oneformer_processor(img, ["panoptic"], return_tensors="pt")
78
+ inputs["pixel_values"] = inputs["pixel_values"].to("cuda")
79
+ with torch.no_grad():
80
+ feats = oneformer.forward_features(**inputs)
81
+ else:
82
+ img = np.array(image)
83
+ with torch.no_grad():
84
+ mask_generator.predictor.set_image(img)
85
+ feats = mask_generator.predictor.features
86
+ mask_generator.predictor.reset_image()
87
+ return feats
88
+
89
+
90
+ def predict(args):
91
+
92
+ mode = args.mode
93
+
94
+ name = args.model_path.split("/")[-1]
95
+ os.makedirs(f"plots/probe_scores/{name}/", exist_ok=True)
96
+
97
+ if "cambrian" in name:
98
+ from ola_vlm.cambrian.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
99
+ from ola_vlm.cambrian.conversation import conv_templates, SeparatorStyle
100
+ from ola_vlm.cambrian.model.builder import load_pretrained_model
101
+ from ola_vlm.cambrian.utils import disable_torch_init
102
+ from ola_vlm.cambrian.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
103
+
104
+ disable_torch_init()
105
+ model_name = get_model_name_from_path(args.model_path)
106
+ tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
107
+
108
+ if 'llama-2' in model_name.lower():
109
+ conv_mode = "cambrian_llama_2"
110
+ elif "v1" in model_name.lower():
111
+ conv_mode = "cambrian_v1"
112
+ elif "mpt" in model_name.lower():
113
+ conv_mode = "mpt"
114
+ else:
115
+ conv_mode = "cambrian_v0"
116
+
117
+ else:
118
+ from ola_vlm.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
119
+ from ola_vlm.conversation import conv_templates
120
+ from ola_vlm.model.builder import load_pretrained_model
121
+ from ola_vlm.utils import disable_torch_init
122
+ from ola_vlm.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path
123
+
124
+ disable_torch_init()
125
+ model_name = get_model_name_from_path(args.model_path)
126
+ tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
127
+ if "mistral" in model_name.lower():
128
+ conv_mode = "mistral_instruct"
129
+ elif "v1.6-34b" in model_name.lower():
130
+ conv_mode = "chatml_direct"
131
+ elif "llama3" in model_name.lower():
132
+ conv_mode = "llava_llama_3"
133
+ elif "qwen" in model_name.lower():
134
+ conv_mode = "llava_qwen"
135
+ elif "v1" in model_name.lower():
136
+ conv_mode = "llava_v1"
137
+ elif "phi" in model_name.lower():
138
+ conv_mode = "llava_phi_3"
139
+
140
+ images, prompts, answers = prepare_coco(args.json_file)
141
+
142
+ images = get_chunk(images, args.num_chunks, args.chunk_idx)
143
+ prompts = get_chunk(prompts, args.num_chunks, args.chunk_idx)
144
+ answers = get_chunk(answers, args.num_chunks, args.chunk_idx)
145
+
146
+ if mode == "gen":
147
+ pipe = StableUnCLIPImg2ImgPipeline.from_pretrained(f"playground/jiteshjain_sherlock/stable-diffusion-2-1-unclip", torch_dtype=torch.float16, variant="fp16")
148
+ pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
149
+ pipe = pipe.to("cuda")
150
+
151
+ elif mode == "seg":
152
+ oneformer_processor, oneformer, mask_generator = None, None, None
153
+ seg_teacher = model.config.image_seg.get("seg_teacher", "sam")
154
+ if seg_teacher == "sam":
155
+ sam = sam_model_registry["vit_l"](checkpoint="/mnt/projects4jw/jiteshjain_sherlock/oneformer_coco_swin_large")
156
+ sam = sam.to("cuda")
157
+ mask_generator = SamAutomaticMaskGenerator(sam.float())
158
+ else:
159
+ oneformer_processor = OneFormerProcessor.from_pretrained("/mnt/projects4jw/jiteshjain_sherlock/oneformer_coco_swin_large")
160
+ oneformer = OneFormerHead.from_pretrained("/mnt/projects4jw/jiteshjain_sherlock/oneformer_coco_swin_large")
161
+ oneformer = oneformer.to("cuda")
162
+
163
+ elif mode == "depth":
164
+ dav2_cfg = {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]}
165
+ dav2_backbone = DepthAnythingV2(**dav2_cfg)
166
+ dav2_backbone.load_state_dict(torch.load("/mnt/projects4jw/jiteshjain_sherlock/depth_anything_v2_vitl.pth", map_location='cpu'))
167
+ dav2_backbone = dav2_backbone.to("cuda")
168
+
169
+
170
+ set_seed(42)
171
+
172
+ if mode == "gen":
173
+ try:
174
+ layers = model.config.image_gen["layer_indices"]
175
+ except:
176
+ layers = [i+1 for i in range(32)]
177
+ elif mode == "depth":
178
+ try:
179
+ layers = model.config.image_depth["layer_indices"]
180
+ except:
181
+ layers = [i+1 for i in range(32)]
182
+ elif mode == "seg":
183
+ try:
184
+ layers = model.config.image_seg["layer_indices"]
185
+ except:
186
+ layers = [i+1 for i in range(32)]
187
+
188
+
189
+ os.makedirs(f"plots/probe_scores/{name}/{mode}/", exist_ok=True)
190
+
191
+ if os.path.exists(f"plots/probe_scores/{name}/{mode}/{args.num_chunks}_{args.chunk_idx}.json"):
192
+ with open(f"plots/probe_scores/{name}/{mode}/{args.num_chunks}_{args.chunk_idx}.json", 'r') as f:
193
+ diff_dict = json.load(f)
194
+ else:
195
+ diff_dict = {}
196
+
197
+ i = 0
198
+ from tqdm import tqdm
199
+ for fname, prompt, answer in tqdm(zip(images, prompts, answers), total=len(prompts)):
200
+
201
+ # if fname.split("/")[-1] in diff_dict.keys():
202
+ # continue
203
+
204
+ conv = conv_templates[conv_mode].copy()
205
+ image = load_image(fname)
206
+ image = image.resize((640, 640))
207
+
208
+ image_size = image.size
209
+
210
+ image_tensor = process_images([image], image_processor, model.config)
211
+ if type(image_tensor) is list:
212
+ image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor]
213
+ else:
214
+ image_tensor = image_tensor.to(model.device, dtype=torch.float16)
215
+
216
+ inp = prompt
217
+ if image is not None:
218
+ if model.config.mm_use_im_start_end:
219
+ inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp
220
+ else:
221
+ inp = DEFAULT_IMAGE_TOKEN + '\n' + inp
222
+
223
+ conv.append_message(conv.roles[0], inp)
224
+ conv.append_message(conv.roles[1], None)
225
+ prompt = conv.get_prompt()
226
+
227
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
228
+
229
+ with torch.inference_mode():
230
+ out = model.get_visual_interpretations(
231
+ input_ids,
232
+ images=image_tensor,
233
+ image_sizes=[image_size],
234
+ )
235
+
236
+ if mode == "gen":
237
+ embeds = out.image_embs
238
+ feats = get_gen_feats(pipe, image)
239
+ elif mode == "depth":
240
+ embeds = out.depth_embs
241
+ embeds = [emb[0][0] for emb in embeds]
242
+ feats = get_dav2_feats(dav2_backbone, image)
243
+ elif mode == "seg":
244
+ embeds = out.seg_embs
245
+ feats = get_seg_feats(mask_generator, oneformer, oneformer_processor, seg_teacher, image)
246
+
247
+ layer_diff = {}
248
+ for i, emb in enumerate(embeds):
249
+ emb = emb.to("cuda")
250
+ layer_diff[layers[i]] = torch.nn.CosineEmbeddingLoss(reduction="mean")(
251
+ emb.reshape(1, -1).float(), feats.reshape(1, -1).float(),
252
+ torch.ones(len(emb)).to(feats.device)
253
+ ).cpu().item()
254
+ from icecream import ic
255
+ ic(layer_diff[layers[i]])
256
+ diff_dict[fname.split("/")[-1]] = layer_diff
257
+
258
+ if i % 200 == 0:
259
+ # Save progress intermittently
260
+ with open(f"plots/probe_scores/{name}/{mode}/{args.num_chunks}_{args.chunk_idx}.json", 'w') as f:
261
+ json.dump(diff_dict, f, indent=2)
262
+
263
+ i += 1
264
+
265
+ with open(f"plots/probe_scores/{name}/{mode}/{args.num_chunks}_{args.chunk_idx}.json", 'w') as f:
266
+ json.dump(diff_dict, f, indent=2)
267
+
268
+ if __name__ == "__main__":
269
+ parser = argparse.ArgumentParser()
270
+ parser.add_argument("--model-path", type=str, default="/mnt/projects4jw/jiteshjain_sherlock/llava-v1.5-7b")
271
+ parser.add_argument("--model-base", type=str, default=None)
272
+ parser.add_argument("--json-file", type=str, default="/mnt/projects4jw/jiteshjain_sherlock/datasets/coco/annotations/captions_val2017.json")
273
+ parser.add_argument("--device", type=str, default="cuda")
274
+ parser.add_argument("--temperature", type=float, default=0.2)
275
+ parser.add_argument("--max-new-tokens", type=int, default=10)
276
+ parser.add_argument("--load-8bit", action="store_true")
277
+ parser.add_argument("--load-4bit", action="store_true")
278
+ parser.add_argument("--mode", type=str, default="gen")
279
+ parser.add_argument("--num-chunks", type=int, default=1)
280
+ parser.add_argument("--chunk-idx", type=int, default=0)
281
+ args = parser.parse_args()
282
+ predict(args)
ola_vlm/eval/get_all_stats.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import argparse
3
+ from icecream import ic
4
+ import os
5
+ import numpy as np
6
+
7
+
8
+ if __name__ == "__main__":
9
+ parser = argparse.ArgumentParser()
10
+ parser.add_argument("--results_folder", type=str, default="./playground/data/eval/results")
11
+ parser.add_argument("--ckpt", type=str)
12
+ args = parser.parse_args()
13
+
14
+ scores = {}
15
+
16
+ dirs = os.listdir(f"{args.results_folder}/{args.ckpt}")
17
+ for dir in dirs:
18
+ if args.ckpt in dir and dir not in args.ckpt:
19
+ break
20
+
21
+
22
+ try:
23
+ with open(f"{args.results_folder}/{args.ckpt}/mmstar/merge_score.json", "r") as f:
24
+ data = json.load(f)
25
+ scores["MMStar"] = round(data.get("final score", 0)*100, 1) if data.get("final score") is not None else None
26
+ except:
27
+ scores["MMStar"] = None
28
+
29
+ cv_scores = {}
30
+
31
+ with open(f"{args.results_folder}/{args.ckpt}/cv-bench/merge_score.json", "r") as f:
32
+ data = json.load(f)
33
+ scores["CV-Bench"] = round(data.get("Overall", 0)*100, 1) if data.get("Overall") is not None else None
34
+ cv_scores["CV-Bench (2D)"] = round(data.get("2D", 0)*100, 1) if data.get("2D") is not None else None
35
+ cv_scores["CV-Bench (3D)"] = round(data.get("3D", 0)*100, 1) if data.get("3D") is not None else None
36
+ cv_scores["CV-Bench (Count)"] = round(data.get("Count", 0)*100, 1) if data.get("Count") is not None else None
37
+ cv_scores["CV-Bench (Depth)"] = round(data.get("Depth", 0)*100, 1) if data.get("Depth") is not None else None
38
+ cv_scores["CV-Bench (Relation)"] = round(data.get("Relation", 0)*100, 1) if data.get("Relation") is not None else None
39
+ cv_scores["CV-Bench (Distance)"] = round(data.get("Distance", 0)*100, 1) if data.get("Distance") is not None else None
40
+
41
+
42
+ with open(f"{args.results_folder}/{args.ckpt}/{dir}/results.json", "r") as f:
43
+ results = json.load(f).get("results", {})
44
+ # scores["MME-Cognition"] = round(results.get("mme", {}).get("mme_cognition_score,none", 0), 1) if results.get("mme", {}).get("mme_cognition_score,none") is not None else None
45
+ # scores["MME-Perception"] = round(results.get("mme", {}).get("mme_percetion_score,none", 0), 1) if results.get("mme", {}).get("mme_percetion_score,none") is not None else None
46
+
47
+ scores["Realworld-QA"] = round(results.get("realworldqa", {}).get("exact_match,flexible-extract", 0)*100, 1) if results.get("realworldqa", {}).get("exact_match,flexible-extract") is not None else None
48
+ scores["VizWiz-VQA-Val"] = round(results.get("vizwiz_vqa_val", {}).get("exact_match,none", 0)*100, 1) if results.get("vizwiz_vqa_val", {}).get("exact_match,none") is not None else None
49
+ # scores["SEEDBench-Image"] = round(results.get("seedbench", {}).get("seed_image,none", 0)*100, 1) if results.get("seedbench", {}).get("seed_image,none") is not None else None
50
+ # scores["VQAv2-Val"] = round(results.get("vqav2_val", {}).get("exact_match,none", 0)*100, 1) if results.get("vqav2_val", {}).get("exact_match,none") is not None else None
51
+
52
+ # scores["Science-QA-Img"] = round(results.get("scienceqa_img", {}).get("exact_match,none", 0)*100, 1) if results.get("scienceqa_img", {}).get("exact_match,none") is not None else None
53
+ scores["MMMU-Val"] = round(results.get("mmmu_val", {}).get("mmmu_acc,none", 0)*100, 1) if results.get("mmmu_val", {}).get("mmmu_acc,none") is not None else None
54
+ # scores["MMBench"] = round(results.get("mmbench_en_dev", {}).get("gpt_eval_score,none", 0), 1) if results.get("mmbench_en_dev", {}).get("gpt_eval_score,none") is not None else None
55
+
56
+ # scores["NaturalBench"] = round(results.get("naturalbench", {}).get("mme_score,none", 0)*100, 1) if results.get("naturalbench", {}).get("mme_score,none") is not None else None
57
+
58
+ # scores["GQA"] = round(results.get("gqa", {}).get("exact_match,none", 0)*100, 1) if results.get("gqa", {}).get("exact_match,none") is not None else None
59
+ scores["POPE"] = round(results.get("pope", {}).get("pope_accuracy,none", 0)*100, 1) if results.get("pope", {}).get("pope_accuracy,none") is not None else None
60
+ scores["MMVet"] = round(results.get("mmvet", {}).get("gpt_eval_score", 0)*100, 1) if results.get("mmvet", {}).get("gpt_eval_score") is not None else None
61
+ scores["OK-VQA"] = round(results.get("ok_vqa", {}).get("exact_match,none", 0)*100, 1) if results.get("ok_vqa", {}).get("exact_match,none") is not None else None
62
+ # scores["ChartQA"] = round(results.get("chartqa", {}).get("relaxed_overall,none", 0)*100, 1) if results.get("chartqa", {}).get("relaxed_overall,none") is not None else None
63
+ # scores["DocVQA"] = round(results.get("docvqa_val", {}).get("anls,none", 0)*100, 1) if results.get("docvqa_val", {}).get("anls,none") is not None else None
64
+ # scores["TextVQA"] = round(results.get("textvqa_val", {}).get("exact_match,none", 0)*100, 1) if results.get("textvqa_val", {}).get("exact_match,none") is not None else None
65
+
66
+ try:
67
+ with open(f"{args.results_folder}/{args.ckpt}/mmvp/merge_score.json", "r") as f:
68
+ data = json.load(f)
69
+ scores["MMVP"] = round(data.get("mmvp", 0)*100, 1) if data.get("mmvp") is not None else None
70
+ except:
71
+ scores["MMVP"] = None
72
+
73
+ keys = list(scores.keys())
74
+ str_scores = [str(scores[key]) if scores[key] is not None else 'None' for key in keys]
75
+
76
+ abl_keys = ["CV-Bench", "MMStar", "VizWiz-VQA-Val", "MMVet", "MMVP", "MMMU-Val"]
77
+
78
+ abl_scores = [scores[key] for key in abl_keys if scores[key] is not None]
79
+
80
+ small_abl_keys = ["CV-Bench", "MMStar", "OK-VQA", "MMMU-Val"]
81
+ small_abl_scores = [scores[key] for key in small_abl_keys if scores[key] is not None]
82
+
83
+ cv_bench_keys = ["CV-Bench (2D)", "CV-Bench (3D)", "CV-Bench (Count)", "CV-Bench (Depth)", "CV-Bench (Relation)", "CV-Bench (Distance)"]
84
+ cv_bench_scores = [cv_scores[key] for key in cv_bench_keys if cv_scores[key] is not None]
85
+
86
+ # cat_scores = {}
87
+ # if os.path.exists(f"{args.results_folder}/{args.ckpt}/categorized_scores.json"):
88
+ # with open(f"{args.results_folder}/{args.ckpt}/categorized_scores.json", "r") as f:
89
+ # cat_scores = json.load(f)
90
+ # cat_scores.pop("Both")
91
+
92
+ print("\n====================All-Scores===========================================")
93
+ print(" & ".join(keys))
94
+ print(" & ".join(str_scores))
95
+ if abl_scores:
96
+ print("\n====================Abl-Scores===========================================")
97
+ print(" & ".join(abl_keys))
98
+ print(" & ".join([str(a) for a in abl_scores]))
99
+ print(f"Ablation Avg: {round(np.mean(abl_scores), 1)}")
100
+ else:
101
+ print("Ablation Avg: None")
102
+
103
+ if small_abl_scores:
104
+ print("\n====================Small-Abl-Scores===========================================")
105
+ print(" & ".join(small_abl_keys))
106
+ print(" & ".join([str(a) for a in small_abl_scores]))
107
+ print(f"Small-Ablation Avg: {round(np.mean(small_abl_scores), 1)}")
108
+ else:
109
+ print("Small-Ablation Avg: None")
110
+
111
+ if cv_bench_scores:
112
+ print("\n====================CV-Bench-Scores===========================================")
113
+ print(" & ".join(cv_bench_keys))
114
+ print(" & ".join([str(c) for c in cv_bench_scores]))
115
+ print(f"CV-Bench Overall: {round(np.mean(cv_bench_scores[:2]), 1)}")
116
+ else:
117
+ print("CV-Bench Avg: None")
118
+
119
+ # if cat_scores is not None:
120
+ # print("\n====================Categorized-Scores===========================================")
121
+ # cats = []
122
+ # class_scores = []
123
+ # benches = []
124
+ # for k, v in cat_scores.items():
125
+ # cats.append(k)
126
+ # for bench, score in v.items():
127
+ # benches.append(bench)
128
+ # class_scores.append(round(score*100, 1))
129
+ # print(" & ".join(cats))
130
+ # print(" & ".join(benches))
131
+ # print(" & ".join([str(c) for c in class_scores]))
132
+ # print("================================================================")
ola_vlm/eval/get_probe_task_scores.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ from PIL import Image
4
+ import json
5
+ import os
6
+ from tqdm import tqdm
7
+ import warnings
8
+ import random
9
+ import numpy as np
10
+ import multiprocessing as mp
11
+ from ola_vlm.eval.probe_metrics.fid_score import compute_fid
12
+ from analyze.analyze_utils import prepare_coco, prepare_da2k, parse_json
13
+ from multiprocessing import Pool
14
+ warnings.filterwarnings("ignore")
15
+
16
+ def set_seed(seed):
17
+ random.seed(seed)
18
+ np.random.seed(seed)
19
+ torch.manual_seed(seed)
20
+ torch.cuda.manual_seed_all(seed)
21
+
22
+ def load_image(image_file):
23
+ image = Image.open(image_file)
24
+ return image
25
+
26
+ def mask_iou(gt, pred):
27
+ gt = np.array(gt).astype(np.uint8)
28
+ pred = np.array(pred).astype(np.uint8)
29
+
30
+ iou_scores = []
31
+ for category in np.unique(gt):
32
+ if category == 255:
33
+ continue
34
+ gt_mask = (gt == category)
35
+ pred_mask = (pred == category)
36
+
37
+ intersection = np.logical_and(gt_mask, pred_mask)
38
+ union = np.logical_or(gt_mask, pred_mask)
39
+ if np.sum(union) == 0:
40
+ iou_scores.append(1.0)
41
+ else:
42
+ iou_scores.append(np.sum(intersection) / np.sum(union))
43
+
44
+ return np.mean(iou_scores)
45
+
46
+ def load_json(path):
47
+ with open(path) as f:
48
+ data = json.load(f)
49
+ return data
50
+
51
+ # Helper function for multiprocessing in evaluate_seg
52
+ def process_iou(args):
53
+ gt_path, layer_folder, dir, fname = args
54
+ gt_data = load_image(os.path.join(gt_path, fname.replace("jpg", "png")))
55
+ pred = load_image(os.path.join(layer_folder, dir, fname))
56
+ return mask_iou(gt_data, pred)
57
+
58
+ def evaluate_seg(args):
59
+ images, _, _ = prepare_coco("/mnt/vlpdatasets/coco/annotations/captions_val2017.json")
60
+ fnames = [img.split("/")[-1] for img in images][:8]
61
+
62
+ name = args.ckpt
63
+ gt_path = "/mnt/vlpdatasets/sherlock/eval/coco/annotations/panoptic_semseg_val2017"
64
+ layer_folder = f"plots/probes_task/{name}/seg"
65
+
66
+ scores = {"m_iou": []}
67
+ dirs = os.listdir(layer_folder)
68
+
69
+ with mp.Pool() as pool:
70
+ for dir in dirs:
71
+ print(f"Evaluating mask iou for {dir}")
72
+ args_list = [(gt_path, layer_folder, dir, fname) for fname in fnames]
73
+ m_iou = list(tqdm(pool.imap(process_iou, args_list), total=len(args_list), desc=f"Processing {dir}"))
74
+ scores["m_iou"].append({dir: round(np.mean(m_iou) * 100, 2)})
75
+
76
+ return scores
77
+
78
+ # Helper function for multiprocessing in evaluate_depth
79
+ def process_depth(args):
80
+ depth_map, point_1, point_2, answer = args
81
+ return score_points(depth_map, point_1, point_2, answer)
82
+
83
+ def score_points(depth_map, point_1, point_2, answer):
84
+ pt1_depth = depth_map[point_1[0], point_1[1]]
85
+ pt2_depth = depth_map[point_2[0], point_2[1]]
86
+
87
+ if isinstance(pt1_depth, np.ndarray):
88
+ pt1_depth = pt1_depth.mean()
89
+ if isinstance(pt2_depth, np.ndarray):
90
+ pt2_depth = pt2_depth.mean()
91
+
92
+ return (answer == "point2") if pt1_depth < pt2_depth else (answer == "point1")
93
+
94
+ def load_and_process_image(args):
95
+ folder, fname, entry = args
96
+ gt_path = os.path.join("/mnt/vlpdatasets/sherlock/plots/dav2_da2k", fname.split("/")[-1].split(".")[0] + ".jpg")
97
+ pred_path = os.path.join(folder, fname.split("/")[-1])
98
+
99
+ gt = load_image(gt_path)
100
+ pred = load_image(pred_path)
101
+ pred = pred.resize(gt.size)
102
+ pred = np.array(pred) / 255.0
103
+
104
+ # Process depth for each entry within the image
105
+ return [process_depth((pred, entry["point1"], entry["point2"], entry["closer_point"])) for entry in entry["entries"]]
106
+
107
+ def score_da2k_parallel(folder, anns):
108
+ pred_scores = []
109
+ tasks = [(folder, fname, {"entries": entries}) for fname, entries in anns.items()]
110
+
111
+ with Pool() as pool:
112
+ results = list(tqdm(pool.imap(load_and_process_image, tasks), total=len(tasks), desc="Processing images"))
113
+ for res in results:
114
+ if res is not None:
115
+ pred_scores.extend(res)
116
+
117
+ return np.mean(pred_scores) if pred_scores else 0
118
+
119
+ def evaluate_depth(args):
120
+ anns = parse_json("/mnt/vlpdatasets/sherlock/eval/DA-2K/DA-2K/annotations.json")
121
+
122
+ name = args.ckpt
123
+ layer_folder = f"plots/probes_task/{name}/depth"
124
+
125
+ scores = {"da2k_acc": []}
126
+ dirs = os.listdir(layer_folder)
127
+
128
+ for dir in dirs:
129
+ print(f"Evaluating da2k_acc for {dir}")
130
+ pred_scores = score_da2k_parallel(os.path.join(layer_folder, dir), anns)
131
+ scores["da2k_acc"].append({dir: round(pred_scores * 100, 2)})
132
+
133
+ return scores
134
+
135
+ def evaluate_fid(args):
136
+ name = args.ckpt
137
+ gt_path = os.path.join("plots/coco_gt")
138
+ layer_folder = f"plots/probes_task/{name}/gen"
139
+
140
+ scores = {"fid": []}
141
+ dirs = os.listdir(layer_folder)
142
+
143
+ for dir in dirs:
144
+ print(f"Evaluating fid for {dir}")
145
+ paths = [gt_path, os.path.join(layer_folder, dir)]
146
+ fid_score = compute_fid(paths)
147
+ scores["fid"].append({dir.replace("_", "-"): round(fid_score, 2)})
148
+
149
+ return scores
150
+
151
+ import re
152
+
153
+ def print_sorted_scores(scores, metric_name):
154
+ # Extract numeric part from layer names for sorting
155
+ sorted_scores = sorted(scores[metric_name], key=lambda x: int(re.search(r'\d+', list(x.keys())[0]).group()))
156
+
157
+ layers = [list(score.keys())[0] for score in sorted_scores]
158
+ values = [list(score.values())[0] for score in sorted_scores]
159
+
160
+ # Print sorted layers and scores in the requested format
161
+ print("\n=========================Results===============================")
162
+ print(" & ".join(layers))
163
+ print(" & ".join([f"{value}" for value in values]))
164
+ print(f"Average score: {round(np.mean(values), 2)}")
165
+ print("================================================================")
166
+
167
+ if __name__ == "__main__":
168
+ parser = argparse.ArgumentParser()
169
+ parser.add_argument("--ckpt", type=str, default="llava-1.5-7b")
170
+ parser.add_argument("--mode", type=str, default="gen")
171
+ args = parser.parse_args()
172
+
173
+ mode = args.mode
174
+
175
+ if mode == "gen":
176
+ scores = evaluate_fid(args)
177
+
178
+ print("\n=========================FID-Scores===============================")
179
+ for score in scores["fid"]:
180
+ for key, value in score.items():
181
+ print(f"{key} -> {value}")
182
+ print("================================================================")
183
+
184
+ elif mode == "seg":
185
+ scores = evaluate_seg(args)
186
+
187
+ print("\n=========================Mask-IOU===============================")
188
+ print_sorted_scores(scores, "m_iou")
189
+
190
+ elif mode == "depth":
191
+ scores = evaluate_depth(args)
192
+
193
+ print("\n=========================DA2K-Acc===============================")
194
+ print_sorted_scores(scores, "da2k_acc")
195
+
196
+ else:
197
+ print("Invalid mode. Choose from [gen, seg, depth]")
ola_vlm/eval/get_sherlock_dsg_scores.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+
4
+ import json
5
+ import os
6
+ from tqdm import tqdm
7
+ from icecream import ic
8
+ import warnings
9
+ warnings.filterwarnings("ignore")
10
+ import random
11
+ import numpy as np
12
+
13
+
14
+ def set_seed(seed):
15
+ random.seed(seed)
16
+ np.random.seed(seed)
17
+ torch.manual_seed(seed)
18
+ torch.cuda.manual_seed_all(seed)
19
+
20
+ if __name__ == "__main__":
21
+ parser = argparse.ArgumentParser()
22
+ parser.add_argument("--ckpt", type=str, default="llava-1.5-7b")
23
+ parser.add_argument("--mode", type=str, default="gen")
24
+ args = parser.parse_args()
25
+
26
+ mode = args.mode
27
+ name = args.ckpt.split("/")[-1]
28
+
29
+ with open(f'plots/probe_scores/{name}/{args.mode}.json') as file:
30
+ scores = json.load(file)
31
+
32
+ layer_scores = {}
33
+
34
+ for img, v in tqdm(scores.items()):
35
+ for layer, score in v.items():
36
+ if layer not in layer_scores:
37
+ layer_scores[layer] = []
38
+ layer_scores[layer].append(score)
39
+
40
+ for layer, scores in layer_scores.items():
41
+ layer_scores[layer] = np.mean(scores)
42
+
43
+ with open(f"plots/probe_scores/{name}/{mode}_scores.json", "w") as f:
44
+ json.dump(layer_scores, f, indent=2)
45
+
46
+ print(f"================Scores: {mode}===============")
47
+ for layer, score in layer_scores.items():
48
+ print(f"Layer: {layer}, Score: {score}")
49
+ print("===========================================")
ola_vlm/eval/merge_json.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import argparse
4
+
5
+ parser = argparse.ArgumentParser(
6
+ description='Probe eval')
7
+ parser.add_argument('--ckpt',
8
+ help='ckpt',
9
+ default='probe_llava-1.5-vicuna-7b-lr-1e-3')
10
+ parser.add_argument('--mode',
11
+ help='mode',
12
+ default='gen')
13
+ parser.add_argument("--num-chunks", type=int, default=1)
14
+
15
+
16
+ def save_merged_json(data, output_file):
17
+ with open(output_file, 'w') as file:
18
+ json.dump(data, file, indent=4)
19
+
20
+ if __name__ == "__main__":
21
+ args = parser.parse_args()
22
+ merge_data = {}
23
+ name = args.ckpt.split("/")[-1]
24
+
25
+ for i in range(args.num_chunks):
26
+ with open(f'plots/probe_scores/{name}/{args.mode}/{args.num_chunks}_{i}.json', 'r') as file:
27
+ data = json.load(file)
28
+ merge_data.update(data)
29
+
30
+ save_merged_json(merge_data, f'plots/probe_scores/{name}/{args.mode}.json')
ola_vlm/eval/mmstar/evaluate/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .mmstar import MMStar_eval
ola_vlm/eval/mmstar/evaluate/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (183 Bytes). View file
 
ola_vlm/eval/mmstar/evaluate/__pycache__/mmstar.cpython-310.pyc ADDED
Binary file (2.45 kB). View file
 
ola_vlm/eval/mmstar/evaluate/mmstar.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ola_vlm.eval.mmstar.smp import *
2
+ from copy import deepcopy
3
+
4
+
5
+ def MMStar_eval(eval_file):
6
+ MMStar_score_l2 = {
7
+ 'coarse perception': {
8
+ 'image scene and topic': 0,
9
+ 'image style & quality': 0,
10
+ 'image emotion': 0
11
+ },
12
+ 'fine-grained perception': {
13
+ 'object counting': 0,
14
+ 'recognition': 0,
15
+ 'localization': 0
16
+ },
17
+ 'instance reasoning': {
18
+ 'single-instance reasoning': 0,
19
+ 'cross-instance attribute reasoning': 0,
20
+ 'cross-instance relation reasoning': 0
21
+ },
22
+ 'logical reasoning': {
23
+ 'code & sequence reasoning': 0,
24
+ 'diagram reasoning': 0,
25
+ 'common reasoning': 0
26
+ },
27
+ 'science & technology': {
28
+ 'biology & chemistry & physics': 0,
29
+ 'electronics & energy & mechanical eng.': 0,
30
+ 'geography & earth science & agriculture': 0
31
+ },
32
+ 'math': {
33
+ 'geometry': 0,
34
+ 'numeric commonsense and calculation': 0,
35
+ 'statistical reasoning': 0
36
+ },
37
+ }
38
+ MMStar_counter = deepcopy(MMStar_score_l2)
39
+ logger = get_logger('Evaluation')
40
+
41
+ data = load(eval_file)
42
+ lt = len(data)
43
+ lines = [data[i] for i in range(lt)]
44
+ for i in tqdm(range(len(lines))):
45
+ line = lines[i]
46
+ predict = str(line['prediction'])
47
+ answers = str(line['answer'])
48
+ category = str(line['category'])
49
+ l2_category = str(line['l2_category'])
50
+ MMStar_counter[category][l2_category] += 1
51
+
52
+ answer = answers.lower().strip().replace('\n', ' ')
53
+ predict = predict.lower().strip().replace('\n', ' ')
54
+
55
+ try:
56
+ if answer == predict[0]:
57
+ MMStar_score_l2[category][l2_category] += 1
58
+ elif predict[0] == '(' and answer == predict[1]:
59
+ MMStar_score_l2[category][l2_category] += 1
60
+ elif predict[0:7] == 'option ' and answer == predict[7]:
61
+ MMStar_score_l2[category][l2_category] += 1
62
+ elif predict[0:14] == 'the answer is ' and answer == predict[14]:
63
+ MMStar_score_l2[category][l2_category] += 1
64
+ except Exception as e:
65
+ pass
66
+
67
+ MMStar_score = {}
68
+ MMStar_score['final score'] = 0
69
+ for k, v in MMStar_score_l2.items():
70
+ MMStar_score[k] = 0
71
+ for l2_k, l2_v in v.items():
72
+ MMStar_score[f'{k}({l2_k})'] = float(l2_v) / \
73
+ float(MMStar_counter[k][l2_k])
74
+ MMStar_score[k] += l2_v
75
+ MMStar_score['final score'] += MMStar_score[k]
76
+ MMStar_score[k] = float(MMStar_score[k]) / 250.0
77
+ MMStar_score['final score'] = float(MMStar_score['final score']) / 1500.0
78
+
79
+ score_pth = eval_file.replace('.jsonl', '_score.json')
80
+ dump(MMStar_score, score_pth)
81
+ logger.info(
82
+ f'MMStar_eval successfully finished evaluating {eval_file}, results saved in {score_pth}')
83
+ logger.info('Score: ')
84
+ for key, value in MMStar_score.items():
85
+ logger.info('{}:{}'.format(key, value))
86
+
87
+ return MMStar_score
ola_vlm/eval/mmstar/smp/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .file import *
2
+ from .misc import *
3
+ from .log import *
ola_vlm/eval/mmstar/smp/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (188 Bytes). View file
 
ola_vlm/eval/mmstar/smp/__pycache__/file.cpython-310.pyc ADDED
Binary file (7.12 kB). View file
 
ola_vlm/eval/mmstar/smp/__pycache__/log.cpython-310.pyc ADDED
Binary file (1.02 kB). View file
 
ola_vlm/eval/mmstar/smp/__pycache__/misc.cpython-310.pyc ADDED
Binary file (5.18 kB). View file
 
ola_vlm/eval/mmstar/smp/__pycache__/vlm.cpython-310.pyc ADDED
Binary file (4.99 kB). View file
 
ola_vlm/eval/mmstar/smp/file.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import hashlib
3
+ import json
4
+ import os
5
+ import os.path as osp
6
+ import pickle
7
+ import time
8
+
9
+ import numpy as np
10
+ import pandas as pd
11
+
12
+
13
+ class NumpyEncoder(json.JSONEncoder):
14
+ def default(self, obj):
15
+ if isinstance(obj, (np.int_, np.intc, np.intp, np.int8,
16
+ np.int16, np.int32, np.int64, np.uint8,
17
+ np.uint16, np.uint32, np.uint64)):
18
+ return int(obj)
19
+ elif isinstance(obj, (np.float_, np.float16, np.float32, np.float64)):
20
+ return float(obj)
21
+ elif isinstance(obj, (np.complex_, np.complex64, np.complex128)):
22
+ return {'real': obj.real, 'imag': obj.imag}
23
+ elif isinstance(obj, (np.ndarray,)):
24
+ return obj.tolist()
25
+ elif isinstance(obj, (np.bool_)):
26
+ return bool(obj)
27
+ elif isinstance(obj, (np.void)):
28
+ return None
29
+ return json.JSONEncoder.default(self, obj)
30
+
31
+ # LOAD & DUMP
32
+ def dump(data, f, **kwargs):
33
+ def dump_pkl(data, pth, **kwargs):
34
+ pickle.dump(data, open(pth, 'wb'))
35
+
36
+ def dump_json(data, pth, **kwargs):
37
+ json.dump(data, open(pth, 'w'), indent=4, ensure_ascii=False, cls=NumpyEncoder)
38
+
39
+ def dump_jsonl(data, f, **kwargs):
40
+ lines = [json.dumps(x, ensure_ascii=False, cls=NumpyEncoder) for x in data]
41
+ with open(f, 'w', encoding='utf8') as fout:
42
+ fout.write('\n'.join(lines))
43
+
44
+ def dump_xlsx(data, f, **kwargs):
45
+ data.to_excel(f, index=False, engine='xlsxwriter')
46
+
47
+ def dump_csv(data, f, quoting=csv.QUOTE_ALL):
48
+ data.to_csv(f, index=False, encoding='utf-8', quoting=quoting)
49
+
50
+ def dump_tsv(data, f, quoting=csv.QUOTE_ALL):
51
+ data.to_csv(f, sep='\t', index=False, encoding='utf-8', quoting=quoting)
52
+
53
+ handlers = dict(pkl=dump_pkl, json=dump_json, jsonl=dump_jsonl, xlsx=dump_xlsx, csv=dump_csv, tsv=dump_tsv)
54
+ suffix = f.split('.')[-1]
55
+ return handlers[suffix](data, f, **kwargs)
56
+
57
+ def load(f):
58
+ def load_pkl(pth):
59
+ return pickle.load(open(pth, 'rb'))
60
+
61
+ def load_json(pth):
62
+ return json.load(open(pth, 'r', encoding='utf-8'))
63
+
64
+ def load_jsonl(f):
65
+ lines = open(f, encoding='utf-8').readlines()
66
+ lines = [x.strip() for x in lines]
67
+ if lines[-1] == '':
68
+ lines = lines[:-1]
69
+ data = [json.loads(x) for x in lines]
70
+ return data
71
+
72
+ def load_xlsx(f):
73
+ return pd.read_excel(f)
74
+
75
+ def load_csv(f):
76
+ return pd.read_csv(f)
77
+
78
+ def load_tsv(f):
79
+ return pd.read_csv(f, sep='\t')
80
+
81
+ handlers = dict(pkl=load_pkl, json=load_json, jsonl=load_jsonl, xlsx=load_xlsx, csv=load_csv, tsv=load_tsv)
82
+ suffix = f.split('.')[-1]
83
+ return handlers[suffix](f)
84
+
85
+ def download_file(url, filename=None):
86
+ import urllib.request
87
+
88
+ from tqdm import tqdm
89
+
90
+ class DownloadProgressBar(tqdm):
91
+ def update_to(self, b=1, bsize=1, tsize=None):
92
+ if tsize is not None:
93
+ self.total = tsize
94
+ self.update(b * bsize - self.n)
95
+
96
+ if filename is None:
97
+ filename = url.split('/')[-1]
98
+
99
+ with DownloadProgressBar(unit='B', unit_scale=True,
100
+ miniters=1, desc=url.split('/')[-1]) as t:
101
+ urllib.request.urlretrieve(url, filename=filename, reporthook=t.update_to)
102
+ return filename
103
+
104
+ def ls(dirname='.', match='', mode='all', level=1):
105
+ if dirname == '.':
106
+ ans = os.listdir(dirname)
107
+ else:
108
+ ans = [osp.join(dirname, x) for x in os.listdir(dirname)]
109
+ assert mode in ['all', 'dir', 'file']
110
+ assert level >= 1 and isinstance(level, int)
111
+ if level == 1:
112
+ ans = [x for x in ans if match in x]
113
+ if mode == 'dir':
114
+ ans = [x for x in ans if osp.isdir(x)]
115
+ elif mode == 'file':
116
+ ans = [x for x in ans if not osp.isdir(x)]
117
+ else:
118
+ ans = [x for x in ans if osp.isdir(x)]
119
+ res = []
120
+ for d in ans:
121
+ res.extend(ls(d, match=match, mode=mode, level=level-1))
122
+ ans = res
123
+ return ans
124
+
125
+ def mrlines(fname, sp='\n'):
126
+ f = open(fname).read().split(sp)
127
+ while f != [] and f[-1] == '':
128
+ f = f[:-1]
129
+ return f
130
+
131
+ def mwlines(lines, fname):
132
+ with open(fname, 'w') as fout:
133
+ fout.write('\n'.join(lines))
134
+
135
+ def md5(file_pth):
136
+ with open(file_pth, 'rb') as f:
137
+ hash = hashlib.new('md5')
138
+ for chunk in iter(lambda: f.read(2**20), b''):
139
+ hash.update(chunk)
140
+ return str(hash.hexdigest())
141
+
142
+ def last_modified(pth):
143
+ stamp = osp.getmtime(pth)
144
+ m_ti = time.ctime(stamp)
145
+ t_obj = time.strptime(m_ti)
146
+ t = time.strftime('%Y%m%d%H%M%S', t_obj)[2:]
147
+ return t
ola_vlm/eval/mmstar/smp/log.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ logger_initialized = {}
4
+
5
+ def get_logger(name, log_file=None, log_level=logging.INFO, file_mode='w'):
6
+ logger = logging.getLogger(name)
7
+ if name in logger_initialized:
8
+ return logger
9
+
10
+ for logger_name in logger_initialized:
11
+ if name.startswith(logger_name):
12
+ return logger
13
+
14
+ stream_handler = logging.StreamHandler()
15
+ handlers = [stream_handler]
16
+
17
+ try:
18
+ import torch.distributed as dist
19
+ if dist.is_available() and dist.is_initialized():
20
+ rank = dist.get_rank()
21
+ else:
22
+ rank = 0
23
+ except ImportError:
24
+ rank = 0
25
+
26
+ if rank == 0 and log_file is not None:
27
+ file_handler = logging.FileHandler(log_file, file_mode)
28
+ handlers.append(file_handler)
29
+
30
+ formatter = logging.Formatter(
31
+ '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
32
+ for handler in handlers:
33
+ handler.setFormatter(formatter)
34
+ handler.setLevel(log_level)
35
+ logger.addHandler(handler)
36
+
37
+ if rank == 0:
38
+ logger.setLevel(log_level)
39
+ else:
40
+ logger.setLevel(logging.ERROR)
41
+
42
+ logger_initialized[name] = True
43
+ return logger
ola_vlm/eval/mmstar/smp/misc.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa: F401, F403
2
+ import abc
3
+ import argparse
4
+ import copy as cp
5
+ import csv
6
+ import datetime
7
+ import multiprocessing as mp
8
+ import os
9
+ import os.path as osp
10
+ import random as rd
11
+ import shutil
12
+ import subprocess
13
+ import warnings
14
+ from collections import OrderedDict, defaultdict
15
+ from multiprocessing import Pool, current_process
16
+
17
+ import matplotlib.pyplot as plt
18
+ import pandas as pd
19
+ import requests
20
+ import seaborn as sns
21
+ from huggingface_hub import scan_cache_dir
22
+ from sty import bg, ef, fg, rs
23
+ from tabulate import tabulate, tabulate_formats
24
+ from tqdm import tqdm
25
+
26
+
27
+ def process_punctuation(inText):
28
+ import re
29
+ outText = inText
30
+ punct = [
31
+ ';', r'/', '[', ']', '"', '{', '}', '(', ')', '=', '+', '\\', '_', '-',
32
+ '>', '<', '@', '`', ',', '?', '!'
33
+ ]
34
+ commaStrip = re.compile('(\d)(,)(\d)') # noqa: W605
35
+ periodStrip = re.compile('(?!<=\d)(\.)(?!\d)') # noqa: W605
36
+ for p in punct:
37
+ if (p + ' ' in inText or ' ' + p in inText) or (re.search(
38
+ commaStrip, inText) is not None):
39
+ outText = outText.replace(p, '')
40
+ else:
41
+ outText = outText.replace(p, ' ')
42
+ outText = periodStrip.sub('', outText, re.UNICODE)
43
+ return outText
44
+
45
+
46
+ def h2r(value):
47
+ if value[0] == '#':
48
+ value = value[1:]
49
+ assert len(value) == 6
50
+ return tuple(int(value[i:i + 2], 16) for i in range(0, 6, 2))
51
+
52
+
53
+ def r2h(rgb):
54
+ return '#%02x%02x%02x' % rgb
55
+
56
+
57
+ def colored(s, color):
58
+ if isinstance(color, str):
59
+ if hasattr(fg, color):
60
+ return getattr(fg, color) + s + fg.rs
61
+ color = h2r(color)
62
+ return fg(*color) + s + fg.rs
63
+
64
+
65
+ def istype(s, type):
66
+ if isinstance(s, type):
67
+ return True
68
+ try:
69
+ return isinstance(eval(s), type)
70
+ except Exception as _:
71
+ return False
72
+
73
+
74
+ def bincount(lst):
75
+ bins = defaultdict(lambda: 0)
76
+ for item in lst:
77
+ bins[item] += 1
78
+ return bins
79
+
80
+
81
+ def get_cache_path(repo_id):
82
+ hf_cache_info = scan_cache_dir()
83
+ repos = list(hf_cache_info.repos)
84
+ repo = None
85
+ for r in repos:
86
+ if r.repo_id == repo_id:
87
+ repo = r
88
+ break
89
+ if repo is None:
90
+ return None
91
+ revs = list(repo.revisions)
92
+ rev2keep, last_modified = None, 0
93
+ for rev in revs:
94
+ if rev.last_modified > last_modified:
95
+ rev2keep, last_modified = rev, rev.last_modified
96
+ if rev2keep is None:
97
+ return None
98
+ return str(rev2keep.snapshot_path)
99
+
100
+
101
+ def proxy_set(s):
102
+ import os
103
+ for key in ['http_proxy', 'HTTP_PROXY', 'https_proxy', 'HTTPS_PROXY']:
104
+ os.environ[key] = s
105
+
106
+
107
+ def get_rank_and_world_size():
108
+ local_rank = int(os.environ.get("RANK", 0))
109
+ world_size = int(os.environ.get("WORLD_SIZE", 1))
110
+ return local_rank, world_size
111
+
112
+
113
+ def get_local_rank_and_world_size():
114
+ local_rank = int(os.environ.get("LOCAL_RANK", 0))
115
+ world_size = int(os.environ.get("WORLD_SIZE", 1))
116
+ return local_rank, world_size
117
+
118
+
119
+ def splitlen(s, sym='/'):
120
+ return len(s.split(sym))
121
+
122
+
123
+ def listinstr(lst, s):
124
+ assert isinstance(lst, list)
125
+ for item in lst:
126
+ if item in s:
127
+ return True
128
+ return False
129
+
130
+
131
+ def d2df(D):
132
+ return pd.DataFrame({x: [D[x]] for x in D})
133
+
134
+
135
+ def cn_string(s):
136
+ import re
137
+ if re.search(u'[\u4e00-\u9fff]', s):
138
+ return True
139
+ return False
140
+
141
+
142
+ try:
143
+ import decord
144
+ except ImportError:
145
+ pass
146
+
147
+
148
+ def timestr(second=True, minute=False):
149
+ s = datetime.datetime.now().strftime('%Y%m%d%H%M%S')[2:]
150
+ if second:
151
+ return s
152
+ elif minute:
153
+ return s[:-2]
154
+ else:
155
+ return s[:-4]
156
+
157
+
158
+ def dict_merge(dct, merge_dct):
159
+ for k, _ in merge_dct.items():
160
+ if (k in dct and isinstance(dct[k], dict) and isinstance(merge_dct[k], dict)): # noqa
161
+ dict_merge(dct[k], merge_dct[k])
162
+ else:
163
+ dct[k] = merge_dct[k]
164
+
165
+
166
+ def youtube_dl(idx):
167
+ cmd = f'youtube-dl -f best -f mp4 "{idx}" -o {idx}.mp4'
168
+ os.system(cmd)
169
+
170
+
171
+ def run_command(cmd):
172
+ if isinstance(cmd, str):
173
+ cmd = cmd.split()
174
+ return subprocess.check_output(cmd)
ola_vlm/eval/model_cvbench_loader.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ import os
4
+ import json
5
+ from tqdm import tqdm
6
+ import shortuuid
7
+
8
+ from ola_vlm.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
9
+ from ola_vlm.conversation import conv_templates, SeparatorStyle
10
+ from ola_vlm.model.builder import load_pretrained_model
11
+ from ola_vlm.utils import disable_torch_init
12
+ from ola_vlm.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path
13
+ from torch.utils.data import Dataset, DataLoader
14
+ from datasets import load_dataset
15
+ from PIL import Image
16
+ import math
17
+
18
+
19
+ def split_list(lst, n):
20
+ """Split a list into n (roughly) equal-sized chunks"""
21
+ chunk_size = math.ceil(len(lst) / n) # integer division
22
+ return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
23
+
24
+
25
+ def get_chunk(lst, n, k):
26
+ chunks = split_list(lst, n)
27
+ return chunks[k]
28
+
29
+ def load_jsonl(f):
30
+ lines = open(f, encoding='utf-8').readlines()
31
+ lines = [x.strip() for x in lines]
32
+ if lines[-1] == '':
33
+ lines = lines[:-1]
34
+ data = [json.loads(x) for x in lines]
35
+ return data
36
+
37
+ def prepare_CVBench(path):
38
+ dataset = load_jsonl(os.path.join(path, 'test.jsonl'))
39
+ data = []
40
+ for i in range(len(dataset)):
41
+ d = {
42
+ "image": os.path.join(path, dataset[i]["filename"]),
43
+ "question": dataset[i]["prompt"] + "\nOnly answer the option as the output. For example, if your answer is the option A, answer (A).",
44
+ "answer": dataset[i]["answer"],
45
+ "task": dataset[i]["task"],
46
+ "source": dataset[i]["source"]
47
+ }
48
+ data.append(d)
49
+ return data
50
+
51
+
52
+ # Custom dataset class
53
+ class CustomDataset(Dataset):
54
+ def __init__(self, data, tokenizer, image_processor, model_config):
55
+ self.questions = data
56
+ self.tokenizer = tokenizer
57
+ self.image_processor = image_processor
58
+ self.model_config = model_config
59
+
60
+ def __getitem__(self, index):
61
+ d = self.questions[index]
62
+ qs = d["question"]
63
+ image_file = d["image"]
64
+ ans = d["answer"]
65
+ task = d["task"]
66
+ source = d["source"]
67
+
68
+ if self.model_config.mm_use_im_start_end:
69
+ qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
70
+ else:
71
+ qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
72
+
73
+ conv = conv_templates[args.conv_mode].copy()
74
+ conv.append_message(conv.roles[0], qs)
75
+ conv.append_message(conv.roles[1], None)
76
+ prompt = conv.get_prompt()
77
+
78
+ image = Image.open(image_file).convert('RGB')
79
+ image_tensor = process_images([image], self.image_processor, self.model_config)[0]
80
+
81
+ input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')
82
+
83
+ return input_ids, image_tensor, image.size, ans, task, source
84
+
85
+ def __len__(self):
86
+ return len(self.questions)
87
+
88
+
89
+ def collate_fn(batch):
90
+ input_ids, image_tensors, image_sizes, answers, cats, cats_l2 = zip(*batch)
91
+ input_ids = torch.stack(input_ids, dim=0)
92
+ image_tensors = torch.stack(image_tensors, dim=0)
93
+ return input_ids, image_tensors, image_sizes, answers, cats, cats_l2
94
+
95
+
96
+ # DataLoader
97
+ def create_data_loader(questions, tokenizer, image_processor, model_config, batch_size=1, num_workers=4):
98
+ assert batch_size == 1, "batch_size must be 1"
99
+ dataset = CustomDataset(questions, tokenizer, image_processor, model_config)
100
+ data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, collate_fn=collate_fn)
101
+ return data_loader
102
+
103
+
104
+ def eval_model(args):
105
+ # Model
106
+ disable_torch_init()
107
+ model_path = os.path.expanduser(args.model_path)
108
+ model_name = get_model_name_from_path(model_path)
109
+ tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name)
110
+
111
+ questions = prepare_CVBench(args.path)
112
+ questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
113
+ answers_file = os.path.expanduser(args.answers_file)
114
+ os.makedirs(os.path.dirname(answers_file), exist_ok=True)
115
+ ans_file = open(answers_file, "w")
116
+
117
+ if 'plain' in model_name and 'finetune' not in model_name.lower() and 'mmtag' not in args.conv_mode:
118
+ args.conv_mode = args.conv_mode + '_mmtag'
119
+ print(f'It seems that this is a plain model, but it is not using a mmtag prompt, auto switching to {args.conv_mode}.')
120
+
121
+ data_loader = create_data_loader(questions, tokenizer, image_processor, model.config)
122
+
123
+ for (input_ids, image_tensor, image_sizes, answer, task, source), line in tqdm(zip(data_loader, questions), total=len(questions)):
124
+ input_ids = input_ids.to(device='cuda', non_blocking=True)
125
+
126
+ with torch.inference_mode():
127
+ output_ids = model.generate(
128
+ input_ids,
129
+ images=image_tensor.to(dtype=torch.float16, device='cuda', non_blocking=True),
130
+ image_sizes=image_sizes,
131
+ do_sample=True if args.temperature > 0 else False,
132
+ temperature=args.temperature,
133
+ top_p=args.top_p,
134
+ num_beams=args.num_beams,
135
+ max_new_tokens=args.max_new_tokens,
136
+ use_cache=True)
137
+
138
+ if not isinstance(output_ids, torch.Tensor):
139
+ output_ids = output_ids.sequences
140
+
141
+ outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
142
+
143
+ ans_file.write(json.dumps({"prediction": outputs,
144
+ "answer": answer,
145
+ "question": line,
146
+ "source": source,
147
+ "task": task}) + "\n")
148
+ # ans_file.flush()
149
+ ans_file.close()
150
+
151
+ if __name__ == "__main__":
152
+ parser = argparse.ArgumentParser()
153
+ parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
154
+ parser.add_argument("--model-base", type=str, default=None)
155
+ parser.add_argument("--path", type=str, default="CV-Bench")
156
+ parser.add_argument("--answers-file", type=str, default="cv-bench_answer.jsonl")
157
+ parser.add_argument("--conv-mode", type=str, default="llava_phi_3")
158
+ parser.add_argument("--num-chunks", type=int, default=1)
159
+ parser.add_argument("--chunk-idx", type=int, default=0)
160
+ parser.add_argument("--temperature", type=float, default=0.2)
161
+ parser.add_argument("--top_p", type=float, default=None)
162
+ parser.add_argument("--num_beams", type=int, default=1)
163
+ parser.add_argument("--max_new_tokens", type=int, default=128)
164
+ args = parser.parse_args()
165
+
166
+ eval_model(args)
ola_vlm/eval/model_mmstar_loader.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ import os
4
+ import json
5
+ from tqdm import tqdm
6
+ import shortuuid
7
+
8
+ from ola_vlm.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
9
+ from ola_vlm.conversation import conv_templates, SeparatorStyle
10
+ from ola_vlm.model.builder import load_pretrained_model
11
+ from ola_vlm.utils import disable_torch_init
12
+ from ola_vlm.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path
13
+ from torch.utils.data import Dataset, DataLoader
14
+ from datasets import load_dataset
15
+ from PIL import Image
16
+ import math
17
+
18
+
19
+ def split_list(lst, n):
20
+ """Split a list into n (roughly) equal-sized chunks"""
21
+ chunk_size = math.ceil(len(lst) / n) # integer division
22
+ return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
23
+
24
+
25
+ def get_chunk(lst, n, k):
26
+ chunks = split_list(lst, n)
27
+ return chunks[k]
28
+
29
+
30
+ def prepare_MMStar(path):
31
+ os.makedirs(f"{path}/images", exist_ok=True)
32
+ dataset = load_dataset(path, "val")
33
+ dataset = dataset["val"]
34
+ data = []
35
+ for i in range(len(dataset)):
36
+ if not os.path.exists(f"{path}/images/{i}.jpeg"):
37
+ dataset[i]["image"].save(f"{path}/images/{i}.jpeg")
38
+ prompt = dataset[i]["question"] + "\n"
39
+ prompt += "Answer with the option's letter from the given choices directly, such as answer letter 'A' only. \n"
40
+
41
+ d = {
42
+ "image": f"{path}/images/{i}.jpeg",
43
+ "question": prompt,
44
+ "answer": dataset[i]["answer"],
45
+ "category": dataset[i]["category"],
46
+ "l2_category": dataset[i]["l2_category"]
47
+ }
48
+ data.append(d)
49
+ return data
50
+
51
+
52
+ # Custom dataset class
53
+ class CustomDataset(Dataset):
54
+ def __init__(self, data, tokenizer, image_processor, model_config):
55
+ self.questions = data
56
+ self.tokenizer = tokenizer
57
+ self.image_processor = image_processor
58
+ self.model_config = model_config
59
+
60
+ def __getitem__(self, index):
61
+ d = self.questions[index]
62
+ qs = d["question"]
63
+ image_file = d["image"]
64
+ ans = d["answer"]
65
+
66
+ if self.model_config.mm_use_im_start_end:
67
+ qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
68
+ else:
69
+ qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
70
+
71
+ conv = conv_templates[args.conv_mode].copy()
72
+ conv.append_message(conv.roles[0], qs)
73
+ conv.append_message(conv.roles[1], None)
74
+ prompt = conv.get_prompt()
75
+
76
+ image = Image.open(image_file).convert('RGB')
77
+ image_tensor = process_images([image], self.image_processor, self.model_config)[0]
78
+
79
+ input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')
80
+
81
+ return input_ids, image_tensor, image.size, ans, d["category"], d["l2_category"]
82
+
83
+ def __len__(self):
84
+ return len(self.questions)
85
+
86
+
87
+ def collate_fn(batch):
88
+ input_ids, image_tensors, image_sizes, answers, cats, cats_l2 = zip(*batch)
89
+ input_ids = torch.stack(input_ids, dim=0)
90
+ image_tensors = torch.stack(image_tensors, dim=0)
91
+ return input_ids, image_tensors, image_sizes, answers, cats, cats_l2
92
+
93
+
94
+ # DataLoader
95
+ def create_data_loader(questions, tokenizer, image_processor, model_config, batch_size=1, num_workers=4):
96
+ assert batch_size == 1, "batch_size must be 1"
97
+ dataset = CustomDataset(questions, tokenizer, image_processor, model_config)
98
+ data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, collate_fn=collate_fn)
99
+ return data_loader
100
+
101
+
102
+ def eval_model(args):
103
+ # Model
104
+ disable_torch_init()
105
+ model_path = os.path.expanduser(args.model_path)
106
+ model_name = get_model_name_from_path(model_path)
107
+ tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name)
108
+
109
+ questions = prepare_MMStar(args.path)
110
+ questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
111
+ answers_file = os.path.expanduser(args.answers_file)
112
+ os.makedirs(os.path.dirname(answers_file), exist_ok=True)
113
+ ans_file = open(answers_file, "w")
114
+
115
+ if 'plain' in model_name and 'finetune' not in model_name.lower() and 'mmtag' not in args.conv_mode:
116
+ args.conv_mode = args.conv_mode + '_mmtag'
117
+ print(f'It seems that this is a plain model, but it is not using a mmtag prompt, auto switching to {args.conv_mode}.')
118
+
119
+ data_loader = create_data_loader(questions, tokenizer, image_processor, model.config)
120
+
121
+ for (input_ids, image_tensor, image_sizes, answer, cat, cat_l2), line in tqdm(zip(data_loader, questions), total=len(questions)):
122
+ input_ids = input_ids.to(device='cuda', non_blocking=True)
123
+
124
+ with torch.inference_mode():
125
+ output_ids = model.generate(
126
+ input_ids,
127
+ images=image_tensor.to(dtype=torch.float16, device='cuda', non_blocking=True),
128
+ image_sizes=image_sizes,
129
+ do_sample=True if args.temperature > 0 else False,
130
+ temperature=args.temperature,
131
+ top_p=args.top_p,
132
+ num_beams=args.num_beams,
133
+ max_new_tokens=args.max_new_tokens,
134
+ use_cache=True)
135
+
136
+ if not isinstance(output_ids, torch.Tensor):
137
+ output_ids = output_ids.sequences
138
+
139
+ outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
140
+
141
+ ans_file.write(json.dumps({"prediction": outputs,
142
+ "answer": answer[0],
143
+ "question": line,
144
+ "category": cat[0],
145
+ "l2_category": cat_l2[0]}) + "\n")
146
+ # ans_file.flush()
147
+ ans_file.close()
148
+
149
+ if __name__ == "__main__":
150
+ parser = argparse.ArgumentParser()
151
+ parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
152
+ parser.add_argument("--model-base", type=str, default=None)
153
+ parser.add_argument("--path", type=str, default="MMStar")
154
+ parser.add_argument("--answers-file", type=str, default="mmstar_answer.jsonl")
155
+ parser.add_argument("--conv-mode", type=str, default="llava_phi_3")
156
+ parser.add_argument("--num-chunks", type=int, default=1)
157
+ parser.add_argument("--chunk-idx", type=int, default=0)
158
+ parser.add_argument("--temperature", type=float, default=0.2)
159
+ parser.add_argument("--top_p", type=float, default=None)
160
+ parser.add_argument("--num_beams", type=int, default=1)
161
+ parser.add_argument("--max_new_tokens", type=int, default=128)
162
+ args = parser.parse_args()
163
+
164
+ eval_model(args)
ola_vlm/mm_utils.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from io import BytesIO
3
+ import base64
4
+ import torch
5
+ import math
6
+ import ast
7
+ import re
8
+ from transformers import StoppingCriteria
9
+ from ola_vlm.constants import IMAGE_TOKEN_INDEX
10
+
11
+ ###########################################
12
+
13
+ def resize_and_center_crop(image, shortest_edge_length):
14
+ # Calculate new dimensions and resize
15
+ aspect_ratio = float(image.width) / float(image.height)
16
+ if aspect_ratio > 1:
17
+ new_width = int(shortest_edge_length * aspect_ratio)
18
+ new_height = shortest_edge_length
19
+ else:
20
+ new_width = shortest_edge_length
21
+ new_height = int(shortest_edge_length / aspect_ratio)
22
+ resized_image = image.resize((new_width, new_height), Image.ANTIALIAS)
23
+
24
+ # Calculate the position and perform the center crop
25
+ left = (new_width - shortest_edge_length) / 2
26
+ top = (new_height - shortest_edge_length) / 2
27
+ right = (new_width + shortest_edge_length) / 2
28
+ bottom = (new_height + shortest_edge_length) / 2
29
+ cropped_image = resized_image.crop((left, top, right, bottom))
30
+
31
+ return cropped_image
32
+
33
+
34
+ def auto_pad_images(image, grid_params):
35
+ assert isinstance(image, Image.Image), "Input should be a Pillow Image"
36
+ assert len(grid_params) > 0, "Grid parameters should not be empty"
37
+
38
+ # Step 1: Calculate and find the closest aspect ratio
39
+ input_width, input_height = image.size
40
+ input_aspect_ratio = input_width / input_height
41
+ candidate_resolutions = [(w / h, w, h) for w in grid_params for h in grid_params]
42
+ closest_aspect_ratio = min(candidate_resolutions, key=lambda x: abs(input_aspect_ratio - x[0]))
43
+
44
+ candidate_resolutions = [(x[1], x[2]) for x in candidate_resolutions if abs(x[0] - closest_aspect_ratio[0]) < 1e-3]
45
+
46
+ target_resolution = min(candidate_resolutions, key=lambda res: abs(max(input_width, input_height) / max(res) - 1))
47
+
48
+ resize_width, resize_height = target_resolution
49
+ if input_width > input_height:
50
+ resize_height = int(resize_width / input_aspect_ratio)
51
+ else:
52
+ resize_width = int(resize_height * input_aspect_ratio)
53
+ resized_image = image.resize((resize_width, resize_height), Image.ANTIALIAS)
54
+
55
+ # Step 5: Pad the resized image if necessary to match the target resolution
56
+ pad_width = target_resolution[0] - resize_width
57
+ pad_height = target_resolution[1] - resize_height
58
+ padded_image = Image.new("RGB", target_resolution, color=(0, 0, 0))
59
+ padded_image.paste(resized_image, (pad_width // 2, pad_height // 2))
60
+
61
+ return padded_image
62
+
63
+
64
+ def extract_patches(image, patch_size, overlap_ratio):
65
+ assert isinstance(image, Image.Image), "Input should be a Pillow Image"
66
+ assert patch_size > 0, "Patch size should be greater than 0"
67
+ assert 0 <= overlap_ratio < 1, "Overlap ratio should be between 0 and 1"
68
+
69
+ W, H = image.size
70
+ patches = []
71
+
72
+ stride = int(patch_size * (1 - overlap_ratio))
73
+
74
+ num_patches_y = (H - patch_size) // stride + 1
75
+ num_patches_x = (W - patch_size) // stride + 1
76
+
77
+ y_start = (H - (num_patches_y - 1) * stride - patch_size) // 2
78
+ x_start = (W - (num_patches_x - 1) * stride - patch_size) // 2
79
+
80
+ for y in range(y_start, y_start + num_patches_y * stride, stride):
81
+ for x in range(x_start, x_start + num_patches_x * stride, stride):
82
+ patch = image.crop((x, y, x + patch_size, y + patch_size))
83
+ patches.append(patch)
84
+
85
+ return patches
86
+
87
+
88
+ def process_highres_image_crop_split(image, data_args, processor=None):
89
+ crop_resolution = data_args.image_crop_resolution
90
+ split_resolution = data_args.image_split_resolution
91
+ if processor is None:
92
+ processor = data_args.image_processor
93
+ image_crop = resize_and_center_crop(image, crop_resolution)
94
+ image_patches = extract_patches(image_crop, patch_size=split_resolution, overlap_ratio=0)
95
+ image_patches = [processor.preprocess(image_patch, return_tensors="pt")["pixel_values"][0] for image_patch in image_patches]
96
+ return torch.stack(image_patches, dim=0)
97
+
98
+
99
+ def process_highres_image(image, processor, grid_pinpoints):
100
+ grid_params = [int(x) for x in grid_pinpoints.split(",")]
101
+ width_height = max(image.size)
102
+ fit_grid_params = [x for x in grid_params if x >= width_height]
103
+ if len(fit_grid_params) == 0:
104
+ select_size = max(grid_params)
105
+ else:
106
+ select_size = min(fit_grid_params)
107
+ # FIXME: always select the 448
108
+ select_size = max(grid_params)
109
+ image_padded = expand2square(image, tuple(int(x * 255) for x in processor.image_mean))
110
+
111
+ # FIXME: this seems to be a bug that it always resizes instead of padding
112
+ image_original_resize = image.resize((processor.size["shortest_edge"], processor.size["shortest_edge"]))
113
+ image_padded = image_padded.resize((select_size, select_size))
114
+ image_patches = extract_patches(image_padded, patch_size=processor.size["shortest_edge"], overlap_ratio=0)
115
+ image_patches = [image_original_resize] + image_patches
116
+ image_patches = [processor.preprocess(image_patch, return_tensors="pt")["pixel_values"][0] for image_patch in image_patches]
117
+ return torch.stack(image_patches, dim=0)
118
+
119
+ ########################################
120
+
121
+ def select_best_resolution(original_size, possible_resolutions):
122
+ """
123
+ Selects the best resolution from a list of possible resolutions based on the original size.
124
+
125
+ Args:
126
+ original_size (tuple): The original size of the image in the format (width, height).
127
+ possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
128
+
129
+ Returns:
130
+ tuple: The best fit resolution in the format (width, height).
131
+ """
132
+ original_width, original_height = original_size
133
+ best_fit = None
134
+ max_effective_resolution = 0
135
+ min_wasted_resolution = float('inf')
136
+
137
+ for width, height in possible_resolutions:
138
+ scale = min(width / original_width, height / original_height)
139
+ downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
140
+ effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
141
+ wasted_resolution = (width * height) - effective_resolution
142
+
143
+ if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
144
+ max_effective_resolution = effective_resolution
145
+ min_wasted_resolution = wasted_resolution
146
+ best_fit = (width, height)
147
+
148
+ return best_fit
149
+
150
+
151
+ def resize_and_pad_image(image, target_resolution):
152
+ """
153
+ Resize and pad an image to a target resolution while maintaining aspect ratio.
154
+
155
+ Args:
156
+ image (PIL.Image.Image): The input image.
157
+ target_resolution (tuple): The target resolution (width, height) of the image.
158
+
159
+ Returns:
160
+ PIL.Image.Image: The resized and padded image.
161
+ """
162
+ original_width, original_height = image.size
163
+ target_width, target_height = target_resolution
164
+
165
+ scale_w = target_width / original_width
166
+ scale_h = target_height / original_height
167
+
168
+ if scale_w < scale_h:
169
+ new_width = target_width
170
+ new_height = min(math.ceil(original_height * scale_w), target_height)
171
+ else:
172
+ new_height = target_height
173
+ new_width = min(math.ceil(original_width * scale_h), target_width)
174
+
175
+ # Resize the image
176
+ resized_image = image.resize((new_width, new_height))
177
+
178
+ new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0))
179
+ paste_x = (target_width - new_width) // 2
180
+ paste_y = (target_height - new_height) // 2
181
+ new_image.paste(resized_image, (paste_x, paste_y))
182
+
183
+ return new_image
184
+
185
+
186
+ def divide_to_patches(image, patch_size):
187
+ """
188
+ Divides an image into patches of a specified size.
189
+
190
+ Args:
191
+ image (PIL.Image.Image): The input image.
192
+ patch_size (int): The size of each patch.
193
+
194
+ Returns:
195
+ list: A list of PIL.Image.Image objects representing the patches.
196
+ """
197
+ patches = []
198
+ width, height = image.size
199
+ for i in range(0, height, patch_size):
200
+ for j in range(0, width, patch_size):
201
+ box = (j, i, j + patch_size, i + patch_size)
202
+ patch = image.crop(box)
203
+ patches.append(patch)
204
+
205
+ return patches
206
+
207
+
208
+ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
209
+ """
210
+ Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
211
+
212
+ Args:
213
+ image_size (tuple): The size of the input image in the format (width, height).
214
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
215
+ patch_size (int): The size of each image patch.
216
+
217
+ Returns:
218
+ tuple: The shape of the image patch grid in the format (width, height).
219
+ """
220
+ if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints:
221
+ assert patch_size in [224, 336, 384, 448, 512], "patch_size should be in [224, 336, 384, 448, 512]"
222
+ # Use regex to extract the range from the input string
223
+ matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints)
224
+ range_start = tuple(map(int, matches[0]))
225
+ range_end = tuple(map(int, matches[-1]))
226
+ # Generate a matrix of tuples from (range_start[0], range_start[1]) to (range_end[0], range_end[1])
227
+ grid_pinpoints = [(i, j) for i in range(range_start[0], range_end[0] + 1) for j in range(range_start[1], range_end[1] + 1)]
228
+ # Multiply all elements by patch_size
229
+ grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints]
230
+ if type(grid_pinpoints) is list:
231
+ possible_resolutions = grid_pinpoints
232
+ else:
233
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
234
+ width, height = select_best_resolution(image_size, possible_resolutions)
235
+ return width // patch_size, height // patch_size
236
+
237
+
238
+ def process_anyres_image(image, processor, grid_pinpoints):
239
+ """
240
+ Process an image with variable resolutions.
241
+
242
+ Args:
243
+ image (PIL.Image.Image): The input image to be processed.
244
+ processor: The image processor object.
245
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
246
+
247
+ Returns:
248
+ torch.Tensor: A tensor containing the processed image patches.
249
+ """
250
+ # Convert grid_pinpoints from string to list
251
+ if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints:
252
+ try:
253
+ patch_size = processor.size[0]
254
+ except Exception as e:
255
+ patch_size = processor.size["shortest_edge"]
256
+ assert patch_size in [224, 336, 384, 448, 512], "patch_size should be in [224, 336, 384, 448, 512]"
257
+ # Use regex to extract the range from the input string
258
+ matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints)
259
+ range_start = tuple(map(int, matches[0]))
260
+ range_end = tuple(map(int, matches[-1]))
261
+ # Generate a matrix of tuples from (range_start[0], range_start[1]) to (range_end[0], range_end[1])
262
+ grid_pinpoints = [(i, j) for i in range(range_start[0], range_end[0] + 1) for j in range(range_start[1], range_end[1] + 1)]
263
+ # Multiply all elements by patch_size
264
+ grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints]
265
+
266
+ if type(grid_pinpoints) is list:
267
+ possible_resolutions = grid_pinpoints
268
+ else:
269
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
270
+ best_resolution = select_best_resolution(image.size, possible_resolutions)
271
+ image_padded = resize_and_pad_image(image, best_resolution)
272
+
273
+ patches = divide_to_patches(image_padded, processor.crop_size["height"])
274
+
275
+ # FIXME: this seems to be a bug that it resizes instead of pad.
276
+ # but to keep it consistent with previous, i will keep it as it is
277
+ # TODO: uncomment below to ablate with the padding
278
+ if isinstance(processor.size, dict):
279
+ shortest_edge = processor.size["shortest_edge"]
280
+ else:
281
+ shortest_edge = min(processor.size)
282
+ image_original_resize = image.resize((shortest_edge, shortest_edge))
283
+ # image_padded_square = expand2square(image, tuple(int(x*255) for x in processor.image_mean))
284
+ # image_original_resize = image_padded_square.resize((processor.size['shortest_edge'], processor.size['shortest_edge']))
285
+
286
+ image_patches = [image_original_resize] + patches
287
+ image_patches = [processor.preprocess(image_patch, return_tensors="pt")["pixel_values"][0] for image_patch in image_patches]
288
+ return torch.stack(image_patches, dim=0)
289
+
290
+
291
+ def load_image_from_base64(image):
292
+ return Image.open(BytesIO(base64.b64decode(image)))
293
+
294
+
295
+ def expand2square(pil_img, background_color):
296
+ width, height = pil_img.size
297
+ if width == height:
298
+ return pil_img
299
+ elif width > height:
300
+ result = Image.new(pil_img.mode, (width, width), background_color)
301
+ result.paste(pil_img, (0, (width - height) // 2))
302
+ return result
303
+ else:
304
+ result = Image.new(pil_img.mode, (height, height), background_color)
305
+ result.paste(pil_img, ((height - width) // 2, 0))
306
+ return result
307
+
308
+
309
+ def process_images(images, image_processor, model_cfg):
310
+ image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
311
+ new_images = []
312
+ if image_aspect_ratio == "highres":
313
+ for image in images:
314
+ image = process_highres_image(image, image_processor, model_cfg.image_grid_pinpoints)
315
+ new_images.append(image)
316
+ elif image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio:
317
+ for image in images:
318
+ image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints)
319
+ new_images.append(image)
320
+ elif image_aspect_ratio == "crop_split":
321
+ for image in images:
322
+ image = process_highres_image_crop_split(image, model_cfg, image_processor)
323
+ new_images.append(image)
324
+ elif image_aspect_ratio == "pad":
325
+ for image in images:
326
+ image = expand2square(image, tuple(int(x * 255) for x in image_processor.image_mean))
327
+ image = image_processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
328
+ new_images.append(image)
329
+ else:
330
+ return image_processor.preprocess(images, return_tensors="pt")["pixel_values"]
331
+ if all(x.shape == new_images[0].shape for x in new_images):
332
+ new_images = torch.stack(new_images, dim=0)
333
+ return new_images
334
+
335
+
336
+ def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
337
+ prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
338
+
339
+ def insert_separator(X, sep):
340
+ return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
341
+
342
+ input_ids = []
343
+ offset = 0
344
+ if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
345
+ offset = 1
346
+ input_ids.append(prompt_chunks[0][0])
347
+
348
+ for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
349
+ input_ids.extend(x[offset:])
350
+
351
+ if return_tensors is not None:
352
+ if return_tensors == 'pt':
353
+ return torch.tensor(input_ids, dtype=torch.long)
354
+ raise ValueError(f'Unsupported tensor type: {return_tensors}')
355
+ return input_ids
356
+
357
+
358
+ def get_model_name_from_path(model_path):
359
+ model_path = model_path.strip("/")
360
+ model_paths = model_path.split("/")
361
+ if model_paths[-1].startswith('checkpoint-'):
362
+ return model_paths[-2] + "_" + model_paths[-1]
363
+ else:
364
+ return model_paths[-1]
365
+
366
+ class KeywordsStoppingCriteria(StoppingCriteria):
367
+ def __init__(self, keywords, tokenizer, input_ids):
368
+ self.keywords = keywords
369
+ self.keyword_ids = []
370
+ self.max_keyword_len = 0
371
+ for keyword in keywords:
372
+ cur_keyword_ids = tokenizer(keyword).input_ids
373
+ if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
374
+ cur_keyword_ids = cur_keyword_ids[1:]
375
+ if len(cur_keyword_ids) > self.max_keyword_len:
376
+ self.max_keyword_len = len(cur_keyword_ids)
377
+ self.keyword_ids.append(torch.tensor(cur_keyword_ids))
378
+ self.tokenizer = tokenizer
379
+ self.start_len = input_ids.shape[1]
380
+
381
+ def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
382
+ offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
383
+ self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
384
+ for keyword_id in self.keyword_ids:
385
+ truncated_output_ids = output_ids[0, -keyword_id.shape[0]:]
386
+ if torch.equal(truncated_output_ids, keyword_id):
387
+ return True
388
+ outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
389
+ for keyword in self.keywords:
390
+ if keyword in outputs:
391
+ return True
392
+ return False
393
+
394
+ def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
395
+ outputs = []
396
+ for i in range(output_ids.shape[0]):
397
+ outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
398
+ return all(outputs)
ola_vlm/model/.DS_Store ADDED
Binary file (6.15 kB). View file
 
ola_vlm/model/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .language_model.llava_llama import LlavaLlamaForCausalLM, LlavaConfig
2
+ from .language_model.llava_phi3 import LlavaPhi3ForCausalLM, LlavaPhi3Config
3
+ from .language_model.ola_llama import OlaLlavaLlamaForCausalLM, OlaLlavaLlamaConfig
4
+ from .language_model.ola_phi3 import OlaLlavaPhi3ForCausalLM, OlaLlavaPhi3Config
5
+ from .language_model.probe_llava_llama import ProbeDSGLlavaLlamaForCausalLM, ProbeDSGLlavaLlamaConfig
ola_vlm/model/apply_delta.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage:
3
+ python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta
4
+ """
5
+ import argparse
6
+
7
+ import torch
8
+ from tqdm import tqdm
9
+ from transformers import AutoTokenizer, AutoModelForCausalLM
10
+ from llava import LlavaLlamaForCausalLM
11
+
12
+
13
+ def apply_delta(base_model_path, target_model_path, delta_path):
14
+ print("Loading base model")
15
+ base = AutoModelForCausalLM.from_pretrained(
16
+ base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17
+
18
+ print("Loading delta")
19
+ delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
20
+ delta_tokenizer = AutoTokenizer.from_pretrained(delta_path)
21
+
22
+ print("Applying delta")
23
+ for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"):
24
+ if name not in base.state_dict():
25
+ assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model'
26
+ continue
27
+ if param.data.shape == base.state_dict()[name].shape:
28
+ param.data += base.state_dict()[name]
29
+ else:
30
+ assert name in ['model.embed_tokens.weight', 'lm_head.weight'], \
31
+ f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}'
32
+ bparam = base.state_dict()[name]
33
+ param.data[:bparam.shape[0], :bparam.shape[1]] += bparam
34
+
35
+ print("Saving target model")
36
+ delta.save_pretrained(target_model_path)
37
+ delta_tokenizer.save_pretrained(target_model_path)
38
+
39
+
40
+ if __name__ == "__main__":
41
+ parser = argparse.ArgumentParser()
42
+ parser.add_argument("--base-model-path", type=str, required=True)
43
+ parser.add_argument("--target-model-path", type=str, required=True)
44
+ parser.add_argument("--delta-path", type=str, required=True)
45
+
46
+ args = parser.parse_args()
47
+
48
+ apply_delta(args.base_model_path, args.target_model_path, args.delta_path)
ola_vlm/model/aux_heads/.DS_Store ADDED
Binary file (6.15 kB). View file
 
ola_vlm/model/aux_heads/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .da_v2_head import DepthHead, DAv2_Head, DepthProbeHead, TaskTokenDepthHead
2
+ from .oneformer_head import OneFormerSegHead, OneFormerTaskTokenSegHead
3
+ from .gen_head import GenHead, TaskTokenGenHead
ola_vlm/model/aux_heads/da_v2_head.py ADDED
@@ -0,0 +1,457 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from ola_vlm.model.multimodal_projector.resampler import Resampler, TaskTokenResampler
6
+
7
+
8
+ def _make_scratch(in_shape, out_shape, groups=1, expand=False):
9
+ scratch = nn.Module()
10
+
11
+ out_shape1 = out_shape
12
+ out_shape2 = out_shape
13
+ out_shape3 = out_shape
14
+ if len(in_shape) >= 4:
15
+ out_shape4 = out_shape
16
+
17
+ if expand:
18
+ out_shape1 = out_shape
19
+ out_shape2 = out_shape * 2
20
+ out_shape3 = out_shape * 4
21
+ if len(in_shape) >= 4:
22
+ out_shape4 = out_shape * 8
23
+
24
+ scratch.layer1_rn = nn.Conv2d(in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
25
+ scratch.layer2_rn = nn.Conv2d(in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
26
+ scratch.layer3_rn = nn.Conv2d(in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
27
+ if len(in_shape) >= 4:
28
+ scratch.layer4_rn = nn.Conv2d(in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
29
+
30
+ return scratch
31
+
32
+
33
+ class ResidualConvUnit(nn.Module):
34
+ """Residual convolution module.
35
+ """
36
+
37
+ def __init__(self, features, activation, bn):
38
+ """Init.
39
+
40
+ Args:
41
+ features (int): number of features
42
+ """
43
+ super().__init__()
44
+
45
+ self.bn = bn
46
+
47
+ self.groups=1
48
+
49
+ self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
50
+
51
+ self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
52
+
53
+ if self.bn == True:
54
+ self.bn1 = nn.BatchNorm2d(features)
55
+ self.bn2 = nn.BatchNorm2d(features)
56
+
57
+ self.activation = activation
58
+
59
+ self.skip_add = nn.quantized.FloatFunctional()
60
+
61
+ def forward(self, x):
62
+ """Forward pass.
63
+
64
+ Args:
65
+ x (tensor): input
66
+
67
+ Returns:
68
+ tensor: output
69
+ """
70
+
71
+ out = self.activation(x)
72
+ out = self.conv1(out)
73
+ if self.bn == True:
74
+ out = self.bn1(out)
75
+
76
+ out = self.activation(out)
77
+ out = self.conv2(out)
78
+ if self.bn == True:
79
+ out = self.bn2(out)
80
+
81
+ if self.groups > 1:
82
+ out = self.conv_merge(out)
83
+
84
+ return self.skip_add.add(out, x)
85
+
86
+
87
+ class FeatureFusionBlock(nn.Module):
88
+ """Feature fusion block.
89
+ """
90
+
91
+ def __init__(
92
+ self,
93
+ features,
94
+ activation,
95
+ deconv=False,
96
+ bn=False,
97
+ expand=False,
98
+ align_corners=True,
99
+ size=None
100
+ ):
101
+ """Init.
102
+
103
+ Args:
104
+ features (int): number of features
105
+ """
106
+ super(FeatureFusionBlock, self).__init__()
107
+
108
+ self.deconv = deconv
109
+ self.align_corners = align_corners
110
+
111
+ self.groups=1
112
+
113
+ self.expand = expand
114
+ out_features = features
115
+ if self.expand == True:
116
+ out_features = features // 2
117
+
118
+ self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
119
+
120
+ self.resConfUnit1 = ResidualConvUnit(features, activation, bn)
121
+ self.resConfUnit2 = ResidualConvUnit(features, activation, bn)
122
+
123
+ self.skip_add = nn.quantized.FloatFunctional()
124
+
125
+ self.size=size
126
+
127
+ def forward(self, *xs, size=None):
128
+ """Forward pass.
129
+
130
+ Returns:
131
+ tensor: output
132
+ """
133
+ output = xs[0]
134
+
135
+ if len(xs) == 2:
136
+ res = self.resConfUnit1(xs[1])
137
+ output = self.skip_add.add(output, res)
138
+
139
+ output = self.resConfUnit2(output)
140
+
141
+ if (size is None) and (self.size is None):
142
+ modifier = {"scale_factor": 2}
143
+ elif size is None:
144
+ modifier = {"size": self.size}
145
+ else:
146
+ modifier = {"size": size}
147
+
148
+ output = nn.functional.interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
149
+
150
+ output = self.out_conv(output)
151
+
152
+ return output
153
+
154
+
155
+ def _make_fusion_block(features, use_bn, size=None):
156
+ return FeatureFusionBlock(
157
+ features,
158
+ nn.ReLU(False),
159
+ deconv=False,
160
+ bn=use_bn,
161
+ expand=False,
162
+ align_corners=True,
163
+ size=size,
164
+ )
165
+
166
+
167
+ class ConvBlock(nn.Module):
168
+ def __init__(self, in_feature, out_feature):
169
+ super().__init__()
170
+
171
+ self.conv_block = nn.Sequential(
172
+ nn.Conv2d(in_feature, out_feature, kernel_size=3, stride=1, padding=1),
173
+ nn.BatchNorm2d(out_feature),
174
+ nn.ReLU(True)
175
+ )
176
+
177
+ def forward(self, x):
178
+ return self.conv_block(x)
179
+
180
+
181
+ class DPTHead(nn.Module):
182
+ def __init__(
183
+ self,
184
+ in_channels,
185
+ features=256,
186
+ use_bn=False,
187
+ out_channels=[256, 512, 1024, 1024],
188
+ use_clstoken=False
189
+ ):
190
+ super(DPTHead, self).__init__()
191
+
192
+ self.use_clstoken = use_clstoken
193
+
194
+ self.projects = nn.ModuleList([
195
+ nn.Conv2d(
196
+ in_channels=in_channels,
197
+ out_channels=out_channel,
198
+ kernel_size=1,
199
+ stride=1,
200
+ padding=0,
201
+ ) for out_channel in out_channels
202
+ ])
203
+
204
+ self.resize_layers = nn.ModuleList([
205
+ nn.ConvTranspose2d(
206
+ in_channels=out_channels[0],
207
+ out_channels=out_channels[0],
208
+ kernel_size=4,
209
+ stride=4,
210
+ padding=0),
211
+ nn.ConvTranspose2d(
212
+ in_channels=out_channels[1],
213
+ out_channels=out_channels[1],
214
+ kernel_size=2,
215
+ stride=2,
216
+ padding=0),
217
+ nn.Identity(),
218
+ nn.Conv2d(
219
+ in_channels=out_channels[3],
220
+ out_channels=out_channels[3],
221
+ kernel_size=3,
222
+ stride=2,
223
+ padding=1)
224
+ ])
225
+
226
+ if use_clstoken:
227
+ self.readout_projects = nn.ModuleList()
228
+ for _ in range(len(self.projects)):
229
+ self.readout_projects.append(
230
+ nn.Sequential(
231
+ nn.Linear(2 * in_channels, in_channels),
232
+ nn.GELU()))
233
+
234
+ self.scratch = _make_scratch(
235
+ out_channels,
236
+ features,
237
+ groups=1,
238
+ expand=False,
239
+ )
240
+
241
+ self.scratch.stem_transpose = None
242
+
243
+ self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
244
+ self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
245
+ self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
246
+ self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
247
+
248
+ head_features_1 = features
249
+ head_features_2 = 32
250
+
251
+ self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1)
252
+ self.scratch.output_conv2 = nn.Sequential(
253
+ nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1),
254
+ nn.ReLU(True),
255
+ nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0),
256
+ nn.ReLU(True),
257
+ nn.Identity(),
258
+ )
259
+
260
+ def forward(self, out_features, patch_h, patch_w):
261
+ out = []
262
+ for i, x in enumerate(out_features):
263
+ if self.use_clstoken:
264
+ x, cls_token = x[0], x[1]
265
+ readout = cls_token.unsqueeze(1).expand_as(x)
266
+ x = self.readout_projects[i](torch.cat((x, readout), -1))
267
+ else:
268
+ x = x[0]
269
+
270
+ x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
271
+
272
+ x = self.projects[i](x)
273
+ x = self.resize_layers[i](x)
274
+
275
+ out.append(x)
276
+
277
+ layer_1, layer_2, layer_3, layer_4 = out
278
+
279
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
280
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
281
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
282
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
283
+
284
+ path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
285
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
286
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
287
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
288
+
289
+ out = self.scratch.output_conv1(path_1)
290
+ out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True)
291
+ out = self.scratch.output_conv2(out)
292
+
293
+ return out
294
+
295
+
296
+ class DAv2_Head(nn.Module):
297
+ def __init__(
298
+ self,
299
+ encoder='vitl',
300
+ features=256,
301
+ out_channels=[256, 512, 1024, 1024],
302
+ use_bn=False,
303
+ use_clstoken=False
304
+ ):
305
+ super(DAv2_Head, self).__init__()
306
+
307
+ self.embd_dims = {
308
+ 'vits': 1024,
309
+ 'vitb': 1024,
310
+ 'vitl': 1024,
311
+ 'vitg': 1024,
312
+ }
313
+
314
+ self.depth_head = DPTHead(self.embd_dims[encoder], features, use_bn, out_channels=out_channels, use_clstoken=use_clstoken)
315
+
316
+ def forward(self, features):
317
+ patch_h, patch_w = 336 // 14, 336 // 14
318
+ depth = self.depth_head(features, patch_h, patch_w)
319
+ depth = F.relu(depth)
320
+
321
+ return depth.squeeze(1)
322
+
323
+ @torch.no_grad()
324
+ def infer_feats(self, feats, image_size=(336, 336)):
325
+ h, w = image_size
326
+ depth = self.forward(feats)
327
+
328
+ depth = F.interpolate(depth[:, None], (h, w), mode="bilinear", align_corners=True)[0, 0]
329
+ return depth.cpu().numpy()
330
+
331
+ def build_mlp(in_hidden_size, hidden_size):
332
+ modules = [nn.Linear(in_hidden_size, hidden_size)]
333
+ modules.append(nn.ReLU())
334
+ modules.append(nn.Linear(hidden_size, hidden_size))
335
+ return nn.Sequential(*modules)
336
+
337
+ def build_expand_mlp(in_hidden_size, hidden_size, out_size):
338
+ modules = [nn.Linear(in_hidden_size, hidden_size)]
339
+ modules.append(nn.ReLU())
340
+ modules.append(nn.Linear(hidden_size, hidden_size))
341
+ modules.append(nn.ReLU())
342
+ modules.append(nn.Linear(hidden_size, out_size))
343
+ return nn.Sequential(*modules)
344
+
345
+ class DepthProbeHead(nn.Module):
346
+ def __init__(
347
+ self,
348
+ llm_hidden_size=4096,
349
+ proj_config=None,
350
+ ):
351
+ super(DepthProbeHead, self).__init__()
352
+
353
+ self.linear_1 = build_mlp(llm_hidden_size, proj_config["output_dim"])
354
+ self.linear_2 = build_mlp(llm_hidden_size, proj_config["output_dim"])
355
+ self.linear_3 = build_mlp(llm_hidden_size, proj_config["output_dim"])
356
+ self.linear_4 = build_mlp(llm_hidden_size, proj_config["output_dim"])
357
+
358
+ # self._init_weights()
359
+
360
+ # def _init_weights(self):
361
+ # for m in self.modules():
362
+ # if isinstance(m, nn.Linear):
363
+ # nn.init.xavier_uniform_(m.weight)
364
+ # if m.bias is not None:
365
+ # nn.init.constant_(m.bias, 0)
366
+
367
+ def forward(self, llm_feats):
368
+
369
+ features = [(self.linear_1(llm_feats), None),
370
+ (self.linear_1(llm_feats), None),
371
+ (self.linear_2(llm_feats), None),
372
+ (self.linear_3(llm_feats), None)
373
+ ]
374
+
375
+ return features
376
+
377
+ class DepthHead(nn.Module):
378
+ def __init__(
379
+ self,
380
+ llm_hidden_size=4096,
381
+ proj_config=None,
382
+ use_intermediate_depth=False,
383
+ ):
384
+ super(DepthHead, self).__init__()
385
+
386
+ self.projector = Resampler(
387
+ dim=proj_config["output_dim"],
388
+ depth=proj_config["depth"],
389
+ dim_head=proj_config["dim_head"],
390
+ heads=proj_config["num_heads"],
391
+ num_queries=proj_config["num_tokens"],
392
+ embedding_dim=llm_hidden_size,
393
+ output_dim=proj_config["output_dim"],
394
+ ff_mult=proj_config["ff_mult"],
395
+ )
396
+
397
+ self.use_intermediate_depth = use_intermediate_depth
398
+
399
+ if self.use_intermediate_depth:
400
+ self.linear_1 = build_mlp(proj_config["output_dim"], proj_config["output_dim"])
401
+ self.linear_2 = build_mlp(proj_config["output_dim"], proj_config["output_dim"])
402
+ self.linear_3 = build_mlp(proj_config["output_dim"], proj_config["output_dim"])
403
+
404
+ def forward(self, llm_feats):
405
+ visual_feats = self.projector(llm_feats)
406
+
407
+ features = []
408
+
409
+ if self.use_intermediate_depth:
410
+ features.append((self.linear_1(visual_feats), None))
411
+ features.append((self.linear_2(visual_feats), None))
412
+ features.append((self.linear_3(visual_feats), None))
413
+
414
+ features.append((visual_feats, None))
415
+
416
+ return features
417
+
418
+ class TaskTokenDepthHead(nn.Module):
419
+ def __init__(
420
+ self,
421
+ proj_config=None,
422
+ llm_hidden_size=4096,
423
+ use_intermediate_depth=False,
424
+ ):
425
+ super(TaskTokenDepthHead, self).__init__()
426
+
427
+ self.projector = TaskTokenResampler(
428
+ dim=llm_hidden_size,
429
+ depth=proj_config["depth"],
430
+ dim_head=proj_config["dim_head"],
431
+ heads=proj_config["num_heads"],
432
+ num_queries=proj_config["num_tokens"],
433
+ embedding_dim=llm_hidden_size,
434
+ output_dim=proj_config["output_dim"],
435
+ ff_mult=proj_config["ff_mult"],
436
+ )
437
+ self.use_intermediate_depth = use_intermediate_depth
438
+
439
+ if self.use_intermediate_depth:
440
+ self.linear_1 = build_mlp(proj_config["output_dim"], proj_config["output_dim"])
441
+ self.linear_2 = build_mlp(proj_config["output_dim"], proj_config["output_dim"])
442
+ self.linear_3 = build_mlp(proj_config["output_dim"], proj_config["output_dim"])
443
+
444
+ def forward(self, llm_feats, latents):
445
+
446
+ visual_feats = self.projector(llm_feats, latents)
447
+
448
+ features = []
449
+
450
+ if self.use_intermediate_depth:
451
+ features.append((self.linear_1(visual_feats), None))
452
+ features.append((self.linear_2(visual_feats), None))
453
+ features.append((self.linear_3(visual_feats), None))
454
+
455
+ features.append((visual_feats, None))
456
+
457
+ return features
ola_vlm/model/aux_heads/depth_anything_v2/dinov2.py ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
9
+
10
+ from functools import partial
11
+ import math
12
+ import logging
13
+ from typing import Sequence, Tuple, Union, Callable
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.utils.checkpoint
18
+ from torch.nn.init import trunc_normal_
19
+
20
+ from .dinov2_layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
21
+
22
+
23
+ logger = logging.getLogger("dinov2")
24
+
25
+
26
+ def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
27
+ if not depth_first and include_root:
28
+ fn(module=module, name=name)
29
+ for child_name, child_module in module.named_children():
30
+ child_name = ".".join((name, child_name)) if name else child_name
31
+ named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
32
+ if depth_first and include_root:
33
+ fn(module=module, name=name)
34
+ return module
35
+
36
+
37
+ class BlockChunk(nn.ModuleList):
38
+ def forward(self, x):
39
+ for b in self:
40
+ x = b(x)
41
+ return x
42
+
43
+
44
+ class DinoVisionTransformer(nn.Module):
45
+ def __init__(
46
+ self,
47
+ img_size=224,
48
+ patch_size=16,
49
+ in_chans=3,
50
+ embed_dim=768,
51
+ depth=12,
52
+ num_heads=12,
53
+ mlp_ratio=4.0,
54
+ qkv_bias=True,
55
+ ffn_bias=True,
56
+ proj_bias=True,
57
+ drop_path_rate=0.0,
58
+ drop_path_uniform=False,
59
+ init_values=None, # for layerscale: None or 0 => no layerscale
60
+ embed_layer=PatchEmbed,
61
+ act_layer=nn.GELU,
62
+ block_fn=Block,
63
+ ffn_layer="mlp",
64
+ block_chunks=1,
65
+ num_register_tokens=0,
66
+ interpolate_antialias=False,
67
+ interpolate_offset=0.1,
68
+ ):
69
+ """
70
+ Args:
71
+ img_size (int, tuple): input image size
72
+ patch_size (int, tuple): patch size
73
+ in_chans (int): number of input channels
74
+ embed_dim (int): embedding dimension
75
+ depth (int): depth of transformer
76
+ num_heads (int): number of attention heads
77
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
78
+ qkv_bias (bool): enable bias for qkv if True
79
+ proj_bias (bool): enable bias for proj in attn if True
80
+ ffn_bias (bool): enable bias for ffn if True
81
+ drop_path_rate (float): stochastic depth rate
82
+ drop_path_uniform (bool): apply uniform drop rate across blocks
83
+ weight_init (str): weight init scheme
84
+ init_values (float): layer-scale init values
85
+ embed_layer (nn.Module): patch embedding layer
86
+ act_layer (nn.Module): MLP activation layer
87
+ block_fn (nn.Module): transformer block class
88
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
89
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
90
+ num_register_tokens: (int) number of extra cls tokens (so-called "registers")
91
+ interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
92
+ interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
93
+ """
94
+ super().__init__()
95
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
96
+
97
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
98
+ self.num_tokens = 1
99
+ self.n_blocks = depth
100
+ self.num_heads = num_heads
101
+ self.patch_size = patch_size
102
+ self.num_register_tokens = num_register_tokens
103
+ self.interpolate_antialias = interpolate_antialias
104
+ self.interpolate_offset = interpolate_offset
105
+
106
+ self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
107
+ num_patches = self.patch_embed.num_patches
108
+
109
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
110
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
111
+ assert num_register_tokens >= 0
112
+ self.register_tokens = (
113
+ nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
114
+ )
115
+
116
+ if drop_path_uniform is True:
117
+ dpr = [drop_path_rate] * depth
118
+ else:
119
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
120
+
121
+ if ffn_layer == "mlp":
122
+ logger.info("using MLP layer as FFN")
123
+ ffn_layer = Mlp
124
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
125
+ logger.info("using SwiGLU layer as FFN")
126
+ ffn_layer = SwiGLUFFNFused
127
+ elif ffn_layer == "identity":
128
+ logger.info("using Identity layer as FFN")
129
+
130
+ def f(*args, **kwargs):
131
+ return nn.Identity()
132
+
133
+ ffn_layer = f
134
+ else:
135
+ raise NotImplementedError
136
+
137
+ blocks_list = [
138
+ block_fn(
139
+ dim=embed_dim,
140
+ num_heads=num_heads,
141
+ mlp_ratio=mlp_ratio,
142
+ qkv_bias=qkv_bias,
143
+ proj_bias=proj_bias,
144
+ ffn_bias=ffn_bias,
145
+ drop_path=dpr[i],
146
+ norm_layer=norm_layer,
147
+ act_layer=act_layer,
148
+ ffn_layer=ffn_layer,
149
+ init_values=init_values,
150
+ )
151
+ for i in range(depth)
152
+ ]
153
+ if block_chunks > 0:
154
+ self.chunked_blocks = True
155
+ chunked_blocks = []
156
+ chunksize = depth // block_chunks
157
+ for i in range(0, depth, chunksize):
158
+ # this is to keep the block index consistent if we chunk the block list
159
+ chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
160
+ self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
161
+ else:
162
+ self.chunked_blocks = False
163
+ self.blocks = nn.ModuleList(blocks_list)
164
+
165
+ self.norm = norm_layer(embed_dim)
166
+ self.head = nn.Identity()
167
+
168
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
169
+
170
+ self.init_weights()
171
+
172
+ def init_weights(self):
173
+ trunc_normal_(self.pos_embed, std=0.02)
174
+ nn.init.normal_(self.cls_token, std=1e-6)
175
+ if self.register_tokens is not None:
176
+ nn.init.normal_(self.register_tokens, std=1e-6)
177
+ named_apply(init_weights_vit_timm, self)
178
+
179
+ def interpolate_pos_encoding(self, x, w, h):
180
+ previous_dtype = x.dtype
181
+ npatch = x.shape[1] - 1
182
+ N = self.pos_embed.shape[1] - 1
183
+ if npatch == N and w == h:
184
+ return self.pos_embed
185
+ pos_embed = self.pos_embed.float()
186
+ class_pos_embed = pos_embed[:, 0]
187
+ patch_pos_embed = pos_embed[:, 1:]
188
+ dim = x.shape[-1]
189
+ w0 = w // self.patch_size
190
+ h0 = h // self.patch_size
191
+ # we add a small number to avoid floating point error in the interpolation
192
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
193
+ # DINOv2 with register modify the interpolate_offset from 0.1 to 0.0
194
+ w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset
195
+ # w0, h0 = w0 + 0.1, h0 + 0.1
196
+
197
+ sqrt_N = math.sqrt(N)
198
+ sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N
199
+ patch_pos_embed = nn.functional.interpolate(
200
+ patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2),
201
+ scale_factor=(sx, sy),
202
+ # (int(w0), int(h0)), # to solve the upsampling shape issue
203
+ mode="bicubic",
204
+ antialias=self.interpolate_antialias
205
+ )
206
+
207
+ assert int(w0) == patch_pos_embed.shape[-2]
208
+ assert int(h0) == patch_pos_embed.shape[-1]
209
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
210
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
211
+
212
+ def prepare_tokens_with_masks(self, x, masks=None):
213
+ B, nc, w, h = x.shape
214
+ x = self.patch_embed(x)
215
+ if masks is not None:
216
+ x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
217
+
218
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
219
+ x = x + self.interpolate_pos_encoding(x, w, h)
220
+
221
+ if self.register_tokens is not None:
222
+ x = torch.cat(
223
+ (
224
+ x[:, :1],
225
+ self.register_tokens.expand(x.shape[0], -1, -1),
226
+ x[:, 1:],
227
+ ),
228
+ dim=1,
229
+ )
230
+
231
+ return x
232
+
233
+ def forward_features_list(self, x_list, masks_list):
234
+ x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
235
+ for blk in self.blocks:
236
+ x = blk(x)
237
+
238
+ all_x = x
239
+ output = []
240
+ for x, masks in zip(all_x, masks_list):
241
+ x_norm = self.norm(x)
242
+ output.append(
243
+ {
244
+ "x_norm_clstoken": x_norm[:, 0],
245
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
246
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
247
+ "x_prenorm": x,
248
+ "masks": masks,
249
+ }
250
+ )
251
+ return output
252
+
253
+ def forward_features(self, x, masks=None):
254
+ if isinstance(x, list):
255
+ return self.forward_features_list(x, masks)
256
+
257
+ x = self.prepare_tokens_with_masks(x, masks)
258
+
259
+ for blk in self.blocks:
260
+ x = blk(x)
261
+
262
+ x_norm = self.norm(x)
263
+ return {
264
+ "x_norm_clstoken": x_norm[:, 0],
265
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
266
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
267
+ "x_prenorm": x,
268
+ "masks": masks,
269
+ }
270
+
271
+ def _get_intermediate_layers_not_chunked(self, x, n=1):
272
+ x = self.prepare_tokens_with_masks(x)
273
+ # If n is an int, take the n last blocks. If it's a list, take them
274
+ output, total_block_len = [], len(self.blocks)
275
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
276
+ for i, blk in enumerate(self.blocks):
277
+ x = blk(x)
278
+ if i in blocks_to_take:
279
+ output.append(x)
280
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
281
+ return output
282
+
283
+ def _get_intermediate_layers_chunked(self, x, n=1):
284
+ x = self.prepare_tokens_with_masks(x)
285
+ output, i, total_block_len = [], 0, len(self.blocks[-1])
286
+ # If n is an int, take the n last blocks. If it's a list, take them
287
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
288
+ for block_chunk in self.blocks:
289
+ for blk in block_chunk[i:]: # Passing the nn.Identity()
290
+ x = blk(x)
291
+ if i in blocks_to_take:
292
+ output.append(x)
293
+ i += 1
294
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
295
+ return output
296
+
297
+ def get_intermediate_layers(
298
+ self,
299
+ x: torch.Tensor,
300
+ n: Union[int, Sequence] = 1, # Layers or n last layers to take
301
+ reshape: bool = False,
302
+ return_class_token: bool = False,
303
+ norm=True
304
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
305
+ if self.chunked_blocks:
306
+ outputs = self._get_intermediate_layers_chunked(x, n)
307
+ else:
308
+ outputs = self._get_intermediate_layers_not_chunked(x, n)
309
+ if norm:
310
+ outputs = [self.norm(out) for out in outputs]
311
+ class_tokens = [out[:, 0] for out in outputs]
312
+ outputs = [out[:, 1 + self.num_register_tokens:] for out in outputs]
313
+ if reshape:
314
+ B, _, w, h = x.shape
315
+ outputs = [
316
+ out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
317
+ for out in outputs
318
+ ]
319
+ if return_class_token:
320
+ return tuple(zip(outputs, class_tokens))
321
+ return tuple(outputs)
322
+
323
+ def forward(self, *args, is_training=False, **kwargs):
324
+ ret = self.forward_features(*args, **kwargs)
325
+ if is_training:
326
+ return ret
327
+ else:
328
+ return self.head(ret["x_norm_clstoken"])
329
+
330
+
331
+ def init_weights_vit_timm(module: nn.Module, name: str = ""):
332
+ """ViT weight initialization, original timm impl (for reproducibility)"""
333
+ if isinstance(module, nn.Linear):
334
+ trunc_normal_(module.weight, std=0.02)
335
+ if module.bias is not None:
336
+ nn.init.zeros_(module.bias)
337
+
338
+
339
+ def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
340
+ model = DinoVisionTransformer(
341
+ patch_size=patch_size,
342
+ embed_dim=384,
343
+ depth=12,
344
+ num_heads=6,
345
+ mlp_ratio=4,
346
+ block_fn=partial(Block, attn_class=MemEffAttention),
347
+ num_register_tokens=num_register_tokens,
348
+ **kwargs,
349
+ )
350
+ return model
351
+
352
+
353
+ def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
354
+ model = DinoVisionTransformer(
355
+ patch_size=patch_size,
356
+ embed_dim=768,
357
+ depth=12,
358
+ num_heads=12,
359
+ mlp_ratio=4,
360
+ block_fn=partial(Block, attn_class=MemEffAttention),
361
+ num_register_tokens=num_register_tokens,
362
+ **kwargs,
363
+ )
364
+ return model
365
+
366
+
367
+ def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
368
+ model = DinoVisionTransformer(
369
+ patch_size=patch_size,
370
+ embed_dim=1024,
371
+ depth=24,
372
+ num_heads=16,
373
+ mlp_ratio=4,
374
+ block_fn=partial(Block, attn_class=MemEffAttention),
375
+ num_register_tokens=num_register_tokens,
376
+ **kwargs,
377
+ )
378
+ return model
379
+
380
+
381
+ def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
382
+ """
383
+ Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
384
+ """
385
+ model = DinoVisionTransformer(
386
+ patch_size=patch_size,
387
+ embed_dim=1536,
388
+ depth=40,
389
+ num_heads=24,
390
+ mlp_ratio=4,
391
+ block_fn=partial(Block, attn_class=MemEffAttention),
392
+ num_register_tokens=num_register_tokens,
393
+ **kwargs,
394
+ )
395
+ return model
396
+
397
+
398
+ def DINOv2(model_name):
399
+ model_zoo = {
400
+ "vits": vit_small,
401
+ "vitb": vit_base,
402
+ "vitl": vit_large,
403
+ "vitg": vit_giant2
404
+ }
405
+
406
+ return model_zoo[model_name](
407
+ img_size=518,
408
+ patch_size=14,
409
+ init_values=1.0,
410
+ ffn_layer="mlp" if model_name != "vitg" else "swiglufused",
411
+ block_chunks=0,
412
+ num_register_tokens=0,
413
+ interpolate_antialias=False,
414
+ interpolate_offset=0.1
415
+ )
ola_vlm/model/aux_heads/depth_anything_v2/dinov2_layers/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from .mlp import Mlp
8
+ from .patch_embed import PatchEmbed
9
+ from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
10
+ from .block import NestedTensorBlock
11
+ from .attention import MemEffAttention
ola_vlm/model/aux_heads/depth_anything_v2/dinov2_layers/attention.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # References:
8
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
10
+
11
+ import logging
12
+
13
+ from torch import Tensor
14
+ from torch import nn
15
+
16
+
17
+ logger = logging.getLogger("dinov2")
18
+
19
+
20
+ try:
21
+ from xformers.ops import memory_efficient_attention, unbind, fmha
22
+
23
+ XFORMERS_AVAILABLE = True
24
+ except ImportError:
25
+ logger.warning("xFormers not available")
26
+ XFORMERS_AVAILABLE = False
27
+
28
+
29
+ class Attention(nn.Module):
30
+ def __init__(
31
+ self,
32
+ dim: int,
33
+ num_heads: int = 8,
34
+ qkv_bias: bool = False,
35
+ proj_bias: bool = True,
36
+ attn_drop: float = 0.0,
37
+ proj_drop: float = 0.0,
38
+ ) -> None:
39
+ super().__init__()
40
+ self.num_heads = num_heads
41
+ head_dim = dim // num_heads
42
+ self.scale = head_dim**-0.5
43
+
44
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
45
+ self.attn_drop = nn.Dropout(attn_drop)
46
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
47
+ self.proj_drop = nn.Dropout(proj_drop)
48
+
49
+ def forward(self, x: Tensor) -> Tensor:
50
+ B, N, C = x.shape
51
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
52
+
53
+ q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
54
+ attn = q @ k.transpose(-2, -1)
55
+
56
+ attn = attn.softmax(dim=-1)
57
+ attn = self.attn_drop(attn)
58
+
59
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
60
+ x = self.proj(x)
61
+ x = self.proj_drop(x)
62
+ return x
63
+
64
+
65
+ class MemEffAttention(Attention):
66
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
67
+ if not XFORMERS_AVAILABLE:
68
+ assert attn_bias is None, "xFormers is required for nested tensors usage"
69
+ return super().forward(x)
70
+
71
+ B, N, C = x.shape
72
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
73
+
74
+ q, k, v = unbind(qkv, 2)
75
+
76
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
77
+ x = x.reshape([B, N, C])
78
+
79
+ x = self.proj(x)
80
+ x = self.proj_drop(x)
81
+ return x
82
+
83
+
ola_vlm/model/aux_heads/depth_anything_v2/dinov2_layers/block.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # References:
8
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
10
+
11
+ import logging
12
+ from typing import Callable, List, Any, Tuple, Dict
13
+
14
+ import torch
15
+ from torch import nn, Tensor
16
+
17
+ from .attention import Attention, MemEffAttention
18
+ from .drop_path import DropPath
19
+ from .layer_scale import LayerScale
20
+ from .mlp import Mlp
21
+
22
+
23
+ logger = logging.getLogger("dinov2")
24
+
25
+
26
+ try:
27
+ from xformers.ops import fmha
28
+ from xformers.ops import scaled_index_add, index_select_cat
29
+
30
+ XFORMERS_AVAILABLE = True
31
+ except ImportError:
32
+ logger.warning("xFormers not available")
33
+ XFORMERS_AVAILABLE = False
34
+
35
+
36
+ class Block(nn.Module):
37
+ def __init__(
38
+ self,
39
+ dim: int,
40
+ num_heads: int,
41
+ mlp_ratio: float = 4.0,
42
+ qkv_bias: bool = False,
43
+ proj_bias: bool = True,
44
+ ffn_bias: bool = True,
45
+ drop: float = 0.0,
46
+ attn_drop: float = 0.0,
47
+ init_values=None,
48
+ drop_path: float = 0.0,
49
+ act_layer: Callable[..., nn.Module] = nn.GELU,
50
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
51
+ attn_class: Callable[..., nn.Module] = Attention,
52
+ ffn_layer: Callable[..., nn.Module] = Mlp,
53
+ ) -> None:
54
+ super().__init__()
55
+ # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
56
+ self.norm1 = norm_layer(dim)
57
+ self.attn = attn_class(
58
+ dim,
59
+ num_heads=num_heads,
60
+ qkv_bias=qkv_bias,
61
+ proj_bias=proj_bias,
62
+ attn_drop=attn_drop,
63
+ proj_drop=drop,
64
+ )
65
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
66
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
67
+
68
+ self.norm2 = norm_layer(dim)
69
+ mlp_hidden_dim = int(dim * mlp_ratio)
70
+ self.mlp = ffn_layer(
71
+ in_features=dim,
72
+ hidden_features=mlp_hidden_dim,
73
+ act_layer=act_layer,
74
+ drop=drop,
75
+ bias=ffn_bias,
76
+ )
77
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
78
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
79
+
80
+ self.sample_drop_ratio = drop_path
81
+
82
+ def forward(self, x: Tensor) -> Tensor:
83
+ def attn_residual_func(x: Tensor) -> Tensor:
84
+ return self.ls1(self.attn(self.norm1(x)))
85
+
86
+ def ffn_residual_func(x: Tensor) -> Tensor:
87
+ return self.ls2(self.mlp(self.norm2(x)))
88
+
89
+ if self.training and self.sample_drop_ratio > 0.1:
90
+ # the overhead is compensated only for a drop path rate larger than 0.1
91
+ x = drop_add_residual_stochastic_depth(
92
+ x,
93
+ residual_func=attn_residual_func,
94
+ sample_drop_ratio=self.sample_drop_ratio,
95
+ )
96
+ x = drop_add_residual_stochastic_depth(
97
+ x,
98
+ residual_func=ffn_residual_func,
99
+ sample_drop_ratio=self.sample_drop_ratio,
100
+ )
101
+ elif self.training and self.sample_drop_ratio > 0.0:
102
+ x = x + self.drop_path1(attn_residual_func(x))
103
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
104
+ else:
105
+ x = x + attn_residual_func(x)
106
+ x = x + ffn_residual_func(x)
107
+ return x
108
+
109
+
110
+ def drop_add_residual_stochastic_depth(
111
+ x: Tensor,
112
+ residual_func: Callable[[Tensor], Tensor],
113
+ sample_drop_ratio: float = 0.0,
114
+ ) -> Tensor:
115
+ # 1) extract subset using permutation
116
+ b, n, d = x.shape
117
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
118
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
119
+ x_subset = x[brange]
120
+
121
+ # 2) apply residual_func to get residual
122
+ residual = residual_func(x_subset)
123
+
124
+ x_flat = x.flatten(1)
125
+ residual = residual.flatten(1)
126
+
127
+ residual_scale_factor = b / sample_subset_size
128
+
129
+ # 3) add the residual
130
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
131
+ return x_plus_residual.view_as(x)
132
+
133
+
134
+ def get_branges_scales(x, sample_drop_ratio=0.0):
135
+ b, n, d = x.shape
136
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
137
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
138
+ residual_scale_factor = b / sample_subset_size
139
+ return brange, residual_scale_factor
140
+
141
+
142
+ def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
143
+ if scaling_vector is None:
144
+ x_flat = x.flatten(1)
145
+ residual = residual.flatten(1)
146
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
147
+ else:
148
+ x_plus_residual = scaled_index_add(
149
+ x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
150
+ )
151
+ return x_plus_residual
152
+
153
+
154
+ attn_bias_cache: Dict[Tuple, Any] = {}
155
+
156
+
157
+ def get_attn_bias_and_cat(x_list, branges=None):
158
+ """
159
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
160
+ """
161
+ batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
162
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
163
+ if all_shapes not in attn_bias_cache.keys():
164
+ seqlens = []
165
+ for b, x in zip(batch_sizes, x_list):
166
+ for _ in range(b):
167
+ seqlens.append(x.shape[1])
168
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
169
+ attn_bias._batch_sizes = batch_sizes
170
+ attn_bias_cache[all_shapes] = attn_bias
171
+
172
+ if branges is not None:
173
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
174
+ else:
175
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
176
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
177
+
178
+ return attn_bias_cache[all_shapes], cat_tensors
179
+
180
+
181
+ def drop_add_residual_stochastic_depth_list(
182
+ x_list: List[Tensor],
183
+ residual_func: Callable[[Tensor, Any], Tensor],
184
+ sample_drop_ratio: float = 0.0,
185
+ scaling_vector=None,
186
+ ) -> Tensor:
187
+ # 1) generate random set of indices for dropping samples in the batch
188
+ branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
189
+ branges = [s[0] for s in branges_scales]
190
+ residual_scale_factors = [s[1] for s in branges_scales]
191
+
192
+ # 2) get attention bias and index+concat the tensors
193
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
194
+
195
+ # 3) apply residual_func to get residual, and split the result
196
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
197
+
198
+ outputs = []
199
+ for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
200
+ outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
201
+ return outputs
202
+
203
+
204
+ class NestedTensorBlock(Block):
205
+ def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
206
+ """
207
+ x_list contains a list of tensors to nest together and run
208
+ """
209
+ assert isinstance(self.attn, MemEffAttention)
210
+
211
+ if self.training and self.sample_drop_ratio > 0.0:
212
+
213
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
214
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
215
+
216
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
217
+ return self.mlp(self.norm2(x))
218
+
219
+ x_list = drop_add_residual_stochastic_depth_list(
220
+ x_list,
221
+ residual_func=attn_residual_func,
222
+ sample_drop_ratio=self.sample_drop_ratio,
223
+ scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
224
+ )
225
+ x_list = drop_add_residual_stochastic_depth_list(
226
+ x_list,
227
+ residual_func=ffn_residual_func,
228
+ sample_drop_ratio=self.sample_drop_ratio,
229
+ scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
230
+ )
231
+ return x_list
232
+ else:
233
+
234
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
235
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
236
+
237
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
238
+ return self.ls2(self.mlp(self.norm2(x)))
239
+
240
+ attn_bias, x = get_attn_bias_and_cat(x_list)
241
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
242
+ x = x + ffn_residual_func(x)
243
+ return attn_bias.split(x)
244
+
245
+ def forward(self, x_or_x_list):
246
+ if isinstance(x_or_x_list, Tensor):
247
+ return super().forward(x_or_x_list)
248
+ elif isinstance(x_or_x_list, list):
249
+ assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage"
250
+ return self.forward_nested(x_or_x_list)
251
+ else:
252
+ raise AssertionError
ola_vlm/model/aux_heads/depth_anything_v2/dinov2_layers/drop_path.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # References:
8
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
10
+
11
+
12
+ from torch import nn
13
+
14
+
15
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
16
+ if drop_prob == 0.0 or not training:
17
+ return x
18
+ keep_prob = 1 - drop_prob
19
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
20
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
21
+ if keep_prob > 0.0:
22
+ random_tensor.div_(keep_prob)
23
+ output = x * random_tensor
24
+ return output
25
+
26
+
27
+ class DropPath(nn.Module):
28
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
29
+
30
+ def __init__(self, drop_prob=None):
31
+ super(DropPath, self).__init__()
32
+ self.drop_prob = drop_prob
33
+
34
+ def forward(self, x):
35
+ return drop_path(x, self.drop_prob, self.training)
ola_vlm/model/aux_heads/depth_anything_v2/dinov2_layers/layer_scale.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
8
+
9
+ from typing import Union
10
+
11
+ import torch
12
+ from torch import Tensor
13
+ from torch import nn
14
+
15
+
16
+ class LayerScale(nn.Module):
17
+ def __init__(
18
+ self,
19
+ dim: int,
20
+ init_values: Union[float, Tensor] = 1e-5,
21
+ inplace: bool = False,
22
+ ) -> None:
23
+ super().__init__()
24
+ self.inplace = inplace
25
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
26
+
27
+ def forward(self, x: Tensor) -> Tensor:
28
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
ola_vlm/model/aux_heads/depth_anything_v2/dinov2_layers/mlp.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # References:
8
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
10
+
11
+
12
+ from typing import Callable, Optional
13
+
14
+ from torch import Tensor, nn
15
+
16
+
17
+ class Mlp(nn.Module):
18
+ def __init__(
19
+ self,
20
+ in_features: int,
21
+ hidden_features: Optional[int] = None,
22
+ out_features: Optional[int] = None,
23
+ act_layer: Callable[..., nn.Module] = nn.GELU,
24
+ drop: float = 0.0,
25
+ bias: bool = True,
26
+ ) -> None:
27
+ super().__init__()
28
+ out_features = out_features or in_features
29
+ hidden_features = hidden_features or in_features
30
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
31
+ self.act = act_layer()
32
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
33
+ self.drop = nn.Dropout(drop)
34
+
35
+ def forward(self, x: Tensor) -> Tensor:
36
+ x = self.fc1(x)
37
+ x = self.act(x)
38
+ x = self.drop(x)
39
+ x = self.fc2(x)
40
+ x = self.drop(x)
41
+ return x
ola_vlm/model/aux_heads/depth_anything_v2/dinov2_layers/patch_embed.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # References:
8
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
10
+
11
+ from typing import Callable, Optional, Tuple, Union
12
+
13
+ from torch import Tensor
14
+ import torch.nn as nn
15
+
16
+
17
+ def make_2tuple(x):
18
+ if isinstance(x, tuple):
19
+ assert len(x) == 2
20
+ return x
21
+
22
+ assert isinstance(x, int)
23
+ return (x, x)
24
+
25
+
26
+ class PatchEmbed(nn.Module):
27
+ """
28
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
29
+
30
+ Args:
31
+ img_size: Image size.
32
+ patch_size: Patch token size.
33
+ in_chans: Number of input image channels.
34
+ embed_dim: Number of linear projection output channels.
35
+ norm_layer: Normalization layer.
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ img_size: Union[int, Tuple[int, int]] = 224,
41
+ patch_size: Union[int, Tuple[int, int]] = 16,
42
+ in_chans: int = 3,
43
+ embed_dim: int = 768,
44
+ norm_layer: Optional[Callable] = None,
45
+ flatten_embedding: bool = True,
46
+ ) -> None:
47
+ super().__init__()
48
+
49
+ image_HW = make_2tuple(img_size)
50
+ patch_HW = make_2tuple(patch_size)
51
+ patch_grid_size = (
52
+ image_HW[0] // patch_HW[0],
53
+ image_HW[1] // patch_HW[1],
54
+ )
55
+
56
+ self.img_size = image_HW
57
+ self.patch_size = patch_HW
58
+ self.patches_resolution = patch_grid_size
59
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
60
+
61
+ self.in_chans = in_chans
62
+ self.embed_dim = embed_dim
63
+
64
+ self.flatten_embedding = flatten_embedding
65
+
66
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
67
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
68
+
69
+ def forward(self, x: Tensor) -> Tensor:
70
+ _, _, H, W = x.shape
71
+ patch_H, patch_W = self.patch_size
72
+
73
+ assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
74
+ assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
75
+
76
+ x = x.to(self.proj.bias.dtype)
77
+ x = self.proj(x) # B C H W
78
+ H, W = x.size(2), x.size(3)
79
+ x = x.flatten(2).transpose(1, 2) # B HW C
80
+ x = self.norm(x)
81
+ if not self.flatten_embedding:
82
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
83
+ return x
84
+
85
+ def flops(self) -> float:
86
+ Ho, Wo = self.patches_resolution
87
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
88
+ if self.norm is not None:
89
+ flops += Ho * Wo * self.embed_dim
90
+ return flops
ola_vlm/model/aux_heads/depth_anything_v2/dinov2_layers/swiglu_ffn.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Callable, Optional
8
+
9
+ from torch import Tensor, nn
10
+ import torch.nn.functional as F
11
+
12
+
13
+ class SwiGLUFFN(nn.Module):
14
+ def __init__(
15
+ self,
16
+ in_features: int,
17
+ hidden_features: Optional[int] = None,
18
+ out_features: Optional[int] = None,
19
+ act_layer: Callable[..., nn.Module] = None,
20
+ drop: float = 0.0,
21
+ bias: bool = True,
22
+ ) -> None:
23
+ super().__init__()
24
+ out_features = out_features or in_features
25
+ hidden_features = hidden_features or in_features
26
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
27
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
28
+
29
+ def forward(self, x: Tensor) -> Tensor:
30
+ x12 = self.w12(x)
31
+ x1, x2 = x12.chunk(2, dim=-1)
32
+ hidden = F.silu(x1) * x2
33
+ return self.w3(hidden)
34
+
35
+
36
+ try:
37
+ from xformers.ops import SwiGLU
38
+
39
+ XFORMERS_AVAILABLE = True
40
+ except ImportError:
41
+ SwiGLU = SwiGLUFFN
42
+ XFORMERS_AVAILABLE = False
43
+
44
+
45
+ class SwiGLUFFNFused(SwiGLU):
46
+ def __init__(
47
+ self,
48
+ in_features: int,
49
+ hidden_features: Optional[int] = None,
50
+ out_features: Optional[int] = None,
51
+ act_layer: Callable[..., nn.Module] = None,
52
+ drop: float = 0.0,
53
+ bias: bool = True,
54
+ ) -> None:
55
+ out_features = out_features or in_features
56
+ hidden_features = hidden_features or in_features
57
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
58
+ super().__init__(
59
+ in_features=in_features,
60
+ hidden_features=hidden_features,
61
+ out_features=out_features,
62
+ bias=bias,
63
+ )
ola_vlm/model/aux_heads/depth_anything_v2/dpt.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from torchvision.transforms import Compose
6
+
7
+ from .dinov2 import DINOv2
8
+ from .util.blocks import FeatureFusionBlock, _make_scratch
9
+ from .util.transform import Resize, NormalizeImage, PrepareForNet
10
+
11
+
12
+ def _make_fusion_block(features, use_bn, size=None):
13
+ return FeatureFusionBlock(
14
+ features,
15
+ nn.ReLU(False),
16
+ deconv=False,
17
+ bn=use_bn,
18
+ expand=False,
19
+ align_corners=True,
20
+ size=size,
21
+ )
22
+
23
+
24
+ class ConvBlock(nn.Module):
25
+ def __init__(self, in_feature, out_feature):
26
+ super().__init__()
27
+
28
+ self.conv_block = nn.Sequential(
29
+ nn.Conv2d(in_feature, out_feature, kernel_size=3, stride=1, padding=1),
30
+ nn.BatchNorm2d(out_feature),
31
+ nn.ReLU(True)
32
+ )
33
+
34
+ def forward(self, x):
35
+ return self.conv_block(x)
36
+
37
+
38
+ class DPTHead(nn.Module):
39
+ def __init__(
40
+ self,
41
+ in_channels,
42
+ features=256,
43
+ use_bn=False,
44
+ out_channels=[256, 512, 1024, 1024],
45
+ use_clstoken=False
46
+ ):
47
+ super(DPTHead, self).__init__()
48
+
49
+ self.use_clstoken = use_clstoken
50
+
51
+ self.projects = nn.ModuleList([
52
+ nn.Conv2d(
53
+ in_channels=in_channels,
54
+ out_channels=out_channel,
55
+ kernel_size=1,
56
+ stride=1,
57
+ padding=0,
58
+ ) for out_channel in out_channels
59
+ ])
60
+
61
+ self.resize_layers = nn.ModuleList([
62
+ nn.ConvTranspose2d(
63
+ in_channels=out_channels[0],
64
+ out_channels=out_channels[0],
65
+ kernel_size=4,
66
+ stride=4,
67
+ padding=0),
68
+ nn.ConvTranspose2d(
69
+ in_channels=out_channels[1],
70
+ out_channels=out_channels[1],
71
+ kernel_size=2,
72
+ stride=2,
73
+ padding=0),
74
+ nn.Identity(),
75
+ nn.Conv2d(
76
+ in_channels=out_channels[3],
77
+ out_channels=out_channels[3],
78
+ kernel_size=3,
79
+ stride=2,
80
+ padding=1)
81
+ ])
82
+
83
+ if use_clstoken:
84
+ self.readout_projects = nn.ModuleList()
85
+ for _ in range(len(self.projects)):
86
+ self.readout_projects.append(
87
+ nn.Sequential(
88
+ nn.Linear(2 * in_channels, in_channels),
89
+ nn.GELU()))
90
+
91
+ self.scratch = _make_scratch(
92
+ out_channels,
93
+ features,
94
+ groups=1,
95
+ expand=False,
96
+ )
97
+
98
+ self.scratch.stem_transpose = None
99
+
100
+ self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
101
+ self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
102
+ self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
103
+ self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
104
+
105
+ head_features_1 = features
106
+ head_features_2 = 32
107
+
108
+ self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1)
109
+ self.scratch.output_conv2 = nn.Sequential(
110
+ nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1),
111
+ nn.ReLU(True),
112
+ nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0),
113
+ nn.ReLU(True),
114
+ nn.Identity(),
115
+ )
116
+
117
+ def forward(self, out_features, patch_h, patch_w):
118
+ out = []
119
+ for i, x in enumerate(out_features):
120
+ if self.use_clstoken:
121
+ x, cls_token = x[0], x[1]
122
+ readout = cls_token.unsqueeze(1).expand_as(x)
123
+ x = self.readout_projects[i](torch.cat((x, readout), -1))
124
+ else:
125
+ x = x[0]
126
+
127
+ x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
128
+
129
+ x = self.projects[i](x)
130
+ x = self.resize_layers[i](x)
131
+
132
+ out.append(x)
133
+
134
+ layer_1, layer_2, layer_3, layer_4 = out
135
+
136
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
137
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
138
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
139
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
140
+
141
+ path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
142
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
143
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
144
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
145
+
146
+ out = self.scratch.output_conv1(path_1)
147
+ out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True)
148
+ out = self.scratch.output_conv2(out)
149
+
150
+ return out
151
+
152
+
153
+ class DepthAnythingV2(nn.Module):
154
+ def __init__(
155
+ self,
156
+ encoder='vitl',
157
+ features=256,
158
+ out_channels=[256, 512, 1024, 1024],
159
+ use_bn=False,
160
+ use_clstoken=False
161
+ ):
162
+ super(DepthAnythingV2, self).__init__()
163
+
164
+ self.intermediate_layer_idx = {
165
+ 'vits': [2, 5, 8, 11],
166
+ 'vitb': [2, 5, 8, 11],
167
+ 'vitl': [4, 11, 17, 23],
168
+ 'vitg': [9, 19, 29, 39]
169
+ }
170
+
171
+ self.encoder = encoder
172
+ self.pretrained = DINOv2(model_name=encoder)
173
+
174
+ self.depth_head = DPTHead(self.pretrained.embed_dim, features, use_bn, out_channels=out_channels, use_clstoken=use_clstoken)
175
+
176
+ def forward(self, x):
177
+ patch_h, patch_w = x.shape[-2] // 14, x.shape[-1] // 14
178
+ features = self.pretrained.get_intermediate_layers(x, self.intermediate_layer_idx[self.encoder], return_class_token=True)
179
+
180
+ return features
181
+
182
+ @torch.no_grad()
183
+ def infer_image(self, raw_image, input_size=336, is_dsg=False):
184
+ image, (h, w) = self.image2tensor(raw_image, input_size)
185
+
186
+ features = self.forward(image)
187
+ if is_dsg:
188
+ return features
189
+ # feats = torch.cat([f[0] for f in features], dim=2)
190
+ feats = features[-1][0]
191
+
192
+ return feats
193
+
194
+ def image2tensor(self, raw_image, input_size=518):
195
+ transform = Compose([
196
+ Resize(
197
+ width=input_size,
198
+ height=input_size,
199
+ resize_target=False,
200
+ keep_aspect_ratio=True,
201
+ ensure_multiple_of=14,
202
+ resize_method='lower_bound',
203
+ image_interpolation_method=cv2.INTER_CUBIC,
204
+ ),
205
+ NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
206
+ PrepareForNet(),
207
+ ])
208
+
209
+ h, w = raw_image.shape[:2]
210
+
211
+ image = cv2.cvtColor(raw_image, cv2.COLOR_BGR2RGB) / 255.0
212
+
213
+ image = transform({'image': image})['image']
214
+ image = torch.from_numpy(image).unsqueeze(0)
215
+
216
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
217
+ image = image.to(DEVICE)
218
+
219
+ return image, (h, w)
ola_vlm/model/aux_heads/depth_anything_v2/util/blocks.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+
4
+ def _make_scratch(in_shape, out_shape, groups=1, expand=False):
5
+ scratch = nn.Module()
6
+
7
+ out_shape1 = out_shape
8
+ out_shape2 = out_shape
9
+ out_shape3 = out_shape
10
+ if len(in_shape) >= 4:
11
+ out_shape4 = out_shape
12
+
13
+ if expand:
14
+ out_shape1 = out_shape
15
+ out_shape2 = out_shape * 2
16
+ out_shape3 = out_shape * 4
17
+ if len(in_shape) >= 4:
18
+ out_shape4 = out_shape * 8
19
+
20
+ scratch.layer1_rn = nn.Conv2d(in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
21
+ scratch.layer2_rn = nn.Conv2d(in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
22
+ scratch.layer3_rn = nn.Conv2d(in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
23
+ if len(in_shape) >= 4:
24
+ scratch.layer4_rn = nn.Conv2d(in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
25
+
26
+ return scratch
27
+
28
+
29
+ class ResidualConvUnit(nn.Module):
30
+ """Residual convolution module.
31
+ """
32
+
33
+ def __init__(self, features, activation, bn):
34
+ """Init.
35
+
36
+ Args:
37
+ features (int): number of features
38
+ """
39
+ super().__init__()
40
+
41
+ self.bn = bn
42
+
43
+ self.groups=1
44
+
45
+ self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
46
+
47
+ self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
48
+
49
+ if self.bn == True:
50
+ self.bn1 = nn.BatchNorm2d(features)
51
+ self.bn2 = nn.BatchNorm2d(features)
52
+
53
+ self.activation = activation
54
+
55
+ self.skip_add = nn.quantized.FloatFunctional()
56
+
57
+ def forward(self, x):
58
+ """Forward pass.
59
+
60
+ Args:
61
+ x (tensor): input
62
+
63
+ Returns:
64
+ tensor: output
65
+ """
66
+
67
+ out = self.activation(x)
68
+ out = self.conv1(out)
69
+ if self.bn == True:
70
+ out = self.bn1(out)
71
+
72
+ out = self.activation(out)
73
+ out = self.conv2(out)
74
+ if self.bn == True:
75
+ out = self.bn2(out)
76
+
77
+ if self.groups > 1:
78
+ out = self.conv_merge(out)
79
+
80
+ return self.skip_add.add(out, x)
81
+
82
+
83
+ class FeatureFusionBlock(nn.Module):
84
+ """Feature fusion block.
85
+ """
86
+
87
+ def __init__(
88
+ self,
89
+ features,
90
+ activation,
91
+ deconv=False,
92
+ bn=False,
93
+ expand=False,
94
+ align_corners=True,
95
+ size=None
96
+ ):
97
+ """Init.
98
+
99
+ Args:
100
+ features (int): number of features
101
+ """
102
+ super(FeatureFusionBlock, self).__init__()
103
+
104
+ self.deconv = deconv
105
+ self.align_corners = align_corners
106
+
107
+ self.groups=1
108
+
109
+ self.expand = expand
110
+ out_features = features
111
+ if self.expand == True:
112
+ out_features = features // 2
113
+
114
+ self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
115
+
116
+ self.resConfUnit1 = ResidualConvUnit(features, activation, bn)
117
+ self.resConfUnit2 = ResidualConvUnit(features, activation, bn)
118
+
119
+ self.skip_add = nn.quantized.FloatFunctional()
120
+
121
+ self.size=size
122
+
123
+ def forward(self, *xs, size=None):
124
+ """Forward pass.
125
+
126
+ Returns:
127
+ tensor: output
128
+ """
129
+ output = xs[0]
130
+
131
+ if len(xs) == 2:
132
+ res = self.resConfUnit1(xs[1])
133
+ output = self.skip_add.add(output, res)
134
+
135
+ output = self.resConfUnit2(output)
136
+
137
+ if (size is None) and (self.size is None):
138
+ modifier = {"scale_factor": 2}
139
+ elif size is None:
140
+ modifier = {"size": self.size}
141
+ else:
142
+ modifier = {"size": size}
143
+
144
+ output = nn.functional.interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
145
+
146
+ output = self.out_conv(output)
147
+
148
+ return output
ola_vlm/model/aux_heads/depth_anything_v2/util/transform.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+
4
+
5
+ class Resize(object):
6
+ """Resize sample to given size (width, height).
7
+ """
8
+
9
+ def __init__(
10
+ self,
11
+ width,
12
+ height,
13
+ resize_target=True,
14
+ keep_aspect_ratio=False,
15
+ ensure_multiple_of=1,
16
+ resize_method="lower_bound",
17
+ image_interpolation_method=cv2.INTER_AREA,
18
+ ):
19
+ """Init.
20
+
21
+ Args:
22
+ width (int): desired output width
23
+ height (int): desired output height
24
+ resize_target (bool, optional):
25
+ True: Resize the full sample (image, mask, target).
26
+ False: Resize image only.
27
+ Defaults to True.
28
+ keep_aspect_ratio (bool, optional):
29
+ True: Keep the aspect ratio of the input sample.
30
+ Output sample might not have the given width and height, and
31
+ resize behaviour depends on the parameter 'resize_method'.
32
+ Defaults to False.
33
+ ensure_multiple_of (int, optional):
34
+ Output width and height is constrained to be multiple of this parameter.
35
+ Defaults to 1.
36
+ resize_method (str, optional):
37
+ "lower_bound": Output will be at least as large as the given size.
38
+ "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
39
+ "minimal": Scale as least as possible. (Output size might be smaller than given size.)
40
+ Defaults to "lower_bound".
41
+ """
42
+ self.__width = width
43
+ self.__height = height
44
+
45
+ self.__resize_target = resize_target
46
+ self.__keep_aspect_ratio = keep_aspect_ratio
47
+ self.__multiple_of = ensure_multiple_of
48
+ self.__resize_method = resize_method
49
+ self.__image_interpolation_method = image_interpolation_method
50
+
51
+ def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
52
+ y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
53
+
54
+ if max_val is not None and y > max_val:
55
+ y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
56
+
57
+ if y < min_val:
58
+ y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
59
+
60
+ return y
61
+
62
+ def get_size(self, width, height):
63
+ # determine new height and width
64
+ scale_height = self.__height / height
65
+ scale_width = self.__width / width
66
+
67
+ if self.__keep_aspect_ratio:
68
+ if self.__resize_method == "lower_bound":
69
+ # scale such that output size is lower bound
70
+ if scale_width > scale_height:
71
+ # fit width
72
+ scale_height = scale_width
73
+ else:
74
+ # fit height
75
+ scale_width = scale_height
76
+ elif self.__resize_method == "upper_bound":
77
+ # scale such that output size is upper bound
78
+ if scale_width < scale_height:
79
+ # fit width
80
+ scale_height = scale_width
81
+ else:
82
+ # fit height
83
+ scale_width = scale_height
84
+ elif self.__resize_method == "minimal":
85
+ # scale as least as possbile
86
+ if abs(1 - scale_width) < abs(1 - scale_height):
87
+ # fit width
88
+ scale_height = scale_width
89
+ else:
90
+ # fit height
91
+ scale_width = scale_height
92
+ else:
93
+ raise ValueError(f"resize_method {self.__resize_method} not implemented")
94
+
95
+ if self.__resize_method == "lower_bound":
96
+ new_height = self.constrain_to_multiple_of(scale_height * height, min_val=self.__height)
97
+ new_width = self.constrain_to_multiple_of(scale_width * width, min_val=self.__width)
98
+ elif self.__resize_method == "upper_bound":
99
+ new_height = self.constrain_to_multiple_of(scale_height * height, max_val=self.__height)
100
+ new_width = self.constrain_to_multiple_of(scale_width * width, max_val=self.__width)
101
+ elif self.__resize_method == "minimal":
102
+ new_height = self.constrain_to_multiple_of(scale_height * height)
103
+ new_width = self.constrain_to_multiple_of(scale_width * width)
104
+ else:
105
+ raise ValueError(f"resize_method {self.__resize_method} not implemented")
106
+
107
+ return (new_width, new_height)
108
+
109
+ def __call__(self, sample):
110
+ width, height = self.get_size(sample["image"].shape[1], sample["image"].shape[0])
111
+
112
+ # resize sample
113
+ sample["image"] = cv2.resize(sample["image"], (width, height), interpolation=self.__image_interpolation_method)
114
+
115
+ if self.__resize_target:
116
+ if "depth" in sample:
117
+ sample["depth"] = cv2.resize(sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST)
118
+
119
+ if "mask" in sample:
120
+ sample["mask"] = cv2.resize(sample["mask"].astype(np.float32), (width, height), interpolation=cv2.INTER_NEAREST)
121
+
122
+ return sample
123
+
124
+
125
+ class NormalizeImage(object):
126
+ """Normlize image by given mean and std.
127
+ """
128
+
129
+ def __init__(self, mean, std):
130
+ self.__mean = mean
131
+ self.__std = std
132
+
133
+ def __call__(self, sample):
134
+ sample["image"] = (sample["image"] - self.__mean) / self.__std
135
+
136
+ return sample
137
+
138
+
139
+ class PrepareForNet(object):
140
+ """Prepare sample for usage as network input.
141
+ """
142
+
143
+ def __init__(self):
144
+ pass
145
+
146
+ def __call__(self, sample):
147
+ image = np.transpose(sample["image"], (2, 0, 1))
148
+ sample["image"] = np.ascontiguousarray(image).astype(np.float32)
149
+
150
+ if "depth" in sample:
151
+ depth = sample["depth"].astype(np.float32)
152
+ sample["depth"] = np.ascontiguousarray(depth)
153
+
154
+ if "mask" in sample:
155
+ sample["mask"] = sample["mask"].astype(np.float32)
156
+ sample["mask"] = np.ascontiguousarray(sample["mask"])
157
+
158
+ return sample