aarbelle commited on
Commit
77733ea
1 Parent(s): 099eca2

add class selection

Browse files
Files changed (1) hide show
  1. app.py +27 -13
app.py CHANGED
@@ -14,11 +14,16 @@ num_nn = 20
14
  search_domain = 'all'
15
  num_results_per_domain = 5
16
  src_data_dict = {}
 
17
  if search_domain == 'all':
18
  for d in domains:
19
  with open(os.path.join(feat_dir, f'dst_{d}_{shots}.pkl'), 'rb') as fp:
20
  src_data = pickle.load(fp)
21
-
 
 
 
 
22
  src_nn_fit = NearestNeighbors(n_neighbors=num_results_per_domain, algorithm='auto', n_jobs=-1).fit(src_data[1])
23
  src_data_dict[d] = (src_data,src_nn_fit)
24
  else:
@@ -32,12 +37,17 @@ dst_data_dict = {}
32
  min_len = 1e10
33
  for d in domains:
34
  with open(os.path.join(feat_dir, f'src_{d}_{shots}.pkl'), 'rb') as fp:
35
- dst_data_dict[d] = pickle.load(fp)
36
- min_len = min(min_len, len(dst_data_dict[d][0]))
 
 
 
 
 
37
 
38
- def query(query_index, query_domain):
39
  dst_data = dst_data_dict[query_domain]
40
- dst_img_path = os.path.join(data_root, dst_data[0][query_index])
41
  img_paths = [dst_img_path]
42
  q_cl = dst_img_path.split('/')[-2]
43
  captions = [f'Query: {q_cl}'.title()]
@@ -61,17 +71,21 @@ with demo:
61
  gr.Markdown('## Instructions:')
62
  gr.Markdown('Select a query domain from the dropdown menu and the select any random image from the domain using the slider below. The retrieved results from each of the four domains, along with the class label will be presented.')
63
  gr.Markdown('## Select Query Domain: ')
64
- domain_drop = gr.Dropdown(domains)
 
 
65
  # domain_select_button = gr.Button("Select Domain")
66
- slider = gr.Slider(0, min_len)
67
  # slider = gr.Slider(0, 10000)
68
  image_button = gr.Button("Run")
69
-
70
  with gr.Row():
71
- gr.Markdown('# Query Image: \t\t\t\t ')
72
- gr.Markdown('\t')
73
- gr.Markdown('\t')
74
- gr.Markdown('\t')
 
 
 
75
  with gr.Column():
76
  src_cap = gr.Label()
77
  src_img = gr.Image()
@@ -87,6 +101,6 @@ with demo:
87
  out_captions.append(gr.Label())
88
  out_images.append(gr.Image())
89
 
90
- image_button.click(query, inputs=[slider, domain_drop], outputs=[src_img]+out_images +[src_cap]+ out_captions)
91
 
92
  demo.launch(share=True)
 
14
  search_domain = 'all'
15
  num_results_per_domain = 5
16
  src_data_dict = {}
17
+ class_list = []
18
  if search_domain == 'all':
19
  for d in domains:
20
  with open(os.path.join(feat_dir, f'dst_{d}_{shots}.pkl'), 'rb') as fp:
21
  src_data = pickle.load(fp)
22
+ if class_list == []:
23
+ for p in src_data[0]:
24
+ cl = p.split('/')[-2]
25
+ if cl not in class_list
26
+ class_list.append(cl)
27
  src_nn_fit = NearestNeighbors(n_neighbors=num_results_per_domain, algorithm='auto', n_jobs=-1).fit(src_data[1])
28
  src_data_dict[d] = (src_data,src_nn_fit)
29
  else:
 
37
  min_len = 1e10
38
  for d in domains:
39
  with open(os.path.join(feat_dir, f'src_{d}_{shots}.pkl'), 'rb') as fp:
40
+ dest_data = pickle.load(fp)
41
+ dst_data_dict[d] = {cl: [] for cl in class_list}
42
+ for p in dst_data[0]:
43
+ cl = p.split('/')[-2]
44
+ dst_data_dict[d][cl].append(p)
45
+ for cl in class_list:
46
+ min_len = min(min_len, len(dst_data_dict[d][cl]))
47
 
48
+ def query(query_index, query_domain, cl):
49
  dst_data = dst_data_dict[query_domain]
50
+ dst_img_path = os.path.join(data_root, dst_data[cl][query_index])
51
  img_paths = [dst_img_path]
52
  q_cl = dst_img_path.split('/')[-2]
53
  captions = [f'Query: {q_cl}'.title()]
 
71
  gr.Markdown('## Instructions:')
72
  gr.Markdown('Select a query domain from the dropdown menu and the select any random image from the domain using the slider below. The retrieved results from each of the four domains, along with the class label will be presented.')
73
  gr.Markdown('## Select Query Domain: ')
74
+ gr.Markdown('# Query Image: \t\t\t\t')
75
+ # domain_drop = gr.Dropdown(domains)
76
+ # cl_drop = gr.Dropdown(class_list)
77
  # domain_select_button = gr.Button("Select Domain")
78
+ # slider = gr.Slider(0, min_len)
79
  # slider = gr.Slider(0, 10000)
80
  image_button = gr.Button("Run")
 
81
  with gr.Row():
82
+ with gr.Column():
83
+ domain_drop = gr.Dropdown(domains)
84
+ cl_drop = gr.Dropdown(class_list)
85
+ slider = gr.Slider(0, min_len)
86
+ # gr.Markdown('\t')
87
+ # gr.Markdown('\t')
88
+ # gr.Markdown('\t')
89
  with gr.Column():
90
  src_cap = gr.Label()
91
  src_img = gr.Image()
 
101
  out_captions.append(gr.Label())
102
  out_images.append(gr.Image())
103
 
104
+ image_button.click(query, inputs=[slider, domain_drop, cl_drop], outputs=[src_img]+out_images +[src_cap]+ out_captions)
105
 
106
  demo.launch(share=True)