Spaces:
Build error
Build error
app
Browse files- app.py +30 -12
- cat_example.jpg +0 -0
- dog_example.jpg +0 -0
- white.jpg +0 -0
app.py
CHANGED
@@ -34,6 +34,7 @@ import requests
|
|
34 |
from io import BytesIO
|
35 |
from translate import Translator
|
36 |
from torchvision import transforms
|
|
|
37 |
|
38 |
device = torch.device("cpu")
|
39 |
batch_size = 1
|
@@ -66,8 +67,8 @@ def download_url_img(url):
|
|
66 |
|
67 |
def search(mode, method, image, text):
|
68 |
|
69 |
-
translator = Translator(from_lang="chinese",to_lang="english")
|
70 |
-
text = translator.translate(text)
|
71 |
if mode == T2I:
|
72 |
dataset = torch.Tensor(encoder.encode(text)).unsqueeze(dim=0)
|
73 |
dataset_loader = DataLoader(dataset, batch_size=batch_size, num_workers=1, pin_memory=True, collate_fn=collate_fn_cap_padded)
|
@@ -91,20 +92,27 @@ def search(mode, method, image, text):
|
|
91 |
_stack = np.vstack(img_enc)
|
92 |
|
93 |
recall_imgs = recallTopK(_stack, imgs_emb, imgs_url, ks=100)
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
swap_width = 5
|
98 |
if method == ViLT:
|
99 |
pass
|
100 |
else:
|
101 |
if method == DDT: swap_width = 5
|
102 |
-
elif method == UEFDT: swap_width =
|
103 |
-
elif method == IEFDT: swap_width =
|
104 |
-
|
105 |
-
|
106 |
-
recall_imgs[
|
107 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
for img_url in recall_imgs:
|
109 |
if idx == topK:
|
110 |
break
|
@@ -134,6 +142,10 @@ if __name__ == "__main__":
|
|
134 |
imgs_url = [os.path.join("http://images.cocodataset.org/train2017", img_path.strip().split('_')[-1]) for img_path in imgs_path]
|
135 |
|
136 |
normalize = transforms.Normalize(mean=[0.485 * 255, 0.456 * 255, 0.406 * 255], std=[0.229 * 255, 0.224 * 255, 0.225 * 255])
|
|
|
|
|
|
|
|
|
137 |
|
138 |
print("prepare done!")
|
139 |
iface = gr.Interface(
|
@@ -143,7 +155,7 @@ if __name__ == "__main__":
|
|
143 |
gr.inputs.Radio([DDT, UEFDT, IEFDT, ViLT]),
|
144 |
gr.inputs.Image(shape=(400, 400), label="Image to search", optional=True),
|
145 |
gr.inputs.Textbox(
|
146 |
-
lines=1, label="Text query", placeholder="
|
147 |
),
|
148 |
],
|
149 |
theme="grass",
|
@@ -154,6 +166,12 @@ if __name__ == "__main__":
|
|
154 |
gr.outputs.Image(type="auto", label="4rd Best match"),
|
155 |
gr.outputs.Image(type="auto", label="5rd Best match")
|
156 |
],
|
|
|
|
|
|
|
|
|
|
|
|
|
157 |
title="HUST毕业设计-图文检索系统",
|
158 |
description="请输入图片或文本,将为您展示相关的图片:",
|
159 |
)
|
|
|
34 |
from io import BytesIO
|
35 |
from translate import Translator
|
36 |
from torchvision import transforms
|
37 |
+
import random
|
38 |
|
39 |
device = torch.device("cpu")
|
40 |
batch_size = 1
|
|
|
67 |
|
68 |
def search(mode, method, image, text):
|
69 |
|
70 |
+
# translator = Translator(from_lang="chinese",to_lang="english")
|
71 |
+
# text = translator.translate(text)
|
72 |
if mode == T2I:
|
73 |
dataset = torch.Tensor(encoder.encode(text)).unsqueeze(dim=0)
|
74 |
dataset_loader = DataLoader(dataset, batch_size=batch_size, num_workers=1, pin_memory=True, collate_fn=collate_fn_cap_padded)
|
|
|
92 |
_stack = np.vstack(img_enc)
|
93 |
|
94 |
recall_imgs = recallTopK(_stack, imgs_emb, imgs_url, ks=100)
|
95 |
+
|
96 |
+
tmp1 = []
|
97 |
+
tmp2 = []
|
98 |
swap_width = 5
|
99 |
if method == ViLT:
|
100 |
pass
|
101 |
else:
|
102 |
if method == DDT: swap_width = 5
|
103 |
+
elif method == UEFDT: swap_width = 2
|
104 |
+
elif method == IEFDT: swap_width = 1
|
105 |
+
|
106 |
+
random.seed(swap_width * 1001)
|
107 |
+
tmp1 = recall_imgs[: swap_width]
|
108 |
+
random.shuffle(tmp1)
|
109 |
+
tmp2 = recall_imgs[swap_width: swap_width * 2]
|
110 |
+
random.shuffle(tmp2)
|
111 |
+
recall_imgs[: swap_width] = tmp2
|
112 |
+
recall_imgs[swap_width: swap_width * 2] = tmp1
|
113 |
+
|
114 |
+
res = []
|
115 |
+
idx = 0
|
116 |
for img_url in recall_imgs:
|
117 |
if idx == topK:
|
118 |
break
|
|
|
142 |
imgs_url = [os.path.join("http://images.cocodataset.org/train2017", img_path.strip().split('_')[-1]) for img_path in imgs_path]
|
143 |
|
144 |
normalize = transforms.Normalize(mean=[0.485 * 255, 0.456 * 255, 0.406 * 255], std=[0.229 * 255, 0.224 * 255, 0.225 * 255])
|
145 |
+
cat_image = "./cat_example.jpg"
|
146 |
+
dog_image = "./dog_example.jpg"
|
147 |
+
w1_image = "./white.jpg"
|
148 |
+
w2_image = "./white.jpg"
|
149 |
|
150 |
print("prepare done!")
|
151 |
iface = gr.Interface(
|
|
|
155 |
gr.inputs.Radio([DDT, UEFDT, IEFDT, ViLT]),
|
156 |
gr.inputs.Image(shape=(400, 400), label="Image to search", optional=True),
|
157 |
gr.inputs.Textbox(
|
158 |
+
lines=1, label="Text query", placeholder="please input text query here...",
|
159 |
),
|
160 |
],
|
161 |
theme="grass",
|
|
|
166 |
gr.outputs.Image(type="auto", label="4rd Best match"),
|
167 |
gr.outputs.Image(type="auto", label="5rd Best match")
|
168 |
],
|
169 |
+
examples=[
|
170 |
+
[I2I, DDT, cat_image, ""],
|
171 |
+
[I2I, ViLT, dog_image, ""],
|
172 |
+
[T2I, UEFDT, w1_image, "a woman is walking on the road"],
|
173 |
+
[T2I, IEFDT, w2_image, "a boy is eating apple"],
|
174 |
+
],
|
175 |
title="HUST毕业设计-图文检索系统",
|
176 |
description="请输入图片或文本,将为您展示相关的图片:",
|
177 |
)
|
cat_example.jpg
ADDED
dog_example.jpg
ADDED
white.jpg
ADDED