from transformers import AutoFeatureExtractor, AutoTokenizer pretrained_name = 'openai/whisper-base' feature_extractor = AutoFeatureExtractor.from_pretrained(pretrained_name) tokenizer = AutoTokenizer.from_pretrained(pretrained_name) def prepare_dataset( sample: dict, labels_max_len: int = None, ): sample = sample['audio'] inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"]) sample['input_features'] = inputs.get('input_features')[0] sample["input_length"] = len(sample["array"]) input_str = sample['sentence'] sample['labels'] = tokenizer(input_str).input_ids sample['labels_length'] = len(sample['labels']) # include special characters sample['labels_truncated'] = 0 # need to truncate validation and test labels that are longer that model.config.max_length. # can't drop such examples because this will affect validation and test scores. # thus need to truncate. if labels_max_len is not None: if len(sample['labels']) > labels_max_len: sample['labels'] = sample['labels'][:labels_max_len] sample['labels_truncated'] = 1 return sample