oza75's picture
Update README.md
626146d verified
metadata
license: apache-2.0
base_model: oza75/whisper-bambara-asr-002
tags:
  - generated_from_trainer
metrics:
  - wer
model-index:
  - name: whisper-bambara-asr-002
    results: []
datasets:
  - oza75/bambara-asr
language:
  - bm
library_name: transformers

whisper-bambara-asr-002

This model is a fine-tuned version of oza75/whisper-bambara-asr-002 on an Bambara ASR dataset. It achieves the following results on the evaluation set (Evaluation set was very small for training speed purpose):

  • Loss: 1.3643
  • Wer: 53.2541

Usage

To use this model, we first need to define a Tokenizer class because the default Whisper tokenizer does not support Bambara.

IMPORTANT: The following code will also override the Whisper tokenizer's LANGUAGES constants. This is not the ideal approach, but it is effective. If you do not make this modification, the generation process will fail.

from typing import List

from tokenizers import AddedToken
from transformers import WhisperTokenizer, WhisperProcessor
import transformers.models.whisper.tokenization_whisper as whisper_tokenization
from transformers.models.whisper.tokenization_whisper import TO_LANGUAGE_CODE, TASK_IDS

CUSTOM_TO_LANGUAGE_CODE = {**TO_LANGUAGE_CODE, "bambara": "bm"}

# IMPORTANT: We update the whisper tokenizer constants to add Bambara Language. Not ideal but at least it works
whisper_tokenization.TO_LANGUAGE_CODE.update(CUSTOM_TO_LANGUAGE_CODE)


class BambaraWhisperTokenizer(WhisperTokenizer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.add_tokens(AddedToken(content="<|bm|>", lstrip=False, rstrip=False, normalized=False, special=True))

    @property
    def prefix_tokens(self) -> List[int]:
        bos_token_id = self.convert_tokens_to_ids("<|startoftranscript|>")
        translate_token_id = self.convert_tokens_to_ids("<|translate|>")
        transcribe_token_id = self.convert_tokens_to_ids("<|transcribe|>")
        notimestamps_token_id = self.convert_tokens_to_ids("<|notimestamps|>")

        if self.language is not None:
            self.language = self.language.lower()
            if self.language in CUSTOM_TO_LANGUAGE_CODE:
                language_id = CUSTOM_TO_LANGUAGE_CODE[self.language]
            elif self.language in CUSTOM_TO_LANGUAGE_CODE.values():
                language_id = self.language
            else:
                is_language_code = len(self.language) == 2
                raise ValueError(
                    f"Unsupported language: {self.language}. Language should be one of:"
                    f" {list(CUSTOM_TO_LANGUAGE_CODE.values()) if is_language_code else list(CUSTOM_TO_LANGUAGE_CODE.keys())}."
                )

        if self.task is not None:
            if self.task not in TASK_IDS:
                raise ValueError(f"Unsupported task: {self.task}. Task should be in: {TASK_IDS}")

        bos_sequence = [bos_token_id]
        if self.language is not None:
            bos_sequence.append(self.convert_tokens_to_ids(f"<|{language_id}|>"))
        if self.task is not None:
            bos_sequence.append(transcribe_token_id if self.task == "transcribe" else translate_token_id)
        if not self.predict_timestamps:
            bos_sequence.append(notimestamps_token_id)
        return bos_sequence

Then, we can define the pipeline:

import torch
from transformers import pipeline

# Determine the appropriate device (GPU or CPU)
device = "cuda" if torch.cuda.is_available() else "cpu"

# Define the model checkpoint and language
model_checkpoint = "oza75/whisper-bambara-asr-001"
language = "bambara"

# Load the custom tokenizer designed for Bambara and the ASR model
tokenizer = BambaraWhisperTokenizer.from_pretrained(model_checkpoint, language=language, device=device)
pipe = pipeline(model=model_checkpoint, tokenizer=tokenizer, device=device)

def transcribe(audio):
    """
    Transcribes the provided audio file into text using the configured ASR pipeline.

    Args:
        audio: The path to the audio file to transcribe.

    Returns:
        A string representing the transcribed text.
    """
    # Use the pipeline to perform transcription
    text = pipe(audio)["text"]
    return text


transcribe(path_to_the_audio)

Intended uses & limitations

This checkpoint is intended to be used ONLY for research purposes !!!

Training procedure

Training hyperparameters

The following hyperparameters were used during training:

  • learning_rate: 5e-06
  • train_batch_size: 16
  • eval_batch_size: 16
  • seed: 42
  • distributed_type: multi-GPU
  • num_devices: 3
  • gradient_accumulation_steps: 4
  • total_train_batch_size: 192
  • total_eval_batch_size: 48
  • optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
  • lr_scheduler_type: linear
  • num_epochs: 8
  • mixed_precision_training: Native AMP

Training results

Training Loss Epoch Step Validation Loss Wer
1.2108 0.7752 100 1.1191 77.3885
0.848 1.5504 200 0.8848 64.0640
0.694 2.3256 300 0.8022 62.0272
0.6062 3.1008 400 0.7607 67.0102
0.5617 3.8760 500 0.7314 59.4083
0.4565 4.6512 600 0.7334 69.0713
0.3455 5.4264 700 0.7656 59.8812
0.2621 6.2016 800 0.8062 68.1256
0.2672 6.9767 900 0.8130 56.9593
0.1916 7.7519 1000 0.8706 58.0868
0.1302 8.5271 1100 0.9390 58.2565
0.0785 9.3023 1200 0.9932 55.5286
0.0785 0.7752 1300 1.0391 55.8802
0.0495 1.5504 1400 1.0820 58.4627
0.032 2.3256 1500 1.1270 55.2498
0.026 3.1008 1600 1.1660 57.4321
0.0241 3.8760 1700 1.1738 53.5766
0.019 4.6512 1800 1.1943 53.6736
0.0149 5.4264 1900 1.2236 52.3642
0.0116 6.2016 2000 1.2549 58.8143
0.014 6.9767 2100 1.25 52.1581
0.0121 7.7519 2200 1.2627 51.3094
0.0106 8.5271 2300 1.2705 52.6673
0.0097 9.3023 2400 1.2744 53.0674
0.0079 0.7797 2500 1.2803 58.1741
0.0071 1.5595 2600 1.2979 55.2040
0.0058 2.3392 2700 1.3174 54.9199
0.0052 3.1189 2800 1.3281 56.4954
0.0056 3.8986 2900 1.3193 51.3946
0.0051 4.6784 3000 1.3291 49.9483
0.0045 5.4581 3100 1.3428 52.4019
0.0036 6.2378 3200 1.3506 49.0186
0.0041 7.0175 3300 1.3623 50.2583
0.0047 7.7973 3400 1.3643 53.2541

Framework versions

  • Transformers 4.40.1
  • Pytorch 2.3.0+cu121
  • Datasets 2.19.0
  • Tokenizers 0.19.1