Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from nemo.core import ModelPT | |
| import torch | |
| import os | |
| import spaces | |
| # Check for GPU support and configure appropriately | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| zero = torch.Tensor([0]).to(device) | |
| print(f"Device being used: {zero.device}") | |
| # Model name from HuggingFace | |
| MODEL_NAMES = [ | |
| "NAMAA-Space/EgypTalk-ASR-v2" | |
| ] | |
| # Cache loaded models | |
| LOADED_MODELS = {} | |
| def get_model(model_name): | |
| if model_name not in LOADED_MODELS: | |
| print(f"Loading model {model_name} on {device}...") | |
| # Load model from HuggingFace using ModelPT | |
| model = ModelPT.from_pretrained(model_name) | |
| model = model.to(device) | |
| model.eval() | |
| LOADED_MODELS[model_name] = model | |
| print(f"Model {model_name} loaded successfully!") | |
| return LOADED_MODELS[model_name] | |
| def transcribe_and_score(audio): | |
| if audio is None: | |
| return "" | |
| model = get_model(MODEL_NAMES[0]) | |
| # Use the correct transcribe API | |
| predictions = model.transcribe([audio]) | |
| pred = predictions[0] if isinstance(predictions, list) else predictions | |
| if not isinstance(pred, str): | |
| pred = str(pred) | |
| return pred.strip() | |
| def batch_transcribe(audio_files): | |
| if not audio_files: | |
| return [] | |
| model = get_model(MODEL_NAMES[0]) | |
| # Use the correct transcribe API for batch | |
| predictions = model.transcribe(audio_files) | |
| if isinstance(predictions, list): | |
| texts = [p if isinstance(p, str) else str(p) for p in predictions] | |
| else: | |
| texts = [str(predictions)] | |
| # Return as rows for a single-column dataframe | |
| return [[t.strip()] for t in texts] | |
| with gr.Blocks(title="EgypTalk-ASR-v2") as demo: | |
| gr.Markdown(""" | |
| # EgypTalk-ASR-v2 | |
| Upload an audio file. This app transcribes audio using EgypTalk-ASR-v2. | |
| """) | |
| with gr.Tab("Single Test"): | |
| with gr.Row(): | |
| audio_input = gr.Audio(type="filepath", label="Audio File") | |
| transcribe_btn = gr.Button("Transcribe") | |
| pred_output = gr.Textbox(label="Transcription") | |
| transcribe_btn.click(transcribe_and_score, inputs=[audio_input], outputs=[pred_output]) | |
| with gr.Tab("Batch Test"): | |
| gr.Markdown("Upload multiple audio files. Batch size is limited by GPU/CPU memory.") | |
| audio_files = gr.Files(label="Audio Files (wav)") | |
| batch_btn = gr.Button("Batch Transcribe") | |
| preds_output = gr.Dataframe(headers=["Transcription"], label="Results") | |
| batch_btn.click(batch_transcribe, inputs=[audio_files], outputs=[preds_output]) | |
| demo.launch(share=True) |