Spaces:
Runtime error
Runtime error
v3.0
Browse files- app.py +15 -22
- misc/evaluation.py +22 -20
app.py
CHANGED
@@ -33,7 +33,7 @@ from misc.dataset import TextEncoder
|
|
33 |
import requests
|
34 |
from io import BytesIO
|
35 |
from translate import Translator
|
36 |
-
import cupy as cp
|
37 |
from torchvision import transforms
|
38 |
import random
|
39 |
|
@@ -46,7 +46,7 @@ topK = 5
|
|
46 |
T2I = "以文搜图"
|
47 |
I2I = "以图搜图"
|
48 |
|
49 |
-
|
50 |
UEFDT = "双塔联合融合"
|
51 |
IEFDT = "双塔嵌入融合"
|
52 |
ViLT = "视觉语言预训练"
|
@@ -82,37 +82,28 @@ def search(mode, method, image, text):
|
|
82 |
dataset_loader = DataLoader(dataset, batch_size=batch_size, num_workers=1, pin_memory=True, collate_fn=collate_fn_cap_padded)
|
83 |
caps_enc = list()
|
84 |
for i, (caps, length) in enumerate(dataset_loader, 0):
|
85 |
-
input_caps = caps
|
86 |
with torch.no_grad():
|
87 |
_, output_emb = join_emb(None, input_caps, length)
|
88 |
caps_enc.append(output_emb)
|
89 |
-
_stack =
|
90 |
|
91 |
elif mode == I2I:
|
92 |
dataset = normalize(torch.Tensor(image).permute(2, 0, 1)).unsqueeze(dim=0)
|
93 |
dataset_loader = DataLoader(dataset, batch_size=batch_size, num_workers=1, pin_memory=True, collate_fn=collate_fn_cap_padded)
|
94 |
img_enc = list()
|
95 |
for i, (imgs, length) in enumerate(dataset_loader, 0):
|
96 |
-
input_imgs = imgs
|
97 |
with torch.no_grad():
|
98 |
output_emb, _ = join_emb(input_imgs, None, None)
|
99 |
img_enc.append(output_emb)
|
100 |
-
_stack =
|
101 |
|
102 |
-
# dataset = torch.Tensor(encoder.encode(text)).unsqueeze(dim=0)
|
103 |
-
# dataset_loader = DataLoader(dataset, batch_size=batch_size, num_workers=1, pin_memory=True, collate_fn=collate_fn_cap_padded)
|
104 |
-
# caps_enc = list()
|
105 |
|
106 |
-
# for _, (caps, length) in enumerate(dataset_loader, 0):
|
107 |
-
# input_caps = caps.to(device)
|
108 |
-
# with torch.no_grad():
|
109 |
-
# _, caps_emb = join_emb(None, input_caps, length)
|
110 |
-
# caps_enc.append(caps_emb)
|
111 |
-
# caps_stack = cp.vstack(caps_enc)
|
112 |
|
113 |
imgs_url = [os.path.join("http://images.cocodataset.org/train2017", img_path.strip().split('_')[-1]) for img_path in imgs_path]
|
114 |
|
115 |
-
recall_imgs = recallTopK(_stack, imgs_emb, imgs_url, ks=100)
|
116 |
|
117 |
|
118 |
tmp1 = []
|
@@ -121,7 +112,7 @@ def search(mode, method, image, text):
|
|
121 |
if method == ViLT:
|
122 |
pass
|
123 |
else:
|
124 |
-
if method ==
|
125 |
elif method == UEFDT: swap_width = 2
|
126 |
elif method == IEFDT: swap_width = 1
|
127 |
|
@@ -146,6 +137,8 @@ def search(mode, method, image, text):
|
|
146 |
|
147 |
if __name__ == "__main__":
|
148 |
# print("Loading model from:", model_path)
|
|
|
|
|
149 |
checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage)
|
150 |
|
151 |
join_emb = joint_embedding(checkpoint['args_dict'])
|
@@ -159,10 +152,10 @@ if __name__ == "__main__":
|
|
159 |
encoder = TextEncoder()
|
160 |
imgs_emb_file_path = "./coco_img_emb"
|
161 |
imgs_emb, imgs_path = load_obj(imgs_emb_file_path)
|
162 |
-
|
|
|
163 |
|
164 |
-
normalize = transforms.Normalize(mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
|
165 |
-
std=[0.229 * 255, 0.224 * 255, 0.225 * 255])
|
166 |
|
167 |
cat_image = "./cat_example.jpg"
|
168 |
dog_image = "./dog_example.jpg"
|
@@ -174,7 +167,7 @@ if __name__ == "__main__":
|
|
174 |
fn=search,
|
175 |
inputs=[
|
176 |
gr.inputs.Radio([I2I, T2I]),
|
177 |
-
gr.inputs.Radio([
|
178 |
gr.inputs.Image(shape=(400, 400), label="Image to search", optional=True),
|
179 |
gr.inputs.Textbox(
|
180 |
lines=1, label="Text query", placeholder="please input text query here...",
|
@@ -189,7 +182,7 @@ if __name__ == "__main__":
|
|
189 |
gr.outputs.Image(type="auto", label="5rd Best match")
|
190 |
],
|
191 |
examples=[
|
192 |
-
[I2I,
|
193 |
[I2I, ViLT, dog_image, ""],#, img_folder / "_ppnPXy_TVw.jpg"],
|
194 |
[T2I, UEFDT, w1_image, "a woman is walking on the road"],#, img_folder / "8LWtpfhGP4U.jpg"],
|
195 |
[T2I, IEFDT, w2_image, "a boy is eating apple"],#, img_folder / "_ppnPXy_TVw.jpg"],
|
33 |
import requests
|
34 |
from io import BytesIO
|
35 |
from translate import Translator
|
36 |
+
# import cupy as cp
|
37 |
from torchvision import transforms
|
38 |
import random
|
39 |
|
46 |
T2I = "以文搜图"
|
47 |
I2I = "以图搜图"
|
48 |
|
49 |
+
DPDT = "双塔动态嵌入"
|
50 |
UEFDT = "双塔联合融合"
|
51 |
IEFDT = "双塔嵌入融合"
|
52 |
ViLT = "视觉语言预训练"
|
82 |
dataset_loader = DataLoader(dataset, batch_size=batch_size, num_workers=1, pin_memory=True, collate_fn=collate_fn_cap_padded)
|
83 |
caps_enc = list()
|
84 |
for i, (caps, length) in enumerate(dataset_loader, 0):
|
85 |
+
input_caps = caps
|
86 |
with torch.no_grad():
|
87 |
_, output_emb = join_emb(None, input_caps, length)
|
88 |
caps_enc.append(output_emb)
|
89 |
+
_stack = np.vstack(caps_enc)
|
90 |
|
91 |
elif mode == I2I:
|
92 |
dataset = normalize(torch.Tensor(image).permute(2, 0, 1)).unsqueeze(dim=0)
|
93 |
dataset_loader = DataLoader(dataset, batch_size=batch_size, num_workers=1, pin_memory=True, collate_fn=collate_fn_cap_padded)
|
94 |
img_enc = list()
|
95 |
for i, (imgs, length) in enumerate(dataset_loader, 0):
|
96 |
+
input_imgs = imgs
|
97 |
with torch.no_grad():
|
98 |
output_emb, _ = join_emb(input_imgs, None, None)
|
99 |
img_enc.append(output_emb)
|
100 |
+
_stack = np.vstack(img_enc)
|
101 |
|
|
|
|
|
|
|
102 |
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
|
104 |
imgs_url = [os.path.join("http://images.cocodataset.org/train2017", img_path.strip().split('_')[-1]) for img_path in imgs_path]
|
105 |
|
106 |
+
recall_imgs = recallTopK(_stack, imgs_emb, imgs_url, method, ks=100)
|
107 |
|
108 |
|
109 |
tmp1 = []
|
112 |
if method == ViLT:
|
113 |
pass
|
114 |
else:
|
115 |
+
if method == DPDT: swap_width = 5
|
116 |
elif method == UEFDT: swap_width = 2
|
117 |
elif method == IEFDT: swap_width = 1
|
118 |
|
137 |
|
138 |
if __name__ == "__main__":
|
139 |
# print("Loading model from:", model_path)
|
140 |
+
import nltk
|
141 |
+
nltk.download('punkt')
|
142 |
checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage)
|
143 |
|
144 |
join_emb = joint_embedding(checkpoint['args_dict'])
|
152 |
encoder = TextEncoder()
|
153 |
imgs_emb_file_path = "./coco_img_emb"
|
154 |
imgs_emb, imgs_path = load_obj(imgs_emb_file_path)
|
155 |
+
imgs_url = [os.path.join("http://images.cocodataset.org/train2017", img_path.strip().split('_')[-1]) for img_path in imgs_path]
|
156 |
+
# imgs_emb = np.asarray(imgs_emb)
|
157 |
|
158 |
+
normalize = transforms.Normalize(mean=[0.485 * 255, 0.456 * 255, 0.406 * 255], std=[0.229 * 255, 0.224 * 255, 0.225 * 255])
|
|
|
159 |
|
160 |
cat_image = "./cat_example.jpg"
|
161 |
dog_image = "./dog_example.jpg"
|
167 |
fn=search,
|
168 |
inputs=[
|
169 |
gr.inputs.Radio([I2I, T2I]),
|
170 |
+
gr.inputs.Radio([DPDT, UEFDT, IEFDT, ViLT]),
|
171 |
gr.inputs.Image(shape=(400, 400), label="Image to search", optional=True),
|
172 |
gr.inputs.Textbox(
|
173 |
lines=1, label="Text query", placeholder="please input text query here...",
|
182 |
gr.outputs.Image(type="auto", label="5rd Best match")
|
183 |
],
|
184 |
examples=[
|
185 |
+
[I2I, DPDT, cat_image, ""],#, img_folder / "8LWtpfhGP4U.jpg"],
|
186 |
[I2I, ViLT, dog_image, ""],#, img_folder / "_ppnPXy_TVw.jpg"],
|
187 |
[T2I, UEFDT, w1_image, "a woman is walking on the road"],#, img_folder / "8LWtpfhGP4U.jpg"],
|
188 |
[T2I, IEFDT, w2_image, "a boy is eating apple"],#, img_folder / "_ppnPXy_TVw.jpg"],
|
misc/evaluation.py
CHANGED
@@ -19,30 +19,32 @@ This scripts permits one to reproduce training and experiments of:
|
|
19 |
|
20 |
Author: Martin Engilberge
|
21 |
"""
|
22 |
-
|
23 |
-
import
|
24 |
|
25 |
from misc.utils import flatten
|
26 |
|
27 |
def cosine_sim(A, B):
|
28 |
-
img_norm =
|
29 |
-
caps_norm =
|
30 |
|
31 |
-
scores =
|
32 |
|
33 |
-
norms =
|
34 |
-
|
35 |
|
36 |
scores = (scores / norms)
|
37 |
|
38 |
return scores
|
39 |
|
40 |
-
def recallTopK(cap_enc, imgs_enc, imgs_path, ks=10, scores=None):
|
41 |
|
42 |
if scores is None:
|
43 |
scores = cosine_sim(cap_enc, imgs_enc)
|
44 |
|
45 |
-
recall_imgs = [imgs_path[
|
|
|
|
|
46 |
|
47 |
return recall_imgs
|
48 |
|
@@ -50,26 +52,26 @@ def recall_at_k_multi_cap(imgs_enc, caps_enc, ks=[1, 5, 10], scores=None):
|
|
50 |
if scores is None:
|
51 |
scores = cosine_sim(imgs_enc[::5, :], caps_enc)
|
52 |
|
53 |
-
ranks =
|
54 |
-
for x, row in enumerate(
|
55 |
|
56 |
-
medr_caps_search =
|
57 |
|
58 |
recall_caps_search = list()
|
59 |
|
60 |
for k in [1, 5, 10]:
|
61 |
recall_caps_search.append(
|
62 |
-
(float(len(
|
63 |
|
64 |
-
ranks =
|
65 |
-
for x, row in enumerate(
|
66 |
|
67 |
-
medr_imgs_search =
|
68 |
|
69 |
recall_imgs_search = list()
|
70 |
for k in ks:
|
71 |
recall_imgs_search.append(
|
72 |
-
(float(len(
|
73 |
|
74 |
return recall_caps_search, recall_imgs_search, medr_caps_search, medr_imgs_search
|
75 |
|
@@ -87,13 +89,13 @@ def avg_recall(imgs_enc, caps_enc):
|
|
87 |
caps = caps_enc[i:i + 5000]
|
88 |
res.append(recall_at_k_multi_cap(imgs, caps))
|
89 |
|
90 |
-
return [
|
91 |
|
92 |
|
93 |
def eval_recall(imgs_enc, caps_enc):
|
94 |
|
95 |
-
imgs_enc =
|
96 |
-
caps_enc =
|
97 |
|
98 |
res = avg_recall(imgs_enc, caps_enc)
|
99 |
|
19 |
|
20 |
Author: Martin Engilberge
|
21 |
"""
|
22 |
+
from scripts.postprocess import postprocess
|
23 |
+
import numpy as np
|
24 |
|
25 |
from misc.utils import flatten
|
26 |
|
27 |
def cosine_sim(A, B):
|
28 |
+
img_norm = np.linalg.norm(A, axis=1)
|
29 |
+
caps_norm = np.linalg.norm(B, axis=1)
|
30 |
|
31 |
+
scores = np.dot(A, B.T)
|
32 |
|
33 |
+
norms = np.dot(np.expand_dims(img_norm, 1),
|
34 |
+
np.expand_dims(caps_norm.T, 1).T)
|
35 |
|
36 |
scores = (scores / norms)
|
37 |
|
38 |
return scores
|
39 |
|
40 |
+
def recallTopK(cap_enc, imgs_enc, imgs_path, method, ks=10, scores=None):
|
41 |
|
42 |
if scores is None:
|
43 |
scores = cosine_sim(cap_enc, imgs_enc)
|
44 |
|
45 |
+
# recall_imgs = [imgs_path[np.asnumpy(i)] for i in np.argsort(scores, axis=1)[0][::-1][:ks]]
|
46 |
+
recall_imgs = [imgs_path[i] for i in np.argsort(scores, axis=1)[0][::-1][:ks]]
|
47 |
+
postprocess(method, recall_imgs)
|
48 |
|
49 |
return recall_imgs
|
50 |
|
52 |
if scores is None:
|
53 |
scores = cosine_sim(imgs_enc[::5, :], caps_enc)
|
54 |
|
55 |
+
ranks = np.array([np.nonzero(np.in1d(row, np.arange(x * 5, x * 5 + 5, 1)))[0][0]
|
56 |
+
for x, row in enumerate(np.argsort(scores, axis=1)[:, ::-1])])
|
57 |
|
58 |
+
medr_caps_search = np.median(ranks)
|
59 |
|
60 |
recall_caps_search = list()
|
61 |
|
62 |
for k in [1, 5, 10]:
|
63 |
recall_caps_search.append(
|
64 |
+
(float(len(np.where(ranks < k)[0])) / ranks.shape[0]) * 100)
|
65 |
|
66 |
+
ranks = np.array([np.nonzero(row == int(x / 5.0))[0][0]
|
67 |
+
for x, row in enumerate(np.argsort(scores.T, axis=1)[:, ::-1])])
|
68 |
|
69 |
+
medr_imgs_search = np.median(ranks)
|
70 |
|
71 |
recall_imgs_search = list()
|
72 |
for k in ks:
|
73 |
recall_imgs_search.append(
|
74 |
+
(float(len(np.where(ranks < k)[0])) / ranks.shape[0]) * 100)
|
75 |
|
76 |
return recall_caps_search, recall_imgs_search, medr_caps_search, medr_imgs_search
|
77 |
|
89 |
caps = caps_enc[i:i + 5000]
|
90 |
res.append(recall_at_k_multi_cap(imgs, caps))
|
91 |
|
92 |
+
return [np.sum([x[i] for x in res], axis=0) / len(res) for i in range(len(res[0]))]
|
93 |
|
94 |
|
95 |
def eval_recall(imgs_enc, caps_enc):
|
96 |
|
97 |
+
imgs_enc = np.vstack(flatten(imgs_enc))
|
98 |
+
caps_enc = np.vstack(flatten(caps_enc))
|
99 |
|
100 |
res = avg_recall(imgs_enc, caps_enc)
|
101 |
|