Spaces:
Sleeping
Sleeping
import os | |
import time | |
import torch | |
import torchaudio | |
import gradio as gr | |
from torchaudio.transforms import Resample | |
from torchaudio.models.decoder import download_pretrained_files, ctc_decoder | |
# Constants for decoding | |
LM_WEIGHT = 1.23 | |
WORD_SCORE = -0.26 | |
def get_featurizer(): | |
return torchaudio.transforms.MelSpectrogram( | |
sample_rate=16000, | |
n_fft=400, | |
win_length=400, | |
hop_length=160, | |
n_mels=80, | |
) | |
def preprocess_audio(audio_file, featurizer, target_sample_rate=16000): | |
""" | |
Preprocess the audio: load, resample, and extract features. | |
""" | |
try: | |
# Wait for file to be saved | |
wait_time = 0 | |
while not os.path.exists(audio_file) and wait_time < 3: | |
time.sleep(0.1) | |
wait_time += 0.1 | |
waveform, sample_rate = torchaudio.load(audio_file) | |
if sample_rate != target_sample_rate: | |
waveform = Resample(orig_freq=sample_rate, new_freq=target_sample_rate)(waveform) | |
return featurizer(waveform).permute(0, 2, 1) | |
except Exception as e: | |
raise ValueError(f"Error in preprocessing audio: {e}") | |
def decode_emission(emission, tokens, files): | |
try: | |
beam_search_decoder = ctc_decoder( | |
lexicon=files.lexicon, | |
tokens=tokens, | |
lm=files.lm, | |
nbest=1, | |
beam_size=100, | |
beam_threshold=50, | |
beam_size_token=25, | |
lm_weight=LM_WEIGHT, | |
word_score=WORD_SCORE, | |
) | |
beam_search_result = beam_search_decoder(emission) | |
return " ".join(beam_search_result[0][0].words).strip() | |
except Exception as e: | |
raise ValueError(f"Error in decoding: {e}") | |
def transcribe(audio_file, model, featurizer, tokens, files): | |
try: | |
waveform = preprocess_audio(audio_file, featurizer) | |
emission = model(waveform) | |
return decode_emission(emission, tokens, files) | |
except Exception as e: | |
return f"Error processing audio: {e}" | |
def launch_app(model_path, token_path="tokens.txt", share=False): | |
model = torch.jit.load(model_path) | |
model.eval().to('cpu') | |
with open(token_path, 'r') as f: | |
tokens = f.read().splitlines() | |
files = download_pretrained_files("librispeech-4-gram") | |
featurizer = get_featurizer() | |
def gradio_transcribe(audio_file): | |
return transcribe(audio_file, model, featurizer, tokens, files) | |
interface = gr.Interface( | |
fn=gradio_transcribe, | |
inputs=gr.Audio(sources="microphone", type="filepath", label="Speak into the microphone"), | |
outputs="text", | |
title="Conformer-Small ASR Model", | |
description="""<b>Trained on:</b> Mozilla Corpus, Personal Recordings, and LibriSpeech β 2900 hrs of audio data.<br> | |
<b>Training Script and Experiment Results</b> available <a href="https://github.com/LuluW8071/Conformer" target="_blank">here</a>""", | |
) | |
interface.launch(share=share) | |
if __name__ == "__main__": | |
try: | |
model_path = "optimized_model.pt" | |
token_path = "tokens.txt" | |
share = False | |
launch_app(model_path, token_path, share) | |
except Exception as e: | |
raise ValueError(f"Fatal error: {e}") | |