File size: 4,852 Bytes
3855474
 
 
 
 
 
 
153699b
3855474
 
723dade
 
3855474
153699b
3855474
 
136286c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a2dfc2b
136286c
46958a8
136286c
a2dfc2b
136286c
 
 
 
46958a8
 
 
 
 
 
 
a2dfc2b
136286c
a2dfc2b
136286c
 
 
 
 
 
 
 
 
 
 
 
 
 
3855474
723dade
 
 
3855474
 
02bd135
3855474
 
153699b
 
0cd38f0
 
153699b
 
 
 
 
 
 
 
 
 
 
 
 
3855474
723dade
45957cb
1abe91a
723dade
 
3855474
 
 
153699b
3855474
 
 
a2dfc2b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import gradio as gr
import torch
from transformers import (
    AutoModelForCTC, 
    Wav2Vec2Processor,
    AutoProcessor,
    WhisperProcessor,
    WhisperForConditionalGeneration
)
import librosa
from gradio_pdf import PDF
import os  # For working with file paths

# Initialize device - will work on CPU if GPU not available
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

class ModelManager:
    def __init__(self):
        self.asr_models = {}
        
    def load_wav2vec2_base(self):
        model = AutoModelForCTC.from_pretrained("kabir259/w2v2-base_kabir").to(DEVICE)
        processor = Wav2Vec2Processor.from_pretrained("kabir259/w2v2-base_kabir")
        return model, processor
    
    def load_wav2vec2_bert(self):
        model = AutoModelForCTC.from_pretrained("Kabir259/w2v2-BERT_kabir").to(DEVICE)
        processor = AutoProcessor.from_pretrained("Kabir259/w2v2-BERT_kabir")
        return model, processor
    
    def load_whisper_small(self):
        model = WhisperForConditionalGeneration.from_pretrained("Kabir259/whisper-small_kabir").to(DEVICE)
        processor = WhisperProcessor.from_pretrained("Kabir259/whisper-small_kabir")
        model.generation_config.task = "transcribe"
        return model, processor

    def get_asr_model(self, model_name):
        if model_name not in self.asr_models:
            if model_name == "wav2vec2-base":
                self.asr_models[model_name] = self.load_wav2vec2_base()
            elif model_name == "wav2vec2-BERT":
                self.asr_models[model_name] = self.load_wav2vec2_bert()
            elif model_name == "whisper-small":
                self.asr_models[model_name] = self.load_whisper_small()
        return self.asr_models[model_name]

def process_audio(audio_path, asr_model_name, model_manager):
    model, processor = model_manager.get_asr_model(asr_model_name)
    
    # Load and preprocess audio
    audio, sr = librosa.load(audio_path, sr=16000)  # Load audio with a fixed sampling rate
    
    if asr_model_name == "wav2vec2-base":
        # Process audio for wav2vec2 models
        input_values = processor(audio, sampling_rate=16000, return_tensors="pt").input_values.to(DEVICE)
        with torch.no_grad():
            logits = model(input_values).logits
        predicted_ids = torch.argmax(logits, dim=-1)
        transcription = processor.batch_decode(predicted_ids)[0]
        
    elif asr_model_name == "wav2vec2-BERT":
        input_values = processor(audio, sampling_rate=16000, return_tensors="pt").input_features.to(DEVICE)
        with torch.no_grad():
            logits = model(input_values).logits
        predicted_ids = torch.argmax(logits, dim=-1)
        transcription = processor.batch_decode(predicted_ids)[0]
    
    else:  # whisper model
        input_features = processor(audio, sampling_rate=16000, return_tensors="pt").input_features.to(DEVICE)
        with torch.no_grad():
            predicted_ids = model.generate(input_features)
        transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
    
    return transcription

def process_pipeline(audio, asr_model_choice, model_manager):
    if audio is None:
        return "Please record some audio first."
    transcription = process_audio(audio, asr_model_choice, model_manager)
    return transcription

# Initialize the model manager
model_manager = ModelManager()

# Path to your PDF (relative path to `main.pdf`)
path_to_pdf = os.path.join(os.path.dirname(__file__), "main.pdf")

# Create Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# Medical Speech Recognition System 🥼")
    
    with gr.Row():
        with gr.Column():
            audio_input = gr.Audio(
                label="Record Audio",
                type="filepath"
            )
            asr_model_choice = gr.Dropdown(
                choices=["wav2vec2-base", "wav2vec2-BERT", "whisper-small"],
                value="wav2vec2-base",
                label="Select ASR Model"
            )
            submit_btn = gr.Button("Transcribe")
        
        with gr.Column():
            transcription_output = gr.Textbox(
                label="Transcribed Text",
                placeholder="Transcription will appear here..."
            )
    
    with gr.Row():
        gr.Markdown("## Benchmarking Wav2Vec 2.0, Whisper & Qwen2 for my Medical ASR + LLM pipeline!  <br>[PDF](https://github.com/Kabir259/BenchASR-LLM4Med/blob/main/main.pdf), [GitHub](https://github.com/Kabir259/BenchASR-LLM4Med)")

        pdf_display = PDF(path_to_pdf)  # Display the pre-loaded PDF

    submit_btn.click(
        fn=lambda audio, asr_choice: process_pipeline(audio, asr_choice, model_manager),
        inputs=[audio_input, asr_model_choice],
        outputs=transcription_output
    )

if __name__ == "__main__":
    demo.launch(share=True)