Spaces:
Running
Running
import gradio as gr | |
import torch | |
from PIL import Image | |
import os | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from transformers import CLIPProcessor, CLIPModel, AutoProcessor, AutoModel | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Load models and processors | |
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device) | |
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
dino_model = AutoModel.from_pretrained("facebook/dinov2-base").to(device) | |
dino_processor = AutoProcessor.from_pretrained("facebook/dinov2-base") | |
def get_image_embedding(image, model, processor, model_type): | |
if isinstance(image, str): # Handle file input | |
image = Image.open(image) | |
inputs = processor(images=image, return_tensors="pt").to(device) | |
with torch.no_grad(): | |
if model_type == "clip": | |
embedding = model.get_image_features(**inputs) | |
elif model_type == "dinov2": | |
outputs = model(**inputs) | |
embedding = outputs.last_hidden_state.mean(dim=1) # Global pooling | |
embedding /= embedding.norm(dim=-1, keepdim=True) # Normalize | |
return embedding | |
def retrieve_images(query_img, gallery_imgs, model, processor, model_type, top_k=10): | |
query_embedding = get_image_embedding(query_img, model, processor, model_type) | |
gallery_embeddings = [] | |
for img in gallery_imgs: | |
emb = get_image_embedding(img, model, processor, model_type) | |
gallery_embeddings.append((emb, img)) | |
rank_list = [] | |
for emb, img in gallery_embeddings: | |
similarity_score = (query_embedding @ emb.T).item() | |
rank_list.append((similarity_score, img)) | |
rank_list = sorted(rank_list, key=lambda x: x[0], reverse=True)[:top_k] | |
return [img for _, img in rank_list] | |
def display_results(query_img, gallery_imgs, top_k): | |
clip_results = retrieve_images(query_img, gallery_imgs, clip_model, clip_processor, "clip", top_k) | |
dino_results = retrieve_images(query_img, gallery_imgs, dino_model, dino_processor, "dinov2", top_k) | |
return [query_img] + clip_results, [query_img] + dino_results | |
def gradio_interface(query_img, gallery_imgs, top_k): | |
if not isinstance(gallery_imgs, list): | |
gallery_imgs = [gallery_imgs] | |
gallery_imgs = [img.name if hasattr(img, 'name') else img for img in gallery_imgs] # Handle NamedString issue | |
clip_res, dino_res = display_results(query_img, gallery_imgs, top_k) | |
return clip_res, dino_res | |
import copy | |
gallery_path = "dataset/gallery" | |
filenames = os.listdir(gallery_path) | |
flag_filenames = [filename for filename in filenames if "flag" in filenames] | |
tattoo_filenames = [filename for filename in filenames if "tattoo" in filename] | |
gallery_examples_flags = [os.path.join(gallery_path, filename) for filename in flag_filenames ] | |
gallery_examples_tattoos = [os.path.join(gallery_path, filename) for filename in tattoo_filenames ] | |
query_examples = ["dataset/query/american_flag46.jpg", "dataset/query/bird.jpg"] | |
print(gallery_examples_flags) | |
print(gallery_examples_tattoos) | |
demo = gr.Interface( | |
fn=gradio_interface, | |
inputs=[ | |
gr.Image(type="pil", label="Query Image"), | |
gr.File(file_types=["image"], label="Gallery Images", file_count="multiple", elem_id="gallery-files"), | |
gr.Slider(1, 30, value=10, step=1, label="Top-K Matches"), | |
], | |
outputs=[ | |
gr.Gallery(label="CLIP Retrieval Results", elem_id="clip-results", rows=[1], columns=[30]), | |
gr.Gallery(label="DINOv2 Retrieval Results", elem_id="dino-results", rows=[1], columns=[30]), | |
], | |
title="CLIP vs DINOv2 Image Retrieval", | |
description="Upload a query image and gallery images to see the top-k retrieval results side by side using CLIP and DINOv2.", | |
examples=[[query_examples[1], gallery_examples_tattoos, 10]], | |
css=""" | |
#gallery-files { | |
max-height: 150px; | |
overflow-y: scroll; | |
} | |
#clip-results, #dino-results { | |
max-height: 150px; | |
} | |
""" | |
) | |
demo.launch(share=True) | |