File size: 4,066 Bytes
cc403c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import torch
import torch.nn as nn
from openvoice.api import ToneColorConverter
from openvoice.models import SynthesizerTrn
import os

os.chdir(os.path.dirname(os.path.abspath(__file__)))

class ToneColorExtractWrapper(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
        
    def forward(self, audio):
        # audio: [1, source_audio_len, 513]
        # 将mel谱图转置为模型需要的格式 [1, 513, source_audio_len]
        audio = audio.contiguous()
        # 提取声纹
        g = self.model.ref_enc(audio)
        # 扩展最后一维
        # g = g.unsqueeze(-1)  # [1, 256, 1]
        return g

class ToneCloneWrapper(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
        
    def forward(self, audio, audio_lengths, src_tone, dest_tone, tau):
        # 确保张量连续
        audio = audio.contiguous()
        src_tone = src_tone.contiguous()
        dest_tone = dest_tone.contiguous()
        
        # 语音转换
        o_hat, _, _ = self.model.voice_conversion(
            audio, 
            audio_lengths,
            sid_src=src_tone,
            sid_tgt=dest_tone,
            tau=tau[0]
        )
        return o_hat

def export_models(ckpt_path, output_dir, target_audio_lens, source_audio_lens):
    """
    导出音色提取和克隆模型为ONNX格式
    
    Args:
        ckpt_path: 模型检查点路径
        output_dir: 输出目录
        target_audio_lens: 目标音频长度列表
        source_audio_lens: 源音频长度列表
    """
    
    # 加载模型
    device = "cpu"
    converter = ToneColorConverter(f'{ckpt_path}/config.json', device=device)
    converter.load_ckpt(f'{ckpt_path}/checkpoint.pth')
    
    # 创建输出目录
    os.makedirs(output_dir, exist_ok=True)
    
    # 导出音色提取模型
    extract_wrapper = ToneColorExtractWrapper(converter.model)
    extract_wrapper.eval()
    
    for source_len in source_audio_lens:
        dummy_input = torch.randn(1, source_len, 513).contiguous()
        output_path = f"{output_dir}/tone_color_extract_model.onnx"
        
        torch.onnx.export(
            extract_wrapper,
            dummy_input,
            output_path,
            input_names=['input'],
            output_names=['tone_embedding'],
            dynamic_axes={
                'input': {1: 'source_audio_len'},
            },
            opset_version=11,
            do_constant_folding=True,
            verbose=True
        )
        print(f"Exported tone extract model to {output_path}")
    
    # 导出音色克隆模型  
    clone_wrapper = ToneCloneWrapper(converter.model)
    clone_wrapper.eval()
    
    for target_len in target_audio_lens:
        dummy_inputs = (
            torch.randn(1, 513, target_len).contiguous(),  # audio
            torch.LongTensor([target_len]),   # audio_lengths  
            torch.randn(1, 256, 1).contiguous(),          # src_tone
            torch.randn(1, 256, 1).contiguous(),          # dest_tone
            torch.FloatTensor([0.3])         # tau
        )
        
        output_path = f"{output_dir}/tone_clone_model.onnx"
        
        torch.onnx.export(
            clone_wrapper,
            dummy_inputs,
            output_path,
            input_names=['audio', 'audio_length', 'src_tone', 'dest_tone', 'tau'],
            output_names=['converted_audio'],
            dynamic_axes={
                'audio': {2: 'target_audio_len'},
            },
            opset_version=17,
            do_constant_folding=True,
            verbose=True
        )
        print(f"Exported tone clone model to {output_path}")

if __name__ == "__main__":
    # 示例用法
    TARGET_AUDIO_LENS = [1024]  # 根据需要设置目标长度
    SOURCE_AUDIO_LENS = [1024]  # 根据需要设置源长度
    
    export_models(
        ckpt_path="checkpoints_v2/converter",
        output_dir="onnx_models",
        target_audio_lens=TARGET_AUDIO_LENS,
        source_audio_lens=SOURCE_AUDIO_LENS
    )