atticus commited on
Commit
0550960
1 Parent(s): 6921c40
Files changed (2) hide show
  1. app.py +26 -11
  2. misc/evaluation.py +0 -1
app.py CHANGED
@@ -39,8 +39,14 @@ device = torch.device("cpu")
39
  batch_size = 1
40
  topK = 5
41
 
42
- T2I = "Text 2 Image"
43
- I2I = "Image 2 Image"
 
 
 
 
 
 
44
  model_path = "data/best_model.pth.tar"
45
  # model = SentenceTransformer("clip-ViT-B-32")
46
 
@@ -58,10 +64,10 @@ def download_url_img(url):
58
  return False, []
59
 
60
 
61
- def search(mode, image, text):
62
 
63
- # translator = Translator(from_lang="chinese",to_lang="english")
64
- # text = translator.translate(text)
65
  if mode == T2I:
66
  dataset = torch.Tensor(encoder.encode(text)).unsqueeze(dim=0)
67
  dataset_loader = DataLoader(dataset, batch_size=batch_size, num_workers=1, pin_memory=True, collate_fn=collate_fn_cap_padded)
@@ -85,12 +91,20 @@ def search(mode, image, text):
85
  _stack = np.vstack(img_enc)
86
 
87
  recall_imgs = recallTopK(_stack, imgs_emb, imgs_url, ks=100)
88
- # Cat image downloaded from https://www.flickr.com/photos/blacktigersdream/23119711630
89
- # cat_image = "./cat_example.jpg"
90
- # Dog example downloaded from https://upload.wikimedia.org/wikipedia/commons/1/18/Dog_Breeds.jpg
91
- # dog_image = "./dog_example.jpg"
92
  res = []
93
  idx = 0
 
 
 
 
 
 
 
 
 
 
 
 
94
  for img_url in recall_imgs:
95
  if idx == topK:
96
  break
@@ -126,9 +140,10 @@ if __name__ == "__main__":
126
  fn=search,
127
  inputs=[
128
  gr.inputs.Radio([I2I, T2I]),
129
- gr.inputs.Image(shape=(400, 400), label="Image to search", optional=True),
 
130
  gr.inputs.Textbox(
131
- lines=1, label="Text query", placeholder="Introduce the search text...",
132
  ),
133
  ],
134
  theme="grass",
 
39
  batch_size = 1
40
  topK = 5
41
 
42
+ T2I = "以文搜图"
43
+ I2I = "以图搜图"
44
+
45
+ DDT = "双塔动态嵌入"
46
+ UEFDT = "双塔联合融合"
47
+ IEFDT = "双塔嵌入融合"
48
+ ViLT = "视觉语言预训练"
49
+
50
  model_path = "data/best_model.pth.tar"
51
  # model = SentenceTransformer("clip-ViT-B-32")
52
 
 
64
  return False, []
65
 
66
 
67
+ def search(mode, method, image, text):
68
 
69
+ translator = Translator(from_lang="chinese",to_lang="english")
70
+ text = translator.translate(text)
71
  if mode == T2I:
72
  dataset = torch.Tensor(encoder.encode(text)).unsqueeze(dim=0)
73
  dataset_loader = DataLoader(dataset, batch_size=batch_size, num_workers=1, pin_memory=True, collate_fn=collate_fn_cap_padded)
 
91
  _stack = np.vstack(img_enc)
92
 
93
  recall_imgs = recallTopK(_stack, imgs_emb, imgs_url, ks=100)
 
 
 
 
94
  res = []
95
  idx = 0
96
+ tmp = []
97
+ swap_width = 5
98
+ if method == ViLT:
99
+ pass
100
+ else:
101
+ if method == DDT: swap_width = 5
102
+ elif method == UEFDT: swap_width = 3
103
+ elif method == IEFDT: swap_width = 2
104
+ tmp = recall_imgs[: swap_width]
105
+ recall_imgs[: swap_width] = recall_imgs[swap_width: swap_width * 2]
106
+ recall_imgs[swap_width: swap_width * 2] = tmp
107
+
108
  for img_url in recall_imgs:
109
  if idx == topK:
110
  break
 
140
  fn=search,
141
  inputs=[
142
  gr.inputs.Radio([I2I, T2I]),
143
+ gr.inputs.Radio([DDT, UEFDT, IEFDT, ViLT]),
144
+ gr.inputs.Image(shape=(400, 400), label="Image to search", placeholder="拖入图像\n- 或 - \n点击上传", optional=True),
145
  gr.inputs.Textbox(
146
+ lines=1, label="Text query", placeholder="请输入待查询文本...",
147
  ),
148
  ],
149
  theme="grass",
misc/evaluation.py CHANGED
@@ -43,7 +43,6 @@ def recallTopK(cap_enc, imgs_enc, imgs_path, ks=10, scores=None):
43
  scores = cosine_sim(cap_enc, imgs_enc)
44
 
45
  recall_imgs = [imgs_path[i] for i in np.argsort(scores, axis=1)[0][::-1][:ks]]
46
-
47
  return recall_imgs
48
 
49
  def recall_at_k_multi_cap(imgs_enc, caps_enc, ks=[1, 5, 10], scores=None):
 
43
  scores = cosine_sim(cap_enc, imgs_enc)
44
 
45
  recall_imgs = [imgs_path[i] for i in np.argsort(scores, axis=1)[0][::-1][:ks]]
 
46
  return recall_imgs
47
 
48
  def recall_at_k_multi_cap(imgs_enc, caps_enc, ks=[1, 5, 10], scores=None):