gbn2 / app.py
nimool's picture
Update app.py
6daeff1
raw
history blame
No virus
3.64 kB
import soundfile as sf
import torch
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
import gradio as gr
import sox
import subprocess
from google_spell_checker import GoogleSpellChecker
spell_checker = GoogleSpellChecker(lang="fa")
def read_file_and_process(wav_file):
filename = wav_file.split('.')[0]
filename_16k = filename + "16k.wav"
resampler(wav_file, filename_16k)
speech, _ = sf.read(filename_16k)
inputs = processor(speech, sampling_rate=16_000, return_tensors="pt", padding=True)
return inputs
def resampler(input_file_path, output_file_path):
command = (
f"ffmpeg -hide_banner -loglevel panic -i {input_file_path} -ar 16000 -ac 1 -bits_per_raw_sample 16 -vn "
f"{output_file_path}"
)
subprocess.call(command, shell=True)
def parse_transcription(logits):
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.decode(predicted_ids[0], skip_special_tokens=True)
del(logits)
return transcription
def corrector(sentence):
check_spell = spell_checker.check(sentence)
if check_spell[0] is False:
corrected = check_spell[1]
return corrected
else:
return sentence
def parse(wav_file):
input_values = read_file_and_process(wav_file)
with torch.no_grad():
logits = model(**input_values).logits
sentence = parse_transcription(logits)
corrected_sent = corrector(sentence)
return corrected_sent
# def parse(wav_file):
# check_spell = ''
# input_values = read_file_and_process(wav_file)
# with torch.no_grad():
# logits = model(**input_values).logits
# # sentence = parse_transcription(logits)
# check_spell = spell_checker.check(parse_transcription(logits))
# # if check_spell[0] is False:
# # corrected = check_spell[1]
# # else:
# # corrected = sentence
# return spell_checker.check(parse_transcription(logits))[1] if spell_checker.check(parse_transcription(logits))[0] is False else parse_transcription(logits)
model_id = "jonatasgrosman/wav2vec2-large-xlsr-53-persian"
processor = Wav2Vec2Processor.from_pretrained(model_id)
model = Wav2Vec2ForCTC.from_pretrained(model_id)
input_ = gr.Audio(source="microphone",
type="filepath",
label="لطفا دکمه ضبط صدا را بزنید و شروع به صحبت کنید و بعذ از اتمام صحبت دوباره دکمه ضبط را فشار دهید.",
show_download_button=True,
show_edit_button=True,
)
txtbox = gr.Textbox(
label="متن گفتار شما: ",
lines=5,
text_align="right",
show_label=True,
show_copy_button=True,
)
title = "Speech-to-Text (persian)"
description = "، توجه داشته باشید که هرچه گفتار شما شمرده تر باشد خروجی با کیفیت تری دارید.روی دکمه ضبط صدا کلیک کنید و سپس دسترسی مرورگر خود را به میکروفون دستگاه بدهید، سپس شروع به صحبت کنید و برای اتمام ضبط دوباره روی دکمه کلیک کنید"
article = "<p style='text-align: center'><a href='https://github.com/nimaprgrmr'>Large-Scale Self- and Semi-Supervised Learning for Speech Translation</a></p>"
demo = gr.Interface(fn=parse, inputs = input_, outputs=txtbox, title=title, description=description, article = article,
streaming=True, interactive=True,
analytics_enabled=False, show_tips=False, enable_queue=True)
demo.launch(share=True)