sanchit-gandhi's picture
Update README.md
6652985
metadata
language:
  - en
tags:
  - automatic-speech-recognition
datasets:
  - LIUM/tedlium
license: cc-by-4.0
metrics:
  - name: Dev WER
    type: wer
    value: 9
  - name: Test WER
    type: wer
    value: 6.4

Wav2Vec2-2-Bart-Large-Tedlium

This model is a sequence-2-sequence (seq2seq) model trained on the TEDLIUM corpus (release 3).

It combines a speech encoder with a text decoder to perform automatic speech recognition. The encoder weights are initialised with the Wav2Vec2 LV-60k checkpoint from @facebook. The decoder weights are initialised with the Bart large checkpoint from @facebook.

When using the model, make sure that your speech input is sampled at 16Khz.

The model achieves a word error rate (WER) of 9.0% on the dev set and 6.4% on the test set. Training logs document the training and evaluation progress over 50k steps of fine-tuning.

Usage

To transcribe audio files the model can be used as a standalone acoustic model as follows:

 from transformers import AutoProcessor, SpeechEncoderDecoderModel
 from datasets import load_dataset
 import torch
 
 # load model and processor
 processor = AutoProcessor.from_pretrained("sanchit-gandhi/wav2vec2-2-bart-large-tedlium")
 model = SpeechEncoderDecoderModel.from_pretrained("sanchit-gandhi/wav2vec2-2-bart-large-tedlium")
     
 # load dummy dataset
 ds = load_dataset("sanchit-gandhi/tedlium_dummy", split="validation")
 
 # process audio inputs
 input_values = processor(ds[0]["audio"]["array"], return_tensors="pt", padding="longest").input_values  # Batch size 1
 
 # run inference (greedy search)
 generated = model.generate(input_values)
 
 # decode
 decoded = processor.batch_decode(generated, skip_special_tokens=True)
 print("Target: ", ds["text"][0])
 print("Transcription: ", decoded[0])

Evaluation

This code snippet shows how to evaluate Wav2Vec2-Large-Tedlium on the TEDLIUM test data.

from datasets import load_dataset
from transformers import AutoProcessor, SpeechEncoderDecoderModel
import torch
from jiwer import wer

tedlium_eval = load_dataset("LIUM/tedlium", "release3", split="test")

def filter_ds(text):
    return text != "ignore_time_segment_in_scoring"

# remove samples ignored from scoring
tedlium_eval = tedlium_eval.map(filter_ds, input_columns=["text"])

model = SpeechEncoderDecoderModel.from_pretrained("sanchit-gandhi/wav2vec2-2-bart-large-tedlium").to("cuda")
processor = AutoProcessor.from_pretrained("sanchit-gandhi/wav2vec2-2-bart-large-tedlium")

gen_kwargs = {
        "max_length": 200,
        "num_beams": 5,
        "length_penalty": 1.2
        }

def map_to_pred(batch):
    input_values = processor(batch["audio"]["array"], return_tensors="pt", padding="longest").input_values
    with torch.no_grad():
        generated = model.generate(input_values.to("cuda"), **gen_kwargs)
    decoded = processor.batch_decode(generated, skip_special_tokens=True)
    batch["transcription"] = decoded[0]
    return batch

result = tedlium_eval.map(map_to_pred, batched=True, batch_size=1, remove_columns=["speech"])
print("WER:", wer(result["text"], result["transcription"]))