|
import pickle |
|
import os |
|
from sklearn.neighbors import NearestNeighbors |
|
import numpy as np |
|
import gradio as gr |
|
from PIL import Image |
|
|
|
data_root = '.' |
|
img_data_root = 'https://nextup-public-media.s3.ap.cloud-object-storage.appdomain.cloud/documents/search/ikea/' |
|
feat_dir = os.path.join(data_root, 'feats') |
|
doc_names = os.listdir(feat_dir) |
|
num_nn = 5 |
|
|
|
|
|
|
|
src_data_dict = {} |
|
for doc_name in doc_names: |
|
with open(os.path.join(feat_dir, doc_name), 'rb') as fp: |
|
doc_data = pickle.load(fp) |
|
t2i_space = NearestNeighbors(n_neighbors=num_nn, algorithm='auto', n_jobs=-1, metric='correlation').fit(doc_data['image_feat']) |
|
i2t_space = NearestNeighbors(n_neighbors=num_nn, algorithm='auto', n_jobs=-1, metric='correlation').fit(doc_data['text_feat']) |
|
src_data_dict[doc_name] = (doc_data, t2i_space, i2t_space) |
|
|
|
def query_i2t(query_index, query_doc): |
|
doc_data = src_data_dict[query_doc][0] |
|
|
|
src_img_path = os.path.join(img_data_root, |
|
'/'.join(doc_data['img_paths'][query_index].split('/')[-2:])) |
|
|
|
|
|
|
|
page_file_name = os.path.basename(src_img_path).split('_')[0] + f".{os.path.basename(src_img_path).split('.')[-1]}" |
|
src_page_path = os.path.join(os.path.dirname(src_img_path), page_file_name) |
|
_, top_n_matches_ids = src_data_dict[query_doc][2].kneighbors(doc_data['image_feat'][query_index].unsqueeze(0)) |
|
captions = [doc_data['texts'][i] for i in top_n_matches_ids[0]] |
|
return [src_page_path] + captions |
|
|
|
def query_t2i(query_index, query_doc): |
|
doc_data = src_data_dict[query_doc][0] |
|
src_txt = doc_data['texts'][query_index] |
|
_, top_n_matches_ids = src_data_dict[query_doc][1].kneighbors(doc_data['text_feat'][query_index].unsqueeze(0)) |
|
dst_image_paths = [] |
|
dst_page_paths = [] |
|
for i in range(num_nn): |
|
|
|
dst_img_path = os.path.join(img_data_root, |
|
'/'.join(doc_data['img_paths'][top_n_matches_ids[0][i]].split('/')[-2:])) |
|
|
|
dst_image_paths.append(dst_img_path) |
|
|
|
|
|
page_file_name = os.path.basename(dst_img_path).split('_')[0] + f".{os.path.basename(dst_img_path).split('.')[-1]}" |
|
dst_page_path = os.path.join(os.path.dirname(dst_img_path), page_file_name) |
|
dst_page_paths.append(dst_page_path) |
|
return [src_txt] + dst_page_paths |
|
|
|
demo = gr.Blocks() |
|
with demo: |
|
gr.Markdown('# FETA towards Specializing Foundational Models for Expert Task Applications') |
|
gr.Markdown('This demo showcases the txt to image and image to text retrieval capabilities of FETA.') |
|
gr.Markdown('The model is trained in an self-supervised automated manner on a folder of PDF documents without any manual labels.') |
|
gr.Markdown('## Instructions:') |
|
gr.Markdown('Select a query domain and a class from the drop-down menus and select any random image index from the domain using the slider below, then press the "Run" button. The query image and the retrieved results from each of the four domains, along with the class label will be presented.') |
|
gr.Markdown('## Select Query Domain: ') |
|
gr.Markdown('# Query Image: \t\t\t\t') |
|
|
|
|
|
|
|
|
|
|
|
with gr.Tabs(): |
|
with gr.TabItem("image to text"): |
|
with gr.Row(): |
|
with gr.Column(): |
|
doc_drop_i2t = gr.Dropdown(doc_names, label='Doc name') |
|
slider_i2t = gr.Slider(0, 100, label='Query image selector slider') |
|
|
|
|
|
|
|
|
|
with gr.Column(): |
|
|
|
src_page_i2t = gr.Image() |
|
button_i2t = gr.Button("Run") |
|
out_captions_i2t = [] |
|
gr.Markdown(f'# Retrieved texts:') |
|
with gr.Row(): |
|
for _ in range(num_nn): |
|
with gr.Column(): |
|
out_captions_i2t.append(gr.Label()) |
|
with gr.TabItem("text to image"): |
|
with gr.Row(): |
|
with gr.Column(): |
|
doc_drop_t2i = gr.Dropdown(doc_names, label='Doc name') |
|
slider_t2i = gr.Slider(0, 100, label='Query text selector slider') |
|
|
|
|
|
|
|
|
|
with gr.Column(): |
|
src_caption_t2i = gr.Text() |
|
|
|
button_t2i = gr.Button("Run") |
|
dst_images_t2i = [] |
|
dst_pages_t2i = [] |
|
gr.Markdown(f'# Retrieved images:') |
|
for _ in range(num_nn): |
|
with gr.Row(): |
|
|
|
|
|
with gr.Column(): |
|
dst_pages_t2i.append(gr.Image()) |
|
|
|
button_i2t.click(query_i2t, inputs=[slider_i2t, doc_drop_i2t], outputs=[src_page_i2t] + out_captions_i2t) |
|
button_t2i.click(query_t2i, inputs=[slider_t2i, doc_drop_t2i], outputs= [src_caption_t2i] + dst_pages_t2i) |
|
|
|
demo.launch(share=True) |
|
|