patrickvonplaten commited on
Commit
3c5a93a
1 Parent(s): f974bb6
Files changed (2) hide show
  1. convert.py +194 -0
  2. whisper-32-2.pt +3 -0
convert.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ """Converts a Whisper model in Hugging Face format to OpenAI format.
3
+ This script is based on the following script to do the opposite:
4
+ https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/convert_openai_to_hf.py
5
+ Requirements:
6
+ ```bash
7
+ pip install -U openai-whisper
8
+ ```
9
+ Example:
10
+ ```bash
11
+ # Converts the model from Hugging Face to OpenAI format:
12
+ python convert_hf_to_openai.py \
13
+ --checkpoint openai/whisper-tiny \
14
+ --whisper_dump_path whisper-tiny-openai.pt
15
+ ```
16
+ ```python
17
+ >>> # Disabled doctest because it requries the openai-whisper package.
18
+ >> import whisper
19
+ >> from transformers.models.whisper.convert_hf_to_openai import convert_tfms_to_openai_whisper
20
+ >> # Converts the model from Hugging Face to OpenAI format:
21
+ >> convert_tfms_to_openai_whisper(
22
+ .. "openai/whisper-tiny", "whisper-tiny-openai.pt"
23
+ .. )
24
+ HF model path: openai/whisper-tiny
25
+ OpenAI model path: whisper-tiny-openai.pt
26
+ >> # Select an audio file:
27
+ >> audio_path = "https://huggingface.co/datasets/sanchit-gandhi/librispeech_long/resolve/main/audio.wav"
28
+ >> # Load the Whisper model in OpenAI format:
29
+ >> model = whisper.load_model("whisper-tiny-openai.pt")
30
+ >> # Transcribe the audio:
31
+ >> prediction = model.transcribe(audio_path)
32
+ >> prediction["text"][:70]
33
+ ' chapter 16. I might have told you of the beginning of this liaison in'
34
+ ```
35
+ """
36
+ # Copyright 2023 Xabier de Zuazo and the Aholab team. All rights reserved.
37
+ #
38
+ # Licensed under the Apache License, Version 2.0 (the "License");
39
+ # you may not use this file except in compliance with the License.
40
+ # You may obtain a copy of the License at
41
+ #
42
+ # http://www.apache.org/licenses/LICENSE-2.0
43
+ #
44
+ # Unless required by applicable law or agreed to in writing, software
45
+ # distributed under the License is distributed on an "AS IS" BASIS,
46
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
47
+ # See the License for the specific language governing permissions and
48
+ # limitations under the License.
49
+
50
+ import argparse
51
+
52
+ import torch
53
+ from torch import nn
54
+
55
+ from transformers import WhisperConfig, WhisperForConditionalGeneration
56
+
57
+
58
+ # Create the reverse mapping adapting it from the original `WHISPER_MAPPING` in
59
+ # the `convert_openai_to_hf.py` script:
60
+ REVERSE_WHISPER_MAPPING = {
61
+ "layers": "blocks",
62
+ "fc1": "mlp.0",
63
+ "fc2": "mlp.2",
64
+ "final_layer_norm": "mlp_ln",
65
+ ".self_attn.q_proj": ".attn.query",
66
+ ".self_attn.k_proj": ".attn.key",
67
+ ".self_attn.v_proj": ".attn.value",
68
+ ".self_attn_layer_norm": ".attn_ln",
69
+ ".self_attn.out_proj": ".attn.out",
70
+ ".encoder_attn.q_proj": ".cross_attn.query",
71
+ ".encoder_attn.k_proj": ".cross_attn.key",
72
+ ".encoder_attn.v_proj": ".cross_attn.value",
73
+ ".encoder_attn_layer_norm": ".cross_attn_ln",
74
+ ".encoder_attn.out_proj": ".cross_attn.out",
75
+ "decoder.layer_norm.": "decoder.ln.",
76
+ "encoder.layer_norm.": "encoder.ln_post.",
77
+ "embed_tokens": "token_embedding",
78
+ "encoder.embed_positions.weight": "encoder.positional_embedding",
79
+ "decoder.embed_positions.weight": "decoder.positional_embedding",
80
+ }
81
+
82
+
83
+ def reverse_rename_keys(s_dict: dict) -> dict:
84
+ """Renames the keys back from Hugging Face to OpenAI Whisper format.
85
+ By using this function on an HF model's state_dict, we should get the names in the format expected by Whisper.
86
+ Args:
87
+ s_dict (`dict`): A dictionary with keys in Hugging Face format.
88
+ Returns:
89
+ `dict`: The same dictionary but in OpenAI Whisper format.
90
+ """
91
+ keys = list(s_dict.keys())
92
+ for orig_key in keys:
93
+ new_key = orig_key
94
+ for key_r, value_r in REVERSE_WHISPER_MAPPING.items():
95
+ if key_r in orig_key:
96
+ new_key = new_key.replace(key_r, value_r)
97
+
98
+ # print(f"{orig_key} -> {new_key}")
99
+
100
+ s_dict[new_key] = s_dict.pop(orig_key)
101
+ return s_dict
102
+
103
+
104
+ def make_emb_from_linear(linear: nn.Linear) -> nn.Embedding:
105
+ """Converts a linear layer's weights into an embedding layer.
106
+ The linear layer's `in_features` dimension corresponds to the vocabulary size and its `out_features` dimension
107
+ corresponds to the embedding size.
108
+ Args:
109
+ linear (`nn.Linear`): The linear layer to be converted.
110
+ Returns:
111
+ `nn.Embedding`:
112
+ An embedding layer with weights set to those of the input linear layer.
113
+ """
114
+ vocab_size, emb_size = linear.weight.data.shape
115
+ emb_layer = nn.Embedding(vocab_size, emb_size, _weight=linear.weight.data)
116
+ return emb_layer
117
+
118
+
119
+ def extract_dims_from_hf(config: WhisperConfig) -> dict:
120
+ """Extracts necessary dimensions from Hugging Face's WhisperConfig.
121
+ Extracts necessary dimensions and related configuration data from the Hugging Face model and then restructure it
122
+ for the OpenAI Whisper format.
123
+ Args:
124
+ config (`WhisperConfig`): Configuration of the Hugging Face's model.
125
+ Returns:
126
+ `dict`: The `dims` of the OpenAI Whisper model.
127
+ """
128
+ dims = {
129
+ "n_vocab": config.vocab_size,
130
+ "n_mels": config.num_mel_bins,
131
+ "n_audio_state": config.d_model,
132
+ "n_text_ctx": config.max_target_positions,
133
+ "n_audio_layer": config.encoder_layers,
134
+ "n_audio_head": config.encoder_attention_heads,
135
+ "n_text_layer": config.decoder_layers,
136
+ "n_text_head": config.decoder_attention_heads,
137
+ "n_text_state": config.d_model,
138
+ "n_audio_ctx": config.max_source_positions,
139
+ }
140
+ return dims
141
+
142
+
143
+ def convert_tfms_to_openai_whisper(hf_model_path: str, whisper_dump_path: str):
144
+ """Converts a Whisper model from the Hugging Face to the OpenAI format.
145
+ Takes in the path to a Hugging Face Whisper model, extracts its state_dict, renames keys as needed, and then saves
146
+ the model OpenAI's format.
147
+ Args:
148
+ hf_model_path (`str`):
149
+ Path to the pretrained Whisper model in Hugging Face format.
150
+ whisper_dump_path (`str`):
151
+ Destination path where the converted model in Whisper/OpenAI format will be saved.
152
+ Returns:
153
+ `None`
154
+ """
155
+ print("HF model path:", hf_model_path)
156
+ print("OpenAI model path:", whisper_dump_path)
157
+
158
+ # Load the HF model and its state_dict
159
+ model = WhisperForConditionalGeneration.from_pretrained(hf_model_path)
160
+ state_dict = model.state_dict()
161
+
162
+ # Use a reverse mapping to rename state_dict keys
163
+ state_dict = reverse_rename_keys(state_dict)
164
+
165
+ # Extract configurations and other necessary metadata
166
+ dims = extract_dims_from_hf(model.config)
167
+
168
+ # Remove the proj_out weights from state dictionary
169
+ del state_dict["proj_out.weight"]
170
+
171
+ # Construct the Whisper checkpoint structure
172
+ state_dict = {k.replace("model.", "", 1): v for k, v in state_dict.items()}
173
+ whisper_checkpoint = {"dims": dims, "model_state_dict": state_dict}
174
+
175
+ # Save in Whisper's format
176
+ torch.save(whisper_checkpoint, whisper_dump_path)
177
+
178
+
179
+ if __name__ == "__main__":
180
+ parser = argparse.ArgumentParser()
181
+ # Required parameters
182
+ parser.add_argument(
183
+ "--checkpoint",
184
+ type=str,
185
+ help="Path of name of the Hugging Face checkpoint.", # noqa: E501
186
+ )
187
+ parser.add_argument(
188
+ "--whisper_dump_path",
189
+ type=str,
190
+ help="Path to the output Whisper model.", # noqa: E501
191
+ )
192
+ args = parser.parse_args()
193
+
194
+ convert_tfms_to_openai_whisper(args.checkpoint, args.whisper_dump_path)
whisper-32-2.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b2eadb7893a248cab031b175a7cdf09a3e19f525fa3fb9ae1f33a9e036bd985c
3
+ size 3025049815