CLIPvsDINOv2 / app.py
YagmurCA's picture
Update app.py
ed3c777 verified
import gradio as gr
import torch
from PIL import Image
import os
import numpy as np
import matplotlib.pyplot as plt
from transformers import CLIPProcessor, CLIPModel, AutoProcessor, AutoModel
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load models and processors
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
dino_model = AutoModel.from_pretrained("facebook/dinov2-base").to(device)
dino_processor = AutoProcessor.from_pretrained("facebook/dinov2-base")
def get_image_embedding(image, model, processor, model_type):
if isinstance(image, str): # Handle file input
image = Image.open(image)
inputs = processor(images=image, return_tensors="pt").to(device)
with torch.no_grad():
if model_type == "clip":
embedding = model.get_image_features(**inputs)
elif model_type == "dinov2":
outputs = model(**inputs)
embedding = outputs.last_hidden_state.mean(dim=1) # Global pooling
embedding /= embedding.norm(dim=-1, keepdim=True) # Normalize
return embedding
def retrieve_images(query_img, gallery_imgs, model, processor, model_type, top_k=10):
query_embedding = get_image_embedding(query_img, model, processor, model_type)
gallery_embeddings = []
for img in gallery_imgs:
emb = get_image_embedding(img, model, processor, model_type)
gallery_embeddings.append((emb, img))
rank_list = []
for emb, img in gallery_embeddings:
similarity_score = (query_embedding @ emb.T).item()
rank_list.append((similarity_score, img))
rank_list = sorted(rank_list, key=lambda x: x[0], reverse=True)[:top_k]
return [img for _, img in rank_list]
def display_results(query_img, gallery_imgs, top_k):
clip_results = retrieve_images(query_img, gallery_imgs, clip_model, clip_processor, "clip", top_k)
dino_results = retrieve_images(query_img, gallery_imgs, dino_model, dino_processor, "dinov2", top_k)
return [query_img] + clip_results, [query_img] + dino_results
def gradio_interface(query_img, gallery_imgs, top_k):
if not isinstance(gallery_imgs, list):
gallery_imgs = [gallery_imgs]
gallery_imgs = [img.name if hasattr(img, 'name') else img for img in gallery_imgs] # Handle NamedString issue
clip_res, dino_res = display_results(query_img, gallery_imgs, top_k)
return clip_res, dino_res
import copy
gallery_path = "dataset/gallery"
filenames = os.listdir(gallery_path)
flag_filenames = [filename for filename in filenames if "flag" in filenames]
tattoo_filenames = [filename for filename in filenames if "tattoo" in filename]
gallery_examples_flags = [os.path.join(gallery_path, filename) for filename in flag_filenames ]
gallery_examples_tattoos = [os.path.join(gallery_path, filename) for filename in tattoo_filenames ]
query_examples = ["dataset/query/american_flag46.jpg", "dataset/query/bird.jpg"]
print(gallery_examples_flags)
print(gallery_examples_tattoos)
demo = gr.Interface(
fn=gradio_interface,
inputs=[
gr.Image(type="pil", label="Query Image"),
gr.File(file_types=["image"], label="Gallery Images", file_count="multiple", elem_id="gallery-files"),
gr.Slider(1, 30, value=10, step=1, label="Top-K Matches"),
],
outputs=[
gr.Gallery(label="CLIP Retrieval Results", elem_id="clip-results", rows=[1], columns=[30]),
gr.Gallery(label="DINOv2 Retrieval Results", elem_id="dino-results", rows=[1], columns=[30]),
],
title="CLIP vs DINOv2 Image Retrieval",
description="Upload a query image and gallery images to see the top-k retrieval results side by side using CLIP and DINOv2.",
examples=[[query_examples[1], gallery_examples_tattoos, 10]],
css="""
#gallery-files {
max-height: 150px;
overflow-y: scroll;
}
#clip-results, #dino-results {
max-height: 150px;
}
"""
)
demo.launch(share=True)