draw_to_search / app.py
osanseviero's picture
osanseviero HF staff
Update app.py
9a64f12
raw history blame
No virus
1.43 kB
import pandas as pd, numpy as np
import os
from transformers import CLIPProcessor, CLIPTextModel, CLIPModel
import gradio as gr
import requests
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
df = pd.read_csv('data2.csv')
embeddings_npy = np.load('embeddings.npy')
embeddings = np.divide(embeddings_npy, np.sqrt(np.sum(embeddings_npy**2, axis=1, keepdims=True)))
def compute_text_embeddings(list_of_strings):
inputs = processor(text=list_of_strings, return_tensors="pt", padding=True)
return model.get_text_features(**inputs)
def download_img(path):
img_data = requests.get(path).content
local_path = path.split("/")[-1] + ".jpg"
with open(local_path, 'wb') as handler:
handler.write(img_data)
return local_path
def predict(query):
n_results=3
text_embeddings = compute_text_embeddings([query]).detach().numpy()
results = np.argsort((embeddings@text_embeddings.T)[:, 0])[-1:-n_results-1:-1]
paths = [download_img(df.iloc[i]['path']) for i in results]
print(paths)
return paths
title = "Draw to Search"
iface = gr.Interface(
fn=predict,
inputs=[gr.inputs.Textbox(label="text", lines=3)],
outputs=[gr.outputs.Image(type="file"), gr.outputs.Image(type="file"), gr.outputs.Image(type="file")],
title=title,
examples=[["Sunset"]]
)
iface.launch(debug=True)