|
""" |
|
CLIP Image Search Application |
|
|
|
A Gradio-based application for searching similar images using OpenAI's CLIP model. |
|
Supports multiple image formats and provides a web interface for uploading and searching images. |
|
""" |
|
|
|
import gradio as gr |
|
from transformers import CLIPProcessor, CLIPModel |
|
from PIL import Image |
|
import torch |
|
import pickle |
|
from pathlib import Path |
|
import os |
|
import spaces |
|
from typing import List, Dict, Tuple, Optional, Union |
|
|
|
|
|
model: CLIPModel = CLIPModel.from_pretrained("openai/clip-vit-large-patch14") |
|
processor: CLIPProcessor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") |
|
model.eval() |
|
|
|
DATASET_DIR: Path = Path("dataset") |
|
CACHE_FILE: str = "cache.pkl" |
|
|
|
|
|
IMAGE_EXTENSIONS: List[str] = ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.gif", "*.webp", "*.tiff", "*.tif"] |
|
|
|
def get_all_image_files() -> List[Path]: |
|
""" |
|
Get all image files from the dataset directory. |
|
|
|
Searches for images with supported extensions in both lowercase and uppercase. |
|
|
|
Returns: |
|
List[Path]: List of Path objects for all found image files |
|
""" |
|
image_files: List[Path] = [] |
|
for ext in IMAGE_EXTENSIONS: |
|
image_files.extend(DATASET_DIR.glob(ext)) |
|
image_files.extend(DATASET_DIR.glob(ext.upper())) |
|
return image_files |
|
|
|
def get_embedding(image: Image.Image, device: str = "cpu") -> torch.Tensor: |
|
""" |
|
Generate CLIP embedding for an image. |
|
|
|
Args: |
|
image (Image.Image): PIL Image object to process |
|
device (str, optional): Device to run computation on. Defaults to "cpu". |
|
|
|
Returns: |
|
torch.Tensor: L2-normalized image embedding tensor |
|
|
|
Raises: |
|
RuntimeError: If CUDA is requested but not available |
|
""" |
|
|
|
inputs = processor(images=image, return_tensors="pt").to(device) |
|
model_device = model.to(device) |
|
with torch.no_grad(): |
|
emb: torch.Tensor = model_device.get_image_features(**inputs) |
|
|
|
emb = emb / emb.norm(p=2, dim=-1, keepdim=True) |
|
return emb |
|
|
|
@spaces.GPU |
|
def get_reference_embeddings() -> Dict[str, torch.Tensor]: |
|
""" |
|
Load or compute embeddings for all reference images in the dataset. |
|
|
|
Checks if cached embeddings are up to date with the current dataset. |
|
If not, recomputes embeddings for all images and updates the cache. |
|
|
|
Returns: |
|
Dict[str, torch.Tensor]: Dictionary mapping image filenames to their embeddings |
|
|
|
Raises: |
|
FileNotFoundError: If dataset directory doesn't exist |
|
PermissionError: If unable to write cache file |
|
""" |
|
|
|
current_image_files: List[Path] = get_all_image_files() |
|
current_images: set = set(img_path.name for img_path in current_image_files) |
|
|
|
|
|
cached_embeddings: Dict[str, torch.Tensor] = {} |
|
if os.path.exists(CACHE_FILE): |
|
with open(CACHE_FILE, "rb") as f: |
|
cached_embeddings = pickle.load(f) |
|
|
|
|
|
cached_images: set = set(cached_embeddings.keys()) |
|
|
|
|
|
if current_images != cached_images: |
|
print(f"Cache outdated. Current: {len(current_images)}, Cached: {len(cached_images)}") |
|
embeddings: Dict[str, torch.Tensor] = {} |
|
device: str = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
for img_path in current_image_files: |
|
print(f"Processing {img_path.name}...") |
|
try: |
|
img: Image.Image = Image.open(img_path).convert("RGB") |
|
emb: torch.Tensor = get_embedding(img, device=device) |
|
embeddings[img_path.name] = emb.cpu() |
|
except Exception as e: |
|
print(f"Error processing {img_path.name}: {e}") |
|
continue |
|
|
|
|
|
with open(CACHE_FILE, "wb") as f: |
|
pickle.dump(embeddings, f) |
|
print(f"Cache updated with {len(embeddings)} images") |
|
return embeddings |
|
else: |
|
print(f"Using cached embeddings for {len(cached_embeddings)} images") |
|
return cached_embeddings |
|
|
|
|
|
reference_embeddings: Dict[str, torch.Tensor] = get_reference_embeddings() |
|
|
|
@spaces.GPU |
|
def search_similar(query_img: Image.Image) -> List[Tuple[str, str]]: |
|
""" |
|
Find similar images to the query image using CLIP embeddings. |
|
|
|
Args: |
|
query_img (Image.Image): Query image to find similar images for |
|
|
|
Returns: |
|
List[Tuple[str, str]]: List of tuples containing (image_path, similarity_score) |
|
Limited to top 5 results above similarity threshold |
|
|
|
Raises: |
|
RuntimeError: If CUDA operations fail |
|
""" |
|
|
|
global reference_embeddings |
|
reference_embeddings = get_reference_embeddings() |
|
|
|
query_emb: torch.Tensor = get_embedding(query_img, device="cuda") |
|
results: List[Tuple[str, float]] = [] |
|
|
|
for name, ref_emb in reference_embeddings.items(): |
|
|
|
ref_emb_gpu: torch.Tensor = ref_emb.to("cuda") |
|
|
|
sim: float = torch.nn.functional.cosine_similarity(query_emb, ref_emb_gpu, dim=1).item() |
|
results.append((name, sim)) |
|
|
|
results.sort(key=lambda x: x[1], reverse=True) |
|
|
|
|
|
SIMILARITY_THRESHOLD: float = 0.2 |
|
filtered_results: List[Tuple[str, float]] = [(name, score) for name, score in results if score > SIMILARITY_THRESHOLD] |
|
|
|
if not filtered_results: |
|
return [("No similar images found", "No matches above similarity threshold")] |
|
|
|
|
|
return [(f"dataset/{name}", f"Score: {score:.4f}") for name, score in filtered_results[:5]] |
|
|
|
@spaces.GPU |
|
def add_image(name: str, image: Image.Image) -> str: |
|
""" |
|
Add a new image to the dataset and update embeddings. |
|
|
|
Args: |
|
name (str): Name for the new image (without extension) |
|
image (Image.Image): PIL Image object to add to dataset |
|
|
|
Returns: |
|
str: Success message with total image count |
|
|
|
Raises: |
|
ValueError: If name is empty or invalid |
|
PermissionError: If unable to save image or update cache |
|
RuntimeError: If embedding computation fails |
|
""" |
|
if not name.strip(): |
|
return "Please provide a valid image name." |
|
|
|
|
|
path: Path = DATASET_DIR / f"{name}.png" |
|
image.save(path, "PNG") |
|
|
|
|
|
device: str = "cuda" if torch.cuda.is_available() else "cpu" |
|
emb: torch.Tensor = get_embedding(image, device=device) |
|
|
|
|
|
reference_embeddings[f"{name}.png"] = emb.cpu() |
|
|
|
with open(CACHE_FILE, "wb") as f: |
|
pickle.dump(reference_embeddings, f) |
|
|
|
return f"Image '{name}' added to dataset. Total images: {len(reference_embeddings)}" |
|
|
|
|
|
search_interface: gr.Interface = gr.Interface( |
|
fn=search_similar, |
|
inputs=gr.Image(type="pil", label="Query Image"), |
|
outputs=gr.Gallery(label="Top Matches", columns=5), |
|
allow_flagging="never", |
|
title="Image Similarity Search", |
|
description="Upload an image to find similar images in the dataset" |
|
) |
|
|
|
add_interface: gr.Interface = gr.Interface( |
|
fn=add_image, |
|
inputs=[ |
|
gr.Text(label="Image Name", placeholder="Enter a unique name for your image"), |
|
gr.Image(type="pil", label="Product Image") |
|
], |
|
outputs="text", |
|
allow_flagging="never", |
|
title="Add Image to Dataset", |
|
description="Add a new image to the searchable dataset" |
|
) |
|
|
|
|
|
demo: gr.TabbedInterface = gr.TabbedInterface( |
|
[search_interface, add_interface], |
|
tab_names=["Search", "Add Product"], |
|
title="CLIP Image Search System", |
|
theme=gr.themes.Soft() |
|
) |
|
|
|
if __name__ == "__main__": |
|
|
|
DATASET_DIR.mkdir(exist_ok=True) |
|
demo.launch(share=True, mcp_server=True) |