File size: 8,398 Bytes
55bb1f4
 
 
 
 
 
 
fb6458d
e6e631c
 
 
 
 
 
 
74ce846
fb6458d
e6e631c
cf7175c
 
e6e631c
fb6458d
cf7175c
 
e6e631c
c0e2011
cf7175c
c0e2011
cf7175c
55bb1f4
 
 
 
 
 
 
 
cf7175c
c0e2011
 
 
 
 
55bb1f4
 
 
 
 
 
 
 
 
 
 
 
 
 
c0e2011
e6e631c
 
 
cf7175c
a4d053b
e6e631c
 
e1286f2
c0e2011
55bb1f4
 
 
 
 
 
 
 
 
 
 
 
 
 
e1286f2
cf7175c
 
e1286f2
 
cf7175c
e6e631c
 
e1286f2
a4d053b
e1286f2
cf7175c
a4d053b
e1286f2
 
 
cf7175c
 
e1286f2
c0e2011
e1286f2
c0e2011
cf7175c
 
c0e2011
 
 
 
e1286f2
 
 
 
 
 
 
 
 
e6e631c
cf7175c
 
e6e631c
 
55bb1f4
 
 
 
 
 
 
 
 
 
 
 
 
 
e1286f2
 
 
 
cf7175c
 
a4d053b
e6e631c
a4d053b
cf7175c
a4d053b
cf7175c
e6e631c
a4d053b
e6e631c
e1286f2
 
cf7175c
 
e1286f2
 
 
 
 
 
e6e631c
30bbdee
55bb1f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e1286f2
 
 
c0e2011
cf7175c
c0e2011
a4d053b
 
cf7175c
 
a4d053b
e1286f2
c0e2011
a4d053b
e6e631c
 
e1286f2
 
e6e631c
cf7175c
 
 
 
 
 
 
 
 
e6e631c
cf7175c
 
 
 
 
 
 
 
 
 
 
e6e631c
cf7175c
 
 
 
 
 
 
 
 
 
 
ded7d37
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
"""
CLIP Image Search Application

A Gradio-based application for searching similar images using OpenAI's CLIP model.
Supports multiple image formats and provides a web interface for uploading and searching images.
"""

import gradio as gr
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import torch
import pickle
from pathlib import Path
import os
import spaces
from typing import List, Dict, Tuple, Optional, Union

# Load model/processor
model: CLIPModel = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
processor: CLIPProcessor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
model.eval()

DATASET_DIR: Path = Path("dataset")
CACHE_FILE: str = "cache.pkl"

# Define supported image formats
IMAGE_EXTENSIONS: List[str] = ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.gif", "*.webp", "*.tiff", "*.tif"]

def get_all_image_files() -> List[Path]:
    """
    Get all image files from the dataset directory.
    
    Searches for images with supported extensions in both lowercase and uppercase.
    
    Returns:
        List[Path]: List of Path objects for all found image files
    """
    image_files: List[Path] = []
    for ext in IMAGE_EXTENSIONS:
        image_files.extend(DATASET_DIR.glob(ext))
        image_files.extend(DATASET_DIR.glob(ext.upper()))  # Also check uppercase
    return image_files

def get_embedding(image: Image.Image, device: str = "cpu") -> torch.Tensor:
    """
    Generate CLIP embedding for an image.
    
    Args:
        image (Image.Image): PIL Image object to process
        device (str, optional): Device to run computation on. Defaults to "cpu".
    
    Returns:
        torch.Tensor: L2-normalized image embedding tensor
        
    Raises:
        RuntimeError: If CUDA is requested but not available
    """
    # Use CLIP's built-in preprocessing
    inputs = processor(images=image, return_tensors="pt").to(device)
    model_device = model.to(device)
    with torch.no_grad():
        emb: torch.Tensor = model_device.get_image_features(**inputs)
    # L2 normalize the embeddings
    emb = emb / emb.norm(p=2, dim=-1, keepdim=True)
    return emb

@spaces.GPU
def get_reference_embeddings() -> Dict[str, torch.Tensor]:
    """
    Load or compute embeddings for all reference images in the dataset.
    
    Checks if cached embeddings are up to date with the current dataset.
    If not, recomputes embeddings for all images and updates the cache.
    
    Returns:
        Dict[str, torch.Tensor]: Dictionary mapping image filenames to their embeddings
        
    Raises:
        FileNotFoundError: If dataset directory doesn't exist
        PermissionError: If unable to write cache file
    """
    # Get all current image files
    current_image_files: List[Path] = get_all_image_files()
    current_images: set = set(img_path.name for img_path in current_image_files)
    
    # Load existing cache if it exists
    cached_embeddings: Dict[str, torch.Tensor] = {}
    if os.path.exists(CACHE_FILE):
        with open(CACHE_FILE, "rb") as f:
            cached_embeddings = pickle.load(f)
    
    # Check if cache is up to date
    cached_images: set = set(cached_embeddings.keys())
    
    # If cache is missing images or has extra images, rebuild
    if current_images != cached_images:
        print(f"Cache outdated. Current: {len(current_images)}, Cached: {len(cached_images)}")
        embeddings: Dict[str, torch.Tensor] = {}
        device: str = "cuda" if torch.cuda.is_available() else "cpu"
        
        for img_path in current_image_files:
            print(f"Processing {img_path.name}...")
            try:
                img: Image.Image = Image.open(img_path).convert("RGB")
                emb: torch.Tensor = get_embedding(img, device=device)
                embeddings[img_path.name] = emb.cpu()
            except Exception as e:
                print(f"Error processing {img_path.name}: {e}")
                continue
        
        # Save updated cache
        with open(CACHE_FILE, "wb") as f:
            pickle.dump(embeddings, f)
        print(f"Cache updated with {len(embeddings)} images")
        return embeddings
    else:
        print(f"Using cached embeddings for {len(cached_embeddings)} images")
        return cached_embeddings

# Initialize reference embeddings
reference_embeddings: Dict[str, torch.Tensor] = get_reference_embeddings()

@spaces.GPU
def search_similar(query_img: Image.Image) -> List[Tuple[str, str]]:
    """
    Find similar images to the query image using CLIP embeddings.
    
    Args:
        query_img (Image.Image): Query image to find similar images for
        
    Returns:
        List[Tuple[str, str]]: List of tuples containing (image_path, similarity_score)
                              Limited to top 5 results above similarity threshold
                              
    Raises:
        RuntimeError: If CUDA operations fail
    """
    # Refresh embeddings to catch any new images
    global reference_embeddings
    reference_embeddings = get_reference_embeddings()
    
    query_emb: torch.Tensor = get_embedding(query_img, device="cuda")
    results: List[Tuple[str, float]] = []
    
    for name, ref_emb in reference_embeddings.items():
        # Move reference embedding to same device as query
        ref_emb_gpu: torch.Tensor = ref_emb.to("cuda")
        # Compute cosine similarity
        sim: float = torch.nn.functional.cosine_similarity(query_emb, ref_emb_gpu, dim=1).item()
        results.append((name, sim))
    
    results.sort(key=lambda x: x[1], reverse=True)
    
    # Filter out low similarity results (adjust threshold as needed)
    SIMILARITY_THRESHOLD: float = 0.2  # Only show results above 20% similarity
    filtered_results: List[Tuple[str, float]] = [(name, score) for name, score in results if score > SIMILARITY_THRESHOLD]
    
    if not filtered_results:
        return [("No similar images found", "No matches above similarity threshold")]
    
    # Return top 5 results
    return [(f"dataset/{name}", f"Score: {score:.4f}") for name, score in filtered_results[:5]]

@spaces.GPU
def add_image(name: str, image: Image.Image) -> str:
    """
    Add a new image to the dataset and update embeddings.
    
    Args:
        name (str): Name for the new image (without extension)
        image (Image.Image): PIL Image object to add to dataset
        
    Returns:
        str: Success message with total image count
        
    Raises:
        ValueError: If name is empty or invalid
        PermissionError: If unable to save image or update cache
        RuntimeError: If embedding computation fails
    """
    if not name.strip():
        return "Please provide a valid image name."
    
    # Save as PNG to preserve quality for all input formats
    path: Path = DATASET_DIR / f"{name}.png"
    image.save(path, "PNG")
    
    # Use GPU for consistency if available
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    emb: torch.Tensor = get_embedding(image, device=device)
    
    # Add to current embeddings and save cache
    reference_embeddings[f"{name}.png"] = emb.cpu()
    
    with open(CACHE_FILE, "wb") as f:
        pickle.dump(reference_embeddings, f)
    
    return f"Image '{name}' added to dataset. Total images: {len(reference_embeddings)}"

# Create Gradio interfaces
search_interface: gr.Interface = gr.Interface(
    fn=search_similar,
    inputs=gr.Image(type="pil", label="Query Image"),
    outputs=gr.Gallery(label="Top Matches", columns=5),
    allow_flagging="never",
    title="Image Similarity Search",
    description="Upload an image to find similar images in the dataset"
)

add_interface: gr.Interface = gr.Interface(
    fn=add_image,
    inputs=[
        gr.Text(label="Image Name", placeholder="Enter a unique name for your image"),
        gr.Image(type="pil", label="Product Image")
    ],
    outputs="text",
    allow_flagging="never",
    title="Add Image to Dataset",
    description="Add a new image to the searchable dataset"
)

# Create main application
demo: gr.TabbedInterface = gr.TabbedInterface(
    [search_interface, add_interface], 
    tab_names=["Search", "Add Product"],
    title="CLIP Image Search System",
    theme=gr.themes.Soft()
)

if __name__ == "__main__":
    # Ensure dataset directory exists
    DATASET_DIR.mkdir(exist_ok=True)
    demo.launch(share=True, mcp_server=True)