Spaces:
Build error
Build error
transform
Browse files
app.py
CHANGED
@@ -33,7 +33,7 @@ from misc.dataset import TextEncoder
|
|
33 |
import requests
|
34 |
from io import BytesIO
|
35 |
from translate import Translator
|
36 |
-
|
37 |
|
38 |
device = torch.device("cpu")
|
39 |
batch_size = 1
|
@@ -74,13 +74,13 @@ def search(mode, image, text):
|
|
74 |
_stack = np.vstack(caps_enc)
|
75 |
|
76 |
elif mode == I2I:
|
77 |
-
dataset = torch.Tensor(image).unsqueeze(dim=0)
|
78 |
dataset_loader = DataLoader(dataset, batch_size=batch_size, num_workers=1, pin_memory=True, collate_fn=collate_fn_cap_padded)
|
79 |
img_enc = list()
|
80 |
for i, (imgs, length) in enumerate(dataset_loader, 0):
|
81 |
input_imgs = imgs
|
82 |
with torch.no_grad():
|
83 |
-
|
84 |
img_enc.append(output_emb)
|
85 |
_stack = np.vstack(img_enc)
|
86 |
|
@@ -118,12 +118,15 @@ if __name__ == "__main__":
|
|
118 |
imgs_emb_file_path = "./coco_img_emb"
|
119 |
imgs_emb, imgs_path = load_obj(imgs_emb_file_path)
|
120 |
imgs_url = [os.path.join("http://images.cocodataset.org/train2017", img_path.strip().split('_')[-1]) for img_path in imgs_path]
|
|
|
|
|
|
|
121 |
print("prepare done!")
|
122 |
iface = gr.Interface(
|
123 |
fn=search,
|
124 |
inputs=[
|
125 |
gr.inputs.Radio([I2I, T2I]),
|
126 |
-
gr.inputs.Image(shape=(
|
127 |
gr.inputs.Textbox(
|
128 |
lines=1, label="Text query", placeholder="Introduce the search text...",
|
129 |
),
|
|
|
33 |
import requests
|
34 |
from io import BytesIO
|
35 |
from translate import Translator
|
36 |
+
from torchvision import transforms
|
37 |
|
38 |
device = torch.device("cpu")
|
39 |
batch_size = 1
|
|
|
74 |
_stack = np.vstack(caps_enc)
|
75 |
|
76 |
elif mode == I2I:
|
77 |
+
dataset = normalize(torch.Tensor(image).permute(2, 0, 1)).unsqueeze(dim=0)
|
78 |
dataset_loader = DataLoader(dataset, batch_size=batch_size, num_workers=1, pin_memory=True, collate_fn=collate_fn_cap_padded)
|
79 |
img_enc = list()
|
80 |
for i, (imgs, length) in enumerate(dataset_loader, 0):
|
81 |
input_imgs = imgs
|
82 |
with torch.no_grad():
|
83 |
+
output_emb, _ = join_emb(input_imgs, None, None)
|
84 |
img_enc.append(output_emb)
|
85 |
_stack = np.vstack(img_enc)
|
86 |
|
|
|
118 |
imgs_emb_file_path = "./coco_img_emb"
|
119 |
imgs_emb, imgs_path = load_obj(imgs_emb_file_path)
|
120 |
imgs_url = [os.path.join("http://images.cocodataset.org/train2017", img_path.strip().split('_')[-1]) for img_path in imgs_path]
|
121 |
+
|
122 |
+
normalize = transforms.Normalize(mean=[0.485 * 255, 0.456 * 255, 0.406 * 255], std=[0.229 * 255, 0.224 * 255, 0.225 * 255])
|
123 |
+
|
124 |
print("prepare done!")
|
125 |
iface = gr.Interface(
|
126 |
fn=search,
|
127 |
inputs=[
|
128 |
gr.inputs.Radio([I2I, T2I]),
|
129 |
+
gr.inputs.Image(shape=(400, 400), label="Image to search", optional=True),
|
130 |
gr.inputs.Textbox(
|
131 |
lines=1, label="Text query", placeholder="Introduce the search text...",
|
132 |
),
|