csukuangfj's picture
small fixes
321e3a5
#!/usr/bin/env python3
#
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# References:
# https://gradio.app/docs/#dropdown
import os
import time
from datetime import datetime
import gradio as gr
import torchaudio
from model import (
get_gigaspeech_pre_trained_model,
sample_rate,
get_wenetspeech_pre_trained_model,
)
models = {
"Chinese": get_wenetspeech_pre_trained_model(),
"English": get_gigaspeech_pre_trained_model(),
}
def convert_to_wav(in_filename: str) -> str:
"""Convert the input audio file to a wave file"""
out_filename = in_filename + ".wav"
print(f"Converting '{in_filename}' to '{out_filename}'")
_ = os.system(f"ffmpeg -hide_banner -i '{in_filename}' '{out_filename}'")
return out_filename
demo = gr.Blocks()
def process(in_filename: str, language: str) -> str:
print("in_filename", in_filename)
print("language", language)
filename = convert_to_wav(in_filename)
now = datetime.now()
date_time = now.strftime("%Y-%m-%d %H:%M:%S.%f")
print(f"Started at {date_time}")
start = time.time()
wave, wave_sample_rate = torchaudio.load(filename)
if wave_sample_rate != sample_rate:
print(
f"Expected sample rate: {sample_rate}. Given: {wave_sample_rate}. "
f"Resampling to {sample_rate}."
)
wave = torchaudio.functional.resample(
wave,
orig_freq=wave_sample_rate,
new_freq=sample_rate,
)
wave = wave[0] # use only the first channel.
hyp = models[language].decode_waves([wave])[0]
date_time = now.strftime("%Y-%m-%d %H:%M:%S.%f")
end = time.time()
duration = wave.shape[0] / sample_rate
rtf = (end - start) / duration
print(f"Finished at {date_time} s. Elapsed: {end - start: .3f} s")
print(f"Duration {duration: .3f} s")
print(f"RTF {rtf: .3f}")
print("hyp")
print(hyp)
return hyp
with demo:
gr.Markdown("# Automatic Speech Recognition with Next-gen Kaldi")
language_choices = list(models.keys())
language = gr.inputs.Radio(
label="Language",
choices=language_choices,
)
with gr.Tabs():
with gr.TabItem("Upload from disk"):
uploaded_file = gr.inputs.Audio(
source="upload", # Choose between "microphone", "upload"
type="filepath",
optional=False,
label="Upload from disk",
)
upload_button = gr.Button("Submit for recognition")
uploaded_output = gr.outputs.Textbox(
label="Recognized speech from uploaded file"
)
with gr.TabItem("Record from microphone"):
microphone = gr.inputs.Audio(
source="microphone", # Choose between "microphone", "upload"
type="filepath",
optional=False,
label="Record from microphone",
)
recorded_output = gr.outputs.Textbox(
label="Recognized speech from recordings"
)
record_button = gr.Button("Submit for recordings")
upload_button.click(
process,
inputs=[uploaded_file, language],
outputs=uploaded_output,
)
record_button.click(
process,
inputs=[microphone, language],
outputs=recorded_output,
)
if __name__ == "__main__":
demo.launch()