|
import gradio as gr |
|
from nemo.collections.asr.models import ASRModel |
|
import torch |
|
import os |
|
import spaces |
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
zero = torch.Tensor([0]).to(device) |
|
print(f"Device being used: {zero.device}") |
|
|
|
|
|
MODEL_NAMES = [ |
|
"NAMAA-Space/EgypTalk-ASR-v2" |
|
] |
|
|
|
|
|
LOADED_MODELS = {} |
|
|
|
|
|
NEMO_FILENAME = "asr-egyptian-nemo-v2.0.nemo" |
|
|
|
def get_model(model_name): |
|
if model_name not in LOADED_MODELS: |
|
print(f"Loading model {model_name} on {device}...") |
|
|
|
nemo_path = hf_hub_download(repo_id=model_name, filename=NEMO_FILENAME) |
|
model = ASRModel.restore_from(nemo_path) |
|
model = model.to(device) |
|
model.eval() |
|
LOADED_MODELS[model_name] = model |
|
print(f"Model {model_name} loaded successfully!") |
|
return LOADED_MODELS[model_name] |
|
|
|
@spaces.GPU(duration=120) |
|
def transcribe_and_score(audio): |
|
if audio is None: |
|
return "" |
|
model = get_model(MODEL_NAMES[0]) |
|
predictions = model.transcribe([audio]) |
|
item = predictions[0] if isinstance(predictions, list) else predictions |
|
|
|
|
|
if hasattr(item, "text"): |
|
text = item.text |
|
elif isinstance(item, dict) and "text" in item: |
|
text = item["text"] |
|
elif isinstance(item, str): |
|
text = item |
|
else: |
|
text = str(item) |
|
|
|
return text.strip() |
|
|
|
@spaces.GPU(duration=120) |
|
def batch_transcribe(audio_files): |
|
if not audio_files: |
|
return [] |
|
model = get_model(MODEL_NAMES[0]) |
|
predictions = model.transcribe(audio_files) |
|
|
|
texts = [] |
|
if isinstance(predictions, list): |
|
for p in predictions: |
|
if hasattr(p, "text"): |
|
t = p.text |
|
elif isinstance(p, dict) and "text" in p: |
|
t = p["text"] |
|
elif isinstance(p, str): |
|
t = p |
|
else: |
|
t = str(p) |
|
texts.append(t) |
|
else: |
|
texts.append(str(predictions)) |
|
|
|
return [[t.strip()] for t in texts] |
|
|
|
with gr.Blocks(title="EgypTalk-ASR-v2") as demo: |
|
gr.Markdown(""" |
|
# EgypTalk-ASR-v2 |
|
Upload or record an audio file. This app transcribes audio using EgypTalk-ASR-v2. |
|
""") |
|
with gr.Tab("Single Test"): |
|
with gr.Row(): |
|
audio_input = gr.Audio(sources=["microphone", "upload"], type="filepath", label="Audio") |
|
transcribe_btn = gr.Button("Transcribe") |
|
pred_output = gr.Textbox(label="Transcription") |
|
transcribe_btn.click(transcribe_and_score, inputs=[audio_input], outputs=[pred_output]) |
|
|
|
with gr.Tab("Batch Test"): |
|
gr.Markdown("Upload multiple audio files. Batch size is limited by GPU/CPU memory.") |
|
audio_files = gr.Files(label="Audio Files (wav)") |
|
batch_btn = gr.Button("Batch Transcribe") |
|
preds_output = gr.Dataframe(headers=["Transcription"], label="Results") |
|
batch_btn.click(batch_transcribe, inputs=[audio_files], outputs=[preds_output]) |
|
|
|
|
|
demo.launch(share=True) |