File size: 4,027 Bytes
1e4a2ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import os
import io
import sys
import onnx
import json
import torch
import onnxsim
import warnings

sys.path.append(os.getcwd())

from main.library.algorithm.synthesizers import SynthesizerONNX

warnings.filterwarnings("ignore")

def onnx_exporter(input_path, output_path, is_half=False, device="cpu"):
    cpt = (torch.load(input_path, map_location="cpu", weights_only=True) if os.path.isfile(input_path) else None)
    cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0]

    model_name, model_author, epochs, steps, version, f0, model_hash, vocoder, creation_date, energy_use = cpt.get("model_name", None), cpt.get("author", None), cpt.get("epoch", None), cpt.get("step", None), cpt.get("version", "v1"), cpt.get("f0", 1), cpt.get("model_hash", None), cpt.get("vocoder", "Default"), cpt.get("creation_date", None), cpt.get("energy", False)
    text_enc_hidden_dim = 768 if version == "v2" else 256
    tgt_sr = cpt["config"][-1]

    net_g = SynthesizerONNX(*cpt["config"], use_f0=f0, text_enc_hidden_dim=text_enc_hidden_dim, vocoder=vocoder, checkpointing=False, energy=energy_use)
    net_g.load_state_dict(cpt["weight"], strict=False)
    net_g.eval().to(device)
    net_g = (net_g.half() if is_half else net_g.float())

    phone = torch.rand(1, 200, text_enc_hidden_dim).to(device)
    phone_length = torch.tensor([200]).long().to(device)
    ds = torch.LongTensor([0]).to(device)
    rnd = torch.rand(1, 192, 200).to(device)

    if f0:
        pitch = torch.randint(size=(1, 200), low=5, high=255).to(device)
        pitchf = torch.rand(1, 200).to(device)

    if energy_use: 
        energy = torch.rand(1, 200).to(device)

    args = [phone, phone_length, ds, rnd]
    input_names = ["phone", "phone_lengths", "ds", "rnd"]
    dynamic_axes = {"phone": [1], "rnd": [2]}

    if f0:
        args += [pitch, pitchf]
        input_names += ["pitch", "pitchf"]
        dynamic_axes.update({"pitch": [1], "pitchf": [1]})

    if energy_use:
        args.append(energy)
        input_names.append("energy")
        dynamic_axes.update({"energy": [1]})

    try:
        with io.BytesIO() as model:
            torch.onnx.export(
                net_g, 
                tuple(args), 
                model, 
                do_constant_folding=True, 
                opset_version=17, 
                verbose=False, 
                input_names=input_names, 
                output_names=["audio"], 
                dynamic_axes=dynamic_axes
            )

            model, _ = onnxsim.simplify(onnx.load_model_from_string(model.getvalue()))
            model.metadata_props.append(
                onnx.StringStringEntryProto(
                    key="model_info", 
                    value=json.dumps(
                        {
                            "model_name": model_name, 
                            "author": model_author, 
                            "epoch": epochs, 
                            "step": steps, 
                            "version": version, 
                            "sr": tgt_sr, 
                            "f0": f0, 
                            "model_hash": model_hash, 
                            "creation_date": creation_date, 
                            "vocoder": vocoder, 
                            "text_enc_hidden_dim": text_enc_hidden_dim,
                            "energy": energy_use
                        }
                    )
                )
            )

        if is_half:
            try:
                import onnxconverter_common
            except:
                os.system(f"{sys.executable} -m pip install onnxconverter_common")
                import onnxconverter_common

            model = onnxconverter_common.convert_float_to_float16(model, keep_io_types=True)

        onnx.save(model, output_path)
        return output_path
    except:
        import traceback
        print(traceback.print_exc())
        return None