Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -8,7 +8,7 @@ import torch
|
|
8 |
from transformers import AutoTokenizer, CLIPTextModelWithProjection
|
9 |
|
10 |
|
11 |
-
TITLE="""<h1 style="font-size:
|
12 |
|
13 |
DESCRIPTION="""This is a video retrieval demo using [Diangle/clip4clip-webvid](https://huggingface.co/Diangle/clip4clip-webvid)."""
|
14 |
IMAGE='<div style="text-align: right;"><img src="https://huggingface.co/spaces/Diangle/Clip4Clip-webvid/resolve/main/Searchium.png" width="333" height="216"/>'
|
@@ -23,7 +23,6 @@ ft_visual_features_database = np.load(ft_visual_features_file)
|
|
23 |
database_csv_path = os.path.join(DATA_PATH, 'dataset_v1.csv')
|
24 |
database_df = pd.read_csv(database_csv_path)
|
25 |
|
26 |
-
|
27 |
class NearestNeighbors:
|
28 |
"""
|
29 |
Class for NearestNeighbors.
|
@@ -56,28 +55,33 @@ class NearestNeighbors:
|
|
56 |
sim, idx = self.index.search(q_data, self.n_neighbors)
|
57 |
else:
|
58 |
if self.metric == 'binary':
|
59 |
-
print('binary search
|
60 |
-
bq_data = np.packbits((q_data > 0.0).astype(bool), axis=1)
|
61 |
-
print(bq_data.shape, self.index.d)
|
62 |
sim, idx = self.index.search(bq_data, max(self.rerank_from, self.n_neighbors))
|
63 |
|
64 |
if self.rerank_from > self.n_neighbors:
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
return sim, idx
|
72 |
|
|
|
73 |
model = CLIPTextModelWithProjection.from_pretrained("Diangle/clip4clip-webvid")
|
74 |
-
tokenizer =
|
75 |
|
76 |
def search(search_sentence):
|
77 |
inputs = tokenizer(text=search_sentence , return_tensors="pt", padding=True)
|
78 |
|
79 |
-
outputs = model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], return_dict=False)
|
80 |
-
# Customized projection layer
|
81 |
text_projection = model.state_dict()['text_projection.weight']
|
82 |
text_embeds = outputs[1] @ text_projection
|
83 |
final_output = text_embeds[torch.arange(text_embeds.shape[0]), inputs["input_ids"].argmax(dim=-1)]
|
@@ -89,8 +93,15 @@ def search(search_sentence):
|
|
89 |
|
90 |
nn_search = NearestNeighbors(n_neighbors=5, metric='binary', rerank_from=100)
|
91 |
nn_search.fit(np.packbits((ft_visual_features_database > 0.0).astype(bool), axis=1), o_data=ft_visual_features_database)
|
92 |
-
sims, idxs = nn_search.kneighbors(sequence_output)
|
93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
|
95 |
|
96 |
with gr.Blocks() as demo:
|
@@ -102,12 +113,13 @@ with gr.Blocks() as demo:
|
|
102 |
with gr.Column():
|
103 |
inp = gr.Textbox(placeholder="Write a sentence.")
|
104 |
btn = gr.Button(value="Retrieve")
|
105 |
-
ex = [["
|
|
|
106 |
gr.Examples(examples=ex,
|
107 |
-
|
108 |
-
|
109 |
with gr.Column():
|
110 |
-
out = [gr.
|
111 |
btn.click(search, inputs=inp, outputs=out)
|
112 |
|
113 |
demo.launch()
|
|
|
8 |
from transformers import AutoTokenizer, CLIPTextModelWithProjection
|
9 |
|
10 |
|
11 |
+
TITLE="""<h1 style="font-size: 64px;" align="center">Video Retrieval</h1>"""
|
12 |
|
13 |
DESCRIPTION="""This is a video retrieval demo using [Diangle/clip4clip-webvid](https://huggingface.co/Diangle/clip4clip-webvid)."""
|
14 |
IMAGE='<div style="text-align: right;"><img src="https://huggingface.co/spaces/Diangle/Clip4Clip-webvid/resolve/main/Searchium.png" width="333" height="216"/>'
|
|
|
23 |
database_csv_path = os.path.join(DATA_PATH, 'dataset_v1.csv')
|
24 |
database_df = pd.read_csv(database_csv_path)
|
25 |
|
|
|
26 |
class NearestNeighbors:
|
27 |
"""
|
28 |
Class for NearestNeighbors.
|
|
|
55 |
sim, idx = self.index.search(q_data, self.n_neighbors)
|
56 |
else:
|
57 |
if self.metric == 'binary':
|
58 |
+
print('This is binary search.')
|
59 |
+
bq_data = np.packbits((q_data > 0.0).astype(bool), axis=1)
|
|
|
60 |
sim, idx = self.index.search(bq_data, max(self.rerank_from, self.n_neighbors))
|
61 |
|
62 |
if self.rerank_from > self.n_neighbors:
|
63 |
+
re_sims = np.zeros([len(q_data), self.n_neighbors], dtype=float)
|
64 |
+
re_idxs = np.zeros([len(q_data), self.n_neighbors], dtype=float)
|
65 |
+
for i, q in enumerate(q_data):
|
66 |
+
rerank_data = self.o_data[idx[i]]
|
67 |
+
rerank_search = NearestNeighbors(n_neighbors=self.n_neighbors, metric='cosine')
|
68 |
+
rerank_search.fit(rerank_data)
|
69 |
+
re_sim, re_idx = rerank_search.kneighbors(np.asarray([q]))
|
70 |
+
re_sims[i, :] = re_sim
|
71 |
+
re_idxs[i, :] = idx[i][re_idx]
|
72 |
+
idx = re_idxs
|
73 |
+
sim = re_sims
|
74 |
+
|
75 |
return sim, idx
|
76 |
|
77 |
+
|
78 |
model = CLIPTextModelWithProjection.from_pretrained("Diangle/clip4clip-webvid")
|
79 |
+
tokenizer = CLIPTokenizer.from_pretrained("Diangle/clip4clip-webvid")
|
80 |
|
81 |
def search(search_sentence):
|
82 |
inputs = tokenizer(text=search_sentence , return_tensors="pt", padding=True)
|
83 |
|
84 |
+
outputs = model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], return_dict=False)
|
|
|
85 |
text_projection = model.state_dict()['text_projection.weight']
|
86 |
text_embeds = outputs[1] @ text_projection
|
87 |
final_output = text_embeds[torch.arange(text_embeds.shape[0]), inputs["input_ids"].argmax(dim=-1)]
|
|
|
93 |
|
94 |
nn_search = NearestNeighbors(n_neighbors=5, metric='binary', rerank_from=100)
|
95 |
nn_search.fit(np.packbits((ft_visual_features_database > 0.0).astype(bool), axis=1), o_data=ft_visual_features_database)
|
96 |
+
sims, idxs = nn_search.kneighbors(sequence_output)
|
97 |
+
# print(database_df.iloc[idxs[0]]['contentUrl'])
|
98 |
+
urls = database_df.iloc[idxs[0]]['contentUrl'].to_list()
|
99 |
+
AUTOPLAY_VIDEOS = []
|
100 |
+
for url in urls:
|
101 |
+
AUTOPLAY_VIDEOS.append("""<video controls muted autoplay>
|
102 |
+
<source src={} type="video/mp4">
|
103 |
+
</video>""".format(url))
|
104 |
+
return AUTOPLAY_VIDEOS
|
105 |
|
106 |
|
107 |
with gr.Blocks() as demo:
|
|
|
113 |
with gr.Column():
|
114 |
inp = gr.Textbox(placeholder="Write a sentence.")
|
115 |
btn = gr.Button(value="Retrieve")
|
116 |
+
ex = [["mind-blowing magic tricks"],["baking chocolate cake"],
|
117 |
+
["birds fly in the sky"], ["natural wonders of the world"]]
|
118 |
gr.Examples(examples=ex,
|
119 |
+
inputs=[inp]
|
120 |
+
)
|
121 |
with gr.Column():
|
122 |
+
out = [gr.HTML() for _ in range(5)]
|
123 |
btn.click(search, inputs=inp, outputs=out)
|
124 |
|
125 |
demo.launch()
|