# Convert whisper into smaller model using layer pruning import torch from transformers import WhisperProcessor, GenerationConfig, WhisperForConditionalGeneration, WhisperTokenizer TEACHER_CKPT = "large-v2" DECODER_LAYERS = 8 SAVE_DIR = "." CACHE_DIR = "." teacher_model = WhisperForConditionalGeneration.from_pretrained(TEACHER_CKPT, cache_dir=CACHE_DIR) teacher_config = teacher_model.config teacher_layers = teacher_config.decoder_layers student_config = teacher_config student_config.decoder_layers = DECODER_LAYERS mapping = [0, 1, 4, 8, 16, 24, 30, 31] # mapping 8 teacher decoder layers to student model assert DECODER_LAYERS == len(mapping) student_model = WhisperForConditionalGeneration(student_config) # copy layers info = student_model.load_state_dict(teacher_model.state_dict(), strict=False) # make sure entire encoder is copied for s,t in zip(student_model.model.encoder.parameters(), teacher_model.model.encoder.parameters()): assert torch.equal(s.data, t.data) # copy decoder layers # has to be strict match: layers_to_copy = torch.nn.ModuleList([teacher_model.model.decoder.layers[i] for i in mapping]) student_model.model.decoder.layers.load_state_dict(layers_to_copy.state_dict()) # save model student_model.save_pretrained(SAVE_DIR) # also save processor, generation config and tokenizer processor = WhisperProcessor.from_pretrained(TEACHER_CKPT, cache_dir=CACHE_DIR) processor.save_pretrained(SAVE_DIR) generation_config = GenerationConfig.from_pretrained(TEACHER_CKPT, cache_dir=CACHE_DIR) generation_config.save_pretrained(SAVE_DIR) tokenizer = WhisperTokenizer.from_pretrained(TEACHER_CKPT, cache_dir=CACHE_DIR) tokenizer.save_pretrained(SAVE_DIR)