Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration | |
from PIL import Image | |
import torch | |
# Load model and processor | |
mix_model_id = "google/paligemma-3b-mix-224" | |
mix_model = PaliGemmaForConditionalGeneration.from_pretrained(mix_model_id) | |
mix_processor = AutoProcessor.from_pretrained(mix_model_id) | |
# Define inference function | |
def process_image(image, prompt): | |
# Process the image and prompt using the processor | |
inputs = mix_processor(image.convert("RGB"), prompt, return_tensors="pt") | |
try: | |
# Generate output from the model | |
output = mix_model.generate(**inputs, max_new_tokens=20) | |
# Decode and return the output | |
decoded_output = mix_processor.decode(output[0], skip_special_tokens=True) | |
# Return the answer (exclude the prompt part from output) | |
return decoded_output[len(prompt):] | |
except IndexError as e: | |
print(f"IndexError: {e}") | |
return "An error occurred during processing." | |
# Define the Gradio interface | |
inputs = [ | |
gr.Image(type="pil"), | |
gr.Textbox(label="Prompt", placeholder="Enter your question") | |
] | |
outputs = gr.Textbox(label="Answer") | |
# Create the Gradio app | |
demo = gr.Interface(fn=process_image, inputs=inputs, outputs=outputs, title="Image Captioning with Mix PaliGemma Model", | |
description="Upload an image and get captions based on your prompt.") | |
# Launch the app | |
demo.launch() |