import argparse import os from regex import R abs_path = os.path.abspath('.') # base_dir = os.path.dirname(os.path.dirname(abs_path)) base_dir = os.path.dirname(abs_path) os.environ['TRANSFORMERS_CACHE'] = os.path.join(base_dir, 'models_cache') os.environ['TRANSFORMERS_OFFLINE'] = '0' os.environ['HF_DATASETS_CACHE'] = os.path.join(base_dir, 'datasets_cache') os.environ['HF_DATASETS_OFFLINE'] = '0' from transformers import pipeline from transformers.models.whisper.english_normalizer import BasicTextNormalizer from datasets import load_dataset, Audio from bnunicodenormalizer import Normalizer import evaluate import unicodedata wer_metric = evaluate.load("wer", cache_dir=os.path.join(base_dir, "metrics_cache")) cer_metric = evaluate.load("cer", cache_dir=os.path.join(base_dir, "metrics_cache")) def is_target_text_in_range(ref): if ref.strip() == "ignore time segment in scoring": return False else: return ref.strip() != "" def get_text(sample): if "text" in sample: return sample["text"] elif "sentence" in sample: return sample["sentence"] elif "normalized_text" in sample: return sample["normalized_text"] elif "transcript" in sample: return sample["transcript"] elif "transcription" in sample: return sample["transcription"] else: raise ValueError( "Expected transcript column of either 'text', 'sentence', 'normalized_text' or 'transcript'. Got sample of " ".join{sample.keys()}. Ensure a text column name is present in the dataset." ) whisper_norm = BasicTextNormalizer() bangla_normalizer = Normalizer(allow_english=True) def normalise(batch): batch["norm_text"] = whisper_norm(get_text(batch)) return batch def removeOptionalZW(text): """ Removes all optional occurrences of ZWNJ or ZWJ from Bangla text. """ # Regex for matching zero witdh joiner variations. STANDARDIZE_ZW = re.compile(r'(?<=\u09b0)[\u200c\u200d]+(?=\u09cd\u09af)') # Regex for removing standardized zero width joiner, except in edge cases. DELETE_ZW = re.compile(r'(?