File size: 2,974 Bytes
10b2979
 
e403cae
10b2979
7572379
 
444f65d
 
7572379
 
10b2979
7572379
10b2979
2698d32
7572379
2698d32
 
7572379
 
 
 
444f65d
 
7572379
 
2698d32
7572379
 
 
 
 
444f65d
7572379
 
 
2698d32
 
7572379
 
 
 
 
2698d32
 
7572379
 
 
 
2698d32
7572379
 
2698d32
7572379
 
 
2698d32
7572379
 
 
 
 
 
 
2698d32
7572379
 
 
 
 
 
 
2698d32
7572379
 
 
 
 
 
 
 
 
2698d32
7572379
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import torch
import clip
from datasets import load_dataset
from PIL import Image
import gradio as gr
from torchvision import transforms
import requests
from io import BytesIO
import numpy as np
import faiss

# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load CLIP model
model_clip, preprocess = clip.load("ViT-B/32", device=device)

# Load your fine-tuned model weights
fine_tuned_state_dict = torch.load("best_model.pt", map_location=device)
model_clip.load_state_dict(fine_tuned_state_dict)

model_clip.eval()

# Load 50 samples from Flickr30k test split
dataset = load_dataset("nlphuji/flickr30k", split="test[:50]")

# Precompute embeddings
image_embeddings = []
images = []
captions = []
valid_indices = []

print("Extracting embeddings...")

for i, example in enumerate(dataset):
    try:
        img = Image.open(requests.get(example["image"], stream=True).raw).convert("RGB")
        img_tensor = preprocess(img).unsqueeze(0).to(device)
        with torch.no_grad():
            img_feat = model_clip.encode_image(img_tensor)
            img_feat /= img_feat.norm(dim=-1, keepdim=True)
            image_embeddings.append(img_feat.cpu())
        images.append(img)
        captions.append(example["sentence"])
        valid_indices.append(i)
    except Exception as e:
        print(f"Skipping sample {i} due to error: {e}")
        continue

# Stack image features
image_embeddings = torch.cat(image_embeddings, dim=0)

# Build FAISS index
image_index = faiss.IndexFlatIP(image_embeddings.shape[1])
image_index.add(image_embeddings.numpy())

# Search function
def search_by_text(query):
    with torch.no_grad():
        tokens = clip.tokenize([query]).to(device)
        text_feat = model_clip.encode_text(tokens)
        text_feat /= text_feat.norm(dim=-1, keepdim=True)
        text_feat_np = text_feat.cpu().numpy()

    D, I = image_index.search(text_feat_np, 5)
    results = []
    for idx in I[0]:
        img = images[idx]
        caption = captions[idx]
        results.append((img, caption))
    return results

# Gradio interface
def display_results(text_query):
    results = search_by_text(text_query)
    output = ""
    for i, (img, caption) in enumerate(results):
        output += f"### Result {i+1}\n"
        output += f"**Caption:** {caption}\n\n"
        output += f"![img](data:image/png;base64,{image_to_base64(img)})\n\n"
    return output

# Convert PIL image to base64
import base64
from io import BytesIO

def image_to_base64(image):
    buffer = BytesIO()
    image.save(buffer, format="PNG")
    return base64.b64encode(buffer.getvalue()).decode()

iface = gr.Interface(fn=display_results,
                     inputs=gr.Textbox(lines=2, placeholder="Enter text to search..."),
                     outputs="markdown",
                     title="Text-to-Image Retrieval with CLIP",
                     description="Enter a sentence to retrieve similar images using a fine-tuned CLIP model.")

iface.launch()