m3hrdadfi's picture
change precision
a0462f3
|
raw
history blame
10.2 kB
metadata
language: fa
datasets:
  - common_voice
tags:
  - audio
  - automatic-speech-recognition
  - speech
  - xlsr-fine-tuning-week
license: apache-2.0
widget:
  - label: Common Voice sample 687
    src: >-
      https://huggingface.co/m3hrdadfi/wav2vec2-large-xlsr-persian/resolve/main/sample687.flac
  - label: Common Voice sample 1671
    src: >-
      https://huggingface.co/m3hrdadfi/wav2vec2-large-xlsr-persian/resolve/main/sample1671.flac
model-index:
  - name: XLSR Wav2Vec2 Persian (Farsi) by Mehrdad Farahani
    results:
      - task:
          name: Speech Recognition
          type: automatic-speech-recognition
        dataset:
          name: Common Voice fa
          type: common_voice
          args: fa
        metrics:
          - name: Test WER
            type: wer
            value: 32.09
          - name: Test CER
            type: cer
            value: 8.23

Wav2Vec2-Large-XLSR-53-tw-gpt

Fine-tuned facebook/wav2vec2-large-xlsr-53 in Persian (Farsi) using Common Voice. When using this model, make sure that your speech input is sampled at 16kHz.

Usage

The model can be used directly (without a language model) as follows:

!pip install git+https://github.com/huggingface/datasets.git
!pip install git+https://github.com/huggingface/transformers.git
!pip install torchaudio
!pip install librosa
!pip install jiwer
!pip install hazm
import torch
import torchaudio
from datasets import load_dataset, load_metric
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor

import librosa

import pandas as pd
import numpy as np

import hazm

import random
import os
import string
import six
import re

import IPython.display as ipd

# Loading the datasets
dataset = load_dataset("common_voice", "fa", split="test[:2%]")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
processor = Wav2Vec2Processor.from_pretrained("m3hrdadfi/wav2vec2-large-xlsr-persian")
model = Wav2Vec2ForCTC.from_pretrained("m3hrdadfi/wav2vec2-large-xlsr-persian").to(device)


# Preprocessing the datasets.
# Normalizing the texts
_normalizer = hazm.Normalizer()
def multiple_replace(mapping, text):
    pattern = "|".join(map(re.escape, mapping.keys()))
    return re.sub(pattern, lambda m: mapping[m.group()], str(text))

def convert_weirdos(input_str):
    # character
    mapping = {
        'ك': 'ک',
        'دِ': 'د',
        'بِ': 'ب',
        'زِ': 'ز',
        'ذِ': 'ذ',
        'شِ': 'ش',
        'سِ': 'س',
        'ى': 'ی',
        'ي': 'ی',
        'أ': 'ا',
        'ؤ': 'و',
        "ے": "ی",
        "ۀ": "ه",
        "ﭘ": "پ",
        "ﮐ": "ک",
        "ﯽ": "ی",
        "ﺎ": "ا",
        "ﺑ": "ب",
        "ﺘ": "ت",
        "ﺧ": "خ",
        "ﺩ": "د",
        "ﺱ": "س",
        "ﻀ": "ض",
        "ﻌ": "ع",
        "ﻟ": "ل",
        "ﻡ": "م",
        "ﻢ": "م",
        "ﻪ": "ه",
        "ﻮ": "و",
        "ئ": "ی",
        'ﺍ': "ا",
        'ة': "ه",
        'ﯾ': "ی",
        'ﯿ': "ی",
        'ﺒ': "ب",
        'ﺖ': "ت",
        'ﺪ': "د",
        'ﺮ': "ر",
        'ﺴ': "س",
        'ﺷ': "ش",
        'ﺸ': "ش",
        'ﻋ': "ع",
        'ﻤ': "م",
        'ﻥ': "ن",
        'ﻧ': "ن",
        'ﻭ': "و",
        'ﺭ': "ر",
        "ﮔ": "گ",
    }

    # notation
    mapping.update(**{
        "#": " ",
        "!": " ",
        "؟": " ",
        "?": " ",
        "«": " ",
        "»": " ",
        "ء": " ",
        "،": " ",
        "(": " ",
        ")": " ",
        "؛": " ",
        "'ٔ": " ",
        "٬": " ",
        'ٔ': " ",
        ",": " ",
        "?": " ",
        ".": " ",
        "!": " ",
        "-": " ",
        ";": " ",
        ":": " ",
        '"': " ",
        "“": " ",
        "%": " ",
        "‘": " ",
        "”": " ",
        "�": " ",
        "–": " ",
        "…": " ",
        "_": " ",
    })

    return multiple_replace(mapping, input_str)


PERSIAN_ALPHA = "\u0621-\u0628\u062A-\u063A\u0641-\u0642\u0644-\u0648\u064E-\u0651\u0655\u067E\u0686\u0698\u06A9\u06AF\u06BE\u06CC"  # noqa: E501
PERSIAN_DIGIT = "\u06F0-\u06F9"

COMMON_ARABIC_ALPHA = "\u0629\u0643\u0649-\u064B\u064D\u06D5"
COMMON_ARABIC_DIGIT = "\u0660-\u0669"

ZWNJ = "\u200c"

ENGLISH = "a-z0-9\&"
PERSIAN = PERSIAN_ALPHA + PERSIAN_DIGIT + COMMON_ARABIC_ALPHA + COMMON_ARABIC_DIGIT + ZWNJ


def normalizer(text, min_ratio=1.1):
    text = text.lower()
    text = _normalizer.normalize(text)
    text = text.replace("\u200c", " ")
    text = text.replace("\u200d", " ")
    text = text.replace("\u200e", " ")
    text = text.replace("\u200f", " ")
    text = text.replace("\ufeff", " ")
    text = convert_weirdos(text)

    words = [word.replace("آ", "ا") if "آ" in word and not word.startswith("آ") else word for word in text.split()]
    text = " ".join(words)

    if not text or not len(text) > 2:
        return None

    en_text = re.sub(r"[^" + ENGLISH + "+]", " ", six.ensure_str(text))
    en_text = re.sub(r"\s+", " ", en_text)
    if len(en_text) > 1:
        return None

    return text


chars_to_ignore_regex = '[\,\?\.\!\-\;\:\"\“\%\‘\”\�]'
def remove_special_characters(batch):
    text = re.sub(chars_to_ignore_regex, '', batch["sentence"]).lower() + " "
    text = normalizer(text)
    batch["sentence"] = text
    return batch

# We need to read the aduio files as arrays
def speech_file_to_array_fn(batch):
    speech_array, sampling_rate = torchaudio.load(batch["path"])
    speech_array = speech_array.squeeze().numpy()
    speech_array = librosa.resample(np.asarray(speech_array), sampling_rate, 16_000)

    batch["speech"] = speech_array
    return batch

def predict(batch):
    features = processor(batch["speech"], sampling_rate=16_000, return_tensors="pt", padding=True)

    input_values = features.input_values.to(device)
    attention_mask = features.attention_mask.to(device)

    with torch.no_grad():
        logits = model(input_values, attention_mask=attention_mask).logits 
        
    pred_ids = torch.argmax(logits, dim=-1)

    batch["predicted"] = processor.batch_decode(pred_ids)[0]
    return batch

dataset = dataset.map(remove_special_characters)
dataset = dataset.map(speech_file_to_array_fn, remove_columns=list(set(dataset.column_names) - set(['sentence', 'path'])))
result = dataset.map(predict)

Prediction

max_items = np.random.randint(0, len(result), 20).tolist()
for i in max_items:
    reference, predicted =  result["sentence"][i], result["predicted"][i]
    print("reference:", reference)
    print("predicted:", predicted)
    print('---')
reference: اطلاعات مسری است
predicted: اطلاعات مسری است
---
reference: نه منظورم اینه که وقتی که ساکته چه کاریه خودمونه بندازیم زحمت
predicted: نه منظورم اینه که وقتی که ساکت چی کاریه خودمونو بندازیم زحمت
---
reference: من آب پرتقال می خورم لطفا
predicted: من آپ ارتغال می خورم لطفا
---
reference: وقت آن رسیده آنها را که قدم پیش میگذارند بزرگ بداریم
predicted: وقت آ رسیده آنها را که قدم پیش میگذارند بزرگ بداریم
---
reference: سیم باتری دارید
predicted: سیم باتری دارید
---
reference: این بهتره تا اینکه به بهونه درس و مشق هر روز بره خونه شون
predicted: این بهتره تا اینکه به بهمونه درسومش خرروز بره خونه اشون
---
reference: ژاکت تنگ است
predicted: ژاکت تنگ است
---
reference: آت و اشغال های خیابان
predicted: آت و اشغال های خیابان
---
reference: من به این روند اعتراض دارم
predicted: من به این لوند تراج دارم
---
reference: کرایه این مکان چند است
predicted: کرایه این مکان چند است
---
reference: ولی این فرصت این سهم جوانی اعطا نشده است
predicted: ولی این فرصت این سحم جوانی اتان نشده است
---
reference: متوجه فاجعهای محیطی میشوم
predicted: متوجه فاجایهای محیطی میشوم
---
reference: ترافیک شدیدیم بود و دیدن نور ماشینا و چراغا و لامپهای مراکز تجاری حس خوبی بهم میدادن
predicted: ترافیک شدید ی هم بودا دیدن نور ماشینا و چراغ لامپهای مراکز تجاری حس خولی بهم میدادن
---
reference: این مورد عمل ها مربوط به تخصص شما می شود
predicted: این مورد عملها مربوط به تخصص شما میشود
---
reference: انرژی خیلی کمی دارم
predicted: انرژی خیلی کمی دارم
---
reference: زیادی خوبی کردنم تهش داستانه
predicted: زیادی خوبی کردنم ترش داستانه
---
reference: بردهای که پادشاه شود
predicted: برده ای که پاده شاه شود
---
reference: یونسکو
predicted: یونسکو
---
reference: شما اخراج هستید
predicted: شما اخراج هستید
---
reference: من سفر کردن را دوست دارم
predicted: من سفر کردم را دوست دارم

Evaluation

!mkdir cer
!wget -O cer/cer.py https://huggingface.co/ctl/wav2vec2-large-xlsr-cantonese/raw/main/cer.py

wer = load_metric("wer")
cer = load_metric("./cer")

print("WER: {:.2f}".format(100 * wer.compute(predictions=result["predicted"], references=result["sentence"])))
print("CER: {:.2f}".format(100 * cer.compute(predictions=result["predicted"], references=result["sentence"])))

Test Result:

  • WER: 32.09%
  • CER: 8.23%

Training

The Common Voice train, validation datasets were used for training. The script used for training can be found here