atticus commited on
Commit
953580c
1 Parent(s): 950c874
.vscode/sftp.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "itr-ddt",
3
+ "host": "192.168.0.109",
4
+ "protocol": "sftp",
5
+ "port": 22,
6
+ "username": "atticus",
7
+ "password":"qs123",
8
+ "passphrase": "null",
9
+ "passive": false,
10
+ "interactiveAuth": true,
11
+ "remotePath": "/home/atticus/proj/matching/itr-ddt",
12
+ "context": "D:/Projects/MultiModal/itr-ddt",
13
+ "uploadOnSave": true,
14
+ "downloadOnOpen":true,
15
+ "syncMode": "update",
16
+ "ignore": [
17
+ "**/.vscode/**",
18
+ "**/.git/**",
19
+ "**/.DS_Store",
20
+ "**/*.tar",
21
+ "**/*.zip",
22
+ "**/*.pkl",
23
+ "**/*.json",
24
+ "**/*.npy"
25
+ ],
26
+ "watcher": {
27
+ "files": "*",
28
+ "autoUpload": false,
29
+ "autoDelete": false
30
+ }
31
+ }
32
+
app.py CHANGED
@@ -32,18 +32,22 @@ import sys
32
  from misc.dataset import TextEncoder
33
  import requests
34
  import cv2
 
 
 
35
 
36
 
37
  device = torch.device("cuda")
38
- batch_size = 32
39
-
40
 
41
  T2I = "Text 2 Image"
42
-
43
  I2I = "Image 2 Image"
44
  model_path = "data/best_model.pth.tar"
45
  # model = SentenceTransformer("clip-ViT-B-32")
46
 
 
 
47
 
48
 
49
  def download_url_img(url):
@@ -55,83 +59,68 @@ def download_url_img(url):
55
  return False, []
56
  if response is not None and response.status_code == 200:
57
  input_image_data = response.content
58
- np_arr = np.asarray(bytearray(input_image_data), np.uint8).reshape(1, -1)
59
- parsed_image = cv2.imdecode(np_arr, cv2.IMREAD_UNCHANGED)
60
- return True, parsed_image
61
-
62
- def search(image, mode, text):
63
-
64
- print("Loading model from:", model_path)
65
- checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage)
66
-
67
- join_emb = joint_embedding(checkpoint['args_dict'])
68
- join_emb.load_state_dict(checkpoint["state_dict"])
69
 
70
- for param in join_emb.parameters():
71
- param.requires_grad = False
72
-
73
- join_emb.to(device)
74
- join_emb.eval()
75
-
76
- encoder = TextEncoder()
77
- print("Loading model done")
78
- # (4) design intersection mode.
79
- print("Please input your description of the image that you wanna search >>>")
80
 
81
-
82
- # with open(args.data_path, 'w') as cap_file:
83
- # cap_file.writelines(cap_str)
84
 
 
 
85
  dataset = torch.Tensor(encoder.encode(text)).unsqueeze(dim=0)
86
  dataset_loader = DataLoader(dataset, batch_size=batch_size, num_workers=1, pin_memory=True, collate_fn=collate_fn_cap_padded)
87
  caps_enc = list()
 
88
  for _, (caps, length) in enumerate(dataset_loader, 0):
89
  input_caps = caps.to(device)
90
  with torch.no_grad():
91
  _, caps_emb = join_emb(None, input_caps, length)
92
- caps_enc.append(caps_emb.cpu().data.numpy())
 
93
 
 
94
 
95
- caps_stack = np.vstack(caps_enc)
96
-
97
-
98
- print("recall from resources ...")
99
- # (1) load candidate imgs from saved embeding pkl file.
100
- imgs_emb_file_path = "./coco_img_emb"
101
- # imgs_emb(40775, 2400)
102
- imgs_emb, imgs_path = load_obj(imgs_emb_file_path)
103
- # (2) calculate the sim between cap and imgs.
104
- # (3) rank imgs and display the searching result.
105
- imgs_url = os.path.join("http://images.cocodataset.org/train2017", imgs_path.strip().split('_')[-1])
106
-
107
- recall_imgs = recallTopK(caps_stack, imgs_emb, imgs_url, ks=5)
108
 
109
  # Cat image downloaded from https://www.flickr.com/photos/blacktigersdream/23119711630
110
  # cat_image = "./cat_example.jpg"
111
  # Dog example downloaded from https://upload.wikimedia.org/wikipedia/commons/1/18/Dog_Breeds.jpg
112
  # dog_image = "./dog_example.jpg"
113
- res = [download_url_img(img_url)[1] for img_url in recall_imgs if download_url_img(img_url)[0] == True]
 
 
 
 
 
 
 
 
 
114
 
115
- # logger.info(f"Mode {mode} selected")
116
- # if mode == I2I:
117
- # logger.info(f"Processing image in mode {mode}")
118
- # emb = model.encode([Image.fromarray(image)], convert_to_tensor=True)
119
- # elif mode == T2I:
120
- # logger.info(f"Processing text in mode {mode}")
121
- # emb = model.encode([text], convert_to_tensor=True)
122
 
123
- # cos_sim = util.cos_sim(img_emb, emb)
124
- # logger.info(f"Best match: {img_names[torch.argmax(cos_sim)]}")
125
- # return [Image.open(img_folder / img_names[top_k_best_image]) for top_k_best_image in torch.topk(cos_sim, 5, 0).indices]
126
- return res
127
 
128
- if __name__ == "__main__":
 
129
 
 
 
 
 
 
 
 
130
  iface = gr.Interface(
131
  fn=search,
132
  inputs=[
133
- gr.inputs.Image(label="Image to search", optional=True),
134
- gr.inputs.Radio([T2I, I2I]),
135
  gr.inputs.Textbox(
136
  lines=1, label="Text query", placeholder="Introduce the search text...",
137
  ),
@@ -147,4 +136,4 @@ if __name__ == "__main__":
147
  title="HUST毕业设计-图文检索系统",
148
  description="请输入图片或文本,将为您展示相关的图片:",
149
  )
150
- iface.launch()
 
32
  from misc.dataset import TextEncoder
33
  import requests
34
  import cv2
35
+ from io import BytesIO
36
+ from translate import Translator
37
+ import cupy as cp
38
 
39
 
40
  device = torch.device("cuda")
41
+ batch_size = 1
42
+ topK = 5
43
 
44
  T2I = "Text 2 Image"
 
45
  I2I = "Image 2 Image"
46
  model_path = "data/best_model.pth.tar"
47
  # model = SentenceTransformer("clip-ViT-B-32")
48
 
49
+ img_folder = Path("./photos/")
50
+
51
 
52
 
53
  def download_url_img(url):
 
59
  return False, []
60
  if response is not None and response.status_code == 200:
61
  input_image_data = response.content
62
+ # np_arr = np.asarray(bytearray(input_image_data), np.uint8).reshape(1, -1)
63
+ # parsed_image = cv2.imdecode(np_arr, cv2.IMREAD_UNCHANGED)
64
+ image=Image.open(BytesIO(input_image_data))
65
+ return True, image
66
+ return False, []
 
 
 
 
 
 
67
 
 
 
 
 
 
 
 
 
 
 
68
 
69
+ def search(mode, text):
 
 
70
 
71
+ # translator = Translator(from_lang="chinese",to_lang="english")
72
+ # text = translator.translate(text)
73
  dataset = torch.Tensor(encoder.encode(text)).unsqueeze(dim=0)
74
  dataset_loader = DataLoader(dataset, batch_size=batch_size, num_workers=1, pin_memory=True, collate_fn=collate_fn_cap_padded)
75
  caps_enc = list()
76
+
77
  for _, (caps, length) in enumerate(dataset_loader, 0):
78
  input_caps = caps.to(device)
79
  with torch.no_grad():
80
  _, caps_emb = join_emb(None, input_caps, length)
81
+ caps_enc.append(caps_emb)
82
+ caps_stack = cp.vstack(caps_enc)
83
 
84
+ imgs_url = [os.path.join("http://images.cocodataset.org/train2017", img_path.strip().split('_')[-1]) for img_path in imgs_path]
85
 
86
+ recall_imgs = recallTopK(caps_stack, imgs_emb, imgs_url, ks=100)
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
  # Cat image downloaded from https://www.flickr.com/photos/blacktigersdream/23119711630
89
  # cat_image = "./cat_example.jpg"
90
  # Dog example downloaded from https://upload.wikimedia.org/wikipedia/commons/1/18/Dog_Breeds.jpg
91
  # dog_image = "./dog_example.jpg"
92
+ res = []
93
+ idx = 0
94
+ for img_url in recall_imgs:
95
+ if idx == topK:
96
+ break
97
+ b, img = download_url_img(img_url)
98
+ if b:
99
+ res.append(img)
100
+ idx += 1
101
+ return res
102
 
103
+ if __name__ == "__main__":
104
+ # print("Loading model from:", model_path)
105
+ checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage)
 
 
 
 
106
 
107
+ join_emb = joint_embedding(checkpoint['args_dict'])
108
+ join_emb.load_state_dict(checkpoint["state_dict"])
 
 
109
 
110
+ for param in join_emb.parameters():
111
+ param.requires_grad = False
112
 
113
+ join_emb.to(device)
114
+ join_emb.eval()
115
+ encoder = TextEncoder()
116
+ imgs_emb_file_path = "./coco_img_emb"
117
+ imgs_emb, imgs_path = load_obj(imgs_emb_file_path)
118
+ imgs_emb = cp.asarray(imgs_emb)
119
+ print("prepare done!")
120
  iface = gr.Interface(
121
  fn=search,
122
  inputs=[
123
+ gr.inputs.Radio([T2I]),
 
124
  gr.inputs.Textbox(
125
  lines=1, label="Text query", placeholder="Introduce the search text...",
126
  ),
 
136
  title="HUST毕业设计-图文检索系统",
137
  description="请输入图片或文本,将为您展示相关的图片:",
138
  )
139
+ iface.launch(share=True)
coco_img_emb.pkl CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:43ca7d88339063c202638beab9203f1acbc86acaaf43aa7f61a87b2789070bdd
3
- size 1587836571
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:012377f7e09f9f95cc15a391f2da541ede470d4c6d6c36f9239bb59def6ec269
3
+ size 108068864
data/best_model.pth.tar CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:1615e4ce1ee8f906ded31c29817945758f532e4793faba45c7df546a593efb3e
3
- size 1500259972
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f8ada75eacbe26ecf1c3507238b542e1db689254a1dac3825ffe4842443d2947
3
+ size 108068864
data/utable.npy CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:790951d4b08e843e3bca0563570f4134ffd17b6bd4ab8d237d2e5ae15e4febb3
3
- size 2342138474
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8c8af23b32fcfb69ad00bc22f39c557e2926b66e2edb3275437157967b5f8257
3
+ size 120258560
flagged/1st Best match/0.png ADDED
flagged/2nd Best match/0.png ADDED
flagged/3rd Best match/0.png ADDED
flagged/4rd Best match/0.png ADDED
flagged/5rd Best match/0.png ADDED
flagged/log.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ Image to search,Text query,1st Best match,2nd Best match,3rd Best match,4rd Best match,5rd Best match,timestamp
2
+ ,,,,,,,2022-03-10 00:37:13.783708
3
+ ,Text 2 Image,,1st Best match/0.png,2nd Best match/0.png,3rd Best match/0.png,4rd Best match/0.png,5rd Best match/0.png,2022-03-10 01:17:09.801153
misc/__pycache__/dataset.cpython-37.pyc CHANGED
Binary files a/misc/__pycache__/dataset.cpython-37.pyc and b/misc/__pycache__/dataset.cpython-37.pyc differ
 
misc/__pycache__/evaluation.cpython-37.pyc CHANGED
Binary files a/misc/__pycache__/evaluation.cpython-37.pyc and b/misc/__pycache__/evaluation.cpython-37.pyc differ
 
misc/dataset.py CHANGED
@@ -269,7 +269,7 @@ class TextEncoder(object):
269
  def __init__(self, word_dict_path=path["WORD_DICT"]):
270
 
271
  path_params = os.path.join(word_dict_path, 'utable.npy')
272
- self.params = np.load(path_params, encoding='latin1')
273
  self.dico = _load_dictionary(word_dict_path)
274
 
275
  def encode(self, text):
 
269
  def __init__(self, word_dict_path=path["WORD_DICT"]):
270
 
271
  path_params = os.path.join(word_dict_path, 'utable.npy')
272
+ self.params = np.load(path_params, encoding='latin1', allow_pickle=True)
273
  self.dico = _load_dictionary(word_dict_path)
274
 
275
  def encode(self, text):
misc/evaluation.py CHANGED
@@ -23,26 +23,27 @@ Author: Martin Engilberge
23
  import numpy as np
24
 
25
  from misc.utils import flatten
 
26
 
27
  def cosine_sim(A, B):
28
- img_norm = np.linalg.norm(A, axis=1)
29
- caps_norm = np.linalg.norm(B, axis=1)
30
 
31
- scores = np.dot(A, B.T)
32
 
33
- norms = np.dot(np.expand_dims(img_norm, 1),
34
- np.expand_dims(caps_norm.T, 1).T)
35
 
36
  scores = (scores / norms)
37
 
38
  return scores
39
 
40
- def recallTopK(cap_enc, imgs_enc, imgs_url, ks=10, scores=None):
41
 
42
  if scores is None:
43
  scores = cosine_sim(cap_enc, imgs_enc)
44
 
45
- recall_imgs = [imgs_url[i] for i in np.argsort(scores, axis=1)[0][::-1][:ks]]
46
 
47
  return recall_imgs
48
 
 
23
  import numpy as np
24
 
25
  from misc.utils import flatten
26
+ import cupy as cp
27
 
28
  def cosine_sim(A, B):
29
+ img_norm = cp.linalg.norm(A, axis=1)
30
+ caps_norm = cp.linalg.norm(B, axis=1)
31
 
32
+ scores = cp.dot(A, B.T)
33
 
34
+ norms = cp.dot(cp.expand_dims(img_norm, 1),
35
+ cp.expand_dims(caps_norm.T, 1).T)
36
 
37
  scores = (scores / norms)
38
 
39
  return scores
40
 
41
+ def recallTopK(cap_enc, imgs_enc, imgs_path, ks=10, scores=None):
42
 
43
  if scores is None:
44
  scores = cosine_sim(cap_enc, imgs_enc)
45
 
46
+ recall_imgs = [imgs_path[cp.asnumpy(i)] for i in cp.argsort(scores, axis=1)[0][::-1][:ks]]
47
 
48
  return recall_imgs
49
 
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ cupy==10.2.0
2
+ cupy_cuda101==9.6.0
3
+ gradio==2.8.9
4
+ matplotlib==2.2.2
5
+ nltk==3.3
6
+ numpy==1.21.5
7
+ Pillow==9.0.1
8
+ pycocotools==2.0.4
9
+ requests==2.27.1
10
+ scipy==1.1.0
11
+ sru==2.6.0
12
+ torch==1.10.2
13
+ torchvision==0.2.1
14
+ tqdm==4.63.0
15
+ translate==3.6.1
16
+ visual_genome==1.1.1
run.sh ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ echo "Welcome to image search system !"
3
+ echo "Please enjoy your time !"
4
+
5
+ python pred_retrieval.py -p "data/best_model.pth.tar" -d "data/cap_file.txt" -bs 1
run_train.sh ADDED
@@ -0,0 +1 @@
 
 
1
+ python train.py -bs 160 -gpu 1,2,3
tmp.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import requests
3
+ import numpy as np
4
+
5
+ def download_url_img(url):
6
+ """
7
+ 下载url图像
8
+ """
9
+
10
+ try:
11
+ response = requests.get(url, timeout=3)
12
+ except Exception as e:
13
+ print(str(e))
14
+ return False, []
15
+ if response is not None and response.status_code == 200:
16
+ input_image_data = response.content
17
+ np_arr = np.asarray(bytearray(input_image_data), np.uint8).reshape(1, -1)
18
+ parsed_image = cv2.imdecode(np_arr, cv2.IMREAD_UNCHANGED)
19
+ return True, parsed_image
20
+
21
+ download_url_img("http://images.cocodataset.org/train2017/000000146722.jpg")
22
+
23
+