File size: 7,330 Bytes
3c5a93a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
#!/usr/bin/env python
"""Converts a Whisper model in Hugging Face format to OpenAI format.
This script is based on the following script to do the opposite:
https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/convert_openai_to_hf.py
Requirements:
```bash
pip install -U openai-whisper
```
Example:
```bash
# Converts the model from Hugging Face to OpenAI format:
python convert_hf_to_openai.py \
    --checkpoint openai/whisper-tiny \
    --whisper_dump_path whisper-tiny-openai.pt
```
```python
>>> # Disabled doctest because it requries the openai-whisper package.
>> import whisper
>> from transformers.models.whisper.convert_hf_to_openai import convert_tfms_to_openai_whisper
>> # Converts the model from Hugging Face to OpenAI format:
>> convert_tfms_to_openai_whisper(
..   "openai/whisper-tiny", "whisper-tiny-openai.pt"
.. )
HF model path: openai/whisper-tiny
OpenAI model path: whisper-tiny-openai.pt
>> # Select an audio file:
>> audio_path = "https://huggingface.co/datasets/sanchit-gandhi/librispeech_long/resolve/main/audio.wav"
>> # Load the Whisper model in OpenAI format:
>> model = whisper.load_model("whisper-tiny-openai.pt")
>> # Transcribe the audio:
>> prediction = model.transcribe(audio_path)
>> prediction["text"][:70]
' chapter 16. I might have told you of the beginning of this liaison in'
```
"""
# Copyright 2023 Xabier de Zuazo and the Aholab team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse

import torch
from torch import nn

from transformers import WhisperConfig, WhisperForConditionalGeneration


# Create the reverse mapping adapting it from the original `WHISPER_MAPPING` in
# the `convert_openai_to_hf.py` script:
REVERSE_WHISPER_MAPPING = {
    "layers": "blocks",
    "fc1": "mlp.0",
    "fc2": "mlp.2",
    "final_layer_norm": "mlp_ln",
    ".self_attn.q_proj": ".attn.query",
    ".self_attn.k_proj": ".attn.key",
    ".self_attn.v_proj": ".attn.value",
    ".self_attn_layer_norm": ".attn_ln",
    ".self_attn.out_proj": ".attn.out",
    ".encoder_attn.q_proj": ".cross_attn.query",
    ".encoder_attn.k_proj": ".cross_attn.key",
    ".encoder_attn.v_proj": ".cross_attn.value",
    ".encoder_attn_layer_norm": ".cross_attn_ln",
    ".encoder_attn.out_proj": ".cross_attn.out",
    "decoder.layer_norm.": "decoder.ln.",
    "encoder.layer_norm.": "encoder.ln_post.",
    "embed_tokens": "token_embedding",
    "encoder.embed_positions.weight": "encoder.positional_embedding",
    "decoder.embed_positions.weight": "decoder.positional_embedding",
}


def reverse_rename_keys(s_dict: dict) -> dict:
    """Renames the keys back from Hugging Face to OpenAI Whisper format.
    By using this function on an HF model's state_dict, we should get the names in the format expected by Whisper.
    Args:
        s_dict (`dict`): A dictionary with keys in Hugging Face format.
    Returns:
        `dict`: The same dictionary but in OpenAI Whisper format.
    """
    keys = list(s_dict.keys())
    for orig_key in keys:
        new_key = orig_key
        for key_r, value_r in REVERSE_WHISPER_MAPPING.items():
            if key_r in orig_key:
                new_key = new_key.replace(key_r, value_r)

        # print(f"{orig_key} -> {new_key}")

        s_dict[new_key] = s_dict.pop(orig_key)
    return s_dict


def make_emb_from_linear(linear: nn.Linear) -> nn.Embedding:
    """Converts a linear layer's weights into an embedding layer.
    The linear layer's `in_features` dimension corresponds to the vocabulary size and its `out_features` dimension
    corresponds to the embedding size.
    Args:
        linear (`nn.Linear`): The linear layer to be converted.
    Returns:
        `nn.Embedding`:
            An embedding layer with weights set to those of the input linear layer.
    """
    vocab_size, emb_size = linear.weight.data.shape
    emb_layer = nn.Embedding(vocab_size, emb_size, _weight=linear.weight.data)
    return emb_layer


def extract_dims_from_hf(config: WhisperConfig) -> dict:
    """Extracts necessary dimensions from Hugging Face's WhisperConfig.
    Extracts necessary dimensions and related configuration data from the Hugging Face model and then restructure it
    for the OpenAI Whisper format.
    Args:
        config (`WhisperConfig`): Configuration of the Hugging Face's model.
    Returns:
        `dict`: The `dims` of the OpenAI Whisper model.
    """
    dims = {
        "n_vocab": config.vocab_size,
        "n_mels": config.num_mel_bins,
        "n_audio_state": config.d_model,
        "n_text_ctx": config.max_target_positions,
        "n_audio_layer": config.encoder_layers,
        "n_audio_head": config.encoder_attention_heads,
        "n_text_layer": config.decoder_layers,
        "n_text_head": config.decoder_attention_heads,
        "n_text_state": config.d_model,
        "n_audio_ctx": config.max_source_positions,
    }
    return dims


def convert_tfms_to_openai_whisper(hf_model_path: str, whisper_dump_path: str):
    """Converts a Whisper model from the Hugging Face to the OpenAI format.
    Takes in the path to a Hugging Face Whisper model, extracts its state_dict, renames keys as needed, and then saves
    the model OpenAI's format.
    Args:
        hf_model_path (`str`):
            Path to the pretrained Whisper model in Hugging Face format.
        whisper_dump_path (`str`):
            Destination path where the converted model in Whisper/OpenAI format will be saved.
    Returns:
        `None`
    """
    print("HF model path:", hf_model_path)
    print("OpenAI model path:", whisper_dump_path)

    # Load the HF model and its state_dict
    model = WhisperForConditionalGeneration.from_pretrained(hf_model_path)
    state_dict = model.state_dict()

    # Use a reverse mapping to rename state_dict keys
    state_dict = reverse_rename_keys(state_dict)

    # Extract configurations and other necessary metadata
    dims = extract_dims_from_hf(model.config)

    # Remove the proj_out weights from state dictionary
    del state_dict["proj_out.weight"]

    # Construct the Whisper checkpoint structure
    state_dict = {k.replace("model.", "", 1): v for k, v in state_dict.items()}
    whisper_checkpoint = {"dims": dims, "model_state_dict": state_dict}

    # Save in Whisper's format
    torch.save(whisper_checkpoint, whisper_dump_path)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # Required parameters
    parser.add_argument(
        "--checkpoint",
        type=str,
        help="Path of name of the Hugging Face checkpoint.",  # noqa: E501
    )
    parser.add_argument(
        "--whisper_dump_path",
        type=str,
        help="Path to the output Whisper model.",  # noqa: E501
    )
    args = parser.parse_args()

    convert_tfms_to_openai_whisper(args.checkpoint, args.whisper_dump_path)