Spaces:
Runtime error
Runtime error
File size: 2,974 Bytes
10b2979 e403cae 10b2979 7572379 444f65d 7572379 10b2979 7572379 10b2979 2698d32 7572379 2698d32 7572379 444f65d 7572379 2698d32 7572379 444f65d 7572379 2698d32 7572379 2698d32 7572379 2698d32 7572379 2698d32 7572379 2698d32 7572379 2698d32 7572379 2698d32 7572379 2698d32 7572379 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
import torch
import clip
from datasets import load_dataset
from PIL import Image
import gradio as gr
from torchvision import transforms
import requests
from io import BytesIO
import numpy as np
import faiss
# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load CLIP model
model_clip, preprocess = clip.load("ViT-B/32", device=device)
# Load your fine-tuned model weights
fine_tuned_state_dict = torch.load("best_model.pt", map_location=device)
model_clip.load_state_dict(fine_tuned_state_dict)
model_clip.eval()
# Load 50 samples from Flickr30k test split
dataset = load_dataset("nlphuji/flickr30k", split="test[:50]")
# Precompute embeddings
image_embeddings = []
images = []
captions = []
valid_indices = []
print("Extracting embeddings...")
for i, example in enumerate(dataset):
try:
img = Image.open(requests.get(example["image"], stream=True).raw).convert("RGB")
img_tensor = preprocess(img).unsqueeze(0).to(device)
with torch.no_grad():
img_feat = model_clip.encode_image(img_tensor)
img_feat /= img_feat.norm(dim=-1, keepdim=True)
image_embeddings.append(img_feat.cpu())
images.append(img)
captions.append(example["sentence"])
valid_indices.append(i)
except Exception as e:
print(f"Skipping sample {i} due to error: {e}")
continue
# Stack image features
image_embeddings = torch.cat(image_embeddings, dim=0)
# Build FAISS index
image_index = faiss.IndexFlatIP(image_embeddings.shape[1])
image_index.add(image_embeddings.numpy())
# Search function
def search_by_text(query):
with torch.no_grad():
tokens = clip.tokenize([query]).to(device)
text_feat = model_clip.encode_text(tokens)
text_feat /= text_feat.norm(dim=-1, keepdim=True)
text_feat_np = text_feat.cpu().numpy()
D, I = image_index.search(text_feat_np, 5)
results = []
for idx in I[0]:
img = images[idx]
caption = captions[idx]
results.append((img, caption))
return results
# Gradio interface
def display_results(text_query):
results = search_by_text(text_query)
output = ""
for i, (img, caption) in enumerate(results):
output += f"### Result {i+1}\n"
output += f"**Caption:** {caption}\n\n"
output += f"})\n\n"
return output
# Convert PIL image to base64
import base64
from io import BytesIO
def image_to_base64(image):
buffer = BytesIO()
image.save(buffer, format="PNG")
return base64.b64encode(buffer.getvalue()).decode()
iface = gr.Interface(fn=display_results,
inputs=gr.Textbox(lines=2, placeholder="Enter text to search..."),
outputs="markdown",
title="Text-to-Image Retrieval with CLIP",
description="Enter a sentence to retrieve similar images using a fine-tuned CLIP model.")
iface.launch() |