import pickle import os from sklearn.neighbors import NearestNeighbors import numpy as np import gradio as gr from PIL import Image data_root = 'https://ai-vision-public-datasets.s3.eu.cloud-object-storage.appdomain.cloud/DomainNet' feat_dir = 'brad_feats' domains = ['real', 'painting', 'clipart', 'sketch'] shots = '-1' num_nn = 20 search_domain = 'all' num_results_per_domain = 5 src_data_dict = {} if search_domain == 'all': for d in domains: with open(os.path.join(feat_dir, f'dst_{d}_{shots}.pkl'), 'rb') as fp: src_data = pickle.load(fp) src_nn_fit = NearestNeighbors(n_neighbors=num_results_per_domain, algorithm='auto', n_jobs=-1).fit(src_data[1]) src_data_dict[d] = (src_data,src_nn_fit) else: with open(os.path.join(feat_dir, f'dst_{search_domain}_{shots}.pkl'), 'rb') as fp: src_data = pickle.load(fp) src_nn_fit = NearestNeighbors(n_neighbors=num_results_per_domain, algorithm='auto', n_jobs=-1).fit(src_data[1]) src_data_dict[search_domain] = (src_data,src_nn_fit) dst_data_dict = {} for d in domains: with open(os.path.join(feat_dir, f'src_{d}_{shots}.pkl'), 'rb') as fp: dst_data_dict[d] = pickle.load(fp) def query(query_index, query_domain): dst_data = dst_data_dict[query_domain] dst_img_path = os.path.join(data_root, dst_data[0][query_index]) img_paths = [dst_img_path] q_cl = dst_img_path.split('/')[-2] captions = [f'Query: {q_cl}'.title()] for s_domain, s_data in src_data_dict.items(): _, top_n_matches_ids = s_data[1].kneighbors(dst_data[1][query_index:query_index+1]) top_n_labels = s_data[0][2][top_n_matches_ids][0] src_img_pths = [os.path.join(data_root, s_data[0][0][ix]) for ix in top_n_matches_ids[0]] img_paths += src_img_pths for p in src_img_pths: src_cl = p.split('/')[-2] src_file = p.split('/')[-1] captions.append(src_cl.title()) print(img_paths) return tuple([p for p in img_paths])+ tuple(captions) demo = gr.Blocks() with demo: gr.Markdown('## Select Query Domain: ') domain_drop = gr.Dropdown(domains) # domain_select_button = gr.Button("Select Domain") slider = gr.Slider(0, 1000) image_button = gr.Button("Run") with gr.Row(): gr.Markdown('# Query Image: \t\t\t\t ') gr.Markdown('\t') gr.Markdown('\t') gr.Markdown('\t') with gr.Column(): src_cap = gr.Label() src_img = gr.Image() out_images = [] out_captions = [] for d in domains: gr.Markdown(f'# {d.title()} Domain Images') with gr.Row(): for _ in range(num_results_per_domain): with gr.Column(): out_captions.append(gr.Label()) out_images.append(gr.Image()) image_button.click(query, inputs=[slider, domain_drop], outputs=[src_img]+out_images +[src_cap]+ out_captions) demo.launch(share=True)