Jumon commited on
Commit
64e6e6c
β€’
1 Parent(s): 48726fa
Files changed (1) hide show
  1. app.py +8 -1
app.py CHANGED
@@ -6,11 +6,18 @@ from whisper.tokenizer import get_tokenizer
6
 
7
  import classify
8
 
 
 
9
 
10
  def zero_shot_classify(audio_path: str, class_names: str, model_name: str) -> Dict[str, float]:
11
  class_names = class_names.split(",")
12
  tokenizer = get_tokenizer(multilingual=".en" not in model_name)
13
- model = whisper.load_model(model_name)
 
 
 
 
 
14
 
15
  internal_lm_average_logprobs = classify.calculate_internal_lm_average_logprobs(
16
  model=model,
 
6
 
7
  import classify
8
 
9
+ model_cache = {}
10
+
11
 
12
  def zero_shot_classify(audio_path: str, class_names: str, model_name: str) -> Dict[str, float]:
13
  class_names = class_names.split(",")
14
  tokenizer = get_tokenizer(multilingual=".en" not in model_name)
15
+
16
+ if model_name not in model_cache:
17
+ model = whisper.load_model(model_name)
18
+ model_cache[model_name] = model
19
+ else:
20
+ model = model_cache[model_name]
21
 
22
  internal_lm_average_logprobs = classify.calculate_internal_lm_average_logprobs(
23
  model=model,