File size: 1,722 Bytes
5734b5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
# Convert whisper into smaller model using layer pruning

import torch
from transformers import WhisperProcessor, GenerationConfig, WhisperForConditionalGeneration, WhisperTokenizer

TEACHER_CKPT = "large-v2"
DECODER_LAYERS = 8
SAVE_DIR = "."
CACHE_DIR = "."

teacher_model = WhisperForConditionalGeneration.from_pretrained(TEACHER_CKPT, cache_dir=CACHE_DIR)

teacher_config = teacher_model.config
teacher_layers = teacher_config.decoder_layers
student_config = teacher_config
student_config.decoder_layers = DECODER_LAYERS

mapping = [0, 1, 4, 8, 16, 24, 30, 31] # mapping 8 teacher decoder layers to student model
assert DECODER_LAYERS == len(mapping)

student_model = WhisperForConditionalGeneration(student_config)

# copy layers
info = student_model.load_state_dict(teacher_model.state_dict(), strict=False)

# make sure entire encoder is copied
for s,t in zip(student_model.model.encoder.parameters(), teacher_model.model.encoder.parameters()):
    assert torch.equal(s.data, t.data)

# copy decoder layers
# has to be strict match: <All keys matched successfully>
layers_to_copy = torch.nn.ModuleList([teacher_model.model.decoder.layers[i] for i in mapping])
student_model.model.decoder.layers.load_state_dict(layers_to_copy.state_dict())

# save model
student_model.save_pretrained(SAVE_DIR)

# also save processor, generation config and tokenizer
processor = WhisperProcessor.from_pretrained(TEACHER_CKPT, cache_dir=CACHE_DIR)
processor.save_pretrained(SAVE_DIR)
generation_config = GenerationConfig.from_pretrained(TEACHER_CKPT, cache_dir=CACHE_DIR)
generation_config.save_pretrained(SAVE_DIR)
tokenizer = WhisperTokenizer.from_pretrained(TEACHER_CKPT, cache_dir=CACHE_DIR)
tokenizer.save_pretrained(SAVE_DIR)