File size: 3,241 Bytes
bd7af00
 
dcae561
 
 
 
 
 
 
229460f
dcae561
 
 
 
 
229460f
 
 
 
 
 
 
 
dcae561
2f94472
 
 
 
 
dcae561
182ac63
 
 
 
 
 
2f94472
 
 
dcae561
 
 
 
 
 
 
 
 
 
 
 
 
182ac63
 
dcae561
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229460f
dcae561
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182ac63
 
dcae561
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import os 
import time
import torch
import torchaudio
import gradio as gr

from torchaudio.transforms import Resample
from torchaudio.models.decoder import download_pretrained_files, ctc_decoder



# Constants for decoding
LM_WEIGHT = 1.23
WORD_SCORE = -0.26

def get_featurizer():
    return torchaudio.transforms.MelSpectrogram(
        sample_rate=16000,
        n_fft=400,
        win_length=400,
        hop_length=160,
        n_mels=80,
    )


def preprocess_audio(audio_file, featurizer, target_sample_rate=16000):
    """
    Preprocess the audio: load, resample, and extract features.
    """
    try:
        # Wait for file to be saved
        wait_time = 0
        while not os.path.exists(audio_file) and wait_time < 3:
            time.sleep(0.1)
            wait_time += 0.1

        waveform, sample_rate = torchaudio.load(audio_file)
        if sample_rate != target_sample_rate:
            waveform = Resample(orig_freq=sample_rate, new_freq=target_sample_rate)(waveform)
        return featurizer(waveform).permute(0, 2, 1)
    except Exception as e:
        raise ValueError(f"Error in preprocessing audio: {e}")


def decode_emission(emission, tokens, files):
    try:
        beam_search_decoder = ctc_decoder(
            lexicon=files.lexicon,
            tokens=tokens,
            lm=files.lm,
            nbest=1,
            beam_size=100,
            beam_threshold=50,
            beam_size_token=25,
            lm_weight=LM_WEIGHT,
            word_score=WORD_SCORE,
        )
        beam_search_result = beam_search_decoder(emission)
        return " ".join(beam_search_result[0][0].words).strip()
    except Exception as e:
        raise ValueError(f"Error in decoding: {e}")


def transcribe(audio_file, model, featurizer, tokens, files):
    try:
        waveform = preprocess_audio(audio_file, featurizer)
        emission = model(waveform)
        return decode_emission(emission, tokens, files)
    except Exception as e:
        return f"Error processing audio: {e}"


def launch_app(model_path, token_path="tokens.txt", share=False):
    model = torch.jit.load(model_path)
    model.eval().to('cpu')

    with open(token_path, 'r') as f:
        tokens = f.read().splitlines()

    files = download_pretrained_files("librispeech-4-gram")
    featurizer = get_featurizer()

    def gradio_transcribe(audio_file):
        return transcribe(audio_file, model, featurizer, tokens, files)

    interface = gr.Interface(
        fn=gradio_transcribe,
        inputs=gr.Audio(sources="microphone", type="filepath", label="Speak into the microphone"),
        outputs="text",
        title="Conformer-Small ASR Model",
        description="""<b>Trained on:</b> Mozilla Corpus, Personal Recordings, and LibriSpeech — 2900 hrs of audio data.<br>
                       <b>Training Script and Experiment Results</b> available <a href="https://github.com/LuluW8071/Conformer" target="_blank">here</a>""",
    )

    interface.launch(share=share)


if __name__ == "__main__":
    try:
        model_path = "optimized_model.pt"
        token_path = "tokens.txt"
        share = False 
        launch_app(model_path, token_path, share)
    except Exception as e:
        raise ValueError(f"Fatal error: {e}")