Video-Search / app.py
Diangle's picture
Update app.py
8c3fa1e
raw
history blame
6.33 kB
import gradio as gr
import os
import numpy as np
import pandas as pd
from IPython import display
import faiss
import torch
from transformers import CLIPTokenizer, CLIPTextModelWithProjection
HTML="""
<!DOCTYPE html>
<html>
<style>
.container {
align-items: center;
justify-content: center;
}
img {
max-width: 10%;
max-height:10%;
float: left;
}
.text {
font-size: 32px;
padding-top: 15%;
padding-left: 15%;
padding-bottom: 5%;
float: left
}
</style>
<body>
<div class="container">
<div class="image">
<img src="https://huggingface.co/spaces/Searchium-ai/Video-Search/resolve/main/Searchium.png" width="333" height="216">
</div>
<div class="text">
<h1 style="font-size: 32px;"> Large Scale Video Search </h1>
</div>
</div>
</body>
</html>
"""
DESCRIPTION="""This space is currently under development to become even more impressive!
Welcome to our video retrieval demo powered by [Searchium-ai/clip4clip-webvid150k](https://huggingface.co/Searchium-ai/clip4clip-webvid150k)! <br>
Using free text search - you will find the top 5 most relevant clips among a dataset of 1.5 million video clips. <br>
Discover, explore, and enjoy the world of video search at your fingertips.
"""
ENDING = """For search acceleration capabilities, please refer to [Searchium.ai](https://www.searchium.ai)
"""
DATA_PATH = './data'
ft_visual_features_file = DATA_PATH + '/dataset_v1_visual_features.npy'
ft_visual_features_file_bin = DATA_PATH + '/dataset_v1_visual_features_binary_packed.npy'
#load database features:
ft_visual_features_database_bin = np.load(ft_visual_features_file_bin)
ft_visual_features_database = np.memmap(ft_visual_features_file, dtype='float32', mode='r', offset=128,
shape=(ft_visual_features_database_bin.shape[0], ft_visual_features_database_bin.shape[1]*8))
database_csv_path = os.path.join(DATA_PATH, 'dataset_v1.csv')
database_df = pd.read_csv(database_csv_path)
class NearestNeighbors:
"""
Class for NearestNeighbors.
"""
def __init__(self, n_neighbors=10, metric='cosine', rerank_from=-1):
"""
metric = 'cosine' / 'binary'
if metric ~= 'cosine' and rerank_from > n_neighbors then a cosine rerank will be performed
"""
self.n_neighbors = n_neighbors
self.metric = metric
self.rerank_from = rerank_from
def normalize(self, a):
return a / np.sum(a**2, axis=1, keepdims=True)
def fit(self, data, o_data=None):
if self.metric == 'cosine':
data = self.normalize(data)
self.index = faiss.IndexFlatIP(data.shape[1])
elif self.metric == 'binary':
self.o_data = data if o_data is None else o_data
#assuming data already packed
self.index = faiss.IndexBinaryFlat(data.shape[1]*8)
self.index.add(np.ascontiguousarray(data))
def kneighbors(self, q_data):
if self.metric == 'cosine':
q_data = self.normalize(q_data)
sim, idx = self.index.search(q_data, self.n_neighbors)
else:
if self.metric == 'binary':
print('This is binary search.')
bq_data = np.packbits((q_data > 0.0).astype(bool), axis=1)
sim, idx = self.index.search(bq_data, max(self.rerank_from, self.n_neighbors))
if self.rerank_from > self.n_neighbors:
re_sims = np.zeros([len(q_data), self.n_neighbors], dtype=float)
re_idxs = np.zeros([len(q_data), self.n_neighbors], dtype=float)
for i, q in enumerate(q_data):
rerank_data = self.o_data[idx[i]]
rerank_search = NearestNeighbors(n_neighbors=self.n_neighbors, metric='cosine')
rerank_search.fit(rerank_data)
re_sim, re_idx = rerank_search.kneighbors(np.asarray([q]))
print("re_idx: ", re_idx)
re_sims[i, :] = re_sim
re_idxs[i, :] = idx[i][re_idx]
idx = re_idxs
sim = re_sims
return sim, idx
model = CLIPTextModelWithProjection.from_pretrained("Searchium-ai/clip4clip-webvid150k")
tokenizer = CLIPTokenizer.from_pretrained("Searchium-ai/clip4clip-webvid150k")
nn_search = NearestNeighbors(n_neighbors=5, metric='binary', rerank_from=100)
nn_search.fit(ft_visual_features_database_bin, o_data=ft_visual_features_database)
def search(search_sentence):
inputs = tokenizer(text=search_sentence , return_tensors="pt")
outputs = model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])
# Normalizing the embeddings:
final_output = outputs[0] / outputs[0].norm(dim=-1, keepdim=True)
sequence_output = final_output.cpu().detach().numpy()
sims, idxs = nn_search.kneighbors(sequence_output)
urls = database_df.iloc[idxs[0]]['contentUrl'].to_list()
AUTOPLAY_VIDEOS = []
for url in urls:
AUTOPLAY_VIDEOS.append("""<video controls muted autoplay>
<source src={} type="video/mp4">
</video>""".format(url))
return AUTOPLAY_VIDEOS
with gr.Blocks(theme=gr.themes.Default(spacing_size=gr.themes.sizes.spacing_lg, radius_size=gr.themes.sizes.radius_lg, text_size=gr.themes.sizes.text_lg)) as demo:
gr.HTML(HTML)
gr.Markdown(DESCRIPTION)
with gr.Row():
with gr.Column():
inp = gr.Textbox(placeholder="Write a sentence.")
btn = gr.Button(value="Search")
ex = [["natural wonders of the world"],["baking chocolate cake"],
["birds fly in the sky"]]
gr.Examples(examples=ex,
inputs=[inp]
)
with gr.Column():
out = [gr.HTML() for _ in range(5)]
btn.click(search, inputs=inp, outputs=out)
gr.Markdown(ENDING)
demo.launch()