atticus commited on
Commit
3b4c7a3
·
1 Parent(s): 362a148
Files changed (1) hide show
  1. app.py +30 -20
app.py CHANGED
@@ -58,25 +58,33 @@ def download_url_img(url):
58
  return False, []
59
 
60
 
61
- def search(mode, text):
62
-
63
- # translator = Translator(from_lang="chinese",to_lang="english")
64
- # text = translator.translate(text)
65
- dataset = torch.Tensor(encoder.encode(text)).unsqueeze(dim=0)
66
- dataset_loader = DataLoader(dataset, batch_size=batch_size, num_workers=1, pin_memory=True, collate_fn=collate_fn_cap_padded)
67
- caps_enc = list()
68
-
69
- for i, (caps, length) in enumerate(dataset_loader, 0):
70
- input_caps = caps.to(device)
71
- with torch.no_grad():
72
- _, output_emb = join_emb(None, input_caps, length)
73
- caps_enc.append(output_emb.cpu().data.numpy())
74
-
75
- caps_stack = np.vstack(caps_enc)
76
-
77
- imgs_url = [os.path.join("http://images.cocodataset.org/train2017", img_path.strip().split('_')[-1]) for img_path in imgs_path]
78
-
79
- recall_imgs = recallTopK(caps_stack, imgs_emb, imgs_url, ks=100)
 
 
 
 
 
 
 
 
80
  # Cat image downloaded from https://www.flickr.com/photos/blacktigersdream/23119711630
81
  # cat_image = "./cat_example.jpg"
82
  # Dog example downloaded from https://upload.wikimedia.org/wikipedia/commons/1/18/Dog_Breeds.jpg
@@ -109,11 +117,13 @@ if __name__ == "__main__":
109
  encoder = TextEncoder()
110
  imgs_emb_file_path = "./coco_img_emb"
111
  imgs_emb, imgs_path = load_obj(imgs_emb_file_path)
 
112
  print("prepare done!")
113
  iface = gr.Interface(
114
  fn=search,
115
  inputs=[
116
- gr.inputs.Radio([T2I]),
 
117
  gr.inputs.Textbox(
118
  lines=1, label="Text query", placeholder="Introduce the search text...",
119
  ),
 
58
  return False, []
59
 
60
 
61
+ def search(mode, image, text):
62
+
63
+ translator = Translator(from_lang="chinese",to_lang="english")
64
+ text = translator.translate(text)
65
+ if mode == T2I:
66
+ dataset = torch.Tensor(encoder.encode(text)).unsqueeze(dim=0)
67
+ dataset_loader = DataLoader(dataset, batch_size=batch_size, num_workers=1, pin_memory=True, collate_fn=collate_fn_cap_padded)
68
+ caps_enc = list()
69
+ for i, (caps, length) in enumerate(dataset_loader, 0):
70
+ input_caps = caps
71
+ with torch.no_grad():
72
+ _, output_emb = join_emb(None, input_caps, length)
73
+ caps_enc.append(output_emb)
74
+ _stack = np.vstack(caps_enc)
75
+
76
+ elif mode == I2I:
77
+ dataset = torch.Tensor(image).unsqueeze(dim=0)
78
+ dataset_loader = DataLoader(dataset, batch_size=batch_size, num_workers=1, pin_memory=True, collate_fn=collate_fn_cap_padded)
79
+ img_enc = list()
80
+ for i, (imgs, length) in enumerate(dataset_loader, 0):
81
+ input_imgs = imgs
82
+ with torch.no_grad():
83
+ _, output_emb = join_emb(input_imgs, None, length)
84
+ img_enc.append(output_emb)
85
+ _stack = np.vstack(img_enc)
86
+
87
+ recall_imgs = recallTopK(_stack, imgs_emb, imgs_url, ks=100)
88
  # Cat image downloaded from https://www.flickr.com/photos/blacktigersdream/23119711630
89
  # cat_image = "./cat_example.jpg"
90
  # Dog example downloaded from https://upload.wikimedia.org/wikipedia/commons/1/18/Dog_Breeds.jpg
 
117
  encoder = TextEncoder()
118
  imgs_emb_file_path = "./coco_img_emb"
119
  imgs_emb, imgs_path = load_obj(imgs_emb_file_path)
120
+ imgs_url = [os.path.join("http://images.cocodataset.org/train2017", img_path.strip().split('_')[-1]) for img_path in imgs_path]
121
  print("prepare done!")
122
  iface = gr.Interface(
123
  fn=search,
124
  inputs=[
125
+ gr.inputs.Radio([I2I, T2I]),
126
+ gr.inputs.Image(label="Image to search", optional=True),
127
  gr.inputs.Textbox(
128
  lines=1, label="Text query", placeholder="Introduce the search text...",
129
  ),