Armen Gabrielyan commited on
Commit
68da745
1 Parent(s): 5e95a58

add video summarization pre-trained model

Browse files
Files changed (2) hide show
  1. app.py +0 -2
  2. inference.py +3 -3
app.py CHANGED
@@ -8,14 +8,12 @@ import numpy as np
8
  from inference import Inference
9
  import utils
10
 
11
- model_checkpoint = 'saved_model'
12
  encoder_model_name = 'google/vit-large-patch32-224-in21k'
13
  decoder_model_name = 'gpt2'
14
  frame_step = 300
15
 
16
  inference = Inference(
17
  decoder_model_name=decoder_model_name,
18
- model_checkpoint=model_checkpoint,
19
  )
20
 
21
  model = SentenceTransformer('all-mpnet-base-v2')
 
8
  from inference import Inference
9
  import utils
10
 
 
11
  encoder_model_name = 'google/vit-large-patch32-224-in21k'
12
  decoder_model_name = 'gpt2'
13
  frame_step = 300
14
 
15
  inference = Inference(
16
  decoder_model_name=decoder_model_name,
 
17
  )
18
 
19
  model = SentenceTransformer('all-mpnet-base-v2')
inference.py CHANGED
@@ -1,14 +1,14 @@
1
  import torch
2
- from transformers import AutoTokenizer, VisionEncoderDecoderModel
3
 
4
  import utils
5
 
6
  class Inference:
7
- def __init__(self, decoder_model_name, model_checkpoint, max_length=32):
8
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
9
 
10
  self.tokenizer = AutoTokenizer.from_pretrained(decoder_model_name)
11
- self.encoder_decoder_model = VisionEncoderDecoderModel.from_pretrained(model_checkpoint)
12
  self.encoder_decoder_model.to(self.device)
13
 
14
  self.max_length = max_length
 
1
  import torch
2
+ from transformers import AutoTokenizer, AutoModel
3
 
4
  import utils
5
 
6
  class Inference:
7
+ def __init__(self, decoder_model_name, max_length=32):
8
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
9
 
10
  self.tokenizer = AutoTokenizer.from_pretrained(decoder_model_name)
11
+ self.encoder_decoder_model = AutoModel.from_pretrained('armgabrielyan/video-summarization')
12
  self.encoder_decoder_model.to(self.device)
13
 
14
  self.max_length = max_length