Wav2Vec-XLS-R-300M fine-tuned for Arabic
Wav2Vec-XLS-R-300M Fine-tuned for Arabic using Common-Voice 11. When using the model, make sure the audio files are sampled at 16 kHz.
Evaluation
The model can be used directly (without a language model) as follows:
import torch
import torchaudio
from datasets import load_dataset, Audio
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
import string, re
punctuation = '''`÷×؛<>_()*&^%][ـ،/:"؟.,'{}~¦+|!”…“–ـ''' + string.punctuation
arabic_diacritics = re.compile("""
ّ | # Shadda
َ | # Fatha
ً | # Tanwin Fath
ُ | # Damma
ٌ | # Tanwin Damm
ِ | # Kasra
ٍ | # Tanwin Kasr
ْ | # Sukun
ـ # Tatwil/Kashid
""", re.VERBOSE)
def process_text(text):
translator = str.maketrans('', '', punctuation)
text = text.translate(translator)
text = re.sub("[0123456789]", '', text)
# remove Tashkeel
text = re.sub(arabic_diacritics, '', text)
# remove elongation
text = re.sub("[إأآا]", "ا", text)
text = re.sub("ى", "ي", text)
text = re.sub("ؤ", "ء", text)
text = re.sub("ئ", "ء", text)
text = re.sub("ة", "ه", text)
text = re.sub("گ", "ك", text)
text = ' '.join(word for word in text.split())
feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0,
do_normalize=True, return_attention_mask=True)
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained("aitor-alvarez/wav2vec2-xls-r-300m-ar", unk_token="[UNK]", pad_token="[PAD]",
word_delimiter_token="|")
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
model = Wav2Vec2ForCTC.from_pretrained("aitor-alvarez/wav2vec2-xls-r-300m-ar").to("cuda")
def prepare_dataset(batch):
audio = batch["audio"]
batch["input_values"] = processor(audio["array"], sampling_rate=audio["sampling_rate"]).input_values[0]
batch["input_length"] = len(batch["input_values"])
with processor.as_target_processor():
batch["labels"] = processor(batch["sentence"]).input_ids
return batch
def remove_ar_special_characters(batch):
batch["sentence"] = process_text(batch["sentence"]).lower()
return batch
speech_test = load_dataset("mozilla-foundation/common_voice_11_0", "ar", split="test")
speech_test = speech_test.cast_column("audio", Audio(sampling_rate=16_000))
speech_test = speech_test.map(remove_ar_special_characters)
speech_test = speech_test.map(prepare_dataset, remove_columns=speech_test.column_names)
def get_predictions(batch):
with torch.no_grad():
input_dict = processor(batch["input_values"], return_tensors="pt", padding=True)
logits = model(input_dict.input_values.to(device), attention_mask=input_dict.attention_mask.to(device)).logits
pred_ids = torch.argmax(logits, dim=-1)
batch["pred_txt"] = processor.batch_decode(pred_ids)[0]
batch["txt"] = processor.decode(batch["labels"])
return batch
results = speech_test.map(get_predictions)
print("Test WER: {:.2f}".format(wer_metric.compute(predictions=results["pred_txt"], references=results["txt"])))
WER Test: 20.71 %
- Downloads last month
- 11