File size: 5,846 Bytes
e32307e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import pickle
import os
from sklearn.neighbors import NearestNeighbors
import numpy as np
import gradio as gr
from PIL import Image

data_root = '.'
img_data_root = 'https://nextup-public-media.s3.ap.cloud-object-storage.appdomain.cloud/documents/search/ikea/'
feat_dir = os.path.join(data_root, 'feats')
doc_names = os.listdir(feat_dir)
num_nn = 5

# search_domain = 'all'
# num_results_per_domain = 5
src_data_dict = {}
for doc_name in doc_names:
    with open(os.path.join(feat_dir, doc_name), 'rb') as fp:
        doc_data = pickle.load(fp)
        t2i_space = NearestNeighbors(n_neighbors=num_nn, algorithm='auto', n_jobs=-1, metric='correlation').fit(doc_data['image_feat'])
        i2t_space = NearestNeighbors(n_neighbors=num_nn, algorithm='auto', n_jobs=-1, metric='correlation').fit(doc_data['text_feat'])
        src_data_dict[doc_name] = (doc_data, t2i_space, i2t_space)

def query_i2t(query_index, query_doc):
    doc_data = src_data_dict[query_doc][0]
#     src_img_path = os.path.join(data_root, doc_data['img_paths'][query_index])
    src_img_path = os.path.join(img_data_root,
                                      '/'.join(doc_data['img_paths'][query_index].split('/')[-2:]))
#     print(src_img_path)
#     src_page_path = 'pages'.join(src_img_path.split('images'))
#     src_page_path = '_'.join(src_page_path.split('_')[:-1])+'.png'
    page_file_name = os.path.basename(src_img_path).split('_')[0] + f".{os.path.basename(src_img_path).split('.')[-1]}"
    src_page_path = os.path.join(os.path.dirname(src_img_path), page_file_name)
    _, top_n_matches_ids = src_data_dict[query_doc][2].kneighbors(doc_data['image_feat'][query_index].unsqueeze(0))
    captions = [doc_data['texts'][i] for i in top_n_matches_ids[0]]
    return [src_page_path] + captions

def query_t2i(query_index, query_doc):
    doc_data = src_data_dict[query_doc][0]
    src_txt = doc_data['texts'][query_index]
    _, top_n_matches_ids = src_data_dict[query_doc][1].kneighbors(doc_data['text_feat'][query_index].unsqueeze(0))
    dst_image_paths = []
    dst_page_paths = []
    for i in range(num_nn):
#         dst_img_path = os.path.join(data_root, doc_data['img_paths'][top_n_matches_ids[0][i]])
        dst_img_path = os.path.join(img_data_root,
                                      '/'.join(doc_data['img_paths'][top_n_matches_ids[0][i]].split('/')[-2:]))

        dst_image_paths.append(dst_img_path)
#         dst_page_path = 'pages'.join(dst_img_path.split('images'))
#         dst_page_path = '_'.join(dst_page_path.split('_')[:-1])+'.png'
        page_file_name = os.path.basename(dst_img_path).split('_')[0] + f".{os.path.basename(dst_img_path).split('.')[-1]}"
        dst_page_path = os.path.join(os.path.dirname(dst_img_path), page_file_name)
        dst_page_paths.append(dst_page_path)
    return [src_txt] + dst_page_paths

demo = gr.Blocks()
with demo:
    gr.Markdown('# FETA towards Specializing Foundational Models for Expert Task Applications')
    gr.Markdown('This demo showcases the txt to image and image to text retrieval capabilities of FETA.')
    gr.Markdown('The model is trained in an self-supervised automated manner on a folder of PDF documents without any manual labels.')
    gr.Markdown('## Instructions:')
    gr.Markdown('Select a query domain and a class from the drop-down menus and select any random image index from the domain using the slider below, then press the "Run" button. The query image and the retrieved results from each of the four domains, along with the class label will be presented.')
    gr.Markdown('## Select Query Domain: ')
    gr.Markdown('# Query Image: \t\t\t\t')
    # domain_drop = gr.Dropdown(domains)
    # cl_drop = gr.Dropdown(class_list)
    # domain_select_button = gr.Button("Select Domain")
    # slider = gr.Slider(0, min_len)
    # slider = gr.Slider(0, 10000)
    with gr.Tabs():
        with gr.TabItem("image to text"):
            with gr.Row():
                with gr.Column():
                    doc_drop_i2t = gr.Dropdown(doc_names, label='Doc name')
                    slider_i2t = gr.Slider(0, 100, label='Query image selector slider') # TODO: make this len(doc_drop) instead

                # gr.Markdown('\t')
                # gr.Markdown('\t')
                # gr.Markdown('\t')
                with gr.Column():
#                     src_img_i2t = gr.Image()
                    src_page_i2t = gr.Image()
            button_i2t = gr.Button("Run")
            out_captions_i2t = []
            gr.Markdown(f'# Retrieved texts:')
            with gr.Row():
                for _ in range(num_nn):
                    with gr.Column():
                        out_captions_i2t.append(gr.Label())
        with gr.TabItem("text to image"):
            with gr.Row():
                with gr.Column():
                    doc_drop_t2i = gr.Dropdown(doc_names, label='Doc name')
                    slider_t2i = gr.Slider(0, 100, label='Query text selector slider') # TODO: make this len(doc_drop) instead

                # gr.Markdown('\t')
                # gr.Markdown('\t')
                # gr.Markdown('\t')
                with gr.Column():
                    src_caption_t2i = gr.Text()

            button_t2i = gr.Button("Run")
            dst_images_t2i = []
            dst_pages_t2i = []
            gr.Markdown(f'# Retrieved images:')
            for _ in range(num_nn):
                with gr.Row():
#                     with gr.Column():
#                         dst_images_t2i.append(gr.Image())
                    with gr.Column():
                        dst_pages_t2i.append(gr.Image())

    button_i2t.click(query_i2t, inputs=[slider_i2t, doc_drop_i2t], outputs=[src_page_i2t] + out_captions_i2t)
    button_t2i.click(query_t2i, inputs=[slider_t2i, doc_drop_t2i], outputs= [src_caption_t2i] + dst_pages_t2i)

demo.launch(share=True)