File size: 2,716 Bytes
45f9b3d 44157d6 0dc6935 cf83b3d ef13ec4 44157d6 cf83b3d ef13ec4 3373ce1 cf83b3d 44157d6 e0a390e 44157d6 e0a390e 44157d6 fa73fe7 cf83b3d e0a390e cf83b3d fa73fe7 44157d6 cf83b3d 3373ce1 44157d6 cf83b3d 44157d6 cf83b3d 44157d6 cf83b3d 44157d6 cf83b3d 44157d6 cf83b3d 44157d6 3373ce1 44157d6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
import gradio as gr
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
from PIL import Image
import torch
import os
# Load the model and processor
model_id = "google/paligemma-3b-mix-224"
HF_TOKEN = os.getenv('HF_TOKEN')
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, token=HF_TOKEN).eval()
processor = AutoProcessor.from_pretrained(model_id, token=HF_TOKEN)
def generate_caption(image, prompt="What is in this image?", max_tokens=100):
"""Generate image description"""
if image is None:
return "Please upload an image."
# Update UI to show processing
gr.Info("Analysis starting. This may take up to 119 seconds.")
# Modify prompt to include image token
full_prompt = "<image> " + prompt
# Preprocess inputs
model_inputs = processor(text=full_prompt, images=image, return_tensors="pt")
input_len = model_inputs["input_ids"].shape[-1]
# Generate caption
with torch.inference_mode():
generation = model.generate(**model_inputs, max_new_tokens=max_tokens, do_sample=False)
generation = generation[0][input_len:]
decoded = processor.decode(generation, skip_special_tokens=True)
return decoded
# Load local example images
def load_local_images():
"""Load images from the repository"""
image_files = ['image1.jpg', 'image2.jpg', 'image3.jpg']
local_images = []
for img_file in image_files:
try:
img_path = os.path.join('.', img_file)
if os.path.exists(img_path):
local_images.append(Image.open(img_path))
except Exception as e:
print(f"Could not load {img_file}: {e}")
return local_images
# Prepare example images
EXAMPLE_IMAGES = load_local_images()
# Create Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("# PaliGemma Image Analysis")
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="Upload or Select Image")
custom_prompt = gr.Textbox(label="Custom Prompt", value="What is in this image?")
submit_btn = gr.Button("Analyze Image")
with gr.Column():
output_text = gr.Textbox(label="Image Description")
# Connect components
submit_btn.click(
fn=generate_caption,
inputs=[input_image, custom_prompt],
outputs=output_text
)
# Add example images
gr.Examples(
examples=[[img, "What is in this image?"] for img in EXAMPLE_IMAGES],
inputs=[input_image, custom_prompt],
fn=generate_caption,
outputs=output_text
)
# Launch the app
if __name__ == "__main__":
demo.launch() |