enclap / port_weights.py
tonyswoo's picture
Initial Commit
73baeae
import argparse
from math import ceil
from pathlib import Path
import torch
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--ckpt_path", "-c", type=str)
args = parser.parse_args()
weight_name_map = {
"model.encodec_embeddings": None,
"encodec_embeddings": "embed_encodec",
"encodec_mlm_head": "mcm_heads",
}
ckpt_path = Path(args.ckpt_path)
weight_file = ckpt_path / "pytorch_model.bin"
state_dict = torch.load(weight_file, map_location="cpu")
new_state_dict = {}
for key in state_dict:
new_key = key
for orig, repl in weight_name_map.items():
if repl is None:
if orig in new_key:
new_key = None
break
continue
new_key = new_key.replace(orig, repl)
if new_key:
new_state_dict[new_key] = state_dict[key]
for key in new_state_dict:
if "model.encoder.embed_encodec" in key:
dim = new_state_dict[key].shape[0]
new_weight = torch.normal(
0, 1, (ceil(dim / 64) * 64, new_state_dict[key].shape[1])
)
new_weight[:dim] = new_state_dict[key]
new_state_dict[key] = new_weight
weight_file.rename(weight_file.with_suffix(".bin.bak"))
torch.save(new_state_dict, weight_file)