kaveh's picture
fixed examples repetition
1a3b019
import gradio as gr
import torch
import pickle
import numpy as np
import pandas as pd
from transformers import CLIPProcessor, CLIPModel
from transformers import VisionTextDualEncoderModel, VisionTextDualEncoderProcessor
from sklearn.metrics.pairwise import cosine_similarity
import csv
from PIL import Image
model_path_rclip = "kaveh/rclip"
embeddings_file_rclip = './image_embeddings_rclip.pkl'
model_path_pubmedclip = "flaviagiammarino/pubmed-clip-vit-base-patch32"
embeddings_file_pubmedclip = './image_embeddings_pubmedclip.pkl'
csv_path = "./captions.txt"
def load_image_ids(csv_file):
ids = []
captions = []
with open(csv_file, 'r') as f:
reader = csv.reader(f, delimiter='\t')
for row in reader:
ids.append(row[0])
captions.append(row[1])
return ids, captions
def load_embeddings(embeddings_file):
with open(embeddings_file, 'rb') as f:
image_embeddings = pickle.load(f)
return image_embeddings
def find_similar_images(query_embedding, image_embeddings, k=2):
similarities = cosine_similarity(query_embedding.reshape(1, -1), image_embeddings)
closest_indices = np.argsort(similarities[0])[::-1][:k]
scores = sorted(similarities[0])[::-1][:k]
return closest_indices, scores
def main(query, model_id="RCLIP", k=2):
if model_id=="RCLIP":
# Load RCLIP model
model = VisionTextDualEncoderModel.from_pretrained(model_path_rclip)
processor = VisionTextDualEncoderProcessor.from_pretrained(model_path_rclip)
# Load image embeddings
image_embeddings = load_embeddings(embeddings_file_rclip)
elif model_id=="PubMedCLIP":
model = CLIPModel.from_pretrained(model_path_pubmedclip)
processor = CLIPProcessor.from_pretrained(model_path_pubmedclip)
# Load image embeddings
image_embeddings = load_embeddings(embeddings_file_pubmedclip)
# Embed the query
inputs = processor(text=query, images=None, return_tensors="pt", padding=True)
with torch.no_grad():
query_embedding = model.get_text_features(**inputs)[0].numpy()
# Get image names
ids, captions = load_image_ids(csv_path)
# Find similar images
similar_image_indices, scores = find_similar_images(query_embedding, image_embeddings, k=int(k))
# Return the results
similar_image_names = [f"./images/{ids[index]}.jpg" for index in similar_image_indices]
similar_image_captions = [captions[index] for index in similar_image_indices]
similar_images = [Image.open(i) for i in similar_image_names]
return similar_images, pd.DataFrame([[t+1 for t in range(k)], similar_image_names, similar_image_captions, scores], index=["#", "path", "caption", "score"]).T
# Define the Gradio interface
examples = [
["Chest X-ray photos", "RCLIP", 5],
["Chest X-ray photos", "PubMedCLIP", 5],
["Orthopantogram (OPG)", "RCLIP",5],
["Brain MRI", "RCLIP",5],
["Ultrasound", "RCLIP",5],
]
title="RCLIP Image Retrieval"
description = "CLIP model fine-tuned on the ROCO dataset"
with gr.Blocks(title=title) as demo:
with gr.Row():
with gr.Column(scale=5):
gr.Markdown("# "+title)
gr.Markdown(description)
gr.HTML(value="<img src=\"https://newresults.co.uk/wp-content/uploads/2022/02/teesside-university-logo.png\" alt=\"teesside logo\" width=\"120\" height=\"70\">", show_label=False,scale=1)
#Image.open("./data/teesside university logo.png"), height=70, show_label=False, container=False)
with gr.Row(variant="compact"):
query = gr.Textbox(value="Chest X-Ray Photos", label="Enter your query", show_label=False, placeholder= "Enter your query" , scale=5)
btn = gr.Button("Search query", variant="primary", scale=1)
with gr.Row(variant="compact"):
model_id = gr.Dropdown(["RCLIP", "PubMedCLIP"], value="RCLIP", label="Model", type="value", scale=1)
n_s = gr.Slider(2, 10, label='Number of Top Results', value=5, step=1.0, show_label=True, scale=1)
with gr.Column(variant="compact"):
gr.Markdown("## Results")
gallery = gr.Gallery(label="found images", show_label=True, elem_id="gallery", columns=[2], rows=[4], object_fit="contain", height="400px", preview=True)
gr.Markdown("Information of the found images")
df = gr.DataFrame()
btn.click(main, [query, model_id, n_s], [gallery, df])
with gr.Column(variant="compact"):
gr.Markdown("## Examples")
gr.Examples(examples, [query, model_id, n_s])
demo.launch(debug='True')