Video-Search / app.py
Diangle's picture
Update app.py
eec4792
raw
history blame
5.07 kB
import gradio
import os
import numpy as np
import pandas as pd
from IPython import display
import faiss
import torch
from transformers import AutoTokenizer, CLIPTextModelWithProjection
DATA_PATH = '/data'
ft_visual_features_file = DATA_PATH + '/features/features_Cl4Cl_ckpt_webvid_retrieval_looseType_bs26_gpus2_lr7_150k_finalsample/dataset_v1_visual_features_database.npy'
binary_visual_features_file = DATA_PATH + '/features/features_Cl4Cl_ckpt_webvid_retrieval_looseType_bs26_gpus2_lr7_150k_finalsample_binary20/dataset_v1_visual_features_database_packed.npy'
ft_visual_features_database = np.load(ft_visual_features_file)
binary_visual_features = np.load(binary_visual_features_file)
database_csv_path = os.path.join(DATA_PATH, 'dataset_v1.csv')
database_df = pd.read_csv(database_csv_path)
#Gradio can display URL
def display_videos(display_df):
display_path_list = display_df['contentUrl'].to_list()
display_text_list = display_df['name'].to_list()
html = ''
for path, text in zip(display_path_list, display_text_list):
html_line = '<video autoplay loop {}> <source src="{}" type="video/mp4"> </video> <div class="caption">{}</div><br/>'.format("muted", path, text)
html += html_line
return display.HTML(html)
class NearestNeighbors:
"""
Class for NearestNeighbors.
"""
def __init__(self, n_neighbors=10, metric='cosine', rerank_from=-1):
"""
metric = 'cosine' / 'binary'
if metric ~= 'cosine' and rerank_from > n_neighbors then a cosine rerank will be performed
"""
self.n_neighbors = n_neighbors
self.metric = metric
self.rerank_from = rerank_from
def normalize(self, a):
return a / np.sum(a**2, axis=1, keepdims=True)
def fit(self, data, o_data=None):
if self.metric == 'cosine':
data = self.normalize(data)
self.index = faiss.IndexFlatIP(data.shape[1])
elif self.metric == 'binary':
self.o_data = data if o_data is None else o_data
#assuming data already packed
self.index = faiss.IndexBinaryFlat(data.shape[1]*8)
self.index.add(np.ascontiguousarray(data))
def kneighbors(self, q_data):
if self.metric == 'cosine':
print('cosine search')
q_data = self.normalize(q_data)
sim, idx = self.index.search(q_data, self.n_neighbors)
else:
if self.metric == 'binary':
print('binary search')
bq_data = np.packbits((q_data > 0.0).astype(bool), axis=1)
print(bq_data.shape, self.index.d)
sim, idx = self.index.search(bq_data, max(self.rerank_from, self.n_neighbors))
if self.rerank_from > self.n_neighbors:
sim_float = np.zeros([len(q_data), self.rerank_from], dtype=float)
for i, q in enumerate(q_data):
candidates = np.take_along_axis(self.o_data, idx[i:i+1,:].T, axis=0)
sim_float[i,:] = q @ candidates.T
sort_idx = np.argsort(sim_float[i,:])[::-1]
sim_float[i,:] = sim_float[i,:][sort_idx]
idx[i,:] = idx[i,:][sort_idx]
sim = sim_float[:,:self.n_neighbors]
idx = idx[:,:self.n_neighbors]
return sim, idx
def search(search_sentence):
my_model = CLIPTextModelWithProjection.from_pretrained("Diangle/clip4clip-webvid")
tokenizer = AutoTokenizer.from_pretrained("Diangle/clip4clip-webvid")
inputs = tokenizer(text=search_sentence , return_tensors="pt", padding=True)
outputs = my_model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], return_dict=False)
text_projection = my_model.state_dict()['text_projection.weight']
text_embeds = outputs[1] @ text_projection
final_output = text_embeds[torch.arange(text_embeds.shape[0]), inputs["input_ids"].argmax(dim=-1)]
final_output = final_output / final_output.norm(dim=-1, keepdim=True)
final_output = final_output.cpu().detach().numpy()
sequence_output = final_output / np.sum(final_output**2, axis=1, keepdims=True)
nn_search = NearestNeighbors(n_neighbors=5, metric='binary', rerank_from=100)
nn_search.fit(np.packbits((ft_visual_features_database > 0.0).astype(bool), axis=1), o_data=ft_visual_features_database)
sims, idxs = nn_search.kneighbors(sequence_output)
return database_df.iloc[idxs[0]]['contentUrl'].to_list()
gradio.close_all()
interface = gradio.Interface(search,
inputs=[gradio.Textbox()],
outputs=[gradio.Video(format='mp4') for _ in range(5)],
title = 'Video Search Demo',
description = 'Type some text to search by content within a video database!',
).launch()