Armen Gabrielyan commited on
Commit
5e95a58
1 Parent(s): deb4867

add initial app

Browse files
Files changed (4) hide show
  1. app.py +84 -0
  2. inference.py +29 -0
  3. requirements.txt +4 -0
  4. utils.py +44 -0
app.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import timedelta
2
+ import gradio as gr
3
+ from sentence_transformers import SentenceTransformer
4
+ import torchvision
5
+ from sklearn.metrics.pairwise import cosine_similarity
6
+ import numpy as np
7
+
8
+ from inference import Inference
9
+ import utils
10
+
11
+ model_checkpoint = 'saved_model'
12
+ encoder_model_name = 'google/vit-large-patch32-224-in21k'
13
+ decoder_model_name = 'gpt2'
14
+ frame_step = 300
15
+
16
+ inference = Inference(
17
+ decoder_model_name=decoder_model_name,
18
+ model_checkpoint=model_checkpoint,
19
+ )
20
+
21
+ model = SentenceTransformer('all-mpnet-base-v2')
22
+
23
+ def search_in_video(video, query):
24
+ result = torchvision.io.read_video(video)
25
+ video = result[0]
26
+ video_fps = result[2]['video_fps']
27
+
28
+ video_segments = [
29
+ video[idx:idx + frame_step, :, :, :] for idx in range(0, video.shape[0], frame_step)
30
+ ]
31
+
32
+ generated_texts = []
33
+
34
+ for video_seg in video_segments:
35
+ pixel_values = utils.video2image(video_seg, encoder_model_name)
36
+
37
+ generated_text = inference.generate_text(pixel_values, encoder_model_name)
38
+ generated_texts.append(generated_text)
39
+
40
+ sentences = [query] + generated_texts
41
+
42
+ sentence_embeddings = model.encode(sentences)
43
+
44
+ similarities = cosine_similarity(
45
+ [sentence_embeddings[0]],
46
+ sentence_embeddings[1:]
47
+ )
48
+ arg_sorted_similarities = np.argsort(similarities)
49
+
50
+ ordered_similarity_scores = similarities[0][arg_sorted_similarities]
51
+
52
+ best_video = video_segments[arg_sorted_similarities[0, -1]]
53
+ torchvision.io.write_video('best.mp4', best_video, video_fps)
54
+
55
+ total_frames = video.shape[0]
56
+
57
+ video_frame_segs = [
58
+ [idx, min(idx + frame_step, total_frames)] for idx in range(0, total_frames, frame_step)
59
+ ]
60
+ ordered_start_ends = []
61
+
62
+ for [start, end] in video_frame_segs:
63
+ td = timedelta(seconds=(start / video_fps))
64
+ s = round(td.total_seconds(), 2)
65
+
66
+ td = timedelta(seconds=(end / video_fps))
67
+ e = round(td.total_seconds(), 2)
68
+
69
+ ordered_start_ends.append(f'{s}:{e}')
70
+
71
+ ordered_start_ends = np.array(ordered_start_ends)[arg_sorted_similarities]
72
+
73
+ labels_to_scores = dict(
74
+ zip(ordered_start_ends[0].tolist(), ordered_similarity_scores[0].tolist())
75
+ )
76
+
77
+ return 'best.mp4', labels_to_scores
78
+
79
+ app = gr.Interface(
80
+ fn=search_in_video,
81
+ inputs=['video', 'text'],
82
+ outputs=['video', gr.outputs.Label(num_top_classes=3, type='auto')],
83
+ )
84
+ app.launch(share=True)
inference.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, VisionEncoderDecoderModel
3
+
4
+ import utils
5
+
6
+ class Inference:
7
+ def __init__(self, decoder_model_name, model_checkpoint, max_length=32):
8
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
9
+
10
+ self.tokenizer = AutoTokenizer.from_pretrained(decoder_model_name)
11
+ self.encoder_decoder_model = VisionEncoderDecoderModel.from_pretrained(model_checkpoint)
12
+ self.encoder_decoder_model.to(self.device)
13
+
14
+ self.max_length = max_length
15
+
16
+ def generate_text(self, video, encoder_model_name):
17
+ if isinstance(video, str):
18
+ pixel_values = utils.video2image_from_path(video, encoder_model_name)
19
+ else:
20
+ pixel_values = video
21
+
22
+ if not self.tokenizer.pad_token:
23
+ self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
24
+ self.encoder_decoder_model.decoder.resize_token_embeddings(len(self.tokenizer))
25
+
26
+ generated_ids = self.encoder_decoder_model.generate(pixel_values.unsqueeze(0).to(self.device), max_length=self.max_length)
27
+ generated_text = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
28
+
29
+ return generated_text
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ nltk==3.7
2
+ tqdm==4.64.0
3
+ scikit-learn==1.1.1
4
+ sentence-transformers==2.2.0
utils.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import ViTFeatureExtractor
2
+ import torchvision
3
+ import torchvision.transforms.functional as fn
4
+ import torch as th
5
+ import os
6
+ import pickle
7
+
8
+
9
+ def video2image_from_path(video_path, feature_extractor_name):
10
+ video = torchvision.io.read_video(video_path)
11
+
12
+ return video2image(video[0], feature_extractor_name)
13
+
14
+
15
+ def video2image(video, feature_extractor_name):
16
+ feature_extractor = ViTFeatureExtractor.from_pretrained(
17
+ feature_extractor_name
18
+ )
19
+
20
+ vid = th.permute(video, (3, 0, 1, 2))
21
+ samp = th.linspace(0, vid.shape[1]-1, 49, dtype=th.long)
22
+ vid = vid[:, samp, :, :]
23
+
24
+ im_l = list()
25
+ for i in range(vid.shape[1]):
26
+ im_l.append(vid[:, i, :, :])
27
+
28
+ inputs = feature_extractor(im_l, return_tensors="pt")
29
+
30
+ inputs = inputs['pixel_values']
31
+
32
+ im_h = list()
33
+ for i in range(7):
34
+ im_v = th.cat((inputs[0+i*7, :, :, :],
35
+ inputs[1+i*7, :, :, :],
36
+ inputs[2+i*7, :, :, :],
37
+ inputs[3+i*7, :, :, :],
38
+ inputs[4+i*7, :, :, :],
39
+ inputs[5+i*7, :, :, :],
40
+ inputs[6+i*7, :, :, :]), 2)
41
+ im_h.append(im_v)
42
+ resize = fn.resize(th.cat(im_h, 1), size=[224])
43
+
44
+ return resize