distil-whisper-large-v2-8-ls / make-student-whisper.py
rsonavane's picture
Add config
5734b5e
raw history blame
No virus
1.72 kB
# Convert whisper into smaller model using layer pruning
import torch
from transformers import WhisperProcessor, GenerationConfig, WhisperForConditionalGeneration, WhisperTokenizer
TEACHER_CKPT = "large-v2"
DECODER_LAYERS = 8
SAVE_DIR = "."
CACHE_DIR = "."
teacher_model = WhisperForConditionalGeneration.from_pretrained(TEACHER_CKPT, cache_dir=CACHE_DIR)
teacher_config = teacher_model.config
teacher_layers = teacher_config.decoder_layers
student_config = teacher_config
student_config.decoder_layers = DECODER_LAYERS
mapping = [0, 1, 4, 8, 16, 24, 30, 31] # mapping 8 teacher decoder layers to student model
assert DECODER_LAYERS == len(mapping)
student_model = WhisperForConditionalGeneration(student_config)
# copy layers
info = student_model.load_state_dict(teacher_model.state_dict(), strict=False)
# make sure entire encoder is copied
for s,t in zip(student_model.model.encoder.parameters(), teacher_model.model.encoder.parameters()):
assert torch.equal(s.data, t.data)
# copy decoder layers
# has to be strict match: <All keys matched successfully>
layers_to_copy = torch.nn.ModuleList([teacher_model.model.decoder.layers[i] for i in mapping])
student_model.model.decoder.layers.load_state_dict(layers_to_copy.state_dict())
# save model
student_model.save_pretrained(SAVE_DIR)
# also save processor, generation config and tokenizer
processor = WhisperProcessor.from_pretrained(TEACHER_CKPT, cache_dir=CACHE_DIR)
processor.save_pretrained(SAVE_DIR)
generation_config = GenerationConfig.from_pretrained(TEACHER_CKPT, cache_dir=CACHE_DIR)
generation_config.save_pretrained(SAVE_DIR)
tokenizer = WhisperTokenizer.from_pretrained(TEACHER_CKPT, cache_dir=CACHE_DIR)
tokenizer.save_pretrained(SAVE_DIR)