whisper-base-belarusian / src /preprocess.py
ales's picture
Training in progress, step 1000
abbb14d
raw
history blame
1.18 kB
from transformers import AutoFeatureExtractor, AutoTokenizer
pretrained_name = 'openai/whisper-base'
feature_extractor = AutoFeatureExtractor.from_pretrained(pretrained_name)
tokenizer = AutoTokenizer.from_pretrained(pretrained_name)
def prepare_dataset(
sample: dict,
labels_max_len: int = None,
):
sample = sample['audio']
inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])
sample['input_features'] = inputs.get('input_features')[0]
sample["input_length"] = len(sample["array"])
input_str = sample['sentence']
sample['labels'] = tokenizer(input_str).input_ids
sample['labels_length'] = len(sample['labels']) # include special characters
sample['labels_truncated'] = 0
# need to truncate validation and test labels that are longer that model.config.max_length.
# can't drop such examples because this will affect validation and test scores.
# thus need to truncate.
if labels_max_len is not None:
if len(sample['labels']) > labels_max_len:
sample['labels'] = sample['labels'][:labels_max_len]
sample['labels_truncated'] = 1
return sample