FETA_IKEA / app.py
Amit Alfassy
first commit
e32307e
import pickle
import os
from sklearn.neighbors import NearestNeighbors
import numpy as np
import gradio as gr
from PIL import Image
data_root = '.'
img_data_root = 'https://nextup-public-media.s3.ap.cloud-object-storage.appdomain.cloud/documents/search/ikea/'
feat_dir = os.path.join(data_root, 'feats')
doc_names = os.listdir(feat_dir)
num_nn = 5
# search_domain = 'all'
# num_results_per_domain = 5
src_data_dict = {}
for doc_name in doc_names:
with open(os.path.join(feat_dir, doc_name), 'rb') as fp:
doc_data = pickle.load(fp)
t2i_space = NearestNeighbors(n_neighbors=num_nn, algorithm='auto', n_jobs=-1, metric='correlation').fit(doc_data['image_feat'])
i2t_space = NearestNeighbors(n_neighbors=num_nn, algorithm='auto', n_jobs=-1, metric='correlation').fit(doc_data['text_feat'])
src_data_dict[doc_name] = (doc_data, t2i_space, i2t_space)
def query_i2t(query_index, query_doc):
doc_data = src_data_dict[query_doc][0]
# src_img_path = os.path.join(data_root, doc_data['img_paths'][query_index])
src_img_path = os.path.join(img_data_root,
'/'.join(doc_data['img_paths'][query_index].split('/')[-2:]))
# print(src_img_path)
# src_page_path = 'pages'.join(src_img_path.split('images'))
# src_page_path = '_'.join(src_page_path.split('_')[:-1])+'.png'
page_file_name = os.path.basename(src_img_path).split('_')[0] + f".{os.path.basename(src_img_path).split('.')[-1]}"
src_page_path = os.path.join(os.path.dirname(src_img_path), page_file_name)
_, top_n_matches_ids = src_data_dict[query_doc][2].kneighbors(doc_data['image_feat'][query_index].unsqueeze(0))
captions = [doc_data['texts'][i] for i in top_n_matches_ids[0]]
return [src_page_path] + captions
def query_t2i(query_index, query_doc):
doc_data = src_data_dict[query_doc][0]
src_txt = doc_data['texts'][query_index]
_, top_n_matches_ids = src_data_dict[query_doc][1].kneighbors(doc_data['text_feat'][query_index].unsqueeze(0))
dst_image_paths = []
dst_page_paths = []
for i in range(num_nn):
# dst_img_path = os.path.join(data_root, doc_data['img_paths'][top_n_matches_ids[0][i]])
dst_img_path = os.path.join(img_data_root,
'/'.join(doc_data['img_paths'][top_n_matches_ids[0][i]].split('/')[-2:]))
dst_image_paths.append(dst_img_path)
# dst_page_path = 'pages'.join(dst_img_path.split('images'))
# dst_page_path = '_'.join(dst_page_path.split('_')[:-1])+'.png'
page_file_name = os.path.basename(dst_img_path).split('_')[0] + f".{os.path.basename(dst_img_path).split('.')[-1]}"
dst_page_path = os.path.join(os.path.dirname(dst_img_path), page_file_name)
dst_page_paths.append(dst_page_path)
return [src_txt] + dst_page_paths
demo = gr.Blocks()
with demo:
gr.Markdown('# FETA towards Specializing Foundational Models for Expert Task Applications')
gr.Markdown('This demo showcases the txt to image and image to text retrieval capabilities of FETA.')
gr.Markdown('The model is trained in an self-supervised automated manner on a folder of PDF documents without any manual labels.')
gr.Markdown('## Instructions:')
gr.Markdown('Select a query domain and a class from the drop-down menus and select any random image index from the domain using the slider below, then press the "Run" button. The query image and the retrieved results from each of the four domains, along with the class label will be presented.')
gr.Markdown('## Select Query Domain: ')
gr.Markdown('# Query Image: \t\t\t\t')
# domain_drop = gr.Dropdown(domains)
# cl_drop = gr.Dropdown(class_list)
# domain_select_button = gr.Button("Select Domain")
# slider = gr.Slider(0, min_len)
# slider = gr.Slider(0, 10000)
with gr.Tabs():
with gr.TabItem("image to text"):
with gr.Row():
with gr.Column():
doc_drop_i2t = gr.Dropdown(doc_names, label='Doc name')
slider_i2t = gr.Slider(0, 100, label='Query image selector slider') # TODO: make this len(doc_drop) instead
# gr.Markdown('\t')
# gr.Markdown('\t')
# gr.Markdown('\t')
with gr.Column():
# src_img_i2t = gr.Image()
src_page_i2t = gr.Image()
button_i2t = gr.Button("Run")
out_captions_i2t = []
gr.Markdown(f'# Retrieved texts:')
with gr.Row():
for _ in range(num_nn):
with gr.Column():
out_captions_i2t.append(gr.Label())
with gr.TabItem("text to image"):
with gr.Row():
with gr.Column():
doc_drop_t2i = gr.Dropdown(doc_names, label='Doc name')
slider_t2i = gr.Slider(0, 100, label='Query text selector slider') # TODO: make this len(doc_drop) instead
# gr.Markdown('\t')
# gr.Markdown('\t')
# gr.Markdown('\t')
with gr.Column():
src_caption_t2i = gr.Text()
button_t2i = gr.Button("Run")
dst_images_t2i = []
dst_pages_t2i = []
gr.Markdown(f'# Retrieved images:')
for _ in range(num_nn):
with gr.Row():
# with gr.Column():
# dst_images_t2i.append(gr.Image())
with gr.Column():
dst_pages_t2i.append(gr.Image())
button_i2t.click(query_i2t, inputs=[slider_i2t, doc_drop_i2t], outputs=[src_page_i2t] + out_captions_i2t)
button_t2i.click(query_t2i, inputs=[slider_t2i, doc_drop_t2i], outputs= [src_caption_t2i] + dst_pages_t2i)
demo.launch(share=True)