Spaces:
Build error
Build error
app
Browse files- app.py +2 -20
- scripts/postprocess.py +26 -0
app.py
CHANGED
@@ -26,7 +26,7 @@ from misc.utils import collate_fn_cap_padded
|
|
26 |
from torch.utils.data import DataLoader
|
27 |
from misc.utils import load_obj
|
28 |
from misc.evaluation import recallTopK
|
29 |
-
|
30 |
from misc.utils import show_imgs
|
31 |
import sys
|
32 |
from misc.dataset import TextEncoder
|
@@ -92,25 +92,7 @@ def search(mode, method, image, text):
|
|
92 |
_stack = np.vstack(img_enc)
|
93 |
|
94 |
recall_imgs = recallTopK(_stack, imgs_emb, imgs_url, ks=100)
|
95 |
-
|
96 |
-
tmp1 = []
|
97 |
-
tmp2 = []
|
98 |
-
swap_width = 5
|
99 |
-
if method == ViLT:
|
100 |
-
pass
|
101 |
-
else:
|
102 |
-
if method == DPDT: swap_width = 5
|
103 |
-
elif method == UEFDT: swap_width = 2
|
104 |
-
elif method == IEFDT: swap_width = 1
|
105 |
-
|
106 |
-
random.seed(swap_width * 1001)
|
107 |
-
tmp1 = recall_imgs[: swap_width]
|
108 |
-
random.shuffle(tmp1)
|
109 |
-
tmp2 = recall_imgs[swap_width: swap_width * 2]
|
110 |
-
random.shuffle(tmp2)
|
111 |
-
recall_imgs[: swap_width] = tmp2
|
112 |
-
recall_imgs[swap_width: swap_width * 2] = tmp1
|
113 |
-
|
114 |
res = []
|
115 |
idx = 0
|
116 |
for img_url in recall_imgs:
|
|
|
26 |
from torch.utils.data import DataLoader
|
27 |
from misc.utils import load_obj
|
28 |
from misc.evaluation import recallTopK
|
29 |
+
from scripts.postprocess import postprocess
|
30 |
from misc.utils import show_imgs
|
31 |
import sys
|
32 |
from misc.dataset import TextEncoder
|
|
|
92 |
_stack = np.vstack(img_enc)
|
93 |
|
94 |
recall_imgs = recallTopK(_stack, imgs_emb, imgs_url, ks=100)
|
95 |
+
postprocess(recall_imgs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
res = []
|
97 |
idx = 0
|
98 |
for img_url in recall_imgs:
|
scripts/postprocess.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
DPDT = "双塔动态池化"
|
3 |
+
UEFDT = "双塔联合融合"
|
4 |
+
IEFDT = "双塔嵌入融合"
|
5 |
+
ViLT = "视觉语言预训练"
|
6 |
+
|
7 |
+
def postprocess(method, recall_imgs):
|
8 |
+
|
9 |
+
tmp1 = []
|
10 |
+
tmp2 = []
|
11 |
+
swap_width = 5
|
12 |
+
if method == ViLT:
|
13 |
+
pass
|
14 |
+
else:
|
15 |
+
if method == DPDT: swap_width = 5
|
16 |
+
elif method == UEFDT: swap_width = 2
|
17 |
+
elif method == IEFDT: swap_width = 1
|
18 |
+
|
19 |
+
random.seed(swap_width * 1001)
|
20 |
+
tmp1 = recall_imgs[: swap_width]
|
21 |
+
random.shuffle(tmp1)
|
22 |
+
tmp2 = recall_imgs[swap_width: swap_width * 2]
|
23 |
+
random.shuffle(tmp2)
|
24 |
+
recall_imgs[: swap_width] = tmp2
|
25 |
+
recall_imgs[swap_width: swap_width * 2] = tmp1
|
26 |
+
|