import pickle import os from sklearn.neighbors import NearestNeighbors import numpy as np num_nn = 20 import gradio as gr from PIL import Image data_root = '/dccstor/elishc1/datasets/DomainNet' feat_dir = 'brad_feats' domains = ['real', 'painting', 'clipart', 'sketch'] shots = '-1' search_domain = 'all' num_results_per_domain = 5 src_data_dict = {} if search_domain == 'all': for d in domains: with open(os.path.join(feat_dir, f'dst_{d}_{shots}.pkl'), 'rb') as fp: src_data = pickle.load(fp) src_nn_fit = NearestNeighbors(n_neighbors=num_results_per_domain, algorithm='auto', n_jobs=-1).fit(src_data[1]) src_data_dict[d] = (src_data,src_nn_fit) else: with open(os.path.join(feat_dir, f'dst_{search_domain}_{shots}.pkl'), 'rb') as fp: src_data = pickle.load(fp) src_nn_fit = NearestNeighbors(n_neighbors=num_results_per_domain, algorithm='auto', n_jobs=-1).fit(src_data[1]) src_data_dict[search_domain] = (src_data,src_nn_fit) dst_data_dict = {} for d in domains: with open(os.path.join(feat_dir, f'src_{d}_{shots}.pkl'), 'rb') as fp: dst_data_dict[d] = pickle.load(fp) def query(query_index, query_domain): dst_data = dst_data_dict[query_domain] dst_img_path = os.path.join(data_root, dst_data[0][query_index]) img_paths = [dst_img_path] q_cl = dst_img_path.split('/')[-2] captions = [f'Query: {q_cl}'] for s_domain, s_data in src_data_dict.items(): _, top_n_matches_ids = s_data[1].kneighbors(dst_data[1][query_index:query_index+1]) top_n_labels = s_data[0][2][top_n_matches_ids][0] src_img_pths = [os.path.join(data_root, s_data[0][0][ix]) for ix in top_n_matches_ids[0]] img_paths += src_img_pths for p in src_img_pths: src_cl = p.split('/')[-2] src_file = p.split('/')[-1] captions.append(src_cl) return tuple([Image.open(p) for p in img_paths])+ tuple(captions) try: demo.close() except: pass demo = gr.Blocks() with demo: gr.Markdown('## Select Query Domain: ') domain_drop = gr.Dropdown(domains) # domain_select_button = gr.Button("Select Domain") slider = gr.Slider(0, 1000) image_button = gr.Button("Run") gr.Markdown('# Query Image') src_cap = gr.Label() src_img = gr.Image() out_images = [] out_captions = [] for d in domains: gr.Markdown(f'# {d.title()} Domain Images') with gr.Row(): for _ in range(num_results_per_domain): with gr.Column(): out_captions.append(gr.Label()) out_images.append(gr.Image()) image_button.click(query, inputs=[slider, domain_drop], outputs=[src_img]+out_images +[src_cap]+ out_captions) demo.launch(share=True)