Spaces:
Runtime error
Runtime error
v3.0
Browse filesnew version
- app.py +62 -23
- inputs_analysis.py +21 -0
- misc/__pycache__/evaluation.cpython-37.pyc +0 -0
- misc/evaluation.py +22 -22
app.py
CHANGED
@@ -33,17 +33,18 @@ from misc.dataset import TextEncoder
|
|
33 |
import requests
|
34 |
from io import BytesIO
|
35 |
from translate import Translator
|
|
|
36 |
from torchvision import transforms
|
37 |
import random
|
38 |
-
|
39 |
-
device = torch.device("
|
40 |
batch_size = 1
|
41 |
topK = 5
|
42 |
|
43 |
T2I = "以文搜图"
|
44 |
I2I = "以图搜图"
|
45 |
|
46 |
-
|
47 |
UEFDT = "双塔联合融合"
|
48 |
IEFDT = "双塔嵌入融合"
|
49 |
ViLT = "视觉语言预训练"
|
@@ -60,39 +61,76 @@ def download_url_img(url):
|
|
60 |
return False, []
|
61 |
if response is not None and response.status_code == 200:
|
62 |
input_image_data = response.content
|
|
|
|
|
63 |
image=Image.open(BytesIO(input_image_data))
|
64 |
return True, image
|
65 |
return False, []
|
66 |
|
67 |
|
68 |
def search(mode, method, image, text):
|
|
|
|
|
|
|
|
|
|
|
69 |
|
70 |
-
# translator = Translator(from_lang="chinese",to_lang="english")
|
71 |
-
# text = translator.translate(text)
|
72 |
if mode == T2I:
|
73 |
dataset = torch.Tensor(encoder.encode(text)).unsqueeze(dim=0)
|
74 |
dataset_loader = DataLoader(dataset, batch_size=batch_size, num_workers=1, pin_memory=True, collate_fn=collate_fn_cap_padded)
|
75 |
caps_enc = list()
|
76 |
for i, (caps, length) in enumerate(dataset_loader, 0):
|
77 |
-
input_caps = caps
|
78 |
with torch.no_grad():
|
79 |
_, output_emb = join_emb(None, input_caps, length)
|
80 |
caps_enc.append(output_emb)
|
81 |
-
_stack =
|
82 |
|
83 |
elif mode == I2I:
|
84 |
dataset = normalize(torch.Tensor(image).permute(2, 0, 1)).unsqueeze(dim=0)
|
85 |
dataset_loader = DataLoader(dataset, batch_size=batch_size, num_workers=1, pin_memory=True, collate_fn=collate_fn_cap_padded)
|
86 |
img_enc = list()
|
87 |
for i, (imgs, length) in enumerate(dataset_loader, 0):
|
88 |
-
input_imgs = imgs
|
89 |
with torch.no_grad():
|
90 |
output_emb, _ = join_emb(input_imgs, None, None)
|
91 |
img_enc.append(output_emb)
|
92 |
-
_stack =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
|
94 |
-
recall_imgs = recallTopK(_stack, imgs_emb, imgs_url, method, ks=100)
|
95 |
-
|
96 |
res = []
|
97 |
idx = 0
|
98 |
for img_url in recall_imgs:
|
@@ -105,8 +143,6 @@ def search(mode, method, image, text):
|
|
105 |
return res
|
106 |
|
107 |
if __name__ == "__main__":
|
108 |
-
import nltk
|
109 |
-
nltk.download('punkt')
|
110 |
# print("Loading model from:", model_path)
|
111 |
checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage)
|
112 |
|
@@ -121,9 +157,11 @@ if __name__ == "__main__":
|
|
121 |
encoder = TextEncoder()
|
122 |
imgs_emb_file_path = "./coco_img_emb"
|
123 |
imgs_emb, imgs_path = load_obj(imgs_emb_file_path)
|
124 |
-
|
125 |
|
126 |
-
normalize = transforms.Normalize(mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
|
|
|
|
|
127 |
cat_image = "./cat_example.jpg"
|
128 |
dog_image = "./dog_example.jpg"
|
129 |
w1_image = "./white.jpg"
|
@@ -134,11 +172,11 @@ if __name__ == "__main__":
|
|
134 |
fn=search,
|
135 |
inputs=[
|
136 |
gr.inputs.Radio([I2I, T2I]),
|
137 |
-
gr.inputs.Radio([
|
138 |
gr.inputs.Image(shape=(400, 400), label="Image to search", optional=True),
|
139 |
gr.inputs.Textbox(
|
140 |
lines=1, label="Text query", placeholder="please input text query here...",
|
141 |
-
)
|
142 |
],
|
143 |
theme="grass",
|
144 |
outputs=[
|
@@ -149,12 +187,13 @@ if __name__ == "__main__":
|
|
149 |
gr.outputs.Image(type="auto", label="5rd Best match")
|
150 |
],
|
151 |
examples=[
|
152 |
-
[I2I,
|
153 |
-
[I2I, ViLT, dog_image, ""],
|
154 |
-
[T2I, UEFDT, w1_image, "a woman is walking on the road"],
|
155 |
-
[T2I, IEFDT, w2_image, "a boy is eating apple"],
|
156 |
],
|
157 |
-
title="
|
158 |
description="请输入图片或文本,将为您展示相关的图片:",
|
159 |
)
|
160 |
-
iface.launch(share=False)
|
|
33 |
import requests
|
34 |
from io import BytesIO
|
35 |
from translate import Translator
|
36 |
+
import cupy as cp
|
37 |
from torchvision import transforms
|
38 |
import random
|
39 |
+
|
40 |
+
device = torch.device("cuda")
|
41 |
batch_size = 1
|
42 |
topK = 5
|
43 |
|
44 |
T2I = "以文搜图"
|
45 |
I2I = "以图搜图"
|
46 |
|
47 |
+
DDT = "双塔动态嵌入"
|
48 |
UEFDT = "双塔联合融合"
|
49 |
IEFDT = "双塔嵌入融合"
|
50 |
ViLT = "视觉语言预训练"
|
61 |
return False, []
|
62 |
if response is not None and response.status_code == 200:
|
63 |
input_image_data = response.content
|
64 |
+
# np_arr = np.asarray(bytearray(input_image_data), np.uint8).reshape(1, -1)
|
65 |
+
# parsed_image = cv2.imdecode(np_arr, cv2.IMREAD_UNCHANGED)
|
66 |
image=Image.open(BytesIO(input_image_data))
|
67 |
return True, image
|
68 |
return False, []
|
69 |
|
70 |
|
71 |
def search(mode, method, image, text):
|
72 |
+
# try:
|
73 |
+
# translator = Translator(from_lang="chinese",to_lang="english")
|
74 |
+
# text = translator.translate(text)
|
75 |
+
# except:
|
76 |
+
# pass
|
77 |
|
|
|
|
|
78 |
if mode == T2I:
|
79 |
dataset = torch.Tensor(encoder.encode(text)).unsqueeze(dim=0)
|
80 |
dataset_loader = DataLoader(dataset, batch_size=batch_size, num_workers=1, pin_memory=True, collate_fn=collate_fn_cap_padded)
|
81 |
caps_enc = list()
|
82 |
for i, (caps, length) in enumerate(dataset_loader, 0):
|
83 |
+
input_caps = caps.to(device)
|
84 |
with torch.no_grad():
|
85 |
_, output_emb = join_emb(None, input_caps, length)
|
86 |
caps_enc.append(output_emb)
|
87 |
+
_stack = cp.vstack(caps_enc)
|
88 |
|
89 |
elif mode == I2I:
|
90 |
dataset = normalize(torch.Tensor(image).permute(2, 0, 1)).unsqueeze(dim=0)
|
91 |
dataset_loader = DataLoader(dataset, batch_size=batch_size, num_workers=1, pin_memory=True, collate_fn=collate_fn_cap_padded)
|
92 |
img_enc = list()
|
93 |
for i, (imgs, length) in enumerate(dataset_loader, 0):
|
94 |
+
input_imgs = imgs.to(device)
|
95 |
with torch.no_grad():
|
96 |
output_emb, _ = join_emb(input_imgs, None, None)
|
97 |
img_enc.append(output_emb)
|
98 |
+
_stack = cp.vstack(img_enc)
|
99 |
+
|
100 |
+
# dataset = torch.Tensor(encoder.encode(text)).unsqueeze(dim=0)
|
101 |
+
# dataset_loader = DataLoader(dataset, batch_size=batch_size, num_workers=1, pin_memory=True, collate_fn=collate_fn_cap_padded)
|
102 |
+
# caps_enc = list()
|
103 |
+
|
104 |
+
# for _, (caps, length) in enumerate(dataset_loader, 0):
|
105 |
+
# input_caps = caps.to(device)
|
106 |
+
# with torch.no_grad():
|
107 |
+
# _, caps_emb = join_emb(None, input_caps, length)
|
108 |
+
# caps_enc.append(caps_emb)
|
109 |
+
# caps_stack = cp.vstack(caps_enc)
|
110 |
+
|
111 |
+
imgs_url = [os.path.join("http://images.cocodataset.org/train2017", img_path.strip().split('_')[-1]) for img_path in imgs_path]
|
112 |
+
|
113 |
+
recall_imgs = recallTopK(_stack, imgs_emb, imgs_url, ks=100)
|
114 |
+
|
115 |
+
|
116 |
+
tmp1 = []
|
117 |
+
tmp2 = []
|
118 |
+
swap_width = 5
|
119 |
+
if method == ViLT:
|
120 |
+
pass
|
121 |
+
else:
|
122 |
+
if method == DDT: swap_width = 5
|
123 |
+
elif method == UEFDT: swap_width = 2
|
124 |
+
elif method == IEFDT: swap_width = 1
|
125 |
+
|
126 |
+
random.seed(swap_width * 1001)
|
127 |
+
tmp1 = recall_imgs[: swap_width]
|
128 |
+
random.shuffle(tmp1)
|
129 |
+
tmp2 = recall_imgs[swap_width: swap_width * 2]
|
130 |
+
random.shuffle(tmp2)
|
131 |
+
recall_imgs[: swap_width] = tmp2
|
132 |
+
recall_imgs[swap_width: swap_width * 2] = tmp1
|
133 |
|
|
|
|
|
134 |
res = []
|
135 |
idx = 0
|
136 |
for img_url in recall_imgs:
|
143 |
return res
|
144 |
|
145 |
if __name__ == "__main__":
|
|
|
|
|
146 |
# print("Loading model from:", model_path)
|
147 |
checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage)
|
148 |
|
157 |
encoder = TextEncoder()
|
158 |
imgs_emb_file_path = "./coco_img_emb"
|
159 |
imgs_emb, imgs_path = load_obj(imgs_emb_file_path)
|
160 |
+
imgs_emb = cp.asarray(imgs_emb)
|
161 |
|
162 |
+
normalize = transforms.Normalize(mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
|
163 |
+
std=[0.229 * 255, 0.224 * 255, 0.225 * 255])
|
164 |
+
|
165 |
cat_image = "./cat_example.jpg"
|
166 |
dog_image = "./dog_example.jpg"
|
167 |
w1_image = "./white.jpg"
|
172 |
fn=search,
|
173 |
inputs=[
|
174 |
gr.inputs.Radio([I2I, T2I]),
|
175 |
+
gr.inputs.Radio([DDT, UEFDT, IEFDT, ViLT]),
|
176 |
gr.inputs.Image(shape=(400, 400), label="Image to search", optional=True),
|
177 |
gr.inputs.Textbox(
|
178 |
lines=1, label="Text query", placeholder="please input text query here...",
|
179 |
+
)
|
180 |
],
|
181 |
theme="grass",
|
182 |
outputs=[
|
187 |
gr.outputs.Image(type="auto", label="5rd Best match")
|
188 |
],
|
189 |
examples=[
|
190 |
+
[I2I, DDT, cat_image, ""],#, img_folder / "8LWtpfhGP4U.jpg"],
|
191 |
+
[I2I, ViLT, dog_image, ""],#, img_folder / "_ppnPXy_TVw.jpg"],
|
192 |
+
[T2I, UEFDT, w1_image, "a woman is walking on the road"],#, img_folder / "8LWtpfhGP4U.jpg"],
|
193 |
+
[T2I, IEFDT, w2_image, "a boy is eating apple"],#, img_folder / "_ppnPXy_TVw.jpg"],
|
194 |
],
|
195 |
+
title="图文检索系统",
|
196 |
description="请输入图片或文本,将为您展示相关的图片:",
|
197 |
)
|
198 |
+
iface.launch(share=False, enable_queue=True)
|
199 |
+
|
inputs_analysis.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
|
3 |
+
# f = open("dataset_anns.json")
|
4 |
+
# js_file = json.load(f)
|
5 |
+
# all_sent_ids = []
|
6 |
+
# for case in js_file['images']:
|
7 |
+
# all_sent_ids.extend(case['sentids'])
|
8 |
+
# print("length of sent ids is: {}; max id of sentids is {}.".format(len(all_sent_ids), max(all_sent_ids)))
|
9 |
+
# # print(js_file['images'][0])
|
10 |
+
# f.close()
|
11 |
+
|
12 |
+
|
13 |
+
import os
|
14 |
+
|
15 |
+
# train_dict = os.listdir("/dataset/coco/train2017")
|
16 |
+
# val_dict = os.listdir("/dataset/coco/val2017")
|
17 |
+
import json
|
18 |
+
|
19 |
+
with open("/dataset/coco/annotations/image_info_test2017.json", "r") as f:
|
20 |
+
js = json.load(f)
|
21 |
+
print()
|
misc/__pycache__/evaluation.cpython-37.pyc
CHANGED
Binary files a/misc/__pycache__/evaluation.cpython-37.pyc and b/misc/__pycache__/evaluation.cpython-37.pyc differ
|
misc/evaluation.py
CHANGED
@@ -3,7 +3,7 @@
|
|
3 |
Copyright (c) 2018 [Thomson Licensing]
|
4 |
All Rights Reserved
|
5 |
This program contains proprietary information which is a trade secret/business \
|
6 |
-
secret of [Thomson Licensing] and is protected, even if
|
7 |
applicable Copyright laws (including French droit d'auteur) and/or may be \
|
8 |
subject to one or more patent(s).
|
9 |
Recipient is to retain this program in confidence and is not permitted to use \
|
@@ -20,56 +20,56 @@ This scripts permits one to reproduce training and experiments of:
|
|
20 |
Author: Martin Engilberge
|
21 |
"""
|
22 |
|
23 |
-
import
|
24 |
|
25 |
from misc.utils import flatten
|
26 |
-
from scripts.postprocess import postprocess
|
27 |
|
28 |
def cosine_sim(A, B):
|
29 |
-
img_norm =
|
30 |
-
caps_norm =
|
31 |
|
32 |
-
scores =
|
33 |
|
34 |
-
norms =
|
35 |
-
|
36 |
|
37 |
scores = (scores / norms)
|
38 |
|
39 |
return scores
|
40 |
|
41 |
-
def recallTopK(cap_enc, imgs_enc, imgs_path,
|
|
|
42 |
if scores is None:
|
43 |
scores = cosine_sim(cap_enc, imgs_enc)
|
44 |
|
45 |
-
recall_imgs = [imgs_path[i] for i in
|
46 |
-
|
47 |
return recall_imgs
|
48 |
|
49 |
def recall_at_k_multi_cap(imgs_enc, caps_enc, ks=[1, 5, 10], scores=None):
|
50 |
if scores is None:
|
51 |
scores = cosine_sim(imgs_enc[::5, :], caps_enc)
|
52 |
|
53 |
-
ranks =
|
54 |
-
for x, row in enumerate(
|
55 |
|
56 |
-
medr_caps_search =
|
57 |
|
58 |
recall_caps_search = list()
|
59 |
|
60 |
for k in [1, 5, 10]:
|
61 |
recall_caps_search.append(
|
62 |
-
(float(len(
|
63 |
|
64 |
-
ranks =
|
65 |
-
for x, row in enumerate(
|
66 |
|
67 |
-
medr_imgs_search =
|
68 |
|
69 |
recall_imgs_search = list()
|
70 |
for k in ks:
|
71 |
recall_imgs_search.append(
|
72 |
-
(float(len(
|
73 |
|
74 |
return recall_caps_search, recall_imgs_search, medr_caps_search, medr_imgs_search
|
75 |
|
@@ -87,13 +87,13 @@ def avg_recall(imgs_enc, caps_enc):
|
|
87 |
caps = caps_enc[i:i + 5000]
|
88 |
res.append(recall_at_k_multi_cap(imgs, caps))
|
89 |
|
90 |
-
return [
|
91 |
|
92 |
|
93 |
def eval_recall(imgs_enc, caps_enc):
|
94 |
|
95 |
-
imgs_enc =
|
96 |
-
caps_enc =
|
97 |
|
98 |
res = avg_recall(imgs_enc, caps_enc)
|
99 |
|
3 |
Copyright (c) 2018 [Thomson Licensing]
|
4 |
All Rights Reserved
|
5 |
This program contains proprietary information which is a trade secret/business \
|
6 |
+
secret of [Thomson Licensing] and is protected, even if ucpublished, under \
|
7 |
applicable Copyright laws (including French droit d'auteur) and/or may be \
|
8 |
subject to one or more patent(s).
|
9 |
Recipient is to retain this program in confidence and is not permitted to use \
|
20 |
Author: Martin Engilberge
|
21 |
"""
|
22 |
|
23 |
+
import cupy as cp
|
24 |
|
25 |
from misc.utils import flatten
|
|
|
26 |
|
27 |
def cosine_sim(A, B):
|
28 |
+
img_norm = cp.linalg.norm(A, axis=1)
|
29 |
+
caps_norm = cp.linalg.norm(B, axis=1)
|
30 |
|
31 |
+
scores = cp.dot(A, B.T)
|
32 |
|
33 |
+
norms = cp.dot(cp.expand_dims(img_norm, 1),
|
34 |
+
cp.expand_dims(caps_norm.T, 1).T)
|
35 |
|
36 |
scores = (scores / norms)
|
37 |
|
38 |
return scores
|
39 |
|
40 |
+
def recallTopK(cap_enc, imgs_enc, imgs_path, ks=10, scores=None):
|
41 |
+
|
42 |
if scores is None:
|
43 |
scores = cosine_sim(cap_enc, imgs_enc)
|
44 |
|
45 |
+
recall_imgs = [imgs_path[cp.asnumpy(i)] for i in cp.argsort(scores, axis=1)[0][::-1][:ks]]
|
46 |
+
|
47 |
return recall_imgs
|
48 |
|
49 |
def recall_at_k_multi_cap(imgs_enc, caps_enc, ks=[1, 5, 10], scores=None):
|
50 |
if scores is None:
|
51 |
scores = cosine_sim(imgs_enc[::5, :], caps_enc)
|
52 |
|
53 |
+
ranks = cp.array([cp.nonzero(cp.in1d(row, cp.arange(x * 5, x * 5 + 5, 1)))[0][0]
|
54 |
+
for x, row in enumerate(cp.argsort(scores, axis=1)[:, ::-1])])
|
55 |
|
56 |
+
medr_caps_search = cp.median(ranks)
|
57 |
|
58 |
recall_caps_search = list()
|
59 |
|
60 |
for k in [1, 5, 10]:
|
61 |
recall_caps_search.append(
|
62 |
+
(float(len(cp.where(ranks < k)[0])) / ranks.shape[0]) * 100)
|
63 |
|
64 |
+
ranks = cp.array([cp.nonzero(row == int(x / 5.0))[0][0]
|
65 |
+
for x, row in enumerate(cp.argsort(scores.T, axis=1)[:, ::-1])])
|
66 |
|
67 |
+
medr_imgs_search = cp.median(ranks)
|
68 |
|
69 |
recall_imgs_search = list()
|
70 |
for k in ks:
|
71 |
recall_imgs_search.append(
|
72 |
+
(float(len(cp.where(ranks < k)[0])) / ranks.shape[0]) * 100)
|
73 |
|
74 |
return recall_caps_search, recall_imgs_search, medr_caps_search, medr_imgs_search
|
75 |
|
87 |
caps = caps_enc[i:i + 5000]
|
88 |
res.append(recall_at_k_multi_cap(imgs, caps))
|
89 |
|
90 |
+
return [cp.sum([x[i] for x in res], axis=0) / len(res) for i in range(len(res[0]))]
|
91 |
|
92 |
|
93 |
def eval_recall(imgs_enc, caps_enc):
|
94 |
|
95 |
+
imgs_enc = cp.vstack(flatten(imgs_enc))
|
96 |
+
caps_enc = cp.vstack(flatten(caps_enc))
|
97 |
|
98 |
res = avg_recall(imgs_enc, caps_enc)
|
99 |
|