Edit model card

Wav2Vec-XLS-R-300M fine-tuned for Arabic

Wav2Vec-XLS-R-300M Fine-tuned for Arabic using Common-Voice 11. When using the model, make sure the audio files are sampled at 16 kHz.

Evaluation

The model can be used directly (without a language model) as follows:

import torch
import torchaudio
from datasets import load_dataset, Audio
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
import string, re

punctuation = '''`÷×؛<>_()*&^%][ـ،/:"؟.,'{}~¦+|!”…“–ـ''' + string.punctuation
arabic_diacritics = re.compile("""
                             ّ    | # Shadda
                             َ    | # Fatha
                             ً    | # Tanwin Fath
                             ُ    | # Damma
                             ٌ    | # Tanwin Damm
                             ِ    | # Kasra
                             ٍ    | # Tanwin Kasr
                             ْ    | # Sukun
                             ـ     # Tatwil/Kashid
                         """, re.VERBOSE)


def process_text(text):
    translator = str.maketrans('', '', punctuation)
    text = text.translate(translator)
    text = re.sub("[0123456789]", '', text)

    # remove Tashkeel
    text = re.sub(arabic_diacritics, '', text)

    # remove elongation
    text = re.sub("[إأآا]", "ا", text)
    text = re.sub("ى", "ي", text)
    text = re.sub("ؤ", "ء", text)
    text = re.sub("ئ", "ء", text)
    text = re.sub("ة", "ه", text)
    text = re.sub("گ", "ك", text)

    text = ' '.join(word for word in text.split())


feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0,
do_normalize=True, return_attention_mask=True)
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained("aitor-alvarez/wav2vec2-xls-r-300m-ar", unk_token="[UNK]", pad_token="[PAD]",
word_delimiter_token="|")
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
model = Wav2Vec2ForCTC.from_pretrained("aitor-alvarez/wav2vec2-xls-r-300m-ar").to("cuda")


def prepare_dataset(batch):
  audio = batch["audio"]
  
  batch["input_values"] = processor(audio["array"], sampling_rate=audio["sampling_rate"]).input_values[0]
  batch["input_length"] = len(batch["input_values"])
  
  with processor.as_target_processor():
  batch["labels"] = processor(batch["sentence"]).input_ids
  return batch


def remove_ar_special_characters(batch):
    batch["sentence"] = process_text(batch["sentence"]).lower()
    return batch

speech_test = load_dataset("mozilla-foundation/common_voice_11_0", "ar", split="test")
speech_test = speech_test.cast_column("audio", Audio(sampling_rate=16_000))
speech_test = speech_test.map(remove_ar_special_characters)
speech_test = speech_test.map(prepare_dataset, remove_columns=speech_test.column_names)


def get_predictions(batch):
  with torch.no_grad():
    input_dict = processor(batch["input_values"], return_tensors="pt", padding=True)
    logits = model(input_dict.input_values.to(device), attention_mask=input_dict.attention_mask.to(device)).logits
    pred_ids = torch.argmax(logits, dim=-1)
    batch["pred_txt"] = processor.batch_decode(pred_ids)[0]
    batch["txt"] = processor.decode(batch["labels"])
    return batch

results = speech_test.map(get_predictions)
print("Test WER: {:.2f}".format(wer_metric.compute(predictions=results["pred_txt"], references=results["txt"])))

WER Test: 20.71 %

Downloads last month
11
Safetensors
Model size
316M params
Tensor type
F32
·
Inference API
or
This model can be loaded on Inference API (serverless).

Dataset used to train aitor-alvarez/wav2vec2-xls-r-300m-ar

Space using aitor-alvarez/wav2vec2-xls-r-300m-ar 1