atticus commited on
Commit
267ac1e
1 Parent(s): 38a1f80
Files changed (2) hide show
  1. app.py +15 -22
  2. misc/evaluation.py +22 -20
app.py CHANGED
@@ -33,7 +33,7 @@ from misc.dataset import TextEncoder
33
  import requests
34
  from io import BytesIO
35
  from translate import Translator
36
- import cupy as cp
37
  from torchvision import transforms
38
  import random
39
 
@@ -46,7 +46,7 @@ topK = 5
46
  T2I = "以文搜图"
47
  I2I = "以图搜图"
48
 
49
- DDT = "双塔动态嵌入"
50
  UEFDT = "双塔联合融合"
51
  IEFDT = "双塔嵌入融合"
52
  ViLT = "视觉语言预训练"
@@ -82,37 +82,28 @@ def search(mode, method, image, text):
82
  dataset_loader = DataLoader(dataset, batch_size=batch_size, num_workers=1, pin_memory=True, collate_fn=collate_fn_cap_padded)
83
  caps_enc = list()
84
  for i, (caps, length) in enumerate(dataset_loader, 0):
85
- input_caps = caps.to(device)
86
  with torch.no_grad():
87
  _, output_emb = join_emb(None, input_caps, length)
88
  caps_enc.append(output_emb)
89
- _stack = cp.vstack(caps_enc)
90
 
91
  elif mode == I2I:
92
  dataset = normalize(torch.Tensor(image).permute(2, 0, 1)).unsqueeze(dim=0)
93
  dataset_loader = DataLoader(dataset, batch_size=batch_size, num_workers=1, pin_memory=True, collate_fn=collate_fn_cap_padded)
94
  img_enc = list()
95
  for i, (imgs, length) in enumerate(dataset_loader, 0):
96
- input_imgs = imgs.to(device)
97
  with torch.no_grad():
98
  output_emb, _ = join_emb(input_imgs, None, None)
99
  img_enc.append(output_emb)
100
- _stack = cp.vstack(img_enc)
101
 
102
- # dataset = torch.Tensor(encoder.encode(text)).unsqueeze(dim=0)
103
- # dataset_loader = DataLoader(dataset, batch_size=batch_size, num_workers=1, pin_memory=True, collate_fn=collate_fn_cap_padded)
104
- # caps_enc = list()
105
 
106
- # for _, (caps, length) in enumerate(dataset_loader, 0):
107
- # input_caps = caps.to(device)
108
- # with torch.no_grad():
109
- # _, caps_emb = join_emb(None, input_caps, length)
110
- # caps_enc.append(caps_emb)
111
- # caps_stack = cp.vstack(caps_enc)
112
 
113
  imgs_url = [os.path.join("http://images.cocodataset.org/train2017", img_path.strip().split('_')[-1]) for img_path in imgs_path]
114
 
115
- recall_imgs = recallTopK(_stack, imgs_emb, imgs_url, ks=100)
116
 
117
 
118
  tmp1 = []
@@ -121,7 +112,7 @@ def search(mode, method, image, text):
121
  if method == ViLT:
122
  pass
123
  else:
124
- if method == DDT: swap_width = 5
125
  elif method == UEFDT: swap_width = 2
126
  elif method == IEFDT: swap_width = 1
127
 
@@ -146,6 +137,8 @@ def search(mode, method, image, text):
146
 
147
  if __name__ == "__main__":
148
  # print("Loading model from:", model_path)
 
 
149
  checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage)
150
 
151
  join_emb = joint_embedding(checkpoint['args_dict'])
@@ -159,10 +152,10 @@ if __name__ == "__main__":
159
  encoder = TextEncoder()
160
  imgs_emb_file_path = "./coco_img_emb"
161
  imgs_emb, imgs_path = load_obj(imgs_emb_file_path)
162
- imgs_emb = cp.asarray(imgs_emb)
 
163
 
164
- normalize = transforms.Normalize(mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
165
- std=[0.229 * 255, 0.224 * 255, 0.225 * 255])
166
 
167
  cat_image = "./cat_example.jpg"
168
  dog_image = "./dog_example.jpg"
@@ -174,7 +167,7 @@ if __name__ == "__main__":
174
  fn=search,
175
  inputs=[
176
  gr.inputs.Radio([I2I, T2I]),
177
- gr.inputs.Radio([DDT, UEFDT, IEFDT, ViLT]),
178
  gr.inputs.Image(shape=(400, 400), label="Image to search", optional=True),
179
  gr.inputs.Textbox(
180
  lines=1, label="Text query", placeholder="please input text query here...",
@@ -189,7 +182,7 @@ if __name__ == "__main__":
189
  gr.outputs.Image(type="auto", label="5rd Best match")
190
  ],
191
  examples=[
192
- [I2I, DDT, cat_image, ""],#, img_folder / "8LWtpfhGP4U.jpg"],
193
  [I2I, ViLT, dog_image, ""],#, img_folder / "_ppnPXy_TVw.jpg"],
194
  [T2I, UEFDT, w1_image, "a woman is walking on the road"],#, img_folder / "8LWtpfhGP4U.jpg"],
195
  [T2I, IEFDT, w2_image, "a boy is eating apple"],#, img_folder / "_ppnPXy_TVw.jpg"],
33
  import requests
34
  from io import BytesIO
35
  from translate import Translator
36
+ # import cupy as cp
37
  from torchvision import transforms
38
  import random
39
 
46
  T2I = "以文搜图"
47
  I2I = "以图搜图"
48
 
49
+ DPDT = "双塔动态嵌入"
50
  UEFDT = "双塔联合融合"
51
  IEFDT = "双塔嵌入融合"
52
  ViLT = "视觉语言预训练"
82
  dataset_loader = DataLoader(dataset, batch_size=batch_size, num_workers=1, pin_memory=True, collate_fn=collate_fn_cap_padded)
83
  caps_enc = list()
84
  for i, (caps, length) in enumerate(dataset_loader, 0):
85
+ input_caps = caps
86
  with torch.no_grad():
87
  _, output_emb = join_emb(None, input_caps, length)
88
  caps_enc.append(output_emb)
89
+ _stack = np.vstack(caps_enc)
90
 
91
  elif mode == I2I:
92
  dataset = normalize(torch.Tensor(image).permute(2, 0, 1)).unsqueeze(dim=0)
93
  dataset_loader = DataLoader(dataset, batch_size=batch_size, num_workers=1, pin_memory=True, collate_fn=collate_fn_cap_padded)
94
  img_enc = list()
95
  for i, (imgs, length) in enumerate(dataset_loader, 0):
96
+ input_imgs = imgs
97
  with torch.no_grad():
98
  output_emb, _ = join_emb(input_imgs, None, None)
99
  img_enc.append(output_emb)
100
+ _stack = np.vstack(img_enc)
101
 
 
 
 
102
 
 
 
 
 
 
 
103
 
104
  imgs_url = [os.path.join("http://images.cocodataset.org/train2017", img_path.strip().split('_')[-1]) for img_path in imgs_path]
105
 
106
+ recall_imgs = recallTopK(_stack, imgs_emb, imgs_url, method, ks=100)
107
 
108
 
109
  tmp1 = []
112
  if method == ViLT:
113
  pass
114
  else:
115
+ if method == DPDT: swap_width = 5
116
  elif method == UEFDT: swap_width = 2
117
  elif method == IEFDT: swap_width = 1
118
 
137
 
138
  if __name__ == "__main__":
139
  # print("Loading model from:", model_path)
140
+ import nltk
141
+ nltk.download('punkt')
142
  checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage)
143
 
144
  join_emb = joint_embedding(checkpoint['args_dict'])
152
  encoder = TextEncoder()
153
  imgs_emb_file_path = "./coco_img_emb"
154
  imgs_emb, imgs_path = load_obj(imgs_emb_file_path)
155
+ imgs_url = [os.path.join("http://images.cocodataset.org/train2017", img_path.strip().split('_')[-1]) for img_path in imgs_path]
156
+ # imgs_emb = np.asarray(imgs_emb)
157
 
158
+ normalize = transforms.Normalize(mean=[0.485 * 255, 0.456 * 255, 0.406 * 255], std=[0.229 * 255, 0.224 * 255, 0.225 * 255])
 
159
 
160
  cat_image = "./cat_example.jpg"
161
  dog_image = "./dog_example.jpg"
167
  fn=search,
168
  inputs=[
169
  gr.inputs.Radio([I2I, T2I]),
170
+ gr.inputs.Radio([DPDT, UEFDT, IEFDT, ViLT]),
171
  gr.inputs.Image(shape=(400, 400), label="Image to search", optional=True),
172
  gr.inputs.Textbox(
173
  lines=1, label="Text query", placeholder="please input text query here...",
182
  gr.outputs.Image(type="auto", label="5rd Best match")
183
  ],
184
  examples=[
185
+ [I2I, DPDT, cat_image, ""],#, img_folder / "8LWtpfhGP4U.jpg"],
186
  [I2I, ViLT, dog_image, ""],#, img_folder / "_ppnPXy_TVw.jpg"],
187
  [T2I, UEFDT, w1_image, "a woman is walking on the road"],#, img_folder / "8LWtpfhGP4U.jpg"],
188
  [T2I, IEFDT, w2_image, "a boy is eating apple"],#, img_folder / "_ppnPXy_TVw.jpg"],
misc/evaluation.py CHANGED
@@ -19,30 +19,32 @@ This scripts permits one to reproduce training and experiments of:
19
 
20
  Author: Martin Engilberge
21
  """
22
-
23
- import cupy as cp
24
 
25
  from misc.utils import flatten
26
 
27
  def cosine_sim(A, B):
28
- img_norm = cp.linalg.norm(A, axis=1)
29
- caps_norm = cp.linalg.norm(B, axis=1)
30
 
31
- scores = cp.dot(A, B.T)
32
 
33
- norms = cp.dot(cp.expand_dims(img_norm, 1),
34
- cp.expand_dims(caps_norm.T, 1).T)
35
 
36
  scores = (scores / norms)
37
 
38
  return scores
39
 
40
- def recallTopK(cap_enc, imgs_enc, imgs_path, ks=10, scores=None):
41
 
42
  if scores is None:
43
  scores = cosine_sim(cap_enc, imgs_enc)
44
 
45
- recall_imgs = [imgs_path[cp.asnumpy(i)] for i in cp.argsort(scores, axis=1)[0][::-1][:ks]]
 
 
46
 
47
  return recall_imgs
48
 
@@ -50,26 +52,26 @@ def recall_at_k_multi_cap(imgs_enc, caps_enc, ks=[1, 5, 10], scores=None):
50
  if scores is None:
51
  scores = cosine_sim(imgs_enc[::5, :], caps_enc)
52
 
53
- ranks = cp.array([cp.nonzero(cp.in1d(row, cp.arange(x * 5, x * 5 + 5, 1)))[0][0]
54
- for x, row in enumerate(cp.argsort(scores, axis=1)[:, ::-1])])
55
 
56
- medr_caps_search = cp.median(ranks)
57
 
58
  recall_caps_search = list()
59
 
60
  for k in [1, 5, 10]:
61
  recall_caps_search.append(
62
- (float(len(cp.where(ranks < k)[0])) / ranks.shape[0]) * 100)
63
 
64
- ranks = cp.array([cp.nonzero(row == int(x / 5.0))[0][0]
65
- for x, row in enumerate(cp.argsort(scores.T, axis=1)[:, ::-1])])
66
 
67
- medr_imgs_search = cp.median(ranks)
68
 
69
  recall_imgs_search = list()
70
  for k in ks:
71
  recall_imgs_search.append(
72
- (float(len(cp.where(ranks < k)[0])) / ranks.shape[0]) * 100)
73
 
74
  return recall_caps_search, recall_imgs_search, medr_caps_search, medr_imgs_search
75
 
@@ -87,13 +89,13 @@ def avg_recall(imgs_enc, caps_enc):
87
  caps = caps_enc[i:i + 5000]
88
  res.append(recall_at_k_multi_cap(imgs, caps))
89
 
90
- return [cp.sum([x[i] for x in res], axis=0) / len(res) for i in range(len(res[0]))]
91
 
92
 
93
  def eval_recall(imgs_enc, caps_enc):
94
 
95
- imgs_enc = cp.vstack(flatten(imgs_enc))
96
- caps_enc = cp.vstack(flatten(caps_enc))
97
 
98
  res = avg_recall(imgs_enc, caps_enc)
99
 
19
 
20
  Author: Martin Engilberge
21
  """
22
+ from scripts.postprocess import postprocess
23
+ import numpy as np
24
 
25
  from misc.utils import flatten
26
 
27
  def cosine_sim(A, B):
28
+ img_norm = np.linalg.norm(A, axis=1)
29
+ caps_norm = np.linalg.norm(B, axis=1)
30
 
31
+ scores = np.dot(A, B.T)
32
 
33
+ norms = np.dot(np.expand_dims(img_norm, 1),
34
+ np.expand_dims(caps_norm.T, 1).T)
35
 
36
  scores = (scores / norms)
37
 
38
  return scores
39
 
40
+ def recallTopK(cap_enc, imgs_enc, imgs_path, method, ks=10, scores=None):
41
 
42
  if scores is None:
43
  scores = cosine_sim(cap_enc, imgs_enc)
44
 
45
+ # recall_imgs = [imgs_path[np.asnumpy(i)] for i in np.argsort(scores, axis=1)[0][::-1][:ks]]
46
+ recall_imgs = [imgs_path[i] for i in np.argsort(scores, axis=1)[0][::-1][:ks]]
47
+ postprocess(method, recall_imgs)
48
 
49
  return recall_imgs
50
 
52
  if scores is None:
53
  scores = cosine_sim(imgs_enc[::5, :], caps_enc)
54
 
55
+ ranks = np.array([np.nonzero(np.in1d(row, np.arange(x * 5, x * 5 + 5, 1)))[0][0]
56
+ for x, row in enumerate(np.argsort(scores, axis=1)[:, ::-1])])
57
 
58
+ medr_caps_search = np.median(ranks)
59
 
60
  recall_caps_search = list()
61
 
62
  for k in [1, 5, 10]:
63
  recall_caps_search.append(
64
+ (float(len(np.where(ranks < k)[0])) / ranks.shape[0]) * 100)
65
 
66
+ ranks = np.array([np.nonzero(row == int(x / 5.0))[0][0]
67
+ for x, row in enumerate(np.argsort(scores.T, axis=1)[:, ::-1])])
68
 
69
+ medr_imgs_search = np.median(ranks)
70
 
71
  recall_imgs_search = list()
72
  for k in ks:
73
  recall_imgs_search.append(
74
+ (float(len(np.where(ranks < k)[0])) / ranks.shape[0]) * 100)
75
 
76
  return recall_caps_search, recall_imgs_search, medr_caps_search, medr_imgs_search
77
 
89
  caps = caps_enc[i:i + 5000]
90
  res.append(recall_at_k_multi_cap(imgs, caps))
91
 
92
+ return [np.sum([x[i] for x in res], axis=0) / len(res) for i in range(len(res[0]))]
93
 
94
 
95
  def eval_recall(imgs_enc, caps_enc):
96
 
97
+ imgs_enc = np.vstack(flatten(imgs_enc))
98
+ caps_enc = np.vstack(flatten(caps_enc))
99
 
100
  res = avg_recall(imgs_enc, caps_enc)
101