winnie22
/
transformers_4_35_0
/models
/data2vec
/convert_data2vec_audio_original_pytorch_checkpoint_to_pytorch.py
# coding=utf-8 | |
# Copyright 2021 The HuggingFace Inc. team. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Convert Wav2Vec2 checkpoint.""" | |
import argparse | |
import os | |
from functools import reduce | |
import fairseq | |
import torch | |
from datasets import load_dataset | |
from transformers import Wav2Vec2Processor, logging | |
from transformers.models.data2vec.configuration_data2vec_audio import Data2VecAudioConfig | |
# Copied from https://github.com/pytorch/fairseq/blob/main/examples/data2vec/models/data2vec_audio.py | |
from transformers.models.data2vec.data2vec_audio import Data2VecAudioModel as Dummy # noqa: F401 | |
from transformers.models.data2vec.modeling_data2vec_audio import Data2VecAudioForCTC, Data2VecAudioModel | |
logging.set_verbosity_info() | |
logger = logging.get_logger(__name__) | |
MAPPING = { | |
"post_extract_proj": "feature_projection.projection", | |
"models.0.layer_norm": "feature_projection.layer_norm", | |
"self_attn.k_proj": "encoder.layers.*.attention.k_proj", | |
"self_attn.v_proj": "encoder.layers.*.attention.v_proj", | |
"self_attn.q_proj": "encoder.layers.*.attention.q_proj", | |
"self_attn.out_proj": "encoder.layers.*.attention.out_proj", | |
"self_attn_layer_norm": "encoder.layers.*.layer_norm", | |
"fc1": "encoder.layers.*.feed_forward.intermediate_dense", | |
"fc2": "encoder.layers.*.feed_forward.output_dense", | |
"final_layer_norm": "encoder.layers.*.final_layer_norm", | |
"encoder.layer_norm": "encoder.layer_norm", | |
"w2v_model.layer_norm": "feature_projection.layer_norm", | |
"w2v_encoder.proj": "lm_head", | |
"mask_emb": "masked_spec_embed", | |
} | |
TOP_LEVEL_KEYS = [ | |
"lm_head", | |
] | |
def set_recursively(hf_pointer, key, value, full_name, weight_type): | |
for attribute in key.split("."): | |
hf_pointer = getattr(hf_pointer, attribute) | |
if weight_type is not None: | |
hf_shape = getattr(hf_pointer, weight_type).shape | |
else: | |
hf_shape = hf_pointer.shape | |
if hf_shape != value.shape: | |
raise ValueError( | |
f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be" | |
f" {value.shape} for {full_name}" | |
) | |
if weight_type == "weight": | |
hf_pointer.weight.data = value | |
elif weight_type == "weight_g": | |
hf_pointer.weight_g.data = value | |
elif weight_type == "weight_v": | |
hf_pointer.weight_v.data = value | |
elif weight_type == "bias": | |
hf_pointer.bias.data = value | |
else: | |
hf_pointer.data = value | |
logger.info(f"{key + '.' + weight_type if weight_type is not None else ''} was initialized from {full_name}.") | |
def recursively_load_weights(fairseq_model, hf_model, is_headless): | |
unused_weights = [] | |
fairseq_dict = fairseq_model.state_dict() | |
if not is_headless: | |
feature_extractor = hf_model.data2vec_audio.feature_extractor | |
pos_conv_embedding = hf_model.data2vec_audio.encoder.pos_conv_embed | |
else: | |
feature_extractor = hf_model.feature_extractor | |
pos_conv_embedding = hf_model.encoder.pos_conv_embed | |
for name, value in fairseq_dict.items(): | |
is_used = False | |
if "conv_layers" in name: | |
load_conv_layer( | |
name, | |
value, | |
feature_extractor, | |
unused_weights, | |
) | |
is_used = True | |
elif "pos_conv" in name: | |
load_pos_conv_layer( | |
name, | |
value, | |
pos_conv_embedding, | |
unused_weights, | |
) | |
is_used = True | |
else: | |
for key, mapped_key in MAPPING.items(): | |
if not is_headless: | |
mapped_key = "data2vec_audio." + mapped_key if mapped_key not in TOP_LEVEL_KEYS else mapped_key | |
if key in name or key.split("w2v_model.")[-1] == name.split(".")[0]: | |
is_used = True | |
if "*" in mapped_key: | |
layer_index = name.split(key)[0].split(".")[-2] | |
mapped_key = mapped_key.replace("*", layer_index) | |
if "weight_g" in name: | |
weight_type = "weight_g" | |
elif "weight_v" in name: | |
weight_type = "weight_v" | |
elif "bias" in name: | |
weight_type = "bias" | |
elif "weight" in name: | |
# TODO: don't match quantizer.weight_proj | |
weight_type = "weight" | |
else: | |
weight_type = None | |
set_recursively(hf_model, mapped_key, value, name, weight_type) | |
continue | |
if not is_used: | |
unused_weights.append(name) | |
logger.warning(f"Unused weights: {unused_weights}") | |
def access_by_string(module, path): | |
names = path.split(".") | |
return reduce(getattr, names, module) | |
def set_weights(full_name, module, fsq_value, hf_weight_path): | |
hf_weight = access_by_string(module, hf_weight_path) | |
hf_value = hf_weight.data | |
if fsq_value.shape != hf_value.shape: | |
raise ValueError(f"{full_name} has size {fsq_value.shape}, but {hf_value.shape} was found.") | |
hf_weight.data = fsq_value | |
logger.info(f"{full_name} was correctly initialized from {hf_weight_path}.") | |
def load_conv_layer(full_name, value, feature_extractor, unused_weights): | |
name = full_name.split("conv_layers.")[-1] | |
items = name.split(".") | |
layer_id = int(items[0]) | |
type_id = int(items[1]) | |
weight_type = name.split(".")[-1] | |
if type_id == 0: | |
layer_type = "conv" | |
elif type_id == 2: | |
layer_type = "layer_norm" | |
else: | |
unused_weights.append(full_name) | |
return | |
set_weights(full_name, feature_extractor, value, f"conv_layers.{layer_id}.{layer_type}.{weight_type}") | |
def load_pos_conv_layer(full_name, value, pos_conv_embeddings, unused_weights): | |
name = full_name.split("pos_conv.")[-1] | |
items = name.split(".") | |
layer_id = int(items[0]) | |
type_id = int(items[1]) | |
weight_type = name.split(".")[-1] | |
if type_id != 0: | |
unused_weights.append(full_name) | |
return | |
else: | |
layer_type = "conv" | |
set_weights(full_name, pos_conv_embeddings, value, f"layers.{layer_id}.{layer_type}.{weight_type}") | |
def convert_wav2vec2_checkpoint( | |
checkpoint_path, pytorch_dump_folder_path, config_path=None, dict_path=None, is_finetuned=True | |
): | |
""" | |
Copy/paste/tweak model's weights to transformers design. | |
""" | |
if config_path is not None: | |
config = Data2VecAudioConfig.from_pretrained(config_path) | |
else: | |
config = Data2VecAudioConfig() | |
if not is_finetuned: | |
# Modify final_proj layer name | |
hf_wav2vec = Data2VecAudioModel(config) | |
data2vec_checkpoint_dir = os.path.dirname(checkpoint_path) | |
state_dict = torch.load(checkpoint_path) | |
state_dict["model"]["final_proj.weight"] = state_dict["model"].pop("final_proj.0.weight") | |
state_dict["model"]["final_proj.bias"] = state_dict["model"].pop("final_proj.0.bias") | |
converted_ckpt = os.path.join(data2vec_checkpoint_dir, "converted.pt") | |
torch.save(state_dict, converted_ckpt) | |
else: | |
hf_wav2vec = Data2VecAudioForCTC(config) | |
converted_ckpt = checkpoint_path | |
def load_data2vec(path): | |
model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([path]) | |
return model[0].eval() | |
model = load_data2vec(converted_ckpt) | |
recursively_load_weights(model, hf_wav2vec, not is_finetuned) | |
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-lv60") | |
ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation") | |
input_audio = [x["array"] for x in ds[:4]["audio"]] | |
inputs = processor(input_audio, return_tensors="pt", padding=True) | |
input_values = inputs.input_values | |
attention_mask = inputs.attention_mask | |
# input_values = inputs.input_values[:, :-1] | |
# attention_mask = inputs.attention_mask[:, :-1] | |
hf_wav2vec.eval() | |
model.eval() | |
if is_finetuned: | |
their_output = model(source=input_values, padding_mask=(1 - attention_mask), mask=False, features_only=True)[ | |
"encoder_out" | |
].transpose(0, 1) | |
our_output = hf_wav2vec(input_values, attention_mask=attention_mask)["logits"] | |
pred_ids = torch.argmax(our_output, dim=-1) | |
output_string = processor.batch_decode(pred_ids) | |
print(f"Expected Output: {ds[:4]['text']}, Pred: {output_string}") | |
else: | |
their_output = model(source=input_values, padding_mask=(1 - attention_mask), mask=False, features_only=True)[ | |
"layer_results" | |
][-1][0].transpose(0, 1) | |
our_output = hf_wav2vec(input_values, attention_mask=attention_mask)["last_hidden_state"] | |
print(our_output.shape, their_output.shape) | |
max_absolute_diff = torch.max(torch.abs(our_output - their_output)).item() | |
print(f"max_absolute_diff = {max_absolute_diff}") # ~ 1e-7 | |
success = torch.allclose(our_output, their_output, atol=1e-3) | |
print("Do both models output the same tensors?", "🔥" if success else "💩") | |
if not success: | |
raise Exception("Something went wRoNg") | |
hf_wav2vec.save_pretrained(pytorch_dump_folder_path) | |
if is_finetuned: | |
processor.save_pretrained(pytorch_dump_folder_path) | |
else: | |
processor.feature_extractor.save_pretrained(pytorch_dump_folder_path) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.") | |
parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to fairseq checkpoint") | |
parser.add_argument("--dict_path", default=None, type=str, help="Path to dict of fine-tuned model") | |
parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert") | |
parser.add_argument( | |
"--not_finetuned", action="store_true", help="Whether the model to convert is a fine-tuned model or not" | |
) | |
args = parser.parse_args() | |
convert_wav2vec2_checkpoint( | |
args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path, args.dict_path, not args.not_finetuned | |
) | |