content-vec-best / convert.py
lengyue233's picture
First model version
16b7417
import torch
from torch import nn
from transformers import HubertConfig, HubertModel
import logging
# Ignore fairseq's logger
logging.getLogger("fairseq").setLevel(logging.WARNING)
logging.getLogger("torch.distributed.nn.jit.instantiator").setLevel(logging.WARNING)
from fairseq import checkpoint_utils
models, _, _ = checkpoint_utils.load_model_ensemble_and_task(
["content-vec-best-legacy-500.pt"], suffix=""
)
model = models[0]
model.eval()
model.eval()
class HubertModelWithFinalProj(HubertModel):
def __init__(self, config):
super().__init__(config)
self.final_proj = nn.Linear(config.hidden_size, config.classifier_proj_size)
# Default Config
hubert = HubertModelWithFinalProj(HubertConfig())
# huggingface: fairseq
mapping = {
"masked_spec_embed": "mask_emb",
"encoder.layer_norm.bias": "encoder.layer_norm.bias",
"encoder.layer_norm.weight": "encoder.layer_norm.weight",
"encoder.pos_conv_embed.conv.bias": "encoder.pos_conv.0.bias",
"encoder.pos_conv_embed.conv.weight_g": "encoder.pos_conv.0.weight_g",
"encoder.pos_conv_embed.conv.weight_v": "encoder.pos_conv.0.weight_v",
"feature_projection.layer_norm.bias": "layer_norm.bias",
"feature_projection.layer_norm.weight": "layer_norm.weight",
"feature_projection.projection.bias": "post_extract_proj.bias",
"feature_projection.projection.weight": "post_extract_proj.weight",
"final_proj.bias": "final_proj.bias",
"final_proj.weight": "final_proj.weight",
}
# Convert encoder
for layer in range(12):
for j in ["q", "k", "v"]:
mapping[
f"encoder.layers.{layer}.attention.{j}_proj.weight"
] = f"encoder.layers.{layer}.self_attn.{j}_proj.weight"
mapping[
f"encoder.layers.{layer}.attention.{j}_proj.bias"
] = f"encoder.layers.{layer}.self_attn.{j}_proj.bias"
mapping[
f"encoder.layers.{layer}.final_layer_norm.bias"
] = f"encoder.layers.{layer}.final_layer_norm.bias"
mapping[
f"encoder.layers.{layer}.final_layer_norm.weight"
] = f"encoder.layers.{layer}.final_layer_norm.weight"
mapping[
f"encoder.layers.{layer}.layer_norm.bias"
] = f"encoder.layers.{layer}.self_attn_layer_norm.bias"
mapping[
f"encoder.layers.{layer}.layer_norm.weight"
] = f"encoder.layers.{layer}.self_attn_layer_norm.weight"
mapping[
f"encoder.layers.{layer}.attention.out_proj.bias"
] = f"encoder.layers.{layer}.self_attn.out_proj.bias"
mapping[
f"encoder.layers.{layer}.attention.out_proj.weight"
] = f"encoder.layers.{layer}.self_attn.out_proj.weight"
mapping[
f"encoder.layers.{layer}.feed_forward.intermediate_dense.bias"
] = f"encoder.layers.{layer}.fc1.bias"
mapping[
f"encoder.layers.{layer}.feed_forward.intermediate_dense.weight"
] = f"encoder.layers.{layer}.fc1.weight"
mapping[
f"encoder.layers.{layer}.feed_forward.output_dense.bias"
] = f"encoder.layers.{layer}.fc2.bias"
mapping[
f"encoder.layers.{layer}.feed_forward.output_dense.weight"
] = f"encoder.layers.{layer}.fc2.weight"
# Convert Conv Layers
for layer in range(7):
mapping[
f"feature_extractor.conv_layers.{layer}.conv.weight"
] = f"feature_extractor.conv_layers.{layer}.0.weight"
if layer != 0:
continue
mapping[
f"feature_extractor.conv_layers.{layer}.layer_norm.weight"
] = f"feature_extractor.conv_layers.{layer}.2.weight"
mapping[
f"feature_extractor.conv_layers.{layer}.layer_norm.bias"
] = f"feature_extractor.conv_layers.{layer}.2.bias"
hf_keys = set(hubert.state_dict().keys())
fair_keys = set(model.state_dict().keys())
hf_keys -= set(mapping.keys())
fair_keys -= set(mapping.values())
for i, j in zip(sorted(hf_keys), sorted(fair_keys)):
print(i, j)
print(hf_keys, fair_keys)
print(len(hf_keys), len(fair_keys))
# try loading the weights
new_state_dict = {}
for k, v in mapping.items():
new_state_dict[k] = model.state_dict()[v]
x = hubert.load_state_dict(new_state_dict, strict=False)
print(x)
hubert.eval()
with torch.no_grad():
new_input = torch.randn(1, 16384)
result1 = hubert(new_input, output_hidden_states=True)["hidden_states"][9]
result1 = hubert.final_proj(result1)
result2 = model.extract_features(
**{
"source": new_input,
"padding_mask": torch.zeros(1, 16384, dtype=torch.bool),
# "features_only": True,
"output_layer": 9,
}
)[0]
result2 = model.final_proj(result2)
assert torch.allclose(result1, result2, atol=1e-3)
print("Sanity check passed")
# Save huggingface model
hubert.save_pretrained(".")
print("Saved model")