Spaces:
Runtime error
Runtime error
import gradio as gr | |
import soundfile | |
import time | |
import torch | |
import scipy.io.wavfile | |
from espnet2.bin.tts_inference import Text2Speech | |
from espnet2.utils.types import str_or_none | |
from espnet2.bin.asr_inference import Speech2Text | |
from subprocess import call | |
import os | |
with open('s3prl.sh', 'rb') as file: | |
script = file.read() | |
rc = call(script, shell=True) | |
import sys | |
sys.path.append(os.getcwd()+"/s3prl") | |
os.environ["PYTHONPATH"]=os.getcwd()+"/s3prl" | |
import fairseq | |
print(fairseq.__version__) | |
# exit() | |
# tagen = 'kan-bayashi/ljspeech_vits' | |
# vocoder_tagen = "none" | |
speech2text_slurp = Speech2Text.from_pretrained( | |
asr_train_config="slurp/config.yaml", | |
asr_model_file="slurp/valid.acc.ave_10best.pth", | |
# Decoding parameters are not included in the model file | |
nbest=1 | |
) | |
speech2text_fsc = Speech2Text.from_pretrained( | |
asr_train_config="fsc/config.yaml", | |
asr_model_file="fsc/valid.acc.ave_5best.pth", | |
# Decoding parameters are not included in the model file | |
nbest=1 | |
) | |
speech2text_fsc = Speech2Text.from_pretrained( | |
asr_train_config="catslu/config.yaml", | |
asr_model_file="catslu/valid.acc.ave_5best.pth", | |
# Decoding parameters are not included in the model file | |
nbest=1 | |
) | |
def inference(wav,data): | |
with torch.no_grad(): | |
if data == "english_slurp": | |
speech, rate = soundfile.read(wav.name) | |
nbests = speech2text_slurp(speech) | |
text, *_ = nbests[0] | |
intent=text.split(" ")[0] | |
scenario=intent.split("_")[0] | |
action=intent.split("_")[1] | |
text="{scenario: "+scenario+", action: "+action+"}" | |
elif data == "english_fsc": | |
print(wav.name) | |
speech, rate = soundfile.read(wav.name) | |
print(speech.shape) | |
if len(speech.shape)==2: | |
speech=speech[:,0] | |
# soundfile.write("store_file.wav", speech, rate, subtype='FLOAT') | |
print(speech.shape) | |
nbests = speech2text_fsc(speech) | |
text, *_ = nbests[0] | |
intent=text.split(" ")[0] | |
action=intent.split("_")[0] | |
objects=intent.split("_")[1] | |
location=intent.split("_")[2] | |
text="{action: "+action+", object: "+objects+", location: "+location+"}" | |
elif data == "chinese": | |
print(wav.name) | |
speech, rate = soundfile.read(wav.name) | |
print(speech.shape) | |
if len(speech.shape)==2: | |
speech=speech[:,0] | |
# soundfile.write("store_file.wav", speech, rate, subtype='FLOAT') | |
print(speech.shape) | |
nbests = speech2text_fsc(speech) | |
text, *_ = nbests[0] | |
text=text.split(" ")[0] | |
# intent=text.split(" ")[0] | |
# action=intent.split("_")[0] | |
# objects=intent.split("_")[1] | |
# location=intent.split("_")[2] | |
# text="{action: "+action+", object: "+objects+", location: "+location+"}" | |
# if lang == "chinese": | |
# wav = text2speechch(text)["wav"] | |
# scipy.io.wavfile.write("out.wav",text2speechch.fs , wav.view(-1).cpu().numpy()) | |
# if lang == "japanese": | |
# wav = text2speechjp(text)["wav"] | |
# scipy.io.wavfile.write("out.wav",text2speechjp.fs , wav.view(-1).cpu().numpy()) | |
return text | |
title = "ESPnet2-SLU" | |
description = "Gradio demo for ESPnet2-SLU: Advancing Spoken Language Understanding through ESPnet. To use it, simply record your audio or click one of the examples to load them. Read more at the links below." | |
article = "<p style='text-align: center'><a href='https://github.com/espnet/espnet' target='_blank'>Github Repo</a></p>" | |
examples=[['audio_slurp.flac',"english_slurp"],['audio_fsc.wav',"english_fsc"],['audio_catslu.wav',"chinese"]] | |
# gr.inputs.Textbox(label="input text",lines=10),gr.inputs.Radio(choices=["english"], type="value", default="english", label="language") | |
gr.Interface( | |
inference, | |
[gr.inputs.Audio(label="input audio",source = "microphone", type="file"),gr.inputs.Radio(choices=["english_slurp","english_fsc","chinese"], type="value", default="english_slurp", label="Dataset")], | |
gr.outputs.Textbox(type="str", label="Output"), | |
title=title, | |
description=description, | |
article=article, | |
enable_queue=True, | |
examples=examples | |
).launch(debug=True) | |