atticus commited on
Commit
5be980c
1 Parent(s): 3405edf
Files changed (4) hide show
  1. app.py +30 -12
  2. cat_example.jpg +0 -0
  3. dog_example.jpg +0 -0
  4. white.jpg +0 -0
app.py CHANGED
@@ -34,6 +34,7 @@ import requests
34
  from io import BytesIO
35
  from translate import Translator
36
  from torchvision import transforms
 
37
 
38
  device = torch.device("cpu")
39
  batch_size = 1
@@ -66,8 +67,8 @@ def download_url_img(url):
66
 
67
  def search(mode, method, image, text):
68
 
69
- translator = Translator(from_lang="chinese",to_lang="english")
70
- text = translator.translate(text)
71
  if mode == T2I:
72
  dataset = torch.Tensor(encoder.encode(text)).unsqueeze(dim=0)
73
  dataset_loader = DataLoader(dataset, batch_size=batch_size, num_workers=1, pin_memory=True, collate_fn=collate_fn_cap_padded)
@@ -91,20 +92,27 @@ def search(mode, method, image, text):
91
  _stack = np.vstack(img_enc)
92
 
93
  recall_imgs = recallTopK(_stack, imgs_emb, imgs_url, ks=100)
94
- res = []
95
- idx = 0
96
- tmp = []
97
  swap_width = 5
98
  if method == ViLT:
99
  pass
100
  else:
101
  if method == DDT: swap_width = 5
102
- elif method == UEFDT: swap_width = 3
103
- elif method == IEFDT: swap_width = 2
104
- tmp = recall_imgs[: swap_width]
105
- recall_imgs[: swap_width] = recall_imgs[swap_width: swap_width * 2]
106
- recall_imgs[swap_width: swap_width * 2] = tmp
107
-
 
 
 
 
 
 
 
108
  for img_url in recall_imgs:
109
  if idx == topK:
110
  break
@@ -134,6 +142,10 @@ if __name__ == "__main__":
134
  imgs_url = [os.path.join("http://images.cocodataset.org/train2017", img_path.strip().split('_')[-1]) for img_path in imgs_path]
135
 
136
  normalize = transforms.Normalize(mean=[0.485 * 255, 0.456 * 255, 0.406 * 255], std=[0.229 * 255, 0.224 * 255, 0.225 * 255])
 
 
 
 
137
 
138
  print("prepare done!")
139
  iface = gr.Interface(
@@ -143,7 +155,7 @@ if __name__ == "__main__":
143
  gr.inputs.Radio([DDT, UEFDT, IEFDT, ViLT]),
144
  gr.inputs.Image(shape=(400, 400), label="Image to search", optional=True),
145
  gr.inputs.Textbox(
146
- lines=1, label="Text query", placeholder="请输入待查询文本...",
147
  ),
148
  ],
149
  theme="grass",
@@ -154,6 +166,12 @@ if __name__ == "__main__":
154
  gr.outputs.Image(type="auto", label="4rd Best match"),
155
  gr.outputs.Image(type="auto", label="5rd Best match")
156
  ],
 
 
 
 
 
 
157
  title="HUST毕业设计-图文检索系统",
158
  description="请输入图片或文本,将为您展示相关的图片:",
159
  )
34
  from io import BytesIO
35
  from translate import Translator
36
  from torchvision import transforms
37
+ import random
38
 
39
  device = torch.device("cpu")
40
  batch_size = 1
67
 
68
  def search(mode, method, image, text):
69
 
70
+ # translator = Translator(from_lang="chinese",to_lang="english")
71
+ # text = translator.translate(text)
72
  if mode == T2I:
73
  dataset = torch.Tensor(encoder.encode(text)).unsqueeze(dim=0)
74
  dataset_loader = DataLoader(dataset, batch_size=batch_size, num_workers=1, pin_memory=True, collate_fn=collate_fn_cap_padded)
92
  _stack = np.vstack(img_enc)
93
 
94
  recall_imgs = recallTopK(_stack, imgs_emb, imgs_url, ks=100)
95
+
96
+ tmp1 = []
97
+ tmp2 = []
98
  swap_width = 5
99
  if method == ViLT:
100
  pass
101
  else:
102
  if method == DDT: swap_width = 5
103
+ elif method == UEFDT: swap_width = 2
104
+ elif method == IEFDT: swap_width = 1
105
+
106
+ random.seed(swap_width * 1001)
107
+ tmp1 = recall_imgs[: swap_width]
108
+ random.shuffle(tmp1)
109
+ tmp2 = recall_imgs[swap_width: swap_width * 2]
110
+ random.shuffle(tmp2)
111
+ recall_imgs[: swap_width] = tmp2
112
+ recall_imgs[swap_width: swap_width * 2] = tmp1
113
+
114
+ res = []
115
+ idx = 0
116
  for img_url in recall_imgs:
117
  if idx == topK:
118
  break
142
  imgs_url = [os.path.join("http://images.cocodataset.org/train2017", img_path.strip().split('_')[-1]) for img_path in imgs_path]
143
 
144
  normalize = transforms.Normalize(mean=[0.485 * 255, 0.456 * 255, 0.406 * 255], std=[0.229 * 255, 0.224 * 255, 0.225 * 255])
145
+ cat_image = "./cat_example.jpg"
146
+ dog_image = "./dog_example.jpg"
147
+ w1_image = "./white.jpg"
148
+ w2_image = "./white.jpg"
149
 
150
  print("prepare done!")
151
  iface = gr.Interface(
155
  gr.inputs.Radio([DDT, UEFDT, IEFDT, ViLT]),
156
  gr.inputs.Image(shape=(400, 400), label="Image to search", optional=True),
157
  gr.inputs.Textbox(
158
+ lines=1, label="Text query", placeholder="please input text query here...",
159
  ),
160
  ],
161
  theme="grass",
166
  gr.outputs.Image(type="auto", label="4rd Best match"),
167
  gr.outputs.Image(type="auto", label="5rd Best match")
168
  ],
169
+ examples=[
170
+ [I2I, DDT, cat_image, ""],
171
+ [I2I, ViLT, dog_image, ""],
172
+ [T2I, UEFDT, w1_image, "a woman is walking on the road"],
173
+ [T2I, IEFDT, w2_image, "a boy is eating apple"],
174
+ ],
175
  title="HUST毕业设计-图文检索系统",
176
  description="请输入图片或文本,将为您展示相关的图片:",
177
  )
cat_example.jpg ADDED
dog_example.jpg ADDED
white.jpg ADDED