wav2vec2-telugu_150 / telugu_xlmr.py
krishnateja's picture
training script
e5deaf1
raw
history blame
No virus
8.21 kB
import numpy as np
import random
import pandas as pd
import re
import json
import torch
import argparse
from datasets import load_dataset, load_metric, Audio
from transformers import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor, Wav2Vec2Processor, Wav2Vec2ForCTC, TrainingArguments, Trainer
from dataclasses import dataclass, field
from typing import Dict, List, Union
from IPython.display import display, HTML
class dataset_gen:
def __init__(self,processor):
self.processor = processor
def prepare_dataset(self,batch):
audio = batch["audio"]
batch["input_values"] = self.processor(audio["array"], sampling_rate=audio["sampling_rate"]).input_values[0]
batch["input_length"] = len(batch["input_values"])
with self.processor.as_target_processor():
batch["labels"] = self.processor(batch["sentence"]).input_ids
return batch
@dataclass
class DataCollatorCTCWithPadding:
processor: Wav2Vec2Processor
padding: Union[bool, str] = True
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
# split inputs and labels since they have to be of different lenghts and need
# different padding methods
input_features = [{"input_values": feature["input_values"]} for feature in features]
label_features = [{"input_ids": feature["labels"]} for feature in features]
batch = self.processor.pad(
input_features,
padding=self.padding,
return_tensors="pt",
)
with self.processor.as_target_processor():
labels_batch = self.processor.pad(
label_features,
padding=self.padding,
return_tensors="pt",
)
# replace padding with -100 to ignore loss correctly
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
batch["labels"] = labels
return batch
class metrics:
def __init__(self,processor,wer_metric):
self.processor = processor
self.wer_metric = wer_metric
def compute_metrics(self,pred):
pred_logits = pred.predictions
pred_ids = np.argmax(pred_logits, axis=-1)
pred.label_ids[pred.label_ids == -100] = self.processor.tokenizer.pad_token_id
pred_str = self.processor.batch_decode(pred_ids)
# we do not want to group tokens when computing the metrics
label_str = self.processor.batch_decode(pred.label_ids, group_tokens=False)
wer = self.wer_metric.compute(predictions=pred_str, references=label_str)
return {"wer": wer}
def show_random_elements(dataset, num_examples=10):
assert num_examples <= len(dataset)
picks = []
for _ in range(num_examples):
pick = random.randint(0, len(dataset)-1)
while pick in picks:
pick = random.randint(0, len(dataset)-1)
picks.append(pick)
df = pd.DataFrame(dataset[picks])
display(HTML(df.to_html()))
def remove_special_characters(batch):
chars_to_remove_regex = '[\,\?\.\!\-\;\:\"\“\%\‘\”\�\'\&\/\d\_\\\]'
batch["sentence"] = re.sub(chars_to_remove_regex, '', batch["sentence"]).lower()
batch["sentence"] = re.sub('\u200c', '', batch["sentence"])
batch["sentence"] = re.sub('[a-z]', '', batch["sentence"])
return batch
def extract_all_chars(batch):
all_text = " ".join(batch["sentence"])
vocab = list(set(all_text))
return {"vocab": [vocab], "all_text": [all_text]}
def preprocess_labels(telugu_train,telugu_test):
telugu_train = telugu_train.map(remove_special_characters)
telugu_test = telugu_test.map(remove_special_characters)
vocab_train = telugu_train.map(extract_all_chars, batched=True, batch_size=-1,
keep_in_memory=True, remove_columns=telugu_train.column_names)
vocab_test = telugu_test.map(extract_all_chars, batched=True, batch_size=-1,
keep_in_memory=True, remove_columns=telugu_test.column_names)
vocab_list = list(set(vocab_train["vocab"][0]) | set(vocab_test["vocab"][0]))
vocab_dict = {v: k for k, v in enumerate(sorted(vocab_list))}
vocab_dict["|"] = vocab_dict[" "]
del vocab_dict[" "]
vocab_dict["[UNK]"] = len(vocab_dict)
vocab_dict["[PAD]"] = len(vocab_dict)
with open('vocab.json', 'w') as vocab_file:
json.dump(vocab_dict, vocab_file)
return telugu_train, telugu_test
def preprocess_audio(telugu_train,telugu_test,processor):
telugu_train = telugu_train.cast_column("audio", Audio(sampling_rate=16_000))
telugu_test = telugu_test.cast_column("audio", Audio(sampling_rate=16_000))
dataset = dataset_gen(processor)
telugu_train = telugu_train.map(dataset.prepare_dataset, remove_columns=telugu_train.column_names)
telugu_test = telugu_test.map(dataset.prepare_dataset, remove_columns=telugu_test.column_names)
return telugu_train, telugu_test
def main(args):
repo_name = args.repo_name
telugu_dataset = load_dataset(args.dataset, args.config)
train_testvalid = telugu_dataset['train'].train_test_split(test_size=args.test_split_size)
telugu_train = train_testvalid["train"]
telugu_test = train_testvalid["test"]
telugu_train,telugu_test = preprocess_labels(telugu_train,telugu_test)
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained("./", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|")
tokenizer.push_to_hub(repo_name)
feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=True)
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
telugu_train,telugu_test = preprocess_audio(telugu_train,telugu_test,processor)
data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)
wer_metric = load_metric("wer")
metric = metrics(processor,wer_metric)
model = Wav2Vec2ForCTC.from_pretrained(
args.model_id,
attention_dropout=0.0,
hidden_dropout=0.0,
feat_proj_dropout=0.0,
mask_time_prob=0.05,
layerdrop=0.0,
ctc_loss_reduction="mean",
pad_token_id=processor.tokenizer.pad_token_id,
vocab_size=len(processor.tokenizer),
)
model.freeze_feature_extractor()
training_args = TrainingArguments(
output_dir=repo_name,
group_by_length=True,
per_device_train_batch_size=16,
gradient_accumulation_steps=2,
evaluation_strategy="steps",
num_train_epochs=args.epochs,
gradient_checkpointing=True,
fp16=True,
save_steps=400,
eval_steps=400,
logging_steps=400,
learning_rate=3e-4,
warmup_steps=500,
save_total_limit=2,
push_to_hub=True,
)
trainer = Trainer(
model=model,
data_collator=data_collator,
args=training_args,
compute_metrics=metric.compute_metrics,
train_dataset=telugu_train,
eval_dataset=telugu_test,
tokenizer=processor.feature_extractor,
)
output = trainer.train()
print(output)
trainer.push_to_hub()
if __name__=="__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_id", type=str, required=True, default="facebook/wav2vec2-large-xlsr-53", help="Model identifier. Should be loadable with 🤗 Transformers"
)
parser.add_argument(
"--dataset", type=str, required=True, default="openslr", help="Dataset name to evaluate the `model_id`. Should be loadable with 🤗 Datasets"
)
parser.add_argument(
"--config", type=str, required=True, default="SLR66", help="Config of the dataset. *E.g.* `'en'` for Common Voice"
)
parser.add_argument(
"--num_epochs", type=int, required =False, help="Number of epochs for training"
)
parser.add_argument(
"--repo_name", type=str, help="Name of the repo for storing files"
)
parser.add_argument(
"--test_split_size", type= int, default=0.25, required=False, help="split size for test set from dataset"
)
args = parser.parse_args()
main(args)