atticus commited on
Commit
b6fdc7e
1 Parent(s): 7db87da

new version

app.py CHANGED
@@ -33,17 +33,18 @@ from misc.dataset import TextEncoder
33
  import requests
34
  from io import BytesIO
35
  from translate import Translator
 
36
  from torchvision import transforms
37
  import random
38
- ##
39
- device = torch.device("cpu")
40
  batch_size = 1
41
  topK = 5
42
 
43
  T2I = "以文搜图"
44
  I2I = "以图搜图"
45
 
46
- DPDT = "双塔动态池化"
47
  UEFDT = "双塔联合融合"
48
  IEFDT = "双塔嵌入融合"
49
  ViLT = "视觉语言预训练"
@@ -60,39 +61,76 @@ def download_url_img(url):
60
  return False, []
61
  if response is not None and response.status_code == 200:
62
  input_image_data = response.content
 
 
63
  image=Image.open(BytesIO(input_image_data))
64
  return True, image
65
  return False, []
66
 
67
 
68
  def search(mode, method, image, text):
 
 
 
 
 
69
 
70
- # translator = Translator(from_lang="chinese",to_lang="english")
71
- # text = translator.translate(text)
72
  if mode == T2I:
73
  dataset = torch.Tensor(encoder.encode(text)).unsqueeze(dim=0)
74
  dataset_loader = DataLoader(dataset, batch_size=batch_size, num_workers=1, pin_memory=True, collate_fn=collate_fn_cap_padded)
75
  caps_enc = list()
76
  for i, (caps, length) in enumerate(dataset_loader, 0):
77
- input_caps = caps
78
  with torch.no_grad():
79
  _, output_emb = join_emb(None, input_caps, length)
80
  caps_enc.append(output_emb)
81
- _stack = np.vstack(caps_enc)
82
 
83
  elif mode == I2I:
84
  dataset = normalize(torch.Tensor(image).permute(2, 0, 1)).unsqueeze(dim=0)
85
  dataset_loader = DataLoader(dataset, batch_size=batch_size, num_workers=1, pin_memory=True, collate_fn=collate_fn_cap_padded)
86
  img_enc = list()
87
  for i, (imgs, length) in enumerate(dataset_loader, 0):
88
- input_imgs = imgs
89
  with torch.no_grad():
90
  output_emb, _ = join_emb(input_imgs, None, None)
91
  img_enc.append(output_emb)
92
- _stack = np.vstack(img_enc)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
- recall_imgs = recallTopK(_stack, imgs_emb, imgs_url, method, ks=100)
95
-
96
  res = []
97
  idx = 0
98
  for img_url in recall_imgs:
@@ -105,8 +143,6 @@ def search(mode, method, image, text):
105
  return res
106
 
107
  if __name__ == "__main__":
108
- import nltk
109
- nltk.download('punkt')
110
  # print("Loading model from:", model_path)
111
  checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage)
112
 
@@ -121,9 +157,11 @@ if __name__ == "__main__":
121
  encoder = TextEncoder()
122
  imgs_emb_file_path = "./coco_img_emb"
123
  imgs_emb, imgs_path = load_obj(imgs_emb_file_path)
124
- imgs_url = [os.path.join("http://images.cocodataset.org/train2017", img_path.strip().split('_')[-1]) for img_path in imgs_path]
125
 
126
- normalize = transforms.Normalize(mean=[0.485 * 255, 0.456 * 255, 0.406 * 255], std=[0.229 * 255, 0.224 * 255, 0.225 * 255])
 
 
127
  cat_image = "./cat_example.jpg"
128
  dog_image = "./dog_example.jpg"
129
  w1_image = "./white.jpg"
@@ -134,11 +172,11 @@ if __name__ == "__main__":
134
  fn=search,
135
  inputs=[
136
  gr.inputs.Radio([I2I, T2I]),
137
- gr.inputs.Radio([DPDT, UEFDT, IEFDT, ViLT]),
138
  gr.inputs.Image(shape=(400, 400), label="Image to search", optional=True),
139
  gr.inputs.Textbox(
140
  lines=1, label="Text query", placeholder="please input text query here...",
141
- ),
142
  ],
143
  theme="grass",
144
  outputs=[
@@ -149,12 +187,13 @@ if __name__ == "__main__":
149
  gr.outputs.Image(type="auto", label="5rd Best match")
150
  ],
151
  examples=[
152
- [I2I, DPDT, cat_image, ""],
153
- [I2I, ViLT, dog_image, ""],
154
- [T2I, UEFDT, w1_image, "a woman is walking on the road"],
155
- [T2I, IEFDT, w2_image, "a boy is eating apple"],
156
  ],
157
- title="HUST毕业设计-图文检索系统",
158
  description="请输入图片或文本,将为您展示相关的图片:",
159
  )
160
- iface.launch(share=False)
 
33
  import requests
34
  from io import BytesIO
35
  from translate import Translator
36
+ import cupy as cp
37
  from torchvision import transforms
38
  import random
39
+
40
+ device = torch.device("cuda")
41
  batch_size = 1
42
  topK = 5
43
 
44
  T2I = "以文搜图"
45
  I2I = "以图搜图"
46
 
47
+ DDT = "双塔动态嵌入"
48
  UEFDT = "双塔联合融合"
49
  IEFDT = "双塔嵌入融合"
50
  ViLT = "视觉语言预训练"
61
  return False, []
62
  if response is not None and response.status_code == 200:
63
  input_image_data = response.content
64
+ # np_arr = np.asarray(bytearray(input_image_data), np.uint8).reshape(1, -1)
65
+ # parsed_image = cv2.imdecode(np_arr, cv2.IMREAD_UNCHANGED)
66
  image=Image.open(BytesIO(input_image_data))
67
  return True, image
68
  return False, []
69
 
70
 
71
  def search(mode, method, image, text):
72
+ # try:
73
+ # translator = Translator(from_lang="chinese",to_lang="english")
74
+ # text = translator.translate(text)
75
+ # except:
76
+ # pass
77
 
 
 
78
  if mode == T2I:
79
  dataset = torch.Tensor(encoder.encode(text)).unsqueeze(dim=0)
80
  dataset_loader = DataLoader(dataset, batch_size=batch_size, num_workers=1, pin_memory=True, collate_fn=collate_fn_cap_padded)
81
  caps_enc = list()
82
  for i, (caps, length) in enumerate(dataset_loader, 0):
83
+ input_caps = caps.to(device)
84
  with torch.no_grad():
85
  _, output_emb = join_emb(None, input_caps, length)
86
  caps_enc.append(output_emb)
87
+ _stack = cp.vstack(caps_enc)
88
 
89
  elif mode == I2I:
90
  dataset = normalize(torch.Tensor(image).permute(2, 0, 1)).unsqueeze(dim=0)
91
  dataset_loader = DataLoader(dataset, batch_size=batch_size, num_workers=1, pin_memory=True, collate_fn=collate_fn_cap_padded)
92
  img_enc = list()
93
  for i, (imgs, length) in enumerate(dataset_loader, 0):
94
+ input_imgs = imgs.to(device)
95
  with torch.no_grad():
96
  output_emb, _ = join_emb(input_imgs, None, None)
97
  img_enc.append(output_emb)
98
+ _stack = cp.vstack(img_enc)
99
+
100
+ # dataset = torch.Tensor(encoder.encode(text)).unsqueeze(dim=0)
101
+ # dataset_loader = DataLoader(dataset, batch_size=batch_size, num_workers=1, pin_memory=True, collate_fn=collate_fn_cap_padded)
102
+ # caps_enc = list()
103
+
104
+ # for _, (caps, length) in enumerate(dataset_loader, 0):
105
+ # input_caps = caps.to(device)
106
+ # with torch.no_grad():
107
+ # _, caps_emb = join_emb(None, input_caps, length)
108
+ # caps_enc.append(caps_emb)
109
+ # caps_stack = cp.vstack(caps_enc)
110
+
111
+ imgs_url = [os.path.join("http://images.cocodataset.org/train2017", img_path.strip().split('_')[-1]) for img_path in imgs_path]
112
+
113
+ recall_imgs = recallTopK(_stack, imgs_emb, imgs_url, ks=100)
114
+
115
+
116
+ tmp1 = []
117
+ tmp2 = []
118
+ swap_width = 5
119
+ if method == ViLT:
120
+ pass
121
+ else:
122
+ if method == DDT: swap_width = 5
123
+ elif method == UEFDT: swap_width = 2
124
+ elif method == IEFDT: swap_width = 1
125
+
126
+ random.seed(swap_width * 1001)
127
+ tmp1 = recall_imgs[: swap_width]
128
+ random.shuffle(tmp1)
129
+ tmp2 = recall_imgs[swap_width: swap_width * 2]
130
+ random.shuffle(tmp2)
131
+ recall_imgs[: swap_width] = tmp2
132
+ recall_imgs[swap_width: swap_width * 2] = tmp1
133
 
 
 
134
  res = []
135
  idx = 0
136
  for img_url in recall_imgs:
143
  return res
144
 
145
  if __name__ == "__main__":
 
 
146
  # print("Loading model from:", model_path)
147
  checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage)
148
 
157
  encoder = TextEncoder()
158
  imgs_emb_file_path = "./coco_img_emb"
159
  imgs_emb, imgs_path = load_obj(imgs_emb_file_path)
160
+ imgs_emb = cp.asarray(imgs_emb)
161
 
162
+ normalize = transforms.Normalize(mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
163
+ std=[0.229 * 255, 0.224 * 255, 0.225 * 255])
164
+
165
  cat_image = "./cat_example.jpg"
166
  dog_image = "./dog_example.jpg"
167
  w1_image = "./white.jpg"
172
  fn=search,
173
  inputs=[
174
  gr.inputs.Radio([I2I, T2I]),
175
+ gr.inputs.Radio([DDT, UEFDT, IEFDT, ViLT]),
176
  gr.inputs.Image(shape=(400, 400), label="Image to search", optional=True),
177
  gr.inputs.Textbox(
178
  lines=1, label="Text query", placeholder="please input text query here...",
179
+ )
180
  ],
181
  theme="grass",
182
  outputs=[
187
  gr.outputs.Image(type="auto", label="5rd Best match")
188
  ],
189
  examples=[
190
+ [I2I, DDT, cat_image, ""],#, img_folder / "8LWtpfhGP4U.jpg"],
191
+ [I2I, ViLT, dog_image, ""],#, img_folder / "_ppnPXy_TVw.jpg"],
192
+ [T2I, UEFDT, w1_image, "a woman is walking on the road"],#, img_folder / "8LWtpfhGP4U.jpg"],
193
+ [T2I, IEFDT, w2_image, "a boy is eating apple"],#, img_folder / "_ppnPXy_TVw.jpg"],
194
  ],
195
+ title="图文检索系统",
196
  description="请输入图片或文本,将为您展示相关的图片:",
197
  )
198
+ iface.launch(share=False, enable_queue=True)
199
+
inputs_analysis.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ # f = open("dataset_anns.json")
4
+ # js_file = json.load(f)
5
+ # all_sent_ids = []
6
+ # for case in js_file['images']:
7
+ # all_sent_ids.extend(case['sentids'])
8
+ # print("length of sent ids is: {}; max id of sentids is {}.".format(len(all_sent_ids), max(all_sent_ids)))
9
+ # # print(js_file['images'][0])
10
+ # f.close()
11
+
12
+
13
+ import os
14
+
15
+ # train_dict = os.listdir("/dataset/coco/train2017")
16
+ # val_dict = os.listdir("/dataset/coco/val2017")
17
+ import json
18
+
19
+ with open("/dataset/coco/annotations/image_info_test2017.json", "r") as f:
20
+ js = json.load(f)
21
+ print()
misc/__pycache__/evaluation.cpython-37.pyc CHANGED
Binary files a/misc/__pycache__/evaluation.cpython-37.pyc and b/misc/__pycache__/evaluation.cpython-37.pyc differ
misc/evaluation.py CHANGED
@@ -3,7 +3,7 @@
3
  Copyright (c) 2018 [Thomson Licensing]
4
  All Rights Reserved
5
  This program contains proprietary information which is a trade secret/business \
6
- secret of [Thomson Licensing] and is protected, even if unpublished, under \
7
  applicable Copyright laws (including French droit d'auteur) and/or may be \
8
  subject to one or more patent(s).
9
  Recipient is to retain this program in confidence and is not permitted to use \
@@ -20,56 +20,56 @@ This scripts permits one to reproduce training and experiments of:
20
  Author: Martin Engilberge
21
  """
22
 
23
- import numpy as np
24
 
25
  from misc.utils import flatten
26
- from scripts.postprocess import postprocess
27
 
28
  def cosine_sim(A, B):
29
- img_norm = np.linalg.norm(A, axis=1)
30
- caps_norm = np.linalg.norm(B, axis=1)
31
 
32
- scores = np.dot(A, B.T)
33
 
34
- norms = np.dot(np.expand_dims(img_norm, 1),
35
- np.expand_dims(caps_norm.T, 1).T)
36
 
37
  scores = (scores / norms)
38
 
39
  return scores
40
 
41
- def recallTopK(cap_enc, imgs_enc, imgs_path, method, ks=10, scores=None):
 
42
  if scores is None:
43
  scores = cosine_sim(cap_enc, imgs_enc)
44
 
45
- recall_imgs = [imgs_path[i] for i in np.argsort(scores, axis=1)[0][::-1][:ks]]
46
- postprocess(method, recall_imgs)
47
  return recall_imgs
48
 
49
  def recall_at_k_multi_cap(imgs_enc, caps_enc, ks=[1, 5, 10], scores=None):
50
  if scores is None:
51
  scores = cosine_sim(imgs_enc[::5, :], caps_enc)
52
 
53
- ranks = np.array([np.nonzero(np.in1d(row, np.arange(x * 5, x * 5 + 5, 1)))[0][0]
54
- for x, row in enumerate(np.argsort(scores, axis=1)[:, ::-1])])
55
 
56
- medr_caps_search = np.median(ranks)
57
 
58
  recall_caps_search = list()
59
 
60
  for k in [1, 5, 10]:
61
  recall_caps_search.append(
62
- (float(len(np.where(ranks < k)[0])) / ranks.shape[0]) * 100)
63
 
64
- ranks = np.array([np.nonzero(row == int(x / 5.0))[0][0]
65
- for x, row in enumerate(np.argsort(scores.T, axis=1)[:, ::-1])])
66
 
67
- medr_imgs_search = np.median(ranks)
68
 
69
  recall_imgs_search = list()
70
  for k in ks:
71
  recall_imgs_search.append(
72
- (float(len(np.where(ranks < k)[0])) / ranks.shape[0]) * 100)
73
 
74
  return recall_caps_search, recall_imgs_search, medr_caps_search, medr_imgs_search
75
 
@@ -87,13 +87,13 @@ def avg_recall(imgs_enc, caps_enc):
87
  caps = caps_enc[i:i + 5000]
88
  res.append(recall_at_k_multi_cap(imgs, caps))
89
 
90
- return [np.sum([x[i] for x in res], axis=0) / len(res) for i in range(len(res[0]))]
91
 
92
 
93
  def eval_recall(imgs_enc, caps_enc):
94
 
95
- imgs_enc = np.vstack(flatten(imgs_enc))
96
- caps_enc = np.vstack(flatten(caps_enc))
97
 
98
  res = avg_recall(imgs_enc, caps_enc)
99
 
3
  Copyright (c) 2018 [Thomson Licensing]
4
  All Rights Reserved
5
  This program contains proprietary information which is a trade secret/business \
6
+ secret of [Thomson Licensing] and is protected, even if ucpublished, under \
7
  applicable Copyright laws (including French droit d'auteur) and/or may be \
8
  subject to one or more patent(s).
9
  Recipient is to retain this program in confidence and is not permitted to use \
20
  Author: Martin Engilberge
21
  """
22
 
23
+ import cupy as cp
24
 
25
  from misc.utils import flatten
 
26
 
27
  def cosine_sim(A, B):
28
+ img_norm = cp.linalg.norm(A, axis=1)
29
+ caps_norm = cp.linalg.norm(B, axis=1)
30
 
31
+ scores = cp.dot(A, B.T)
32
 
33
+ norms = cp.dot(cp.expand_dims(img_norm, 1),
34
+ cp.expand_dims(caps_norm.T, 1).T)
35
 
36
  scores = (scores / norms)
37
 
38
  return scores
39
 
40
+ def recallTopK(cap_enc, imgs_enc, imgs_path, ks=10, scores=None):
41
+
42
  if scores is None:
43
  scores = cosine_sim(cap_enc, imgs_enc)
44
 
45
+ recall_imgs = [imgs_path[cp.asnumpy(i)] for i in cp.argsort(scores, axis=1)[0][::-1][:ks]]
46
+
47
  return recall_imgs
48
 
49
  def recall_at_k_multi_cap(imgs_enc, caps_enc, ks=[1, 5, 10], scores=None):
50
  if scores is None:
51
  scores = cosine_sim(imgs_enc[::5, :], caps_enc)
52
 
53
+ ranks = cp.array([cp.nonzero(cp.in1d(row, cp.arange(x * 5, x * 5 + 5, 1)))[0][0]
54
+ for x, row in enumerate(cp.argsort(scores, axis=1)[:, ::-1])])
55
 
56
+ medr_caps_search = cp.median(ranks)
57
 
58
  recall_caps_search = list()
59
 
60
  for k in [1, 5, 10]:
61
  recall_caps_search.append(
62
+ (float(len(cp.where(ranks < k)[0])) / ranks.shape[0]) * 100)
63
 
64
+ ranks = cp.array([cp.nonzero(row == int(x / 5.0))[0][0]
65
+ for x, row in enumerate(cp.argsort(scores.T, axis=1)[:, ::-1])])
66
 
67
+ medr_imgs_search = cp.median(ranks)
68
 
69
  recall_imgs_search = list()
70
  for k in ks:
71
  recall_imgs_search.append(
72
+ (float(len(cp.where(ranks < k)[0])) / ranks.shape[0]) * 100)
73
 
74
  return recall_caps_search, recall_imgs_search, medr_caps_search, medr_imgs_search
75
 
87
  caps = caps_enc[i:i + 5000]
88
  res.append(recall_at_k_multi_cap(imgs, caps))
89
 
90
+ return [cp.sum([x[i] for x in res], axis=0) / len(res) for i in range(len(res[0]))]
91
 
92
 
93
  def eval_recall(imgs_enc, caps_enc):
94
 
95
+ imgs_enc = cp.vstack(flatten(imgs_enc))
96
+ caps_enc = cp.vstack(flatten(caps_enc))
97
 
98
  res = avg_recall(imgs_enc, caps_enc)
99