Kr08 commited on
Commit
e269658
1 Parent(s): 5096cb5

Created model utils script, includes modules for loading whisper model and processor

Browse files
Files changed (1) hide show
  1. model_utils.py +39 -0
model_utils.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import WhisperProcessor, WhisperForConditionalGeneration
3
+ import whisper
4
+ from config import WHISPER_MODEL_SIZE
5
+
6
+ # Global variables to store models
7
+ whisper_processor = None
8
+ whisper_model = None
9
+ whisper_model_small = None
10
+
11
+ def load_models():
12
+ global whisper_processor, whisper_model, whisper_model_small
13
+ if whisper_processor is None:
14
+ whisper_processor = WhisperProcessor.from_pretrained(f"openai/whisper-{WHISPER_MODEL_SIZE}")
15
+ if whisper_model is None:
16
+ whisper_model = WhisperForConditionalGeneration.from_pretrained(f"openai/whisper-{WHISPER_MODEL_SIZE}").to(get_device())
17
+ if whisper_model_small is None:
18
+ whisper_model_small = whisper.load_model(WHISPER_MODEL_SIZE)
19
+
20
+ def get_device():
21
+ return "cuda:0" if torch.cuda.is_available() else "cpu"
22
+
23
+ def get_processor():
24
+ global whisper_processor
25
+ if whisper_processor is None:
26
+ load_models()
27
+ return whisper_processor
28
+
29
+ def get_model():
30
+ global whisper_model
31
+ if whisper_model is None:
32
+ load_models()
33
+ return whisper_model
34
+
35
+ def get_whisper_model_small():
36
+ global whisper_model_small
37
+ if whisper_model_small is None:
38
+ load_models()
39
+ return whisper_model_small