import pickle import os from sklearn.neighbors import NearestNeighbors import numpy as np import gradio as gr from PIL import Image data_root = '.' img_data_root = 'https://nextup-public-media.s3.ap.cloud-object-storage.appdomain.cloud/documents/search/ikea/' feat_dir = os.path.join(data_root, 'feats') doc_names = os.listdir(feat_dir) num_nn = 5 # search_domain = 'all' # num_results_per_domain = 5 src_data_dict = {} for doc_name in doc_names: with open(os.path.join(feat_dir, doc_name), 'rb') as fp: doc_data = pickle.load(fp) t2i_space = NearestNeighbors(n_neighbors=num_nn, algorithm='auto', n_jobs=-1, metric='correlation').fit(doc_data['image_feat']) i2t_space = NearestNeighbors(n_neighbors=num_nn, algorithm='auto', n_jobs=-1, metric='correlation').fit(doc_data['text_feat']) src_data_dict[doc_name] = (doc_data, t2i_space, i2t_space) def query_i2t(query_index, query_doc): doc_data = src_data_dict[query_doc][0] # src_img_path = os.path.join(data_root, doc_data['img_paths'][query_index]) src_img_path = os.path.join(img_data_root, '/'.join(doc_data['img_paths'][query_index].split('/')[-2:])) # print(src_img_path) # src_page_path = 'pages'.join(src_img_path.split('images')) # src_page_path = '_'.join(src_page_path.split('_')[:-1])+'.png' page_file_name = os.path.basename(src_img_path).split('_')[0] + f".{os.path.basename(src_img_path).split('.')[-1]}" src_page_path = os.path.join(os.path.dirname(src_img_path), page_file_name) _, top_n_matches_ids = src_data_dict[query_doc][2].kneighbors(doc_data['image_feat'][query_index].unsqueeze(0)) captions = [doc_data['texts'][i] for i in top_n_matches_ids[0]] return [src_page_path] + captions def query_t2i(query_index, query_doc): doc_data = src_data_dict[query_doc][0] src_txt = doc_data['texts'][query_index] _, top_n_matches_ids = src_data_dict[query_doc][1].kneighbors(doc_data['text_feat'][query_index].unsqueeze(0)) dst_image_paths = [] dst_page_paths = [] for i in range(num_nn): # dst_img_path = os.path.join(data_root, doc_data['img_paths'][top_n_matches_ids[0][i]]) dst_img_path = os.path.join(img_data_root, '/'.join(doc_data['img_paths'][top_n_matches_ids[0][i]].split('/')[-2:])) dst_image_paths.append(dst_img_path) # dst_page_path = 'pages'.join(dst_img_path.split('images')) # dst_page_path = '_'.join(dst_page_path.split('_')[:-1])+'.png' page_file_name = os.path.basename(dst_img_path).split('_')[0] + f".{os.path.basename(dst_img_path).split('.')[-1]}" dst_page_path = os.path.join(os.path.dirname(dst_img_path), page_file_name) dst_page_paths.append(dst_page_path) return [src_txt] + dst_page_paths demo = gr.Blocks() with demo: gr.Markdown('# FETA towards Specializing Foundational Models for Expert Task Applications') gr.Markdown('This demo showcases the txt to image and image to text retrieval capabilities of FETA.') gr.Markdown('The model is trained in an self-supervised automated manner on a folder of PDF documents without any manual labels.') gr.Markdown('## Instructions:') gr.Markdown('Select a query domain and a class from the drop-down menus and select any random image index from the domain using the slider below, then press the "Run" button. The query image and the retrieved results from each of the four domains, along with the class label will be presented.') gr.Markdown('## Select Query Domain: ') gr.Markdown('# Query Image: \t\t\t\t') # domain_drop = gr.Dropdown(domains) # cl_drop = gr.Dropdown(class_list) # domain_select_button = gr.Button("Select Domain") # slider = gr.Slider(0, min_len) # slider = gr.Slider(0, 10000) with gr.Tabs(): with gr.TabItem("image to text"): with gr.Row(): with gr.Column(): doc_drop_i2t = gr.Dropdown(doc_names, label='Doc name') slider_i2t = gr.Slider(0, 100, label='Query image selector slider') # TODO: make this len(doc_drop) instead # gr.Markdown('\t') # gr.Markdown('\t') # gr.Markdown('\t') with gr.Column(): # src_img_i2t = gr.Image() src_page_i2t = gr.Image() button_i2t = gr.Button("Run") out_captions_i2t = [] gr.Markdown(f'# Retrieved texts:') with gr.Row(): for _ in range(num_nn): with gr.Column(): out_captions_i2t.append(gr.Label()) with gr.TabItem("text to image"): with gr.Row(): with gr.Column(): doc_drop_t2i = gr.Dropdown(doc_names, label='Doc name') slider_t2i = gr.Slider(0, 100, label='Query text selector slider') # TODO: make this len(doc_drop) instead # gr.Markdown('\t') # gr.Markdown('\t') # gr.Markdown('\t') with gr.Column(): src_caption_t2i = gr.Text() button_t2i = gr.Button("Run") dst_images_t2i = [] dst_pages_t2i = [] gr.Markdown(f'# Retrieved images:') for _ in range(num_nn): with gr.Row(): # with gr.Column(): # dst_images_t2i.append(gr.Image()) with gr.Column(): dst_pages_t2i.append(gr.Image()) button_i2t.click(query_i2t, inputs=[slider_i2t, doc_drop_i2t], outputs=[src_page_i2t] + out_captions_i2t) button_t2i.click(query_t2i, inputs=[slider_t2i, doc_drop_t2i], outputs= [src_caption_t2i] + dst_pages_t2i) demo.launch(share=True)