moondream2 / app.py
sessex's picture
Update app.py
eccdcf9 verified
import spaces
import torch
import re
import gradio as gr
from threading import Thread
from transformers import TextIteratorStreamer, AutoTokenizer, AutoModelForCausalLM
from PIL import ImageDraw
from torchvision.transforms.v2 import Resize
import subprocess
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
model_id = "vikhyatk/moondream2"
revision = "2024-05-20"
tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision)
moondream = AutoModelForCausalLM.from_pretrained(
model_id, trust_remote_code=True, revision=revision,
torch_dtype=torch.bfloat16, device_map={"": "cuda"},
attn_implementation="flash_attention_2"
)
moondream.eval()
@spaces.GPU(duration=20)
def answer_question(img, prompt):
image_embeds = moondream.encode_image(img)
streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
thread = Thread(
target=moondream.answer_question,
kwargs={
"image_embeds": image_embeds,
"question": prompt,
"tokenizer": tokenizer,
"streamer": streamer,
},
)
thread.start()
buffer = ""
for new_text in streamer:
buffer += new_text
# Wait for the thread to finish
thread.join()
return buffer.strip()
# def extract_floats(text):
# # Regular expression to match an array of four floating point numbers
# pattern = r"\[\s*(-?\d+\.\d+)\s*,\s*(-?\d+\.\d+)\s*,\s*(-?\d+\.\d+)\s*,\s*(-?\d+\.\d+)\s*\]"
# match = re.search(pattern, text)
# if match:
# # Extract the numbers and convert them to floats
# return [float(num) for num in match.groups()]
# return None # Return None if no match is found
# def extract_bbox(text):
# bbox = None
# if extract_floats(text) is not None:
# x1, y1, x2, y2 = extract_floats(text)
# bbox = (x1, y1, x2, y2)
# return bbox
# def process_answer(img, answer):
# if extract_bbox(answer) is not None:
# x1, y1, x2, y2 = extract_bbox(answer)
# draw_image = Resize(768)(img)
# width, height = draw_image.size
# x1, x2 = int(x1 * width), int(x2 * width)
# y1, y2 = int(y1 * height), int(y2 * height)
# bbox = (x1, y1, x2, y2)
# ImageDraw.Draw(draw_image).rectangle(bbox, outline="red", width=3)
# return gr.update(visible=True, value=draw_image)
# return gr.update(visible=False, value=None)
with gr.Blocks() as demo:
gr.Markdown(
"""
# 🌔 moondream2
A tiny vision language model. [GitHub](https://github.com/vikhyat/moondream)
"""
)
with gr.Row():
prompt = gr.Textbox(label="Input", value="Describe this image.", scale=4)
submit = gr.Button("Submit")
with gr.Row():
img = gr.Image(type="pil", label="Upload an Image")
with gr.Column():
output = gr.Text(label="Response")
ann = gr.Image(visible=False, label="Annotated Image")
submit.click(answer_question, [img, prompt], output)
prompt.submit(answer_question, [img, prompt], output)
# output.change(process_answer, [img, output], ann, show_progress=False)
demo.queue().launch(debug=True, show_error=True)