|
from collections import OrderedDict |
|
|
|
import torch |
|
|
|
from .layers.synthesizers import SynthesizerTrnMsNSFsid |
|
from .jit import load_inputs, export_jit_model, save_pickle |
|
|
|
|
|
def get_synthesizer(cpt: OrderedDict, device=torch.device("cpu")): |
|
cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0] |
|
if_f0 = cpt.get("f0", 1) |
|
version = cpt.get("version", "v1") |
|
if version == "v1": |
|
encoder_dim = 256 |
|
elif version == "v2": |
|
encoder_dim = 768 |
|
net_g = SynthesizerTrnMsNSFsid( |
|
*cpt["config"], |
|
encoder_dim=encoder_dim, |
|
use_f0=if_f0 == 1, |
|
) |
|
del net_g.enc_q |
|
net_g.load_state_dict(cpt["weight"], strict=False) |
|
net_g = net_g.float() |
|
net_g.eval().to(device) |
|
net_g.remove_weight_norm() |
|
return net_g, cpt |
|
|
|
|
|
def load_synthesizer( |
|
pth_path: torch.serialization.FILE_LIKE, device=torch.device("cpu") |
|
): |
|
return get_synthesizer( |
|
torch.load(pth_path, map_location=torch.device("cpu")), |
|
device, |
|
) |
|
|
|
|
|
def synthesizer_jit_export( |
|
model_path: str, |
|
mode: str = "script", |
|
inputs_path: str = None, |
|
save_path: str = None, |
|
device=torch.device("cpu"), |
|
is_half=False, |
|
): |
|
if not save_path: |
|
save_path = model_path.rstrip(".pth") |
|
save_path += ".half.jit" if is_half else ".jit" |
|
if "cuda" in str(device) and ":" not in str(device): |
|
device = torch.device("cuda:0") |
|
from rvc.synthesizer import load_synthesizer |
|
|
|
model, cpt = load_synthesizer(model_path, device) |
|
assert isinstance(cpt, dict) |
|
model.forward = model.infer |
|
inputs = None |
|
if mode == "trace": |
|
inputs = load_inputs(inputs_path, device, is_half) |
|
ckpt = export_jit_model(model, mode, inputs, device, is_half) |
|
cpt.pop("weight") |
|
cpt["model"] = ckpt["model"] |
|
cpt["device"] = device |
|
save_pickle(cpt, save_path) |
|
return cpt |
|
|