Video-Search / app.py
Diangle's picture
Update app.py
66cbe94
raw
history blame
6.3 kB
import gradio as gr
import os
import numpy as np
import pandas as pd
from IPython import display
import faiss
import torch
from transformers import CLIPTokenizer, CLIPTextModelWithProjection
HTML="""
<!DOCTYPE html>
<html>
<style>
.container {
align-items: center;
justify-content: center;
}
img {
max-width: 10%;
max-height:10%;
float: left;
}
.text {
font-size: 32px;
padding-top: 15%;
padding-left: 15%;
padding-bottom: 5%;
float: left
}
</style>
<body>
<div class="container">
<div class="image">
<img src="https://huggingface.co/spaces/Searchium-ai/Video-Search/resolve/main/Searchium.png" width="333" height="216">
</div>
<div class="text">
<h1 style="font-size: 32px;"><b> Large Scale Video Search </b></h1>
</div>
</div>
</body>
</html>
"""
DESCRIPTION="""<b> Exciting News! </b> <br>
<b> We've added another 4 million video embeddings to our collection! </b> <br>
Welcome to our video retrieval demo powered by [Searchium-ai/clip4clip-webvid150k](https://huggingface.co/Searchium-ai/clip4clip-webvid150k)! <br>
Using free text search - you will find the top 5 most relevant clips among a dataset of <b> 5.5 million </b> video clips. <br>
Discover, explore, and enjoy the world of video search at your fingertips.
"""
ENDING = """For search acceleration capabilities, please refer to [Searchium.ai](https://www.searchium.ai)
"""
DATA_PATH = './new_data'
ft_visual_features_file = DATA_PATH + '/video_half_dataset_visual_features.npy'
ft_visual_features_file_bin = DATA_PATH + '/video_half_dataset_visual_features_binary_packed.npy'
#load database features:
ft_visual_features_database_bin = np.load(ft_visual_features_file_bin)
ft_visual_features_database = np.load(ft_visual_features_file, mmap_mode='r')
database_csv_path = os.path.join(DATA_PATH, 'video_half_dataset.csv')
database_df = pd.read_csv(database_csv_path)
class NearestNeighbors:
"""
Class for NearestNeighbors.
"""
def __init__(self, n_neighbors=10, metric='cosine', rerank_from=-1):
"""
metric = 'cosine' / 'binary'
if metric ~= 'cosine' and rerank_from > n_neighbors then a cosine rerank will be performed
"""
self.n_neighbors = n_neighbors
self.metric = metric
self.rerank_from = rerank_from
def normalize(self, a):
return a / np.sum(a**2, axis=1, keepdims=True)
def fit(self, data, o_data=None):
if self.metric == 'cosine':
data = self.normalize(data)
self.index = faiss.IndexFlatIP(data.shape[1])
elif self.metric == 'binary':
self.o_data = data if o_data is None else o_data
#assuming data already packed
self.index = faiss.IndexBinaryFlat(data.shape[1]*8)
self.index.add(np.ascontiguousarray(data))
def kneighbors(self, q_data):
if self.metric == 'cosine':
q_data = self.normalize(q_data)
sim, idx = self.index.search(q_data, self.n_neighbors)
else:
if self.metric == 'binary':
print('This is binary search.')
bq_data = np.packbits((q_data > 0.0).astype(bool), axis=1)
sim, idx = self.index.search(bq_data, max(self.rerank_from, self.n_neighbors))
if self.rerank_from > self.n_neighbors:
re_sims = np.zeros([len(q_data), self.n_neighbors], dtype=float)
re_idxs = np.zeros([len(q_data), self.n_neighbors], dtype=float)
for i, q in enumerate(q_data):
rerank_data = self.o_data[idx[i]]
rerank_search = NearestNeighbors(n_neighbors=self.n_neighbors, metric='cosine')
rerank_search.fit(rerank_data)
re_sim, re_idx = rerank_search.kneighbors(np.asarray([q]))
print("re_idx: ", re_idx)
re_sims[i, :] = re_sim
re_idxs[i, :] = idx[i][re_idx]
idx = re_idxs
sim = re_sims
return sim, idx
model = CLIPTextModelWithProjection.from_pretrained("Searchium-ai/clip4clip-webvid150k")
tokenizer = CLIPTokenizer.from_pretrained("Searchium-ai/clip4clip-webvid150k")
nn_search = NearestNeighbors(n_neighbors=5, metric='binary', rerank_from=100)
nn_search.fit(ft_visual_features_database_bin, o_data=ft_visual_features_database)
def search(search_sentence):
inputs = tokenizer(text=search_sentence , return_tensors="pt")
outputs = model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])
# Normalizing the embeddings:
final_output = outputs[0] / outputs[0].norm(dim=-1, keepdim=True)
sequence_output = final_output.cpu().detach().numpy()
sims, idxs = nn_search.kneighbors(sequence_output)
urls = database_df.iloc[idxs[0]]['contentUrl'].to_list()
AUTOPLAY_VIDEOS = []
for url in urls:
AUTOPLAY_VIDEOS.append("""<video controls muted autoplay>
<source src={} type="video/mp4">
</video>""".format(url))
return AUTOPLAY_VIDEOS
with gr.Blocks(theme=gr.themes.Default(spacing_size=gr.themes.sizes.spacing_lg, radius_size=gr.themes.sizes.radius_lg, text_size=gr.themes.sizes.text_lg)) as demo:
gr.HTML(HTML)
gr.Markdown(DESCRIPTION)
with gr.Row():
with gr.Column():
inp = gr.Textbox(placeholder="Write a sentence.")
btn = gr.Button(value="Search")
ex = [["natural wonders of the world"],["yoga routines for morning energy"],
["baking chocolate cake"],["birds fly in the sky"]]
gr.Examples(examples=ex,
inputs=[inp]
)
with gr.Column():
out = [gr.HTML() for _ in range(5)]
btn.click(search, inputs=inp, outputs=out)
gr.Markdown(ENDING)
demo.launch()