Csplk's picture
Update app.py
ee5e19e verified
raw history blame
No virus
2.19 kB
import spaces
import torch
import re
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
from PIL import Image
if torch.cuda.is_available():
device, dtype = "cuda", torch.float16
else:
device, dtype = "cpu", torch.float32
model_id = "vikhyatk/moondream2"
revision = "2024-04-02"
tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision)
moondream = AutoModelForCausalLM.from_pretrained(
model_id, trust_remote_code=True, revision=revision, torch_dtype=dtype
).to(device=device)
moondream.eval()
@spaces.GPU(duration=10)
def answer_questions(image_tuples, prompt_text):
result = ""
print(f"prompt_text: {prompt_text}\n")
prompts = [p.strip() for p in prompt_text.split(',')] # Splitting and cleaning prompts
print(f"prompts: {prompts}\n")
image_embeds = [img[0] for img in image_tuples if img[0] is not None] # Extracting images from tuples, ignoring None
# Check if the lengths of image_embeds and prompts are equal
if len(image_embeds) != len(prompts):
return ("Error: The number of images input and prompts input (seperate by commas in input text field) must be the same.")
answers = moondream.batch_answer(
images=image_embeds,
prompts=prompts,
tokenizer=tokenizer,
)
for question, answer in zip(prompts, answers):
print(f"Q: {question}")
print(f"A: {answer}")
print()
result += (f"Q: {question}\nA: {answer}\n\n")
return result
with gr.Blocks() as demo:
gr.Markdown("# moondream2 unofficial batch processing demo")
gr.Markdown("# πŸŒ” moondream2\nA tiny vision language model. [GitHub](https://github.com/vikhyatk/moondream)")
with gr.Row():
img = gr.Gallery(label="Upload Images", type="pil")
prompt = gr.Textbox(label="Input Prompts", placeholder="Enter prompts (one prompt for each image provided) separated by commas. Ex: Describe this image, What is in this image?", lines=2)
submit = gr.Button("Submit")
output = gr.TextArea(label="Responses", lines=4)
submit.click(answer_questions, [img, prompt], output)
demo.queue().launch()