atticus commited on
Commit
1e7fce7
1 Parent(s): 9666011
Files changed (1) hide show
  1. app.py +7 -4
app.py CHANGED
@@ -33,7 +33,7 @@ from misc.dataset import TextEncoder
33
  import requests
34
  from io import BytesIO
35
  from translate import Translator
36
-
37
 
38
  device = torch.device("cpu")
39
  batch_size = 1
@@ -74,13 +74,13 @@ def search(mode, image, text):
74
  _stack = np.vstack(caps_enc)
75
 
76
  elif mode == I2I:
77
- dataset = torch.Tensor(image).unsqueeze(dim=0)
78
  dataset_loader = DataLoader(dataset, batch_size=batch_size, num_workers=1, pin_memory=True, collate_fn=collate_fn_cap_padded)
79
  img_enc = list()
80
  for i, (imgs, length) in enumerate(dataset_loader, 0):
81
  input_imgs = imgs
82
  with torch.no_grad():
83
- _, output_emb = join_emb(input_imgs, None, length)
84
  img_enc.append(output_emb)
85
  _stack = np.vstack(img_enc)
86
 
@@ -118,12 +118,15 @@ if __name__ == "__main__":
118
  imgs_emb_file_path = "./coco_img_emb"
119
  imgs_emb, imgs_path = load_obj(imgs_emb_file_path)
120
  imgs_url = [os.path.join("http://images.cocodataset.org/train2017", img_path.strip().split('_')[-1]) for img_path in imgs_path]
 
 
 
121
  print("prepare done!")
122
  iface = gr.Interface(
123
  fn=search,
124
  inputs=[
125
  gr.inputs.Radio([I2I, T2I]),
126
- gr.inputs.Image(shape=(512, 512), label="Image to search", optional=True),
127
  gr.inputs.Textbox(
128
  lines=1, label="Text query", placeholder="Introduce the search text...",
129
  ),
 
33
  import requests
34
  from io import BytesIO
35
  from translate import Translator
36
+ from torchvision import transforms
37
 
38
  device = torch.device("cpu")
39
  batch_size = 1
 
74
  _stack = np.vstack(caps_enc)
75
 
76
  elif mode == I2I:
77
+ dataset = normalize(torch.Tensor(image).permute(2, 0, 1)).unsqueeze(dim=0)
78
  dataset_loader = DataLoader(dataset, batch_size=batch_size, num_workers=1, pin_memory=True, collate_fn=collate_fn_cap_padded)
79
  img_enc = list()
80
  for i, (imgs, length) in enumerate(dataset_loader, 0):
81
  input_imgs = imgs
82
  with torch.no_grad():
83
+ output_emb, _ = join_emb(input_imgs, None, None)
84
  img_enc.append(output_emb)
85
  _stack = np.vstack(img_enc)
86
 
 
118
  imgs_emb_file_path = "./coco_img_emb"
119
  imgs_emb, imgs_path = load_obj(imgs_emb_file_path)
120
  imgs_url = [os.path.join("http://images.cocodataset.org/train2017", img_path.strip().split('_')[-1]) for img_path in imgs_path]
121
+
122
+ normalize = transforms.Normalize(mean=[0.485 * 255, 0.456 * 255, 0.406 * 255], std=[0.229 * 255, 0.224 * 255, 0.225 * 255])
123
+
124
  print("prepare done!")
125
  iface = gr.Interface(
126
  fn=search,
127
  inputs=[
128
  gr.inputs.Radio([I2I, T2I]),
129
+ gr.inputs.Image(shape=(400, 400), label="Image to search", optional=True),
130
  gr.inputs.Textbox(
131
  lines=1, label="Text query", placeholder="Introduce the search text...",
132
  ),