from transformers import AutoFeatureExtractor, AutoTokenizer | |
pretrained_name = 'openai/whisper-base' | |
feature_extractor = AutoFeatureExtractor.from_pretrained(pretrained_name) | |
tokenizer = AutoTokenizer.from_pretrained(pretrained_name) | |
def prepare_dataset( | |
sample: dict, | |
labels_max_len: int = None, | |
): | |
sample = sample['audio'] | |
inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"]) | |
sample['input_features'] = inputs.get('input_features')[0] | |
sample["input_length"] = len(sample["array"]) | |
input_str = sample['sentence'] | |
sample['labels'] = tokenizer(input_str).input_ids | |
sample['labels_length'] = len(sample['labels']) # include special characters | |
sample['labels_truncated'] = 0 | |
# need to truncate validation and test labels that are longer that model.config.max_length. | |
# can't drop such examples because this will affect validation and test scores. | |
# thus need to truncate. | |
if labels_max_len is not None: | |
if len(sample['labels']) > labels_max_len: | |
sample['labels'] = sample['labels'][:labels_max_len] | |
sample['labels_truncated'] = 1 | |
return sample | |