RobotJelly's picture
app.py
087fe06
raw history blame
No virus
3.59 kB
# Import Libraries
from pathlib import Path
import pandas as pd
import numpy as np
import torch
import clip
from PIL import Image
from io import BytesIO
import requests
import gradio as gr
# Load the openAI's CLIP model
model, preprocess = clip.load("ViT-B/32", jit=False)
#display output photo
def show_output_image(matched_images) :
image=[]
for photo_id in matched_images:
photo_image_url = f"https://unsplash.com/photos/{photo_id}/download?w=280"
#photo_image_url = f"https://unsplash.com/photos/{photo_id}?w=280"
response = requests.get(photo_image_url)
img = Image.open(BytesIO(response.content))
#return img
image.append(img)
return image
# Encode and normalize the search query using CLIP
def encode_search_query(search_query, model, device):
with torch.no_grad():
text_encoded = model.encode_text(clip.tokenize(search_query).to(device))
text_encoded /= text_encoded.norm(dim=-1, keepdim=True)
# Retrieve the feature vector from the GPU and convert it to a numpy array
return text_encoded.cpu().numpy()
# Find all matched photos
def find_matches(text_features, photo_features, photo_ids, results_count=4):
# Compute the similarity between the search query and each photo using the Cosine similarity
similarities = (photo_features @ text_features.T).squeeze(1)
# Sort the photos by their similarity score
best_photo_idx = (-similarities).argsort()
# Return the photo IDs of the best matches
return [photo_ids[i] for i in best_photo_idx[:results_count]]
def image_search(search_text, search_image, option):
# taking photo IDs
photo_ids = pd.read_csv("./photo_ids.csv")
photo_ids = list(photo_ids['photo_id'])
# taking features vectors
photo_features = np.load("./features.npy")
# check if CUDA available
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load the openAI's CLIP model
#model, preprocess = clip.load("ViT-B/32", device=device, jit=False)
model = model.to(device)
# Input Text Query
#search_query = "The feeling when your program finally works"
if option == "Text-To-Image" :
# Extracting text features
text_features = encode_search_query(search_text, model, device)
# Find the matched Images
matched_images = find_matches(text_features, photo_features, photo_ids, 4)
# ---- debug purpose ------#
print(matched_images[0])
id = matched_images[0]
photo_image_url = f"https://unsplash.com/photos/{id}/download?w=280"
print(photo_image_url)
#--------------------------#
return show_output_image(matched_images)
elif option == "Image-To-Image":
# Input Image for Search
with torch.no_grad():
image_feature = model.encode_image(preprocess(search_image).unsqueeze(0).to(device))
image_feature = (image_feature / image_feature.norm(dim=-1, keepdim=True)).cpu().numpy()
# Find the matched Images
matched_images = find_matches(image_feature, photo_features, photo_ids, 4)
#is_input_image = True
images = show_output_image(matched_images)
return images
gr.Interface(fn=image_search,
inputs=[gr.inputs.Textbox(lines=7, label="Input Text"),
gr.inputs.Image(type="pil", optional=True),
gr.inputs.Dropdown(["Text-To-Image", "Image-To-Image"])
],
outputs=gr.outputs.Carousel([gr.outputs.Image(type="pil"), gr.outputs.Image(type="pil"), gr.outputs.Image(type="pil"), gr.outputs.Image(type="pil")]),
enable_queue=True
).launch(debug=True)