ImgSearch / app.py
AkinyemiAra's picture
Update app.py
ded7d37 verified
"""
CLIP Image Search Application
A Gradio-based application for searching similar images using OpenAI's CLIP model.
Supports multiple image formats and provides a web interface for uploading and searching images.
"""
import gradio as gr
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import torch
import pickle
from pathlib import Path
import os
import spaces
from typing import List, Dict, Tuple, Optional, Union
# Load model/processor
model: CLIPModel = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
processor: CLIPProcessor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
model.eval()
DATASET_DIR: Path = Path("dataset")
CACHE_FILE: str = "cache.pkl"
# Define supported image formats
IMAGE_EXTENSIONS: List[str] = ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.gif", "*.webp", "*.tiff", "*.tif"]
def get_all_image_files() -> List[Path]:
"""
Get all image files from the dataset directory.
Searches for images with supported extensions in both lowercase and uppercase.
Returns:
List[Path]: List of Path objects for all found image files
"""
image_files: List[Path] = []
for ext in IMAGE_EXTENSIONS:
image_files.extend(DATASET_DIR.glob(ext))
image_files.extend(DATASET_DIR.glob(ext.upper())) # Also check uppercase
return image_files
def get_embedding(image: Image.Image, device: str = "cpu") -> torch.Tensor:
"""
Generate CLIP embedding for an image.
Args:
image (Image.Image): PIL Image object to process
device (str, optional): Device to run computation on. Defaults to "cpu".
Returns:
torch.Tensor: L2-normalized image embedding tensor
Raises:
RuntimeError: If CUDA is requested but not available
"""
# Use CLIP's built-in preprocessing
inputs = processor(images=image, return_tensors="pt").to(device)
model_device = model.to(device)
with torch.no_grad():
emb: torch.Tensor = model_device.get_image_features(**inputs)
# L2 normalize the embeddings
emb = emb / emb.norm(p=2, dim=-1, keepdim=True)
return emb
@spaces.GPU
def get_reference_embeddings() -> Dict[str, torch.Tensor]:
"""
Load or compute embeddings for all reference images in the dataset.
Checks if cached embeddings are up to date with the current dataset.
If not, recomputes embeddings for all images and updates the cache.
Returns:
Dict[str, torch.Tensor]: Dictionary mapping image filenames to their embeddings
Raises:
FileNotFoundError: If dataset directory doesn't exist
PermissionError: If unable to write cache file
"""
# Get all current image files
current_image_files: List[Path] = get_all_image_files()
current_images: set = set(img_path.name for img_path in current_image_files)
# Load existing cache if it exists
cached_embeddings: Dict[str, torch.Tensor] = {}
if os.path.exists(CACHE_FILE):
with open(CACHE_FILE, "rb") as f:
cached_embeddings = pickle.load(f)
# Check if cache is up to date
cached_images: set = set(cached_embeddings.keys())
# If cache is missing images or has extra images, rebuild
if current_images != cached_images:
print(f"Cache outdated. Current: {len(current_images)}, Cached: {len(cached_images)}")
embeddings: Dict[str, torch.Tensor] = {}
device: str = "cuda" if torch.cuda.is_available() else "cpu"
for img_path in current_image_files:
print(f"Processing {img_path.name}...")
try:
img: Image.Image = Image.open(img_path).convert("RGB")
emb: torch.Tensor = get_embedding(img, device=device)
embeddings[img_path.name] = emb.cpu()
except Exception as e:
print(f"Error processing {img_path.name}: {e}")
continue
# Save updated cache
with open(CACHE_FILE, "wb") as f:
pickle.dump(embeddings, f)
print(f"Cache updated with {len(embeddings)} images")
return embeddings
else:
print(f"Using cached embeddings for {len(cached_embeddings)} images")
return cached_embeddings
# Initialize reference embeddings
reference_embeddings: Dict[str, torch.Tensor] = get_reference_embeddings()
@spaces.GPU
def search_similar(query_img: Image.Image) -> List[Tuple[str, str]]:
"""
Find similar images to the query image using CLIP embeddings.
Args:
query_img (Image.Image): Query image to find similar images for
Returns:
List[Tuple[str, str]]: List of tuples containing (image_path, similarity_score)
Limited to top 5 results above similarity threshold
Raises:
RuntimeError: If CUDA operations fail
"""
# Refresh embeddings to catch any new images
global reference_embeddings
reference_embeddings = get_reference_embeddings()
query_emb: torch.Tensor = get_embedding(query_img, device="cuda")
results: List[Tuple[str, float]] = []
for name, ref_emb in reference_embeddings.items():
# Move reference embedding to same device as query
ref_emb_gpu: torch.Tensor = ref_emb.to("cuda")
# Compute cosine similarity
sim: float = torch.nn.functional.cosine_similarity(query_emb, ref_emb_gpu, dim=1).item()
results.append((name, sim))
results.sort(key=lambda x: x[1], reverse=True)
# Filter out low similarity results (adjust threshold as needed)
SIMILARITY_THRESHOLD: float = 0.2 # Only show results above 20% similarity
filtered_results: List[Tuple[str, float]] = [(name, score) for name, score in results if score > SIMILARITY_THRESHOLD]
if not filtered_results:
return [("No similar images found", "No matches above similarity threshold")]
# Return top 5 results
return [(f"dataset/{name}", f"Score: {score:.4f}") for name, score in filtered_results[:5]]
@spaces.GPU
def add_image(name: str, image: Image.Image) -> str:
"""
Add a new image to the dataset and update embeddings.
Args:
name (str): Name for the new image (without extension)
image (Image.Image): PIL Image object to add to dataset
Returns:
str: Success message with total image count
Raises:
ValueError: If name is empty or invalid
PermissionError: If unable to save image or update cache
RuntimeError: If embedding computation fails
"""
if not name.strip():
return "Please provide a valid image name."
# Save as PNG to preserve quality for all input formats
path: Path = DATASET_DIR / f"{name}.png"
image.save(path, "PNG")
# Use GPU for consistency if available
device: str = "cuda" if torch.cuda.is_available() else "cpu"
emb: torch.Tensor = get_embedding(image, device=device)
# Add to current embeddings and save cache
reference_embeddings[f"{name}.png"] = emb.cpu()
with open(CACHE_FILE, "wb") as f:
pickle.dump(reference_embeddings, f)
return f"Image '{name}' added to dataset. Total images: {len(reference_embeddings)}"
# Create Gradio interfaces
search_interface: gr.Interface = gr.Interface(
fn=search_similar,
inputs=gr.Image(type="pil", label="Query Image"),
outputs=gr.Gallery(label="Top Matches", columns=5),
allow_flagging="never",
title="Image Similarity Search",
description="Upload an image to find similar images in the dataset"
)
add_interface: gr.Interface = gr.Interface(
fn=add_image,
inputs=[
gr.Text(label="Image Name", placeholder="Enter a unique name for your image"),
gr.Image(type="pil", label="Product Image")
],
outputs="text",
allow_flagging="never",
title="Add Image to Dataset",
description="Add a new image to the searchable dataset"
)
# Create main application
demo: gr.TabbedInterface = gr.TabbedInterface(
[search_interface, add_interface],
tab_names=["Search", "Add Product"],
title="CLIP Image Search System",
theme=gr.themes.Soft()
)
if __name__ == "__main__":
# Ensure dataset directory exists
DATASET_DIR.mkdir(exist_ok=True)
demo.launch(share=True, mcp_server=True)