#!/home/haroon/python_virtual_envs/whisper_fine_tuning/bin/python from datasets import load_dataset, DatasetDict, Audio from transformers import (WhisperTokenizer, WhisperFeatureExtractor, WhisperProcessor, WhisperForConditionalGeneration, Seq2SeqTrainingArguments, Seq2SeqTrainer) from transformers.models.whisper.english_normalizer import BasicTextNormalizer import torch from dataclasses import dataclass from typing import Any, Dict, List, Union import evaluate common_voice = DatasetDict() common_voice["train"] = load_dataset("mozilla-foundation/common_voice_11_0", "hi", split="train+validation", token=True) common_voice["test"] = load_dataset("mozilla-foundation/common_voice_11_0", "hi", split="test", token=True) print(f'YYY1a {common_voice=}') common_voice = common_voice.remove_columns([ "accent", "age", "client_id", "down_votes", "gender", "locale", "path", "segment", "up_votes"]) print(f'YYY1b {common_voice=}') print(f'YYY2 {type(common_voice)=}') feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-small") tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-small", language="Hindi", task="transcribe") processor = WhisperProcessor.from_pretrained("openai/whisper-small", language="Hindi", task="transcribe") print(common_voice["train"][0]) common_voice = common_voice.cast_column("audio", Audio(sampling_rate=16000)) print(common_voice["train"][0]) do_lower_case = False do_remove_punctuation = False normalizer = BasicTextNormalizer() def prepare_dataset(batch): audio = batch["audio"] batch["input_features"] = processor.feature_extractor( audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0] batch["input_length"] = len(audio["array"]) / audio["sampling_rate"] transcription = batch["sentence"] if do_lower_case: transcription = transcription.lower() if do_remove_punctuation: transcription = normalizer(transcription).strip() batch["labels"] = processor.tokenizer(transcription).input_ids return batch common_voice = common_voice.map(prepare_dataset, remove_columns=common_voice.column_names["train"], num_proc=2) max_input_length = 30.0 def is_audio_in_length_range(length): return length < max_input_length common_voice["train"] = common_voice["train"].filter( is_audio_in_length_range, input_columns=["input_length"], ) @dataclass class DataCollatorSpeechSeq2SeqWithPadding: processor: Any def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]])\ -> Dict[str, torch.Tensor]: input_features = [{"input_features": feature["input_features"]} for feature in features] batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt") label_features = [{"input_ids": feature["labels"]} for feature in features] labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt") labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100) if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item(): labels = labels[:, 1:] batch["labels"] = labels return batch data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor) metric = evaluate.load("wer") do_normalize_eval = True def compute_metrics(pred): pred_ids = pred.predictions label_ids = pred.label_ids label_ids[label_ids == -100] = processor.tokenizer.pad_token_id pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True) label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True) if do_normalize_eval: pred_str = [normalizer(pred) for pred in pred_str] label_str = [normalizer(label) for label in label_str] wer = 100 * metric.compute(predictions=pred_str, references=label_str) return {"wer": wer} model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small") model.generation_config.language = "hi" model.config.forced_decoder_ids = None model.config.suppress_tokens = [] model.config.use_cache = False training_args = Seq2SeqTrainingArguments( output_dir="./", per_device_train_batch_size=8, gradient_accumulation_steps=8, # increase by 2x for every 2x decrease in batch size learning_rate=1e-5, warmup_steps=500, max_steps=5000, gradient_checkpointing=True, fp16=True, evaluation_strategy="steps", per_device_eval_batch_size=4, predict_with_generate=True, generation_max_length=225, save_steps=1000, eval_steps=1000, logging_steps=25, report_to=["tensorboard"], load_best_model_at_end=True, metric_for_best_model="wer", greater_is_better=False, push_to_hub=True, ) trainer = Seq2SeqTrainer( args=training_args, model=model, train_dataset=common_voice["train"], eval_dataset=common_voice["test"], data_collator=data_collator, compute_metrics=compute_metrics, tokenizer=processor.feature_extractor, ) processor.save_pretrained(training_args.output_dir) trainer.train() kwargs = { "dataset_tags": "mozilla-foundation/common_voice_11_0", "dataset": "Common Voice 11.0", # a 'pretty' name for the training dataset "language": "hi", "model_name": "Whisper Small Hi - Sanchit Gandhi", # a 'pretty' name for your model "finetuned_from": "openai/whisper-small", "tasks": "automatic-speech-recognition", "tags": "whisper-event", } trainer.push_to_hub(**kwargs)