Armen Gabrielyan commited on
Commit
cde7ed6
1 Parent(s): 4820fa1

add batch generation

Browse files
Files changed (3) hide show
  1. app.py +4 -6
  2. inference.py +4 -10
  3. utils.py +0 -7
app.py CHANGED
@@ -2,6 +2,7 @@ from datetime import timedelta
2
  import gradio as gr
3
  from sentence_transformers import SentenceTransformer
4
  import torchvision
 
5
  from sklearn.metrics.pairwise import cosine_similarity
6
  import numpy as np
7
 
@@ -27,13 +28,10 @@ def search_in_video(video, query):
27
  video[idx:idx + frame_step, :, :, :] for idx in range(0, video.shape[0], frame_step)
28
  ]
29
 
30
- generated_texts = []
 
31
 
32
- for video_seg in video_segments:
33
- pixel_values = utils.video2image(video_seg, encoder_model_name)
34
-
35
- generated_text = inference.generate_text(pixel_values, encoder_model_name)
36
- generated_texts.append(generated_text)
37
 
38
  sentences = [query] + generated_texts
39
 
2
  import gradio as gr
3
  from sentence_transformers import SentenceTransformer
4
  import torchvision
5
+ import torch
6
  from sklearn.metrics.pairwise import cosine_similarity
7
  import numpy as np
8
 
28
  video[idx:idx + frame_step, :, :, :] for idx in range(0, video.shape[0], frame_step)
29
  ]
30
 
31
+ pixel_values = [utils.video2image(video_seg, encoder_model_name) for video_seg in video_segments]
32
+ pixel_values = torch.stack(pixel_values)
33
 
34
+ generated_texts = inference.generate_texts(pixel_values)
 
 
 
 
35
 
36
  sentences = [query] + generated_texts
37
 
inference.py CHANGED
@@ -1,7 +1,6 @@
1
  import torch
2
  from transformers import AutoTokenizer, VisionEncoderDecoderModel
3
 
4
- import utils
5
 
6
  class Inference:
7
  def __init__(self, decoder_model_name, max_length=32):
@@ -13,22 +12,17 @@ class Inference:
13
 
14
  self.max_length = max_length
15
 
16
- def generate_text(self, video, encoder_model_name):
17
- if isinstance(video, str):
18
- pixel_values = utils.video2image_from_path(video, encoder_model_name)
19
- else:
20
- pixel_values = video
21
-
22
  if not self.tokenizer.pad_token:
23
  self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
24
  self.encoder_decoder_model.decoder.resize_token_embeddings(len(self.tokenizer))
25
 
26
  generated_ids = self.encoder_decoder_model.generate(
27
- pixel_values.unsqueeze(0).to(self.device),
28
  max_length=self.max_length,
29
  num_beams=4,
30
  no_repeat_ngram_size=2,
31
  )
32
- generated_text = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
33
 
34
- return generated_text
1
  import torch
2
  from transformers import AutoTokenizer, VisionEncoderDecoderModel
3
 
 
4
 
5
  class Inference:
6
  def __init__(self, decoder_model_name, max_length=32):
12
 
13
  self.max_length = max_length
14
 
15
+ def generate_texts(self, pixel_values):
 
 
 
 
 
16
  if not self.tokenizer.pad_token:
17
  self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
18
  self.encoder_decoder_model.decoder.resize_token_embeddings(len(self.tokenizer))
19
 
20
  generated_ids = self.encoder_decoder_model.generate(
21
+ pixel_values.to(self.device),
22
  max_length=self.max_length,
23
  num_beams=4,
24
  no_repeat_ngram_size=2,
25
  )
26
+ generated_texts = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
27
 
28
+ return generated_texts
utils.py CHANGED
@@ -1,15 +1,8 @@
1
  from transformers import ViTFeatureExtractor
2
- import torchvision
3
  import torchvision.transforms.functional as fn
4
  import torch as th
5
 
6
 
7
- def video2image_from_path(video_path, feature_extractor_name):
8
- video = torchvision.io.read_video(video_path)
9
-
10
- return video2image(video[0], feature_extractor_name)
11
-
12
-
13
  def video2image(video, feature_extractor_name):
14
  feature_extractor = ViTFeatureExtractor.from_pretrained(
15
  feature_extractor_name
1
  from transformers import ViTFeatureExtractor
 
2
  import torchvision.transforms.functional as fn
3
  import torch as th
4
 
5
 
 
 
 
 
 
 
6
  def video2image(video, feature_extractor_name):
7
  feature_extractor = ViTFeatureExtractor.from_pretrained(
8
  feature_extractor_name