from datetime import timedelta import gradio as gr from sentence_transformers import SentenceTransformer import torchvision import torch from sklearn.metrics.pairwise import cosine_similarity import numpy as np from inference import Inference import utils encoder_model_name = 'google/vit-large-patch32-224-in21k' decoder_model_name = 'gpt2-large' frame_step = 300 inference = Inference( decoder_model_name=decoder_model_name, ) model = SentenceTransformer('all-mpnet-base-v2') def search_in_video(video, query): result = torchvision.io.read_video(video) video = result[0] video_fps = result[2]['video_fps'] video_segments = [ video[idx:idx + frame_step, :, :, :] for idx in range(0, video.shape[0], frame_step) ] pixel_values = [utils.video2image(video_seg, encoder_model_name) for video_seg in video_segments] pixel_values = torch.stack(pixel_values) generated_texts = inference.generate_texts(pixel_values) sentences = [query] + generated_texts sentence_embeddings = model.encode(sentences) similarities = cosine_similarity( [sentence_embeddings[0]], sentence_embeddings[1:] ) arg_sorted_similarities = np.argsort(similarities) ordered_similarity_scores = similarities[0][arg_sorted_similarities] top1 = video_segments[arg_sorted_similarities[0, -1]] top2 = video_segments[arg_sorted_similarities[0, -2]] top3 = video_segments[arg_sorted_similarities[0, -3]] torchvision.io.write_video('top1.mp4', top1, video_fps) torchvision.io.write_video('top2.mp4', top2, video_fps) torchvision.io.write_video('top3.mp4', top3, video_fps) total_frames = video.shape[0] video_frame_segs = [ [idx, min(idx + frame_step, total_frames)] for idx in range(0, total_frames, frame_step) ] ordered_start_ends = [] for [start, end] in video_frame_segs: s = timedelta(seconds=(start / video_fps)) e = timedelta(seconds=(end / video_fps)) ordered_start_ends.append(f'{s}:{e}') ordered_start_ends = np.array(ordered_start_ends)[arg_sorted_similarities] labels_to_scores = dict( zip(ordered_start_ends[0].tolist(), ordered_similarity_scores[0].tolist()) ) return 'top1.mp4', 'top2.mp4', 'top3.mp4', labels_to_scores app = gr.Interface( fn=search_in_video, inputs=['video', 'text'], outputs=[ gr.Video(format='mp4', label='Top1'), gr.Video(format='mp4', label='Top2'), gr.Video(format='mp4', label='Top3'), gr.outputs.Label(num_top_classes=5, type='auto', label='Scores'), ], ) app.launch()