aaron commited on
Commit
2fd89a2
·
0 Parent(s):

feat: 통합 TTS + 음성 변환 앱 (깨끗한 버전)

Browse files

- OpenVoice V2와 Seed-VC를 결합한 통합 앱
- DAC 모델 초기화 문제 해결
- 빠른 시작을 위한 지연 로딩 구현
- 강화된 에러 처리 및 로깅 시스템
- Git LFS 문제 해결을 위한 깨끗한 커밋

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +39 -0
  2. OpenVoice/openvoice/__init__.py +0 -0
  3. OpenVoice/openvoice/__pycache__/__init__.cpython-312.pyc +0 -0
  4. OpenVoice/openvoice/__pycache__/se_extractor.cpython-312.pyc +0 -0
  5. OpenVoice/openvoice/api.py +202 -0
  6. OpenVoice/openvoice/attentions.py +465 -0
  7. OpenVoice/openvoice/commons.py +160 -0
  8. OpenVoice/openvoice/mel_processing.py +183 -0
  9. OpenVoice/openvoice/models.py +499 -0
  10. OpenVoice/openvoice/modules.py +598 -0
  11. OpenVoice/openvoice/openvoice_app.py +275 -0
  12. OpenVoice/openvoice/se_extractor.py +154 -0
  13. OpenVoice/openvoice/text/__init__.py +79 -0
  14. OpenVoice/openvoice/text/cleaners.py +16 -0
  15. OpenVoice/openvoice/text/english.py +188 -0
  16. OpenVoice/openvoice/text/mandarin.py +326 -0
  17. OpenVoice/openvoice/text/symbols.py +88 -0
  18. OpenVoice/openvoice/transforms.py +209 -0
  19. OpenVoice/openvoice/utils.py +194 -0
  20. README.md +119 -0
  21. app.py +843 -0
  22. hf_utils.py +12 -0
  23. modules/__pycache__/audio.cpython-310.pyc +0 -0
  24. modules/__pycache__/commons.cpython-310.pyc +0 -0
  25. modules/__pycache__/commons.cpython-38.pyc +0 -0
  26. modules/__pycache__/diffusion_transformer.cpython-310.pyc +0 -0
  27. modules/__pycache__/encodec.cpython-310.pyc +0 -0
  28. modules/__pycache__/flow_matching.cpython-310.pyc +0 -0
  29. modules/__pycache__/length_regulator.cpython-310.pyc +0 -0
  30. modules/__pycache__/rmvpe.cpython-310.pyc +0 -0
  31. modules/__pycache__/wavenet.cpython-310.pyc +0 -0
  32. modules/alias_free_torch/__init__.py +5 -0
  33. modules/alias_free_torch/__pycache__/__init__.cpython-310.pyc +0 -0
  34. modules/alias_free_torch/__pycache__/act.cpython-310.pyc +0 -0
  35. modules/alias_free_torch/__pycache__/filter.cpython-310.pyc +0 -0
  36. modules/alias_free_torch/__pycache__/resample.cpython-310.pyc +0 -0
  37. modules/alias_free_torch/act.py +29 -0
  38. modules/alias_free_torch/filter.py +96 -0
  39. modules/alias_free_torch/resample.py +57 -0
  40. modules/astral_quantization/__pycache__/bsq.cpython-310.pyc +0 -0
  41. modules/astral_quantization/__pycache__/convnext.cpython-310.pyc +0 -0
  42. modules/astral_quantization/__pycache__/default_model.cpython-310.pyc +0 -0
  43. modules/astral_quantization/bsq.py +569 -0
  44. modules/astral_quantization/convnext.py +209 -0
  45. modules/astral_quantization/default_model.py +73 -0
  46. modules/astral_quantization/transformer.py +254 -0
  47. modules/audio.py +82 -0
  48. modules/bigvgan/__pycache__/activations.cpython-310.pyc +0 -0
  49. modules/bigvgan/__pycache__/bigvgan.cpython-310.pyc +0 -0
  50. modules/bigvgan/__pycache__/env.cpython-310.pyc +0 -0
.gitattributes ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.o filter=lfs diff=lfs merge=lfs -text
37
+ *.pyd filter=lfs diff=lfs merge=lfs -text
38
+ *.ninja* filter=lfs diff=lfs merge=lfs -text
39
+ *.deps filter=lfs diff=lfs merge=lfs -text
OpenVoice/openvoice/__init__.py ADDED
File without changes
OpenVoice/openvoice/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (160 Bytes). View file
 
OpenVoice/openvoice/__pycache__/se_extractor.cpython-312.pyc ADDED
Binary file (6.92 kB). View file
 
OpenVoice/openvoice/api.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import re
4
+ import soundfile
5
+ from openvoice import utils
6
+ from openvoice import commons
7
+ import os
8
+ import librosa
9
+ from openvoice.text import text_to_sequence
10
+ from openvoice.mel_processing import spectrogram_torch
11
+ from openvoice.models import SynthesizerTrn
12
+
13
+
14
+ class OpenVoiceBaseClass(object):
15
+ def __init__(self,
16
+ config_path,
17
+ device='cuda:0'):
18
+ if 'cuda' in device:
19
+ assert torch.cuda.is_available()
20
+
21
+ hps = utils.get_hparams_from_file(config_path)
22
+
23
+ model = SynthesizerTrn(
24
+ len(getattr(hps, 'symbols', [])),
25
+ hps.data.filter_length // 2 + 1,
26
+ n_speakers=hps.data.n_speakers,
27
+ **hps.model,
28
+ ).to(device)
29
+
30
+ model.eval()
31
+ self.model = model
32
+ self.hps = hps
33
+ self.device = device
34
+
35
+ def load_ckpt(self, ckpt_path):
36
+ checkpoint_dict = torch.load(ckpt_path, map_location=torch.device(self.device))
37
+ a, b = self.model.load_state_dict(checkpoint_dict['model'], strict=False)
38
+ print("Loaded checkpoint '{}'".format(ckpt_path))
39
+ print('missing/unexpected keys:', a, b)
40
+
41
+
42
+ class BaseSpeakerTTS(OpenVoiceBaseClass):
43
+ language_marks = {
44
+ "english": "EN",
45
+ "chinese": "ZH",
46
+ }
47
+
48
+ @staticmethod
49
+ def get_text(text, hps, is_symbol):
50
+ text_norm = text_to_sequence(text, hps.symbols, [] if is_symbol else hps.data.text_cleaners)
51
+ if hps.data.add_blank:
52
+ text_norm = commons.intersperse(text_norm, 0)
53
+ text_norm = torch.LongTensor(text_norm)
54
+ return text_norm
55
+
56
+ @staticmethod
57
+ def audio_numpy_concat(segment_data_list, sr, speed=1.):
58
+ audio_segments = []
59
+ for segment_data in segment_data_list:
60
+ audio_segments += segment_data.reshape(-1).tolist()
61
+ audio_segments += [0] * int((sr * 0.05)/speed)
62
+ audio_segments = np.array(audio_segments).astype(np.float32)
63
+ return audio_segments
64
+
65
+ @staticmethod
66
+ def split_sentences_into_pieces(text, language_str):
67
+ texts = utils.split_sentence(text, language_str=language_str)
68
+ print(" > Text splitted to sentences.")
69
+ print('\n'.join(texts))
70
+ print(" > ===========================")
71
+ return texts
72
+
73
+ def tts(self, text, output_path, speaker, language='English', speed=1.0):
74
+ mark = self.language_marks.get(language.lower(), None)
75
+ assert mark is not None, f"language {language} is not supported"
76
+
77
+ texts = self.split_sentences_into_pieces(text, mark)
78
+
79
+ audio_list = []
80
+ for t in texts:
81
+ t = re.sub(r'([a-z])([A-Z])', r'\1 \2', t)
82
+ t = f'[{mark}]{t}[{mark}]'
83
+ stn_tst = self.get_text(t, self.hps, False)
84
+ device = self.device
85
+ speaker_id = self.hps.speakers[speaker]
86
+ with torch.no_grad():
87
+ x_tst = stn_tst.unsqueeze(0).to(device)
88
+ x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(device)
89
+ sid = torch.LongTensor([speaker_id]).to(device)
90
+ audio = self.model.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=0.667, noise_scale_w=0.6,
91
+ length_scale=1.0 / speed)[0][0, 0].data.cpu().float().numpy()
92
+ audio_list.append(audio)
93
+ audio = self.audio_numpy_concat(audio_list, sr=self.hps.data.sampling_rate, speed=speed)
94
+
95
+ if output_path is None:
96
+ return audio
97
+ else:
98
+ soundfile.write(output_path, audio, self.hps.data.sampling_rate)
99
+
100
+
101
+ class ToneColorConverter(OpenVoiceBaseClass):
102
+ def __init__(self, *args, **kwargs):
103
+ super().__init__(*args, **kwargs)
104
+
105
+ if kwargs.get('enable_watermark', True):
106
+ import wavmark
107
+ self.watermark_model = wavmark.load_model().to(self.device)
108
+ else:
109
+ self.watermark_model = None
110
+ self.version = getattr(self.hps, '_version_', "v1")
111
+
112
+
113
+
114
+ def extract_se(self, ref_wav_list, se_save_path=None):
115
+ if isinstance(ref_wav_list, str):
116
+ ref_wav_list = [ref_wav_list]
117
+
118
+ device = self.device
119
+ hps = self.hps
120
+ gs = []
121
+
122
+ for fname in ref_wav_list:
123
+ audio_ref, sr = librosa.load(fname, sr=hps.data.sampling_rate)
124
+ y = torch.FloatTensor(audio_ref)
125
+ y = y.to(device)
126
+ y = y.unsqueeze(0)
127
+ y = spectrogram_torch(y, hps.data.filter_length,
128
+ hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length,
129
+ center=False).to(device)
130
+ with torch.no_grad():
131
+ g = self.model.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
132
+ gs.append(g.detach())
133
+ gs = torch.stack(gs).mean(0)
134
+
135
+ if se_save_path is not None:
136
+ os.makedirs(os.path.dirname(se_save_path), exist_ok=True)
137
+ torch.save(gs.cpu(), se_save_path)
138
+
139
+ return gs
140
+
141
+ def convert(self, audio_src_path, src_se, tgt_se, output_path=None, tau=0.3, message="default"):
142
+ hps = self.hps
143
+ # load audio
144
+ audio, sample_rate = librosa.load(audio_src_path, sr=hps.data.sampling_rate)
145
+ audio = torch.tensor(audio).float()
146
+
147
+ with torch.no_grad():
148
+ y = torch.FloatTensor(audio).to(self.device)
149
+ y = y.unsqueeze(0)
150
+ spec = spectrogram_torch(y, hps.data.filter_length,
151
+ hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length,
152
+ center=False).to(self.device)
153
+ spec_lengths = torch.LongTensor([spec.size(-1)]).to(self.device)
154
+ audio = self.model.voice_conversion(spec, spec_lengths, sid_src=src_se, sid_tgt=tgt_se, tau=tau)[0][
155
+ 0, 0].data.cpu().float().numpy()
156
+ audio = self.add_watermark(audio, message)
157
+ if output_path is None:
158
+ return audio
159
+ else:
160
+ soundfile.write(output_path, audio, hps.data.sampling_rate)
161
+
162
+ def add_watermark(self, audio, message):
163
+ if self.watermark_model is None:
164
+ return audio
165
+ device = self.device
166
+ bits = utils.string_to_bits(message).reshape(-1)
167
+ n_repeat = len(bits) // 32
168
+
169
+ K = 16000
170
+ coeff = 2
171
+ for n in range(n_repeat):
172
+ trunck = audio[(coeff * n) * K: (coeff * n + 1) * K]
173
+ if len(trunck) != K:
174
+ print('Audio too short, fail to add watermark')
175
+ break
176
+ message_npy = bits[n * 32: (n + 1) * 32]
177
+
178
+ with torch.no_grad():
179
+ signal = torch.FloatTensor(trunck).to(device)[None]
180
+ message_tensor = torch.FloatTensor(message_npy).to(device)[None]
181
+ signal_wmd_tensor = self.watermark_model.encode(signal, message_tensor)
182
+ signal_wmd_npy = signal_wmd_tensor.detach().cpu().squeeze()
183
+ audio[(coeff * n) * K: (coeff * n + 1) * K] = signal_wmd_npy
184
+ return audio
185
+
186
+ def detect_watermark(self, audio, n_repeat):
187
+ bits = []
188
+ K = 16000
189
+ coeff = 2
190
+ for n in range(n_repeat):
191
+ trunck = audio[(coeff * n) * K: (coeff * n + 1) * K]
192
+ if len(trunck) != K:
193
+ print('Audio too short, fail to detect watermark')
194
+ return 'Fail'
195
+ with torch.no_grad():
196
+ signal = torch.FloatTensor(trunck).to(self.device).unsqueeze(0)
197
+ message_decoded_npy = (self.watermark_model.decode(signal) >= 0.5).int().detach().cpu().numpy().squeeze()
198
+ bits.append(message_decoded_npy)
199
+ bits = np.stack(bits).reshape(-1, 8)
200
+ message = utils.bits_to_string(bits)
201
+ return message
202
+
OpenVoice/openvoice/attentions.py ADDED
@@ -0,0 +1,465 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+
6
+ from openvoice import commons
7
+ import logging
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ class LayerNorm(nn.Module):
13
+ def __init__(self, channels, eps=1e-5):
14
+ super().__init__()
15
+ self.channels = channels
16
+ self.eps = eps
17
+
18
+ self.gamma = nn.Parameter(torch.ones(channels))
19
+ self.beta = nn.Parameter(torch.zeros(channels))
20
+
21
+ def forward(self, x):
22
+ x = x.transpose(1, -1)
23
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
24
+ return x.transpose(1, -1)
25
+
26
+
27
+ @torch.jit.script
28
+ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
29
+ n_channels_int = n_channels[0]
30
+ in_act = input_a + input_b
31
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
32
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
33
+ acts = t_act * s_act
34
+ return acts
35
+
36
+
37
+ class Encoder(nn.Module):
38
+ def __init__(
39
+ self,
40
+ hidden_channels,
41
+ filter_channels,
42
+ n_heads,
43
+ n_layers,
44
+ kernel_size=1,
45
+ p_dropout=0.0,
46
+ window_size=4,
47
+ isflow=True,
48
+ **kwargs
49
+ ):
50
+ super().__init__()
51
+ self.hidden_channels = hidden_channels
52
+ self.filter_channels = filter_channels
53
+ self.n_heads = n_heads
54
+ self.n_layers = n_layers
55
+ self.kernel_size = kernel_size
56
+ self.p_dropout = p_dropout
57
+ self.window_size = window_size
58
+ # if isflow:
59
+ # cond_layer = torch.nn.Conv1d(256, 2*hidden_channels*n_layers, 1)
60
+ # self.cond_pre = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, 1)
61
+ # self.cond_layer = weight_norm(cond_layer, name='weight')
62
+ # self.gin_channels = 256
63
+ self.cond_layer_idx = self.n_layers
64
+ if "gin_channels" in kwargs:
65
+ self.gin_channels = kwargs["gin_channels"]
66
+ if self.gin_channels != 0:
67
+ self.spk_emb_linear = nn.Linear(self.gin_channels, self.hidden_channels)
68
+ # vits2 says 3rd block, so idx is 2 by default
69
+ self.cond_layer_idx = (
70
+ kwargs["cond_layer_idx"] if "cond_layer_idx" in kwargs else 2
71
+ )
72
+ # logging.debug(self.gin_channels, self.cond_layer_idx)
73
+ assert (
74
+ self.cond_layer_idx < self.n_layers
75
+ ), "cond_layer_idx should be less than n_layers"
76
+ self.drop = nn.Dropout(p_dropout)
77
+ self.attn_layers = nn.ModuleList()
78
+ self.norm_layers_1 = nn.ModuleList()
79
+ self.ffn_layers = nn.ModuleList()
80
+ self.norm_layers_2 = nn.ModuleList()
81
+
82
+ for i in range(self.n_layers):
83
+ self.attn_layers.append(
84
+ MultiHeadAttention(
85
+ hidden_channels,
86
+ hidden_channels,
87
+ n_heads,
88
+ p_dropout=p_dropout,
89
+ window_size=window_size,
90
+ )
91
+ )
92
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
93
+ self.ffn_layers.append(
94
+ FFN(
95
+ hidden_channels,
96
+ hidden_channels,
97
+ filter_channels,
98
+ kernel_size,
99
+ p_dropout=p_dropout,
100
+ )
101
+ )
102
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
103
+
104
+ def forward(self, x, x_mask, g=None):
105
+ attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
106
+ x = x * x_mask
107
+ for i in range(self.n_layers):
108
+ if i == self.cond_layer_idx and g is not None:
109
+ g = self.spk_emb_linear(g.transpose(1, 2))
110
+ g = g.transpose(1, 2)
111
+ x = x + g
112
+ x = x * x_mask
113
+ y = self.attn_layers[i](x, x, attn_mask)
114
+ y = self.drop(y)
115
+ x = self.norm_layers_1[i](x + y)
116
+
117
+ y = self.ffn_layers[i](x, x_mask)
118
+ y = self.drop(y)
119
+ x = self.norm_layers_2[i](x + y)
120
+ x = x * x_mask
121
+ return x
122
+
123
+
124
+ class Decoder(nn.Module):
125
+ def __init__(
126
+ self,
127
+ hidden_channels,
128
+ filter_channels,
129
+ n_heads,
130
+ n_layers,
131
+ kernel_size=1,
132
+ p_dropout=0.0,
133
+ proximal_bias=False,
134
+ proximal_init=True,
135
+ **kwargs
136
+ ):
137
+ super().__init__()
138
+ self.hidden_channels = hidden_channels
139
+ self.filter_channels = filter_channels
140
+ self.n_heads = n_heads
141
+ self.n_layers = n_layers
142
+ self.kernel_size = kernel_size
143
+ self.p_dropout = p_dropout
144
+ self.proximal_bias = proximal_bias
145
+ self.proximal_init = proximal_init
146
+
147
+ self.drop = nn.Dropout(p_dropout)
148
+ self.self_attn_layers = nn.ModuleList()
149
+ self.norm_layers_0 = nn.ModuleList()
150
+ self.encdec_attn_layers = nn.ModuleList()
151
+ self.norm_layers_1 = nn.ModuleList()
152
+ self.ffn_layers = nn.ModuleList()
153
+ self.norm_layers_2 = nn.ModuleList()
154
+ for i in range(self.n_layers):
155
+ self.self_attn_layers.append(
156
+ MultiHeadAttention(
157
+ hidden_channels,
158
+ hidden_channels,
159
+ n_heads,
160
+ p_dropout=p_dropout,
161
+ proximal_bias=proximal_bias,
162
+ proximal_init=proximal_init,
163
+ )
164
+ )
165
+ self.norm_layers_0.append(LayerNorm(hidden_channels))
166
+ self.encdec_attn_layers.append(
167
+ MultiHeadAttention(
168
+ hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout
169
+ )
170
+ )
171
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
172
+ self.ffn_layers.append(
173
+ FFN(
174
+ hidden_channels,
175
+ hidden_channels,
176
+ filter_channels,
177
+ kernel_size,
178
+ p_dropout=p_dropout,
179
+ causal=True,
180
+ )
181
+ )
182
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
183
+
184
+ def forward(self, x, x_mask, h, h_mask):
185
+ """
186
+ x: decoder input
187
+ h: encoder output
188
+ """
189
+ self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
190
+ device=x.device, dtype=x.dtype
191
+ )
192
+ encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
193
+ x = x * x_mask
194
+ for i in range(self.n_layers):
195
+ y = self.self_attn_layers[i](x, x, self_attn_mask)
196
+ y = self.drop(y)
197
+ x = self.norm_layers_0[i](x + y)
198
+
199
+ y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
200
+ y = self.drop(y)
201
+ x = self.norm_layers_1[i](x + y)
202
+
203
+ y = self.ffn_layers[i](x, x_mask)
204
+ y = self.drop(y)
205
+ x = self.norm_layers_2[i](x + y)
206
+ x = x * x_mask
207
+ return x
208
+
209
+
210
+ class MultiHeadAttention(nn.Module):
211
+ def __init__(
212
+ self,
213
+ channels,
214
+ out_channels,
215
+ n_heads,
216
+ p_dropout=0.0,
217
+ window_size=None,
218
+ heads_share=True,
219
+ block_length=None,
220
+ proximal_bias=False,
221
+ proximal_init=False,
222
+ ):
223
+ super().__init__()
224
+ assert channels % n_heads == 0
225
+
226
+ self.channels = channels
227
+ self.out_channels = out_channels
228
+ self.n_heads = n_heads
229
+ self.p_dropout = p_dropout
230
+ self.window_size = window_size
231
+ self.heads_share = heads_share
232
+ self.block_length = block_length
233
+ self.proximal_bias = proximal_bias
234
+ self.proximal_init = proximal_init
235
+ self.attn = None
236
+
237
+ self.k_channels = channels // n_heads
238
+ self.conv_q = nn.Conv1d(channels, channels, 1)
239
+ self.conv_k = nn.Conv1d(channels, channels, 1)
240
+ self.conv_v = nn.Conv1d(channels, channels, 1)
241
+ self.conv_o = nn.Conv1d(channels, out_channels, 1)
242
+ self.drop = nn.Dropout(p_dropout)
243
+
244
+ if window_size is not None:
245
+ n_heads_rel = 1 if heads_share else n_heads
246
+ rel_stddev = self.k_channels**-0.5
247
+ self.emb_rel_k = nn.Parameter(
248
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
249
+ * rel_stddev
250
+ )
251
+ self.emb_rel_v = nn.Parameter(
252
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
253
+ * rel_stddev
254
+ )
255
+
256
+ nn.init.xavier_uniform_(self.conv_q.weight)
257
+ nn.init.xavier_uniform_(self.conv_k.weight)
258
+ nn.init.xavier_uniform_(self.conv_v.weight)
259
+ if proximal_init:
260
+ with torch.no_grad():
261
+ self.conv_k.weight.copy_(self.conv_q.weight)
262
+ self.conv_k.bias.copy_(self.conv_q.bias)
263
+
264
+ def forward(self, x, c, attn_mask=None):
265
+ q = self.conv_q(x)
266
+ k = self.conv_k(c)
267
+ v = self.conv_v(c)
268
+
269
+ x, self.attn = self.attention(q, k, v, mask=attn_mask)
270
+
271
+ x = self.conv_o(x)
272
+ return x
273
+
274
+ def attention(self, query, key, value, mask=None):
275
+ # reshape [b, d, t] -> [b, n_h, t, d_k]
276
+ b, d, t_s, t_t = (*key.size(), query.size(2))
277
+ query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
278
+ key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
279
+ value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
280
+
281
+ scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
282
+ if self.window_size is not None:
283
+ assert (
284
+ t_s == t_t
285
+ ), "Relative attention is only available for self-attention."
286
+ key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
287
+ rel_logits = self._matmul_with_relative_keys(
288
+ query / math.sqrt(self.k_channels), key_relative_embeddings
289
+ )
290
+ scores_local = self._relative_position_to_absolute_position(rel_logits)
291
+ scores = scores + scores_local
292
+ if self.proximal_bias:
293
+ assert t_s == t_t, "Proximal bias is only available for self-attention."
294
+ scores = scores + self._attention_bias_proximal(t_s).to(
295
+ device=scores.device, dtype=scores.dtype
296
+ )
297
+ if mask is not None:
298
+ scores = scores.masked_fill(mask == 0, -1e4)
299
+ if self.block_length is not None:
300
+ assert (
301
+ t_s == t_t
302
+ ), "Local attention is only available for self-attention."
303
+ block_mask = (
304
+ torch.ones_like(scores)
305
+ .triu(-self.block_length)
306
+ .tril(self.block_length)
307
+ )
308
+ scores = scores.masked_fill(block_mask == 0, -1e4)
309
+ p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
310
+ p_attn = self.drop(p_attn)
311
+ output = torch.matmul(p_attn, value)
312
+ if self.window_size is not None:
313
+ relative_weights = self._absolute_position_to_relative_position(p_attn)
314
+ value_relative_embeddings = self._get_relative_embeddings(
315
+ self.emb_rel_v, t_s
316
+ )
317
+ output = output + self._matmul_with_relative_values(
318
+ relative_weights, value_relative_embeddings
319
+ )
320
+ output = (
321
+ output.transpose(2, 3).contiguous().view(b, d, t_t)
322
+ ) # [b, n_h, t_t, d_k] -> [b, d, t_t]
323
+ return output, p_attn
324
+
325
+ def _matmul_with_relative_values(self, x, y):
326
+ """
327
+ x: [b, h, l, m]
328
+ y: [h or 1, m, d]
329
+ ret: [b, h, l, d]
330
+ """
331
+ ret = torch.matmul(x, y.unsqueeze(0))
332
+ return ret
333
+
334
+ def _matmul_with_relative_keys(self, x, y):
335
+ """
336
+ x: [b, h, l, d]
337
+ y: [h or 1, m, d]
338
+ ret: [b, h, l, m]
339
+ """
340
+ ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
341
+ return ret
342
+
343
+ def _get_relative_embeddings(self, relative_embeddings, length):
344
+ 2 * self.window_size + 1
345
+ # Pad first before slice to avoid using cond ops.
346
+ pad_length = max(length - (self.window_size + 1), 0)
347
+ slice_start_position = max((self.window_size + 1) - length, 0)
348
+ slice_end_position = slice_start_position + 2 * length - 1
349
+ if pad_length > 0:
350
+ padded_relative_embeddings = F.pad(
351
+ relative_embeddings,
352
+ commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
353
+ )
354
+ else:
355
+ padded_relative_embeddings = relative_embeddings
356
+ used_relative_embeddings = padded_relative_embeddings[
357
+ :, slice_start_position:slice_end_position
358
+ ]
359
+ return used_relative_embeddings
360
+
361
+ def _relative_position_to_absolute_position(self, x):
362
+ """
363
+ x: [b, h, l, 2*l-1]
364
+ ret: [b, h, l, l]
365
+ """
366
+ batch, heads, length, _ = x.size()
367
+ # Concat columns of pad to shift from relative to absolute indexing.
368
+ x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
369
+
370
+ # Concat extra elements so to add up to shape (len+1, 2*len-1).
371
+ x_flat = x.view([batch, heads, length * 2 * length])
372
+ x_flat = F.pad(
373
+ x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
374
+ )
375
+
376
+ # Reshape and slice out the padded elements.
377
+ x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
378
+ :, :, :length, length - 1 :
379
+ ]
380
+ return x_final
381
+
382
+ def _absolute_position_to_relative_position(self, x):
383
+ """
384
+ x: [b, h, l, l]
385
+ ret: [b, h, l, 2*l-1]
386
+ """
387
+ batch, heads, length, _ = x.size()
388
+ # pad along column
389
+ x = F.pad(
390
+ x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
391
+ )
392
+ x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
393
+ # add 0's in the beginning that will skew the elements after reshape
394
+ x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
395
+ x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
396
+ return x_final
397
+
398
+ def _attention_bias_proximal(self, length):
399
+ """Bias for self-attention to encourage attention to close positions.
400
+ Args:
401
+ length: an integer scalar.
402
+ Returns:
403
+ a Tensor with shape [1, 1, length, length]
404
+ """
405
+ r = torch.arange(length, dtype=torch.float32)
406
+ diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
407
+ return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
408
+
409
+
410
+ class FFN(nn.Module):
411
+ def __init__(
412
+ self,
413
+ in_channels,
414
+ out_channels,
415
+ filter_channels,
416
+ kernel_size,
417
+ p_dropout=0.0,
418
+ activation=None,
419
+ causal=False,
420
+ ):
421
+ super().__init__()
422
+ self.in_channels = in_channels
423
+ self.out_channels = out_channels
424
+ self.filter_channels = filter_channels
425
+ self.kernel_size = kernel_size
426
+ self.p_dropout = p_dropout
427
+ self.activation = activation
428
+ self.causal = causal
429
+
430
+ if causal:
431
+ self.padding = self._causal_padding
432
+ else:
433
+ self.padding = self._same_padding
434
+
435
+ self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
436
+ self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
437
+ self.drop = nn.Dropout(p_dropout)
438
+
439
+ def forward(self, x, x_mask):
440
+ x = self.conv_1(self.padding(x * x_mask))
441
+ if self.activation == "gelu":
442
+ x = x * torch.sigmoid(1.702 * x)
443
+ else:
444
+ x = torch.relu(x)
445
+ x = self.drop(x)
446
+ x = self.conv_2(self.padding(x * x_mask))
447
+ return x * x_mask
448
+
449
+ def _causal_padding(self, x):
450
+ if self.kernel_size == 1:
451
+ return x
452
+ pad_l = self.kernel_size - 1
453
+ pad_r = 0
454
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
455
+ x = F.pad(x, commons.convert_pad_shape(padding))
456
+ return x
457
+
458
+ def _same_padding(self, x):
459
+ if self.kernel_size == 1:
460
+ return x
461
+ pad_l = (self.kernel_size - 1) // 2
462
+ pad_r = self.kernel_size // 2
463
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
464
+ x = F.pad(x, commons.convert_pad_shape(padding))
465
+ return x
OpenVoice/openvoice/commons.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch.nn import functional as F
4
+
5
+
6
+ def init_weights(m, mean=0.0, std=0.01):
7
+ classname = m.__class__.__name__
8
+ if classname.find("Conv") != -1:
9
+ m.weight.data.normal_(mean, std)
10
+
11
+
12
+ def get_padding(kernel_size, dilation=1):
13
+ return int((kernel_size * dilation - dilation) / 2)
14
+
15
+
16
+ def convert_pad_shape(pad_shape):
17
+ layer = pad_shape[::-1]
18
+ pad_shape = [item for sublist in layer for item in sublist]
19
+ return pad_shape
20
+
21
+
22
+ def intersperse(lst, item):
23
+ result = [item] * (len(lst) * 2 + 1)
24
+ result[1::2] = lst
25
+ return result
26
+
27
+
28
+ def kl_divergence(m_p, logs_p, m_q, logs_q):
29
+ """KL(P||Q)"""
30
+ kl = (logs_q - logs_p) - 0.5
31
+ kl += (
32
+ 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
33
+ )
34
+ return kl
35
+
36
+
37
+ def rand_gumbel(shape):
38
+ """Sample from the Gumbel distribution, protect from overflows."""
39
+ uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
40
+ return -torch.log(-torch.log(uniform_samples))
41
+
42
+
43
+ def rand_gumbel_like(x):
44
+ g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
45
+ return g
46
+
47
+
48
+ def slice_segments(x, ids_str, segment_size=4):
49
+ ret = torch.zeros_like(x[:, :, :segment_size])
50
+ for i in range(x.size(0)):
51
+ idx_str = ids_str[i]
52
+ idx_end = idx_str + segment_size
53
+ ret[i] = x[i, :, idx_str:idx_end]
54
+ return ret
55
+
56
+
57
+ def rand_slice_segments(x, x_lengths=None, segment_size=4):
58
+ b, d, t = x.size()
59
+ if x_lengths is None:
60
+ x_lengths = t
61
+ ids_str_max = x_lengths - segment_size + 1
62
+ ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
63
+ ret = slice_segments(x, ids_str, segment_size)
64
+ return ret, ids_str
65
+
66
+
67
+ def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
68
+ position = torch.arange(length, dtype=torch.float)
69
+ num_timescales = channels // 2
70
+ log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
71
+ num_timescales - 1
72
+ )
73
+ inv_timescales = min_timescale * torch.exp(
74
+ torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
75
+ )
76
+ scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
77
+ signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
78
+ signal = F.pad(signal, [0, 0, 0, channels % 2])
79
+ signal = signal.view(1, channels, length)
80
+ return signal
81
+
82
+
83
+ def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
84
+ b, channels, length = x.size()
85
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
86
+ return x + signal.to(dtype=x.dtype, device=x.device)
87
+
88
+
89
+ def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
90
+ b, channels, length = x.size()
91
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
92
+ return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
93
+
94
+
95
+ def subsequent_mask(length):
96
+ mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
97
+ return mask
98
+
99
+
100
+ @torch.jit.script
101
+ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
102
+ n_channels_int = n_channels[0]
103
+ in_act = input_a + input_b
104
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
105
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
106
+ acts = t_act * s_act
107
+ return acts
108
+
109
+
110
+ def convert_pad_shape(pad_shape):
111
+ layer = pad_shape[::-1]
112
+ pad_shape = [item for sublist in layer for item in sublist]
113
+ return pad_shape
114
+
115
+
116
+ def shift_1d(x):
117
+ x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
118
+ return x
119
+
120
+
121
+ def sequence_mask(length, max_length=None):
122
+ if max_length is None:
123
+ max_length = length.max()
124
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
125
+ return x.unsqueeze(0) < length.unsqueeze(1)
126
+
127
+
128
+ def generate_path(duration, mask):
129
+ """
130
+ duration: [b, 1, t_x]
131
+ mask: [b, 1, t_y, t_x]
132
+ """
133
+
134
+ b, _, t_y, t_x = mask.shape
135
+ cum_duration = torch.cumsum(duration, -1)
136
+
137
+ cum_duration_flat = cum_duration.view(b * t_x)
138
+ path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
139
+ path = path.view(b, t_x, t_y)
140
+ path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
141
+ path = path.unsqueeze(1).transpose(2, 3) * mask
142
+ return path
143
+
144
+
145
+ def clip_grad_value_(parameters, clip_value, norm_type=2):
146
+ if isinstance(parameters, torch.Tensor):
147
+ parameters = [parameters]
148
+ parameters = list(filter(lambda p: p.grad is not None, parameters))
149
+ norm_type = float(norm_type)
150
+ if clip_value is not None:
151
+ clip_value = float(clip_value)
152
+
153
+ total_norm = 0
154
+ for p in parameters:
155
+ param_norm = p.grad.data.norm(norm_type)
156
+ total_norm += param_norm.item() ** norm_type
157
+ if clip_value is not None:
158
+ p.grad.data.clamp_(min=-clip_value, max=clip_value)
159
+ total_norm = total_norm ** (1.0 / norm_type)
160
+ return total_norm
OpenVoice/openvoice/mel_processing.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.utils.data
3
+ from librosa.filters import mel as librosa_mel_fn
4
+
5
+ MAX_WAV_VALUE = 32768.0
6
+
7
+
8
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
9
+ """
10
+ PARAMS
11
+ ------
12
+ C: compression factor
13
+ """
14
+ return torch.log(torch.clamp(x, min=clip_val) * C)
15
+
16
+
17
+ def dynamic_range_decompression_torch(x, C=1):
18
+ """
19
+ PARAMS
20
+ ------
21
+ C: compression factor used to compress
22
+ """
23
+ return torch.exp(x) / C
24
+
25
+
26
+ def spectral_normalize_torch(magnitudes):
27
+ output = dynamic_range_compression_torch(magnitudes)
28
+ return output
29
+
30
+
31
+ def spectral_de_normalize_torch(magnitudes):
32
+ output = dynamic_range_decompression_torch(magnitudes)
33
+ return output
34
+
35
+
36
+ mel_basis = {}
37
+ hann_window = {}
38
+
39
+
40
+ def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
41
+ if torch.min(y) < -1.1:
42
+ print("min value is ", torch.min(y))
43
+ if torch.max(y) > 1.1:
44
+ print("max value is ", torch.max(y))
45
+
46
+ global hann_window
47
+ dtype_device = str(y.dtype) + "_" + str(y.device)
48
+ wnsize_dtype_device = str(win_size) + "_" + dtype_device
49
+ if wnsize_dtype_device not in hann_window:
50
+ hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
51
+ dtype=y.dtype, device=y.device
52
+ )
53
+
54
+ y = torch.nn.functional.pad(
55
+ y.unsqueeze(1),
56
+ (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
57
+ mode="reflect",
58
+ )
59
+ y = y.squeeze(1)
60
+
61
+ spec = torch.stft(
62
+ y,
63
+ n_fft,
64
+ hop_length=hop_size,
65
+ win_length=win_size,
66
+ window=hann_window[wnsize_dtype_device],
67
+ center=center,
68
+ pad_mode="reflect",
69
+ normalized=False,
70
+ onesided=True,
71
+ return_complex=False,
72
+ )
73
+
74
+ spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
75
+ return spec
76
+
77
+
78
+ def spectrogram_torch_conv(y, n_fft, sampling_rate, hop_size, win_size, center=False):
79
+ # if torch.min(y) < -1.:
80
+ # print('min value is ', torch.min(y))
81
+ # if torch.max(y) > 1.:
82
+ # print('max value is ', torch.max(y))
83
+
84
+ global hann_window
85
+ dtype_device = str(y.dtype) + '_' + str(y.device)
86
+ wnsize_dtype_device = str(win_size) + '_' + dtype_device
87
+ if wnsize_dtype_device not in hann_window:
88
+ hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
89
+
90
+ y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
91
+
92
+ # ******************** original ************************#
93
+ # y = y.squeeze(1)
94
+ # spec1 = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
95
+ # center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False)
96
+
97
+ # ******************** ConvSTFT ************************#
98
+ freq_cutoff = n_fft // 2 + 1
99
+ fourier_basis = torch.view_as_real(torch.fft.fft(torch.eye(n_fft)))
100
+ forward_basis = fourier_basis[:freq_cutoff].permute(2, 0, 1).reshape(-1, 1, fourier_basis.shape[1])
101
+ forward_basis = forward_basis * torch.as_tensor(librosa.util.pad_center(torch.hann_window(win_size), size=n_fft)).float()
102
+
103
+ import torch.nn.functional as F
104
+
105
+ # if center:
106
+ # signal = F.pad(y[:, None, None, :], (n_fft // 2, n_fft // 2, 0, 0), mode = 'reflect').squeeze(1)
107
+ assert center is False
108
+
109
+ forward_transform_squared = F.conv1d(y, forward_basis.to(y.device), stride = hop_size)
110
+ spec2 = torch.stack([forward_transform_squared[:, :freq_cutoff, :], forward_transform_squared[:, freq_cutoff:, :]], dim = -1)
111
+
112
+
113
+ # ******************** Verification ************************#
114
+ spec1 = torch.stft(y.squeeze(1), n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
115
+ center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False)
116
+ assert torch.allclose(spec1, spec2, atol=1e-4)
117
+
118
+ spec = torch.sqrt(spec2.pow(2).sum(-1) + 1e-6)
119
+ return spec
120
+
121
+
122
+ def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
123
+ global mel_basis
124
+ dtype_device = str(spec.dtype) + "_" + str(spec.device)
125
+ fmax_dtype_device = str(fmax) + "_" + dtype_device
126
+ if fmax_dtype_device not in mel_basis:
127
+ mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
128
+ mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
129
+ dtype=spec.dtype, device=spec.device
130
+ )
131
+ spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
132
+ spec = spectral_normalize_torch(spec)
133
+ return spec
134
+
135
+
136
+ def mel_spectrogram_torch(
137
+ y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False
138
+ ):
139
+ if torch.min(y) < -1.0:
140
+ print("min value is ", torch.min(y))
141
+ if torch.max(y) > 1.0:
142
+ print("max value is ", torch.max(y))
143
+
144
+ global mel_basis, hann_window
145
+ dtype_device = str(y.dtype) + "_" + str(y.device)
146
+ fmax_dtype_device = str(fmax) + "_" + dtype_device
147
+ wnsize_dtype_device = str(win_size) + "_" + dtype_device
148
+ if fmax_dtype_device not in mel_basis:
149
+ mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
150
+ mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
151
+ dtype=y.dtype, device=y.device
152
+ )
153
+ if wnsize_dtype_device not in hann_window:
154
+ hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
155
+ dtype=y.dtype, device=y.device
156
+ )
157
+
158
+ y = torch.nn.functional.pad(
159
+ y.unsqueeze(1),
160
+ (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
161
+ mode="reflect",
162
+ )
163
+ y = y.squeeze(1)
164
+
165
+ spec = torch.stft(
166
+ y,
167
+ n_fft,
168
+ hop_length=hop_size,
169
+ win_length=win_size,
170
+ window=hann_window[wnsize_dtype_device],
171
+ center=center,
172
+ pad_mode="reflect",
173
+ normalized=False,
174
+ onesided=True,
175
+ return_complex=False,
176
+ )
177
+
178
+ spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
179
+
180
+ spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
181
+ spec = spectral_normalize_torch(spec)
182
+
183
+ return spec
OpenVoice/openvoice/models.py ADDED
@@ -0,0 +1,499 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+
6
+ from openvoice import commons
7
+ from openvoice import modules
8
+ from openvoice import attentions
9
+
10
+ from torch.nn import Conv1d, ConvTranspose1d, Conv2d
11
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
12
+
13
+ from openvoice.commons import init_weights, get_padding
14
+
15
+
16
+ class TextEncoder(nn.Module):
17
+ def __init__(self,
18
+ n_vocab,
19
+ out_channels,
20
+ hidden_channels,
21
+ filter_channels,
22
+ n_heads,
23
+ n_layers,
24
+ kernel_size,
25
+ p_dropout):
26
+ super().__init__()
27
+ self.n_vocab = n_vocab
28
+ self.out_channels = out_channels
29
+ self.hidden_channels = hidden_channels
30
+ self.filter_channels = filter_channels
31
+ self.n_heads = n_heads
32
+ self.n_layers = n_layers
33
+ self.kernel_size = kernel_size
34
+ self.p_dropout = p_dropout
35
+
36
+ self.emb = nn.Embedding(n_vocab, hidden_channels)
37
+ nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
38
+
39
+ self.encoder = attentions.Encoder(
40
+ hidden_channels,
41
+ filter_channels,
42
+ n_heads,
43
+ n_layers,
44
+ kernel_size,
45
+ p_dropout)
46
+ self.proj= nn.Conv1d(hidden_channels, out_channels * 2, 1)
47
+
48
+ def forward(self, x, x_lengths):
49
+ x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
50
+ x = torch.transpose(x, 1, -1) # [b, h, t]
51
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
52
+
53
+ x = self.encoder(x * x_mask, x_mask)
54
+ stats = self.proj(x) * x_mask
55
+
56
+ m, logs = torch.split(stats, self.out_channels, dim=1)
57
+ return x, m, logs, x_mask
58
+
59
+
60
+ class DurationPredictor(nn.Module):
61
+ def __init__(
62
+ self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
63
+ ):
64
+ super().__init__()
65
+
66
+ self.in_channels = in_channels
67
+ self.filter_channels = filter_channels
68
+ self.kernel_size = kernel_size
69
+ self.p_dropout = p_dropout
70
+ self.gin_channels = gin_channels
71
+
72
+ self.drop = nn.Dropout(p_dropout)
73
+ self.conv_1 = nn.Conv1d(
74
+ in_channels, filter_channels, kernel_size, padding=kernel_size // 2
75
+ )
76
+ self.norm_1 = modules.LayerNorm(filter_channels)
77
+ self.conv_2 = nn.Conv1d(
78
+ filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
79
+ )
80
+ self.norm_2 = modules.LayerNorm(filter_channels)
81
+ self.proj = nn.Conv1d(filter_channels, 1, 1)
82
+
83
+ if gin_channels != 0:
84
+ self.cond = nn.Conv1d(gin_channels, in_channels, 1)
85
+
86
+ def forward(self, x, x_mask, g=None):
87
+ x = torch.detach(x)
88
+ if g is not None:
89
+ g = torch.detach(g)
90
+ x = x + self.cond(g)
91
+ x = self.conv_1(x * x_mask)
92
+ x = torch.relu(x)
93
+ x = self.norm_1(x)
94
+ x = self.drop(x)
95
+ x = self.conv_2(x * x_mask)
96
+ x = torch.relu(x)
97
+ x = self.norm_2(x)
98
+ x = self.drop(x)
99
+ x = self.proj(x * x_mask)
100
+ return x * x_mask
101
+
102
+ class StochasticDurationPredictor(nn.Module):
103
+ def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, n_flows=4, gin_channels=0):
104
+ super().__init__()
105
+ filter_channels = in_channels # it needs to be removed from future version.
106
+ self.in_channels = in_channels
107
+ self.filter_channels = filter_channels
108
+ self.kernel_size = kernel_size
109
+ self.p_dropout = p_dropout
110
+ self.n_flows = n_flows
111
+ self.gin_channels = gin_channels
112
+
113
+ self.log_flow = modules.Log()
114
+ self.flows = nn.ModuleList()
115
+ self.flows.append(modules.ElementwiseAffine(2))
116
+ for i in range(n_flows):
117
+ self.flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
118
+ self.flows.append(modules.Flip())
119
+
120
+ self.post_pre = nn.Conv1d(1, filter_channels, 1)
121
+ self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
122
+ self.post_convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
123
+ self.post_flows = nn.ModuleList()
124
+ self.post_flows.append(modules.ElementwiseAffine(2))
125
+ for i in range(4):
126
+ self.post_flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
127
+ self.post_flows.append(modules.Flip())
128
+
129
+ self.pre = nn.Conv1d(in_channels, filter_channels, 1)
130
+ self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
131
+ self.convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
132
+ if gin_channels != 0:
133
+ self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
134
+
135
+ def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0):
136
+ x = torch.detach(x)
137
+ x = self.pre(x)
138
+ if g is not None:
139
+ g = torch.detach(g)
140
+ x = x + self.cond(g)
141
+ x = self.convs(x, x_mask)
142
+ x = self.proj(x) * x_mask
143
+
144
+ if not reverse:
145
+ flows = self.flows
146
+ assert w is not None
147
+
148
+ logdet_tot_q = 0
149
+ h_w = self.post_pre(w)
150
+ h_w = self.post_convs(h_w, x_mask)
151
+ h_w = self.post_proj(h_w) * x_mask
152
+ e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask
153
+ z_q = e_q
154
+ for flow in self.post_flows:
155
+ z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
156
+ logdet_tot_q += logdet_q
157
+ z_u, z1 = torch.split(z_q, [1, 1], 1)
158
+ u = torch.sigmoid(z_u) * x_mask
159
+ z0 = (w - u) * x_mask
160
+ logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1,2])
161
+ logq = torch.sum(-0.5 * (math.log(2*math.pi) + (e_q**2)) * x_mask, [1,2]) - logdet_tot_q
162
+
163
+ logdet_tot = 0
164
+ z0, logdet = self.log_flow(z0, x_mask)
165
+ logdet_tot += logdet
166
+ z = torch.cat([z0, z1], 1)
167
+ for flow in flows:
168
+ z, logdet = flow(z, x_mask, g=x, reverse=reverse)
169
+ logdet_tot = logdet_tot + logdet
170
+ nll = torch.sum(0.5 * (math.log(2*math.pi) + (z**2)) * x_mask, [1,2]) - logdet_tot
171
+ return nll + logq # [b]
172
+ else:
173
+ flows = list(reversed(self.flows))
174
+ flows = flows[:-2] + [flows[-1]] # remove a useless vflow
175
+ z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale
176
+ for flow in flows:
177
+ z = flow(z, x_mask, g=x, reverse=reverse)
178
+ z0, z1 = torch.split(z, [1, 1], 1)
179
+ logw = z0
180
+ return logw
181
+
182
+ class PosteriorEncoder(nn.Module):
183
+ def __init__(
184
+ self,
185
+ in_channels,
186
+ out_channels,
187
+ hidden_channels,
188
+ kernel_size,
189
+ dilation_rate,
190
+ n_layers,
191
+ gin_channels=0,
192
+ ):
193
+ super().__init__()
194
+ self.in_channels = in_channels
195
+ self.out_channels = out_channels
196
+ self.hidden_channels = hidden_channels
197
+ self.kernel_size = kernel_size
198
+ self.dilation_rate = dilation_rate
199
+ self.n_layers = n_layers
200
+ self.gin_channels = gin_channels
201
+
202
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
203
+ self.enc = modules.WN(
204
+ hidden_channels,
205
+ kernel_size,
206
+ dilation_rate,
207
+ n_layers,
208
+ gin_channels=gin_channels,
209
+ )
210
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
211
+
212
+ def forward(self, x, x_lengths, g=None, tau=1.0):
213
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
214
+ x.dtype
215
+ )
216
+ x = self.pre(x) * x_mask
217
+ x = self.enc(x, x_mask, g=g)
218
+ stats = self.proj(x) * x_mask
219
+ m, logs = torch.split(stats, self.out_channels, dim=1)
220
+ z = (m + torch.randn_like(m) * tau * torch.exp(logs)) * x_mask
221
+ return z, m, logs, x_mask
222
+
223
+
224
+ class Generator(torch.nn.Module):
225
+ def __init__(
226
+ self,
227
+ initial_channel,
228
+ resblock,
229
+ resblock_kernel_sizes,
230
+ resblock_dilation_sizes,
231
+ upsample_rates,
232
+ upsample_initial_channel,
233
+ upsample_kernel_sizes,
234
+ gin_channels=0,
235
+ ):
236
+ super(Generator, self).__init__()
237
+ self.num_kernels = len(resblock_kernel_sizes)
238
+ self.num_upsamples = len(upsample_rates)
239
+ self.conv_pre = Conv1d(
240
+ initial_channel, upsample_initial_channel, 7, 1, padding=3
241
+ )
242
+ resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
243
+
244
+ self.ups = nn.ModuleList()
245
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
246
+ self.ups.append(
247
+ weight_norm(
248
+ ConvTranspose1d(
249
+ upsample_initial_channel // (2**i),
250
+ upsample_initial_channel // (2 ** (i + 1)),
251
+ k,
252
+ u,
253
+ padding=(k - u) // 2,
254
+ )
255
+ )
256
+ )
257
+
258
+ self.resblocks = nn.ModuleList()
259
+ for i in range(len(self.ups)):
260
+ ch = upsample_initial_channel // (2 ** (i + 1))
261
+ for j, (k, d) in enumerate(
262
+ zip(resblock_kernel_sizes, resblock_dilation_sizes)
263
+ ):
264
+ self.resblocks.append(resblock(ch, k, d))
265
+
266
+ self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
267
+ self.ups.apply(init_weights)
268
+
269
+ if gin_channels != 0:
270
+ self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
271
+
272
+ def forward(self, x, g=None):
273
+ x = self.conv_pre(x)
274
+ if g is not None:
275
+ x = x + self.cond(g)
276
+
277
+ for i in range(self.num_upsamples):
278
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
279
+ x = self.ups[i](x)
280
+ xs = None
281
+ for j in range(self.num_kernels):
282
+ if xs is None:
283
+ xs = self.resblocks[i * self.num_kernels + j](x)
284
+ else:
285
+ xs += self.resblocks[i * self.num_kernels + j](x)
286
+ x = xs / self.num_kernels
287
+ x = F.leaky_relu(x)
288
+ x = self.conv_post(x)
289
+ x = torch.tanh(x)
290
+
291
+ return x
292
+
293
+ def remove_weight_norm(self):
294
+ print("Removing weight norm...")
295
+ for layer in self.ups:
296
+ remove_weight_norm(layer)
297
+ for layer in self.resblocks:
298
+ layer.remove_weight_norm()
299
+
300
+
301
+ class ReferenceEncoder(nn.Module):
302
+ """
303
+ inputs --- [N, Ty/r, n_mels*r] mels
304
+ outputs --- [N, ref_enc_gru_size]
305
+ """
306
+
307
+ def __init__(self, spec_channels, gin_channels=0, layernorm=True):
308
+ super().__init__()
309
+ self.spec_channels = spec_channels
310
+ ref_enc_filters = [32, 32, 64, 64, 128, 128]
311
+ K = len(ref_enc_filters)
312
+ filters = [1] + ref_enc_filters
313
+ convs = [
314
+ weight_norm(
315
+ nn.Conv2d(
316
+ in_channels=filters[i],
317
+ out_channels=filters[i + 1],
318
+ kernel_size=(3, 3),
319
+ stride=(2, 2),
320
+ padding=(1, 1),
321
+ )
322
+ )
323
+ for i in range(K)
324
+ ]
325
+ self.convs = nn.ModuleList(convs)
326
+
327
+ out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K)
328
+ self.gru = nn.GRU(
329
+ input_size=ref_enc_filters[-1] * out_channels,
330
+ hidden_size=256 // 2,
331
+ batch_first=True,
332
+ )
333
+ self.proj = nn.Linear(128, gin_channels)
334
+ if layernorm:
335
+ self.layernorm = nn.LayerNorm(self.spec_channels)
336
+ else:
337
+ self.layernorm = None
338
+
339
+ def forward(self, inputs, mask=None):
340
+ N = inputs.size(0)
341
+
342
+ out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs]
343
+ if self.layernorm is not None:
344
+ out = self.layernorm(out)
345
+
346
+ for conv in self.convs:
347
+ out = conv(out)
348
+ # out = wn(out)
349
+ out = F.relu(out) # [N, 128, Ty//2^K, n_mels//2^K]
350
+
351
+ out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K]
352
+ T = out.size(1)
353
+ N = out.size(0)
354
+ out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K]
355
+
356
+ self.gru.flatten_parameters()
357
+ memory, out = self.gru(out) # out --- [1, N, 128]
358
+
359
+ return self.proj(out.squeeze(0))
360
+
361
+ def calculate_channels(self, L, kernel_size, stride, pad, n_convs):
362
+ for i in range(n_convs):
363
+ L = (L - kernel_size + 2 * pad) // stride + 1
364
+ return L
365
+
366
+
367
+ class ResidualCouplingBlock(nn.Module):
368
+ def __init__(self,
369
+ channels,
370
+ hidden_channels,
371
+ kernel_size,
372
+ dilation_rate,
373
+ n_layers,
374
+ n_flows=4,
375
+ gin_channels=0):
376
+ super().__init__()
377
+ self.channels = channels
378
+ self.hidden_channels = hidden_channels
379
+ self.kernel_size = kernel_size
380
+ self.dilation_rate = dilation_rate
381
+ self.n_layers = n_layers
382
+ self.n_flows = n_flows
383
+ self.gin_channels = gin_channels
384
+
385
+ self.flows = nn.ModuleList()
386
+ for i in range(n_flows):
387
+ self.flows.append(modules.ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels, mean_only=True))
388
+ self.flows.append(modules.Flip())
389
+
390
+ def forward(self, x, x_mask, g=None, reverse=False):
391
+ if not reverse:
392
+ for flow in self.flows:
393
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
394
+ else:
395
+ for flow in reversed(self.flows):
396
+ x = flow(x, x_mask, g=g, reverse=reverse)
397
+ return x
398
+
399
+ class SynthesizerTrn(nn.Module):
400
+ """
401
+ Synthesizer for Training
402
+ """
403
+
404
+ def __init__(
405
+ self,
406
+ n_vocab,
407
+ spec_channels,
408
+ inter_channels,
409
+ hidden_channels,
410
+ filter_channels,
411
+ n_heads,
412
+ n_layers,
413
+ kernel_size,
414
+ p_dropout,
415
+ resblock,
416
+ resblock_kernel_sizes,
417
+ resblock_dilation_sizes,
418
+ upsample_rates,
419
+ upsample_initial_channel,
420
+ upsample_kernel_sizes,
421
+ n_speakers=256,
422
+ gin_channels=256,
423
+ zero_g=False,
424
+ **kwargs
425
+ ):
426
+ super().__init__()
427
+
428
+ self.dec = Generator(
429
+ inter_channels,
430
+ resblock,
431
+ resblock_kernel_sizes,
432
+ resblock_dilation_sizes,
433
+ upsample_rates,
434
+ upsample_initial_channel,
435
+ upsample_kernel_sizes,
436
+ gin_channels=gin_channels,
437
+ )
438
+ self.enc_q = PosteriorEncoder(
439
+ spec_channels,
440
+ inter_channels,
441
+ hidden_channels,
442
+ 5,
443
+ 1,
444
+ 16,
445
+ gin_channels=gin_channels,
446
+ )
447
+
448
+ self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
449
+
450
+ self.n_speakers = n_speakers
451
+ if n_speakers == 0:
452
+ self.ref_enc = ReferenceEncoder(spec_channels, gin_channels)
453
+ else:
454
+ self.enc_p = TextEncoder(n_vocab,
455
+ inter_channels,
456
+ hidden_channels,
457
+ filter_channels,
458
+ n_heads,
459
+ n_layers,
460
+ kernel_size,
461
+ p_dropout)
462
+ self.sdp = StochasticDurationPredictor(hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels)
463
+ self.dp = DurationPredictor(hidden_channels, 256, 3, 0.5, gin_channels=gin_channels)
464
+ self.emb_g = nn.Embedding(n_speakers, gin_channels)
465
+ self.zero_g = zero_g
466
+
467
+ def infer(self, x, x_lengths, sid=None, noise_scale=1, length_scale=1, noise_scale_w=1., sdp_ratio=0.2, max_len=None):
468
+ x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths)
469
+ if self.n_speakers > 0:
470
+ g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
471
+ else:
472
+ g = None
473
+
474
+ logw = self.sdp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w) * sdp_ratio \
475
+ + self.dp(x, x_mask, g=g) * (1 - sdp_ratio)
476
+
477
+ w = torch.exp(logw) * x_mask * length_scale
478
+ w_ceil = torch.ceil(w)
479
+ y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
480
+ y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(x_mask.dtype)
481
+ attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
482
+ attn = commons.generate_path(w_ceil, attn_mask)
483
+
484
+ m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
485
+ logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
486
+
487
+ z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
488
+ z = self.flow(z_p, y_mask, g=g, reverse=True)
489
+ o = self.dec((z * y_mask)[:,:,:max_len], g=g)
490
+ return o, attn, y_mask, (z, z_p, m_p, logs_p)
491
+
492
+ def voice_conversion(self, y, y_lengths, sid_src, sid_tgt, tau=1.0):
493
+ g_src = sid_src
494
+ g_tgt = sid_tgt
495
+ z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g_src if not self.zero_g else torch.zeros_like(g_src), tau=tau)
496
+ z_p = self.flow(z, y_mask, g=g_src)
497
+ z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True)
498
+ o_hat = self.dec(z_hat * y_mask, g=g_tgt if not self.zero_g else torch.zeros_like(g_tgt))
499
+ return o_hat, y_mask, (z, z_p, z_hat)
OpenVoice/openvoice/modules.py ADDED
@@ -0,0 +1,598 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+
6
+ from torch.nn import Conv1d
7
+ from torch.nn.utils import weight_norm, remove_weight_norm
8
+
9
+ from openvoice import commons
10
+ from openvoice.commons import init_weights, get_padding
11
+ from openvoice.transforms import piecewise_rational_quadratic_transform
12
+ from openvoice.attentions import Encoder
13
+
14
+ LRELU_SLOPE = 0.1
15
+
16
+
17
+ class LayerNorm(nn.Module):
18
+ def __init__(self, channels, eps=1e-5):
19
+ super().__init__()
20
+ self.channels = channels
21
+ self.eps = eps
22
+
23
+ self.gamma = nn.Parameter(torch.ones(channels))
24
+ self.beta = nn.Parameter(torch.zeros(channels))
25
+
26
+ def forward(self, x):
27
+ x = x.transpose(1, -1)
28
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
29
+ return x.transpose(1, -1)
30
+
31
+
32
+ class ConvReluNorm(nn.Module):
33
+ def __init__(
34
+ self,
35
+ in_channels,
36
+ hidden_channels,
37
+ out_channels,
38
+ kernel_size,
39
+ n_layers,
40
+ p_dropout,
41
+ ):
42
+ super().__init__()
43
+ self.in_channels = in_channels
44
+ self.hidden_channels = hidden_channels
45
+ self.out_channels = out_channels
46
+ self.kernel_size = kernel_size
47
+ self.n_layers = n_layers
48
+ self.p_dropout = p_dropout
49
+ assert n_layers > 1, "Number of layers should be larger than 0."
50
+
51
+ self.conv_layers = nn.ModuleList()
52
+ self.norm_layers = nn.ModuleList()
53
+ self.conv_layers.append(
54
+ nn.Conv1d(
55
+ in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
56
+ )
57
+ )
58
+ self.norm_layers.append(LayerNorm(hidden_channels))
59
+ self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
60
+ for _ in range(n_layers - 1):
61
+ self.conv_layers.append(
62
+ nn.Conv1d(
63
+ hidden_channels,
64
+ hidden_channels,
65
+ kernel_size,
66
+ padding=kernel_size // 2,
67
+ )
68
+ )
69
+ self.norm_layers.append(LayerNorm(hidden_channels))
70
+ self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
71
+ self.proj.weight.data.zero_()
72
+ self.proj.bias.data.zero_()
73
+
74
+ def forward(self, x, x_mask):
75
+ x_org = x
76
+ for i in range(self.n_layers):
77
+ x = self.conv_layers[i](x * x_mask)
78
+ x = self.norm_layers[i](x)
79
+ x = self.relu_drop(x)
80
+ x = x_org + self.proj(x)
81
+ return x * x_mask
82
+
83
+
84
+ class DDSConv(nn.Module):
85
+ """
86
+ Dilated and Depth-Separable Convolution
87
+ """
88
+
89
+ def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0):
90
+ super().__init__()
91
+ self.channels = channels
92
+ self.kernel_size = kernel_size
93
+ self.n_layers = n_layers
94
+ self.p_dropout = p_dropout
95
+
96
+ self.drop = nn.Dropout(p_dropout)
97
+ self.convs_sep = nn.ModuleList()
98
+ self.convs_1x1 = nn.ModuleList()
99
+ self.norms_1 = nn.ModuleList()
100
+ self.norms_2 = nn.ModuleList()
101
+ for i in range(n_layers):
102
+ dilation = kernel_size**i
103
+ padding = (kernel_size * dilation - dilation) // 2
104
+ self.convs_sep.append(
105
+ nn.Conv1d(
106
+ channels,
107
+ channels,
108
+ kernel_size,
109
+ groups=channels,
110
+ dilation=dilation,
111
+ padding=padding,
112
+ )
113
+ )
114
+ self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
115
+ self.norms_1.append(LayerNorm(channels))
116
+ self.norms_2.append(LayerNorm(channels))
117
+
118
+ def forward(self, x, x_mask, g=None):
119
+ if g is not None:
120
+ x = x + g
121
+ for i in range(self.n_layers):
122
+ y = self.convs_sep[i](x * x_mask)
123
+ y = self.norms_1[i](y)
124
+ y = F.gelu(y)
125
+ y = self.convs_1x1[i](y)
126
+ y = self.norms_2[i](y)
127
+ y = F.gelu(y)
128
+ y = self.drop(y)
129
+ x = x + y
130
+ return x * x_mask
131
+
132
+
133
+ class WN(torch.nn.Module):
134
+ def __init__(
135
+ self,
136
+ hidden_channels,
137
+ kernel_size,
138
+ dilation_rate,
139
+ n_layers,
140
+ gin_channels=0,
141
+ p_dropout=0,
142
+ ):
143
+ super(WN, self).__init__()
144
+ assert kernel_size % 2 == 1
145
+ self.hidden_channels = hidden_channels
146
+ self.kernel_size = (kernel_size,)
147
+ self.dilation_rate = dilation_rate
148
+ self.n_layers = n_layers
149
+ self.gin_channels = gin_channels
150
+ self.p_dropout = p_dropout
151
+
152
+ self.in_layers = torch.nn.ModuleList()
153
+ self.res_skip_layers = torch.nn.ModuleList()
154
+ self.drop = nn.Dropout(p_dropout)
155
+
156
+ if gin_channels != 0:
157
+ cond_layer = torch.nn.Conv1d(
158
+ gin_channels, 2 * hidden_channels * n_layers, 1
159
+ )
160
+ self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
161
+
162
+ for i in range(n_layers):
163
+ dilation = dilation_rate**i
164
+ padding = int((kernel_size * dilation - dilation) / 2)
165
+ in_layer = torch.nn.Conv1d(
166
+ hidden_channels,
167
+ 2 * hidden_channels,
168
+ kernel_size,
169
+ dilation=dilation,
170
+ padding=padding,
171
+ )
172
+ in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
173
+ self.in_layers.append(in_layer)
174
+
175
+ # last one is not necessary
176
+ if i < n_layers - 1:
177
+ res_skip_channels = 2 * hidden_channels
178
+ else:
179
+ res_skip_channels = hidden_channels
180
+
181
+ res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
182
+ res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
183
+ self.res_skip_layers.append(res_skip_layer)
184
+
185
+ def forward(self, x, x_mask, g=None, **kwargs):
186
+ output = torch.zeros_like(x)
187
+ n_channels_tensor = torch.IntTensor([self.hidden_channels])
188
+
189
+ if g is not None:
190
+ g = self.cond_layer(g)
191
+
192
+ for i in range(self.n_layers):
193
+ x_in = self.in_layers[i](x)
194
+ if g is not None:
195
+ cond_offset = i * 2 * self.hidden_channels
196
+ g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
197
+ else:
198
+ g_l = torch.zeros_like(x_in)
199
+
200
+ acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
201
+ acts = self.drop(acts)
202
+
203
+ res_skip_acts = self.res_skip_layers[i](acts)
204
+ if i < self.n_layers - 1:
205
+ res_acts = res_skip_acts[:, : self.hidden_channels, :]
206
+ x = (x + res_acts) * x_mask
207
+ output = output + res_skip_acts[:, self.hidden_channels :, :]
208
+ else:
209
+ output = output + res_skip_acts
210
+ return output * x_mask
211
+
212
+ def remove_weight_norm(self):
213
+ if self.gin_channels != 0:
214
+ torch.nn.utils.remove_weight_norm(self.cond_layer)
215
+ for l in self.in_layers:
216
+ torch.nn.utils.remove_weight_norm(l)
217
+ for l in self.res_skip_layers:
218
+ torch.nn.utils.remove_weight_norm(l)
219
+
220
+
221
+ class ResBlock1(torch.nn.Module):
222
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
223
+ super(ResBlock1, self).__init__()
224
+ self.convs1 = nn.ModuleList(
225
+ [
226
+ weight_norm(
227
+ Conv1d(
228
+ channels,
229
+ channels,
230
+ kernel_size,
231
+ 1,
232
+ dilation=dilation[0],
233
+ padding=get_padding(kernel_size, dilation[0]),
234
+ )
235
+ ),
236
+ weight_norm(
237
+ Conv1d(
238
+ channels,
239
+ channels,
240
+ kernel_size,
241
+ 1,
242
+ dilation=dilation[1],
243
+ padding=get_padding(kernel_size, dilation[1]),
244
+ )
245
+ ),
246
+ weight_norm(
247
+ Conv1d(
248
+ channels,
249
+ channels,
250
+ kernel_size,
251
+ 1,
252
+ dilation=dilation[2],
253
+ padding=get_padding(kernel_size, dilation[2]),
254
+ )
255
+ ),
256
+ ]
257
+ )
258
+ self.convs1.apply(init_weights)
259
+
260
+ self.convs2 = nn.ModuleList(
261
+ [
262
+ weight_norm(
263
+ Conv1d(
264
+ channels,
265
+ channels,
266
+ kernel_size,
267
+ 1,
268
+ dilation=1,
269
+ padding=get_padding(kernel_size, 1),
270
+ )
271
+ ),
272
+ weight_norm(
273
+ Conv1d(
274
+ channels,
275
+ channels,
276
+ kernel_size,
277
+ 1,
278
+ dilation=1,
279
+ padding=get_padding(kernel_size, 1),
280
+ )
281
+ ),
282
+ weight_norm(
283
+ Conv1d(
284
+ channels,
285
+ channels,
286
+ kernel_size,
287
+ 1,
288
+ dilation=1,
289
+ padding=get_padding(kernel_size, 1),
290
+ )
291
+ ),
292
+ ]
293
+ )
294
+ self.convs2.apply(init_weights)
295
+
296
+ def forward(self, x, x_mask=None):
297
+ for c1, c2 in zip(self.convs1, self.convs2):
298
+ xt = F.leaky_relu(x, LRELU_SLOPE)
299
+ if x_mask is not None:
300
+ xt = xt * x_mask
301
+ xt = c1(xt)
302
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
303
+ if x_mask is not None:
304
+ xt = xt * x_mask
305
+ xt = c2(xt)
306
+ x = xt + x
307
+ if x_mask is not None:
308
+ x = x * x_mask
309
+ return x
310
+
311
+ def remove_weight_norm(self):
312
+ for l in self.convs1:
313
+ remove_weight_norm(l)
314
+ for l in self.convs2:
315
+ remove_weight_norm(l)
316
+
317
+
318
+ class ResBlock2(torch.nn.Module):
319
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
320
+ super(ResBlock2, self).__init__()
321
+ self.convs = nn.ModuleList(
322
+ [
323
+ weight_norm(
324
+ Conv1d(
325
+ channels,
326
+ channels,
327
+ kernel_size,
328
+ 1,
329
+ dilation=dilation[0],
330
+ padding=get_padding(kernel_size, dilation[0]),
331
+ )
332
+ ),
333
+ weight_norm(
334
+ Conv1d(
335
+ channels,
336
+ channels,
337
+ kernel_size,
338
+ 1,
339
+ dilation=dilation[1],
340
+ padding=get_padding(kernel_size, dilation[1]),
341
+ )
342
+ ),
343
+ ]
344
+ )
345
+ self.convs.apply(init_weights)
346
+
347
+ def forward(self, x, x_mask=None):
348
+ for c in self.convs:
349
+ xt = F.leaky_relu(x, LRELU_SLOPE)
350
+ if x_mask is not None:
351
+ xt = xt * x_mask
352
+ xt = c(xt)
353
+ x = xt + x
354
+ if x_mask is not None:
355
+ x = x * x_mask
356
+ return x
357
+
358
+ def remove_weight_norm(self):
359
+ for l in self.convs:
360
+ remove_weight_norm(l)
361
+
362
+
363
+ class Log(nn.Module):
364
+ def forward(self, x, x_mask, reverse=False, **kwargs):
365
+ if not reverse:
366
+ y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
367
+ logdet = torch.sum(-y, [1, 2])
368
+ return y, logdet
369
+ else:
370
+ x = torch.exp(x) * x_mask
371
+ return x
372
+
373
+
374
+ class Flip(nn.Module):
375
+ def forward(self, x, *args, reverse=False, **kwargs):
376
+ x = torch.flip(x, [1])
377
+ if not reverse:
378
+ logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
379
+ return x, logdet
380
+ else:
381
+ return x
382
+
383
+
384
+ class ElementwiseAffine(nn.Module):
385
+ def __init__(self, channels):
386
+ super().__init__()
387
+ self.channels = channels
388
+ self.m = nn.Parameter(torch.zeros(channels, 1))
389
+ self.logs = nn.Parameter(torch.zeros(channels, 1))
390
+
391
+ def forward(self, x, x_mask, reverse=False, **kwargs):
392
+ if not reverse:
393
+ y = self.m + torch.exp(self.logs) * x
394
+ y = y * x_mask
395
+ logdet = torch.sum(self.logs * x_mask, [1, 2])
396
+ return y, logdet
397
+ else:
398
+ x = (x - self.m) * torch.exp(-self.logs) * x_mask
399
+ return x
400
+
401
+
402
+ class ResidualCouplingLayer(nn.Module):
403
+ def __init__(
404
+ self,
405
+ channels,
406
+ hidden_channels,
407
+ kernel_size,
408
+ dilation_rate,
409
+ n_layers,
410
+ p_dropout=0,
411
+ gin_channels=0,
412
+ mean_only=False,
413
+ ):
414
+ assert channels % 2 == 0, "channels should be divisible by 2"
415
+ super().__init__()
416
+ self.channels = channels
417
+ self.hidden_channels = hidden_channels
418
+ self.kernel_size = kernel_size
419
+ self.dilation_rate = dilation_rate
420
+ self.n_layers = n_layers
421
+ self.half_channels = channels // 2
422
+ self.mean_only = mean_only
423
+
424
+ self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
425
+ self.enc = WN(
426
+ hidden_channels,
427
+ kernel_size,
428
+ dilation_rate,
429
+ n_layers,
430
+ p_dropout=p_dropout,
431
+ gin_channels=gin_channels,
432
+ )
433
+ self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
434
+ self.post.weight.data.zero_()
435
+ self.post.bias.data.zero_()
436
+
437
+ def forward(self, x, x_mask, g=None, reverse=False):
438
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
439
+ h = self.pre(x0) * x_mask
440
+ h = self.enc(h, x_mask, g=g)
441
+ stats = self.post(h) * x_mask
442
+ if not self.mean_only:
443
+ m, logs = torch.split(stats, [self.half_channels] * 2, 1)
444
+ else:
445
+ m = stats
446
+ logs = torch.zeros_like(m)
447
+
448
+ if not reverse:
449
+ x1 = m + x1 * torch.exp(logs) * x_mask
450
+ x = torch.cat([x0, x1], 1)
451
+ logdet = torch.sum(logs, [1, 2])
452
+ return x, logdet
453
+ else:
454
+ x1 = (x1 - m) * torch.exp(-logs) * x_mask
455
+ x = torch.cat([x0, x1], 1)
456
+ return x
457
+
458
+
459
+ class ConvFlow(nn.Module):
460
+ def __init__(
461
+ self,
462
+ in_channels,
463
+ filter_channels,
464
+ kernel_size,
465
+ n_layers,
466
+ num_bins=10,
467
+ tail_bound=5.0,
468
+ ):
469
+ super().__init__()
470
+ self.in_channels = in_channels
471
+ self.filter_channels = filter_channels
472
+ self.kernel_size = kernel_size
473
+ self.n_layers = n_layers
474
+ self.num_bins = num_bins
475
+ self.tail_bound = tail_bound
476
+ self.half_channels = in_channels // 2
477
+
478
+ self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
479
+ self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0)
480
+ self.proj = nn.Conv1d(
481
+ filter_channels, self.half_channels * (num_bins * 3 - 1), 1
482
+ )
483
+ self.proj.weight.data.zero_()
484
+ self.proj.bias.data.zero_()
485
+
486
+ def forward(self, x, x_mask, g=None, reverse=False):
487
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
488
+ h = self.pre(x0)
489
+ h = self.convs(h, x_mask, g=g)
490
+ h = self.proj(h) * x_mask
491
+
492
+ b, c, t = x0.shape
493
+ h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
494
+
495
+ unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels)
496
+ unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(
497
+ self.filter_channels
498
+ )
499
+ unnormalized_derivatives = h[..., 2 * self.num_bins :]
500
+
501
+ x1, logabsdet = piecewise_rational_quadratic_transform(
502
+ x1,
503
+ unnormalized_widths,
504
+ unnormalized_heights,
505
+ unnormalized_derivatives,
506
+ inverse=reverse,
507
+ tails="linear",
508
+ tail_bound=self.tail_bound,
509
+ )
510
+
511
+ x = torch.cat([x0, x1], 1) * x_mask
512
+ logdet = torch.sum(logabsdet * x_mask, [1, 2])
513
+ if not reverse:
514
+ return x, logdet
515
+ else:
516
+ return x
517
+
518
+
519
+ class TransformerCouplingLayer(nn.Module):
520
+ def __init__(
521
+ self,
522
+ channels,
523
+ hidden_channels,
524
+ kernel_size,
525
+ n_layers,
526
+ n_heads,
527
+ p_dropout=0,
528
+ filter_channels=0,
529
+ mean_only=False,
530
+ wn_sharing_parameter=None,
531
+ gin_channels=0,
532
+ ):
533
+ assert n_layers == 3, n_layers
534
+ assert channels % 2 == 0, "channels should be divisible by 2"
535
+ super().__init__()
536
+ self.channels = channels
537
+ self.hidden_channels = hidden_channels
538
+ self.kernel_size = kernel_size
539
+ self.n_layers = n_layers
540
+ self.half_channels = channels // 2
541
+ self.mean_only = mean_only
542
+
543
+ self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
544
+ self.enc = (
545
+ Encoder(
546
+ hidden_channels,
547
+ filter_channels,
548
+ n_heads,
549
+ n_layers,
550
+ kernel_size,
551
+ p_dropout,
552
+ isflow=True,
553
+ gin_channels=gin_channels,
554
+ )
555
+ if wn_sharing_parameter is None
556
+ else wn_sharing_parameter
557
+ )
558
+ self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
559
+ self.post.weight.data.zero_()
560
+ self.post.bias.data.zero_()
561
+
562
+ def forward(self, x, x_mask, g=None, reverse=False):
563
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
564
+ h = self.pre(x0) * x_mask
565
+ h = self.enc(h, x_mask, g=g)
566
+ stats = self.post(h) * x_mask
567
+ if not self.mean_only:
568
+ m, logs = torch.split(stats, [self.half_channels] * 2, 1)
569
+ else:
570
+ m = stats
571
+ logs = torch.zeros_like(m)
572
+
573
+ if not reverse:
574
+ x1 = m + x1 * torch.exp(logs) * x_mask
575
+ x = torch.cat([x0, x1], 1)
576
+ logdet = torch.sum(logs, [1, 2])
577
+ return x, logdet
578
+ else:
579
+ x1 = (x1 - m) * torch.exp(-logs) * x_mask
580
+ x = torch.cat([x0, x1], 1)
581
+ return x
582
+
583
+ x1, logabsdet = piecewise_rational_quadratic_transform(
584
+ x1,
585
+ unnormalized_widths,
586
+ unnormalized_heights,
587
+ unnormalized_derivatives,
588
+ inverse=reverse,
589
+ tails="linear",
590
+ tail_bound=self.tail_bound,
591
+ )
592
+
593
+ x = torch.cat([x0, x1], 1) * x_mask
594
+ logdet = torch.sum(logabsdet * x_mask, [1, 2])
595
+ if not reverse:
596
+ return x, logdet
597
+ else:
598
+ return x
OpenVoice/openvoice/openvoice_app.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import argparse
4
+ import gradio as gr
5
+ from zipfile import ZipFile
6
+ import langid
7
+ from openvoice import se_extractor
8
+ from openvoice.api import BaseSpeakerTTS, ToneColorConverter
9
+
10
+ parser = argparse.ArgumentParser()
11
+ parser.add_argument("--share", action='store_true', default=False, help="make link public")
12
+ args = parser.parse_args()
13
+
14
+ en_ckpt_base = 'checkpoints/base_speakers/EN'
15
+ zh_ckpt_base = 'checkpoints/base_speakers/ZH'
16
+ ckpt_converter = 'checkpoints/converter'
17
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
18
+ output_dir = 'outputs'
19
+ os.makedirs(output_dir, exist_ok=True)
20
+
21
+ # load models
22
+ en_base_speaker_tts = BaseSpeakerTTS(f'{en_ckpt_base}/config.json', device=device)
23
+ en_base_speaker_tts.load_ckpt(f'{en_ckpt_base}/checkpoint.pth')
24
+ zh_base_speaker_tts = BaseSpeakerTTS(f'{zh_ckpt_base}/config.json', device=device)
25
+ zh_base_speaker_tts.load_ckpt(f'{zh_ckpt_base}/checkpoint.pth')
26
+ tone_color_converter = ToneColorConverter(f'{ckpt_converter}/config.json', device=device)
27
+ tone_color_converter.load_ckpt(f'{ckpt_converter}/checkpoint.pth')
28
+
29
+ # load speaker embeddings
30
+ en_source_default_se = torch.load(f'{en_ckpt_base}/en_default_se.pth').to(device)
31
+ en_source_style_se = torch.load(f'{en_ckpt_base}/en_style_se.pth').to(device)
32
+ zh_source_se = torch.load(f'{zh_ckpt_base}/zh_default_se.pth').to(device)
33
+
34
+ # This online demo mainly supports English and Chinese
35
+ supported_languages = ['zh', 'en']
36
+
37
+ def predict(prompt, style, audio_file_pth, agree):
38
+ # initialize a empty info
39
+ text_hint = ''
40
+ # agree with the terms
41
+ if agree == False:
42
+ text_hint += '[ERROR] Please accept the Terms & Condition!\n'
43
+ gr.Warning("Please accept the Terms & Condition!")
44
+ return (
45
+ text_hint,
46
+ None,
47
+ None,
48
+ )
49
+
50
+ # first detect the input language
51
+ language_predicted = langid.classify(prompt)[0].strip()
52
+ print(f"Detected language:{language_predicted}")
53
+
54
+ if language_predicted not in supported_languages:
55
+ text_hint += f"[ERROR] The detected language {language_predicted} for your input text is not in our Supported Languages: {supported_languages}\n"
56
+ gr.Warning(
57
+ f"The detected language {language_predicted} for your input text is not in our Supported Languages: {supported_languages}"
58
+ )
59
+
60
+ return (
61
+ text_hint,
62
+ None,
63
+ None,
64
+ )
65
+
66
+ if language_predicted == "zh":
67
+ tts_model = zh_base_speaker_tts
68
+ source_se = zh_source_se
69
+ language = 'Chinese'
70
+ if style not in ['default']:
71
+ text_hint += f"[ERROR] The style {style} is not supported for Chinese, which should be in ['default']\n"
72
+ gr.Warning(f"The style {style} is not supported for Chinese, which should be in ['default']")
73
+ return (
74
+ text_hint,
75
+ None,
76
+ None,
77
+ )
78
+
79
+ else:
80
+ tts_model = en_base_speaker_tts
81
+ if style == 'default':
82
+ source_se = en_source_default_se
83
+ else:
84
+ source_se = en_source_style_se
85
+ language = 'English'
86
+ if style not in ['default', 'whispering', 'shouting', 'excited', 'cheerful', 'terrified', 'angry', 'sad', 'friendly']:
87
+ text_hint += f"[ERROR] The style {style} is not supported for English, which should be in ['default', 'whispering', 'shouting', 'excited', 'cheerful', 'terrified', 'angry', 'sad', 'friendly']\n"
88
+ gr.Warning(f"The style {style} is not supported for English, which should be in ['default', 'whispering', 'shouting', 'excited', 'cheerful', 'terrified', 'angry', 'sad', 'friendly']")
89
+ return (
90
+ text_hint,
91
+ None,
92
+ None,
93
+ )
94
+
95
+ speaker_wav = audio_file_pth
96
+
97
+ if len(prompt) < 2:
98
+ text_hint += f"[ERROR] Please give a longer prompt text \n"
99
+ gr.Warning("Please give a longer prompt text")
100
+ return (
101
+ text_hint,
102
+ None,
103
+ None,
104
+ )
105
+ if len(prompt) > 200:
106
+ text_hint += f"[ERROR] Text length limited to 200 characters for this demo, please try shorter text. You can clone our open-source repo and try for your usage \n"
107
+ gr.Warning(
108
+ "Text length limited to 200 characters for this demo, please try shorter text. You can clone our open-source repo for your usage"
109
+ )
110
+ return (
111
+ text_hint,
112
+ None,
113
+ None,
114
+ )
115
+
116
+ # note diffusion_conditioning not used on hifigan (default mode), it will be empty but need to pass it to model.inference
117
+ try:
118
+ target_se, audio_name = se_extractor.get_se(speaker_wav, tone_color_converter, target_dir='processed', vad=True)
119
+ except Exception as e:
120
+ text_hint += f"[ERROR] Get target tone color error {str(e)} \n"
121
+ gr.Warning(
122
+ "[ERROR] Get target tone color error {str(e)} \n"
123
+ )
124
+ return (
125
+ text_hint,
126
+ None,
127
+ None,
128
+ )
129
+
130
+ src_path = f'{output_dir}/tmp.wav'
131
+ tts_model.tts(prompt, src_path, speaker=style, language=language)
132
+
133
+ save_path = f'{output_dir}/output.wav'
134
+ # Run the tone color converter
135
+ encode_message = "@MyShell"
136
+ tone_color_converter.convert(
137
+ audio_src_path=src_path,
138
+ src_se=source_se,
139
+ tgt_se=target_se,
140
+ output_path=save_path,
141
+ message=encode_message)
142
+
143
+ text_hint += f'''Get response successfully \n'''
144
+
145
+ return (
146
+ text_hint,
147
+ save_path,
148
+ speaker_wav,
149
+ )
150
+
151
+
152
+
153
+ title = "MyShell OpenVoice"
154
+
155
+ description = """
156
+ We introduce OpenVoice, a versatile instant voice cloning approach that requires only a short audio clip from the reference speaker to replicate their voice and generate speech in multiple languages. OpenVoice enables granular control over voice styles, including emotion, accent, rhythm, pauses, and intonation, in addition to replicating the tone color of the reference speaker. OpenVoice also achieves zero-shot cross-lingual voice cloning for languages not included in the massive-speaker training set.
157
+ """
158
+
159
+ markdown_table = """
160
+ <div align="center" style="margin-bottom: 10px;">
161
+
162
+ | | | |
163
+ | :-----------: | :-----------: | :-----------: |
164
+ | **OpenSource Repo** | **Project Page** | **Join the Community** |
165
+ | <div style='text-align: center;'><a style="display:inline-block,align:center" href='https://github.com/myshell-ai/OpenVoice'><img src='https://img.shields.io/github/stars/myshell-ai/OpenVoice?style=social' /></a></div> | [OpenVoice](https://research.myshell.ai/open-voice) | [![Discord](https://img.shields.io/discord/1122227993805336617?color=%239B59B6&label=%20Discord%20)](https://discord.gg/myshell) |
166
+
167
+ </div>
168
+ """
169
+
170
+ markdown_table_v2 = """
171
+ <div align="center" style="margin-bottom: 2px;">
172
+
173
+ | | | | |
174
+ | :-----------: | :-----------: | :-----------: | :-----------: |
175
+ | **OpenSource Repo** | <div style='text-align: center;'><a style="display:inline-block,align:center" href='https://github.com/myshell-ai/OpenVoice'><img src='https://img.shields.io/github/stars/myshell-ai/OpenVoice?style=social' /></a></div> | **Project Page** | [OpenVoice](https://research.myshell.ai/open-voice) |
176
+
177
+ | | |
178
+ | :-----------: | :-----------: |
179
+ **Join the Community** | [![Discord](https://img.shields.io/discord/1122227993805336617?color=%239B59B6&label=%20Discord%20)](https://discord.gg/myshell) |
180
+
181
+ </div>
182
+ """
183
+ content = """
184
+ <div>
185
+ <strong>If the generated voice does not sound like the reference voice, please refer to <a href='https://github.com/myshell-ai/OpenVoice/blob/main/docs/QA.md'>this QnA</a>.</strong> <strong>For multi-lingual & cross-lingual examples, please refer to <a href='https://github.com/myshell-ai/OpenVoice/blob/main/demo_part2.ipynb'>this jupyter notebook</a>.</strong>
186
+ This online demo mainly supports <strong>English</strong>. The <em>default</em> style also supports <strong>Chinese</strong>. But OpenVoice can adapt to any other language as long as a base speaker is provided.
187
+ </div>
188
+ """
189
+ wrapped_markdown_content = f"<div style='border: 1px solid #000; padding: 10px;'>{content}</div>"
190
+
191
+
192
+ examples = [
193
+ [
194
+ "今天天气真好,我们一起出去吃饭吧。",
195
+ 'default',
196
+ "resources/demo_speaker1.mp3",
197
+ True,
198
+ ],[
199
+ "This audio is generated by open voice with a half-performance model.",
200
+ 'whispering',
201
+ "resources/demo_speaker2.mp3",
202
+ True,
203
+ ],
204
+ [
205
+ "He hoped there would be stew for dinner, turnips and carrots and bruised potatoes and fat mutton pieces to be ladled out in thick, peppered, flour-fattened sauce.",
206
+ 'sad',
207
+ "resources/demo_speaker0.mp3",
208
+ True,
209
+ ],
210
+ ]
211
+
212
+ with gr.Blocks(analytics_enabled=False) as demo:
213
+
214
+ with gr.Row():
215
+ with gr.Column():
216
+ with gr.Row():
217
+ gr.Markdown(
218
+ """
219
+ ## <img src="https://huggingface.co/spaces/myshell-ai/OpenVoice/raw/main/logo.jpg" height="40"/>
220
+ """
221
+ )
222
+ with gr.Row():
223
+ gr.Markdown(markdown_table_v2)
224
+ with gr.Row():
225
+ gr.Markdown(description)
226
+ with gr.Column():
227
+ gr.Video('https://github.com/myshell-ai/OpenVoice/assets/40556743/3cba936f-82bf-476c-9e52-09f0f417bb2f', autoplay=True)
228
+
229
+ with gr.Row():
230
+ gr.HTML(wrapped_markdown_content)
231
+
232
+ with gr.Row():
233
+ with gr.Column():
234
+ input_text_gr = gr.Textbox(
235
+ label="Text Prompt",
236
+ info="One or two sentences at a time is better. Up to 200 text characters.",
237
+ value="He hoped there would be stew for dinner, turnips and carrots and bruised potatoes and fat mutton pieces to be ladled out in thick, peppered, flour-fattened sauce.",
238
+ )
239
+ style_gr = gr.Dropdown(
240
+ label="Style",
241
+ info="Select a style of output audio for the synthesised speech. (Chinese only support 'default' now)",
242
+ choices=['default', 'whispering', 'cheerful', 'terrified', 'angry', 'sad', 'friendly'],
243
+ max_choices=1,
244
+ value="default",
245
+ )
246
+ ref_gr = gr.Audio(
247
+ label="Reference Audio",
248
+ info="Click on the ✎ button to upload your own target speaker audio",
249
+ type="filepath",
250
+ value="resources/demo_speaker2.mp3",
251
+ )
252
+ tos_gr = gr.Checkbox(
253
+ label="Agree",
254
+ value=False,
255
+ info="I agree to the terms of the cc-by-nc-4.0 license-: https://github.com/myshell-ai/OpenVoice/blob/main/LICENSE",
256
+ )
257
+
258
+ tts_button = gr.Button("Send", elem_id="send-btn", visible=True)
259
+
260
+
261
+ with gr.Column():
262
+ out_text_gr = gr.Text(label="Info")
263
+ audio_gr = gr.Audio(label="Synthesised Audio", autoplay=True)
264
+ ref_audio_gr = gr.Audio(label="Reference Audio Used")
265
+
266
+ gr.Examples(examples,
267
+ label="Examples",
268
+ inputs=[input_text_gr, style_gr, ref_gr, tos_gr],
269
+ outputs=[out_text_gr, audio_gr, ref_audio_gr],
270
+ fn=predict,
271
+ cache_examples=False,)
272
+ tts_button.click(predict, [input_text_gr, style_gr, ref_gr, tos_gr], outputs=[out_text_gr, audio_gr, ref_audio_gr])
273
+
274
+ demo.queue()
275
+ demo.launch(debug=True, show_api=True, share=args.share)
OpenVoice/openvoice/se_extractor.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import torch
4
+ import hashlib
5
+ import librosa
6
+ import base64
7
+ from glob import glob
8
+ import numpy as np
9
+ from pydub import AudioSegment
10
+ import hashlib
11
+ import base64
12
+ import librosa
13
+ from whisper_timestamped.transcribe import get_audio_tensor, get_vad_segments
14
+
15
+ model_size = "medium"
16
+ # Run on GPU with FP16
17
+ model = None
18
+ def split_audio_whisper(audio_path, audio_name, target_dir='processed'):
19
+ global model
20
+ if model is None:
21
+ # Lazy import to avoid loading faster-whisper on environments where it is unsupported
22
+ from faster_whisper import WhisperModel # type: ignore
23
+ model = WhisperModel(model_size, device="cuda", compute_type="float16")
24
+ audio = AudioSegment.from_file(audio_path)
25
+ max_len = len(audio)
26
+
27
+ target_folder = os.path.join(target_dir, audio_name)
28
+
29
+ segments, info = model.transcribe(audio_path, beam_size=5, word_timestamps=True)
30
+ segments = list(segments)
31
+
32
+ # create directory
33
+ os.makedirs(target_folder, exist_ok=True)
34
+ wavs_folder = os.path.join(target_folder, 'wavs')
35
+ os.makedirs(wavs_folder, exist_ok=True)
36
+
37
+ # segments
38
+ s_ind = 0
39
+ start_time = None
40
+
41
+ for k, w in enumerate(segments):
42
+ # process with the time
43
+ if k == 0:
44
+ start_time = max(0, w.start)
45
+
46
+ end_time = w.end
47
+
48
+ # calculate confidence
49
+ if len(w.words) > 0:
50
+ confidence = sum([s.probability for s in w.words]) / len(w.words)
51
+ else:
52
+ confidence = 0.
53
+ # clean text
54
+ text = w.text.replace('...', '')
55
+
56
+ # left 0.08s for each audios
57
+ audio_seg = audio[int( start_time * 1000) : min(max_len, int(end_time * 1000) + 80)]
58
+
59
+ # segment file name
60
+ fname = f"{audio_name}_seg{s_ind}.wav"
61
+
62
+ # filter out the segment shorter than 1.5s and longer than 20s
63
+ save = audio_seg.duration_seconds > 1.5 and \
64
+ audio_seg.duration_seconds < 20. and \
65
+ len(text) >= 2 and len(text) < 200
66
+
67
+ if save:
68
+ output_file = os.path.join(wavs_folder, fname)
69
+ audio_seg.export(output_file, format='wav')
70
+
71
+ if k < len(segments) - 1:
72
+ start_time = max(0, segments[k+1].start - 0.08)
73
+
74
+ s_ind = s_ind + 1
75
+ return wavs_folder
76
+
77
+
78
+ def split_audio_vad(audio_path, audio_name, target_dir, split_seconds=10.0):
79
+ SAMPLE_RATE = 16000
80
+ audio_vad = get_audio_tensor(audio_path)
81
+ segments = get_vad_segments(
82
+ audio_vad,
83
+ output_sample=True,
84
+ min_speech_duration=0.1,
85
+ min_silence_duration=1,
86
+ method="silero",
87
+ )
88
+ segments = [(seg["start"], seg["end"]) for seg in segments]
89
+ segments = [(float(s) / SAMPLE_RATE, float(e) / SAMPLE_RATE) for s,e in segments]
90
+ print(segments)
91
+ audio_active = AudioSegment.silent(duration=0)
92
+ audio = AudioSegment.from_file(audio_path)
93
+
94
+ for start_time, end_time in segments:
95
+ audio_active += audio[int( start_time * 1000) : int(end_time * 1000)]
96
+
97
+ audio_dur = audio_active.duration_seconds
98
+ print(f'after vad: dur = {audio_dur}')
99
+ target_folder = os.path.join(target_dir, audio_name)
100
+ wavs_folder = os.path.join(target_folder, 'wavs')
101
+ os.makedirs(wavs_folder, exist_ok=True)
102
+ start_time = 0.
103
+ count = 0
104
+ num_splits = int(np.round(audio_dur / split_seconds))
105
+ assert num_splits > 0, 'input audio is too short'
106
+ interval = audio_dur / num_splits
107
+
108
+ for i in range(num_splits):
109
+ end_time = min(start_time + interval, audio_dur)
110
+ if i == num_splits - 1:
111
+ end_time = audio_dur
112
+ output_file = f"{wavs_folder}/{audio_name}_seg{count}.wav"
113
+ audio_seg = audio_active[int(start_time * 1000): int(end_time * 1000)]
114
+ audio_seg.export(output_file, format='wav')
115
+ start_time = end_time
116
+ count += 1
117
+ return wavs_folder
118
+
119
+ def hash_numpy_array(audio_path):
120
+ array, _ = librosa.load(audio_path, sr=None, mono=True)
121
+ # Convert the array to bytes
122
+ array_bytes = array.tobytes()
123
+ # Calculate the hash of the array bytes
124
+ hash_object = hashlib.sha256(array_bytes)
125
+ hash_value = hash_object.digest()
126
+ # Convert the hash value to base64
127
+ base64_value = base64.b64encode(hash_value)
128
+ return base64_value.decode('utf-8')[:16].replace('/', '_^')
129
+
130
+ def get_se(audio_path, vc_model, target_dir='processed', vad=True):
131
+ device = vc_model.device
132
+ version = vc_model.version
133
+ print("OpenVoice version:", version)
134
+
135
+ audio_name = f"{os.path.basename(audio_path).rsplit('.', 1)[0]}_{version}_{hash_numpy_array(audio_path)}"
136
+ se_path = os.path.join(target_dir, audio_name, 'se.pth')
137
+
138
+ # if os.path.isfile(se_path):
139
+ # se = torch.load(se_path).to(device)
140
+ # return se, audio_name
141
+ # if os.path.isdir(audio_path):
142
+ # wavs_folder = audio_path
143
+
144
+ if vad:
145
+ wavs_folder = split_audio_vad(audio_path, target_dir=target_dir, audio_name=audio_name)
146
+ else:
147
+ wavs_folder = split_audio_whisper(audio_path, target_dir=target_dir, audio_name=audio_name)
148
+
149
+ audio_segs = glob(f'{wavs_folder}/*.wav')
150
+ if len(audio_segs) == 0:
151
+ raise NotImplementedError('No audio segments found!')
152
+
153
+ return vc_model.extract_se(audio_segs, se_save_path=se_path), audio_name
154
+
OpenVoice/openvoice/text/__init__.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ from https://github.com/keithito/tacotron """
2
+ from openvoice.text import cleaners
3
+ from openvoice.text.symbols import symbols
4
+
5
+
6
+ # Mappings from symbol to numeric ID and vice versa:
7
+ _symbol_to_id = {s: i for i, s in enumerate(symbols)}
8
+ _id_to_symbol = {i: s for i, s in enumerate(symbols)}
9
+
10
+
11
+ def text_to_sequence(text, symbols, cleaner_names):
12
+ '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
13
+ Args:
14
+ text: string to convert to a sequence
15
+ cleaner_names: names of the cleaner functions to run the text through
16
+ Returns:
17
+ List of integers corresponding to the symbols in the text
18
+ '''
19
+ sequence = []
20
+ symbol_to_id = {s: i for i, s in enumerate(symbols)}
21
+ clean_text = _clean_text(text, cleaner_names)
22
+ print(clean_text)
23
+ print(f" length:{len(clean_text)}")
24
+ for symbol in clean_text:
25
+ if symbol not in symbol_to_id.keys():
26
+ continue
27
+ symbol_id = symbol_to_id[symbol]
28
+ sequence += [symbol_id]
29
+ print(f" length:{len(sequence)}")
30
+ return sequence
31
+
32
+
33
+ def cleaned_text_to_sequence(cleaned_text, symbols):
34
+ '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
35
+ Args:
36
+ text: string to convert to a sequence
37
+ Returns:
38
+ List of integers corresponding to the symbols in the text
39
+ '''
40
+ symbol_to_id = {s: i for i, s in enumerate(symbols)}
41
+ sequence = [symbol_to_id[symbol] for symbol in cleaned_text if symbol in symbol_to_id.keys()]
42
+ return sequence
43
+
44
+
45
+
46
+ from openvoice.text.symbols import language_tone_start_map
47
+ def cleaned_text_to_sequence_vits2(cleaned_text, tones, language, symbols, languages):
48
+ """Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
49
+ Args:
50
+ text: string to convert to a sequence
51
+ Returns:
52
+ List of integers corresponding to the symbols in the text
53
+ """
54
+ symbol_to_id = {s: i for i, s in enumerate(symbols)}
55
+ language_id_map = {s: i for i, s in enumerate(languages)}
56
+ phones = [symbol_to_id[symbol] for symbol in cleaned_text]
57
+ tone_start = language_tone_start_map[language]
58
+ tones = [i + tone_start for i in tones]
59
+ lang_id = language_id_map[language]
60
+ lang_ids = [lang_id for i in phones]
61
+ return phones, tones, lang_ids
62
+
63
+
64
+ def sequence_to_text(sequence):
65
+ '''Converts a sequence of IDs back to a string'''
66
+ result = ''
67
+ for symbol_id in sequence:
68
+ s = _id_to_symbol[symbol_id]
69
+ result += s
70
+ return result
71
+
72
+
73
+ def _clean_text(text, cleaner_names):
74
+ for name in cleaner_names:
75
+ cleaner = getattr(cleaners, name)
76
+ if not cleaner:
77
+ raise Exception('Unknown cleaner: %s' % name)
78
+ text = cleaner(text)
79
+ return text
OpenVoice/openvoice/text/cleaners.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from openvoice.text.english import english_to_lazy_ipa, english_to_ipa2, english_to_lazy_ipa2
3
+ from openvoice.text.mandarin import number_to_chinese, chinese_to_bopomofo, latin_to_bopomofo, chinese_to_romaji, chinese_to_lazy_ipa, chinese_to_ipa, chinese_to_ipa2
4
+
5
+ def cjke_cleaners2(text):
6
+ text = re.sub(r'\[ZH\](.*?)\[ZH\]',
7
+ lambda x: chinese_to_ipa(x.group(1))+' ', text)
8
+ text = re.sub(r'\[JA\](.*?)\[JA\]',
9
+ lambda x: japanese_to_ipa2(x.group(1))+' ', text)
10
+ text = re.sub(r'\[KO\](.*?)\[KO\]',
11
+ lambda x: korean_to_ipa(x.group(1))+' ', text)
12
+ text = re.sub(r'\[EN\](.*?)\[EN\]',
13
+ lambda x: english_to_ipa2(x.group(1))+' ', text)
14
+ text = re.sub(r'\s+$', '', text)
15
+ text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text)
16
+ return text
OpenVoice/openvoice/text/english.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ from https://github.com/keithito/tacotron """
2
+
3
+ '''
4
+ Cleaners are transformations that run over the input text at both training and eval time.
5
+
6
+ Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
7
+ hyperparameter. Some cleaners are English-specific. You'll typically want to use:
8
+ 1. "english_cleaners" for English text
9
+ 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
10
+ the Unidecode library (https://pypi.python.org/pypi/Unidecode)
11
+ 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
12
+ the symbols in symbols.py to match your data).
13
+ '''
14
+
15
+
16
+ # Regular expression matching whitespace:
17
+
18
+
19
+ import re
20
+ import inflect
21
+ from unidecode import unidecode
22
+ import eng_to_ipa as ipa
23
+ _inflect = inflect.engine()
24
+ _comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])')
25
+ _decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)')
26
+ _pounds_re = re.compile(r'£([0-9\,]*[0-9]+)')
27
+ _dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)')
28
+ _ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)')
29
+ _number_re = re.compile(r'[0-9]+')
30
+
31
+ # List of (regular expression, replacement) pairs for abbreviations:
32
+ _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
33
+ ('mrs', 'misess'),
34
+ ('mr', 'mister'),
35
+ ('dr', 'doctor'),
36
+ ('st', 'saint'),
37
+ ('co', 'company'),
38
+ ('jr', 'junior'),
39
+ ('maj', 'major'),
40
+ ('gen', 'general'),
41
+ ('drs', 'doctors'),
42
+ ('rev', 'reverend'),
43
+ ('lt', 'lieutenant'),
44
+ ('hon', 'honorable'),
45
+ ('sgt', 'sergeant'),
46
+ ('capt', 'captain'),
47
+ ('esq', 'esquire'),
48
+ ('ltd', 'limited'),
49
+ ('col', 'colonel'),
50
+ ('ft', 'fort'),
51
+ ]]
52
+
53
+
54
+ # List of (ipa, lazy ipa) pairs:
55
+ _lazy_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [
56
+ ('r', 'ɹ'),
57
+ ('æ', 'e'),
58
+ ('ɑ', 'a'),
59
+ ('ɔ', 'o'),
60
+ ('ð', 'z'),
61
+ ('θ', 's'),
62
+ ('ɛ', 'e'),
63
+ ('ɪ', 'i'),
64
+ ('ʊ', 'u'),
65
+ ('ʒ', 'ʥ'),
66
+ ('ʤ', 'ʥ'),
67
+ ('ˈ', '↓'),
68
+ ]]
69
+
70
+ # List of (ipa, lazy ipa2) pairs:
71
+ _lazy_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [
72
+ ('r', 'ɹ'),
73
+ ('ð', 'z'),
74
+ ('θ', 's'),
75
+ ('ʒ', 'ʑ'),
76
+ ('ʤ', 'dʑ'),
77
+ ('ˈ', '↓'),
78
+ ]]
79
+
80
+ # List of (ipa, ipa2) pairs
81
+ _ipa_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [
82
+ ('r', 'ɹ'),
83
+ ('ʤ', 'dʒ'),
84
+ ('ʧ', 'tʃ')
85
+ ]]
86
+
87
+
88
+ def expand_abbreviations(text):
89
+ for regex, replacement in _abbreviations:
90
+ text = re.sub(regex, replacement, text)
91
+ return text
92
+
93
+
94
+ def collapse_whitespace(text):
95
+ return re.sub(r'\s+', ' ', text)
96
+
97
+
98
+ def _remove_commas(m):
99
+ return m.group(1).replace(',', '')
100
+
101
+
102
+ def _expand_decimal_point(m):
103
+ return m.group(1).replace('.', ' point ')
104
+
105
+
106
+ def _expand_dollars(m):
107
+ match = m.group(1)
108
+ parts = match.split('.')
109
+ if len(parts) > 2:
110
+ return match + ' dollars' # Unexpected format
111
+ dollars = int(parts[0]) if parts[0] else 0
112
+ cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
113
+ if dollars and cents:
114
+ dollar_unit = 'dollar' if dollars == 1 else 'dollars'
115
+ cent_unit = 'cent' if cents == 1 else 'cents'
116
+ return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit)
117
+ elif dollars:
118
+ dollar_unit = 'dollar' if dollars == 1 else 'dollars'
119
+ return '%s %s' % (dollars, dollar_unit)
120
+ elif cents:
121
+ cent_unit = 'cent' if cents == 1 else 'cents'
122
+ return '%s %s' % (cents, cent_unit)
123
+ else:
124
+ return 'zero dollars'
125
+
126
+
127
+ def _expand_ordinal(m):
128
+ return _inflect.number_to_words(m.group(0))
129
+
130
+
131
+ def _expand_number(m):
132
+ num = int(m.group(0))
133
+ if num > 1000 and num < 3000:
134
+ if num == 2000:
135
+ return 'two thousand'
136
+ elif num > 2000 and num < 2010:
137
+ return 'two thousand ' + _inflect.number_to_words(num % 100)
138
+ elif num % 100 == 0:
139
+ return _inflect.number_to_words(num // 100) + ' hundred'
140
+ else:
141
+ return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ')
142
+ else:
143
+ return _inflect.number_to_words(num, andword='')
144
+
145
+
146
+ def normalize_numbers(text):
147
+ text = re.sub(_comma_number_re, _remove_commas, text)
148
+ text = re.sub(_pounds_re, r'\1 pounds', text)
149
+ text = re.sub(_dollars_re, _expand_dollars, text)
150
+ text = re.sub(_decimal_number_re, _expand_decimal_point, text)
151
+ text = re.sub(_ordinal_re, _expand_ordinal, text)
152
+ text = re.sub(_number_re, _expand_number, text)
153
+ return text
154
+
155
+
156
+ def mark_dark_l(text):
157
+ return re.sub(r'l([^aeiouæɑɔəɛɪʊ ]*(?: |$))', lambda x: 'ɫ'+x.group(1), text)
158
+
159
+
160
+ def english_to_ipa(text):
161
+ text = unidecode(text).lower()
162
+ text = expand_abbreviations(text)
163
+ text = normalize_numbers(text)
164
+ phonemes = ipa.convert(text)
165
+ phonemes = collapse_whitespace(phonemes)
166
+ return phonemes
167
+
168
+
169
+ def english_to_lazy_ipa(text):
170
+ text = english_to_ipa(text)
171
+ for regex, replacement in _lazy_ipa:
172
+ text = re.sub(regex, replacement, text)
173
+ return text
174
+
175
+
176
+ def english_to_ipa2(text):
177
+ text = english_to_ipa(text)
178
+ text = mark_dark_l(text)
179
+ for regex, replacement in _ipa_to_ipa2:
180
+ text = re.sub(regex, replacement, text)
181
+ return text.replace('...', '…')
182
+
183
+
184
+ def english_to_lazy_ipa2(text):
185
+ text = english_to_ipa(text)
186
+ for regex, replacement in _lazy_ipa2:
187
+ text = re.sub(regex, replacement, text)
188
+ return text
OpenVoice/openvoice/text/mandarin.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import re
4
+ from pypinyin import lazy_pinyin, BOPOMOFO
5
+ import jieba
6
+ import cn2an
7
+ import logging
8
+
9
+
10
+ # List of (Latin alphabet, bopomofo) pairs:
11
+ _latin_to_bopomofo = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [
12
+ ('a', 'ㄟˉ'),
13
+ ('b', 'ㄅㄧˋ'),
14
+ ('c', 'ㄙㄧˉ'),
15
+ ('d', 'ㄉㄧˋ'),
16
+ ('e', 'ㄧˋ'),
17
+ ('f', 'ㄝˊㄈㄨˋ'),
18
+ ('g', 'ㄐㄧˋ'),
19
+ ('h', 'ㄝˇㄑㄩˋ'),
20
+ ('i', 'ㄞˋ'),
21
+ ('j', 'ㄐㄟˋ'),
22
+ ('k', 'ㄎㄟˋ'),
23
+ ('l', 'ㄝˊㄛˋ'),
24
+ ('m', 'ㄝˊㄇㄨˋ'),
25
+ ('n', 'ㄣˉ'),
26
+ ('o', 'ㄡˉ'),
27
+ ('p', 'ㄆㄧˉ'),
28
+ ('q', 'ㄎㄧㄡˉ'),
29
+ ('r', 'ㄚˋ'),
30
+ ('s', 'ㄝˊㄙˋ'),
31
+ ('t', 'ㄊㄧˋ'),
32
+ ('u', 'ㄧㄡˉ'),
33
+ ('v', 'ㄨㄧˉ'),
34
+ ('w', 'ㄉㄚˋㄅㄨˋㄌㄧㄡˋ'),
35
+ ('x', 'ㄝˉㄎㄨˋㄙˋ'),
36
+ ('y', 'ㄨㄞˋ'),
37
+ ('z', 'ㄗㄟˋ')
38
+ ]]
39
+
40
+ # List of (bopomofo, romaji) pairs:
41
+ _bopomofo_to_romaji = [(re.compile('%s' % x[0]), x[1]) for x in [
42
+ ('ㄅㄛ', 'p⁼wo'),
43
+ ('ㄆㄛ', 'pʰwo'),
44
+ ('ㄇㄛ', 'mwo'),
45
+ ('ㄈㄛ', 'fwo'),
46
+ ('ㄅ', 'p⁼'),
47
+ ('ㄆ', 'pʰ'),
48
+ ('ㄇ', 'm'),
49
+ ('ㄈ', 'f'),
50
+ ('ㄉ', 't⁼'),
51
+ ('ㄊ', 'tʰ'),
52
+ ('ㄋ', 'n'),
53
+ ('ㄌ', 'l'),
54
+ ('ㄍ', 'k⁼'),
55
+ ('ㄎ', 'kʰ'),
56
+ ('ㄏ', 'h'),
57
+ ('ㄐ', 'ʧ⁼'),
58
+ ('ㄑ', 'ʧʰ'),
59
+ ('ㄒ', 'ʃ'),
60
+ ('ㄓ', 'ʦ`⁼'),
61
+ ('ㄔ', 'ʦ`ʰ'),
62
+ ('ㄕ', 's`'),
63
+ ('ㄖ', 'ɹ`'),
64
+ ('ㄗ', 'ʦ⁼'),
65
+ ('ㄘ', 'ʦʰ'),
66
+ ('ㄙ', 's'),
67
+ ('ㄚ', 'a'),
68
+ ('ㄛ', 'o'),
69
+ ('ㄜ', 'ə'),
70
+ ('ㄝ', 'e'),
71
+ ('ㄞ', 'ai'),
72
+ ('ㄟ', 'ei'),
73
+ ('ㄠ', 'au'),
74
+ ('ㄡ', 'ou'),
75
+ ('ㄧㄢ', 'yeNN'),
76
+ ('ㄢ', 'aNN'),
77
+ ('ㄧㄣ', 'iNN'),
78
+ ('ㄣ', 'əNN'),
79
+ ('ㄤ', 'aNg'),
80
+ ('ㄧㄥ', 'iNg'),
81
+ ('ㄨㄥ', 'uNg'),
82
+ ('ㄩㄥ', 'yuNg'),
83
+ ('ㄥ', 'əNg'),
84
+ ('ㄦ', 'əɻ'),
85
+ ('ㄧ', 'i'),
86
+ ('ㄨ', 'u'),
87
+ ('ㄩ', 'ɥ'),
88
+ ('ˉ', '→'),
89
+ ('ˊ', '↑'),
90
+ ('ˇ', '↓↑'),
91
+ ('ˋ', '↓'),
92
+ ('˙', ''),
93
+ (',', ','),
94
+ ('。', '.'),
95
+ ('!', '!'),
96
+ ('?', '?'),
97
+ ('—', '-')
98
+ ]]
99
+
100
+ # List of (romaji, ipa) pairs:
101
+ _romaji_to_ipa = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [
102
+ ('ʃy', 'ʃ'),
103
+ ('ʧʰy', 'ʧʰ'),
104
+ ('ʧ⁼y', 'ʧ⁼'),
105
+ ('NN', 'n'),
106
+ ('Ng', 'ŋ'),
107
+ ('y', 'j'),
108
+ ('h', 'x')
109
+ ]]
110
+
111
+ # List of (bopomofo, ipa) pairs:
112
+ _bopomofo_to_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [
113
+ ('ㄅㄛ', 'p⁼wo'),
114
+ ('ㄆㄛ', 'pʰwo'),
115
+ ('ㄇㄛ', 'mwo'),
116
+ ('ㄈㄛ', 'fwo'),
117
+ ('ㄅ', 'p⁼'),
118
+ ('ㄆ', 'pʰ'),
119
+ ('ㄇ', 'm'),
120
+ ('ㄈ', 'f'),
121
+ ('ㄉ', 't⁼'),
122
+ ('ㄊ', 'tʰ'),
123
+ ('ㄋ', 'n'),
124
+ ('ㄌ', 'l'),
125
+ ('ㄍ', 'k⁼'),
126
+ ('ㄎ', 'kʰ'),
127
+ ('ㄏ', 'x'),
128
+ ('ㄐ', 'tʃ⁼'),
129
+ ('ㄑ', 'tʃʰ'),
130
+ ('ㄒ', 'ʃ'),
131
+ ('ㄓ', 'ts`⁼'),
132
+ ('ㄔ', 'ts`ʰ'),
133
+ ('ㄕ', 's`'),
134
+ ('ㄖ', 'ɹ`'),
135
+ ('ㄗ', 'ts⁼'),
136
+ ('ㄘ', 'tsʰ'),
137
+ ('ㄙ', 's'),
138
+ ('ㄚ', 'a'),
139
+ ('ㄛ', 'o'),
140
+ ('ㄜ', 'ə'),
141
+ ('ㄝ', 'ɛ'),
142
+ ('ㄞ', 'aɪ'),
143
+ ('ㄟ', 'eɪ'),
144
+ ('ㄠ', 'ɑʊ'),
145
+ ('ㄡ', 'oʊ'),
146
+ ('ㄧㄢ', 'jɛn'),
147
+ ('ㄩㄢ', 'ɥæn'),
148
+ ('ㄢ', 'an'),
149
+ ('ㄧㄣ', 'in'),
150
+ ('ㄩㄣ', 'ɥn'),
151
+ ('ㄣ', 'ən'),
152
+ ('ㄤ', 'ɑŋ'),
153
+ ('ㄧㄥ', 'iŋ'),
154
+ ('ㄨㄥ', 'ʊŋ'),
155
+ ('ㄩㄥ', 'jʊŋ'),
156
+ ('ㄥ', 'əŋ'),
157
+ ('ㄦ', 'əɻ'),
158
+ ('ㄧ', 'i'),
159
+ ('ㄨ', 'u'),
160
+ ('ㄩ', 'ɥ'),
161
+ ('ˉ', '→'),
162
+ ('ˊ', '↑'),
163
+ ('ˇ', '↓↑'),
164
+ ('ˋ', '↓'),
165
+ ('˙', ''),
166
+ (',', ','),
167
+ ('。', '.'),
168
+ ('!', '!'),
169
+ ('?', '?'),
170
+ ('—', '-')
171
+ ]]
172
+
173
+ # List of (bopomofo, ipa2) pairs:
174
+ _bopomofo_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [
175
+ ('ㄅㄛ', 'pwo'),
176
+ ('ㄆㄛ', 'pʰwo'),
177
+ ('ㄇㄛ', 'mwo'),
178
+ ('ㄈㄛ', 'fwo'),
179
+ ('ㄅ', 'p'),
180
+ ('ㄆ', 'pʰ'),
181
+ ('ㄇ', 'm'),
182
+ ('ㄈ', 'f'),
183
+ ('ㄉ', 't'),
184
+ ('ㄊ', 'tʰ'),
185
+ ('ㄋ', 'n'),
186
+ ('ㄌ', 'l'),
187
+ ('ㄍ', 'k'),
188
+ ('ㄎ', 'kʰ'),
189
+ ('ㄏ', 'h'),
190
+ ('ㄐ', 'tɕ'),
191
+ ('ㄑ', 'tɕʰ'),
192
+ ('ㄒ', 'ɕ'),
193
+ ('ㄓ', 'tʂ'),
194
+ ('ㄔ', 'tʂʰ'),
195
+ ('ㄕ', 'ʂ'),
196
+ ('ㄖ', 'ɻ'),
197
+ ('ㄗ', 'ts'),
198
+ ('ㄘ', 'tsʰ'),
199
+ ('ㄙ', 's'),
200
+ ('ㄚ', 'a'),
201
+ ('ㄛ', 'o'),
202
+ ('ㄜ', 'ɤ'),
203
+ ('ㄝ', 'ɛ'),
204
+ ('ㄞ', 'aɪ'),
205
+ ('ㄟ', 'eɪ'),
206
+ ('ㄠ', 'ɑʊ'),
207
+ ('ㄡ', 'oʊ'),
208
+ ('ㄧㄢ', 'jɛn'),
209
+ ('ㄩㄢ', 'yæn'),
210
+ ('ㄢ', 'an'),
211
+ ('ㄧㄣ', 'in'),
212
+ ('ㄩㄣ', 'yn'),
213
+ ('ㄣ', 'ən'),
214
+ ('ㄤ', 'ɑŋ'),
215
+ ('ㄧㄥ', 'iŋ'),
216
+ ('ㄨㄥ', 'ʊŋ'),
217
+ ('ㄩㄥ', 'jʊŋ'),
218
+ ('ㄥ', 'ɤŋ'),
219
+ ('ㄦ', 'əɻ'),
220
+ ('ㄧ', 'i'),
221
+ ('ㄨ', 'u'),
222
+ ('ㄩ', 'y'),
223
+ ('ˉ', '˥'),
224
+ ('ˊ', '˧˥'),
225
+ ('ˇ', '˨˩˦'),
226
+ ('ˋ', '˥˩'),
227
+ ('˙', ''),
228
+ (',', ','),
229
+ ('。', '.'),
230
+ ('!', '!'),
231
+ ('?', '?'),
232
+ ('—', '-')
233
+ ]]
234
+
235
+
236
+ def number_to_chinese(text):
237
+ numbers = re.findall(r'\d+(?:\.?\d+)?', text)
238
+ for number in numbers:
239
+ text = text.replace(number, cn2an.an2cn(number), 1)
240
+ return text
241
+
242
+
243
+ def chinese_to_bopomofo(text):
244
+ text = text.replace('、', ',').replace(';', ',').replace(':', ',')
245
+ words = jieba.lcut(text, cut_all=False)
246
+ text = ''
247
+ for word in words:
248
+ bopomofos = lazy_pinyin(word, BOPOMOFO)
249
+ if not re.search('[\u4e00-\u9fff]', word):
250
+ text += word
251
+ continue
252
+ for i in range(len(bopomofos)):
253
+ bopomofos[i] = re.sub(r'([\u3105-\u3129])$', r'\1ˉ', bopomofos[i])
254
+ if text != '':
255
+ text += ' '
256
+ text += ''.join(bopomofos)
257
+ return text
258
+
259
+
260
+ def latin_to_bopomofo(text):
261
+ for regex, replacement in _latin_to_bopomofo:
262
+ text = re.sub(regex, replacement, text)
263
+ return text
264
+
265
+
266
+ def bopomofo_to_romaji(text):
267
+ for regex, replacement in _bopomofo_to_romaji:
268
+ text = re.sub(regex, replacement, text)
269
+ return text
270
+
271
+
272
+ def bopomofo_to_ipa(text):
273
+ for regex, replacement in _bopomofo_to_ipa:
274
+ text = re.sub(regex, replacement, text)
275
+ return text
276
+
277
+
278
+ def bopomofo_to_ipa2(text):
279
+ for regex, replacement in _bopomofo_to_ipa2:
280
+ text = re.sub(regex, replacement, text)
281
+ return text
282
+
283
+
284
+ def chinese_to_romaji(text):
285
+ text = number_to_chinese(text)
286
+ text = chinese_to_bopomofo(text)
287
+ text = latin_to_bopomofo(text)
288
+ text = bopomofo_to_romaji(text)
289
+ text = re.sub('i([aoe])', r'y\1', text)
290
+ text = re.sub('u([aoəe])', r'w\1', text)
291
+ text = re.sub('([ʦsɹ]`[⁼ʰ]?)([→↓↑ ]+|$)',
292
+ r'\1ɹ`\2', text).replace('ɻ', 'ɹ`')
293
+ text = re.sub('([ʦs][⁼ʰ]?)([→↓↑ ]+|$)', r'\1ɹ\2', text)
294
+ return text
295
+
296
+
297
+ def chinese_to_lazy_ipa(text):
298
+ text = chinese_to_romaji(text)
299
+ for regex, replacement in _romaji_to_ipa:
300
+ text = re.sub(regex, replacement, text)
301
+ return text
302
+
303
+
304
+ def chinese_to_ipa(text):
305
+ text = number_to_chinese(text)
306
+ text = chinese_to_bopomofo(text)
307
+ text = latin_to_bopomofo(text)
308
+ text = bopomofo_to_ipa(text)
309
+ text = re.sub('i([aoe])', r'j\1', text)
310
+ text = re.sub('u([aoəe])', r'w\1', text)
311
+ text = re.sub('([sɹ]`[⁼ʰ]?)([→↓↑ ]+|$)',
312
+ r'\1ɹ`\2', text).replace('ɻ', 'ɹ`')
313
+ text = re.sub('([s][⁼ʰ]?)([→↓↑ ]+|$)', r'\1ɹ\2', text)
314
+ return text
315
+
316
+
317
+ def chinese_to_ipa2(text):
318
+ text = number_to_chinese(text)
319
+ text = chinese_to_bopomofo(text)
320
+ text = latin_to_bopomofo(text)
321
+ text = bopomofo_to_ipa2(text)
322
+ text = re.sub(r'i([aoe])', r'j\1', text)
323
+ text = re.sub(r'u([aoəe])', r'w\1', text)
324
+ text = re.sub(r'([ʂɹ]ʰ?)([˩˨˧˦˥ ]+|$)', r'\1ʅ\2', text)
325
+ text = re.sub(r'(sʰ?)([˩˨˧˦˥ ]+|$)', r'\1ɿ\2', text)
326
+ return text
OpenVoice/openvoice/text/symbols.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Defines the set of symbols used in text input to the model.
3
+ '''
4
+
5
+ # japanese_cleaners
6
+ # _pad = '_'
7
+ # _punctuation = ',.!?-'
8
+ # _letters = 'AEINOQUabdefghijkmnoprstuvwyzʃʧ↓↑ '
9
+
10
+
11
+ '''# japanese_cleaners2
12
+ _pad = '_'
13
+ _punctuation = ',.!?-~…'
14
+ _letters = 'AEINOQUabdefghijkmnoprstuvwyzʃʧʦ↓↑ '
15
+ '''
16
+
17
+
18
+ '''# korean_cleaners
19
+ _pad = '_'
20
+ _punctuation = ',.!?…~'
21
+ _letters = 'ㄱㄴㄷㄹㅁㅂㅅㅇㅈㅊㅋㅌㅍㅎㄲㄸㅃㅆㅉㅏㅓㅗㅜㅡㅣㅐㅔ '
22
+ '''
23
+
24
+ '''# chinese_cleaners
25
+ _pad = '_'
26
+ _punctuation = ',。!?—…'
27
+ _letters = 'ㄅㄆㄇㄈㄉㄊㄋㄌㄍㄎㄏㄐㄑㄒㄓㄔㄕㄖㄗㄘㄙㄚㄛㄜㄝㄞㄟㄠㄡㄢㄣㄤㄥㄦㄧㄨㄩˉˊˇˋ˙ '
28
+ '''
29
+
30
+ # # zh_ja_mixture_cleaners
31
+ # _pad = '_'
32
+ # _punctuation = ',.!?-~…'
33
+ # _letters = 'AEINOQUabdefghijklmnoprstuvwyzʃʧʦɯɹəɥ⁼ʰ`→↓↑ '
34
+
35
+
36
+ '''# sanskrit_cleaners
37
+ _pad = '_'
38
+ _punctuation = '।'
39
+ _letters = 'ँंःअआइईउऊऋएऐओऔकखगघङचछजझञटठडढणतथदधनपफबभमयरलळवशषसहऽािीुूृॄेैोौ्ॠॢ '
40
+ '''
41
+
42
+ '''# cjks_cleaners
43
+ _pad = '_'
44
+ _punctuation = ',.!?-~…'
45
+ _letters = 'NQabdefghijklmnopstuvwxyzʃʧʥʦɯɹəɥçɸɾβŋɦː⁼ʰ`^#*=→↓↑ '
46
+ '''
47
+
48
+ '''# thai_cleaners
49
+ _pad = '_'
50
+ _punctuation = '.!? '
51
+ _letters = 'กขฃคฆงจฉชซฌญฎฏฐฑฒณดตถทธนบปผฝพฟภมยรฤลวศษสหฬอฮฯะัาำิีึืุูเแโใไๅๆ็่้๊๋์'
52
+ '''
53
+
54
+ # # cjke_cleaners2
55
+ _pad = '_'
56
+ _punctuation = ',.!?-~…'
57
+ _letters = 'NQabdefghijklmnopstuvwxyzɑæʃʑçɯɪɔɛɹðəɫɥɸʊɾʒθβŋɦ⁼ʰ`^#*=ˈˌ→↓↑ '
58
+
59
+
60
+ '''# shanghainese_cleaners
61
+ _pad = '_'
62
+ _punctuation = ',.!?…'
63
+ _letters = 'abdfghiklmnopstuvyzøŋȵɑɔɕəɤɦɪɿʑʔʰ̩̃ᴀᴇ15678 '
64
+ '''
65
+
66
+ '''# chinese_dialect_cleaners
67
+ _pad = '_'
68
+ _punctuation = ',.!?~…─'
69
+ _letters = '#Nabdefghijklmnoprstuvwxyzæçøŋœȵɐɑɒɓɔɕɗɘəɚɛɜɣɤɦɪɭɯɵɷɸɻɾɿʂʅʊʋʌʏʑʔʦʮʰʷˀː˥˦˧˨˩̥̩̃̚ᴀᴇ↑↓∅ⱼ '
70
+ '''
71
+
72
+ # Export all symbols:
73
+ symbols = [_pad] + list(_punctuation) + list(_letters)
74
+
75
+ # Special symbol ids
76
+ SPACE_ID = symbols.index(" ")
77
+
78
+ num_ja_tones = 1
79
+ num_kr_tones = 1
80
+ num_zh_tones = 6
81
+ num_en_tones = 4
82
+
83
+ language_tone_start_map = {
84
+ "ZH": 0,
85
+ "JP": num_zh_tones,
86
+ "EN": num_zh_tones + num_ja_tones,
87
+ 'KR': num_zh_tones + num_ja_tones + num_en_tones,
88
+ }
OpenVoice/openvoice/transforms.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.nn import functional as F
3
+
4
+ import numpy as np
5
+
6
+
7
+ DEFAULT_MIN_BIN_WIDTH = 1e-3
8
+ DEFAULT_MIN_BIN_HEIGHT = 1e-3
9
+ DEFAULT_MIN_DERIVATIVE = 1e-3
10
+
11
+
12
+ def piecewise_rational_quadratic_transform(
13
+ inputs,
14
+ unnormalized_widths,
15
+ unnormalized_heights,
16
+ unnormalized_derivatives,
17
+ inverse=False,
18
+ tails=None,
19
+ tail_bound=1.0,
20
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
21
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
22
+ min_derivative=DEFAULT_MIN_DERIVATIVE,
23
+ ):
24
+ if tails is None:
25
+ spline_fn = rational_quadratic_spline
26
+ spline_kwargs = {}
27
+ else:
28
+ spline_fn = unconstrained_rational_quadratic_spline
29
+ spline_kwargs = {"tails": tails, "tail_bound": tail_bound}
30
+
31
+ outputs, logabsdet = spline_fn(
32
+ inputs=inputs,
33
+ unnormalized_widths=unnormalized_widths,
34
+ unnormalized_heights=unnormalized_heights,
35
+ unnormalized_derivatives=unnormalized_derivatives,
36
+ inverse=inverse,
37
+ min_bin_width=min_bin_width,
38
+ min_bin_height=min_bin_height,
39
+ min_derivative=min_derivative,
40
+ **spline_kwargs
41
+ )
42
+ return outputs, logabsdet
43
+
44
+
45
+ def searchsorted(bin_locations, inputs, eps=1e-6):
46
+ bin_locations[..., -1] += eps
47
+ return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1
48
+
49
+
50
+ def unconstrained_rational_quadratic_spline(
51
+ inputs,
52
+ unnormalized_widths,
53
+ unnormalized_heights,
54
+ unnormalized_derivatives,
55
+ inverse=False,
56
+ tails="linear",
57
+ tail_bound=1.0,
58
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
59
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
60
+ min_derivative=DEFAULT_MIN_DERIVATIVE,
61
+ ):
62
+ inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
63
+ outside_interval_mask = ~inside_interval_mask
64
+
65
+ outputs = torch.zeros_like(inputs)
66
+ logabsdet = torch.zeros_like(inputs)
67
+
68
+ if tails == "linear":
69
+ unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
70
+ constant = np.log(np.exp(1 - min_derivative) - 1)
71
+ unnormalized_derivatives[..., 0] = constant
72
+ unnormalized_derivatives[..., -1] = constant
73
+
74
+ outputs[outside_interval_mask] = inputs[outside_interval_mask]
75
+ logabsdet[outside_interval_mask] = 0
76
+ else:
77
+ raise RuntimeError("{} tails are not implemented.".format(tails))
78
+
79
+ (
80
+ outputs[inside_interval_mask],
81
+ logabsdet[inside_interval_mask],
82
+ ) = rational_quadratic_spline(
83
+ inputs=inputs[inside_interval_mask],
84
+ unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
85
+ unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
86
+ unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
87
+ inverse=inverse,
88
+ left=-tail_bound,
89
+ right=tail_bound,
90
+ bottom=-tail_bound,
91
+ top=tail_bound,
92
+ min_bin_width=min_bin_width,
93
+ min_bin_height=min_bin_height,
94
+ min_derivative=min_derivative,
95
+ )
96
+
97
+ return outputs, logabsdet
98
+
99
+
100
+ def rational_quadratic_spline(
101
+ inputs,
102
+ unnormalized_widths,
103
+ unnormalized_heights,
104
+ unnormalized_derivatives,
105
+ inverse=False,
106
+ left=0.0,
107
+ right=1.0,
108
+ bottom=0.0,
109
+ top=1.0,
110
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
111
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
112
+ min_derivative=DEFAULT_MIN_DERIVATIVE,
113
+ ):
114
+ if torch.min(inputs) < left or torch.max(inputs) > right:
115
+ raise ValueError("Input to a transform is not within its domain")
116
+
117
+ num_bins = unnormalized_widths.shape[-1]
118
+
119
+ if min_bin_width * num_bins > 1.0:
120
+ raise ValueError("Minimal bin width too large for the number of bins")
121
+ if min_bin_height * num_bins > 1.0:
122
+ raise ValueError("Minimal bin height too large for the number of bins")
123
+
124
+ widths = F.softmax(unnormalized_widths, dim=-1)
125
+ widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
126
+ cumwidths = torch.cumsum(widths, dim=-1)
127
+ cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)
128
+ cumwidths = (right - left) * cumwidths + left
129
+ cumwidths[..., 0] = left
130
+ cumwidths[..., -1] = right
131
+ widths = cumwidths[..., 1:] - cumwidths[..., :-1]
132
+
133
+ derivatives = min_derivative + F.softplus(unnormalized_derivatives)
134
+
135
+ heights = F.softmax(unnormalized_heights, dim=-1)
136
+ heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
137
+ cumheights = torch.cumsum(heights, dim=-1)
138
+ cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)
139
+ cumheights = (top - bottom) * cumheights + bottom
140
+ cumheights[..., 0] = bottom
141
+ cumheights[..., -1] = top
142
+ heights = cumheights[..., 1:] - cumheights[..., :-1]
143
+
144
+ if inverse:
145
+ bin_idx = searchsorted(cumheights, inputs)[..., None]
146
+ else:
147
+ bin_idx = searchsorted(cumwidths, inputs)[..., None]
148
+
149
+ input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
150
+ input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
151
+
152
+ input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
153
+ delta = heights / widths
154
+ input_delta = delta.gather(-1, bin_idx)[..., 0]
155
+
156
+ input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
157
+ input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
158
+
159
+ input_heights = heights.gather(-1, bin_idx)[..., 0]
160
+
161
+ if inverse:
162
+ a = (inputs - input_cumheights) * (
163
+ input_derivatives + input_derivatives_plus_one - 2 * input_delta
164
+ ) + input_heights * (input_delta - input_derivatives)
165
+ b = input_heights * input_derivatives - (inputs - input_cumheights) * (
166
+ input_derivatives + input_derivatives_plus_one - 2 * input_delta
167
+ )
168
+ c = -input_delta * (inputs - input_cumheights)
169
+
170
+ discriminant = b.pow(2) - 4 * a * c
171
+ assert (discriminant >= 0).all()
172
+
173
+ root = (2 * c) / (-b - torch.sqrt(discriminant))
174
+ outputs = root * input_bin_widths + input_cumwidths
175
+
176
+ theta_one_minus_theta = root * (1 - root)
177
+ denominator = input_delta + (
178
+ (input_derivatives + input_derivatives_plus_one - 2 * input_delta)
179
+ * theta_one_minus_theta
180
+ )
181
+ derivative_numerator = input_delta.pow(2) * (
182
+ input_derivatives_plus_one * root.pow(2)
183
+ + 2 * input_delta * theta_one_minus_theta
184
+ + input_derivatives * (1 - root).pow(2)
185
+ )
186
+ logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
187
+
188
+ return outputs, -logabsdet
189
+ else:
190
+ theta = (inputs - input_cumwidths) / input_bin_widths
191
+ theta_one_minus_theta = theta * (1 - theta)
192
+
193
+ numerator = input_heights * (
194
+ input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta
195
+ )
196
+ denominator = input_delta + (
197
+ (input_derivatives + input_derivatives_plus_one - 2 * input_delta)
198
+ * theta_one_minus_theta
199
+ )
200
+ outputs = input_cumheights + numerator / denominator
201
+
202
+ derivative_numerator = input_delta.pow(2) * (
203
+ input_derivatives_plus_one * theta.pow(2)
204
+ + 2 * input_delta * theta_one_minus_theta
205
+ + input_derivatives * (1 - theta).pow(2)
206
+ )
207
+ logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
208
+
209
+ return outputs, logabsdet
OpenVoice/openvoice/utils.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import json
3
+ import numpy as np
4
+
5
+
6
+ def get_hparams_from_file(config_path):
7
+ with open(config_path, "r", encoding="utf-8") as f:
8
+ data = f.read()
9
+ config = json.loads(data)
10
+
11
+ hparams = HParams(**config)
12
+ return hparams
13
+
14
+ class HParams:
15
+ def __init__(self, **kwargs):
16
+ for k, v in kwargs.items():
17
+ if type(v) == dict:
18
+ v = HParams(**v)
19
+ self[k] = v
20
+
21
+ def keys(self):
22
+ return self.__dict__.keys()
23
+
24
+ def items(self):
25
+ return self.__dict__.items()
26
+
27
+ def values(self):
28
+ return self.__dict__.values()
29
+
30
+ def __len__(self):
31
+ return len(self.__dict__)
32
+
33
+ def __getitem__(self, key):
34
+ return getattr(self, key)
35
+
36
+ def __setitem__(self, key, value):
37
+ return setattr(self, key, value)
38
+
39
+ def __contains__(self, key):
40
+ return key in self.__dict__
41
+
42
+ def __repr__(self):
43
+ return self.__dict__.__repr__()
44
+
45
+
46
+ def string_to_bits(string, pad_len=8):
47
+ # Convert each character to its ASCII value
48
+ ascii_values = [ord(char) for char in string]
49
+
50
+ # Convert ASCII values to binary representation
51
+ binary_values = [bin(value)[2:].zfill(8) for value in ascii_values]
52
+
53
+ # Convert binary strings to integer arrays
54
+ bit_arrays = [[int(bit) for bit in binary] for binary in binary_values]
55
+
56
+ # Convert list of arrays to NumPy array
57
+ numpy_array = np.array(bit_arrays)
58
+ numpy_array_full = np.zeros((pad_len, 8), dtype=numpy_array.dtype)
59
+ numpy_array_full[:, 2] = 1
60
+ max_len = min(pad_len, len(numpy_array))
61
+ numpy_array_full[:max_len] = numpy_array[:max_len]
62
+ return numpy_array_full
63
+
64
+
65
+ def bits_to_string(bits_array):
66
+ # Convert each row of the array to a binary string
67
+ binary_values = [''.join(str(bit) for bit in row) for row in bits_array]
68
+
69
+ # Convert binary strings to ASCII values
70
+ ascii_values = [int(binary, 2) for binary in binary_values]
71
+
72
+ # Convert ASCII values to characters
73
+ output_string = ''.join(chr(value) for value in ascii_values)
74
+
75
+ return output_string
76
+
77
+
78
+ def split_sentence(text, min_len=10, language_str='[EN]'):
79
+ if language_str in ['EN']:
80
+ sentences = split_sentences_latin(text, min_len=min_len)
81
+ else:
82
+ sentences = split_sentences_zh(text, min_len=min_len)
83
+ return sentences
84
+
85
+ def split_sentences_latin(text, min_len=10):
86
+ """Split Long sentences into list of short ones
87
+
88
+ Args:
89
+ str: Input sentences.
90
+
91
+ Returns:
92
+ List[str]: list of output sentences.
93
+ """
94
+ # deal with dirty sentences
95
+ text = re.sub('[。!?;]', '.', text)
96
+ text = re.sub('[,]', ',', text)
97
+ text = re.sub('[“”]', '"', text)
98
+ text = re.sub('[‘’]', "'", text)
99
+ text = re.sub(r"[\<\>\(\)\[\]\"\«\»]+", "", text)
100
+ text = re.sub('[\n\t ]+', ' ', text)
101
+ text = re.sub('([,.!?;])', r'\1 $#!', text)
102
+ # split
103
+ sentences = [s.strip() for s in text.split('$#!')]
104
+ if len(sentences[-1]) == 0: del sentences[-1]
105
+
106
+ new_sentences = []
107
+ new_sent = []
108
+ count_len = 0
109
+ for ind, sent in enumerate(sentences):
110
+ # print(sent)
111
+ new_sent.append(sent)
112
+ count_len += len(sent.split(" "))
113
+ if count_len > min_len or ind == len(sentences) - 1:
114
+ count_len = 0
115
+ new_sentences.append(' '.join(new_sent))
116
+ new_sent = []
117
+ return merge_short_sentences_latin(new_sentences)
118
+
119
+
120
+ def merge_short_sentences_latin(sens):
121
+ """Avoid short sentences by merging them with the following sentence.
122
+
123
+ Args:
124
+ List[str]: list of input sentences.
125
+
126
+ Returns:
127
+ List[str]: list of output sentences.
128
+ """
129
+ sens_out = []
130
+ for s in sens:
131
+ # If the previous sentence is too short, merge them with
132
+ # the current sentence.
133
+ if len(sens_out) > 0 and len(sens_out[-1].split(" ")) <= 2:
134
+ sens_out[-1] = sens_out[-1] + " " + s
135
+ else:
136
+ sens_out.append(s)
137
+ try:
138
+ if len(sens_out[-1].split(" ")) <= 2:
139
+ sens_out[-2] = sens_out[-2] + " " + sens_out[-1]
140
+ sens_out.pop(-1)
141
+ except:
142
+ pass
143
+ return sens_out
144
+
145
+ def split_sentences_zh(text, min_len=10):
146
+ text = re.sub('[。!?;]', '.', text)
147
+ text = re.sub('[,]', ',', text)
148
+ # 将文本中的换行符、空格和制表符替换为空格
149
+ text = re.sub('[\n\t ]+', ' ', text)
150
+ # 在标点符号后添加一个空格
151
+ text = re.sub('([,.!?;])', r'\1 $#!', text)
152
+ # 分隔句子并去除前后空格
153
+ # sentences = [s.strip() for s in re.split('(。|!|?|;)', text)]
154
+ sentences = [s.strip() for s in text.split('$#!')]
155
+ if len(sentences[-1]) == 0: del sentences[-1]
156
+
157
+ new_sentences = []
158
+ new_sent = []
159
+ count_len = 0
160
+ for ind, sent in enumerate(sentences):
161
+ new_sent.append(sent)
162
+ count_len += len(sent)
163
+ if count_len > min_len or ind == len(sentences) - 1:
164
+ count_len = 0
165
+ new_sentences.append(' '.join(new_sent))
166
+ new_sent = []
167
+ return merge_short_sentences_zh(new_sentences)
168
+
169
+
170
+ def merge_short_sentences_zh(sens):
171
+ # return sens
172
+ """Avoid short sentences by merging them with the following sentence.
173
+
174
+ Args:
175
+ List[str]: list of input sentences.
176
+
177
+ Returns:
178
+ List[str]: list of output sentences.
179
+ """
180
+ sens_out = []
181
+ for s in sens:
182
+ # If the previous sentense is too short, merge them with
183
+ # the current sentence.
184
+ if len(sens_out) > 0 and len(sens_out[-1]) <= 2:
185
+ sens_out[-1] = sens_out[-1] + " " + s
186
+ else:
187
+ sens_out.append(s)
188
+ try:
189
+ if len(sens_out[-1]) <= 2:
190
+ sens_out[-2] = sens_out[-2] + " " + sens_out[-1]
191
+ sens_out.pop(-1)
192
+ except:
193
+ pass
194
+ return sens_out
README.md ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: 통합 TTS + 음성 변환 앱
3
+ emoji: 🎤
4
+ colorFrom: blue
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 4.44.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ short_description: OpenVoice V2와 Seed-VC를 결합한 텍스트-음성 변환 및 음성 변환 앱
12
+ ---
13
+
14
+ # 통합 TTS + 음성 변환 앱
15
+
16
+ 이 앱은 OpenVoice V2와 Seed-VC를 결합하여 텍스트를 음성으로 변환한 후, 참조 음성의 스타일로 변환하는 통합 솔루션입니다.
17
+
18
+ ## 🚀 주요 기능
19
+
20
+ 1. **텍스트-음성 변환 (TTS)**: OpenVoice V2를 사용하여 텍스트를 음성으로 변환
21
+ 2. **음성 복제**: 참조 음성의 톤과 스타일을 복제
22
+ 3. **음성 변환**: Seed-VC를 사용하여 생성된 음성을 참조 음성의 스타일로 변환
23
+
24
+ ## 📋 처리 과정
25
+
26
+ 1. **입력**: 텍스트 + 참조 음성 + 변환 파라미터들
27
+ 2. **1단계**: 텍스트 → 참조 음성 톤으로 음성 생성 (OpenVoice V2)
28
+ 3. **2단계**: 생성된 음성 → 참조 음성 스타일로 변환 (Seed-VC)
29
+ 4. **출력**: 최종 변환된 음성
30
+
31
+ ## 🛠️ 설치 및 실행
32
+
33
+ ### 요구사항
34
+ - Python 3.8+
35
+ - CUDA 지원 GPU (권장)
36
+
37
+ ### 설치
38
+ ```bash
39
+ pip install -r requirements.txt
40
+ ```
41
+
42
+ ### 실행
43
+ ```bash
44
+ python app.py
45
+ ```
46
+
47
+ ## 📝 사용법
48
+
49
+ 1. **텍스트 입력**: 변환하고 싶은 텍스트를 입력하세요
50
+ 2. **참조 음성 업로드**: 3-10초 길이의 깨끗한 참조 음성을 업로드하세요
51
+ 3. **기본 설정 조정**:
52
+ - 기본 음성 스타일 선택
53
+ - 음성 속도 조정 (0.6x ~ 1.4x)
54
+ 4. **음성 변환 파라미터 조정**:
55
+ - 확산 단계 (1-200, 기본값: 25)
56
+ - 길이 조정 (0.5-2.0, 기본값: 1.0)
57
+ - CFG 비율 (0.0-1.0, 기본값: 0.7)
58
+ - 피치 시프트 (-24~24 반음, 기본값: 0)
59
+ - F0 조건부 모델 사용 여부
60
+ - 자동 F0 조정 여부
61
+ 5. **"통합 변환" 버튼 클릭**
62
+
63
+ ## ⚙️ 파라미터 설명
64
+
65
+ ### TTS 파라미터
66
+ - **기본 음성 스타일**: 지원 언어별 기본 음성 스타일 (EN, ES, FR, ZH, JP, KR)
67
+ - **음성 속도**: 생성될 음성의 속도 (1.0이 정상 속도)
68
+
69
+ ### 음성 변환 파라미터
70
+ - **확산 단계**: 변환 품질에 영향 (높을수록 고품질, 50-100 권장)
71
+ - **길이 조정**: 음성 길이 조정 (<1.0: 빠르게, >1.0: 느리게)
72
+ - **CFG 비율**: 변환 강도 조정
73
+ - **피치 시프트**: 음높이 조정 (반음 단위)
74
+ - **F0 조건부 모델**: 노래 음성 변환 시 필요
75
+ - **자동 F0 조정**: 대상 음성에 맞게 자동으로 F0 조정
76
+
77
+ ## 🔧 기술 스택
78
+
79
+ - **OpenVoice V2**: 텍스트-음성 변환 및 음성 복제
80
+ - **Seed-VC**: 고품질 음성 변환
81
+ - **MeloTTS**: 다국어 TTS 엔진
82
+ - **Whisper**: 음성 인코딩
83
+ - **BigVGAN**: 고품질 보코더
84
+ - **Gradio**: 웹 인터페이스
85
+
86
+ ## 📁 프로젝트 구조
87
+
88
+ ```
89
+ Lucy_5/
90
+ ├── app.py # 메인 애플리케이션
91
+ ├── requirements.txt # 의존성 패키지
92
+ ├── README.md # 프로젝트 설명
93
+ ├── OpenVoice/ # OpenVoice V2 모듈
94
+ ├── modules/ # Seed-VC 모듈
95
+ └── hf_utils.py # Hugging Face 유틸리티
96
+ ```
97
+
98
+ ## 🚀 Hugging Face Spaces 배포
99
+
100
+ 이 앱은 Hugging Face Spaces에서 배포할 수 있습니다:
101
+
102
+ 1. Hugging Face Spaces에서 새 Space 생성
103
+ 2. 이 저장소를 클론
104
+ 3. `requirements.txt`에 명시된 의존성 설치
105
+ 4. GPU Space로 설정 (권장)
106
+
107
+ ## 💡 팁
108
+
109
+ - 참조 음성은 3-10초 길이의 깨끗한 음성을 사용하세요
110
+ - 노래 음성 변환을 원한다면 "F0 조건부 모델 사용"을 체크하세요
111
+ - 품질을 높이려면 확산 단계를 50-100으로 설정하세요
112
+ - 긴 텍스트의 경우 처리 시간이 오래 걸릴 수 있습니다
113
+
114
+ ## 📄 라이선스
115
+
116
+ 이 프로젝트는 원본 라이브러리들의 라이선스를 따릅니다:
117
+ - OpenVoice V2
118
+ - Seed-VC
119
+ - MeloTTS
app.py ADDED
@@ -0,0 +1,843 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import tempfile
4
+ import zipfile
5
+ import shutil
6
+ import subprocess
7
+ import importlib
8
+ import traceback
9
+ # from typing import Optional, Dict, Any
10
+
11
+ # Force flush stdout and stderr for better logging in Hugging Face Spaces
12
+ def log_print(*args, **kwargs):
13
+ print(*args, **kwargs)
14
+ sys.stdout.flush()
15
+
16
+ def log_error(*args, **kwargs):
17
+ print(*args, file=sys.stderr, **kwargs)
18
+ sys.stderr.flush()
19
+
20
+ try:
21
+ import gradio as gr
22
+ import spaces
23
+ import torch
24
+ import torchaudio
25
+ import librosa
26
+ import yaml
27
+ import numpy as np
28
+ import nltk
29
+ import requests
30
+ from pydub import AudioSegment
31
+ import soundfile as sf
32
+ except ImportError as e:
33
+ log_error(f"Import error: {e}")
34
+ log_error("Please install required packages using: pip install -r requirements.txt")
35
+ sys.exit(1)
36
+
37
+ # Ensure module import works regardless of working directory location
38
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
39
+ OPENVOICE_DIR = os.path.join(ROOT_DIR, "OpenVoice")
40
+ if os.path.isdir(OPENVOICE_DIR) and OPENVOICE_DIR not in sys.path:
41
+ sys.path.insert(0, OPENVOICE_DIR)
42
+
43
+ # Also work inside the OpenVoice project subdir so relative ckpt paths resolve
44
+ PROJECT_SUBDIR = "OpenVoice"
45
+ if os.path.isdir(os.path.join(ROOT_DIR, PROJECT_SUBDIR)):
46
+ os.chdir(os.path.join(ROOT_DIR, PROJECT_SUBDIR))
47
+
48
+ # Import OpenVoice modules
49
+ from openvoice import se_extractor
50
+ from openvoice.api import ToneColorConverter
51
+ TTS = None # will import lazily after ensuring MeCab/Unidic
52
+
53
+ # Import Seed-VC modules
54
+ from modules.commons import build_model, load_checkpoint, recursive_munch
55
+ from hf_utils import load_custom_model_from_hf
56
+
57
+ # OpenVoice configuration
58
+ CKPT_CONVERTER_DIR = 'checkpoints_v2/converter'
59
+ BASE_SPEAKER_SE_DIR = 'checkpoints_v2/base_speakers/ses'
60
+ OUTPUT_DIR = 'outputs_v2'
61
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
62
+
63
+ DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
64
+
65
+ # Lazy singletons
66
+ _tone_color_converter = None
67
+ _melo_models = {}
68
+
69
+ # Seed-VC models (will be initialized at startup)
70
+ _seed_vc_models = None
71
+
72
+ def get_tone_color_converter():
73
+ global _tone_color_converter
74
+ if _tone_color_converter is None:
75
+ ensure_checkpoints()
76
+ converter = ToneColorConverter(f'{CKPT_CONVERTER_DIR}/config.json', device=DEVICE)
77
+ converter.load_ckpt(f'{CKPT_CONVERTER_DIR}/checkpoint.pth')
78
+ _tone_color_converter = converter
79
+ return _tone_color_converter
80
+
81
+ def ensure_unidic_available() -> None:
82
+ try:
83
+ import unidic # type: ignore
84
+ dicdir = getattr(unidic, 'DICDIR', None)
85
+ if not dicdir or not os.path.exists(os.path.join(dicdir, 'mecabrc')):
86
+ subprocess.run([sys.executable, '-m', 'unidic', 'download'], check=False)
87
+ # Reload to get DICDIR
88
+ unidic = importlib.reload(unidic)
89
+ dicdir = getattr(unidic, 'DICDIR', None)
90
+ if dicdir and os.path.exists(os.path.join(dicdir, 'mecabrc')):
91
+ os.environ['MECABRC'] = os.path.join(dicdir, 'mecabrc')
92
+ except Exception:
93
+ # Best-effort; MeloTTS may still work for non-Japanese
94
+ pass
95
+
96
+ def ensure_nltk_resources() -> None:
97
+ try:
98
+ nltk.data.find('taggers/averaged_perceptron_tagger_eng')
99
+ except LookupError:
100
+ try:
101
+ nltk.download('averaged_perceptron_tagger_eng', quiet=True)
102
+ except Exception:
103
+ pass
104
+ try:
105
+ nltk.data.find('corpora/cmudict')
106
+ except LookupError:
107
+ try:
108
+ nltk.download('cmudict', quiet=True)
109
+ except Exception:
110
+ pass
111
+
112
+ def get_melo_model(language: str):
113
+ # Normalize a couple of aliases from demo_part3
114
+ if language.lower() in {"en_us", "en-newest", "en_newest"}:
115
+ language = "EN_NEWEST"
116
+ language = language.upper()
117
+ if language not in _melo_models:
118
+ global TTS
119
+ if TTS is None:
120
+ ensure_unidic_available()
121
+ ensure_nltk_resources()
122
+ from melo.api import TTS as _TTS # type: ignore
123
+ TTS = _TTS
124
+ _melo_models[language] = TTS(language=language, device=DEVICE)
125
+ return _melo_models[language]
126
+
127
+ def list_supported_styles():
128
+ # Map speaker .pth names we have to user choices
129
+ style_list = []
130
+ if not os.path.isdir(BASE_SPEAKER_SE_DIR):
131
+ return style_list
132
+ for name in sorted(os.listdir(BASE_SPEAKER_SE_DIR)):
133
+ if name.endswith('.pth'):
134
+ style_list.append(os.path.splitext(name)[0])
135
+ return style_list
136
+
137
+ def ensure_checkpoints():
138
+ # Download and place checkpoints at exact expected paths.
139
+ if os.path.exists(CKPT_CONVERTER_DIR) and os.path.isdir(BASE_SPEAKER_SE_DIR):
140
+ return
141
+ url = 'https://myshell-public-repo-host.s3.amazonaws.com/openvoice/checkpoints_v2_0417.zip'
142
+ tmp_zip = os.path.join(tempfile.gettempdir(), 'checkpoints_v2_0417.zip')
143
+ tmp_extract_root = tempfile.mkdtemp(prefix='ov_ckpt_')
144
+ try:
145
+ with requests.get(url, stream=True, timeout=300) as r:
146
+ r.raise_for_status()
147
+ with open(tmp_zip, 'wb') as f:
148
+ for chunk in r.iter_content(chunk_size=1 << 20):
149
+ if chunk:
150
+ f.write(chunk)
151
+ with zipfile.ZipFile(tmp_zip, 'r') as zf:
152
+ zf.extractall(tmp_extract_root)
153
+
154
+ # Find the folder that contains 'converter/config.json' directly under it
155
+ candidate = None
156
+ for root, _, _ in os.walk(tmp_extract_root):
157
+ cfg = os.path.join(root, 'converter', 'config.json')
158
+ ses_dir = os.path.join(root, 'base_speakers', 'ses')
159
+ if os.path.isfile(cfg) and os.path.isdir(ses_dir):
160
+ candidate = root
161
+ break
162
+ if candidate is None:
163
+ raise RuntimeError('Could not locate converter/config.json inside the downloaded archive.')
164
+
165
+ # Place contents into 'checkpoints_v2'
166
+ target_root = 'checkpoints_v2'
167
+ if os.path.exists(target_root):
168
+ shutil.rmtree(target_root, ignore_errors=True)
169
+ os.makedirs(target_root, exist_ok=True)
170
+
171
+ # Copy only required subfolders
172
+ for name in ['converter', 'base_speakers']:
173
+ src_path = os.path.join(candidate, name)
174
+ dst_path = os.path.join(target_root, name)
175
+ if os.path.exists(dst_path):
176
+ shutil.rmtree(dst_path, ignore_errors=True)
177
+ shutil.copytree(src_path, dst_path)
178
+ except Exception as e:
179
+ raise gr.Error(f"Failed to prepare checkpoints: {e}")
180
+ finally:
181
+ try:
182
+ if os.path.isdir(tmp_extract_root):
183
+ shutil.rmtree(tmp_extract_root, ignore_errors=True)
184
+ except Exception:
185
+ pass
186
+
187
+ def initialize_seed_vc_models():
188
+ """Initialize Seed-VC models with memory optimization"""
189
+ global _seed_vc_models
190
+
191
+ if _seed_vc_models is not None:
192
+ return _seed_vc_models
193
+
194
+ log_print("Loading Seed-VC models...")
195
+ # Clear GPU cache before loading models
196
+ if torch.cuda.is_available():
197
+ torch.cuda.empty_cache()
198
+
199
+ # Load DiT model
200
+ dit_checkpoint_path, dit_config_path = load_custom_model_from_hf("Plachta/Seed-VC",
201
+ "DiT_seed_v2_uvit_whisper_small_wavenet_bigvgan_pruned.pth",
202
+ "config_dit_mel_seed_uvit_whisper_small_wavenet.yml")
203
+
204
+ with open(dit_config_path, 'r', encoding='utf-8') as f:
205
+ config = yaml.safe_load(f)
206
+ model_params = recursive_munch(config['model_params'])
207
+ model = build_model(model_params, stage='DiT')
208
+ hop_length = config['preprocess_params']['spect_params']['hop_length']
209
+ sr = config['preprocess_params']['sr']
210
+
211
+ # Load checkpoints with memory optimization
212
+ model, _, _, _ = load_checkpoint(model, None, dit_checkpoint_path,
213
+ load_only_params=True, ignore_modules=[], is_distributed=False)
214
+ for key in model:
215
+ model[key].eval()
216
+ model[key].to(DEVICE)
217
+ model.cfm.estimator.setup_caches(max_batch_size=1, max_seq_length=4096) # Reduced from 8192
218
+
219
+ # Load CAMPPlus
220
+ from modules.campplus.DTDNN import CAMPPlus
221
+ campplus_ckpt_path = load_custom_model_from_hf("funasr/campplus", "campplus_cn_common.bin", config_filename=None)
222
+ campplus_model = CAMPPlus(feat_dim=80, embedding_size=192)
223
+ campplus_model.load_state_dict(torch.load(campplus_ckpt_path, map_location="cpu"))
224
+ campplus_model.eval()
225
+ campplus_model.to(DEVICE)
226
+
227
+ # Load BigVGAN
228
+ from modules.bigvgan import bigvgan
229
+ bigvgan_model = bigvgan.BigVGAN.from_pretrained('nvidia/bigvgan_v2_22khz_80band_256x', use_cuda_kernel=False)
230
+ bigvgan_model.remove_weight_norm()
231
+ bigvgan_model = bigvgan_model.eval().to(DEVICE)
232
+
233
+ # Load FAcodec with error handling
234
+ try:
235
+ ckpt_path, config_path = load_custom_model_from_hf("Plachta/FAcodec", 'pytorch_model.bin', 'config.yml')
236
+ with open(config_path, 'r', encoding='utf-8') as f:
237
+ codec_config = yaml.safe_load(f)
238
+ codec_model_params = recursive_munch(codec_config['model_params'])
239
+
240
+ # Remove problematic 'causal' parameter if it exists
241
+ if hasattr(codec_model_params, 'dac_params') and hasattr(codec_model_params.dac_params, 'causal'):
242
+ delattr(codec_model_params.dac_params, 'causal')
243
+ log_print("Removed 'causal' parameter from DAC config")
244
+
245
+ # Also check for other problematic parameters
246
+ if hasattr(codec_model_params, 'dac_params'):
247
+ dac_params = codec_model_params.dac_params
248
+ # Remove any parameters that might cause issues
249
+ problematic_params = ['causal', 'causal_conv', 'causal_attention']
250
+ for param in problematic_params:
251
+ if hasattr(dac_params, param):
252
+ delattr(dac_params, param)
253
+ log_print(f"Removed '{param}' parameter from DAC config")
254
+
255
+ codec_encoder = build_model(codec_model_params, stage="codec")
256
+ log_print("✓ FAcodec loaded successfully")
257
+ except Exception as e:
258
+ log_error(f"Warning: Failed to load FAcodec: {e}")
259
+ log_error(f"FAcodec error traceback: {traceback.format_exc()}")
260
+ # Create a minimal dummy codec encoder
261
+ log_print("Creating minimal codec encoder as fallback...")
262
+ try:
263
+ # Try to create a basic DAC model without problematic parameters
264
+ from descript_audio_codec import DAC
265
+ codec_encoder = {'codec': DAC()}
266
+ log_print("✓ Created minimal DAC fallback")
267
+ except Exception as e2:
268
+ log_error(f"Failed to create DAC fallback: {e2}")
269
+ # Create a completely dummy encoder
270
+ class DummyCodec:
271
+ def __getitem__(self, key):
272
+ return self
273
+ def eval(self):
274
+ return self
275
+ def to(self, device):
276
+ return self
277
+ codec_encoder = {'codec': DummyCodec()}
278
+ log_print("✓ Created dummy codec encoder")
279
+
280
+ # Load codec checkpoint with error handling
281
+ try:
282
+ ckpt_params = torch.load(ckpt_path, map_location="cpu")
283
+ if 'codec' in ckpt_params:
284
+ codec_encoder.codec.load_state_dict(ckpt_params['codec'], strict=False)
285
+ elif 'model' in ckpt_params:
286
+ codec_encoder.codec.load_state_dict(ckpt_params['model'], strict=False)
287
+ else:
288
+ codec_encoder.codec.load_state_dict(ckpt_params, strict=False)
289
+ except Exception as e:
290
+ log_error(f"Warning: Could not load codec state dict: {e}")
291
+ log_error(f"Codec state dict error traceback: {traceback.format_exc()}")
292
+ log_error("Codec will use default parameters")
293
+
294
+ _ = [codec_encoder[key].eval() for key in codec_encoder]
295
+ _ = [codec_encoder[key].to(DEVICE) for key in codec_encoder]
296
+
297
+ # Load Whisper
298
+ from transformers import AutoFeatureExtractor, WhisperModel
299
+ whisper_name = model_params.speech_tokenizer.whisper_name if hasattr(model_params.speech_tokenizer, 'whisper_name') else "openai/whisper-small"
300
+ whisper_model = WhisperModel.from_pretrained(whisper_name, torch_dtype=torch.float16).to(DEVICE)
301
+ del whisper_model.decoder
302
+ whisper_feature_extractor = AutoFeatureExtractor.from_pretrained(whisper_name)
303
+
304
+ # Mel spectrogram function
305
+ mel_fn_args = {
306
+ "n_fft": config['preprocess_params']['spect_params']['n_fft'],
307
+ "win_size": config['preprocess_params']['spect_params']['win_length'],
308
+ "hop_size": config['preprocess_params']['spect_params']['hop_length'],
309
+ "num_mels": config['preprocess_params']['spect_params']['n_mels'],
310
+ "sampling_rate": sr,
311
+ "fmin": 0,
312
+ "fmax": None,
313
+ "center": False
314
+ }
315
+ from modules.audio import mel_spectrogram
316
+ to_mel = lambda x: mel_spectrogram(x, **mel_fn_args)
317
+
318
+ # Load F0 conditioned model
319
+ dit_checkpoint_path_f0, dit_config_path_f0 = load_custom_model_from_hf("Plachta/Seed-VC",
320
+ "DiT_seed_v2_uvit_whisper_base_f0_44k_bigvgan_pruned_ft_ema.pth",
321
+ "config_dit_mel_seed_uvit_whisper_base_f0_44k.yml")
322
+
323
+ with open(dit_config_path_f0, 'r', encoding='utf-8') as f:
324
+ config_f0 = yaml.safe_load(f)
325
+ model_params_f0 = recursive_munch(config_f0['model_params'])
326
+ model_f0 = build_model(model_params_f0, stage='DiT')
327
+ hop_length_f0 = config_f0['preprocess_params']['spect_params']['hop_length']
328
+ sr_f0 = config_f0['preprocess_params']['sr']
329
+
330
+ # Load checkpoints
331
+ model_f0, _, _, _ = load_checkpoint(model_f0, None, dit_checkpoint_path_f0,
332
+ load_only_params=True, ignore_modules=[], is_distributed=False)
333
+ for key in model_f0:
334
+ model_f0[key].eval()
335
+ model_f0[key].to(DEVICE)
336
+ model_f0.cfm.estimator.setup_caches(max_batch_size=1, max_seq_length=4096) # Reduced from 8192
337
+
338
+ # Load RMVPE
339
+ from modules.rmvpe import RMVPE
340
+ model_path = load_custom_model_from_hf("lj1995/VoiceConversionWebUI", "rmvpe.pt", None)
341
+ rmvpe = RMVPE(model_path, is_half=False, device=DEVICE)
342
+
343
+ mel_fn_args_f0 = {
344
+ "n_fft": config_f0['preprocess_params']['spect_params']['n_fft'],
345
+ "win_size": config_f0['preprocess_params']['spect_params']['win_length'],
346
+ "hop_size": config_f0['preprocess_params']['spect_params']['hop_length'],
347
+ "num_mels": config_f0['preprocess_params']['spect_params']['n_mels'],
348
+ "sampling_rate": sr_f0,
349
+ "fmin": 0,
350
+ "fmax": None,
351
+ "center": False
352
+ }
353
+ to_mel_f0 = lambda x: mel_spectrogram(x, **mel_fn_args_f0)
354
+
355
+ bigvgan_44k_model = bigvgan.BigVGAN.from_pretrained('nvidia/bigvgan_v2_44khz_128band_512x', use_cuda_kernel=False)
356
+ bigvgan_44k_model.remove_weight_norm()
357
+ bigvgan_44k_model = bigvgan_44k_model.eval().to(DEVICE)
358
+
359
+ _seed_vc_models = {
360
+ 'model': model,
361
+ 'model_f0': model_f0,
362
+ 'campplus_model': campplus_model,
363
+ 'bigvgan_model': bigvgan_model,
364
+ 'bigvgan_44k_model': bigvgan_44k_model,
365
+ 'codec_encoder': codec_encoder,
366
+ 'whisper_model': whisper_model,
367
+ 'whisper_feature_extractor': whisper_feature_extractor,
368
+ 'to_mel': to_mel,
369
+ 'to_mel_f0': to_mel_f0,
370
+ 'rmvpe': rmvpe,
371
+ 'config': config,
372
+ 'config_f0': config_f0,
373
+ 'hop_length': hop_length,
374
+ 'sr': sr,
375
+ 'hop_length_f0': hop_length_f0,
376
+ 'sr_f0': sr_f0
377
+ }
378
+
379
+ return _seed_vc_models
380
+
381
+ def adjust_f0_semitones(f0_sequence, n_semitones):
382
+ factor = 2 ** (n_semitones / 12)
383
+ return f0_sequence * factor
384
+
385
+ def crossfade(chunk1, chunk2, overlap):
386
+ fade_out = np.cos(np.linspace(0, np.pi / 2, overlap)) ** 2
387
+ fade_in = np.cos(np.linspace(np.pi / 2, 0, overlap)) ** 2
388
+ chunk2[:overlap] = chunk2[:overlap] * fade_in + chunk1[-overlap:] * fade_out
389
+ return chunk2
390
+
391
+ # Step 1: OpenVoice TTS + Voice Cloning
392
+ def run_openvoice_inference(text: str, style_key: str, speed: float, reference_audio_path: str) -> str:
393
+ if not text or not reference_audio_path:
394
+ raise gr.Error("Please provide text and a reference audio.")
395
+
396
+ # Re-evaluate device at call time for ZeroGPU
397
+ global DEVICE
398
+ DEVICE = "cuda:0" if torch.cuda.is_available() else DEVICE
399
+ converter = get_tone_color_converter()
400
+
401
+ # Extract target speaker embedding from uploaded reference audio
402
+ target_se, _ = se_extractor.get_se(reference_audio_path, converter, vad=True)
403
+
404
+ # Prepare base speech with Melo
405
+ language_from_style = "EN_NEWEST" if style_key.startswith("en-") else None
406
+ if style_key.startswith("es"):
407
+ language_from_style = "ES"
408
+ elif style_key.startswith("fr"):
409
+ language_from_style = "FR"
410
+ elif style_key.startswith("zh"):
411
+ language_from_style = "ZH"
412
+ elif style_key.startswith("jp"):
413
+ language_from_style = "JP"
414
+ elif style_key.startswith("kr"):
415
+ language_from_style = "KR"
416
+
417
+ melo = get_melo_model(language_from_style or "EN_NEWEST")
418
+ speaker_ids = melo.hps.data.spk2id
419
+
420
+ # Pick first available speaker id for that language
421
+ speaker_id = next(iter(speaker_ids.values()))
422
+
423
+ # Disable MPS quirk similar to demo_part3
424
+ if torch.backends.mps.is_available() and DEVICE == 'cpu':
425
+ torch.backends.mps.is_available = lambda: False
426
+
427
+ tmp_wav = os.path.join(OUTPUT_DIR, 'tmp.wav')
428
+ melo.tts_to_file(text, speaker_id, tmp_wav, speed=speed)
429
+
430
+ # Source speaker embedding from selected base style
431
+ source_se_path = os.path.join(BASE_SPEAKER_SE_DIR, f'{style_key}.pth')
432
+ if not os.path.exists(source_se_path):
433
+ raise gr.Error(f"Missing base speaker embedding: {source_se_path}")
434
+ source_se = torch.load(source_se_path, map_location=DEVICE)
435
+
436
+ out_path = os.path.join(OUTPUT_DIR, f'openvoice_output_{style_key}.wav')
437
+
438
+ # Convert tone color
439
+ get_tone_color_converter().convert(
440
+ audio_src_path=tmp_wav,
441
+ src_se=source_se,
442
+ tgt_se=target_se,
443
+ output_path=out_path,
444
+ message='@MyShell',
445
+ )
446
+
447
+ return out_path
448
+
449
+ # Step 2: Seed-VC Voice Conversion
450
+ @torch.no_grad()
451
+ @torch.inference_mode()
452
+ def run_seed_vc_inference(source_audio_path: str, target_audio_path: str, vc_diffusion_steps: int,
453
+ vc_length_adjust: float, vc_inference_cfg_rate: float, vc_f0_condition: bool,
454
+ vc_auto_f0_adjust: bool, vc_pitch_shift: int) -> str:
455
+
456
+ log_print("Initializing Seed-VC models...")
457
+ models = initialize_seed_vc_models()
458
+ log_print("✓ Seed-VC models ready")
459
+
460
+ inference_module = models['model_f0'] if vc_f0_condition else models['model']
461
+ mel_fn = models['to_mel_f0'] if vc_f0_condition else models['to_mel']
462
+ bigvgan_fn = models['bigvgan_44k_model'] if vc_f0_condition else models['bigvgan_model']
463
+ sr = models['sr_f0'] if vc_f0_condition else models['sr']
464
+ hop_length = models['hop_length_f0'] if vc_f0_condition else models['hop_length']
465
+
466
+ max_context_window = sr // hop_length * 30
467
+ overlap_frame_len = 16
468
+ overlap_wave_len = overlap_frame_len * hop_length
469
+ bitrate = "320k"
470
+
471
+ # Load audio
472
+ source_audio = librosa.load(source_audio_path, sr=sr)[0]
473
+ ref_audio = librosa.load(target_audio_path, sr=sr)[0]
474
+
475
+ # Process audio
476
+ source_audio = torch.tensor(source_audio).unsqueeze(0).float().to(DEVICE)
477
+ ref_audio = torch.tensor(ref_audio[:sr * 25]).unsqueeze(0).float().to(DEVICE)
478
+
479
+ # Resample
480
+ ref_waves_16k = torchaudio.functional.resample(ref_audio, sr, 16000)
481
+ converted_waves_16k = torchaudio.functional.resample(source_audio, sr, 16000)
482
+
483
+ # Whisper processing
484
+ if converted_waves_16k.size(-1) <= 16000 * 30:
485
+ alt_inputs = models['whisper_feature_extractor']([converted_waves_16k.squeeze(0).cpu().numpy()],
486
+ return_tensors="pt",
487
+ return_attention_mask=True,
488
+ sampling_rate=16000)
489
+ alt_input_features = models['whisper_model']._mask_input_features(
490
+ alt_inputs.input_features, attention_mask=alt_inputs.attention_mask).to(DEVICE)
491
+ alt_outputs = models['whisper_model'].encoder(
492
+ alt_input_features.to(models['whisper_model'].encoder.dtype),
493
+ head_mask=None,
494
+ output_attentions=False,
495
+ output_hidden_states=False,
496
+ return_dict=True,
497
+ )
498
+ S_alt = alt_outputs.last_hidden_state.to(torch.float32)
499
+ S_alt = S_alt[:, :converted_waves_16k.size(-1) // 320 + 1]
500
+ else:
501
+ overlapping_time = 5 # 5 seconds
502
+ S_alt_list = []
503
+ buffer = None
504
+ traversed_time = 0
505
+ while traversed_time < converted_waves_16k.size(-1):
506
+ if buffer is None: # first chunk
507
+ chunk = converted_waves_16k[:, traversed_time:traversed_time + 16000 * 30]
508
+ else:
509
+ chunk = torch.cat([buffer, converted_waves_16k[:, traversed_time:traversed_time + 16000 * (30 - overlapping_time)]], dim=-1)
510
+ alt_inputs = models['whisper_feature_extractor']([chunk.squeeze(0).cpu().numpy()],
511
+ return_tensors="pt",
512
+ return_attention_mask=True,
513
+ sampling_rate=16000)
514
+ alt_input_features = models['whisper_model']._mask_input_features(
515
+ alt_inputs.input_features, attention_mask=alt_inputs.attention_mask).to(DEVICE)
516
+ alt_outputs = models['whisper_model'].encoder(
517
+ alt_input_features.to(models['whisper_model'].encoder.dtype),
518
+ head_mask=None,
519
+ output_attentions=False,
520
+ output_hidden_states=False,
521
+ return_dict=True,
522
+ )
523
+ S_alt = alt_outputs.last_hidden_state.to(torch.float32)
524
+ S_alt = S_alt[:, :chunk.size(-1) // 320 + 1]
525
+ if traversed_time == 0:
526
+ S_alt_list.append(S_alt)
527
+ else:
528
+ S_alt_list.append(S_alt[:, 50 * overlapping_time:])
529
+ buffer = chunk[:, -16000 * overlapping_time:]
530
+ traversed_time += 30 * 16000 if traversed_time == 0 else chunk.size(-1) - 16000 * overlapping_time
531
+ S_alt = torch.cat(S_alt_list, dim=1)
532
+
533
+ ori_waves_16k = torchaudio.functional.resample(ref_audio, sr, 16000)
534
+ ori_inputs = models['whisper_feature_extractor']([ori_waves_16k.squeeze(0).cpu().numpy()],
535
+ return_tensors="pt",
536
+ return_attention_mask=True)
537
+ ori_input_features = models['whisper_model']._mask_input_features(
538
+ ori_inputs.input_features, attention_mask=ori_inputs.attention_mask).to(DEVICE)
539
+ with torch.no_grad():
540
+ ori_outputs = models['whisper_model'].encoder(
541
+ ori_input_features.to(models['whisper_model'].encoder.dtype),
542
+ head_mask=None,
543
+ output_attentions=False,
544
+ output_hidden_states=False,
545
+ return_dict=True,
546
+ )
547
+ S_ori = ori_outputs.last_hidden_state.to(torch.float32)
548
+ S_ori = S_ori[:, :ori_waves_16k.size(-1) // 320 + 1]
549
+
550
+ mel = mel_fn(source_audio.to(DEVICE).float())
551
+ mel2 = mel_fn(ref_audio.to(DEVICE).float())
552
+
553
+ target_lengths = torch.LongTensor([int(mel.size(2) * vc_length_adjust)]).to(mel.device)
554
+ target2_lengths = torch.LongTensor([mel2.size(2)]).to(mel2.device)
555
+
556
+ feat2 = torchaudio.compliance.kaldi.fbank(ref_waves_16k,
557
+ num_mel_bins=80,
558
+ dither=0,
559
+ sample_frequency=16000)
560
+ feat2 = feat2 - feat2.mean(dim=0, keepdim=True)
561
+ style2 = models['campplus_model'](feat2.unsqueeze(0))
562
+
563
+ if vc_f0_condition:
564
+ F0_ori = models['rmvpe'].infer_from_audio(ref_waves_16k[0], thred=0.5)
565
+ F0_alt = models['rmvpe'].infer_from_audio(converted_waves_16k[0], thred=0.5)
566
+
567
+ F0_ori = torch.from_numpy(F0_ori).to(DEVICE)[None]
568
+ F0_alt = torch.from_numpy(F0_alt).to(DEVICE)[None]
569
+
570
+ voiced_F0_ori = F0_ori[F0_ori > 1]
571
+ voiced_F0_alt = F0_alt[F0_alt > 1]
572
+
573
+ log_f0_alt = torch.log(F0_alt + 1e-5)
574
+ voiced_log_f0_ori = torch.log(voiced_F0_ori + 1e-5)
575
+ voiced_log_f0_alt = torch.log(voiced_F0_alt + 1e-5)
576
+ median_log_f0_ori = torch.median(voiced_log_f0_ori)
577
+ median_log_f0_alt = torch.median(voiced_log_f0_alt)
578
+
579
+ # shift alt log f0 level to ori log f0 level
580
+ shifted_log_f0_alt = log_f0_alt.clone()
581
+ if vc_auto_f0_adjust:
582
+ shifted_log_f0_alt[F0_alt > 1] = log_f0_alt[F0_alt > 1] - median_log_f0_alt + median_log_f0_ori
583
+ shifted_f0_alt = torch.exp(shifted_log_f0_alt)
584
+ if vc_pitch_shift != 0:
585
+ shifted_f0_alt[F0_alt > 1] = adjust_f0_semitones(shifted_f0_alt[F0_alt > 1], vc_pitch_shift)
586
+ else:
587
+ F0_ori = None
588
+ F0_alt = None
589
+ shifted_f0_alt = None
590
+
591
+ # Length regulation
592
+ cond, _, _, _, _ = inference_module.length_regulator(S_alt, ylens=target_lengths, n_quantizers=3, f0=shifted_f0_alt)
593
+ prompt_condition, _, _, _, _ = inference_module.length_regulator(S_ori, ylens=target2_lengths, n_quantizers=3, f0=F0_ori)
594
+
595
+ max_source_window = max_context_window - mel2.size(2)
596
+ # split source condition (cond) into chunks
597
+ processed_frames = 0
598
+ generated_wave_chunks = []
599
+ # generate chunk by chunk and stream the output
600
+ while processed_frames < cond.size(1):
601
+ chunk_cond = cond[:, processed_frames:processed_frames + max_source_window]
602
+ is_last_chunk = processed_frames + max_source_window >= cond.size(1)
603
+ cat_condition = torch.cat([prompt_condition, chunk_cond], dim=1)
604
+ with torch.autocast(device_type='cuda', dtype=torch.float16):
605
+ # Voice Conversion
606
+ vc_target = inference_module.cfm.inference(cat_condition,
607
+ torch.LongTensor([cat_condition.size(1)]).to(mel2.device),
608
+ mel2, style2, None, vc_diffusion_steps,
609
+ inference_cfg_rate=vc_inference_cfg_rate)
610
+ vc_target = vc_target[:, :, mel2.size(-1):]
611
+ vc_wave = bigvgan_fn(vc_target.float())[0]
612
+ if processed_frames == 0:
613
+ if is_last_chunk:
614
+ output_wave = vc_wave[0].cpu().numpy()
615
+ generated_wave_chunks.append(output_wave)
616
+ output_wave = (output_wave * 32768.0).astype(np.int16)
617
+ mp3_bytes = AudioSegment(
618
+ output_wave.tobytes(), frame_rate=sr,
619
+ sample_width=output_wave.dtype.itemsize, channels=1
620
+ ).export(format="mp3", bitrate=bitrate).read()
621
+ yield mp3_bytes, (sr, np.concatenate(generated_wave_chunks))
622
+ break
623
+ output_wave = vc_wave[0, :-overlap_wave_len].cpu().numpy()
624
+ generated_wave_chunks.append(output_wave)
625
+ previous_chunk = vc_wave[0, -overlap_wave_len:]
626
+ processed_frames += vc_target.size(2) - overlap_frame_len
627
+ output_wave = (output_wave * 32768.0).astype(np.int16)
628
+ mp3_bytes = AudioSegment(
629
+ output_wave.tobytes(), frame_rate=sr,
630
+ sample_width=output_wave.dtype.itemsize, channels=1
631
+ ).export(format="mp3", bitrate=bitrate).read()
632
+ yield mp3_bytes, None
633
+ elif is_last_chunk:
634
+ output_wave = crossfade(previous_chunk.cpu().numpy(), vc_wave[0].cpu().numpy(), overlap_wave_len)
635
+ generated_wave_chunks.append(output_wave)
636
+ processed_frames += vc_target.size(2) - overlap_frame_len
637
+ output_wave = (output_wave * 32768.0).astype(np.int16)
638
+ mp3_bytes = AudioSegment(
639
+ output_wave.tobytes(), frame_rate=sr,
640
+ sample_width=output_wave.dtype.itemsize, channels=1
641
+ ).export(format="mp3", bitrate=bitrate).read()
642
+ yield mp3_bytes, (sr, np.concatenate(generated_wave_chunks))
643
+ break
644
+ else:
645
+ output_wave = crossfade(previous_chunk.cpu().numpy(), vc_wave[0, :-overlap_wave_len].cpu().numpy(), overlap_wave_len)
646
+ generated_wave_chunks.append(output_wave)
647
+ previous_chunk = vc_wave[0, -overlap_wave_len:]
648
+ processed_frames += vc_target.size(2) - overlap_frame_len
649
+ output_wave = (output_wave * 32768.0).astype(np.int16)
650
+ mp3_bytes = AudioSegment(
651
+ output_wave.tobytes(), frame_rate=sr,
652
+ sample_width=output_wave.dtype.itemsize, channels=1
653
+ ).export(format="mp3", bitrate=bitrate).read()
654
+ yield mp3_bytes, None
655
+
656
+ # Main integrated function
657
+ @spaces.GPU
658
+ def process_integrated_tts_vc(text, style, speed, reference_audio, vc_diffusion_steps, vc_length_adjust,
659
+ vc_inference_cfg_rate, vc_f0_condition, vc_auto_f0_adjust, vc_pitch_shift):
660
+ """Integrated TTS + Voice Conversion pipeline"""
661
+
662
+ log_print("=" * 50)
663
+ log_print("STARTING PROCESSING...")
664
+ log_print(f"Text: {text[:50]}...")
665
+ log_print(f"Style: {style}, Speed: {speed}")
666
+ log_print(f"VC params: steps={vc_diffusion_steps}, length={vc_length_adjust}")
667
+ log_print("=" * 50)
668
+
669
+ # Handle Gradio audio input format
670
+ ref_path = None
671
+ if isinstance(reference_audio, tuple):
672
+ sr, data = reference_audio
673
+ with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as f:
674
+ sf.write(f.name, data, sr)
675
+ ref_path = f.name
676
+ elif isinstance(reference_audio, str):
677
+ ref_path = reference_audio
678
+
679
+ if not ref_path:
680
+ log_error("ERROR: No reference audio provided")
681
+ raise gr.Error("Please provide a reference audio.")
682
+
683
+ try:
684
+ # Step 1: OpenVoice TTS + Voice Cloning
685
+ log_print("Step 1: Running OpenVoice TTS...")
686
+ intermediate_audio = run_openvoice_inference(text, style, speed, ref_path)
687
+ log_print(f"✓ OpenVoice completed. Intermediate audio: {intermediate_audio}")
688
+
689
+ # Step 2: Seed-VC Voice Conversion
690
+ log_print("Step 2: Running Seed-VC Voice Conversion...")
691
+ # Call the actual voice conversion function and collect all results
692
+ results = list(run_seed_vc_inference(intermediate_audio, ref_path, vc_diffusion_steps, vc_length_adjust,
693
+ vc_inference_cfg_rate, vc_f0_condition, vc_auto_f0_adjust, vc_pitch_shift))
694
+ log_print(f"✓ Seed-VC completed. Results count: {len(results)}")
695
+
696
+ except Exception as e:
697
+ log_error(f"CRITICAL ERROR in processing: {str(e)}")
698
+ log_error(f"Error type: {type(e).__name__}")
699
+ log_error("Full traceback:")
700
+ log_error(traceback.format_exc())
701
+ # Re-raise the error to see it in Gradio
702
+ raise
703
+
704
+ # Find the final result (the one with the complete audio data)
705
+ final_result = None
706
+ for result in results:
707
+ if isinstance(result, tuple) and len(result) == 2 and result[1] is not None:
708
+ # This is the final result with complete audio data
709
+ final_result = result[1]
710
+ break
711
+
712
+ if final_result is not None:
713
+ # Save the final audio to a temporary file
714
+ sr, audio_data = final_result
715
+ with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as f:
716
+ sf.write(f.name, audio_data, sr)
717
+ return f.name
718
+
719
+ return None
720
+
721
+ # Get supported styles
722
+ styles = list_supported_styles() or [
723
+ 'en-newest', 'en-default', 'en-us', 'en-br', 'en-au', 'en-india',
724
+ 'es', 'fr', 'zh', 'jp', 'kr'
725
+ ]
726
+
727
+ # Skip model pre-loading for faster startup
728
+ log_print("=" * 50)
729
+ log_print("SKIPPING MODEL PRE-LOADING FOR FASTER STARTUP")
730
+ log_print("Models will be loaded on first use")
731
+ log_print("=" * 50)
732
+
733
+ # Create Gradio interface
734
+ with gr.Blocks(title="Integrated TTS + Voice Conversion", analytics_enabled=False) as demo:
735
+ gr.Markdown("""
736
+ # **Integrated TTS + Voice Conversion** — 텍스트를 음성으로 변환 후 음성 변환
737
+
738
+ 텍스트를 입력하고 참조 음성을 업로드하면, 먼저 텍스트가 음성으로 변환된 후 참조 음성의 스타일로 변환됩니다.
739
+
740
+ **사용법:**
741
+ 1. 변환할 텍스트를 입력하세요
742
+ 2. 참조 음성을 업로드하세요 (3-10초 권장)
743
+ 3. 기본 음성 스타일과 속도를 선택하세요
744
+ 4. 음성 변환 파라미터를 조정하세요
745
+ 5. "통합 변환" 버튼을 클릭하세요
746
+ """)
747
+
748
+ with gr.Row():
749
+ with gr.Column(scale=6):
750
+ # TTS Parameters
751
+ gr.Markdown("### 🎤 텍스트-음성 변환 설정")
752
+ text_input = gr.Textbox(
753
+ label="변환할 텍스트",
754
+ value="안녕하세요! 이것은 통합 TTS와 음성 변환 데모입니다.",
755
+ lines=3
756
+ )
757
+ style_input = gr.Dropdown(
758
+ label="기본 음성 스타일",
759
+ choices=styles,
760
+ value=styles[0]
761
+ )
762
+ speed_input = gr.Slider(
763
+ 0.6, 1.4, value=1.0, step=0.05,
764
+ label="음성 속도 (×)"
765
+ )
766
+ reference_audio_input = gr.Audio(
767
+ label="참조 음성",
768
+ sources=["upload", "microphone"],
769
+ type="filepath"
770
+ )
771
+
772
+ # Voice Conversion Parameters
773
+ gr.Markdown("### 🔄 음성 변환 설정")
774
+ with gr.Row():
775
+ vc_diffusion_steps = gr.Slider(
776
+ minimum=1, maximum=200, value=25, step=1,
777
+ label="확산 단계",
778
+ info="25 기본값, 50~100 최고 품질"
779
+ )
780
+ vc_length_adjust = gr.Slider(
781
+ minimum=0.5, maximum=2.0, step=0.1, value=1.0,
782
+ label="길이 조정",
783
+ info="<1.0 빠르게, >1.0 느리게"
784
+ )
785
+
786
+ with gr.Row():
787
+ vc_inference_cfg_rate = gr.Slider(
788
+ minimum=0.0, maximum=1.0, step=0.1, value=0.7,
789
+ label="CFG 비율",
790
+ info="미묘한 영향"
791
+ )
792
+ vc_pitch_shift = gr.Slider(
793
+ minimum=-24, maximum=24, step=1, value=0,
794
+ label="피치 시프트",
795
+ info="반음 단위"
796
+ )
797
+
798
+ with gr.Row():
799
+ vc_f0_condition = gr.Checkbox(
800
+ label="F0 조건부 모델 사용",
801
+ value=False,
802
+ info="노래 음성 변환에 필요"
803
+ )
804
+ vc_auto_f0_adjust = gr.Checkbox(
805
+ label="자동 F0 조정",
806
+ value=True,
807
+ info="대상 음성에 맞게 F0 조정"
808
+ )
809
+
810
+ convert_btn = gr.Button("통합 변환", variant="primary", size="lg")
811
+
812
+ with gr.Column(scale=6):
813
+ output_audio = gr.Audio(
814
+ label="최종 변환된 음성",
815
+ autoplay=True,
816
+ format="wav"
817
+ )
818
+
819
+ gr.Markdown("""
820
+ ### 📋 처리 과정:
821
+ 1. **텍스트 → 음성**: 입력된 텍스트가 참조 음성의 톤으로 변환됩니다
822
+ 2. **음성 변환**: 생성된 음성이 참조 음성의 스타일로 최종 변환됩니다
823
+
824
+ ### 💡 팁:
825
+ - 참조 음성은 3-10초 길이의 깨끗한 음성을 사용하세요
826
+ - 노래 음성 변환을 원한다면 "F0 조건부 모델 사용"을 체크하세요
827
+ - 품질을 높이려면 확산 단계를 50-100으로 설정하세요
828
+ """)
829
+
830
+ # Connect the button click to the processing function
831
+ convert_btn.click(
832
+ fn=process_integrated_tts_vc,
833
+ inputs=[
834
+ text_input, style_input, speed_input, reference_audio_input,
835
+ vc_diffusion_steps, vc_length_adjust, vc_inference_cfg_rate,
836
+ vc_f0_condition, vc_auto_f0_adjust, vc_pitch_shift
837
+ ],
838
+ outputs=[output_audio],
839
+ concurrency_limit=1
840
+ )
841
+
842
+ demo.queue()
843
+ demo.launch()
hf_utils.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from huggingface_hub import hf_hub_download
3
+
4
+
5
+ def load_custom_model_from_hf(repo_id, model_filename="pytorch_model.bin", config_filename=None):
6
+ os.makedirs("./checkpoints", exist_ok=True)
7
+ model_path = hf_hub_download(repo_id=repo_id, filename=model_filename, cache_dir="./checkpoints")
8
+ if config_filename is None:
9
+ return model_path
10
+ config_path = hf_hub_download(repo_id=repo_id, filename=config_filename, cache_dir="./checkpoints")
11
+
12
+ return model_path, config_path
modules/__pycache__/audio.cpython-310.pyc ADDED
Binary file (2.5 kB). View file
 
modules/__pycache__/commons.cpython-310.pyc ADDED
Binary file (13.3 kB). View file
 
modules/__pycache__/commons.cpython-38.pyc ADDED
Binary file (14.2 kB). View file
 
modules/__pycache__/diffusion_transformer.cpython-310.pyc ADDED
Binary file (17.4 kB). View file
 
modules/__pycache__/encodec.cpython-310.pyc ADDED
Binary file (10.8 kB). View file
 
modules/__pycache__/flow_matching.cpython-310.pyc ADDED
Binary file (5.41 kB). View file
 
modules/__pycache__/length_regulator.cpython-310.pyc ADDED
Binary file (4.23 kB). View file
 
modules/__pycache__/rmvpe.cpython-310.pyc ADDED
Binary file (17.6 kB). View file
 
modules/__pycache__/wavenet.cpython-310.pyc ADDED
Binary file (5.15 kB). View file
 
modules/alias_free_torch/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+
3
+ from .filter import *
4
+ from .resample import *
5
+ from .act import *
modules/alias_free_torch/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (198 Bytes). View file
 
modules/alias_free_torch/__pycache__/act.cpython-310.pyc ADDED
Binary file (1.03 kB). View file
 
modules/alias_free_torch/__pycache__/filter.cpython-310.pyc ADDED
Binary file (2.61 kB). View file
 
modules/alias_free_torch/__pycache__/resample.cpython-310.pyc ADDED
Binary file (1.89 kB). View file
 
modules/alias_free_torch/act.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+
3
+ import torch.nn as nn
4
+ from .resample import UpSample1d, DownSample1d
5
+
6
+
7
+ class Activation1d(nn.Module):
8
+ def __init__(
9
+ self,
10
+ activation,
11
+ up_ratio: int = 2,
12
+ down_ratio: int = 2,
13
+ up_kernel_size: int = 12,
14
+ down_kernel_size: int = 12,
15
+ ):
16
+ super().__init__()
17
+ self.up_ratio = up_ratio
18
+ self.down_ratio = down_ratio
19
+ self.act = activation
20
+ self.upsample = UpSample1d(up_ratio, up_kernel_size)
21
+ self.downsample = DownSample1d(down_ratio, down_kernel_size)
22
+
23
+ # x: [B,C,T]
24
+ def forward(self, x):
25
+ x = self.upsample(x)
26
+ x = self.act(x)
27
+ x = self.downsample(x)
28
+
29
+ return x
modules/alias_free_torch/filter.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import math
7
+
8
+ if "sinc" in dir(torch):
9
+ sinc = torch.sinc
10
+ else:
11
+ # This code is adopted from adefossez's julius.core.sinc under the MIT License
12
+ # https://adefossez.github.io/julius/julius/core.html
13
+ def sinc(x: torch.Tensor):
14
+ """
15
+ Implementation of sinc, i.e. sin(pi * x) / (pi * x)
16
+ __Warning__: Different to julius.sinc, the input is multiplied by `pi`!
17
+ """
18
+ return torch.where(
19
+ x == 0,
20
+ torch.tensor(1.0, device=x.device, dtype=x.dtype),
21
+ torch.sin(math.pi * x) / math.pi / x,
22
+ )
23
+
24
+
25
+ # This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
26
+ # https://adefossez.github.io/julius/julius/lowpass.html
27
+ def kaiser_sinc_filter1d(
28
+ cutoff, half_width, kernel_size
29
+ ): # return filter [1,1,kernel_size]
30
+ even = kernel_size % 2 == 0
31
+ half_size = kernel_size // 2
32
+
33
+ # For kaiser window
34
+ delta_f = 4 * half_width
35
+ A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
36
+ if A > 50.0:
37
+ beta = 0.1102 * (A - 8.7)
38
+ elif A >= 21.0:
39
+ beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0)
40
+ else:
41
+ beta = 0.0
42
+ window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
43
+
44
+ # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
45
+ if even:
46
+ time = torch.arange(-half_size, half_size) + 0.5
47
+ else:
48
+ time = torch.arange(kernel_size) - half_size
49
+ if cutoff == 0:
50
+ filter_ = torch.zeros_like(time)
51
+ else:
52
+ filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
53
+ # Normalize filter to have sum = 1, otherwise we will have a small leakage
54
+ # of the constant component in the input signal.
55
+ filter_ /= filter_.sum()
56
+ filter = filter_.view(1, 1, kernel_size)
57
+
58
+ return filter
59
+
60
+
61
+ class LowPassFilter1d(nn.Module):
62
+ def __init__(
63
+ self,
64
+ cutoff=0.5,
65
+ half_width=0.6,
66
+ stride: int = 1,
67
+ padding: bool = True,
68
+ padding_mode: str = "replicate",
69
+ kernel_size: int = 12,
70
+ ):
71
+ # kernel_size should be even number for stylegan3 setup,
72
+ # in this implementation, odd number is also possible.
73
+ super().__init__()
74
+ if cutoff < -0.0:
75
+ raise ValueError("Minimum cutoff must be larger than zero.")
76
+ if cutoff > 0.5:
77
+ raise ValueError("A cutoff above 0.5 does not make sense.")
78
+ self.kernel_size = kernel_size
79
+ self.even = kernel_size % 2 == 0
80
+ self.pad_left = kernel_size // 2 - int(self.even)
81
+ self.pad_right = kernel_size // 2
82
+ self.stride = stride
83
+ self.padding = padding
84
+ self.padding_mode = padding_mode
85
+ filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
86
+ self.register_buffer("filter", filter)
87
+
88
+ # input [B, C, T]
89
+ def forward(self, x):
90
+ _, C, _ = x.shape
91
+
92
+ if self.padding:
93
+ x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
94
+ out = F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
95
+
96
+ return out
modules/alias_free_torch/resample.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+
3
+ import torch.nn as nn
4
+ from torch.nn import functional as F
5
+ from .filter import LowPassFilter1d
6
+ from .filter import kaiser_sinc_filter1d
7
+
8
+
9
+ class UpSample1d(nn.Module):
10
+ def __init__(self, ratio=2, kernel_size=None):
11
+ super().__init__()
12
+ self.ratio = ratio
13
+ self.kernel_size = (
14
+ int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
15
+ )
16
+ self.stride = ratio
17
+ self.pad = self.kernel_size // ratio - 1
18
+ self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
19
+ self.pad_right = (
20
+ self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
21
+ )
22
+ filter = kaiser_sinc_filter1d(
23
+ cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size
24
+ )
25
+ self.register_buffer("filter", filter)
26
+
27
+ # x: [B, C, T]
28
+ def forward(self, x):
29
+ _, C, _ = x.shape
30
+
31
+ x = F.pad(x, (self.pad, self.pad), mode="replicate")
32
+ x = self.ratio * F.conv_transpose1d(
33
+ x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C
34
+ )
35
+ x = x[..., self.pad_left : -self.pad_right]
36
+
37
+ return x
38
+
39
+
40
+ class DownSample1d(nn.Module):
41
+ def __init__(self, ratio=2, kernel_size=None):
42
+ super().__init__()
43
+ self.ratio = ratio
44
+ self.kernel_size = (
45
+ int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
46
+ )
47
+ self.lowpass = LowPassFilter1d(
48
+ cutoff=0.5 / ratio,
49
+ half_width=0.6 / ratio,
50
+ stride=ratio,
51
+ kernel_size=self.kernel_size,
52
+ )
53
+
54
+ def forward(self, x):
55
+ xx = self.lowpass(x)
56
+
57
+ return xx
modules/astral_quantization/__pycache__/bsq.cpython-310.pyc ADDED
Binary file (12.7 kB). View file
 
modules/astral_quantization/__pycache__/convnext.cpython-310.pyc ADDED
Binary file (6.89 kB). View file
 
modules/astral_quantization/__pycache__/default_model.cpython-310.pyc ADDED
Binary file (2.82 kB). View file
 
modules/astral_quantization/bsq.py ADDED
@@ -0,0 +1,569 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Lookup Free Quantization
3
+ Proposed in https://arxiv.org/abs/2310.05737
4
+
5
+ In the simplest setup, each dimension is quantized into {-1, 1}.
6
+ An entropy penalty is used to encourage utilization.
7
+ """
8
+
9
+ from math import log2, ceil
10
+ from functools import partial, cache
11
+ from collections import namedtuple
12
+ from contextlib import nullcontext
13
+
14
+ import torch.distributed as dist
15
+ from torch.distributed import nn as dist_nn
16
+
17
+ import torch
18
+ from torch import nn, einsum
19
+ import torch.nn.functional as F
20
+ from torch.nn import Module
21
+ from torch.amp import autocast
22
+
23
+ from einops import rearrange, reduce, pack, unpack
24
+
25
+ # constants
26
+
27
+ Return = namedtuple('Return', ['quantized', 'indices', 'entropy_aux_loss'])
28
+
29
+ LossBreakdown = namedtuple('LossBreakdown', ['per_sample_entropy', 'batch_entropy', 'commitment'])
30
+
31
+ # distributed helpers
32
+
33
+ @cache
34
+ def is_distributed():
35
+ return dist.is_initialized() and dist.get_world_size() > 1
36
+
37
+ def maybe_distributed_mean(t):
38
+ if not is_distributed():
39
+ return t
40
+
41
+ dist_nn.all_reduce(t)
42
+ t = t / dist.get_world_size()
43
+ return t
44
+
45
+ # helper functions
46
+
47
+ def exists(v):
48
+ return v is not None
49
+
50
+ def identity(t):
51
+ return t
52
+
53
+ def default(*args):
54
+ for arg in args:
55
+ if exists(arg):
56
+ return arg() if callable(arg) else arg
57
+ return None
58
+
59
+ def pack_one(t, pattern):
60
+ return pack([t], pattern)
61
+
62
+ def unpack_one(t, ps, pattern):
63
+ return unpack(t, ps, pattern)[0]
64
+
65
+ def l2norm(t):
66
+ return F.normalize(t, dim = -1)
67
+
68
+ # entropy
69
+
70
+ def log(t, eps = 1e-5):
71
+ return t.clamp(min = eps).log()
72
+
73
+ def entropy(prob):
74
+ return (-prob * log(prob)).sum(dim=-1)
75
+
76
+ # cosine sim linear
77
+
78
+ class CosineSimLinear(Module):
79
+ def __init__(
80
+ self,
81
+ dim_in,
82
+ dim_out,
83
+ scale = 1.
84
+ ):
85
+ super().__init__()
86
+ self.scale = scale
87
+ self.weight = nn.Parameter(torch.randn(dim_in, dim_out))
88
+
89
+ def forward(self, x):
90
+ x = F.normalize(x, dim = -1)
91
+ w = F.normalize(self.weight, dim = 0)
92
+ return (x @ w) * self.scale
93
+
94
+ def soft_entropy_loss(u, tau=1.0, gamma=1.0):
95
+ """
96
+ Compute the soft entropy loss for Binary Spherical Quantization (BSQ).
97
+
98
+ Args:
99
+ u (torch.Tensor): Input latent embeddings of shape (batch_size, L).
100
+ tau (float): Temperature scaling factor.
101
+ gamma (float): Weight for the second entropy term.
102
+
103
+ Returns:
104
+ torch.Tensor: Soft entropy loss.
105
+ """
106
+ # Binary quantization: Generate implicit codebook corners
107
+ L = u.size(1) # Dimensionality of codebook
108
+ corners = torch.tensor([-1.0, 1.0], device=u.device) / (L**0.5)
109
+
110
+ # Compute soft quantization probabilities for all dimensions
111
+ # q_hat(c|u) for each dimension
112
+ prob_matrix = torch.sigmoid(2 * tau * corners.unsqueeze(1) * u.unsqueeze(2)) # Shape: (batch_size, L, 2)
113
+
114
+ # Entropy of q_hat(c|u) (independent along each dimension)
115
+ entropy_per_dim = -torch.sum(prob_matrix * prob_matrix.log(), dim=-1) # Shape: (batch_size, L)
116
+ entropy_term1 = entropy_per_dim.mean()
117
+
118
+ # Expected probabilities for dataset entropy (approximation)
119
+ expected_probs = prob_matrix.mean(dim=0) # Mean across batch, shape: (L, 2)
120
+ entropy_term2 = -torch.sum(expected_probs * expected_probs.log(), dim=-1).mean()
121
+
122
+ # Final entropy loss
123
+ loss = entropy_term1 - gamma * entropy_term2
124
+ return loss
125
+
126
+ # class
127
+
128
+ class BinarySphericalQuantize(Module):
129
+ def __init__(
130
+ self,
131
+ *,
132
+ dim = None,
133
+ codebook_size = None,
134
+ entropy_loss_weight = 0.1,
135
+ commitment_loss_weight = 0.,
136
+ diversity_gamma = 1.,
137
+ straight_through_activation = nn.Identity(),
138
+ num_codebooks = 1,
139
+ keep_num_codebooks_dim = None,
140
+ codebook_scale = 1., # for residual LFQ, codebook scaled down by 2x at each layer
141
+ frac_per_sample_entropy = 0.25, # make less than 1. to only use a random fraction of the probs for per sample entropy
142
+ has_projections = None,
143
+ projection_has_bias = True,
144
+ soft_clamp_input_value = None,
145
+ cosine_sim_project_in = False,
146
+ cosine_sim_project_in_scale = None,
147
+ channel_first = None,
148
+ experimental_softplus_entropy_loss = False,
149
+ entropy_loss_offset = 5., # how much to shift the loss before softplus
150
+ spherical = True, # from https://arxiv.org/abs/2406.07548
151
+ force_quantization_f32 = True, # will force the quantization step to be full precision
152
+ enable_entropy_loss = True,
153
+ soft_entropy_loss = True,
154
+ ):
155
+ super().__init__()
156
+
157
+ # some assert validations
158
+
159
+ assert exists(dim) or exists(codebook_size), 'either dim or codebook_size must be specified for LFQ'
160
+ assert not exists(codebook_size) or log2(codebook_size).is_integer(), f'your codebook size must be a power of 2 for lookup free quantization (suggested {2 ** ceil(log2(codebook_size))})'
161
+
162
+ codebook_size = default(codebook_size, lambda: 2 ** dim)
163
+ self.codebook_size = codebook_size
164
+
165
+ codebook_dim = int(log2(codebook_size))
166
+ codebook_dims = codebook_dim * num_codebooks
167
+ dim = default(dim, codebook_dims)
168
+
169
+ has_projections = default(has_projections, dim != codebook_dims)
170
+
171
+ if cosine_sim_project_in:
172
+ cosine_sim_project_in = default(cosine_sim_project_in_scale, codebook_scale)
173
+ project_in_klass = partial(CosineSimLinear, scale = cosine_sim_project_in)
174
+ else:
175
+ project_in_klass = partial(nn.Linear, bias = projection_has_bias)
176
+
177
+ self.project_in = project_in_klass(dim, codebook_dims) if has_projections else nn.Identity()
178
+ self.project_out = nn.Linear(codebook_dims, dim, bias = projection_has_bias) if has_projections else nn.Identity()
179
+ self.has_projections = has_projections
180
+
181
+ self.dim = dim
182
+ self.codebook_dim = codebook_dim
183
+ self.num_codebooks = num_codebooks
184
+
185
+ keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1)
186
+ assert not (num_codebooks > 1 and not keep_num_codebooks_dim)
187
+ self.keep_num_codebooks_dim = keep_num_codebooks_dim
188
+
189
+ # channel first
190
+
191
+ self.channel_first = channel_first
192
+
193
+ # straight through activation
194
+
195
+ self.activation = straight_through_activation
196
+
197
+ # whether to use BSQ (binary spherical quantization)
198
+
199
+ self.spherical = spherical
200
+ self.maybe_l2norm = (lambda t: l2norm(t) * self.codebook_scale) if spherical else identity
201
+
202
+ # entropy aux loss related weights
203
+
204
+ assert 0 < frac_per_sample_entropy <= 1.
205
+ self.frac_per_sample_entropy = frac_per_sample_entropy
206
+
207
+ self.diversity_gamma = diversity_gamma
208
+ self.entropy_loss_weight = entropy_loss_weight
209
+
210
+ # codebook scale
211
+
212
+ self.codebook_scale = codebook_scale
213
+
214
+ # commitment loss
215
+
216
+ self.commitment_loss_weight = commitment_loss_weight
217
+
218
+ # whether to soft clamp the input value from -value to value
219
+
220
+ self.soft_clamp_input_value = soft_clamp_input_value
221
+ assert not exists(soft_clamp_input_value) or soft_clamp_input_value >= codebook_scale
222
+
223
+ # whether to make the entropy loss positive through a softplus (experimental, please report if this worked or not in discussions)
224
+
225
+ self.entropy_loss_offset = entropy_loss_offset
226
+ self.experimental_softplus_entropy_loss = experimental_softplus_entropy_loss
227
+
228
+ # for no auxiliary loss, during inference
229
+
230
+ self.register_buffer('mask', 2 ** torch.arange(codebook_dim - 1, -1, -1))
231
+ self.register_buffer('zero', torch.tensor(0.), persistent = False)
232
+
233
+ # whether to force quantization step to be f32
234
+
235
+ self.force_quantization_f32 = force_quantization_f32
236
+
237
+ # codes
238
+ self.enable_entropy_loss = enable_entropy_loss
239
+ self.soft_entropy_loss = soft_entropy_loss
240
+ if codebook_size <= 100000:
241
+ all_codes = torch.arange(codebook_size)
242
+ bits = ((all_codes[..., None].int() & self.mask) != 0).float()
243
+ codebook = self.bits_to_codes(bits)
244
+
245
+ self.register_buffer('codebook', codebook.float(), persistent = False)
246
+ else:
247
+ all_codes = torch.arange(pow(2, 16))
248
+ mask = 2 ** torch.arange(16 - 1, -1, -1)
249
+ bits = ((all_codes[..., None].int() & mask) != 0).float()
250
+ codebook = self.bits_to_codes(bits)
251
+
252
+ self.register_buffer('codebook', codebook.float(), persistent = False)
253
+
254
+ def bits_to_codes(self, bits):
255
+ return bits * self.codebook_scale * 2 - self.codebook_scale
256
+
257
+ @property
258
+ def dtype(self):
259
+ return self.codebook.dtype
260
+
261
+ def indices_to_codes(
262
+ self,
263
+ indices,
264
+ project_out = True
265
+ ):
266
+ is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim))
267
+ should_transpose = default(self.channel_first, is_img_or_video)
268
+
269
+ if not self.keep_num_codebooks_dim:
270
+ indices = rearrange(indices, '... -> ... 1')
271
+
272
+ # indices to codes, which are bits of either -1 or 1
273
+
274
+ bits = ((indices[..., None].int() & self.mask) != 0).to(self.dtype)
275
+
276
+ codes = self.bits_to_codes(bits)
277
+
278
+ codes = self.maybe_l2norm(codes)
279
+
280
+ codes = rearrange(codes, '... c d -> ... (c d)')
281
+
282
+ # whether to project codes out to original dimensions
283
+ # if the input feature dimensions were not log2(codebook size)
284
+
285
+ if project_out:
286
+ codes = self.project_out(codes)
287
+
288
+ # rearrange codes back to original shape
289
+
290
+ if should_transpose:
291
+ codes = rearrange(codes, 'b ... d -> b d ...')
292
+
293
+ return codes
294
+
295
+ def bits_to_z(self, bits):
296
+ # assert bits must contain only -1 and 1
297
+ assert torch.all(bits.abs() == 1)
298
+ quantized = bits.float()
299
+ quantized = self.maybe_l2norm(quantized)
300
+ z = self.project_out(quantized)
301
+ return z
302
+
303
+ def forward(
304
+ self,
305
+ x,
306
+ inv_temperature = 100.,
307
+ return_loss_breakdown = False,
308
+ mask = None,
309
+ return_bits = False
310
+ ):
311
+ """
312
+ einstein notation
313
+ b - batch
314
+ n - sequence (or flattened spatial dimensions)
315
+ d - feature dimension, which is also log2(codebook size)
316
+ c - number of codebook dim
317
+ """
318
+
319
+ is_img_or_video = x.ndim >= 4
320
+ should_transpose = default(self.channel_first, is_img_or_video)
321
+
322
+ # standardize image or video into (batch, seq, dimension)
323
+
324
+ if should_transpose:
325
+ x = rearrange(x, 'b d ... -> b ... d')
326
+ x, ps = pack_one(x, 'b * d')
327
+
328
+ assert x.shape[-1] == self.dim, f'expected dimension of {self.dim} but received {x.shape[-1]}'
329
+
330
+ x = self.project_in(x)
331
+
332
+ # maybe soft clamp
333
+
334
+ if exists(self.soft_clamp_input_value):
335
+ clamp_value = self.soft_clamp_input_value
336
+ x = (x / clamp_value).tanh() * clamp_value
337
+
338
+ # split out number of codebooks
339
+
340
+ x = rearrange(x, 'b n (c d) -> b n c d', c = self.num_codebooks)
341
+
342
+ # maybe l2norm
343
+
344
+ x = self.maybe_l2norm(x)
345
+
346
+ # whether to force quantization step to be full precision or not
347
+
348
+ force_f32 = self.force_quantization_f32
349
+
350
+ quantization_context = partial(autocast, 'cuda', enabled = False) if force_f32 else nullcontext
351
+
352
+ with quantization_context():
353
+
354
+ if force_f32:
355
+ orig_dtype = x.dtype
356
+ x = x.float()
357
+
358
+ # quantize by eq 3.
359
+
360
+ original_input = x
361
+
362
+ codebook_value = torch.ones_like(x) * self.codebook_scale
363
+ quantized = torch.where(x > 0, codebook_value, -codebook_value)
364
+ if return_bits:
365
+ return quantized
366
+
367
+ # calculate indices
368
+
369
+ indices = reduce((quantized > 0).int() * self.mask.int(), 'b n c d -> b n c', 'sum')
370
+
371
+ # maybe l2norm
372
+
373
+ quantized = self.maybe_l2norm(quantized)
374
+
375
+ # use straight-through gradients (optionally with custom activation fn) if training
376
+
377
+ if self.training:
378
+ x = self.activation(x)
379
+ x = x + (quantized - x).detach()
380
+ else:
381
+ x = quantized
382
+
383
+ # entropy aux loss
384
+ if self.soft_entropy_loss:
385
+ entropy_aux_loss = soft_entropy_loss(x, tau=1.0, gamma=1.0)
386
+ elif self.training and self.enable_entropy_loss:
387
+
388
+ if force_f32:
389
+ codebook = self.codebook.float()
390
+
391
+ codebook = self.maybe_l2norm(codebook)
392
+
393
+ # whether to only use a fraction of probs, for reducing memory
394
+
395
+ if self.frac_per_sample_entropy < 1.:
396
+ # account for mask
397
+ if exists(mask):
398
+ original_input = original_input[mask]
399
+ original_input = rearrange(original_input, 'b n ... -> (b n) ...')
400
+
401
+ rand_mask = torch.randn(self.codebook_dim).argsort(dim = -1) < 16
402
+
403
+ sampled_input = original_input[..., rand_mask]
404
+
405
+ sampled_distance = -2 * einsum('... i d, j d -> ... i j', sampled_input, codebook)
406
+
407
+ sampled_prob = (-sampled_distance * inv_temperature).softmax(dim = -1)
408
+
409
+ per_sample_probs = sampled_prob
410
+ else:
411
+ if exists(mask):
412
+ original_input = original_input[mask]
413
+ original_input = rearrange(original_input, 'b n ... -> (b n) ...')
414
+ # the same as euclidean distance up to a constant
415
+ distance = -2 * einsum('... i d, j d -> ... i j', original_input, codebook)
416
+
417
+ prob = (-distance * inv_temperature).softmax(dim = -1)
418
+
419
+ per_sample_probs = prob
420
+
421
+ # calculate per sample entropy
422
+
423
+ per_sample_entropy = entropy(per_sample_probs).mean()
424
+
425
+ # distribution over all available tokens in the batch
426
+
427
+ avg_prob = reduce(per_sample_probs, '... c d -> c d', 'mean')
428
+
429
+ avg_prob = maybe_distributed_mean(avg_prob)
430
+
431
+ codebook_entropy = entropy(avg_prob).mean()
432
+
433
+ # 1. entropy will be nudged to be low for each code, to encourage the network to output confident predictions
434
+ # 2. codebook entropy will be nudged to be high, to encourage all codes to be uniformly used within the batch
435
+
436
+ entropy_aux_loss = per_sample_entropy - self.diversity_gamma * codebook_entropy
437
+ else:
438
+ # if not training, just return dummy 0
439
+ entropy_aux_loss = per_sample_entropy = codebook_entropy = self.zero
440
+
441
+ # whether to make the entropy loss positive or not through a (shifted) softplus
442
+
443
+ if self.training and self.experimental_softplus_entropy_loss:
444
+ entropy_aux_loss = F.softplus(entropy_aux_loss + self.entropy_loss_offset)
445
+
446
+ # commit loss
447
+
448
+ if self.training and self.commitment_loss_weight > 0.:
449
+
450
+ commit_loss = F.mse_loss(original_input, quantized.detach(), reduction = 'none')
451
+
452
+ if exists(mask):
453
+ commit_loss = commit_loss[mask]
454
+
455
+ commit_loss = commit_loss.mean()
456
+ else:
457
+ commit_loss = self.zero
458
+
459
+ # input back to original dtype if needed
460
+
461
+ if force_f32:
462
+ x = x.type(orig_dtype)
463
+
464
+ # merge back codebook dim
465
+
466
+ x = rearrange(x, 'b n c d -> b n (c d)')
467
+
468
+ # project out to feature dimension if needed
469
+
470
+ x = self.project_out(x)
471
+
472
+ # reconstitute image or video dimensions
473
+
474
+ if should_transpose:
475
+ x = unpack_one(x, ps, 'b * d')
476
+ x = rearrange(x, 'b ... d -> b d ...')
477
+
478
+ indices = unpack_one(indices, ps, 'b * c')
479
+
480
+ # whether to remove single codebook dim
481
+
482
+ if not self.keep_num_codebooks_dim:
483
+ indices = rearrange(indices, '... 1 -> ...')
484
+
485
+ # complete aux loss
486
+
487
+ aux_loss = entropy_aux_loss * self.entropy_loss_weight + commit_loss * self.commitment_loss_weight
488
+
489
+ # returns
490
+
491
+ ret = Return(x, indices, aux_loss)
492
+
493
+ if not return_loss_breakdown:
494
+ return ret
495
+
496
+ return ret, LossBreakdown(per_sample_entropy, codebook_entropy, commit_loss)
497
+
498
+ class GroupedResidualBSQ(Module):
499
+ def __init__(
500
+ self,
501
+ *,
502
+ dim,
503
+ groups = 1,
504
+ accept_image_fmap = False,
505
+ **kwargs
506
+ ):
507
+ super().__init__()
508
+ self.dim = dim
509
+ self.groups = groups
510
+ assert (dim % groups) == 0
511
+ dim_per_group = dim // groups
512
+
513
+ self.accept_image_fmap = accept_image_fmap
514
+
515
+ self.rvqs = nn.ModuleList([])
516
+
517
+ for _ in range(groups):
518
+ self.rvqs.append(LFQ(
519
+ dim = dim_per_group,
520
+ **kwargs
521
+ ))
522
+
523
+ self.codebook_size = self.rvqs[0].codebook_size
524
+
525
+ @property
526
+ def codebooks(self):
527
+ return torch.stack(tuple(rvq.codebooks for rvq in self.rvqs))
528
+
529
+ @property
530
+ def split_dim(self):
531
+ return 1 if self.accept_image_fmap else -1
532
+
533
+ def get_codes_from_indices(self, indices):
534
+ codes = tuple(rvq.get_codes_from_indices(chunk_indices) for rvq, chunk_indices in zip(self.rvqs, indices))
535
+ return torch.stack(codes)
536
+
537
+ def get_output_from_indices(self, indices):
538
+ outputs = tuple(rvq.get_output_from_indices(chunk_indices) for rvq, chunk_indices in zip(self.rvqs, indices))
539
+ return torch.cat(outputs, dim = self.split_dim)
540
+
541
+ def forward(
542
+ self,
543
+ x,
544
+ return_all_codes = False
545
+ ):
546
+ shape, split_dim = x.shape, self.split_dim
547
+ assert shape[split_dim] == self.dim
548
+
549
+ # split the feature dimension into groups
550
+
551
+ x = x.chunk(self.groups, dim = split_dim)
552
+
553
+ forward_kwargs = dict(
554
+ )
555
+
556
+ # invoke residual vq on each group
557
+
558
+ out = tuple(rvq(chunk, **forward_kwargs) for rvq, chunk in zip(self.rvqs, x))
559
+ out = tuple(zip(*out))
560
+
561
+ # otherwise, get all the zipped outputs and combine them
562
+
563
+ quantized, all_indices, *maybe_aux_loss = out
564
+
565
+ quantized = torch.cat(quantized, dim = split_dim)
566
+ all_indices = torch.stack(all_indices)
567
+
568
+ ret = (quantized, all_indices, *maybe_aux_loss)
569
+ return ret
modules/astral_quantization/convnext.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from typing import List
5
+
6
+
7
+ class ConvNextV2LayerNorm(nn.Module):
8
+ r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
9
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height,
10
+ width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width).
11
+ """
12
+
13
+ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
14
+ super().__init__()
15
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
16
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
17
+ self.eps = eps
18
+ self.data_format = data_format
19
+ if self.data_format not in ["channels_last", "channels_first"]:
20
+ raise NotImplementedError(f"Unsupported data format: {self.data_format}")
21
+ self.normalized_shape = (normalized_shape,)
22
+
23
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
24
+ if self.data_format == "channels_last":
25
+ x = torch.nn.functional.layer_norm(
26
+ x, self.normalized_shape, self.weight, self.bias, self.eps
27
+ )
28
+ elif self.data_format == "channels_first":
29
+ input_dtype = x.dtype
30
+ x = x.float()
31
+ u = x.mean(1, keepdim=True)
32
+ s = (x - u).pow(2).mean(1, keepdim=True)
33
+ x = (x - u) / torch.sqrt(s + self.eps)
34
+ x = x.to(dtype=input_dtype)
35
+ x = self.weight[None, :, None] * x + self.bias[None, :, None]
36
+ return x
37
+
38
+
39
+ class GRN(nn.Module):
40
+ def __init__(self, dim):
41
+ super().__init__()
42
+ self.gamma = nn.Parameter(torch.zeros(1, 1, dim))
43
+ self.beta = nn.Parameter(torch.zeros(1, 1, dim))
44
+
45
+ def forward(self, x):
46
+ Gx = torch.norm(x, p=2, dim=1, keepdim=True)
47
+ Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
48
+ return self.gamma * (x * Nx) + self.beta + x
49
+
50
+ class InterpolationLayer(nn.Module):
51
+ def __init__(self, ): # this is a default of 1 / 50 * (44100 / 512) / 4
52
+ super().__init__()
53
+ pass
54
+
55
+ def forward(self, x: torch.Tensor, target_len: torch.Tensor, *args, **kwargs) -> torch.Tensor:
56
+ x = F.interpolate(x, size=target_len, mode='linear')
57
+ return x
58
+
59
+ class ConvNeXtV2Stage(nn.Module):
60
+ def __init__(
61
+ self,
62
+ dim: int = 512,
63
+ intermediate_dim: int = 2048,
64
+ num_blocks: int = 1,
65
+ dilation: int = 1,
66
+ downsample_layer_indices: List[int] = None,
67
+ downsample_factors: List[int] = None,
68
+ upsample_layer_indices: List[int] = None,
69
+ upsample_factors: List[int] = None,
70
+ interpolation_layer_indices: List[int] = None,
71
+ input_dim: int = None,
72
+ output_dim: int = None,
73
+ gin_channels: int = 0,
74
+ ):
75
+ super().__init__()
76
+ # maybe downsample layers
77
+ if downsample_layer_indices is not None:
78
+ assert downsample_factors is not None
79
+ self.downsample_blocks = nn.ModuleList(
80
+ [
81
+ nn.Sequential(
82
+ ConvNextV2LayerNorm(dim, data_format="channels_first"),
83
+ nn.Conv1d(
84
+ dim, dim, kernel_size=downsample_factor, stride=downsample_factor
85
+ ),
86
+ ) for _, downsample_factor in zip(downsample_layer_indices, downsample_factors)
87
+ ]
88
+ )
89
+ self.downsample_layer_indices = downsample_layer_indices
90
+ else:
91
+ self.downsample_blocks = nn.ModuleList()
92
+ self.downsample_layer_indices = []
93
+
94
+ # maybe upsample layers
95
+ if upsample_layer_indices is not None:
96
+ assert upsample_factors is not None
97
+ self.upsample_blocks = nn.ModuleList(
98
+ [
99
+ nn.Sequential(
100
+ ConvNextV2LayerNorm(dim, data_format="channels_first"),
101
+ nn.ConvTranspose1d(
102
+ dim, dim, kernel_size=upsample_factor, stride=upsample_factor
103
+ ),
104
+ ) for _, upsample_factor in zip(upsample_layer_indices, upsample_factors)
105
+ ]
106
+ )
107
+ self.upsample_layer_indices = upsample_layer_indices
108
+ else:
109
+ self.upsample_blocks = nn.ModuleList()
110
+ self.upsample_layer_indices = []
111
+
112
+ # maybe interpolation layers
113
+ if interpolation_layer_indices is not None:
114
+ self.interpolation_blocks = nn.ModuleList(
115
+ [
116
+ InterpolationLayer()
117
+ for _ in interpolation_layer_indices
118
+ ]
119
+ )
120
+ self.interpolation_layer_indices = interpolation_layer_indices
121
+ else:
122
+ self.interpolation_blocks = nn.ModuleList()
123
+ self.interpolation_layer_indices = []
124
+
125
+ # main blocks
126
+ self.blocks = nn.ModuleList(
127
+ [
128
+ ConvNeXtV2Block(
129
+ dim=dim,
130
+ intermediate_dim=intermediate_dim,
131
+ dilation=dilation,
132
+ )
133
+ for _ in range(num_blocks)
134
+ ]
135
+ )
136
+ # maybe input and output projections
137
+ if input_dim is not None and input_dim != dim:
138
+ self.input_projection = nn.Conv1d(input_dim, dim, kernel_size=1)
139
+ else:
140
+ self.input_projection = nn.Identity()
141
+ if output_dim is not None and output_dim != dim:
142
+ self.output_projection = nn.Conv1d(dim, output_dim, kernel_size=1)
143
+ else:
144
+ self.output_projection = nn.Identity()
145
+
146
+ if gin_channels > 0:
147
+ self.gin = nn.Conv1d(gin_channels, dim, kernel_size=1)
148
+
149
+ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
150
+ x = self.input_projection(x) # B, D, T
151
+ if hasattr(self, 'gin'):
152
+ g = kwargs['g']
153
+ x = x + self.gin(g)
154
+ # pad to a multiple of cumprod(downsample_factors)
155
+ if len(self.downsample_blocks) > 0:
156
+ downsample_factor = 1
157
+ for factor in self.downsample_blocks:
158
+ downsample_factor *= factor[1].stride[0]
159
+ pad_len = downsample_factor - x.size(-1) % downsample_factor
160
+ if pad_len > 0:
161
+ x = torch.cat([x, torch.zeros_like(x[:, :, :pad_len])], dim=-1)
162
+
163
+ # main blocks
164
+ for layer_idx, block in enumerate(self.blocks):
165
+ if layer_idx in self.downsample_layer_indices:
166
+ x = self.downsample_blocks[self.downsample_layer_indices.index(layer_idx)](x)
167
+ if layer_idx in self.upsample_layer_indices:
168
+ x = self.upsample_blocks[self.upsample_layer_indices.index(layer_idx)](x)
169
+ if layer_idx in self.interpolation_layer_indices:
170
+ x = self.interpolation_blocks[self.interpolation_layer_indices.index(layer_idx)](x, target_len=kwargs['target_len'])
171
+ x = block(x)
172
+ x = self.output_projection(x)
173
+ return x
174
+
175
+ def setup_caches(self, *args, **kwargs):
176
+ pass
177
+
178
+
179
+ class ConvNeXtV2Block(nn.Module):
180
+ def __init__(
181
+ self,
182
+ dim: int,
183
+ intermediate_dim: int,
184
+ dilation: int = 1,
185
+ ):
186
+ super().__init__()
187
+ padding = (dilation * (7 - 1)) // 2
188
+ self.dwconv = nn.Conv1d(
189
+ dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation
190
+ ) # depthwise conv
191
+ self.norm = ConvNextV2LayerNorm(dim, data_format="channels_first")
192
+ self.pwconv1 = nn.Linear(
193
+ dim, intermediate_dim
194
+ ) # pointwise/1x1 convs, implemented with linear layers
195
+ self.act = nn.GELU()
196
+ self.grn = GRN(intermediate_dim)
197
+ self.pwconv2 = nn.Linear(intermediate_dim, dim)
198
+
199
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
200
+ residual = x
201
+ x = self.dwconv(x)
202
+ x = self.norm(x)
203
+ x = x.transpose(1, 2) # b d n -> b n d
204
+ x = self.pwconv1(x)
205
+ x = self.act(x)
206
+ x = self.grn(x)
207
+ x = self.pwconv2(x)
208
+ x = x.transpose(1, 2) # b n d -> b d n
209
+ return residual + x
modules/astral_quantization/default_model.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModel, Wav2Vec2FeatureExtractor
3
+
4
+ class AstralQuantizer(torch.nn.Module):
5
+ def __init__(
6
+ self,
7
+ tokenizer_name: str,
8
+ ssl_model_name: str,
9
+ ssl_output_layer: int,
10
+ encoder: torch.nn.Module,
11
+ quantizer: torch.nn.Module,
12
+ skip_ssl: bool = False,
13
+ ):
14
+ super().__init__()
15
+ self.encoder = encoder
16
+ self.quantizer = quantizer
17
+ self.tokenizer_name = tokenizer_name
18
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
19
+
20
+ # Load SSL model from Huggingface
21
+ self.ssl_model_name = ssl_model_name
22
+ self.ssl_output_layer = ssl_output_layer
23
+ self.ssl_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(ssl_model_name)
24
+
25
+ if skip_ssl: # in case the same SSL model has been loaded somewhere else
26
+ self.ssl_model = None
27
+ else:
28
+ self.ssl_model = AutoModel.from_pretrained(ssl_model_name).eval()
29
+ self.ssl_model.encoder.layers = self.ssl_model.encoder.layers[:ssl_output_layer]
30
+ self.ssl_model.encoder.layer_norm = torch.nn.Identity()
31
+
32
+ def load_separate_checkpoint(self, checkpoint_path):
33
+ params = torch.load(checkpoint_path, map_location='cpu')['net']
34
+ for key in params.keys():
35
+ for k in list(params[key].keys()):
36
+ if k.startswith("module."):
37
+ params[key][k[len("module."):]] = params[key][k]
38
+ del params[key][k]
39
+ self.encoder.load_state_dict(params['encoder'])
40
+ self.quantizer.load_state_dict(params['vq'])
41
+ if self.decoder is not None:
42
+ self.decoder.load_state_dict(params['decoder'])
43
+ if self.asr_decoder is not None:
44
+ self.asr_decoder.load_state_dict(params['predictor'], strict=False)
45
+
46
+ def forward(self, waves_16k, wave_16k_lens, ssl_model=None):
47
+ ssl_fn = self.ssl_model if self.ssl_model else ssl_model
48
+ assert ssl_fn is not None, "In case in-class SSL model loading is skipped, external ssl_model must be provided"
49
+ waves_16k_input_list = [
50
+ waves_16k[bib, :wave_16k_lens[bib]].cpu().numpy()
51
+ for bib in range(len(waves_16k))
52
+ ]
53
+ alt_inputs = self.ssl_feature_extractor(
54
+ waves_16k_input_list,
55
+ return_tensors='pt',
56
+ return_attention_mask=True,
57
+ padding=True,
58
+ sampling_rate=16000
59
+ ).to(waves_16k.device)
60
+ feature_lens = alt_inputs.data['attention_mask'].sum(-1) // 320 # frame rate of hubert is 50 Hz
61
+
62
+ outputs = ssl_fn(
63
+ alt_inputs.input_values,
64
+ attention_mask=alt_inputs.attention_mask,
65
+ )
66
+ last_hidden_states = outputs.last_hidden_state
67
+ last_hidden_states = last_hidden_states[:, :feature_lens.max(), :]
68
+ feature_lens = feature_lens.clamp(max=last_hidden_states.size(1))
69
+ last_hidden_states = last_hidden_states.transpose(1, 2)
70
+ x_hidden = self.encoder(last_hidden_states, feature_lens)
71
+ x_hidden = x_hidden.transpose(1, 2)
72
+ x_quantized, indices = self.quantizer(x_hidden)[:2]
73
+ return x_quantized, indices, feature_lens
modules/astral_quantization/transformer.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ from dataclasses import dataclass
7
+ from typing import Optional
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch import Tensor
12
+ from torch.nn import functional as F
13
+ import time
14
+
15
+ def find_multiple(n: int, k: int) -> int:
16
+ if n % k == 0:
17
+ return n
18
+ return n + k - (n % k)
19
+
20
+ class AdaptiveLayerNorm(nn.Module):
21
+ r"""Adaptive Layer Normalization"""
22
+
23
+ def __init__(self, d_model, norm) -> None:
24
+ super(AdaptiveLayerNorm, self).__init__()
25
+ self.project_layer = nn.Linear(d_model, 2 * d_model)
26
+ self.norm = norm
27
+ self.d_model = d_model
28
+ self.eps = self.norm.eps
29
+
30
+ def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor:
31
+ if embedding is None:
32
+ return self.norm(input)
33
+ weight, bias = torch.split(
34
+ self.project_layer(embedding),
35
+ split_size_or_sections=self.d_model,
36
+ dim=-1,
37
+ )
38
+ return weight * self.norm(input) + bias
39
+
40
+
41
+ @dataclass
42
+ class ModelArgs:
43
+ block_size: int = 2048
44
+ vocab_size: int = 32000
45
+ n_layer: int = 32
46
+ n_head: int = 32
47
+ dim: int = 4096
48
+ intermediate_size: int = None
49
+ n_local_heads: int = -1
50
+ head_dim: int = 64
51
+ rope_base: float = 10000
52
+ norm_eps: float = 1e-5
53
+ has_cross_attention: bool = False
54
+ context_dim: int = 0
55
+ is_causal: bool = False
56
+ dropout_rate: float = 0.1
57
+ attn_dropout_rate: float = 0.1
58
+
59
+ def __post_init__(self):
60
+ if self.n_local_heads == -1:
61
+ self.n_local_heads = self.n_head
62
+ if self.intermediate_size is None:
63
+ hidden_dim = 4 * self.dim
64
+ n_hidden = int(2 * hidden_dim / 3)
65
+ self.intermediate_size = find_multiple(n_hidden, 256)
66
+ # self.head_dim = self.dim // self.n_head
67
+
68
+ class Transformer(nn.Module):
69
+ def __init__(self, config: ModelArgs) -> None:
70
+ super().__init__()
71
+ self.config = config
72
+
73
+ self.layers = nn.ModuleList(TransformerBlock(config) for _ in range(config.n_layer))
74
+ self.norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps))
75
+
76
+ self.max_batch_size = -1
77
+ self.max_seq_length = config.block_size
78
+ freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.head_dim,
79
+ self.config.rope_base)
80
+ self.register_buffer("freqs_cis", freqs_cis)
81
+
82
+ causal_mask = torch.tril(
83
+ torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)
84
+ )
85
+ self.register_buffer("causal_mask", causal_mask)
86
+
87
+ def forward(self,
88
+ x: Tensor,
89
+ c: Tensor,
90
+ input_pos: Optional[Tensor] = None,
91
+ mask: Optional[Tensor] = None,
92
+ context: Optional[Tensor] = None,
93
+ context_input_pos: Optional[Tensor] = None,
94
+ cross_attention_mask: Optional[Tensor] = None,
95
+ ) -> Tensor:
96
+ if mask is None:
97
+ mask = self.causal_mask[:x.size(1), :x.size(1)]
98
+ else:
99
+ mask = mask[..., input_pos]
100
+ freqs_cis = self.freqs_cis[input_pos]
101
+ if context is not None:
102
+ context_freqs_cis = self.freqs_cis[context_input_pos]
103
+ else:
104
+ context_freqs_cis = None
105
+ skip_in_x_list = []
106
+ for i, layer in enumerate(self.layers):
107
+ x = layer(x, c, freqs_cis, mask, context, context_freqs_cis, cross_attention_mask)
108
+ x = self.norm(x, c)
109
+ return x
110
+
111
+
112
+ class TransformerBlock(nn.Module):
113
+ def __init__(self, config: ModelArgs) -> None:
114
+ super().__init__()
115
+ self.attention = Attention(config)
116
+ self.feed_forward = FeedForward(config)
117
+ self.ffn_norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps))
118
+ self.attention_norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps))
119
+
120
+ if config.has_cross_attention:
121
+ self.has_cross_attention = True
122
+ self.cross_attention = Attention(config, is_cross_attention=True)
123
+ self.cross_attention_norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps))
124
+ else:
125
+ self.has_cross_attention = False
126
+
127
+ def forward(self,
128
+ x: Tensor,
129
+ c: Tensor,
130
+ freqs_cis: Tensor,
131
+ mask: Tensor,
132
+ context: Optional[Tensor] = None,
133
+ context_freqs_cis: Optional[Tensor] = None,
134
+ cross_attention_mask: Optional[Tensor] = None,
135
+ ) -> Tensor:
136
+ #time_attn_start = time.time()
137
+ h = x + self.attention(self.attention_norm(x, c), freqs_cis, mask)
138
+ #print(f"time take for attention of sequence length {x.shape[1]} is {time.time() - time_attn_start}")
139
+ if self.has_cross_attention:
140
+ h = h + self.cross_attention(self.cross_attention_norm(h, c), freqs_cis, cross_attention_mask, context, context_freqs_cis)
141
+ out = h + self.feed_forward(self.ffn_norm(h, c))
142
+ return out
143
+
144
+
145
+ class Attention(nn.Module):
146
+ def __init__(self, config: ModelArgs, is_cross_attention: bool = False):
147
+ super().__init__()
148
+ assert config.dim % config.n_head == 0
149
+
150
+ total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
151
+ # key, query, value projections for all heads, but in a batch
152
+ if is_cross_attention:
153
+ self.wq = nn.Linear(config.dim, config.n_head * config.head_dim, bias=False)
154
+ self.wkv = nn.Linear(config.context_dim, 2 * config.n_local_heads * config.head_dim, bias=False)
155
+ else:
156
+ self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
157
+ self.wo = nn.Linear(config.head_dim * config.n_head, config.dim, bias=False)
158
+ self.kv_cache = None
159
+
160
+ self.n_head = config.n_head
161
+ self.head_dim = config.head_dim
162
+ self.n_local_heads = config.n_local_heads
163
+ self.dim = config.dim
164
+ self.attn_dropout_rate = config.attn_dropout_rate
165
+
166
+ def forward(self,
167
+ x: Tensor,
168
+ freqs_cis: Tensor,
169
+ mask: Tensor,
170
+ context: Optional[Tensor] = None,
171
+ context_freqs_cis: Optional[Tensor] = None,
172
+ ) -> Tensor:
173
+ bsz, seqlen, _ = x.shape
174
+
175
+ kv_size = self.n_local_heads * self.head_dim
176
+ if context is None:
177
+ q, k, v = self.wqkv(x).split([kv_size, kv_size, kv_size], dim=-1)
178
+ context_seqlen = seqlen
179
+ else:
180
+ q = self.wq(x)
181
+ k, v = self.wkv(context).split([kv_size, kv_size], dim=-1)
182
+ context_seqlen = context.shape[1]
183
+
184
+ q = q.view(bsz, seqlen, self.n_head, self.head_dim)
185
+ k = k.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
186
+ v = v.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
187
+
188
+ q = apply_rotary_emb(q, freqs_cis)
189
+ k = apply_rotary_emb(k, context_freqs_cis if context_freqs_cis is not None else freqs_cis)
190
+
191
+ q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
192
+
193
+ k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
194
+ v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
195
+ y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=self.attn_dropout_rate if self.training else 0.0)
196
+
197
+ y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.head_dim * self.n_head)
198
+
199
+ y = self.wo(y)
200
+ return y
201
+
202
+
203
+ class FeedForward(nn.Module):
204
+ def __init__(self, config: ModelArgs) -> None:
205
+ super().__init__()
206
+ self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
207
+ self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
208
+ self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
209
+ self.dropout = nn.Dropout(config.dropout_rate)
210
+
211
+ def forward(self, x: Tensor) -> Tensor:
212
+ return self.w2(self.dropout(F.silu(self.w1(x)) * self.w3(x)))
213
+
214
+
215
+ class RMSNorm(nn.Module):
216
+ def __init__(self, dim: int, eps: float = 1e-5):
217
+ super().__init__()
218
+ self.eps = eps
219
+ self.weight = nn.Parameter(torch.ones(dim))
220
+
221
+ def _norm(self, x):
222
+ return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
223
+
224
+ def forward(self, x: Tensor) -> Tensor:
225
+ output = self._norm(x.float()).type_as(x)
226
+ return output * self.weight
227
+
228
+
229
+ def precompute_freqs_cis(
230
+ seq_len: int, n_elem: int, base: int = 10000,
231
+ dtype: torch.dtype = torch.bfloat16
232
+ ) -> Tensor:
233
+ freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem))
234
+ t = torch.arange(seq_len, device=freqs.device)
235
+ freqs = torch.outer(t, freqs)
236
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
237
+ cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
238
+ return cache.to(dtype=dtype)
239
+
240
+
241
+ def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
242
+ xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
243
+ freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
244
+ x_out2 = torch.stack(
245
+ [
246
+ xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
247
+ xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
248
+ ],
249
+ -1,
250
+ )
251
+
252
+ x_out2 = x_out2.flatten(3)
253
+ return x_out2.type_as(x)
254
+
modules/audio.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.utils.data
4
+ from librosa.filters import mel as librosa_mel_fn
5
+ from scipy.io.wavfile import read
6
+
7
+ MAX_WAV_VALUE = 32768.0
8
+
9
+
10
+ def load_wav(full_path):
11
+ sampling_rate, data = read(full_path)
12
+ return data, sampling_rate
13
+
14
+
15
+ def dynamic_range_compression(x, C=1, clip_val=1e-5):
16
+ return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
17
+
18
+
19
+ def dynamic_range_decompression(x, C=1):
20
+ return np.exp(x) / C
21
+
22
+
23
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
24
+ return torch.log(torch.clamp(x, min=clip_val) * C)
25
+
26
+
27
+ def dynamic_range_decompression_torch(x, C=1):
28
+ return torch.exp(x) / C
29
+
30
+
31
+ def spectral_normalize_torch(magnitudes):
32
+ output = dynamic_range_compression_torch(magnitudes)
33
+ return output
34
+
35
+
36
+ def spectral_de_normalize_torch(magnitudes):
37
+ output = dynamic_range_decompression_torch(magnitudes)
38
+ return output
39
+
40
+
41
+ mel_basis = {}
42
+ hann_window = {}
43
+
44
+
45
+ def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
46
+ if torch.min(y) < -1.0:
47
+ print("min value is ", torch.min(y))
48
+ if torch.max(y) > 1.0:
49
+ print("max value is ", torch.max(y))
50
+
51
+ global mel_basis, hann_window # pylint: disable=global-statement
52
+ if f"{str(sampling_rate)}_{str(fmax)}_{str(y.device)}" not in mel_basis:
53
+ mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
54
+ mel_basis[str(sampling_rate) + "_" + str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
55
+ hann_window[str(sampling_rate) + "_" + str(y.device)] = torch.hann_window(win_size).to(y.device)
56
+
57
+ y = torch.nn.functional.pad(
58
+ y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
59
+ )
60
+ y = y.squeeze(1)
61
+
62
+ spec = torch.view_as_real(
63
+ torch.stft(
64
+ y,
65
+ n_fft,
66
+ hop_length=hop_size,
67
+ win_length=win_size,
68
+ window=hann_window[str(sampling_rate) + "_" + str(y.device)],
69
+ center=center,
70
+ pad_mode="reflect",
71
+ normalized=False,
72
+ onesided=True,
73
+ return_complex=True,
74
+ )
75
+ )
76
+
77
+ spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
78
+
79
+ spec = torch.matmul(mel_basis[str(sampling_rate) + "_" + str(fmax) + "_" + str(y.device)], spec)
80
+ spec = spectral_normalize_torch(spec)
81
+
82
+ return spec
modules/bigvgan/__pycache__/activations.cpython-310.pyc ADDED
Binary file (4.02 kB). View file
 
modules/bigvgan/__pycache__/bigvgan.cpython-310.pyc ADDED
Binary file (11.8 kB). View file
 
modules/bigvgan/__pycache__/env.cpython-310.pyc ADDED
Binary file (817 Bytes). View file