aarbelle commited on
Commit
fb018ef
1 Parent(s): d269720
Files changed (1) hide show
  1. app.py +11 -8
app.py CHANGED
@@ -38,16 +38,19 @@ min_len = 1e10
38
  for d in domains:
39
  with open(os.path.join(feat_dir, f'src_{d}_{shots}.pkl'), 'rb') as fp:
40
  dest_data = pickle.load(fp)
41
- dst_data_dict[d] = {cl: [] for cl in class_list}
42
- for p in dest_data[0]:
43
  cl = p.split('/')[-2]
44
- dst_data_dict[d][cl].append(p)
 
 
45
  for cl in class_list:
46
- min_len = min(min_len, len(dst_data_dict[d][cl]))
47
 
48
  def query(query_index, query_domain, cl):
49
  dst_data = dst_data_dict[query_domain]
50
- dst_img_path = os.path.join(data_root, dst_data[cl][query_index])
 
51
  img_paths = [dst_img_path]
52
  q_cl = dst_img_path.split('/')[-2]
53
  captions = [f'Query: {q_cl}'.title()]
@@ -61,7 +64,7 @@ def query(query_index, query_domain, cl):
61
  src_cl = p.split('/')[-2]
62
  src_file = p.split('/')[-1]
63
  captions.append(src_cl.title())
64
- print(img_paths)
65
  return tuple([p for p in img_paths])+ tuple(captions)
66
 
67
  demo = gr.Blocks()
@@ -80,8 +83,8 @@ with demo:
80
  image_button = gr.Button("Run")
81
  with gr.Row():
82
  with gr.Column():
83
- domain_drop = gr.Dropdown(domains)
84
- cl_drop = gr.Dropdown(class_list)
85
  slider = gr.Slider(0, 100)
86
  # gr.Markdown('\t')
87
  # gr.Markdown('\t')
 
38
  for d in domains:
39
  with open(os.path.join(feat_dir, f'src_{d}_{shots}.pkl'), 'rb') as fp:
40
  dest_data = pickle.load(fp)
41
+ dst_data_dict[d] = ({cl: ([],[]) for cl in class_list},dest_data[1])
42
+ for c, p in enumerate(dest_data[0]):
43
  cl = p.split('/')[-2]
44
+ dst_data_dict[d][0][cl][0].append(p)
45
+ dst_data_dict[d][0][cl][1].append(c)
46
+
47
  for cl in class_list:
48
+ min_len = min(min_len, len(dst_data_dict[d][0][cl]))
49
 
50
  def query(query_index, query_domain, cl):
51
  dst_data = dst_data_dict[query_domain]
52
+ dst_img_path = os.path.join(data_root, dst_data[0][cl][0][query_index])
53
+ query_index = dst_data[0][cl][1][query_index]
54
  img_paths = [dst_img_path]
55
  q_cl = dst_img_path.split('/')[-2]
56
  captions = [f'Query: {q_cl}'.title()]
 
64
  src_cl = p.split('/')[-2]
65
  src_file = p.split('/')[-1]
66
  captions.append(src_cl.title())
67
+ # print(img_paths)
68
  return tuple([p for p in img_paths])+ tuple(captions)
69
 
70
  demo = gr.Blocks()
 
83
  image_button = gr.Button("Run")
84
  with gr.Row():
85
  with gr.Column():
86
+ domain_drop = gr.Dropdown(domains, label='Domain')
87
+ cl_drop = gr.Dropdown(class_list, label='Query Class')
88
  slider = gr.Slider(0, 100)
89
  # gr.Markdown('\t')
90
  # gr.Markdown('\t')