|
|
|
print('Loading Dependancies..') |
|
import os |
|
import warnings |
|
warnings.filterwarnings('ignore') |
|
warnings.simplefilter('ignore') |
|
|
|
from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2ForCTCWithAdapterSwitching |
|
from datasets import load_dataset, load_metric, Audio |
|
from datasets import Dataset |
|
from transformers import AutoProcessor, TrainingArguments, Trainer |
|
import pandas as pd |
|
from evaluate import load |
|
from dataclasses import dataclass, field |
|
from typing import Any, Dict, List, Optional, Union |
|
import argparse |
|
import numpy as np |
|
import torch |
|
from safetensors.torch import load_file |
|
import random |
|
|
|
torch.manual_seed(69) |
|
np.random.seed(69) |
|
random.seed(69) |
|
|
|
import jieba |
|
import editdistance |
|
from itertools import chain |
|
from jiwer import wer, cer |
|
from evaluate import load |
|
import time |
|
|
|
def tokenize_for_mer(text): |
|
tokens = list(filter(lambda tok: len(tok.strip()) > 0, jieba.lcut(text))) |
|
tokens = [[tok] if tok.isascii() else list(tok) for tok in tokens] |
|
return list(chain(*tokens)) |
|
|
|
def tokenize_for_cer(text): |
|
tokens = list(filter(lambda tok: len(tok.strip()) > 0, list(text))) |
|
return tokens |
|
|
|
def compute_metrics(pred): |
|
pred_logits = pred.predictions |
|
pred_ids = np.argmax(pred_logits, axis=-1) |
|
pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id |
|
|
|
pred_strs = processor.batch_decode(pred_ids) |
|
|
|
|
|
label_strs = processor.batch_decode(pred.label_ids, group_tokens=False) |
|
wer_eval = wer_metric.compute(predictions=pred_strs, references=label_strs) |
|
cer_eval = cer_metric.compute(predictions=pred_strs, references=label_strs) |
|
|
|
mixed_distance, mixed_tokens = 0, 0 |
|
char_distance, char_tokens = 0, 0 |
|
for pred_str, label_str in zip(pred_strs, label_strs): |
|
|
|
m_pred = tokenize_for_mer(pred_str) |
|
m_ref = tokenize_for_mer(label_str) |
|
mixed_distance += editdistance.distance(m_pred, m_ref) |
|
mixed_tokens += len(m_ref) |
|
|
|
c_pred = tokenize_for_cer(pred_str) |
|
c_ref = tokenize_for_cer(label_str) |
|
char_distance += editdistance.distance(c_pred, c_ref) |
|
char_tokens += len(c_ref) |
|
|
|
mer = mixed_distance / mixed_tokens |
|
cer = char_distance / char_tokens |
|
print({"mer": mer, "cer": cer, "wer_eval": wer_eval, "cer_eval" : cer_eval}) |
|
return {"mer": mer, "cer": cer, "wer_eval": wer_eval, "cer_eval" : cer_eval} |
|
|
|
@dataclass |
|
class DataCollatorCTCWithPadding: |
|
""" |
|
Data collator that will dynamically pad the inputs received. |
|
Args: |
|
processor (:class:`~transformers.Wav2Vec2Processor`) |
|
The processor used for proccessing the data. |
|
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`): |
|
Select a strategy to pad the returned sequences (according to the model's padding side and padding index) |
|
among: |
|
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single |
|
sequence if provided). |
|
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the |
|
maximum acceptable input length for the model if that argument is not provided. |
|
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of |
|
different lengths). |
|
""" |
|
|
|
processor: AutoProcessor |
|
padding: Union[bool, str] = True |
|
|
|
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: |
|
|
|
|
|
input_features = [{"input_values": feature["input_values"]} for feature in features] |
|
label_features = [{"input_ids": feature["labels"]} for feature in features] |
|
|
|
batch = self.processor.pad( |
|
input_features, |
|
padding=self.padding, |
|
return_tensors="pt", |
|
) |
|
|
|
labels_batch = self.processor.pad( |
|
labels=label_features, |
|
padding=self.padding, |
|
return_tensors="pt", |
|
) |
|
|
|
|
|
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100) |
|
|
|
batch["labels"] = labels |
|
|
|
return batch |
|
|
|
def prepare_dataset(batch): |
|
audio = batch["audio"] |
|
|
|
|
|
batch["input_values"] = processor(audio["array"], sampling_rate=audio["sampling_rate"]).input_values[0] |
|
batch["input_length"] = len(batch["input_values"]) |
|
|
|
batch["labels"] = processor(text=batch["transcripts"]).input_ids |
|
return batch |
|
|
|
|
|
def main(): |
|
global wer_metric |
|
global cer_metric |
|
|
|
wer_metric = load('wer') |
|
cer_metric = load('cer') |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--test_metadata_csv_path', default=None) |
|
parser.add_argument('--target_lang_1', required=True) |
|
parser.add_argument('--target_lang_2', required=True) |
|
parser.add_argument('--prefix_path', required=True) |
|
parser.add_argument('--checkpoint_path', required=False) |
|
parser.add_argument('--outfile_path', required=False) |
|
parser.add_argument('--batch_size', required=True, type=int) |
|
args = parser.parse_args() |
|
|
|
df_test = pd.read_csv(args.test_metadata_csv_path, usecols=['file_name', 'transcription']) |
|
|
|
prefix_path = args.prefix_path |
|
|
|
df_test['file_name'] = df_test['file_name'].map(lambda x: prefix_path + x) |
|
|
|
test_data = Dataset.from_dict({'audio' : df_test['file_name'], 'transcripts':df_test['transcription']}).cast_column("audio", Audio()) |
|
|
|
print('Loading Model and Processor..') |
|
global processor |
|
processor = AutoProcessor.from_pretrained('facebook/mms-1b-all') |
|
processor.tokenizer.set_code_switched_target_langs(args.target_lang_1, args.target_lang_2) |
|
|
|
test_data = test_data.map(prepare_dataset, remove_columns=['audio', 'transcripts'], num_proc=8) |
|
|
|
data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True) |
|
|
|
model = Wav2Vec2ForCTCWithAdapterSwitching.from_pretrained( |
|
args.checkpoint_path, |
|
attention_dropout=0.0, |
|
hidden_dropout=0.0, |
|
feat_proj_dropout=0.0, |
|
layerdrop=0.0, |
|
ctc_loss_reduction="mean", |
|
pad_token_id=processor.tokenizer.pad_token_id, |
|
vocab_size=len(processor.tokenizer), |
|
ignore_mismatched_sizes=True, |
|
use_safetensors=True |
|
) |
|
|
|
model.load_adapters_for_code_switching(args.target_lang_1, args.target_lang_2) |
|
model.load_state_dict(torch.load(args.checkpoint_path + '/pytorch_model.bin')) |
|
|
|
print(model) |
|
|
|
|
|
print('Setting up training..') |
|
training_args = TrainingArguments( |
|
output_dir = 'test/', |
|
group_by_length=True, |
|
length_column_name = 'input_length', |
|
per_device_eval_batch_size = args.batch_size, |
|
report_to="none", |
|
eval_accumulation_steps = 10 |
|
|
|
) |
|
|
|
|
|
trainer = Trainer( |
|
model=model, |
|
data_collator=data_collator, |
|
args=training_args, |
|
eval_dataset=test_data, |
|
tokenizer=processor, |
|
compute_metrics = compute_metrics |
|
) |
|
|
|
print('Staring evaluation..') |
|
|
|
results = trainer.predict(test_data) |
|
|
|
ids = torch.argmax(torch.tensor(results[0]), dim=-1) |
|
transcription = processor.batch_decode(ids) |
|
df = df_test.copy() |
|
df['mms_model_transcription'] = transcription |
|
df.to_csv(args.outfile_path) |
|
|
|
if __name__ == '__main__': |
|
|
|
main() |
|
|