|
|
import os |
|
|
import torchaudio |
|
|
import gradio as gr |
|
|
import spaces |
|
|
import torch |
|
|
from transformers import AutoProcessor, AutoModelForCTC |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
print(f"Using device: {device}") |
|
|
|
|
|
|
|
|
examples = [] |
|
|
examples_dir = "examples" |
|
|
if os.path.exists(examples_dir): |
|
|
for filename in os.listdir(examples_dir): |
|
|
if filename.endswith((".wav", ".mp3", ".ogg")): |
|
|
examples.append([os.path.join(examples_dir, filename)]) |
|
|
|
|
|
|
|
|
MODEL_PATH = "badrex/w2v-bert-2.0-kinyarwanda-asr" |
|
|
processor = AutoProcessor.from_pretrained(MODEL_PATH) |
|
|
model = AutoModelForCTC.from_pretrained(MODEL_PATH) |
|
|
|
|
|
|
|
|
model = model.to(device) |
|
|
|
|
|
@spaces.GPU() |
|
|
def process_audio(audio_path): |
|
|
"""Process audio with return the generated response. |
|
|
|
|
|
Args: |
|
|
audio_path: Path to the audio file to be transcribed. |
|
|
Returns: |
|
|
String containing the transcribed text from the audio file, or an error message |
|
|
if the audio file is missing. |
|
|
""" |
|
|
if not audio_path: |
|
|
return "Please upload an audio file." |
|
|
|
|
|
|
|
|
audio_array, sample_rate = torchaudio.load(audio_path) |
|
|
|
|
|
|
|
|
if sample_rate != 16000: |
|
|
audio_array = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio_array) |
|
|
|
|
|
inputs = processor(audio_array, sampling_rate=16000, return_tensors="pt") |
|
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
|
logits = model(**inputs).logits |
|
|
|
|
|
outputs = torch.argmax(logits, dim=-1) |
|
|
|
|
|
decoded_outputs = processor.batch_decode( |
|
|
outputs, |
|
|
skip_special_tokens=True |
|
|
) |
|
|
|
|
|
return decoded_outputs[0].strip() |
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="ASRwanda") as demo: |
|
|
gr.Markdown("# ASRwanda ποΈ Speech Recognition for Kinyarwanda Language π·πΌ") |
|
|
gr.Markdown(""" |
|
|
<div class="centered-content"> |
|
|
<div> |
|
|
<p> |
|
|
Developed with β€ by <a href="https://badrex.github.io/" style="color: #2563eb;">Badr al-Absi</a> β |
|
|
</p> |
|
|
<br> |
|
|
<p style="font-size: 15px; line-height: 1.8;"> |
|
|
Muraho ππΌ |
|
|
<br><br> |
|
|
This is a demo for ASRwanda, a Transformer-based automatic speech recognition (ASR) system for Kinyarwanda language. |
|
|
The underlying ASR model was trained on 1000 hours of transcribed speech provided by |
|
|
<a href="https://digitalumuganda.com/" style="color: #2563eb;">Digital Umuganda</a> as part of the Kinyarwanda |
|
|
<a href="https://www.kaggle.com/competitions/kinyarwanda-automatic-speech-recognition-track-b" style="color: #2563eb;"> ASR hackathon</a> on Kaggle. |
|
|
<br><br> |
|
|
Simply <strong>upload an audio file</strong> π€ or <strong>record yourself speaking</strong> ποΈβΊοΈ to try out the model! |
|
|
</p> |
|
|
</div> |
|
|
</div> |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
audio_input = gr.Audio(type="filepath", label="Upload Audio") |
|
|
submit_btn = gr.Button("Transcribe Audio", variant="primary") |
|
|
|
|
|
with gr.Column(): |
|
|
output_text = gr.Textbox(label="Text Transcription", lines=10) |
|
|
|
|
|
submit_btn.click( |
|
|
fn=process_audio, |
|
|
inputs=[audio_input], |
|
|
outputs=output_text |
|
|
) |
|
|
|
|
|
gr.Examples( |
|
|
examples=examples if examples else None, |
|
|
inputs=[audio_input], |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.queue().launch() |
|
|
|