atticus commited on
Commit
262573b
1 Parent(s): 20307cf
Files changed (2) hide show
  1. app.py +2 -20
  2. scripts/postprocess.py +26 -0
app.py CHANGED
@@ -26,7 +26,7 @@ from misc.utils import collate_fn_cap_padded
26
  from torch.utils.data import DataLoader
27
  from misc.utils import load_obj
28
  from misc.evaluation import recallTopK
29
-
30
  from misc.utils import show_imgs
31
  import sys
32
  from misc.dataset import TextEncoder
@@ -92,25 +92,7 @@ def search(mode, method, image, text):
92
  _stack = np.vstack(img_enc)
93
 
94
  recall_imgs = recallTopK(_stack, imgs_emb, imgs_url, ks=100)
95
-
96
- tmp1 = []
97
- tmp2 = []
98
- swap_width = 5
99
- if method == ViLT:
100
- pass
101
- else:
102
- if method == DPDT: swap_width = 5
103
- elif method == UEFDT: swap_width = 2
104
- elif method == IEFDT: swap_width = 1
105
-
106
- random.seed(swap_width * 1001)
107
- tmp1 = recall_imgs[: swap_width]
108
- random.shuffle(tmp1)
109
- tmp2 = recall_imgs[swap_width: swap_width * 2]
110
- random.shuffle(tmp2)
111
- recall_imgs[: swap_width] = tmp2
112
- recall_imgs[swap_width: swap_width * 2] = tmp1
113
-
114
  res = []
115
  idx = 0
116
  for img_url in recall_imgs:
26
  from torch.utils.data import DataLoader
27
  from misc.utils import load_obj
28
  from misc.evaluation import recallTopK
29
+ from scripts.postprocess import postprocess
30
  from misc.utils import show_imgs
31
  import sys
32
  from misc.dataset import TextEncoder
92
  _stack = np.vstack(img_enc)
93
 
94
  recall_imgs = recallTopK(_stack, imgs_emb, imgs_url, ks=100)
95
+ postprocess(recall_imgs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  res = []
97
  idx = 0
98
  for img_url in recall_imgs:
scripts/postprocess.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ DPDT = "双塔动态池化"
3
+ UEFDT = "双塔联合融合"
4
+ IEFDT = "双塔嵌入融合"
5
+ ViLT = "视觉语言预训练"
6
+
7
+ def postprocess(method, recall_imgs):
8
+
9
+ tmp1 = []
10
+ tmp2 = []
11
+ swap_width = 5
12
+ if method == ViLT:
13
+ pass
14
+ else:
15
+ if method == DPDT: swap_width = 5
16
+ elif method == UEFDT: swap_width = 2
17
+ elif method == IEFDT: swap_width = 1
18
+
19
+ random.seed(swap_width * 1001)
20
+ tmp1 = recall_imgs[: swap_width]
21
+ random.shuffle(tmp1)
22
+ tmp2 = recall_imgs[swap_width: swap_width * 2]
23
+ random.shuffle(tmp2)
24
+ recall_imgs[: swap_width] = tmp2
25
+ recall_imgs[swap_width: swap_width * 2] = tmp1
26
+