tooba248's picture
Update app.py
7572379 verified
import torch
import clip
from datasets import load_dataset
from PIL import Image
import gradio as gr
from torchvision import transforms
import requests
from io import BytesIO
import numpy as np
import faiss
# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load CLIP model
model_clip, preprocess = clip.load("ViT-B/32", device=device)
# Load your fine-tuned model weights
fine_tuned_state_dict = torch.load("best_model.pt", map_location=device)
model_clip.load_state_dict(fine_tuned_state_dict)
model_clip.eval()
# Load 50 samples from Flickr30k test split
dataset = load_dataset("nlphuji/flickr30k", split="test[:50]")
# Precompute embeddings
image_embeddings = []
images = []
captions = []
valid_indices = []
print("Extracting embeddings...")
for i, example in enumerate(dataset):
try:
img = Image.open(requests.get(example["image"], stream=True).raw).convert("RGB")
img_tensor = preprocess(img).unsqueeze(0).to(device)
with torch.no_grad():
img_feat = model_clip.encode_image(img_tensor)
img_feat /= img_feat.norm(dim=-1, keepdim=True)
image_embeddings.append(img_feat.cpu())
images.append(img)
captions.append(example["sentence"])
valid_indices.append(i)
except Exception as e:
print(f"Skipping sample {i} due to error: {e}")
continue
# Stack image features
image_embeddings = torch.cat(image_embeddings, dim=0)
# Build FAISS index
image_index = faiss.IndexFlatIP(image_embeddings.shape[1])
image_index.add(image_embeddings.numpy())
# Search function
def search_by_text(query):
with torch.no_grad():
tokens = clip.tokenize([query]).to(device)
text_feat = model_clip.encode_text(tokens)
text_feat /= text_feat.norm(dim=-1, keepdim=True)
text_feat_np = text_feat.cpu().numpy()
D, I = image_index.search(text_feat_np, 5)
results = []
for idx in I[0]:
img = images[idx]
caption = captions[idx]
results.append((img, caption))
return results
# Gradio interface
def display_results(text_query):
results = search_by_text(text_query)
output = ""
for i, (img, caption) in enumerate(results):
output += f"### Result {i+1}\n"
output += f"**Caption:** {caption}\n\n"
output += f"![img](data:image/png;base64,{image_to_base64(img)})\n\n"
return output
# Convert PIL image to base64
import base64
from io import BytesIO
def image_to_base64(image):
buffer = BytesIO()
image.save(buffer, format="PNG")
return base64.b64encode(buffer.getvalue()).decode()
iface = gr.Interface(fn=display_results,
inputs=gr.Textbox(lines=2, placeholder="Enter text to search..."),
outputs="markdown",
title="Text-to-Image Retrieval with CLIP",
description="Enter a sentence to retrieve similar images using a fine-tuned CLIP model.")
iface.launch()