import os import torchaudio import gradio as gr import spaces import torch from transformers import AutoProcessor, AutoModelForCTC device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") # load examples examples = [] examples_dir = "examples" if os.path.exists(examples_dir): for filename in os.listdir(examples_dir): if filename.endswith((".wav", ".mp3", ".ogg")): examples.append([os.path.join(examples_dir, filename)]) # Load model and processor MODEL_PATH = "badrex/w2v-bert-2.0-kinyarwanda-asr" processor = AutoProcessor.from_pretrained(MODEL_PATH) model = AutoModelForCTC.from_pretrained(MODEL_PATH) # move model and processor to device model = model.to(device) @spaces.GPU() def process_audio(audio_path): """Process audio with return the generated response. Args: audio_path: Path to the audio file to be transcribed. Returns: String containing the transcribed text from the audio file, or an error message if the audio file is missing. """ if not audio_path: return "Please upload an audio file." # get audio array audio_array, sample_rate = torchaudio.load(audio_path) # if sample rate is not 16000, resample to 16000 if sample_rate != 16000: audio_array = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio_array) inputs = processor(audio_array, sampling_rate=16000, return_tensors="pt") inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): logits = model(**inputs).logits outputs = torch.argmax(logits, dim=-1) decoded_outputs = processor.batch_decode( outputs, skip_special_tokens=True ) return decoded_outputs[0].strip() # Define Gradio interface with gr.Blocks(title="ASRwanda") as demo: gr.Markdown("# ASRwanda πŸŽ™οΈ Speech Recognition for Kinyarwanda Language πŸ‡·πŸ‡Ό") gr.Markdown("""

Developed with ❀ by Badr al-Absi β˜•


Muraho πŸ‘‹πŸΌ

This is a demo for ASRwanda, a Transformer-based automatic speech recognition (ASR) system for Kinyarwanda language. The underlying ASR model was trained on 1000 hours of transcribed speech provided by Digital Umuganda as part of the Kinyarwanda ASR hackathon on Kaggle.

Simply upload an audio file πŸ“€ or record yourself speaking πŸŽ™οΈβΊοΈ to try out the model!

""") with gr.Row(): with gr.Column(): audio_input = gr.Audio(type="filepath", label="Upload Audio") submit_btn = gr.Button("Transcribe Audio", variant="primary") with gr.Column(): output_text = gr.Textbox(label="Text Transcription", lines=10) submit_btn.click( fn=process_audio, inputs=[audio_input], outputs=output_text ) gr.Examples( examples=examples if examples else None, inputs=[audio_input], ) # Launch the app if __name__ == "__main__": demo.queue().launch()