lucio commited on
Commit
2915c9d
1 Parent(s): 3363bfd

make this work

Browse files
Files changed (1) hide show
  1. app.py +10 -4
app.py CHANGED
@@ -20,6 +20,11 @@ lang_classifier = EncoderClassifier.from_hparams(
20
  savedir="pretrained_models/lang-id-commonlanguage_ecapa"
21
  )
22
 
 
 
 
 
 
23
 
24
  # download STT model
25
  model_info = {
@@ -30,6 +35,8 @@ model_info = {
30
  "inglés": ("facebook/wav2vec2-large-robust-ft-swbd-300h", "english_xlsr"),
31
  }
32
 
 
 
33
 
34
  def client(audio_data: np.array, sample_rate: int, default_lang: str):
35
  output_audio = _convert_audio(audio_data, sample_rate)
@@ -55,7 +62,7 @@ def client(audio_data: np.array, sample_rate: int, default_lang: str):
55
  return f"{text_lab}: {result}"
56
 
57
 
58
- def load_models(language):
59
 
60
  model_path, file_name = model_info.get("language", ("", ""))
61
 
@@ -69,9 +76,8 @@ def load_models(language):
69
  print(f"Found {file_name}. Skipping download...")
70
  return Model(file_name)
71
 
72
- processor = Wav2Vec2Processor.from_pretrained(model_path)
73
- model = AutoModelForCTC.from_pretrained(model_path)
74
- return processor, model
75
 
76
 
77
 
 
20
  savedir="pretrained_models/lang-id-commonlanguage_ecapa"
21
  )
22
 
23
+ @st.cache(hash_funcs={torch.nn.parameter.Parameter: lambda _: None})
24
+ def load_hf_model(model_path="facebook/wav2vec2-large-robust-ft-swbd-300h"):
25
+ processor = Wav2Vec2Processor.from_pretrained(model_path)
26
+ model = AutoModelForCTC.from_pretrained(model_path).to(device)
27
+ return processor, model
28
 
29
  # download STT model
30
  model_info = {
 
35
  "inglés": ("facebook/wav2vec2-large-robust-ft-swbd-300h", "english_xlsr"),
36
  }
37
 
38
+ STT_MODELS = {lang: load_hf_model(model_info[lang][0]) for lang in ("inglés", "español")}
39
+
40
 
41
  def client(audio_data: np.array, sample_rate: int, default_lang: str):
42
  output_audio = _convert_audio(audio_data, sample_rate)
 
62
  return f"{text_lab}: {result}"
63
 
64
 
65
+ def load_coqui_models(language):
66
 
67
  model_path, file_name = model_info.get("language", ("", ""))
68
 
 
76
  print(f"Found {file_name}. Skipping download...")
77
  return Model(file_name)
78
 
79
+ for lang in ('mixteco', 'chatino', 'totonaco'):
80
+ STT_MODELS[lang] = load_coqui_models(lang)
 
81
 
82
 
83