Spaces:
Sleeping
Sleeping
import pickle | |
import os | |
from sklearn.neighbors import NearestNeighbors | |
import numpy as np | |
import gradio as gr | |
from PIL import Image | |
data_root = 'https://ai-vision-public-datasets.s3.eu.cloud-object-storage.appdomain.cloud/DomainNet' | |
feat_dir = 'brad_feats' | |
domains = ['sketch', 'painting', 'clipart', 'real'] | |
shots = '-1' | |
num_nn = 20 | |
search_domain = 'all' | |
num_results_per_domain = 5 | |
src_data_dict = {} | |
if search_domain == 'all': | |
for d in domains: | |
with open(os.path.join(feat_dir, f'dst_{d}_{shots}.pkl'), 'rb') as fp: | |
src_data = pickle.load(fp) | |
src_nn_fit = NearestNeighbors(n_neighbors=num_results_per_domain, algorithm='auto', n_jobs=-1).fit(src_data[1]) | |
src_data_dict[d] = (src_data,src_nn_fit) | |
else: | |
with open(os.path.join(feat_dir, f'dst_{search_domain}_{shots}.pkl'), 'rb') as fp: | |
src_data = pickle.load(fp) | |
src_nn_fit = NearestNeighbors(n_neighbors=num_results_per_domain, algorithm='auto', n_jobs=-1).fit(src_data[1]) | |
src_data_dict[search_domain] = (src_data,src_nn_fit) | |
dst_data_dict = {} | |
min_len = 1e10 | |
for d in domains: | |
with open(os.path.join(feat_dir, f'src_{d}_{shots}.pkl'), 'rb') as fp: | |
dst_data_dict[d] = pickle.load(fp) | |
min_len = min(min_len, len(dst_data_dict[d][0])) | |
def query(query_index, query_domain): | |
dst_data = dst_data_dict[query_domain] | |
dst_img_path = os.path.join(data_root, dst_data[0][query_index]) | |
img_paths = [dst_img_path] | |
q_cl = dst_img_path.split('/')[-2] | |
captions = [f'Query: {q_cl}'.title()] | |
for s_domain, s_data in src_data_dict.items(): | |
_, top_n_matches_ids = s_data[1].kneighbors(dst_data[1][query_index:query_index+1]) | |
top_n_labels = s_data[0][2][top_n_matches_ids][0] | |
src_img_pths = [os.path.join(data_root, s_data[0][0][ix]) for ix in top_n_matches_ids[0]] | |
img_paths += src_img_pths | |
for p in src_img_pths: | |
src_cl = p.split('/')[-2] | |
src_file = p.split('/')[-1] | |
captions.append(src_cl.title()) | |
print(img_paths) | |
return tuple([p for p in img_paths])+ tuple(captions) | |
demo = gr.Blocks() | |
with demo: | |
gr.Markdown('#Unsupervised Domain Generalization by Learning a Bridge Across Domains') | |
gr.Markdown('This demo showcases the cross-domain retrieval capabilities of our self-supervised cross domain training as presented @CVPR 2022. For details please refer to [the paper](https://openaccess.thecvf.com/content/CVPR2022/papers/Harary_Unsupervised_Domain_Generalization_by_Learning_a_Bridge_Across_Domains_CVPR_2022_paper.pdf)') | |
gr.Markdown('## Select Query Domain: ') | |
domain_drop = gr.Dropdown(domains) | |
# domain_select_button = gr.Button("Select Domain") | |
slider = gr.Slider(0, min_len) | |
# slider = gr.Slider(0, 10000) | |
image_button = gr.Button("Run") | |
with gr.Row(): | |
gr.Markdown('# Query Image: \t\t\t\t ') | |
gr.Markdown('\t') | |
gr.Markdown('\t') | |
gr.Markdown('\t') | |
with gr.Column(): | |
src_cap = gr.Label() | |
src_img = gr.Image() | |
out_images = [] | |
out_captions = [] | |
for d in domains: | |
gr.Markdown(f'# {d.title()} Domain Images') | |
with gr.Row(): | |
for _ in range(num_results_per_domain): | |
with gr.Column(): | |
out_captions.append(gr.Label()) | |
out_images.append(gr.Image()) | |
image_button.click(query, inputs=[slider, domain_drop], outputs=[src_img]+out_images +[src_cap]+ out_captions) | |
demo.launch(share=True) | |