Harveenchadha's picture
Update app.py
c043038
raw
history blame
3.59 kB
import soundfile as sf
import torch
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor,Wav2Vec2ProcessorWithLM
import gradio as gr
import scipy.signal as sps
import sox
import subprocess
def convert(inputfile, outfile):
sox_tfm = sox.Transformer()
sox_tfm.set_output_format(
file_type="wav", channels=1, encoding="signed-integer", rate=16000, bits=16
)
#print(this is not done)
sox_tfm.build(inputfile, outfile)
def read_file(wav):
sample_rate, signal = wav
signal = signal.mean(-1)
number_of_samples = round(len(signal) * float(16000) / sample_rate)
resampled_signal = sps.resample(signal, number_of_samples)
return resampled_signal
def resampler(input_file_path, output_file_path):
#output_file_path = output_folder_path + input_file_path.split('/')[-1]
command = (
f"ffmpeg -hide_banner -loglevel panic -i {input_file_path} -ar 16000 -ac 1 -bits_per_raw_sample 16 -vn "
f"{output_file_path}"
)
subprocess.call(command, shell=True)
def parse_transcription_with_lm(wav_file):
input_values = read_file_and_process(wav_file)
# with torch.no_grad():
# logits = model(**input_values).logits[0].cpu().numpy()
# print(logits)
# int_result = processor_with_LM.decode(logits = logits, output_word_offsets=False,
# beam_width=128
# )
# print(int_result)
# transcription = int_result.text.replace('<s>','')
with torch.no_grad():
logits = self.model(**input_values).logits
result = self.processor.batch_decode(logits.cpu().numpy())
text = result.text
transcription = text[0]
return transcription
def read_file_and_process(wav_file):
filename = wav_file.split('.')[0]
resampler(wav_file, filename + "16k.wav")
speech, _ = sf.read(filename + "16k.wav")
inputs = processor(speech, sampling_rate=16_000, return_tensors="pt", padding=True)
return inputs
def parse(wav_file, applyLM):
if applyLM:
return parse_transcription_with_lm(wav_file)
else:
return parse_transcription(wav_file)
def parse_transcription(wav_file):
input_values = read_file_and_process(wav_file)
with torch.no_grad():
logits = model(**input_values).logits
#logits = model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.decode(predicted_ids[0], skip_special_tokens=True)
return transcription
model_id = "Harveenchadha/vakyansh-wav2vec2-hindi-him-4200"
processor = Wav2Vec2Processor.from_pretrained(model_id)
processor_with_LM = Wav2Vec2ProcessorWithLM.from_pretrained(model_id)
model = Wav2Vec2ForCTC.from_pretrained(model_id)
input_ = gr.Audio(source="microphone", type="filepath")
#input_ = gr.inputs.Audio(source="microphone", type="numpy")
txtbox = gr.Textbox(
label="Output from model will appear here:",
lines=5
)
chkbox = gr.Checkbox(label="Apply LM", value=False)
gr.Interface(parse, inputs = [input_, chkbox], outputs=txtbox,
streaming=True, interactive=True,
analytics_enabled=False, show_tips=False, enable_queue=True).launch(inline=False);