Spaces:
Runtime error
Runtime error
import torch | |
import clip | |
from datasets import load_dataset | |
from PIL import Image | |
import gradio as gr | |
from torchvision import transforms | |
import requests | |
from io import BytesIO | |
import numpy as np | |
import faiss | |
# Set device | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Load CLIP model | |
model_clip, preprocess = clip.load("ViT-B/32", device=device) | |
# Load your fine-tuned model weights | |
fine_tuned_state_dict = torch.load("best_model.pt", map_location=device) | |
model_clip.load_state_dict(fine_tuned_state_dict) | |
model_clip.eval() | |
# Load 50 samples from Flickr30k test split | |
dataset = load_dataset("nlphuji/flickr30k", split="test[:50]") | |
# Precompute embeddings | |
image_embeddings = [] | |
images = [] | |
captions = [] | |
valid_indices = [] | |
print("Extracting embeddings...") | |
for i, example in enumerate(dataset): | |
try: | |
img = Image.open(requests.get(example["image"], stream=True).raw).convert("RGB") | |
img_tensor = preprocess(img).unsqueeze(0).to(device) | |
with torch.no_grad(): | |
img_feat = model_clip.encode_image(img_tensor) | |
img_feat /= img_feat.norm(dim=-1, keepdim=True) | |
image_embeddings.append(img_feat.cpu()) | |
images.append(img) | |
captions.append(example["sentence"]) | |
valid_indices.append(i) | |
except Exception as e: | |
print(f"Skipping sample {i} due to error: {e}") | |
continue | |
# Stack image features | |
image_embeddings = torch.cat(image_embeddings, dim=0) | |
# Build FAISS index | |
image_index = faiss.IndexFlatIP(image_embeddings.shape[1]) | |
image_index.add(image_embeddings.numpy()) | |
# Search function | |
def search_by_text(query): | |
with torch.no_grad(): | |
tokens = clip.tokenize([query]).to(device) | |
text_feat = model_clip.encode_text(tokens) | |
text_feat /= text_feat.norm(dim=-1, keepdim=True) | |
text_feat_np = text_feat.cpu().numpy() | |
D, I = image_index.search(text_feat_np, 5) | |
results = [] | |
for idx in I[0]: | |
img = images[idx] | |
caption = captions[idx] | |
results.append((img, caption)) | |
return results | |
# Gradio interface | |
def display_results(text_query): | |
results = search_by_text(text_query) | |
output = "" | |
for i, (img, caption) in enumerate(results): | |
output += f"### Result {i+1}\n" | |
output += f"**Caption:** {caption}\n\n" | |
output += f"})\n\n" | |
return output | |
# Convert PIL image to base64 | |
import base64 | |
from io import BytesIO | |
def image_to_base64(image): | |
buffer = BytesIO() | |
image.save(buffer, format="PNG") | |
return base64.b64encode(buffer.getvalue()).decode() | |
iface = gr.Interface(fn=display_results, | |
inputs=gr.Textbox(lines=2, placeholder="Enter text to search..."), | |
outputs="markdown", | |
title="Text-to-Image Retrieval with CLIP", | |
description="Enter a sentence to retrieve similar images using a fine-tuned CLIP model.") | |
iface.launch() |