safinal's picture
Update app.py
c902692 verified
import gradio as gr
import torch
import numpy as np
from PIL import Image
import pandas as pd
from sklearn.metrics.pairwise import cosine_similarity
from tqdm import tqdm
from token_classifier import load_token_classifier, predict
from model import Model
from dataset import RetrievalDataset
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 512
import zipfile
import os
def unzip_file(zip_path, extract_path):
# Create the target directory if it doesn't exist
os.makedirs(extract_path, exist_ok=True)
# Open the zip file
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
# Extract all contents to the specified directory
zip_ref.extractall(extract_path)
# Example usage
zip_path = "sample_evaluation.zip"
extract_path = "sample_evaluation"
unzip_file(zip_path, extract_path)
from huggingface_hub import hf_hub_download
hf_hub_download(repo_id="safinal/compositional-image-retrieval", filename="weights.pth", local_dir='.')
def encode_database(model, df: pd.DataFrame) -> np.ndarray :
"""
Process database images and generate embeddings.
Args:
df (pd. DataFrame ): DataFrame with column:
- target_image: str, paths to database images
Returns:
np.ndarray: Embeddings array (num_images, embedding_dim)
"""
model.eval()
all_embeddings = []
for i in tqdm(range(0, len(df), batch_size)):
target_imgs = torch.stack([model.processor(Image.open(target_image_path)) for target_image_path in df['target_image'][i:i+batch_size]]).to(device)
with torch.no_grad():
# target_imgs_embedding = model.encode_database_image(target_imgs)
target_imgs_embedding = model.feature_extractor.encode_image(target_imgs)
target_imgs_embedding = torch.nn.functional.normalize(target_imgs_embedding, dim=1, p=2)
all_embeddings.append(target_imgs_embedding.detach().cpu().numpy())
return np.concatenate(all_embeddings)
# Load model and configurations
def load_model():
model = Model(model_name="ViTamin-L-384", pretrained=None)
model.load("weights.pth")
model.eval()
return model
def process_single_query(model, query_image_path, query_text, database_embeddings, database_df):
# Process query image
query_img = model.processor(Image.open(query_image_path)).unsqueeze(0).to(device)
# Get token classifier
token_classifier, token_classifier_tokenizer = load_token_classifier(
"safinal/compositional-image-retrieval-token-classifier",
device
)
with torch.no_grad():
query_img_embd = model.feature_extractor.encode_image(query_img)
# Process text query
predictions = predict(
tokens=query_text,
model=token_classifier,
tokenizer=token_classifier_tokenizer,
device=device,
max_length=128
)
# Process positive and negative objects
pos = []
neg = []
last_tag = ''
for token, label in predictions:
if label == '<positive_object>':
if last_tag != '<positive_object>':
pos.append(f"a photo of a {token}.")
else:
pos[-1] = pos[-1][:-1] + f" {token}."
elif label == '<negative_object>':
if last_tag != '<negative_object>':
neg.append(f"a photo of a {token}.")
else:
neg[-1] = neg[-1][:-1] + f" {token}."
last_tag = label
# Combine embeddings
for obj in pos:
query_img_embd += model.feature_extractor.encode_text(
model.tokenizer(obj).to(device)
)[0]
for obj in neg:
query_img_embd -= model.feature_extractor.encode_text(
model.tokenizer(obj).to(device)
)[0]
query_img_embd = torch.nn.functional.normalize(query_img_embd, dim=1, p=2)
# Calculate similarities
query_embedding = query_img_embd.cpu().numpy()
similarities = cosine_similarity(query_embedding, database_embeddings)[0]
# Get most similar image
most_similar_idx = np.argmax(similarities)
most_similar_image_path = database_df.iloc[most_similar_idx]['target_image']
return most_similar_image_path
# Initialize model and database
model = load_model()
test_dataset = RetrievalDataset(
img_dir_path="sample_evaluation/images",
annotations_file_path="sample_evaluation/data.csv",
split='test',
transform=model.processor,
tokenizer=model.tokenizer
)
database_embeddings = encode_database(model, test_dataset.load_database()) # Using your existing function
def interface_fn(selected_image, query_text):
result_image_path = process_single_query(
model,
selected_image,
query_text,
database_embeddings,
test_dataset.load_database()
)
return Image.open(result_image_path)
# Create Gradio interface
demo = gr.Interface(
fn=interface_fn,
inputs=[
gr.Image(type="filepath", label="Select Query Image", image_mode="RGB"),
gr.Textbox(label="Enter Query Text", lines=2)
],
outputs=gr.Image(label="Retrieved Image", type="pil"),
title="Compositional Image Retrieval",
description="Select an image and enter a text query to find the most similar image.",
examples=[
["sample_evaluation/images/261684.png", "Bring cow into the picture, and then follow up with removing bench."],
["sample_evaluation/images/283700.png", "add bowl and bench and remove shoe and elephant"],
["sample_evaluation/images/455007.png", "Discard chair in the beginning, then proceed to bring car into play."],
["sample_evaluation/images/612311.png", "Get rid of train initially, and then follow up by including snowboard."]
],
allow_flagging=False,
cache_examples=False
)
if __name__ == "__main__":
try:
demo.queue().launch(server_name="0.0.0.0", server_port=7860)
except Exception as e:
print(f"Error launching app: {str(e)}")
raise