File size: 8,398 Bytes
55bb1f4 fb6458d e6e631c 74ce846 fb6458d e6e631c cf7175c e6e631c fb6458d cf7175c e6e631c c0e2011 cf7175c c0e2011 cf7175c 55bb1f4 cf7175c c0e2011 55bb1f4 c0e2011 e6e631c cf7175c a4d053b e6e631c e1286f2 c0e2011 55bb1f4 e1286f2 cf7175c e1286f2 cf7175c e6e631c e1286f2 a4d053b e1286f2 cf7175c a4d053b e1286f2 cf7175c e1286f2 c0e2011 e1286f2 c0e2011 cf7175c c0e2011 e1286f2 e6e631c cf7175c e6e631c 55bb1f4 e1286f2 cf7175c a4d053b e6e631c a4d053b cf7175c a4d053b cf7175c e6e631c a4d053b e6e631c e1286f2 cf7175c e1286f2 e6e631c 30bbdee 55bb1f4 e1286f2 c0e2011 cf7175c c0e2011 a4d053b cf7175c a4d053b e1286f2 c0e2011 a4d053b e6e631c e1286f2 e6e631c cf7175c e6e631c cf7175c e6e631c cf7175c ded7d37 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 |
"""
CLIP Image Search Application
A Gradio-based application for searching similar images using OpenAI's CLIP model.
Supports multiple image formats and provides a web interface for uploading and searching images.
"""
import gradio as gr
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import torch
import pickle
from pathlib import Path
import os
import spaces
from typing import List, Dict, Tuple, Optional, Union
# Load model/processor
model: CLIPModel = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
processor: CLIPProcessor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
model.eval()
DATASET_DIR: Path = Path("dataset")
CACHE_FILE: str = "cache.pkl"
# Define supported image formats
IMAGE_EXTENSIONS: List[str] = ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.gif", "*.webp", "*.tiff", "*.tif"]
def get_all_image_files() -> List[Path]:
"""
Get all image files from the dataset directory.
Searches for images with supported extensions in both lowercase and uppercase.
Returns:
List[Path]: List of Path objects for all found image files
"""
image_files: List[Path] = []
for ext in IMAGE_EXTENSIONS:
image_files.extend(DATASET_DIR.glob(ext))
image_files.extend(DATASET_DIR.glob(ext.upper())) # Also check uppercase
return image_files
def get_embedding(image: Image.Image, device: str = "cpu") -> torch.Tensor:
"""
Generate CLIP embedding for an image.
Args:
image (Image.Image): PIL Image object to process
device (str, optional): Device to run computation on. Defaults to "cpu".
Returns:
torch.Tensor: L2-normalized image embedding tensor
Raises:
RuntimeError: If CUDA is requested but not available
"""
# Use CLIP's built-in preprocessing
inputs = processor(images=image, return_tensors="pt").to(device)
model_device = model.to(device)
with torch.no_grad():
emb: torch.Tensor = model_device.get_image_features(**inputs)
# L2 normalize the embeddings
emb = emb / emb.norm(p=2, dim=-1, keepdim=True)
return emb
@spaces.GPU
def get_reference_embeddings() -> Dict[str, torch.Tensor]:
"""
Load or compute embeddings for all reference images in the dataset.
Checks if cached embeddings are up to date with the current dataset.
If not, recomputes embeddings for all images and updates the cache.
Returns:
Dict[str, torch.Tensor]: Dictionary mapping image filenames to their embeddings
Raises:
FileNotFoundError: If dataset directory doesn't exist
PermissionError: If unable to write cache file
"""
# Get all current image files
current_image_files: List[Path] = get_all_image_files()
current_images: set = set(img_path.name for img_path in current_image_files)
# Load existing cache if it exists
cached_embeddings: Dict[str, torch.Tensor] = {}
if os.path.exists(CACHE_FILE):
with open(CACHE_FILE, "rb") as f:
cached_embeddings = pickle.load(f)
# Check if cache is up to date
cached_images: set = set(cached_embeddings.keys())
# If cache is missing images or has extra images, rebuild
if current_images != cached_images:
print(f"Cache outdated. Current: {len(current_images)}, Cached: {len(cached_images)}")
embeddings: Dict[str, torch.Tensor] = {}
device: str = "cuda" if torch.cuda.is_available() else "cpu"
for img_path in current_image_files:
print(f"Processing {img_path.name}...")
try:
img: Image.Image = Image.open(img_path).convert("RGB")
emb: torch.Tensor = get_embedding(img, device=device)
embeddings[img_path.name] = emb.cpu()
except Exception as e:
print(f"Error processing {img_path.name}: {e}")
continue
# Save updated cache
with open(CACHE_FILE, "wb") as f:
pickle.dump(embeddings, f)
print(f"Cache updated with {len(embeddings)} images")
return embeddings
else:
print(f"Using cached embeddings for {len(cached_embeddings)} images")
return cached_embeddings
# Initialize reference embeddings
reference_embeddings: Dict[str, torch.Tensor] = get_reference_embeddings()
@spaces.GPU
def search_similar(query_img: Image.Image) -> List[Tuple[str, str]]:
"""
Find similar images to the query image using CLIP embeddings.
Args:
query_img (Image.Image): Query image to find similar images for
Returns:
List[Tuple[str, str]]: List of tuples containing (image_path, similarity_score)
Limited to top 5 results above similarity threshold
Raises:
RuntimeError: If CUDA operations fail
"""
# Refresh embeddings to catch any new images
global reference_embeddings
reference_embeddings = get_reference_embeddings()
query_emb: torch.Tensor = get_embedding(query_img, device="cuda")
results: List[Tuple[str, float]] = []
for name, ref_emb in reference_embeddings.items():
# Move reference embedding to same device as query
ref_emb_gpu: torch.Tensor = ref_emb.to("cuda")
# Compute cosine similarity
sim: float = torch.nn.functional.cosine_similarity(query_emb, ref_emb_gpu, dim=1).item()
results.append((name, sim))
results.sort(key=lambda x: x[1], reverse=True)
# Filter out low similarity results (adjust threshold as needed)
SIMILARITY_THRESHOLD: float = 0.2 # Only show results above 20% similarity
filtered_results: List[Tuple[str, float]] = [(name, score) for name, score in results if score > SIMILARITY_THRESHOLD]
if not filtered_results:
return [("No similar images found", "No matches above similarity threshold")]
# Return top 5 results
return [(f"dataset/{name}", f"Score: {score:.4f}") for name, score in filtered_results[:5]]
@spaces.GPU
def add_image(name: str, image: Image.Image) -> str:
"""
Add a new image to the dataset and update embeddings.
Args:
name (str): Name for the new image (without extension)
image (Image.Image): PIL Image object to add to dataset
Returns:
str: Success message with total image count
Raises:
ValueError: If name is empty or invalid
PermissionError: If unable to save image or update cache
RuntimeError: If embedding computation fails
"""
if not name.strip():
return "Please provide a valid image name."
# Save as PNG to preserve quality for all input formats
path: Path = DATASET_DIR / f"{name}.png"
image.save(path, "PNG")
# Use GPU for consistency if available
device: str = "cuda" if torch.cuda.is_available() else "cpu"
emb: torch.Tensor = get_embedding(image, device=device)
# Add to current embeddings and save cache
reference_embeddings[f"{name}.png"] = emb.cpu()
with open(CACHE_FILE, "wb") as f:
pickle.dump(reference_embeddings, f)
return f"Image '{name}' added to dataset. Total images: {len(reference_embeddings)}"
# Create Gradio interfaces
search_interface: gr.Interface = gr.Interface(
fn=search_similar,
inputs=gr.Image(type="pil", label="Query Image"),
outputs=gr.Gallery(label="Top Matches", columns=5),
allow_flagging="never",
title="Image Similarity Search",
description="Upload an image to find similar images in the dataset"
)
add_interface: gr.Interface = gr.Interface(
fn=add_image,
inputs=[
gr.Text(label="Image Name", placeholder="Enter a unique name for your image"),
gr.Image(type="pil", label="Product Image")
],
outputs="text",
allow_flagging="never",
title="Add Image to Dataset",
description="Add a new image to the searchable dataset"
)
# Create main application
demo: gr.TabbedInterface = gr.TabbedInterface(
[search_interface, add_interface],
tab_names=["Search", "Add Product"],
title="CLIP Image Search System",
theme=gr.themes.Soft()
)
if __name__ == "__main__":
# Ensure dataset directory exists
DATASET_DIR.mkdir(exist_ok=True)
demo.launch(share=True, mcp_server=True) |