Spaces:
Runtime error
Runtime error
test
Browse files- .vscode/sftp.json +32 -0
- app.py +46 -57
- coco_img_emb.pkl +2 -2
- data/best_model.pth.tar +2 -2
- data/utable.npy +2 -2
- flagged/1st Best match/0.png +0 -0
- flagged/2nd Best match/0.png +0 -0
- flagged/3rd Best match/0.png +0 -0
- flagged/4rd Best match/0.png +0 -0
- flagged/5rd Best match/0.png +0 -0
- flagged/log.csv +3 -0
- misc/__pycache__/dataset.cpython-37.pyc +0 -0
- misc/__pycache__/evaluation.cpython-37.pyc +0 -0
- misc/dataset.py +1 -1
- misc/evaluation.py +8 -7
- requirements.txt +16 -0
- run.sh +5 -0
- run_train.sh +1 -0
- tmp.py +23 -0
.vscode/sftp.json
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"name": "itr-ddt",
|
3 |
+
"host": "192.168.0.109",
|
4 |
+
"protocol": "sftp",
|
5 |
+
"port": 22,
|
6 |
+
"username": "atticus",
|
7 |
+
"password":"qs123",
|
8 |
+
"passphrase": "null",
|
9 |
+
"passive": false,
|
10 |
+
"interactiveAuth": true,
|
11 |
+
"remotePath": "/home/atticus/proj/matching/itr-ddt",
|
12 |
+
"context": "D:/Projects/MultiModal/itr-ddt",
|
13 |
+
"uploadOnSave": true,
|
14 |
+
"downloadOnOpen":true,
|
15 |
+
"syncMode": "update",
|
16 |
+
"ignore": [
|
17 |
+
"**/.vscode/**",
|
18 |
+
"**/.git/**",
|
19 |
+
"**/.DS_Store",
|
20 |
+
"**/*.tar",
|
21 |
+
"**/*.zip",
|
22 |
+
"**/*.pkl",
|
23 |
+
"**/*.json",
|
24 |
+
"**/*.npy"
|
25 |
+
],
|
26 |
+
"watcher": {
|
27 |
+
"files": "*",
|
28 |
+
"autoUpload": false,
|
29 |
+
"autoDelete": false
|
30 |
+
}
|
31 |
+
}
|
32 |
+
|
app.py
CHANGED
@@ -32,18 +32,22 @@ import sys
|
|
32 |
from misc.dataset import TextEncoder
|
33 |
import requests
|
34 |
import cv2
|
|
|
|
|
|
|
35 |
|
36 |
|
37 |
device = torch.device("cuda")
|
38 |
-
batch_size =
|
39 |
-
|
40 |
|
41 |
T2I = "Text 2 Image"
|
42 |
-
|
43 |
I2I = "Image 2 Image"
|
44 |
model_path = "data/best_model.pth.tar"
|
45 |
# model = SentenceTransformer("clip-ViT-B-32")
|
46 |
|
|
|
|
|
47 |
|
48 |
|
49 |
def download_url_img(url):
|
@@ -55,83 +59,68 @@ def download_url_img(url):
|
|
55 |
return False, []
|
56 |
if response is not None and response.status_code == 200:
|
57 |
input_image_data = response.content
|
58 |
-
np_arr = np.asarray(bytearray(input_image_data), np.uint8).reshape(1, -1)
|
59 |
-
parsed_image = cv2.imdecode(np_arr, cv2.IMREAD_UNCHANGED)
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
print("Loading model from:", model_path)
|
65 |
-
checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage)
|
66 |
-
|
67 |
-
join_emb = joint_embedding(checkpoint['args_dict'])
|
68 |
-
join_emb.load_state_dict(checkpoint["state_dict"])
|
69 |
|
70 |
-
for param in join_emb.parameters():
|
71 |
-
param.requires_grad = False
|
72 |
-
|
73 |
-
join_emb.to(device)
|
74 |
-
join_emb.eval()
|
75 |
-
|
76 |
-
encoder = TextEncoder()
|
77 |
-
print("Loading model done")
|
78 |
-
# (4) design intersection mode.
|
79 |
-
print("Please input your description of the image that you wanna search >>>")
|
80 |
|
81 |
-
|
82 |
-
# with open(args.data_path, 'w') as cap_file:
|
83 |
-
# cap_file.writelines(cap_str)
|
84 |
|
|
|
|
|
85 |
dataset = torch.Tensor(encoder.encode(text)).unsqueeze(dim=0)
|
86 |
dataset_loader = DataLoader(dataset, batch_size=batch_size, num_workers=1, pin_memory=True, collate_fn=collate_fn_cap_padded)
|
87 |
caps_enc = list()
|
|
|
88 |
for _, (caps, length) in enumerate(dataset_loader, 0):
|
89 |
input_caps = caps.to(device)
|
90 |
with torch.no_grad():
|
91 |
_, caps_emb = join_emb(None, input_caps, length)
|
92 |
-
caps_enc.append(caps_emb
|
|
|
93 |
|
|
|
94 |
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
print("recall from resources ...")
|
99 |
-
# (1) load candidate imgs from saved embeding pkl file.
|
100 |
-
imgs_emb_file_path = "./coco_img_emb"
|
101 |
-
# imgs_emb(40775, 2400)
|
102 |
-
imgs_emb, imgs_path = load_obj(imgs_emb_file_path)
|
103 |
-
# (2) calculate the sim between cap and imgs.
|
104 |
-
# (3) rank imgs and display the searching result.
|
105 |
-
imgs_url = os.path.join("http://images.cocodataset.org/train2017", imgs_path.strip().split('_')[-1])
|
106 |
-
|
107 |
-
recall_imgs = recallTopK(caps_stack, imgs_emb, imgs_url, ks=5)
|
108 |
|
109 |
# Cat image downloaded from https://www.flickr.com/photos/blacktigersdream/23119711630
|
110 |
# cat_image = "./cat_example.jpg"
|
111 |
# Dog example downloaded from https://upload.wikimedia.org/wikipedia/commons/1/18/Dog_Breeds.jpg
|
112 |
# dog_image = "./dog_example.jpg"
|
113 |
-
res = [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
|
115 |
-
|
116 |
-
#
|
117 |
-
|
118 |
-
# emb = model.encode([Image.fromarray(image)], convert_to_tensor=True)
|
119 |
-
# elif mode == T2I:
|
120 |
-
# logger.info(f"Processing text in mode {mode}")
|
121 |
-
# emb = model.encode([text], convert_to_tensor=True)
|
122 |
|
123 |
-
|
124 |
-
|
125 |
-
# return [Image.open(img_folder / img_names[top_k_best_image]) for top_k_best_image in torch.topk(cos_sim, 5, 0).indices]
|
126 |
-
return res
|
127 |
|
128 |
-
|
|
|
129 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
iface = gr.Interface(
|
131 |
fn=search,
|
132 |
inputs=[
|
133 |
-
gr.inputs.
|
134 |
-
gr.inputs.Radio([T2I, I2I]),
|
135 |
gr.inputs.Textbox(
|
136 |
lines=1, label="Text query", placeholder="Introduce the search text...",
|
137 |
),
|
@@ -147,4 +136,4 @@ if __name__ == "__main__":
|
|
147 |
title="HUST毕业设计-图文检索系统",
|
148 |
description="请输入图片或文本,将为您展示相关的图片:",
|
149 |
)
|
150 |
-
iface.launch()
|
|
|
32 |
from misc.dataset import TextEncoder
|
33 |
import requests
|
34 |
import cv2
|
35 |
+
from io import BytesIO
|
36 |
+
from translate import Translator
|
37 |
+
import cupy as cp
|
38 |
|
39 |
|
40 |
device = torch.device("cuda")
|
41 |
+
batch_size = 1
|
42 |
+
topK = 5
|
43 |
|
44 |
T2I = "Text 2 Image"
|
|
|
45 |
I2I = "Image 2 Image"
|
46 |
model_path = "data/best_model.pth.tar"
|
47 |
# model = SentenceTransformer("clip-ViT-B-32")
|
48 |
|
49 |
+
img_folder = Path("./photos/")
|
50 |
+
|
51 |
|
52 |
|
53 |
def download_url_img(url):
|
|
|
59 |
return False, []
|
60 |
if response is not None and response.status_code == 200:
|
61 |
input_image_data = response.content
|
62 |
+
# np_arr = np.asarray(bytearray(input_image_data), np.uint8).reshape(1, -1)
|
63 |
+
# parsed_image = cv2.imdecode(np_arr, cv2.IMREAD_UNCHANGED)
|
64 |
+
image=Image.open(BytesIO(input_image_data))
|
65 |
+
return True, image
|
66 |
+
return False, []
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
|
69 |
+
def search(mode, text):
|
|
|
|
|
70 |
|
71 |
+
# translator = Translator(from_lang="chinese",to_lang="english")
|
72 |
+
# text = translator.translate(text)
|
73 |
dataset = torch.Tensor(encoder.encode(text)).unsqueeze(dim=0)
|
74 |
dataset_loader = DataLoader(dataset, batch_size=batch_size, num_workers=1, pin_memory=True, collate_fn=collate_fn_cap_padded)
|
75 |
caps_enc = list()
|
76 |
+
|
77 |
for _, (caps, length) in enumerate(dataset_loader, 0):
|
78 |
input_caps = caps.to(device)
|
79 |
with torch.no_grad():
|
80 |
_, caps_emb = join_emb(None, input_caps, length)
|
81 |
+
caps_enc.append(caps_emb)
|
82 |
+
caps_stack = cp.vstack(caps_enc)
|
83 |
|
84 |
+
imgs_url = [os.path.join("http://images.cocodataset.org/train2017", img_path.strip().split('_')[-1]) for img_path in imgs_path]
|
85 |
|
86 |
+
recall_imgs = recallTopK(caps_stack, imgs_emb, imgs_url, ks=100)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
|
88 |
# Cat image downloaded from https://www.flickr.com/photos/blacktigersdream/23119711630
|
89 |
# cat_image = "./cat_example.jpg"
|
90 |
# Dog example downloaded from https://upload.wikimedia.org/wikipedia/commons/1/18/Dog_Breeds.jpg
|
91 |
# dog_image = "./dog_example.jpg"
|
92 |
+
res = []
|
93 |
+
idx = 0
|
94 |
+
for img_url in recall_imgs:
|
95 |
+
if idx == topK:
|
96 |
+
break
|
97 |
+
b, img = download_url_img(img_url)
|
98 |
+
if b:
|
99 |
+
res.append(img)
|
100 |
+
idx += 1
|
101 |
+
return res
|
102 |
|
103 |
+
if __name__ == "__main__":
|
104 |
+
# print("Loading model from:", model_path)
|
105 |
+
checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage)
|
|
|
|
|
|
|
|
|
106 |
|
107 |
+
join_emb = joint_embedding(checkpoint['args_dict'])
|
108 |
+
join_emb.load_state_dict(checkpoint["state_dict"])
|
|
|
|
|
109 |
|
110 |
+
for param in join_emb.parameters():
|
111 |
+
param.requires_grad = False
|
112 |
|
113 |
+
join_emb.to(device)
|
114 |
+
join_emb.eval()
|
115 |
+
encoder = TextEncoder()
|
116 |
+
imgs_emb_file_path = "./coco_img_emb"
|
117 |
+
imgs_emb, imgs_path = load_obj(imgs_emb_file_path)
|
118 |
+
imgs_emb = cp.asarray(imgs_emb)
|
119 |
+
print("prepare done!")
|
120 |
iface = gr.Interface(
|
121 |
fn=search,
|
122 |
inputs=[
|
123 |
+
gr.inputs.Radio([T2I]),
|
|
|
124 |
gr.inputs.Textbox(
|
125 |
lines=1, label="Text query", placeholder="Introduce the search text...",
|
126 |
),
|
|
|
136 |
title="HUST毕业设计-图文检索系统",
|
137 |
description="请输入图片或文本,将为您展示相关的图片:",
|
138 |
)
|
139 |
+
iface.launch(share=True)
|
coco_img_emb.pkl
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:012377f7e09f9f95cc15a391f2da541ede470d4c6d6c36f9239bb59def6ec269
|
3 |
+
size 108068864
|
data/best_model.pth.tar
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f8ada75eacbe26ecf1c3507238b542e1db689254a1dac3825ffe4842443d2947
|
3 |
+
size 108068864
|
data/utable.npy
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8c8af23b32fcfb69ad00bc22f39c557e2926b66e2edb3275437157967b5f8257
|
3 |
+
size 120258560
|
flagged/1st Best match/0.png
ADDED
flagged/2nd Best match/0.png
ADDED
flagged/3rd Best match/0.png
ADDED
flagged/4rd Best match/0.png
ADDED
flagged/5rd Best match/0.png
ADDED
flagged/log.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
Image to search,Text query,1st Best match,2nd Best match,3rd Best match,4rd Best match,5rd Best match,timestamp
|
2 |
+
,,,,,,,2022-03-10 00:37:13.783708
|
3 |
+
,Text 2 Image,,1st Best match/0.png,2nd Best match/0.png,3rd Best match/0.png,4rd Best match/0.png,5rd Best match/0.png,2022-03-10 01:17:09.801153
|
misc/__pycache__/dataset.cpython-37.pyc
CHANGED
Binary files a/misc/__pycache__/dataset.cpython-37.pyc and b/misc/__pycache__/dataset.cpython-37.pyc differ
|
|
misc/__pycache__/evaluation.cpython-37.pyc
CHANGED
Binary files a/misc/__pycache__/evaluation.cpython-37.pyc and b/misc/__pycache__/evaluation.cpython-37.pyc differ
|
|
misc/dataset.py
CHANGED
@@ -269,7 +269,7 @@ class TextEncoder(object):
|
|
269 |
def __init__(self, word_dict_path=path["WORD_DICT"]):
|
270 |
|
271 |
path_params = os.path.join(word_dict_path, 'utable.npy')
|
272 |
-
self.params = np.load(path_params, encoding='latin1')
|
273 |
self.dico = _load_dictionary(word_dict_path)
|
274 |
|
275 |
def encode(self, text):
|
|
|
269 |
def __init__(self, word_dict_path=path["WORD_DICT"]):
|
270 |
|
271 |
path_params = os.path.join(word_dict_path, 'utable.npy')
|
272 |
+
self.params = np.load(path_params, encoding='latin1', allow_pickle=True)
|
273 |
self.dico = _load_dictionary(word_dict_path)
|
274 |
|
275 |
def encode(self, text):
|
misc/evaluation.py
CHANGED
@@ -23,26 +23,27 @@ Author: Martin Engilberge
|
|
23 |
import numpy as np
|
24 |
|
25 |
from misc.utils import flatten
|
|
|
26 |
|
27 |
def cosine_sim(A, B):
|
28 |
-
img_norm =
|
29 |
-
caps_norm =
|
30 |
|
31 |
-
scores =
|
32 |
|
33 |
-
norms =
|
34 |
-
|
35 |
|
36 |
scores = (scores / norms)
|
37 |
|
38 |
return scores
|
39 |
|
40 |
-
def recallTopK(cap_enc, imgs_enc,
|
41 |
|
42 |
if scores is None:
|
43 |
scores = cosine_sim(cap_enc, imgs_enc)
|
44 |
|
45 |
-
recall_imgs = [
|
46 |
|
47 |
return recall_imgs
|
48 |
|
|
|
23 |
import numpy as np
|
24 |
|
25 |
from misc.utils import flatten
|
26 |
+
import cupy as cp
|
27 |
|
28 |
def cosine_sim(A, B):
|
29 |
+
img_norm = cp.linalg.norm(A, axis=1)
|
30 |
+
caps_norm = cp.linalg.norm(B, axis=1)
|
31 |
|
32 |
+
scores = cp.dot(A, B.T)
|
33 |
|
34 |
+
norms = cp.dot(cp.expand_dims(img_norm, 1),
|
35 |
+
cp.expand_dims(caps_norm.T, 1).T)
|
36 |
|
37 |
scores = (scores / norms)
|
38 |
|
39 |
return scores
|
40 |
|
41 |
+
def recallTopK(cap_enc, imgs_enc, imgs_path, ks=10, scores=None):
|
42 |
|
43 |
if scores is None:
|
44 |
scores = cosine_sim(cap_enc, imgs_enc)
|
45 |
|
46 |
+
recall_imgs = [imgs_path[cp.asnumpy(i)] for i in cp.argsort(scores, axis=1)[0][::-1][:ks]]
|
47 |
|
48 |
return recall_imgs
|
49 |
|
requirements.txt
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
cupy==10.2.0
|
2 |
+
cupy_cuda101==9.6.0
|
3 |
+
gradio==2.8.9
|
4 |
+
matplotlib==2.2.2
|
5 |
+
nltk==3.3
|
6 |
+
numpy==1.21.5
|
7 |
+
Pillow==9.0.1
|
8 |
+
pycocotools==2.0.4
|
9 |
+
requests==2.27.1
|
10 |
+
scipy==1.1.0
|
11 |
+
sru==2.6.0
|
12 |
+
torch==1.10.2
|
13 |
+
torchvision==0.2.1
|
14 |
+
tqdm==4.63.0
|
15 |
+
translate==3.6.1
|
16 |
+
visual_genome==1.1.1
|
run.sh
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
echo "Welcome to image search system !"
|
3 |
+
echo "Please enjoy your time !"
|
4 |
+
|
5 |
+
python pred_retrieval.py -p "data/best_model.pth.tar" -d "data/cap_file.txt" -bs 1
|
run_train.sh
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
python train.py -bs 160 -gpu 1,2,3
|
tmp.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import requests
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
def download_url_img(url):
|
6 |
+
"""
|
7 |
+
下载url图像
|
8 |
+
"""
|
9 |
+
|
10 |
+
try:
|
11 |
+
response = requests.get(url, timeout=3)
|
12 |
+
except Exception as e:
|
13 |
+
print(str(e))
|
14 |
+
return False, []
|
15 |
+
if response is not None and response.status_code == 200:
|
16 |
+
input_image_data = response.content
|
17 |
+
np_arr = np.asarray(bytearray(input_image_data), np.uint8).reshape(1, -1)
|
18 |
+
parsed_image = cv2.imdecode(np_arr, cv2.IMREAD_UNCHANGED)
|
19 |
+
return True, parsed_image
|
20 |
+
|
21 |
+
download_url_img("http://images.cocodataset.org/train2017/000000146722.jpg")
|
22 |
+
|
23 |
+
|