File size: 3,411 Bytes
e173a84
 
 
 
 
 
 
 
 
848a638
 
e173a84
 
 
 
a24abaa
e173a84
 
 
 
 
a24abaa
e173a84
4a3a204
e173a84
bb5ec6d
 
95cdb44
bb5ec6d
 
e173a84
 
242b67f
 
4a3a204
e173a84
 
 
 
bb5ec6d
 
e173a84
 
 
2a6db88
e173a84
4a3a204
848a638
 
4a3a204
e173a84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a24abaa
3076001
 
928efab
 
 
 
dae33bc
 
 
 
 
 
 
 
928efab
 
dae33bc
928efab
 
dae33bc
 
 
 
 
 
 
 
 
 
 
d79a08c
928efab
 
e173a84
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import io
import requests
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from PIL import Image
import gradio as gr
import uform
from datetime import datetime



model_multi = uform.get_model('unum-cloud/uform-vl-multilingual')

embeddings = np.load('tensors/embeddings.npy')
embeddings = torch.tensor(embeddings)

#features = np.load('multilingual-image-search/tensors/features.npy')
#features = torch.tensor(features)

img_df = pd.read_csv('image_data.csv')

def url2img(url, resize = False, fix_height = 150):
    data = requests.get(url, allow_redirects = True).content
    img = Image.open(io.BytesIO(data))
    if resize:
         img.thumbnail([fix_height, fix_height], Image.LANCZOS)
    return  img

def find_topk(text):

    print('text', text)

    top_k = 20

    text_data = model_multi.preprocess_text(text)
    text_features, text_embedding = model_multi.encode_text(text_data, return_features=True)

    print('Got features', datetime.now().strftime("%H:%M:%S"))

    sims = F.cosine_similarity(text_embedding, embeddings)

    vals, inds = sims.topk(top_k)
    top_k_urls = img_df.iloc[inds]['photo_image_url'].values

    print('Got top_k_urls', top_k_urls)
    print(datetime.now().strftime("%H:%M:%S"))

    return top_k_urls



# def rerank(text_features, text_data):

#     # craet joint embeddings & get scores
#     joint_embedding = model_multi.encode_multimodal(
#         image_features=image_features,
#         text_features=text_features,
#         attention_mask=text_data['attention_mask']
#     )
#     score = model_multi.get_matching_scores(joint_embedding)

#     # argmax to get top N

#     return


#demo = gr.Interface(find_topk, inputs = 'text', outputs = 'image')

print('version', gr.__version__)

with gr.Blocks(theme=gr.themes.Soft()) as demo:
        gr.Markdown('# Enter a prompt in one of the supported languages.')
        with gr.Row():
            with gr.Column():
                gr.Markdown( 
                            '||||||\n'
                            '|:-------:  |:---: |:-------:   |:---: | :--- |\n'
                            '|__English__| #    |__French__  | #    |__Russian__|\n'
                            '|__German__ | #    |__Italian__ | #    |__Chinese (Simplified)__|\n'
                            '|__Spanish__| #    |__Japanese__| #    |__Korean__|\n'
                            '|__Turkish__| #    |__Polish__  | #    |.|\n')

            with gr.Column():
                prompt_box = gr.Textbox(label = 'Enter your prompt', lines = 3, container = True)
                btn_search = gr.Button("Find images")

        with gr.Row():
            gr.Examples(['a girl wandering alone in the forest',
                        'морозное утро в городе', 
                        '카메라를 바라보는 강아지 새끼', 
                        'ein Schloss, das zwischen modernen Gebäuden hervorlugt',
                        'un couple sirotant un café au bord de la rivière',
                        'una banda de música actuando en un gran espacio al aire libre',
                        '秋の静かな霧の庭園'
                        ], inputs=[prompt_box]) 
            

        gallery = gr.Gallery().style(grid = [5], height="auto")
        btn_search.click(find_topk, inputs = prompt_box, outputs = gallery)

if __name__ == "__main__":
    demo.launch()