lucio's picture
throw everything in
511c6af
raw history blame
No virus
4.54 kB
from io import BytesIO
from typing import Tuple
import wave
import gradio as gr
import numpy as np
from pydub.audio_segment import AudioSegment
import requests
from os.path import exists
from stt import Model
import torch
import torchaudio
from speechbrain.pretrained import EncoderClassifier
# initialize language ID model
lang_classifier = EncoderClassifier.from_hparams(
source="speechbrain/lang-id-commonlanguage_ecapa",
savedir="pretrained_models/lang-id-commonlanguage_ecapa"
)
# download STT model
model_info = {
"mixteco": ("https://coqui.gateway.scarf.sh/mixtec/jemeyer/v1.0.0/model.tflite", "mixtec.tflite"),
"chatino": ("https://coqui.gateway.scarf.sh/chatino/bozden/v1.0.0/model.tflite", "chatino.tflite"),
"totonaco": ("https://coqui.gateway.scarf.sh/totonac/bozden/v1.0.0/model.tflite", "totonac.tflite"),
"español": ("jonatasgrosman/wav2vec2-large-xlsr-53-spanish", "spanish_xlsr"),
"inglés": ("facebook/wav2vec2-large-robust-ft-swbd-300h", "english_xlsr"),
}
model_name = "model.tflite"
model_link = f"{storage_url}/{model_name}"
def client(audio_data: np.array, sample_rate: int, default_lang: str):
output_audio = _convert_audio(audio_data, sample_rate)
waveform, _ = torchaudio.load(output_audio)
out_prob, score, index, text_lab = lang_classifier.classify_batch(waveform)
output_audio.seek(0)
fin = wave.open(output_audio, 'rb')
audio = np.frombuffer(fin.readframes(fin.getnframes()), np.int16)
fin.close()
if text_lab == 'Spanish':
processor, model = STT_MODELS['español']
inputs = processor(waveform)
logits = model(inputs.input_values, attention_mask=inputs.attention_mask).logits
result = processor.decode(torch.argmax(logits, dim=-1).cpu().tolist())
else:
ds = STT_MODELS[default_lang]
result = ds.stt(audio)
return f"{text_lab}: {result}"
def load_models(language):
if language in STT_MODELS:
return STT_MODELS[language]
model_path, file_name = model_info.get("language", ("", ""))
if model_path.startswith('http'):
if not exists(file_name):
print(f"Downloading {model_path}")
r = requests.get(model_path, allow_redirects=True)
with open(file_name, 'wb') as file:
file.write(r.content)
else:
print(f"Found {file_name}. Skipping download...")
return Model(file_name)
processor = Wav2Vec2Processor.from_pretrained(model_path)
model = AutoModelForCTC.from_pretrained(model_path)
return processor, model
def stt(default_lang: str, audio: Tuple[int, np.array]):
sample_rate, audio = audio
use_scorer = False
recognized_result = client(audio, sample_rate, default_lang)
return recognized_result
def _convert_audio(audio_data: np.array, sample_rate: int):
source_audio = BytesIO()
source_audio.write(audio_data)
source_audio.seek(0)
output_audio = BytesIO()
wav_file = AudioSegment.from_raw(
source_audio,
channels=1,
sample_width=2,
frame_rate=sample_rate
)
wav_file.set_frame_rate(16000).set_channels(1).export(output_audio, "wav", codec="pcm_s16le")
output_audio.seek(0)
return output_audio
iface = gr.Interface(
fn=stt,
inputs=[
gr.inputs.Radio(choices=("chatino", "mixteco", "totonaco"), default="mixteco", label="Lengua principal"),
gr.inputs.Audio(type="numpy", label="Audio", optional=False),
],
outputs=gr.outputs.Textbox(label="Output"),
title="Coqui STT Yoloxochitl Mixtec",
theme="huggingface",
description="Prueba de dictado a texto para el mixteco de Yoloxochitl,"
" usando [el modelo entrenado por Josh Meyer](https://coqui.ai/mixtec/jemeyer/v1.0.0/)"
" con [los datos recopilados por Rey Castillo y sus colaboradores](https://www.openslr.org/89)."
" Esta prueba es basada en la de [Ukraniano](https://huggingface.co/spaces/robinhad/ukrainian-stt)."
" \n\n"
"Speech-to-text demo for Yoloxochitl Mixtec,"
" using [the model trained by Josh Meyer](https://coqui.ai/mixtec/jemeyer/v1.0.0/)"
" on [the corpus compiled by Rey Castillo and collaborators](https://www.openslr.org/89)."
" This demo is based on the [Ukrainian STT demo](https://huggingface.co/spaces/robinhad/ukrainian-stt).",
)
STT_MODELS = {lang: load_models(lang) for lang in ("inglés", "español")}
iface.launch()