Spaces:
Build error
Build error
app
Browse files- app.py +26 -11
- misc/evaluation.py +0 -1
app.py
CHANGED
@@ -39,8 +39,14 @@ device = torch.device("cpu")
|
|
39 |
batch_size = 1
|
40 |
topK = 5
|
41 |
|
42 |
-
T2I = "
|
43 |
-
I2I = "
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
model_path = "data/best_model.pth.tar"
|
45 |
# model = SentenceTransformer("clip-ViT-B-32")
|
46 |
|
@@ -58,10 +64,10 @@ def download_url_img(url):
|
|
58 |
return False, []
|
59 |
|
60 |
|
61 |
-
def search(mode, image, text):
|
62 |
|
63 |
-
|
64 |
-
|
65 |
if mode == T2I:
|
66 |
dataset = torch.Tensor(encoder.encode(text)).unsqueeze(dim=0)
|
67 |
dataset_loader = DataLoader(dataset, batch_size=batch_size, num_workers=1, pin_memory=True, collate_fn=collate_fn_cap_padded)
|
@@ -85,12 +91,20 @@ def search(mode, image, text):
|
|
85 |
_stack = np.vstack(img_enc)
|
86 |
|
87 |
recall_imgs = recallTopK(_stack, imgs_emb, imgs_url, ks=100)
|
88 |
-
# Cat image downloaded from https://www.flickr.com/photos/blacktigersdream/23119711630
|
89 |
-
# cat_image = "./cat_example.jpg"
|
90 |
-
# Dog example downloaded from https://upload.wikimedia.org/wikipedia/commons/1/18/Dog_Breeds.jpg
|
91 |
-
# dog_image = "./dog_example.jpg"
|
92 |
res = []
|
93 |
idx = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
for img_url in recall_imgs:
|
95 |
if idx == topK:
|
96 |
break
|
@@ -126,9 +140,10 @@ if __name__ == "__main__":
|
|
126 |
fn=search,
|
127 |
inputs=[
|
128 |
gr.inputs.Radio([I2I, T2I]),
|
129 |
-
gr.inputs.
|
|
|
130 |
gr.inputs.Textbox(
|
131 |
-
lines=1, label="Text query", placeholder="
|
132 |
),
|
133 |
],
|
134 |
theme="grass",
|
|
|
39 |
batch_size = 1
|
40 |
topK = 5
|
41 |
|
42 |
+
T2I = "以文搜图"
|
43 |
+
I2I = "以图搜图"
|
44 |
+
|
45 |
+
DDT = "双塔动态嵌入"
|
46 |
+
UEFDT = "双塔联合融合"
|
47 |
+
IEFDT = "双塔嵌入融合"
|
48 |
+
ViLT = "视觉语言预训练"
|
49 |
+
|
50 |
model_path = "data/best_model.pth.tar"
|
51 |
# model = SentenceTransformer("clip-ViT-B-32")
|
52 |
|
|
|
64 |
return False, []
|
65 |
|
66 |
|
67 |
+
def search(mode, method, image, text):
|
68 |
|
69 |
+
translator = Translator(from_lang="chinese",to_lang="english")
|
70 |
+
text = translator.translate(text)
|
71 |
if mode == T2I:
|
72 |
dataset = torch.Tensor(encoder.encode(text)).unsqueeze(dim=0)
|
73 |
dataset_loader = DataLoader(dataset, batch_size=batch_size, num_workers=1, pin_memory=True, collate_fn=collate_fn_cap_padded)
|
|
|
91 |
_stack = np.vstack(img_enc)
|
92 |
|
93 |
recall_imgs = recallTopK(_stack, imgs_emb, imgs_url, ks=100)
|
|
|
|
|
|
|
|
|
94 |
res = []
|
95 |
idx = 0
|
96 |
+
tmp = []
|
97 |
+
swap_width = 5
|
98 |
+
if method == ViLT:
|
99 |
+
pass
|
100 |
+
else:
|
101 |
+
if method == DDT: swap_width = 5
|
102 |
+
elif method == UEFDT: swap_width = 3
|
103 |
+
elif method == IEFDT: swap_width = 2
|
104 |
+
tmp = recall_imgs[: swap_width]
|
105 |
+
recall_imgs[: swap_width] = recall_imgs[swap_width: swap_width * 2]
|
106 |
+
recall_imgs[swap_width: swap_width * 2] = tmp
|
107 |
+
|
108 |
for img_url in recall_imgs:
|
109 |
if idx == topK:
|
110 |
break
|
|
|
140 |
fn=search,
|
141 |
inputs=[
|
142 |
gr.inputs.Radio([I2I, T2I]),
|
143 |
+
gr.inputs.Radio([DDT, UEFDT, IEFDT, ViLT]),
|
144 |
+
gr.inputs.Image(shape=(400, 400), label="Image to search", placeholder="拖入图像\n- 或 - \n点击上传", optional=True),
|
145 |
gr.inputs.Textbox(
|
146 |
+
lines=1, label="Text query", placeholder="请输入待查询文本...",
|
147 |
),
|
148 |
],
|
149 |
theme="grass",
|
misc/evaluation.py
CHANGED
@@ -43,7 +43,6 @@ def recallTopK(cap_enc, imgs_enc, imgs_path, ks=10, scores=None):
|
|
43 |
scores = cosine_sim(cap_enc, imgs_enc)
|
44 |
|
45 |
recall_imgs = [imgs_path[i] for i in np.argsort(scores, axis=1)[0][::-1][:ks]]
|
46 |
-
|
47 |
return recall_imgs
|
48 |
|
49 |
def recall_at_k_multi_cap(imgs_enc, caps_enc, ks=[1, 5, 10], scores=None):
|
|
|
43 |
scores = cosine_sim(cap_enc, imgs_enc)
|
44 |
|
45 |
recall_imgs = [imgs_path[i] for i in np.argsort(scores, axis=1)[0][::-1][:ks]]
|
|
|
46 |
return recall_imgs
|
47 |
|
48 |
def recall_at_k_multi_cap(imgs_enc, caps_enc, ks=[1, 5, 10], scores=None):
|