|
|
|
|
|
import torch |
|
from transformers import WhisperProcessor, GenerationConfig, WhisperForConditionalGeneration, WhisperTokenizer |
|
|
|
TEACHER_CKPT = "large-v2" |
|
DECODER_LAYERS = 8 |
|
SAVE_DIR = "." |
|
CACHE_DIR = "." |
|
|
|
teacher_model = WhisperForConditionalGeneration.from_pretrained(TEACHER_CKPT, cache_dir=CACHE_DIR) |
|
|
|
teacher_config = teacher_model.config |
|
teacher_layers = teacher_config.decoder_layers |
|
student_config = teacher_config |
|
student_config.decoder_layers = DECODER_LAYERS |
|
|
|
mapping = [0, 1, 4, 8, 16, 24, 30, 31] |
|
assert DECODER_LAYERS == len(mapping) |
|
|
|
student_model = WhisperForConditionalGeneration(student_config) |
|
|
|
|
|
info = student_model.load_state_dict(teacher_model.state_dict(), strict=False) |
|
|
|
|
|
for s,t in zip(student_model.model.encoder.parameters(), teacher_model.model.encoder.parameters()): |
|
assert torch.equal(s.data, t.data) |
|
|
|
|
|
|
|
layers_to_copy = torch.nn.ModuleList([teacher_model.model.decoder.layers[i] for i in mapping]) |
|
student_model.model.decoder.layers.load_state_dict(layers_to_copy.state_dict()) |
|
|
|
|
|
student_model.save_pretrained(SAVE_DIR) |
|
|
|
|
|
processor = WhisperProcessor.from_pretrained(TEACHER_CKPT, cache_dir=CACHE_DIR) |
|
processor.save_pretrained(SAVE_DIR) |
|
generation_config = GenerationConfig.from_pretrained(TEACHER_CKPT, cache_dir=CACHE_DIR) |
|
generation_config.save_pretrained(SAVE_DIR) |
|
tokenizer = WhisperTokenizer.from_pretrained(TEACHER_CKPT, cache_dir=CACHE_DIR) |
|
tokenizer.save_pretrained(SAVE_DIR) |
|
|
|
|
|
|