# place this file under egs/wsj0-mix-var/Multi-Decoder-DPRNN and execute, to convert best-model.ckpt to pytorch_model.bin import os import yaml import glob import requests from model import make_model_and_optimizer import torch from asteroid import torch_utils from collections import OrderedDict exp_dir = "exp/tmp" # create an exp and checkpoints folder if none exist os.makedirs(os.path.join(exp_dir, "checkpoints"), exist_ok=True) # Download a checkpoint if none exists if len(glob.glob(os.path.join(exp_dir, "checkpoints", "*.ckpt"))) == 0: r = requests.get( "https://huggingface.co/JunzheJosephZhu/MultiDecoderDPRNN/resolve/main/best-model.ckpt" ) with open(os.path.join(exp_dir, "checkpoints", "best-model.ckpt"), "wb") as handle: handle.write(r.content) # if conf doesn't exist, copy default one conf_path = os.path.join(exp_dir, "conf.yml") if not os.path.exists(conf_path): conf_path = "local/conf.yml" # Load training config with open(conf_path) as f: train_conf = yaml.safe_load(f) sample_rate = train_conf["data"]["sample_rate"] best_model_path = os.path.join(exp_dir, "checkpoints", "best-model.ckpt") model, _ = make_model_and_optimizer(train_conf, sample_rate=sample_rate) model.eval() checkpoint = torch.load(best_model_path, map_location="cpu") model = torch_utils.load_state_dict_in(checkpoint["state_dict"], model) model_args = {} model_args.update(train_conf["masknet"]) model_args.update(train_conf["filterbank"]) new_state_dict = OrderedDict() for k, v in checkpoint["state_dict"].items(): new_k = k[k.find(".") + 1 :] new_state_dict[new_k] = v checkpoint["state_dict"] = new_state_dict checkpoint["model_name"] = "MultiDecoderDPRNN" checkpoint["sample_rate"] = sample_rate checkpoint["model_args"] = model_args torch.save(checkpoint, "pytorch_model.bin") print(f"saved checkpoint to pytorch_model.bin")