import io import requests import numpy as np import pandas as pd import torch import torch.nn.functional as F from PIL import Image import gradio as gr import uform from datetime import datetime model_multi = uform.get_model('unum-cloud/uform-vl-multilingual') embeddings = np.load('tensors/embeddings.npy') embeddings = torch.tensor(embeddings) #features = np.load('multilingual-image-search/tensors/features.npy') #features = torch.tensor(features) img_df = pd.read_csv('image_data.csv') def url2img(url, resize = False, fix_height = 150): data = requests.get(url, allow_redirects = True).content img = Image.open(io.BytesIO(data)) if resize: img.thumbnail([fix_height, fix_height], Image.LANCZOS) return img def find_topk(text): print('text', text) top_k = 20 text_data = model_multi.preprocess_text(text) text_features, text_embedding = model_multi.encode_text(text_data, return_features=True) print('Got features', datetime.now().strftime("%H:%M:%S")) sims = F.cosine_similarity(text_embedding, embeddings) vals, inds = sims.topk(top_k) top_k_urls = img_df.iloc[inds]['photo_image_url'].values print('Got top_k_urls', top_k_urls) print(datetime.now().strftime("%H:%M:%S")) return top_k_urls # def rerank(text_features, text_data): # # craet joint embeddings & get scores # joint_embedding = model_multi.encode_multimodal( # image_features=image_features, # text_features=text_features, # attention_mask=text_data['attention_mask'] # ) # score = model_multi.get_matching_scores(joint_embedding) # # argmax to get top N # return #demo = gr.Interface(find_topk, inputs = 'text', outputs = 'image') print('version', gr.__version__) with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown('# Enter a prompt in one of the supported languages.') with gr.Row(): with gr.Column(): gr.Markdown( '||||||\n' '|:-------: |:---: |:-------: |:---: | :--- |\n' '|__English__| # |__French__ | # |__Russian__|\n' '|__German__ | # |__Italian__ | # |__Chinese (Simplified)__|\n' '|__Spanish__| # |__Japanese__| # |__Korean__|\n' '|__Turkish__| # |__Polish__ | # |.|\n') with gr.Column(): prompt_box = gr.Textbox(label = 'Enter your prompt', lines = 3, container = True) btn_search = gr.Button("Find images") with gr.Row(): gr.Examples(['a girl wandering alone in the forest', 'морозное утро в городе', '카메라를 바라보는 강아지 새끼', 'ein Schloss, das zwischen modernen Gebäuden hervorlugt', 'un couple sirotant un café au bord de la rivière', 'una banda de música actuando en un gran espacio al aire libre', '秋の静かな霧の庭園' ], inputs=[prompt_box]) gallery = gr.Gallery().style(grid = [5], height="auto") btn_search.click(find_topk, inputs = prompt_box, outputs = gallery) if __name__ == "__main__": demo.launch()