File size: 3,750 Bytes
f87f007 c96e867 f87f007 c96e867 f87f007 c96e867 f87f007 c96e867 f87f007 d53269f f87f007 035f746 f87f007 c96e867 f87f007 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import os
import re
import random
from dataclasses import dataclass
from functools import partial
import torch
import gradio as gr
import spaces
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import AutoProcessor, Gemma3ForConditionalGeneration
from PIL import Image, ImageDraw
# --- Configuration ---
@dataclass
class Configuration:
dataset_id: str = "ariG23498/license-detection-paligemma"
model_id: str = "google/gemma-3-4b-pt"
checkpoint_id: str = "ariG23498/gemma-3-4b-pt-object-detection"
device: str = "cuda" if torch.cuda.is_available() else "cpu"
dtype: torch.dtype = torch.bfloat16
batch_size: int = 4
learning_rate: float = 2e-05
epochs: int = 1
# --- Utils ---
def parse_paligemma_label(label, width, height):
# Extract location codes
loc_pattern = r"<loc(\d{4})>"
locations = [int(loc) for loc in re.findall(loc_pattern, label)]
# Extract category (everything after the last location code)
category = label.split(">")[-1].strip()
# Order in PaliGemma format is: y1, x1, y2, x2
y1_norm, x1_norm, y2_norm, x2_norm = locations
# Convert normalized coordinates to image coordinates
x1 = (x1_norm / 1024) * width
y1 = (y1_norm / 1024) * height
x2 = (x2_norm / 1024) * width
y2 = (y2_norm / 1024) * height
return category, [x1, y1, x2, y2]
def visualize_bounding_boxes(image, label, width, height):
# Copy image for drawing
draw_image = image.copy()
draw = ImageDraw.Draw(draw_image)
category, bbox = parse_paligemma_label(label, width, height)
draw.rectangle(bbox, outline="red", width=2)
draw.text((bbox[0], max(0, bbox[1] - 10)), category, fill="red")
return draw_image
def test_collate_function(batch_of_samples, processor, dtype):
images = []
prompts = []
for sample in batch_of_samples:
images.append([sample["image"]])
prompts.append(f"{processor.tokenizer.boi_token} detect \n\n")
batch = processor(images=images, text=prompts, return_tensors="pt", padding=True)
batch["pixel_values"] = batch["pixel_values"].to(dtype)
return batch, images
# --- Initialize ---
cfg = Configuration()
processor = AutoProcessor.from_pretrained(cfg.checkpoint_id)
model = Gemma3ForConditionalGeneration.from_pretrained(
cfg.checkpoint_id,
torch_dtype=cfg.dtype,
device_map="cpu",
)
model.eval()
test_dataset = load_dataset(cfg.dataset_id, split="test")
def get_sample():
sample = random.choice(test_dataset)
images = [[sample["image"]]]
prompts = [f"{processor.tokenizer.boi_token} detect \n\n"]
batch = processor(images=images, text=prompts, return_tensors="pt", padding=True)
batch["pixel_values"] = batch["pixel_values"].to(cfg.dtype)
return batch, sample["image"]
# --- Prediction Logic ---
@spaces.GPU
def run_prediction():
model.to(cfg.device)
batch, raw_image = get_sample()
batch = {k: v.to(cfg.device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
with torch.no_grad():
generation = model.generate(**batch, max_new_tokens=100)
decoded = processor.batch_decode(generation, skip_special_tokens=True)[0]
image = raw_image # ✅ FIXED: raw_image is already a PIL.Image
width, height = image.size
result_image = visualize_bounding_boxes(image, decoded, width, height)
return result_image
# --- Gradio Interface ---
demo = gr.Interface(
fn=run_prediction,
inputs=[],
outputs=gr.Image(type="pil", label="Detected Bounding Box"),
title="Gemma3 Object Detector",
description="Click 'Generate' to visualize a prediction from a randomly sampled test image.",
)
if __name__ == "__main__":
demo.launch()
|