paligemma / app.py
taufiqdp's picture
Update app.py
dde305b verified
raw
history blame
No virus
1.53 kB
import os
import torch
import spaces
import gradio as gr
from huggingface_hub import login
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
login(os.environ.get("HF_TOKEN"))
model_id = "google/paligemma-3b-mix-448"
model = PaliGemmaForConditionalGeneration.from_pretrained(
model_id, device_map={"": 0},
torch_dtype=torch.bfloat16,
)
processor = AutoProcessor.from_pretrained(model_id)
model.eval()
@spaces.GPU()
def answer_question(image, prompt):
model_inputs = processor(text=prompt, images=image, return_tensors="pt").to("cuda")
input_len = model_inputs["input_ids"].shape[-1]
with torch.inference_mode():
generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False)
generation = generation[0][input_len:]
decoded = processor.decode(generation, skip_special_tokens=True)
return decoded
with gr.Blocks() as demo:
gr.Markdown(
"""
# PaliGemma
Lightweight open vision-language model (VLM). [Model card](https://huggingface.co/google/paligemma-3b-mix-448)
"""
)
with gr.Row():
prompt = gr.Textbox(label="Input", value="Describe this image.", scale=4)
submit = gr.Button("Submit")
with gr.Row():
image = gr.Image(type="pil", label="Upload an Image")
output = gr.TextArea(label="Response")
submit.click(answer_question, [image, prompt], output)
prompt.submit(answer_question, [image, prompt], output)
demo.queue().launch()