Spaces:
Sleeping
Sleeping
File size: 4,852 Bytes
3855474 153699b 3855474 723dade 3855474 153699b 3855474 136286c a2dfc2b 136286c 46958a8 136286c a2dfc2b 136286c 46958a8 a2dfc2b 136286c a2dfc2b 136286c 3855474 723dade 3855474 02bd135 3855474 153699b 0cd38f0 153699b 3855474 723dade 45957cb 1abe91a 723dade 3855474 153699b 3855474 a2dfc2b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
import gradio as gr
import torch
from transformers import (
AutoModelForCTC,
Wav2Vec2Processor,
AutoProcessor,
WhisperProcessor,
WhisperForConditionalGeneration
)
import librosa
from gradio_pdf import PDF
import os # For working with file paths
# Initialize device - will work on CPU if GPU not available
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
class ModelManager:
def __init__(self):
self.asr_models = {}
def load_wav2vec2_base(self):
model = AutoModelForCTC.from_pretrained("kabir259/w2v2-base_kabir").to(DEVICE)
processor = Wav2Vec2Processor.from_pretrained("kabir259/w2v2-base_kabir")
return model, processor
def load_wav2vec2_bert(self):
model = AutoModelForCTC.from_pretrained("Kabir259/w2v2-BERT_kabir").to(DEVICE)
processor = AutoProcessor.from_pretrained("Kabir259/w2v2-BERT_kabir")
return model, processor
def load_whisper_small(self):
model = WhisperForConditionalGeneration.from_pretrained("Kabir259/whisper-small_kabir").to(DEVICE)
processor = WhisperProcessor.from_pretrained("Kabir259/whisper-small_kabir")
model.generation_config.task = "transcribe"
return model, processor
def get_asr_model(self, model_name):
if model_name not in self.asr_models:
if model_name == "wav2vec2-base":
self.asr_models[model_name] = self.load_wav2vec2_base()
elif model_name == "wav2vec2-BERT":
self.asr_models[model_name] = self.load_wav2vec2_bert()
elif model_name == "whisper-small":
self.asr_models[model_name] = self.load_whisper_small()
return self.asr_models[model_name]
def process_audio(audio_path, asr_model_name, model_manager):
model, processor = model_manager.get_asr_model(asr_model_name)
# Load and preprocess audio
audio, sr = librosa.load(audio_path, sr=16000) # Load audio with a fixed sampling rate
if asr_model_name == "wav2vec2-base":
# Process audio for wav2vec2 models
input_values = processor(audio, sampling_rate=16000, return_tensors="pt").input_values.to(DEVICE)
with torch.no_grad():
logits = model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)[0]
elif asr_model_name == "wav2vec2-BERT":
input_values = processor(audio, sampling_rate=16000, return_tensors="pt").input_features.to(DEVICE)
with torch.no_grad():
logits = model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)[0]
else: # whisper model
input_features = processor(audio, sampling_rate=16000, return_tensors="pt").input_features.to(DEVICE)
with torch.no_grad():
predicted_ids = model.generate(input_features)
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
return transcription
def process_pipeline(audio, asr_model_choice, model_manager):
if audio is None:
return "Please record some audio first."
transcription = process_audio(audio, asr_model_choice, model_manager)
return transcription
# Initialize the model manager
model_manager = ModelManager()
# Path to your PDF (relative path to `main.pdf`)
path_to_pdf = os.path.join(os.path.dirname(__file__), "main.pdf")
# Create Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Medical Speech Recognition System 🥼")
with gr.Row():
with gr.Column():
audio_input = gr.Audio(
label="Record Audio",
type="filepath"
)
asr_model_choice = gr.Dropdown(
choices=["wav2vec2-base", "wav2vec2-BERT", "whisper-small"],
value="wav2vec2-base",
label="Select ASR Model"
)
submit_btn = gr.Button("Transcribe")
with gr.Column():
transcription_output = gr.Textbox(
label="Transcribed Text",
placeholder="Transcription will appear here..."
)
with gr.Row():
gr.Markdown("## Benchmarking Wav2Vec 2.0, Whisper & Qwen2 for my Medical ASR + LLM pipeline! <br>[PDF](https://github.com/Kabir259/BenchASR-LLM4Med/blob/main/main.pdf), [GitHub](https://github.com/Kabir259/BenchASR-LLM4Med)")
pdf_display = PDF(path_to_pdf) # Display the pre-loaded PDF
submit_btn.click(
fn=lambda audio, asr_choice: process_pipeline(audio, asr_choice, model_manager),
inputs=[audio_input, asr_model_choice],
outputs=transcription_output
)
if __name__ == "__main__":
demo.launch(share=True)
|