import numpy as np import random import pandas as pd import re import json import torch import argparse from datasets import load_dataset, load_metric, Audio from transformers import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor, Wav2Vec2Processor, Wav2Vec2ForCTC, TrainingArguments, Trainer from dataclasses import dataclass, field from typing import Dict, List, Union from IPython.display import display, HTML class dataset_gen: def __init__(self,processor): self.processor = processor def prepare_dataset(self,batch): audio = batch["audio"] batch["input_values"] = self.processor(audio["array"], sampling_rate=audio["sampling_rate"]).input_values[0] batch["input_length"] = len(batch["input_values"]) with self.processor.as_target_processor(): batch["labels"] = self.processor(batch["sentence"]).input_ids return batch @dataclass class DataCollatorCTCWithPadding: processor: Wav2Vec2Processor padding: Union[bool, str] = True def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: # split inputs and labels since they have to be of different lenghts and need # different padding methods input_features = [{"input_values": feature["input_values"]} for feature in features] label_features = [{"input_ids": feature["labels"]} for feature in features] batch = self.processor.pad( input_features, padding=self.padding, return_tensors="pt", ) with self.processor.as_target_processor(): labels_batch = self.processor.pad( label_features, padding=self.padding, return_tensors="pt", ) # replace padding with -100 to ignore loss correctly labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100) batch["labels"] = labels return batch class metrics: def __init__(self,processor,wer_metric): self.processor = processor self.wer_metric = wer_metric def compute_metrics(self,pred): pred_logits = pred.predictions pred_ids = np.argmax(pred_logits, axis=-1) pred.label_ids[pred.label_ids == -100] = self.processor.tokenizer.pad_token_id pred_str = self.processor.batch_decode(pred_ids) # we do not want to group tokens when computing the metrics label_str = self.processor.batch_decode(pred.label_ids, group_tokens=False) wer = self.wer_metric.compute(predictions=pred_str, references=label_str) return {"wer": wer} def show_random_elements(dataset, num_examples=10): assert num_examples <= len(dataset) picks = [] for _ in range(num_examples): pick = random.randint(0, len(dataset)-1) while pick in picks: pick = random.randint(0, len(dataset)-1) picks.append(pick) df = pd.DataFrame(dataset[picks]) display(HTML(df.to_html())) def remove_special_characters(batch): chars_to_remove_regex = '[\,\?\.\!\-\;\:\"\“\%\‘\”\�\'\&\/\d\_\\\]' batch["sentence"] = re.sub(chars_to_remove_regex, '', batch["sentence"]).lower() batch["sentence"] = re.sub('\u200c', '', batch["sentence"]) batch["sentence"] = re.sub('[a-z]', '', batch["sentence"]) return batch def extract_all_chars(batch): all_text = " ".join(batch["sentence"]) vocab = list(set(all_text)) return {"vocab": [vocab], "all_text": [all_text]} def preprocess_labels(telugu_train,telugu_test): telugu_train = telugu_train.map(remove_special_characters) telugu_test = telugu_test.map(remove_special_characters) vocab_train = telugu_train.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=telugu_train.column_names) vocab_test = telugu_test.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=telugu_test.column_names) vocab_list = list(set(vocab_train["vocab"][0]) | set(vocab_test["vocab"][0])) vocab_dict = {v: k for k, v in enumerate(sorted(vocab_list))} vocab_dict["|"] = vocab_dict[" "] del vocab_dict[" "] vocab_dict["[UNK]"] = len(vocab_dict) vocab_dict["[PAD]"] = len(vocab_dict) with open('vocab.json', 'w') as vocab_file: json.dump(vocab_dict, vocab_file) return telugu_train, telugu_test def preprocess_audio(telugu_train,telugu_test,processor): telugu_train = telugu_train.cast_column("audio", Audio(sampling_rate=16_000)) telugu_test = telugu_test.cast_column("audio", Audio(sampling_rate=16_000)) dataset = dataset_gen(processor) telugu_train = telugu_train.map(dataset.prepare_dataset, remove_columns=telugu_train.column_names) telugu_test = telugu_test.map(dataset.prepare_dataset, remove_columns=telugu_test.column_names) return telugu_train, telugu_test def main(args): repo_name = args.repo_name telugu_dataset = load_dataset(args.dataset, args.config) train_testvalid = telugu_dataset['train'].train_test_split(test_size=args.test_split_size) telugu_train = train_testvalid["train"] telugu_test = train_testvalid["test"] telugu_train,telugu_test = preprocess_labels(telugu_train,telugu_test) tokenizer = Wav2Vec2CTCTokenizer.from_pretrained("./", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|") tokenizer.push_to_hub(repo_name) feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=True) processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer) telugu_train,telugu_test = preprocess_audio(telugu_train,telugu_test,processor) data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True) wer_metric = load_metric("wer") metric = metrics(processor,wer_metric) model = Wav2Vec2ForCTC.from_pretrained( args.model_id, attention_dropout=0.0, hidden_dropout=0.0, feat_proj_dropout=0.0, mask_time_prob=0.05, layerdrop=0.0, ctc_loss_reduction="mean", pad_token_id=processor.tokenizer.pad_token_id, vocab_size=len(processor.tokenizer), ) model.freeze_feature_extractor() training_args = TrainingArguments( output_dir=repo_name, group_by_length=True, per_device_train_batch_size=16, gradient_accumulation_steps=2, evaluation_strategy="steps", num_train_epochs=args.epochs, gradient_checkpointing=True, fp16=True, save_steps=400, eval_steps=400, logging_steps=400, learning_rate=3e-4, warmup_steps=500, save_total_limit=2, push_to_hub=True, ) trainer = Trainer( model=model, data_collator=data_collator, args=training_args, compute_metrics=metric.compute_metrics, train_dataset=telugu_train, eval_dataset=telugu_test, tokenizer=processor.feature_extractor, ) output = trainer.train() print(output) trainer.push_to_hub() if __name__=="__main__": parser = argparse.ArgumentParser() parser.add_argument( "--model_id", type=str, required=True, default="facebook/wav2vec2-large-xlsr-53", help="Model identifier. Should be loadable with 🤗 Transformers" ) parser.add_argument( "--dataset", type=str, required=True, default="openslr", help="Dataset name to evaluate the `model_id`. Should be loadable with 🤗 Datasets" ) parser.add_argument( "--config", type=str, required=True, default="SLR66", help="Config of the dataset. *E.g.* `'en'` for Common Voice" ) parser.add_argument( "--num_epochs", type=int, required =False, help="Number of epochs for training" ) parser.add_argument( "--repo_name", type=str, help="Name of the repo for storing files" ) parser.add_argument( "--test_split_size", type= int, default=0.25, required=False, help="split size for test set from dataset" ) args = parser.parse_args() main(args)