lucio commited on
Commit
f4b3d1b
1 Parent(s): 8ee61a8

throw everything in

Browse files
Files changed (1) hide show
  1. app.py +51 -23
app.py CHANGED
@@ -8,54 +8,82 @@ import requests
8
  from os.path import exists
9
  from stt import Model
10
 
 
11
  import torchaudio
12
  from speechbrain.pretrained import EncoderClassifier
13
 
14
  # initialize language ID model
15
- lang_classifier = EncoderClassifier.from_hparams(source="speechbrain/lang-id-commonlanguage_ecapa", savedir="pretrained_models/lang-id-commonlanguage_ecapa")
 
 
 
16
 
17
 
18
  # download STT model
19
- storage_url = "https://coqui.gateway.scarf.sh/mixtec/jemeyer/v1.0.0"
 
 
 
 
 
 
 
20
  model_name = "model.tflite"
21
  model_link = f"{storage_url}/{model_name}"
22
 
23
 
24
- def client(audio_data: np.array, sample_rate: int, use_scorer=False):
25
  output_audio = _convert_audio(audio_data, sample_rate)
26
  waveform, _ = torchaudio.load(output_audio)
27
  out_prob, score, index, text_lab = lang_classifier.classify_batch(waveform)
28
-
29
  output_audio.seek(0)
30
  fin = wave.open(output_audio, 'rb')
31
  audio = np.frombuffer(fin.readframes(fin.getnframes()), np.int16)
32
 
33
  fin.close()
34
 
35
- ds = Model(model_name)
36
- if use_scorer:
37
- ds.enableExternalScorer("kenlm.scorer")
 
 
38
 
39
- result = ds.stt(audio)
 
 
40
 
41
  return f"{text_lab}: {result}"
42
 
43
 
44
- def download(url, file_name):
45
- if not exists(file_name):
46
- print(f"Downloading {file_name}")
47
- r = requests.get(url, allow_redirects=True)
48
- with open(file_name, 'wb') as file:
49
- file.write(r.content)
50
- else:
51
- print(f"Found {file_name}. Skipping download...")
52
 
 
53
 
54
- def stt(audio: Tuple[int, np.array]):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  sample_rate, audio = audio
56
  use_scorer = False
57
 
58
- recognized_result = client(audio, sample_rate, use_scorer)
59
 
60
  return recognized_result
61
 
@@ -71,8 +99,7 @@ def _convert_audio(audio_data: np.array, sample_rate: int):
71
  sample_width=2,
72
  frame_rate=sample_rate
73
  )
74
- wav_file.set_frame_rate(16000).set_channels(
75
- 1).export(output_audio, "wav", codec="pcm_s16le")
76
  output_audio.seek(0)
77
  return output_audio
78
 
@@ -80,8 +107,8 @@ def _convert_audio(audio_data: np.array, sample_rate: int):
80
  iface = gr.Interface(
81
  fn=stt,
82
  inputs=[
83
- gr.inputs.Audio(type="numpy",
84
- label=None, optional=False),
85
  ],
86
  outputs=gr.outputs.Textbox(label="Output"),
87
  title="Coqui STT Yoloxochitl Mixtec",
@@ -97,5 +124,6 @@ iface = gr.Interface(
97
  " This demo is based on the [Ukrainian STT demo](https://huggingface.co/spaces/robinhad/ukrainian-stt).",
98
  )
99
 
100
- download(model_link, model_name)
 
101
  iface.launch()
 
8
  from os.path import exists
9
  from stt import Model
10
 
11
+ import torch
12
  import torchaudio
13
  from speechbrain.pretrained import EncoderClassifier
14
 
15
  # initialize language ID model
16
+ lang_classifier = EncoderClassifier.from_hparams(
17
+ source="speechbrain/lang-id-commonlanguage_ecapa",
18
+ savedir="pretrained_models/lang-id-commonlanguage_ecapa"
19
+ )
20
 
21
 
22
  # download STT model
23
+ model_info = {
24
+ "mixteco": ("https://coqui.gateway.scarf.sh/mixtec/jemeyer/v1.0.0/model.tflite", "mixtec.tflite"),
25
+ "chatino": ("https://coqui.gateway.scarf.sh/chatino/bozden/v1.0.0/model.tflite", "chatino.tflite"),
26
+ "totonaco": ("https://coqui.gateway.scarf.sh/totonac/bozden/v1.0.0/model.tflite", "totonac.tflite"),
27
+ "español": ("jonatasgrosman/wav2vec2-large-xlsr-53-spanish", "spanish_xlsr"),
28
+ "inglés": ("facebook/wav2vec2-large-robust-ft-swbd-300h", "english_xlsr"),
29
+ }
30
+
31
  model_name = "model.tflite"
32
  model_link = f"{storage_url}/{model_name}"
33
 
34
 
35
+ def client(audio_data: np.array, sample_rate: int, default_lang: str):
36
  output_audio = _convert_audio(audio_data, sample_rate)
37
  waveform, _ = torchaudio.load(output_audio)
38
  out_prob, score, index, text_lab = lang_classifier.classify_batch(waveform)
39
+
40
  output_audio.seek(0)
41
  fin = wave.open(output_audio, 'rb')
42
  audio = np.frombuffer(fin.readframes(fin.getnframes()), np.int16)
43
 
44
  fin.close()
45
 
46
+ if text_lab == 'Spanish':
47
+ processor, model = STT_MODELS['español']
48
+ inputs = processor(waveform)
49
+ logits = model(inputs.input_values, attention_mask=inputs.attention_mask).logits
50
+ result = processor.decode(torch.argmax(logits, dim=-1).cpu().tolist())
51
 
52
+ else:
53
+ ds = STT_MODELS[default_lang]
54
+ result = ds.stt(audio)
55
 
56
  return f"{text_lab}: {result}"
57
 
58
 
59
+ def load_models(language):
60
+
61
+ if language in STT_MODELS:
62
+ return STT_MODELS[language]
 
 
 
 
63
 
64
+ model_path, file_name = model_info.get("language", ("", ""))
65
 
66
+ if model_path.startswith('http'):
67
+ if not exists(file_name):
68
+ print(f"Downloading {model_path}")
69
+ r = requests.get(model_path, allow_redirects=True)
70
+ with open(file_name, 'wb') as file:
71
+ file.write(r.content)
72
+ else:
73
+ print(f"Found {file_name}. Skipping download...")
74
+ return Model(file_name)
75
+
76
+ processor = Wav2Vec2Processor.from_pretrained(model_path)
77
+ model = AutoModelForCTC.from_pretrained(model_path)
78
+ return processor, model
79
+
80
+
81
+
82
+ def stt(default_lang: str, audio: Tuple[int, np.array]):
83
  sample_rate, audio = audio
84
  use_scorer = False
85
 
86
+ recognized_result = client(audio, sample_rate, default_lang)
87
 
88
  return recognized_result
89
 
 
99
  sample_width=2,
100
  frame_rate=sample_rate
101
  )
102
+ wav_file.set_frame_rate(16000).set_channels(1).export(output_audio, "wav", codec="pcm_s16le")
 
103
  output_audio.seek(0)
104
  return output_audio
105
 
 
107
  iface = gr.Interface(
108
  fn=stt,
109
  inputs=[
110
+ gr.inputs.Radio(choices=("chatino", "mixteco", "totonaco"), default="mixteco", label="Lengua principal"),
111
+ gr.inputs.Audio(type="numpy", label="Audio", optional=False),
112
  ],
113
  outputs=gr.outputs.Textbox(label="Output"),
114
  title="Coqui STT Yoloxochitl Mixtec",
 
124
  " This demo is based on the [Ukrainian STT demo](https://huggingface.co/spaces/robinhad/ukrainian-stt).",
125
  )
126
 
127
+ STT_MODELS = {lang: load_models(lang) for lang in ("inglés", "español")}
128
+
129
  iface.launch()