|
import numpy as np |
|
import random |
|
import pandas as pd |
|
import re |
|
import json |
|
import torch |
|
import argparse |
|
|
|
|
|
from datasets import load_dataset, load_metric, Audio |
|
from transformers import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor, Wav2Vec2Processor, Wav2Vec2ForCTC, TrainingArguments, Trainer |
|
from dataclasses import dataclass, field |
|
from typing import Dict, List, Union |
|
from IPython.display import display, HTML |
|
|
|
|
|
|
|
class dataset_gen: |
|
|
|
def __init__(self,processor): |
|
self.processor = processor |
|
|
|
def prepare_dataset(self,batch): |
|
audio = batch["audio"] |
|
|
|
batch["input_values"] = self.processor(audio["array"], sampling_rate=audio["sampling_rate"]).input_values[0] |
|
batch["input_length"] = len(batch["input_values"]) |
|
|
|
with self.processor.as_target_processor(): |
|
batch["labels"] = self.processor(batch["sentence"]).input_ids |
|
return batch |
|
|
|
@dataclass |
|
class DataCollatorCTCWithPadding: |
|
|
|
processor: Wav2Vec2Processor |
|
padding: Union[bool, str] = True |
|
|
|
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: |
|
|
|
|
|
input_features = [{"input_values": feature["input_values"]} for feature in features] |
|
label_features = [{"input_ids": feature["labels"]} for feature in features] |
|
|
|
batch = self.processor.pad( |
|
input_features, |
|
padding=self.padding, |
|
return_tensors="pt", |
|
) |
|
with self.processor.as_target_processor(): |
|
labels_batch = self.processor.pad( |
|
label_features, |
|
padding=self.padding, |
|
return_tensors="pt", |
|
) |
|
|
|
|
|
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100) |
|
|
|
batch["labels"] = labels |
|
|
|
return batch |
|
|
|
class metrics: |
|
|
|
def __init__(self,processor,wer_metric): |
|
self.processor = processor |
|
self.wer_metric = wer_metric |
|
|
|
def compute_metrics(self,pred): |
|
pred_logits = pred.predictions |
|
pred_ids = np.argmax(pred_logits, axis=-1) |
|
|
|
pred.label_ids[pred.label_ids == -100] = self.processor.tokenizer.pad_token_id |
|
|
|
pred_str = self.processor.batch_decode(pred_ids) |
|
|
|
label_str = self.processor.batch_decode(pred.label_ids, group_tokens=False) |
|
|
|
wer = self.wer_metric.compute(predictions=pred_str, references=label_str) |
|
|
|
return {"wer": wer} |
|
|
|
def show_random_elements(dataset, num_examples=10): |
|
assert num_examples <= len(dataset) |
|
picks = [] |
|
for _ in range(num_examples): |
|
pick = random.randint(0, len(dataset)-1) |
|
while pick in picks: |
|
pick = random.randint(0, len(dataset)-1) |
|
picks.append(pick) |
|
|
|
df = pd.DataFrame(dataset[picks]) |
|
display(HTML(df.to_html())) |
|
|
|
|
|
def remove_special_characters(batch): |
|
chars_to_remove_regex = '[\,\?\.\!\-\;\:\"\“\%\‘\”\�\'\&\/\d\_\\\]' |
|
batch["sentence"] = re.sub(chars_to_remove_regex, '', batch["sentence"]).lower() |
|
batch["sentence"] = re.sub('\u200c', '', batch["sentence"]) |
|
batch["sentence"] = re.sub('[a-z]', '', batch["sentence"]) |
|
return batch |
|
|
|
|
|
def extract_all_chars(batch): |
|
all_text = " ".join(batch["sentence"]) |
|
vocab = list(set(all_text)) |
|
return {"vocab": [vocab], "all_text": [all_text]} |
|
|
|
|
|
def preprocess_labels(telugu_train,telugu_test): |
|
|
|
|
|
telugu_train = telugu_train.map(remove_special_characters) |
|
telugu_test = telugu_test.map(remove_special_characters) |
|
|
|
|
|
vocab_train = telugu_train.map(extract_all_chars, batched=True, batch_size=-1, |
|
keep_in_memory=True, remove_columns=telugu_train.column_names) |
|
vocab_test = telugu_test.map(extract_all_chars, batched=True, batch_size=-1, |
|
keep_in_memory=True, remove_columns=telugu_test.column_names) |
|
|
|
vocab_list = list(set(vocab_train["vocab"][0]) | set(vocab_test["vocab"][0])) |
|
|
|
vocab_dict = {v: k for k, v in enumerate(sorted(vocab_list))} |
|
vocab_dict["|"] = vocab_dict[" "] |
|
del vocab_dict[" "] |
|
|
|
vocab_dict["[UNK]"] = len(vocab_dict) |
|
vocab_dict["[PAD]"] = len(vocab_dict) |
|
|
|
with open('vocab.json', 'w') as vocab_file: |
|
json.dump(vocab_dict, vocab_file) |
|
|
|
return telugu_train, telugu_test |
|
|
|
|
|
def preprocess_audio(telugu_train,telugu_test,processor): |
|
|
|
telugu_train = telugu_train.cast_column("audio", Audio(sampling_rate=16_000)) |
|
telugu_test = telugu_test.cast_column("audio", Audio(sampling_rate=16_000)) |
|
|
|
dataset = dataset_gen(processor) |
|
|
|
telugu_train = telugu_train.map(dataset.prepare_dataset, remove_columns=telugu_train.column_names) |
|
telugu_test = telugu_test.map(dataset.prepare_dataset, remove_columns=telugu_test.column_names) |
|
|
|
return telugu_train, telugu_test |
|
|
|
def main(args): |
|
|
|
repo_name = args.repo_name |
|
telugu_dataset = load_dataset(args.dataset, args.config) |
|
train_testvalid = telugu_dataset['train'].train_test_split(test_size=args.test_split_size) |
|
telugu_train = train_testvalid["train"] |
|
telugu_test = train_testvalid["test"] |
|
|
|
telugu_train,telugu_test = preprocess_labels(telugu_train,telugu_test) |
|
|
|
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained("./", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|") |
|
|
|
tokenizer.push_to_hub(repo_name) |
|
|
|
feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=True) |
|
|
|
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer) |
|
|
|
telugu_train,telugu_test = preprocess_audio(telugu_train,telugu_test,processor) |
|
|
|
|
|
data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True) |
|
|
|
wer_metric = load_metric("wer") |
|
|
|
metric = metrics(processor,wer_metric) |
|
|
|
model = Wav2Vec2ForCTC.from_pretrained( |
|
args.model_id, |
|
attention_dropout=0.0, |
|
hidden_dropout=0.0, |
|
feat_proj_dropout=0.0, |
|
mask_time_prob=0.05, |
|
layerdrop=0.0, |
|
ctc_loss_reduction="mean", |
|
pad_token_id=processor.tokenizer.pad_token_id, |
|
vocab_size=len(processor.tokenizer), |
|
) |
|
|
|
model.freeze_feature_extractor() |
|
|
|
training_args = TrainingArguments( |
|
output_dir=repo_name, |
|
group_by_length=True, |
|
per_device_train_batch_size=16, |
|
gradient_accumulation_steps=2, |
|
evaluation_strategy="steps", |
|
num_train_epochs=args.epochs, |
|
gradient_checkpointing=True, |
|
fp16=True, |
|
save_steps=400, |
|
eval_steps=400, |
|
logging_steps=400, |
|
learning_rate=3e-4, |
|
warmup_steps=500, |
|
save_total_limit=2, |
|
push_to_hub=True, |
|
) |
|
|
|
trainer = Trainer( |
|
model=model, |
|
data_collator=data_collator, |
|
args=training_args, |
|
compute_metrics=metric.compute_metrics, |
|
train_dataset=telugu_train, |
|
eval_dataset=telugu_test, |
|
tokenizer=processor.feature_extractor, |
|
) |
|
|
|
output = trainer.train() |
|
print(output) |
|
trainer.push_to_hub() |
|
|
|
if __name__=="__main__": |
|
parser = argparse.ArgumentParser() |
|
|
|
parser.add_argument( |
|
"--model_id", type=str, required=True, default="facebook/wav2vec2-large-xlsr-53", help="Model identifier. Should be loadable with 🤗 Transformers" |
|
) |
|
parser.add_argument( |
|
"--dataset", type=str, required=True, default="openslr", help="Dataset name to evaluate the `model_id`. Should be loadable with 🤗 Datasets" |
|
) |
|
parser.add_argument( |
|
"--config", type=str, required=True, default="SLR66", help="Config of the dataset. *E.g.* `'en'` for Common Voice" |
|
) |
|
parser.add_argument( |
|
"--num_epochs", type=int, required =False, help="Number of epochs for training" |
|
) |
|
parser.add_argument( |
|
"--repo_name", type=str, help="Name of the repo for storing files" |
|
) |
|
parser.add_argument( |
|
"--test_split_size", type= int, default=0.25, required=False, help="split size for test set from dataset" |
|
) |
|
|
|
args = parser.parse_args() |
|
main(args) |