Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| import os | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from transformers import CLIPProcessor, CLIPModel, AutoProcessor, AutoModel | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Load models and processors | |
| clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device) | |
| clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
| dino_model = AutoModel.from_pretrained("facebook/dinov2-base").to(device) | |
| dino_processor = AutoProcessor.from_pretrained("facebook/dinov2-base") | |
| def get_image_embedding(image, model, processor, model_type): | |
| if isinstance(image, str): # Handle file input | |
| image = Image.open(image) | |
| inputs = processor(images=image, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| if model_type == "clip": | |
| embedding = model.get_image_features(**inputs) | |
| elif model_type == "dinov2": | |
| outputs = model(**inputs) | |
| embedding = outputs.last_hidden_state.mean(dim=1) # Global pooling | |
| embedding /= embedding.norm(dim=-1, keepdim=True) # Normalize | |
| return embedding | |
| def retrieve_images(query_img, gallery_imgs, model, processor, model_type, top_k=10): | |
| query_embedding = get_image_embedding(query_img, model, processor, model_type) | |
| gallery_embeddings = [] | |
| for img in gallery_imgs: | |
| emb = get_image_embedding(img, model, processor, model_type) | |
| gallery_embeddings.append((emb, img)) | |
| rank_list = [] | |
| for emb, img in gallery_embeddings: | |
| similarity_score = (query_embedding @ emb.T).item() | |
| rank_list.append((similarity_score, img)) | |
| rank_list = sorted(rank_list, key=lambda x: x[0], reverse=True)[:top_k] | |
| return [img for _, img in rank_list] | |
| def display_results(query_img, gallery_imgs, top_k): | |
| clip_results = retrieve_images(query_img, gallery_imgs, clip_model, clip_processor, "clip", top_k) | |
| dino_results = retrieve_images(query_img, gallery_imgs, dino_model, dino_processor, "dinov2", top_k) | |
| return [query_img] + clip_results, [query_img] + dino_results | |
| def gradio_interface(query_img, gallery_imgs, top_k): | |
| if not isinstance(gallery_imgs, list): | |
| gallery_imgs = [gallery_imgs] | |
| gallery_imgs = [img.name if hasattr(img, 'name') else img for img in gallery_imgs] # Handle NamedString issue | |
| clip_res, dino_res = display_results(query_img, gallery_imgs, top_k) | |
| return clip_res, dino_res | |
| import copy | |
| gallery_path = "dataset/gallery" | |
| filenames = os.listdir(gallery_path) | |
| flag_filenames = [filename for filename in filenames if "flag" in filenames] | |
| tattoo_filenames = [filename for filename in filenames if "tattoo" in filename] | |
| gallery_examples_flags = [os.path.join(gallery_path, filename) for filename in flag_filenames ] | |
| gallery_examples_tattoos = [os.path.join(gallery_path, filename) for filename in tattoo_filenames ] | |
| query_examples = ["dataset/query/american_flag46.jpg", "dataset/query/bird.jpg"] | |
| print(gallery_examples_flags) | |
| print(gallery_examples_tattoos) | |
| demo = gr.Interface( | |
| fn=gradio_interface, | |
| inputs=[ | |
| gr.Image(type="pil", label="Query Image"), | |
| gr.File(file_types=["image"], label="Gallery Images", file_count="multiple", elem_id="gallery-files"), | |
| gr.Slider(1, 30, value=10, step=1, label="Top-K Matches"), | |
| ], | |
| outputs=[ | |
| gr.Gallery(label="CLIP Retrieval Results", elem_id="clip-results", rows=[1], columns=[30]), | |
| gr.Gallery(label="DINOv2 Retrieval Results", elem_id="dino-results", rows=[1], columns=[30]), | |
| ], | |
| title="CLIP vs DINOv2 Image Retrieval", | |
| description="Upload a query image and gallery images to see the top-k retrieval results side by side using CLIP and DINOv2.", | |
| examples=[[query_examples[1], gallery_examples_tattoos, 10]], | |
| css=""" | |
| #gallery-files { | |
| max-height: 150px; | |
| overflow-y: scroll; | |
| } | |
| #clip-results, #dino-results { | |
| max-height: 150px; | |
| } | |
| """ | |
| ) | |
| demo.launch(share=True) | |