Spaces:
Running
on
Zero
Running
on
Zero
Jitesh Jain
commited on
Commit
·
20b4d0d
1
Parent(s):
297e5e9
:zap: Fix version
Browse files
README.md
CHANGED
|
@@ -4,7 +4,7 @@ emoji: 🔍
|
|
| 4 |
colorFrom: blue
|
| 5 |
colorTo: purple
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version: 4.
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
license: apache-2.0
|
|
|
|
| 4 |
colorFrom: blue
|
| 5 |
colorTo: purple
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 4.42.0
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
license: apache-2.0
|
app.py
CHANGED
|
@@ -1,8 +1,7 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
-
import spaces
|
| 3 |
import torch
|
| 4 |
import numpy as np
|
| 5 |
-
|
| 6 |
from ola_vlm.constants import DEFAULT_IMAGE_TOKEN
|
| 7 |
|
| 8 |
from ola_vlm.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
|
|
@@ -23,6 +22,14 @@ import math
|
|
| 23 |
from transformers import TextIteratorStreamer
|
| 24 |
from threading import Thread
|
| 25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
def make_grid(pil_images, layer_indices=None):
|
| 27 |
new_images = []
|
| 28 |
new_captions = []
|
|
@@ -242,48 +249,51 @@ def regenerate(state, image_process_mode):
|
|
| 242 |
state.skip_next = False
|
| 243 |
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
|
| 244 |
|
| 245 |
-
@spaces.GPU
|
| 246 |
-
def get_interm_outs(state):
|
| 247 |
-
|
| 248 |
-
images = state.get_images(return_pil=True)
|
| 249 |
-
#prompt, image_args = process_image(prompt, images)
|
| 250 |
-
|
| 251 |
-
if images is not None and len(images) > 0:
|
| 252 |
-
if len(images) > 0:
|
| 253 |
-
if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
|
| 254 |
-
raise ValueError("Number of images does not match number of <image> tokens in prompt")
|
| 255 |
-
|
| 256 |
-
#images = [load_image_from_base64(image) for image in images]
|
| 257 |
-
image_sizes = [image.size for image in images]
|
| 258 |
-
inp_images = process_images(images, image_processor, model.config)
|
| 259 |
|
| 260 |
-
|
| 261 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 262 |
else:
|
| 263 |
-
inp_images =
|
|
|
|
|
|
|
| 264 |
else:
|
| 265 |
inp_images = None
|
| 266 |
-
|
| 267 |
-
image_args = {"images": inp_images, "image_sizes": image_sizes}
|
| 268 |
-
else:
|
| 269 |
-
inp_images = None
|
| 270 |
-
image_args = {}
|
| 271 |
|
| 272 |
-
|
| 273 |
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
|
| 283 |
-
|
| 284 |
|
| 285 |
-
@spaces.GPU
|
| 286 |
-
def generate(state, temperature, top_p, max_output_tokens):
|
| 287 |
prompt = state.get_prompt()
|
| 288 |
images = state.get_images(return_pil=True)
|
| 289 |
#prompt, image_args = process_image(prompt, images)
|
|
@@ -439,9 +449,9 @@ with gr.Blocks(title="OLA-VLM", theme=gr.themes.Default(), css=block_css) as dem
|
|
| 439 |
btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
|
| 440 |
|
| 441 |
inter_vis_btn.click(
|
| 442 |
-
|
| 443 |
[state],
|
| 444 |
-
[depth_box, seg_box, gen_box],
|
| 445 |
)
|
| 446 |
|
| 447 |
clear_btn.click(
|
|
|
|
| 1 |
import gradio as gr
|
|
|
|
| 2 |
import torch
|
| 3 |
import numpy as np
|
| 4 |
+
import spaces
|
| 5 |
from ola_vlm.constants import DEFAULT_IMAGE_TOKEN
|
| 6 |
|
| 7 |
from ola_vlm.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
|
|
|
|
| 22 |
from transformers import TextIteratorStreamer
|
| 23 |
from threading import Thread
|
| 24 |
|
| 25 |
+
import subprocess
|
| 26 |
+
# Install flash attention, skipping CUDA build if necessary
|
| 27 |
+
subprocess.run(
|
| 28 |
+
"pip install flash-attn --no-build-isolation",
|
| 29 |
+
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
|
| 30 |
+
shell=True,
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
def make_grid(pil_images, layer_indices=None):
|
| 34 |
new_images = []
|
| 35 |
new_captions = []
|
|
|
|
| 249 |
state.skip_next = False
|
| 250 |
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
|
| 251 |
|
| 252 |
+
# @spaces.GPU
|
| 253 |
+
# def get_interm_outs(state):
|
| 254 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 255 |
|
| 256 |
+
@spaces.GPU
|
| 257 |
+
def generate(state, temperature, top_p, max_output_tokens, is_inter=False):
|
| 258 |
+
if is_inter:
|
| 259 |
+
prompt = state.get_prompt()
|
| 260 |
+
images = state.get_images(return_pil=True)
|
| 261 |
+
#prompt, image_args = process_image(prompt, images)
|
| 262 |
+
|
| 263 |
+
if images is not None and len(images) > 0:
|
| 264 |
+
if len(images) > 0:
|
| 265 |
+
if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
|
| 266 |
+
raise ValueError("Number of images does not match number of <image> tokens in prompt")
|
| 267 |
+
|
| 268 |
+
#images = [load_image_from_base64(image) for image in images]
|
| 269 |
+
image_sizes = [image.size for image in images]
|
| 270 |
+
inp_images = process_images(images, image_processor, model.config)
|
| 271 |
+
|
| 272 |
+
if type(inp_images) is list:
|
| 273 |
+
inp_images = [image.to(model.device, dtype=torch.float16) for image in images]
|
| 274 |
+
else:
|
| 275 |
+
inp_images = inp_images.to(model.device, dtype=torch.float16)
|
| 276 |
else:
|
| 277 |
+
inp_images = None
|
| 278 |
+
image_sizes = None
|
| 279 |
+
image_args = {"images": inp_images, "image_sizes": image_sizes}
|
| 280 |
else:
|
| 281 |
inp_images = None
|
| 282 |
+
image_args = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 283 |
|
| 284 |
+
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
|
| 285 |
|
| 286 |
+
interm_outs = model.get_visual_interpretations(
|
| 287 |
+
input_ids,
|
| 288 |
+
**image_args
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
depth_outs = get_depth_images(interm_outs, image_sizes[0])
|
| 292 |
+
seg_outs = get_seg_images(interm_outs, images[0])
|
| 293 |
+
gen_outs = get_gen_images(interm_outs)
|
| 294 |
|
| 295 |
+
return depth_outs, seg_outs, gen_outs
|
| 296 |
|
|
|
|
|
|
|
| 297 |
prompt = state.get_prompt()
|
| 298 |
images = state.get_images(return_pil=True)
|
| 299 |
#prompt, image_args = process_image(prompt, images)
|
|
|
|
| 449 |
btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
|
| 450 |
|
| 451 |
inter_vis_btn.click(
|
| 452 |
+
generate,
|
| 453 |
[state],
|
| 454 |
+
[depth_box, seg_box, gen_box, True],
|
| 455 |
)
|
| 456 |
|
| 457 |
clear_btn.click(
|