bofenghuang's picture
updt README.md
6dbf4c2
|
raw
history blame
4.1 kB
metadata
language:
  - fr
license: apache-2.0
tags:
  - automatic-speech-recognition
  - mozilla-foundation/common_voice_9_0
  - generated_from_trainer
  - hf-asr-leaderboard
  - robust-speech-event
datasets:
  - common_voice
  - mozilla-foundation/common_voice_9_0
model-index:
  - name: Fine-tuned Wav2Vec2 XLS-R 1B model for ASR in French
    results:
      - task:
          name: Automatic Speech Recognition
          type: automatic-speech-recognition
        dataset:
          name: Common Voice 9
          type: mozilla-foundation/common_voice_9_0
          args: fr
        metrics:
          - name: Test WER
            type: wer
            value: 12.72
          - name: Test CER
            type: cer
            value: 3.78
          - name: Test WER (+LM)
            type: wer
            value: 10.6
          - name: Test CER (+LM)
            type: cer
            value: 3.41
      - task:
          name: Automatic Speech Recognition
          type: automatic-speech-recognition
        dataset:
          name: Robust Speech Event - Dev Data
          type: speech-recognition-community-v2/dev_data
          args: fr
        metrics:
          - name: Test WER
            type: wer
            value: 24.28
          - name: Test CER
            type: cer
            value: 11.46
          - name: Test WER (+LM)
            type: wer
            value: 20.85
          - name: Test CER (+LM)
            type: cer
            value: 11.09

Fine-tuned Wav2Vec2 XLS-R 1B model for ASR in French

This model is a fine-tuned version of facebook/wav2vec2-xls-r-1b on the MOZILLA-FOUNDATION/COMMON_VOICE_9_0 - FR dataset.

Usage

  1. To use on a local audio file without the language model
import torch
import torchaudio

from transformers import AutoModelForCTC, Wav2Vec2Processor

processor = Wav2Vec2Processor.from_pretrained("bhuang/wav2vec2-xls-r-1b-cv9-fr")
model = AutoModelForCTC.from_pretrained("bhuang/wav2vec2-xls-r-1b-cv9-fr").cuda()

# path to your audio file
wav_path = "example.wav"
waveform, sample_rate = torchaudio.load(wav_path)
waveform = waveform.squeeze(axis=0)  # mono

# resample
if sample_rate != 16_000:
    resampler = torchaudio.transforms.Resample(sample_rate, 16_000)
    waveform = resampler(waveform)

# normalize
input_dict = processor(waveform, sampling_rate=16_000, return_tensors="pt")

with torch.inference_mode():
    logits = model(input_dict.input_values.to("cuda")).logits

# decode
predicted_ids = torch.argmax(logits, dim=-1)
predicted_sentence = processor.batch_decode(predicted_ids)[0]
  1. To use on a local audio file with the language model
import torch
import torchaudio

from transformers import AutoModelForCTC, Wav2Vec2ProcessorWithLM

processor_with_lm = Wav2Vec2ProcessorWithLM.from_pretrained("bhuang/wav2vec2-xls-r-1b-cv9-fr")
model = AutoModelForCTC.from_pretrained("bhuang/wav2vec2-xls-r-1b-cv9-fr").cuda()

model_sampling_rate = processor_with_lm.feature_extractor.sampling_rate

# path to your audio file
wav_path = "example.wav"
waveform, sample_rate = torchaudio.load(wav_path)
waveform = waveform.squeeze(axis=0)  # mono

# resample
if sample_rate != 16_000:
    resampler = torchaudio.transforms.Resample(sample_rate, 16_000)
    waveform = resampler(waveform)

# normalize
input_dict = processor_with_lm(waveform, sampling_rate=16_000, return_tensors="pt")

with torch.inference_mode():
    logits = model(input_dict.input_values.to("cuda")).logits

predicted_sentence = processor_with_lm.batch_decode(logits.cpu().numpy()).text[0]

Evaluation

  1. To evaluate on mozilla-foundation/common_voice_9_0
python eval.py \
  --model_id "bhuang/wav2vec2-xls-r-1b-cv9-fr" \
  --dataset "mozilla-foundation/common_voice_9_0" \
  --config "fr" \
  --split "test" \
  --log_outputs \
  --outdir "outputs/results_mozilla-foundatio_common_voice_9_0_with_lm"
  1. To evaluate on speech-recognition-community-v2/dev_data
python eval.py \
  --model_id "bhuang/wav2vec2-xls-r-1b-cv9-fr" \
  --dataset "speech-recognition-community-v2/dev_data" \
  --config "fr" \
  --split "validation" \
  --chunk_length_s 5.0 \
  --stride_length_s 1.0 \
  --log_outputs \
  --outdir "outputs/results_speech-recognition-community-v2_dev_data_with_lm"