aaron
commited on
Commit
·
2fd89a2
0
Parent(s):
feat: 통합 TTS + 음성 변환 앱 (깨끗한 버전)
Browse files- OpenVoice V2와 Seed-VC를 결합한 통합 앱
- DAC 모델 초기화 문제 해결
- 빠른 시작을 위한 지연 로딩 구현
- 강화된 에러 처리 및 로깅 시스템
- Git LFS 문제 해결을 위한 깨끗한 커밋
This view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +39 -0
- OpenVoice/openvoice/__init__.py +0 -0
- OpenVoice/openvoice/__pycache__/__init__.cpython-312.pyc +0 -0
- OpenVoice/openvoice/__pycache__/se_extractor.cpython-312.pyc +0 -0
- OpenVoice/openvoice/api.py +202 -0
- OpenVoice/openvoice/attentions.py +465 -0
- OpenVoice/openvoice/commons.py +160 -0
- OpenVoice/openvoice/mel_processing.py +183 -0
- OpenVoice/openvoice/models.py +499 -0
- OpenVoice/openvoice/modules.py +598 -0
- OpenVoice/openvoice/openvoice_app.py +275 -0
- OpenVoice/openvoice/se_extractor.py +154 -0
- OpenVoice/openvoice/text/__init__.py +79 -0
- OpenVoice/openvoice/text/cleaners.py +16 -0
- OpenVoice/openvoice/text/english.py +188 -0
- OpenVoice/openvoice/text/mandarin.py +326 -0
- OpenVoice/openvoice/text/symbols.py +88 -0
- OpenVoice/openvoice/transforms.py +209 -0
- OpenVoice/openvoice/utils.py +194 -0
- README.md +119 -0
- app.py +843 -0
- hf_utils.py +12 -0
- modules/__pycache__/audio.cpython-310.pyc +0 -0
- modules/__pycache__/commons.cpython-310.pyc +0 -0
- modules/__pycache__/commons.cpython-38.pyc +0 -0
- modules/__pycache__/diffusion_transformer.cpython-310.pyc +0 -0
- modules/__pycache__/encodec.cpython-310.pyc +0 -0
- modules/__pycache__/flow_matching.cpython-310.pyc +0 -0
- modules/__pycache__/length_regulator.cpython-310.pyc +0 -0
- modules/__pycache__/rmvpe.cpython-310.pyc +0 -0
- modules/__pycache__/wavenet.cpython-310.pyc +0 -0
- modules/alias_free_torch/__init__.py +5 -0
- modules/alias_free_torch/__pycache__/__init__.cpython-310.pyc +0 -0
- modules/alias_free_torch/__pycache__/act.cpython-310.pyc +0 -0
- modules/alias_free_torch/__pycache__/filter.cpython-310.pyc +0 -0
- modules/alias_free_torch/__pycache__/resample.cpython-310.pyc +0 -0
- modules/alias_free_torch/act.py +29 -0
- modules/alias_free_torch/filter.py +96 -0
- modules/alias_free_torch/resample.py +57 -0
- modules/astral_quantization/__pycache__/bsq.cpython-310.pyc +0 -0
- modules/astral_quantization/__pycache__/convnext.cpython-310.pyc +0 -0
- modules/astral_quantization/__pycache__/default_model.cpython-310.pyc +0 -0
- modules/astral_quantization/bsq.py +569 -0
- modules/astral_quantization/convnext.py +209 -0
- modules/astral_quantization/default_model.py +73 -0
- modules/astral_quantization/transformer.py +254 -0
- modules/audio.py +82 -0
- modules/bigvgan/__pycache__/activations.cpython-310.pyc +0 -0
- modules/bigvgan/__pycache__/bigvgan.cpython-310.pyc +0 -0
- modules/bigvgan/__pycache__/env.cpython-310.pyc +0 -0
.gitattributes
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.o filter=lfs diff=lfs merge=lfs -text
|
37 |
+
*.pyd filter=lfs diff=lfs merge=lfs -text
|
38 |
+
*.ninja* filter=lfs diff=lfs merge=lfs -text
|
39 |
+
*.deps filter=lfs diff=lfs merge=lfs -text
|
OpenVoice/openvoice/__init__.py
ADDED
File without changes
|
OpenVoice/openvoice/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (160 Bytes). View file
|
|
OpenVoice/openvoice/__pycache__/se_extractor.cpython-312.pyc
ADDED
Binary file (6.92 kB). View file
|
|
OpenVoice/openvoice/api.py
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import re
|
4 |
+
import soundfile
|
5 |
+
from openvoice import utils
|
6 |
+
from openvoice import commons
|
7 |
+
import os
|
8 |
+
import librosa
|
9 |
+
from openvoice.text import text_to_sequence
|
10 |
+
from openvoice.mel_processing import spectrogram_torch
|
11 |
+
from openvoice.models import SynthesizerTrn
|
12 |
+
|
13 |
+
|
14 |
+
class OpenVoiceBaseClass(object):
|
15 |
+
def __init__(self,
|
16 |
+
config_path,
|
17 |
+
device='cuda:0'):
|
18 |
+
if 'cuda' in device:
|
19 |
+
assert torch.cuda.is_available()
|
20 |
+
|
21 |
+
hps = utils.get_hparams_from_file(config_path)
|
22 |
+
|
23 |
+
model = SynthesizerTrn(
|
24 |
+
len(getattr(hps, 'symbols', [])),
|
25 |
+
hps.data.filter_length // 2 + 1,
|
26 |
+
n_speakers=hps.data.n_speakers,
|
27 |
+
**hps.model,
|
28 |
+
).to(device)
|
29 |
+
|
30 |
+
model.eval()
|
31 |
+
self.model = model
|
32 |
+
self.hps = hps
|
33 |
+
self.device = device
|
34 |
+
|
35 |
+
def load_ckpt(self, ckpt_path):
|
36 |
+
checkpoint_dict = torch.load(ckpt_path, map_location=torch.device(self.device))
|
37 |
+
a, b = self.model.load_state_dict(checkpoint_dict['model'], strict=False)
|
38 |
+
print("Loaded checkpoint '{}'".format(ckpt_path))
|
39 |
+
print('missing/unexpected keys:', a, b)
|
40 |
+
|
41 |
+
|
42 |
+
class BaseSpeakerTTS(OpenVoiceBaseClass):
|
43 |
+
language_marks = {
|
44 |
+
"english": "EN",
|
45 |
+
"chinese": "ZH",
|
46 |
+
}
|
47 |
+
|
48 |
+
@staticmethod
|
49 |
+
def get_text(text, hps, is_symbol):
|
50 |
+
text_norm = text_to_sequence(text, hps.symbols, [] if is_symbol else hps.data.text_cleaners)
|
51 |
+
if hps.data.add_blank:
|
52 |
+
text_norm = commons.intersperse(text_norm, 0)
|
53 |
+
text_norm = torch.LongTensor(text_norm)
|
54 |
+
return text_norm
|
55 |
+
|
56 |
+
@staticmethod
|
57 |
+
def audio_numpy_concat(segment_data_list, sr, speed=1.):
|
58 |
+
audio_segments = []
|
59 |
+
for segment_data in segment_data_list:
|
60 |
+
audio_segments += segment_data.reshape(-1).tolist()
|
61 |
+
audio_segments += [0] * int((sr * 0.05)/speed)
|
62 |
+
audio_segments = np.array(audio_segments).astype(np.float32)
|
63 |
+
return audio_segments
|
64 |
+
|
65 |
+
@staticmethod
|
66 |
+
def split_sentences_into_pieces(text, language_str):
|
67 |
+
texts = utils.split_sentence(text, language_str=language_str)
|
68 |
+
print(" > Text splitted to sentences.")
|
69 |
+
print('\n'.join(texts))
|
70 |
+
print(" > ===========================")
|
71 |
+
return texts
|
72 |
+
|
73 |
+
def tts(self, text, output_path, speaker, language='English', speed=1.0):
|
74 |
+
mark = self.language_marks.get(language.lower(), None)
|
75 |
+
assert mark is not None, f"language {language} is not supported"
|
76 |
+
|
77 |
+
texts = self.split_sentences_into_pieces(text, mark)
|
78 |
+
|
79 |
+
audio_list = []
|
80 |
+
for t in texts:
|
81 |
+
t = re.sub(r'([a-z])([A-Z])', r'\1 \2', t)
|
82 |
+
t = f'[{mark}]{t}[{mark}]'
|
83 |
+
stn_tst = self.get_text(t, self.hps, False)
|
84 |
+
device = self.device
|
85 |
+
speaker_id = self.hps.speakers[speaker]
|
86 |
+
with torch.no_grad():
|
87 |
+
x_tst = stn_tst.unsqueeze(0).to(device)
|
88 |
+
x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(device)
|
89 |
+
sid = torch.LongTensor([speaker_id]).to(device)
|
90 |
+
audio = self.model.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=0.667, noise_scale_w=0.6,
|
91 |
+
length_scale=1.0 / speed)[0][0, 0].data.cpu().float().numpy()
|
92 |
+
audio_list.append(audio)
|
93 |
+
audio = self.audio_numpy_concat(audio_list, sr=self.hps.data.sampling_rate, speed=speed)
|
94 |
+
|
95 |
+
if output_path is None:
|
96 |
+
return audio
|
97 |
+
else:
|
98 |
+
soundfile.write(output_path, audio, self.hps.data.sampling_rate)
|
99 |
+
|
100 |
+
|
101 |
+
class ToneColorConverter(OpenVoiceBaseClass):
|
102 |
+
def __init__(self, *args, **kwargs):
|
103 |
+
super().__init__(*args, **kwargs)
|
104 |
+
|
105 |
+
if kwargs.get('enable_watermark', True):
|
106 |
+
import wavmark
|
107 |
+
self.watermark_model = wavmark.load_model().to(self.device)
|
108 |
+
else:
|
109 |
+
self.watermark_model = None
|
110 |
+
self.version = getattr(self.hps, '_version_', "v1")
|
111 |
+
|
112 |
+
|
113 |
+
|
114 |
+
def extract_se(self, ref_wav_list, se_save_path=None):
|
115 |
+
if isinstance(ref_wav_list, str):
|
116 |
+
ref_wav_list = [ref_wav_list]
|
117 |
+
|
118 |
+
device = self.device
|
119 |
+
hps = self.hps
|
120 |
+
gs = []
|
121 |
+
|
122 |
+
for fname in ref_wav_list:
|
123 |
+
audio_ref, sr = librosa.load(fname, sr=hps.data.sampling_rate)
|
124 |
+
y = torch.FloatTensor(audio_ref)
|
125 |
+
y = y.to(device)
|
126 |
+
y = y.unsqueeze(0)
|
127 |
+
y = spectrogram_torch(y, hps.data.filter_length,
|
128 |
+
hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length,
|
129 |
+
center=False).to(device)
|
130 |
+
with torch.no_grad():
|
131 |
+
g = self.model.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
|
132 |
+
gs.append(g.detach())
|
133 |
+
gs = torch.stack(gs).mean(0)
|
134 |
+
|
135 |
+
if se_save_path is not None:
|
136 |
+
os.makedirs(os.path.dirname(se_save_path), exist_ok=True)
|
137 |
+
torch.save(gs.cpu(), se_save_path)
|
138 |
+
|
139 |
+
return gs
|
140 |
+
|
141 |
+
def convert(self, audio_src_path, src_se, tgt_se, output_path=None, tau=0.3, message="default"):
|
142 |
+
hps = self.hps
|
143 |
+
# load audio
|
144 |
+
audio, sample_rate = librosa.load(audio_src_path, sr=hps.data.sampling_rate)
|
145 |
+
audio = torch.tensor(audio).float()
|
146 |
+
|
147 |
+
with torch.no_grad():
|
148 |
+
y = torch.FloatTensor(audio).to(self.device)
|
149 |
+
y = y.unsqueeze(0)
|
150 |
+
spec = spectrogram_torch(y, hps.data.filter_length,
|
151 |
+
hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length,
|
152 |
+
center=False).to(self.device)
|
153 |
+
spec_lengths = torch.LongTensor([spec.size(-1)]).to(self.device)
|
154 |
+
audio = self.model.voice_conversion(spec, spec_lengths, sid_src=src_se, sid_tgt=tgt_se, tau=tau)[0][
|
155 |
+
0, 0].data.cpu().float().numpy()
|
156 |
+
audio = self.add_watermark(audio, message)
|
157 |
+
if output_path is None:
|
158 |
+
return audio
|
159 |
+
else:
|
160 |
+
soundfile.write(output_path, audio, hps.data.sampling_rate)
|
161 |
+
|
162 |
+
def add_watermark(self, audio, message):
|
163 |
+
if self.watermark_model is None:
|
164 |
+
return audio
|
165 |
+
device = self.device
|
166 |
+
bits = utils.string_to_bits(message).reshape(-1)
|
167 |
+
n_repeat = len(bits) // 32
|
168 |
+
|
169 |
+
K = 16000
|
170 |
+
coeff = 2
|
171 |
+
for n in range(n_repeat):
|
172 |
+
trunck = audio[(coeff * n) * K: (coeff * n + 1) * K]
|
173 |
+
if len(trunck) != K:
|
174 |
+
print('Audio too short, fail to add watermark')
|
175 |
+
break
|
176 |
+
message_npy = bits[n * 32: (n + 1) * 32]
|
177 |
+
|
178 |
+
with torch.no_grad():
|
179 |
+
signal = torch.FloatTensor(trunck).to(device)[None]
|
180 |
+
message_tensor = torch.FloatTensor(message_npy).to(device)[None]
|
181 |
+
signal_wmd_tensor = self.watermark_model.encode(signal, message_tensor)
|
182 |
+
signal_wmd_npy = signal_wmd_tensor.detach().cpu().squeeze()
|
183 |
+
audio[(coeff * n) * K: (coeff * n + 1) * K] = signal_wmd_npy
|
184 |
+
return audio
|
185 |
+
|
186 |
+
def detect_watermark(self, audio, n_repeat):
|
187 |
+
bits = []
|
188 |
+
K = 16000
|
189 |
+
coeff = 2
|
190 |
+
for n in range(n_repeat):
|
191 |
+
trunck = audio[(coeff * n) * K: (coeff * n + 1) * K]
|
192 |
+
if len(trunck) != K:
|
193 |
+
print('Audio too short, fail to detect watermark')
|
194 |
+
return 'Fail'
|
195 |
+
with torch.no_grad():
|
196 |
+
signal = torch.FloatTensor(trunck).to(self.device).unsqueeze(0)
|
197 |
+
message_decoded_npy = (self.watermark_model.decode(signal) >= 0.5).int().detach().cpu().numpy().squeeze()
|
198 |
+
bits.append(message_decoded_npy)
|
199 |
+
bits = np.stack(bits).reshape(-1, 8)
|
200 |
+
message = utils.bits_to_string(bits)
|
201 |
+
return message
|
202 |
+
|
OpenVoice/openvoice/attentions.py
ADDED
@@ -0,0 +1,465 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
from torch.nn import functional as F
|
5 |
+
|
6 |
+
from openvoice import commons
|
7 |
+
import logging
|
8 |
+
|
9 |
+
logger = logging.getLogger(__name__)
|
10 |
+
|
11 |
+
|
12 |
+
class LayerNorm(nn.Module):
|
13 |
+
def __init__(self, channels, eps=1e-5):
|
14 |
+
super().__init__()
|
15 |
+
self.channels = channels
|
16 |
+
self.eps = eps
|
17 |
+
|
18 |
+
self.gamma = nn.Parameter(torch.ones(channels))
|
19 |
+
self.beta = nn.Parameter(torch.zeros(channels))
|
20 |
+
|
21 |
+
def forward(self, x):
|
22 |
+
x = x.transpose(1, -1)
|
23 |
+
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
|
24 |
+
return x.transpose(1, -1)
|
25 |
+
|
26 |
+
|
27 |
+
@torch.jit.script
|
28 |
+
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
|
29 |
+
n_channels_int = n_channels[0]
|
30 |
+
in_act = input_a + input_b
|
31 |
+
t_act = torch.tanh(in_act[:, :n_channels_int, :])
|
32 |
+
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
|
33 |
+
acts = t_act * s_act
|
34 |
+
return acts
|
35 |
+
|
36 |
+
|
37 |
+
class Encoder(nn.Module):
|
38 |
+
def __init__(
|
39 |
+
self,
|
40 |
+
hidden_channels,
|
41 |
+
filter_channels,
|
42 |
+
n_heads,
|
43 |
+
n_layers,
|
44 |
+
kernel_size=1,
|
45 |
+
p_dropout=0.0,
|
46 |
+
window_size=4,
|
47 |
+
isflow=True,
|
48 |
+
**kwargs
|
49 |
+
):
|
50 |
+
super().__init__()
|
51 |
+
self.hidden_channels = hidden_channels
|
52 |
+
self.filter_channels = filter_channels
|
53 |
+
self.n_heads = n_heads
|
54 |
+
self.n_layers = n_layers
|
55 |
+
self.kernel_size = kernel_size
|
56 |
+
self.p_dropout = p_dropout
|
57 |
+
self.window_size = window_size
|
58 |
+
# if isflow:
|
59 |
+
# cond_layer = torch.nn.Conv1d(256, 2*hidden_channels*n_layers, 1)
|
60 |
+
# self.cond_pre = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, 1)
|
61 |
+
# self.cond_layer = weight_norm(cond_layer, name='weight')
|
62 |
+
# self.gin_channels = 256
|
63 |
+
self.cond_layer_idx = self.n_layers
|
64 |
+
if "gin_channels" in kwargs:
|
65 |
+
self.gin_channels = kwargs["gin_channels"]
|
66 |
+
if self.gin_channels != 0:
|
67 |
+
self.spk_emb_linear = nn.Linear(self.gin_channels, self.hidden_channels)
|
68 |
+
# vits2 says 3rd block, so idx is 2 by default
|
69 |
+
self.cond_layer_idx = (
|
70 |
+
kwargs["cond_layer_idx"] if "cond_layer_idx" in kwargs else 2
|
71 |
+
)
|
72 |
+
# logging.debug(self.gin_channels, self.cond_layer_idx)
|
73 |
+
assert (
|
74 |
+
self.cond_layer_idx < self.n_layers
|
75 |
+
), "cond_layer_idx should be less than n_layers"
|
76 |
+
self.drop = nn.Dropout(p_dropout)
|
77 |
+
self.attn_layers = nn.ModuleList()
|
78 |
+
self.norm_layers_1 = nn.ModuleList()
|
79 |
+
self.ffn_layers = nn.ModuleList()
|
80 |
+
self.norm_layers_2 = nn.ModuleList()
|
81 |
+
|
82 |
+
for i in range(self.n_layers):
|
83 |
+
self.attn_layers.append(
|
84 |
+
MultiHeadAttention(
|
85 |
+
hidden_channels,
|
86 |
+
hidden_channels,
|
87 |
+
n_heads,
|
88 |
+
p_dropout=p_dropout,
|
89 |
+
window_size=window_size,
|
90 |
+
)
|
91 |
+
)
|
92 |
+
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
93 |
+
self.ffn_layers.append(
|
94 |
+
FFN(
|
95 |
+
hidden_channels,
|
96 |
+
hidden_channels,
|
97 |
+
filter_channels,
|
98 |
+
kernel_size,
|
99 |
+
p_dropout=p_dropout,
|
100 |
+
)
|
101 |
+
)
|
102 |
+
self.norm_layers_2.append(LayerNorm(hidden_channels))
|
103 |
+
|
104 |
+
def forward(self, x, x_mask, g=None):
|
105 |
+
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
106 |
+
x = x * x_mask
|
107 |
+
for i in range(self.n_layers):
|
108 |
+
if i == self.cond_layer_idx and g is not None:
|
109 |
+
g = self.spk_emb_linear(g.transpose(1, 2))
|
110 |
+
g = g.transpose(1, 2)
|
111 |
+
x = x + g
|
112 |
+
x = x * x_mask
|
113 |
+
y = self.attn_layers[i](x, x, attn_mask)
|
114 |
+
y = self.drop(y)
|
115 |
+
x = self.norm_layers_1[i](x + y)
|
116 |
+
|
117 |
+
y = self.ffn_layers[i](x, x_mask)
|
118 |
+
y = self.drop(y)
|
119 |
+
x = self.norm_layers_2[i](x + y)
|
120 |
+
x = x * x_mask
|
121 |
+
return x
|
122 |
+
|
123 |
+
|
124 |
+
class Decoder(nn.Module):
|
125 |
+
def __init__(
|
126 |
+
self,
|
127 |
+
hidden_channels,
|
128 |
+
filter_channels,
|
129 |
+
n_heads,
|
130 |
+
n_layers,
|
131 |
+
kernel_size=1,
|
132 |
+
p_dropout=0.0,
|
133 |
+
proximal_bias=False,
|
134 |
+
proximal_init=True,
|
135 |
+
**kwargs
|
136 |
+
):
|
137 |
+
super().__init__()
|
138 |
+
self.hidden_channels = hidden_channels
|
139 |
+
self.filter_channels = filter_channels
|
140 |
+
self.n_heads = n_heads
|
141 |
+
self.n_layers = n_layers
|
142 |
+
self.kernel_size = kernel_size
|
143 |
+
self.p_dropout = p_dropout
|
144 |
+
self.proximal_bias = proximal_bias
|
145 |
+
self.proximal_init = proximal_init
|
146 |
+
|
147 |
+
self.drop = nn.Dropout(p_dropout)
|
148 |
+
self.self_attn_layers = nn.ModuleList()
|
149 |
+
self.norm_layers_0 = nn.ModuleList()
|
150 |
+
self.encdec_attn_layers = nn.ModuleList()
|
151 |
+
self.norm_layers_1 = nn.ModuleList()
|
152 |
+
self.ffn_layers = nn.ModuleList()
|
153 |
+
self.norm_layers_2 = nn.ModuleList()
|
154 |
+
for i in range(self.n_layers):
|
155 |
+
self.self_attn_layers.append(
|
156 |
+
MultiHeadAttention(
|
157 |
+
hidden_channels,
|
158 |
+
hidden_channels,
|
159 |
+
n_heads,
|
160 |
+
p_dropout=p_dropout,
|
161 |
+
proximal_bias=proximal_bias,
|
162 |
+
proximal_init=proximal_init,
|
163 |
+
)
|
164 |
+
)
|
165 |
+
self.norm_layers_0.append(LayerNorm(hidden_channels))
|
166 |
+
self.encdec_attn_layers.append(
|
167 |
+
MultiHeadAttention(
|
168 |
+
hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout
|
169 |
+
)
|
170 |
+
)
|
171 |
+
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
172 |
+
self.ffn_layers.append(
|
173 |
+
FFN(
|
174 |
+
hidden_channels,
|
175 |
+
hidden_channels,
|
176 |
+
filter_channels,
|
177 |
+
kernel_size,
|
178 |
+
p_dropout=p_dropout,
|
179 |
+
causal=True,
|
180 |
+
)
|
181 |
+
)
|
182 |
+
self.norm_layers_2.append(LayerNorm(hidden_channels))
|
183 |
+
|
184 |
+
def forward(self, x, x_mask, h, h_mask):
|
185 |
+
"""
|
186 |
+
x: decoder input
|
187 |
+
h: encoder output
|
188 |
+
"""
|
189 |
+
self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
|
190 |
+
device=x.device, dtype=x.dtype
|
191 |
+
)
|
192 |
+
encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
193 |
+
x = x * x_mask
|
194 |
+
for i in range(self.n_layers):
|
195 |
+
y = self.self_attn_layers[i](x, x, self_attn_mask)
|
196 |
+
y = self.drop(y)
|
197 |
+
x = self.norm_layers_0[i](x + y)
|
198 |
+
|
199 |
+
y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
|
200 |
+
y = self.drop(y)
|
201 |
+
x = self.norm_layers_1[i](x + y)
|
202 |
+
|
203 |
+
y = self.ffn_layers[i](x, x_mask)
|
204 |
+
y = self.drop(y)
|
205 |
+
x = self.norm_layers_2[i](x + y)
|
206 |
+
x = x * x_mask
|
207 |
+
return x
|
208 |
+
|
209 |
+
|
210 |
+
class MultiHeadAttention(nn.Module):
|
211 |
+
def __init__(
|
212 |
+
self,
|
213 |
+
channels,
|
214 |
+
out_channels,
|
215 |
+
n_heads,
|
216 |
+
p_dropout=0.0,
|
217 |
+
window_size=None,
|
218 |
+
heads_share=True,
|
219 |
+
block_length=None,
|
220 |
+
proximal_bias=False,
|
221 |
+
proximal_init=False,
|
222 |
+
):
|
223 |
+
super().__init__()
|
224 |
+
assert channels % n_heads == 0
|
225 |
+
|
226 |
+
self.channels = channels
|
227 |
+
self.out_channels = out_channels
|
228 |
+
self.n_heads = n_heads
|
229 |
+
self.p_dropout = p_dropout
|
230 |
+
self.window_size = window_size
|
231 |
+
self.heads_share = heads_share
|
232 |
+
self.block_length = block_length
|
233 |
+
self.proximal_bias = proximal_bias
|
234 |
+
self.proximal_init = proximal_init
|
235 |
+
self.attn = None
|
236 |
+
|
237 |
+
self.k_channels = channels // n_heads
|
238 |
+
self.conv_q = nn.Conv1d(channels, channels, 1)
|
239 |
+
self.conv_k = nn.Conv1d(channels, channels, 1)
|
240 |
+
self.conv_v = nn.Conv1d(channels, channels, 1)
|
241 |
+
self.conv_o = nn.Conv1d(channels, out_channels, 1)
|
242 |
+
self.drop = nn.Dropout(p_dropout)
|
243 |
+
|
244 |
+
if window_size is not None:
|
245 |
+
n_heads_rel = 1 if heads_share else n_heads
|
246 |
+
rel_stddev = self.k_channels**-0.5
|
247 |
+
self.emb_rel_k = nn.Parameter(
|
248 |
+
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
|
249 |
+
* rel_stddev
|
250 |
+
)
|
251 |
+
self.emb_rel_v = nn.Parameter(
|
252 |
+
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
|
253 |
+
* rel_stddev
|
254 |
+
)
|
255 |
+
|
256 |
+
nn.init.xavier_uniform_(self.conv_q.weight)
|
257 |
+
nn.init.xavier_uniform_(self.conv_k.weight)
|
258 |
+
nn.init.xavier_uniform_(self.conv_v.weight)
|
259 |
+
if proximal_init:
|
260 |
+
with torch.no_grad():
|
261 |
+
self.conv_k.weight.copy_(self.conv_q.weight)
|
262 |
+
self.conv_k.bias.copy_(self.conv_q.bias)
|
263 |
+
|
264 |
+
def forward(self, x, c, attn_mask=None):
|
265 |
+
q = self.conv_q(x)
|
266 |
+
k = self.conv_k(c)
|
267 |
+
v = self.conv_v(c)
|
268 |
+
|
269 |
+
x, self.attn = self.attention(q, k, v, mask=attn_mask)
|
270 |
+
|
271 |
+
x = self.conv_o(x)
|
272 |
+
return x
|
273 |
+
|
274 |
+
def attention(self, query, key, value, mask=None):
|
275 |
+
# reshape [b, d, t] -> [b, n_h, t, d_k]
|
276 |
+
b, d, t_s, t_t = (*key.size(), query.size(2))
|
277 |
+
query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
|
278 |
+
key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
|
279 |
+
value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
|
280 |
+
|
281 |
+
scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
|
282 |
+
if self.window_size is not None:
|
283 |
+
assert (
|
284 |
+
t_s == t_t
|
285 |
+
), "Relative attention is only available for self-attention."
|
286 |
+
key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
|
287 |
+
rel_logits = self._matmul_with_relative_keys(
|
288 |
+
query / math.sqrt(self.k_channels), key_relative_embeddings
|
289 |
+
)
|
290 |
+
scores_local = self._relative_position_to_absolute_position(rel_logits)
|
291 |
+
scores = scores + scores_local
|
292 |
+
if self.proximal_bias:
|
293 |
+
assert t_s == t_t, "Proximal bias is only available for self-attention."
|
294 |
+
scores = scores + self._attention_bias_proximal(t_s).to(
|
295 |
+
device=scores.device, dtype=scores.dtype
|
296 |
+
)
|
297 |
+
if mask is not None:
|
298 |
+
scores = scores.masked_fill(mask == 0, -1e4)
|
299 |
+
if self.block_length is not None:
|
300 |
+
assert (
|
301 |
+
t_s == t_t
|
302 |
+
), "Local attention is only available for self-attention."
|
303 |
+
block_mask = (
|
304 |
+
torch.ones_like(scores)
|
305 |
+
.triu(-self.block_length)
|
306 |
+
.tril(self.block_length)
|
307 |
+
)
|
308 |
+
scores = scores.masked_fill(block_mask == 0, -1e4)
|
309 |
+
p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
|
310 |
+
p_attn = self.drop(p_attn)
|
311 |
+
output = torch.matmul(p_attn, value)
|
312 |
+
if self.window_size is not None:
|
313 |
+
relative_weights = self._absolute_position_to_relative_position(p_attn)
|
314 |
+
value_relative_embeddings = self._get_relative_embeddings(
|
315 |
+
self.emb_rel_v, t_s
|
316 |
+
)
|
317 |
+
output = output + self._matmul_with_relative_values(
|
318 |
+
relative_weights, value_relative_embeddings
|
319 |
+
)
|
320 |
+
output = (
|
321 |
+
output.transpose(2, 3).contiguous().view(b, d, t_t)
|
322 |
+
) # [b, n_h, t_t, d_k] -> [b, d, t_t]
|
323 |
+
return output, p_attn
|
324 |
+
|
325 |
+
def _matmul_with_relative_values(self, x, y):
|
326 |
+
"""
|
327 |
+
x: [b, h, l, m]
|
328 |
+
y: [h or 1, m, d]
|
329 |
+
ret: [b, h, l, d]
|
330 |
+
"""
|
331 |
+
ret = torch.matmul(x, y.unsqueeze(0))
|
332 |
+
return ret
|
333 |
+
|
334 |
+
def _matmul_with_relative_keys(self, x, y):
|
335 |
+
"""
|
336 |
+
x: [b, h, l, d]
|
337 |
+
y: [h or 1, m, d]
|
338 |
+
ret: [b, h, l, m]
|
339 |
+
"""
|
340 |
+
ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
|
341 |
+
return ret
|
342 |
+
|
343 |
+
def _get_relative_embeddings(self, relative_embeddings, length):
|
344 |
+
2 * self.window_size + 1
|
345 |
+
# Pad first before slice to avoid using cond ops.
|
346 |
+
pad_length = max(length - (self.window_size + 1), 0)
|
347 |
+
slice_start_position = max((self.window_size + 1) - length, 0)
|
348 |
+
slice_end_position = slice_start_position + 2 * length - 1
|
349 |
+
if pad_length > 0:
|
350 |
+
padded_relative_embeddings = F.pad(
|
351 |
+
relative_embeddings,
|
352 |
+
commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
|
353 |
+
)
|
354 |
+
else:
|
355 |
+
padded_relative_embeddings = relative_embeddings
|
356 |
+
used_relative_embeddings = padded_relative_embeddings[
|
357 |
+
:, slice_start_position:slice_end_position
|
358 |
+
]
|
359 |
+
return used_relative_embeddings
|
360 |
+
|
361 |
+
def _relative_position_to_absolute_position(self, x):
|
362 |
+
"""
|
363 |
+
x: [b, h, l, 2*l-1]
|
364 |
+
ret: [b, h, l, l]
|
365 |
+
"""
|
366 |
+
batch, heads, length, _ = x.size()
|
367 |
+
# Concat columns of pad to shift from relative to absolute indexing.
|
368 |
+
x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
|
369 |
+
|
370 |
+
# Concat extra elements so to add up to shape (len+1, 2*len-1).
|
371 |
+
x_flat = x.view([batch, heads, length * 2 * length])
|
372 |
+
x_flat = F.pad(
|
373 |
+
x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
|
374 |
+
)
|
375 |
+
|
376 |
+
# Reshape and slice out the padded elements.
|
377 |
+
x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
|
378 |
+
:, :, :length, length - 1 :
|
379 |
+
]
|
380 |
+
return x_final
|
381 |
+
|
382 |
+
def _absolute_position_to_relative_position(self, x):
|
383 |
+
"""
|
384 |
+
x: [b, h, l, l]
|
385 |
+
ret: [b, h, l, 2*l-1]
|
386 |
+
"""
|
387 |
+
batch, heads, length, _ = x.size()
|
388 |
+
# pad along column
|
389 |
+
x = F.pad(
|
390 |
+
x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
|
391 |
+
)
|
392 |
+
x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
|
393 |
+
# add 0's in the beginning that will skew the elements after reshape
|
394 |
+
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
|
395 |
+
x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
|
396 |
+
return x_final
|
397 |
+
|
398 |
+
def _attention_bias_proximal(self, length):
|
399 |
+
"""Bias for self-attention to encourage attention to close positions.
|
400 |
+
Args:
|
401 |
+
length: an integer scalar.
|
402 |
+
Returns:
|
403 |
+
a Tensor with shape [1, 1, length, length]
|
404 |
+
"""
|
405 |
+
r = torch.arange(length, dtype=torch.float32)
|
406 |
+
diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
|
407 |
+
return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
|
408 |
+
|
409 |
+
|
410 |
+
class FFN(nn.Module):
|
411 |
+
def __init__(
|
412 |
+
self,
|
413 |
+
in_channels,
|
414 |
+
out_channels,
|
415 |
+
filter_channels,
|
416 |
+
kernel_size,
|
417 |
+
p_dropout=0.0,
|
418 |
+
activation=None,
|
419 |
+
causal=False,
|
420 |
+
):
|
421 |
+
super().__init__()
|
422 |
+
self.in_channels = in_channels
|
423 |
+
self.out_channels = out_channels
|
424 |
+
self.filter_channels = filter_channels
|
425 |
+
self.kernel_size = kernel_size
|
426 |
+
self.p_dropout = p_dropout
|
427 |
+
self.activation = activation
|
428 |
+
self.causal = causal
|
429 |
+
|
430 |
+
if causal:
|
431 |
+
self.padding = self._causal_padding
|
432 |
+
else:
|
433 |
+
self.padding = self._same_padding
|
434 |
+
|
435 |
+
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
|
436 |
+
self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
|
437 |
+
self.drop = nn.Dropout(p_dropout)
|
438 |
+
|
439 |
+
def forward(self, x, x_mask):
|
440 |
+
x = self.conv_1(self.padding(x * x_mask))
|
441 |
+
if self.activation == "gelu":
|
442 |
+
x = x * torch.sigmoid(1.702 * x)
|
443 |
+
else:
|
444 |
+
x = torch.relu(x)
|
445 |
+
x = self.drop(x)
|
446 |
+
x = self.conv_2(self.padding(x * x_mask))
|
447 |
+
return x * x_mask
|
448 |
+
|
449 |
+
def _causal_padding(self, x):
|
450 |
+
if self.kernel_size == 1:
|
451 |
+
return x
|
452 |
+
pad_l = self.kernel_size - 1
|
453 |
+
pad_r = 0
|
454 |
+
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
|
455 |
+
x = F.pad(x, commons.convert_pad_shape(padding))
|
456 |
+
return x
|
457 |
+
|
458 |
+
def _same_padding(self, x):
|
459 |
+
if self.kernel_size == 1:
|
460 |
+
return x
|
461 |
+
pad_l = (self.kernel_size - 1) // 2
|
462 |
+
pad_r = self.kernel_size // 2
|
463 |
+
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
|
464 |
+
x = F.pad(x, commons.convert_pad_shape(padding))
|
465 |
+
return x
|
OpenVoice/openvoice/commons.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
from torch.nn import functional as F
|
4 |
+
|
5 |
+
|
6 |
+
def init_weights(m, mean=0.0, std=0.01):
|
7 |
+
classname = m.__class__.__name__
|
8 |
+
if classname.find("Conv") != -1:
|
9 |
+
m.weight.data.normal_(mean, std)
|
10 |
+
|
11 |
+
|
12 |
+
def get_padding(kernel_size, dilation=1):
|
13 |
+
return int((kernel_size * dilation - dilation) / 2)
|
14 |
+
|
15 |
+
|
16 |
+
def convert_pad_shape(pad_shape):
|
17 |
+
layer = pad_shape[::-1]
|
18 |
+
pad_shape = [item for sublist in layer for item in sublist]
|
19 |
+
return pad_shape
|
20 |
+
|
21 |
+
|
22 |
+
def intersperse(lst, item):
|
23 |
+
result = [item] * (len(lst) * 2 + 1)
|
24 |
+
result[1::2] = lst
|
25 |
+
return result
|
26 |
+
|
27 |
+
|
28 |
+
def kl_divergence(m_p, logs_p, m_q, logs_q):
|
29 |
+
"""KL(P||Q)"""
|
30 |
+
kl = (logs_q - logs_p) - 0.5
|
31 |
+
kl += (
|
32 |
+
0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
|
33 |
+
)
|
34 |
+
return kl
|
35 |
+
|
36 |
+
|
37 |
+
def rand_gumbel(shape):
|
38 |
+
"""Sample from the Gumbel distribution, protect from overflows."""
|
39 |
+
uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
|
40 |
+
return -torch.log(-torch.log(uniform_samples))
|
41 |
+
|
42 |
+
|
43 |
+
def rand_gumbel_like(x):
|
44 |
+
g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
|
45 |
+
return g
|
46 |
+
|
47 |
+
|
48 |
+
def slice_segments(x, ids_str, segment_size=4):
|
49 |
+
ret = torch.zeros_like(x[:, :, :segment_size])
|
50 |
+
for i in range(x.size(0)):
|
51 |
+
idx_str = ids_str[i]
|
52 |
+
idx_end = idx_str + segment_size
|
53 |
+
ret[i] = x[i, :, idx_str:idx_end]
|
54 |
+
return ret
|
55 |
+
|
56 |
+
|
57 |
+
def rand_slice_segments(x, x_lengths=None, segment_size=4):
|
58 |
+
b, d, t = x.size()
|
59 |
+
if x_lengths is None:
|
60 |
+
x_lengths = t
|
61 |
+
ids_str_max = x_lengths - segment_size + 1
|
62 |
+
ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
|
63 |
+
ret = slice_segments(x, ids_str, segment_size)
|
64 |
+
return ret, ids_str
|
65 |
+
|
66 |
+
|
67 |
+
def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
|
68 |
+
position = torch.arange(length, dtype=torch.float)
|
69 |
+
num_timescales = channels // 2
|
70 |
+
log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
|
71 |
+
num_timescales - 1
|
72 |
+
)
|
73 |
+
inv_timescales = min_timescale * torch.exp(
|
74 |
+
torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
|
75 |
+
)
|
76 |
+
scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
|
77 |
+
signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
|
78 |
+
signal = F.pad(signal, [0, 0, 0, channels % 2])
|
79 |
+
signal = signal.view(1, channels, length)
|
80 |
+
return signal
|
81 |
+
|
82 |
+
|
83 |
+
def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
|
84 |
+
b, channels, length = x.size()
|
85 |
+
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
|
86 |
+
return x + signal.to(dtype=x.dtype, device=x.device)
|
87 |
+
|
88 |
+
|
89 |
+
def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
|
90 |
+
b, channels, length = x.size()
|
91 |
+
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
|
92 |
+
return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
|
93 |
+
|
94 |
+
|
95 |
+
def subsequent_mask(length):
|
96 |
+
mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
|
97 |
+
return mask
|
98 |
+
|
99 |
+
|
100 |
+
@torch.jit.script
|
101 |
+
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
|
102 |
+
n_channels_int = n_channels[0]
|
103 |
+
in_act = input_a + input_b
|
104 |
+
t_act = torch.tanh(in_act[:, :n_channels_int, :])
|
105 |
+
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
|
106 |
+
acts = t_act * s_act
|
107 |
+
return acts
|
108 |
+
|
109 |
+
|
110 |
+
def convert_pad_shape(pad_shape):
|
111 |
+
layer = pad_shape[::-1]
|
112 |
+
pad_shape = [item for sublist in layer for item in sublist]
|
113 |
+
return pad_shape
|
114 |
+
|
115 |
+
|
116 |
+
def shift_1d(x):
|
117 |
+
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
|
118 |
+
return x
|
119 |
+
|
120 |
+
|
121 |
+
def sequence_mask(length, max_length=None):
|
122 |
+
if max_length is None:
|
123 |
+
max_length = length.max()
|
124 |
+
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
|
125 |
+
return x.unsqueeze(0) < length.unsqueeze(1)
|
126 |
+
|
127 |
+
|
128 |
+
def generate_path(duration, mask):
|
129 |
+
"""
|
130 |
+
duration: [b, 1, t_x]
|
131 |
+
mask: [b, 1, t_y, t_x]
|
132 |
+
"""
|
133 |
+
|
134 |
+
b, _, t_y, t_x = mask.shape
|
135 |
+
cum_duration = torch.cumsum(duration, -1)
|
136 |
+
|
137 |
+
cum_duration_flat = cum_duration.view(b * t_x)
|
138 |
+
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
|
139 |
+
path = path.view(b, t_x, t_y)
|
140 |
+
path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
|
141 |
+
path = path.unsqueeze(1).transpose(2, 3) * mask
|
142 |
+
return path
|
143 |
+
|
144 |
+
|
145 |
+
def clip_grad_value_(parameters, clip_value, norm_type=2):
|
146 |
+
if isinstance(parameters, torch.Tensor):
|
147 |
+
parameters = [parameters]
|
148 |
+
parameters = list(filter(lambda p: p.grad is not None, parameters))
|
149 |
+
norm_type = float(norm_type)
|
150 |
+
if clip_value is not None:
|
151 |
+
clip_value = float(clip_value)
|
152 |
+
|
153 |
+
total_norm = 0
|
154 |
+
for p in parameters:
|
155 |
+
param_norm = p.grad.data.norm(norm_type)
|
156 |
+
total_norm += param_norm.item() ** norm_type
|
157 |
+
if clip_value is not None:
|
158 |
+
p.grad.data.clamp_(min=-clip_value, max=clip_value)
|
159 |
+
total_norm = total_norm ** (1.0 / norm_type)
|
160 |
+
return total_norm
|
OpenVoice/openvoice/mel_processing.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.utils.data
|
3 |
+
from librosa.filters import mel as librosa_mel_fn
|
4 |
+
|
5 |
+
MAX_WAV_VALUE = 32768.0
|
6 |
+
|
7 |
+
|
8 |
+
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
9 |
+
"""
|
10 |
+
PARAMS
|
11 |
+
------
|
12 |
+
C: compression factor
|
13 |
+
"""
|
14 |
+
return torch.log(torch.clamp(x, min=clip_val) * C)
|
15 |
+
|
16 |
+
|
17 |
+
def dynamic_range_decompression_torch(x, C=1):
|
18 |
+
"""
|
19 |
+
PARAMS
|
20 |
+
------
|
21 |
+
C: compression factor used to compress
|
22 |
+
"""
|
23 |
+
return torch.exp(x) / C
|
24 |
+
|
25 |
+
|
26 |
+
def spectral_normalize_torch(magnitudes):
|
27 |
+
output = dynamic_range_compression_torch(magnitudes)
|
28 |
+
return output
|
29 |
+
|
30 |
+
|
31 |
+
def spectral_de_normalize_torch(magnitudes):
|
32 |
+
output = dynamic_range_decompression_torch(magnitudes)
|
33 |
+
return output
|
34 |
+
|
35 |
+
|
36 |
+
mel_basis = {}
|
37 |
+
hann_window = {}
|
38 |
+
|
39 |
+
|
40 |
+
def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
|
41 |
+
if torch.min(y) < -1.1:
|
42 |
+
print("min value is ", torch.min(y))
|
43 |
+
if torch.max(y) > 1.1:
|
44 |
+
print("max value is ", torch.max(y))
|
45 |
+
|
46 |
+
global hann_window
|
47 |
+
dtype_device = str(y.dtype) + "_" + str(y.device)
|
48 |
+
wnsize_dtype_device = str(win_size) + "_" + dtype_device
|
49 |
+
if wnsize_dtype_device not in hann_window:
|
50 |
+
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
|
51 |
+
dtype=y.dtype, device=y.device
|
52 |
+
)
|
53 |
+
|
54 |
+
y = torch.nn.functional.pad(
|
55 |
+
y.unsqueeze(1),
|
56 |
+
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
|
57 |
+
mode="reflect",
|
58 |
+
)
|
59 |
+
y = y.squeeze(1)
|
60 |
+
|
61 |
+
spec = torch.stft(
|
62 |
+
y,
|
63 |
+
n_fft,
|
64 |
+
hop_length=hop_size,
|
65 |
+
win_length=win_size,
|
66 |
+
window=hann_window[wnsize_dtype_device],
|
67 |
+
center=center,
|
68 |
+
pad_mode="reflect",
|
69 |
+
normalized=False,
|
70 |
+
onesided=True,
|
71 |
+
return_complex=False,
|
72 |
+
)
|
73 |
+
|
74 |
+
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
75 |
+
return spec
|
76 |
+
|
77 |
+
|
78 |
+
def spectrogram_torch_conv(y, n_fft, sampling_rate, hop_size, win_size, center=False):
|
79 |
+
# if torch.min(y) < -1.:
|
80 |
+
# print('min value is ', torch.min(y))
|
81 |
+
# if torch.max(y) > 1.:
|
82 |
+
# print('max value is ', torch.max(y))
|
83 |
+
|
84 |
+
global hann_window
|
85 |
+
dtype_device = str(y.dtype) + '_' + str(y.device)
|
86 |
+
wnsize_dtype_device = str(win_size) + '_' + dtype_device
|
87 |
+
if wnsize_dtype_device not in hann_window:
|
88 |
+
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
|
89 |
+
|
90 |
+
y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
|
91 |
+
|
92 |
+
# ******************** original ************************#
|
93 |
+
# y = y.squeeze(1)
|
94 |
+
# spec1 = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
|
95 |
+
# center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False)
|
96 |
+
|
97 |
+
# ******************** ConvSTFT ************************#
|
98 |
+
freq_cutoff = n_fft // 2 + 1
|
99 |
+
fourier_basis = torch.view_as_real(torch.fft.fft(torch.eye(n_fft)))
|
100 |
+
forward_basis = fourier_basis[:freq_cutoff].permute(2, 0, 1).reshape(-1, 1, fourier_basis.shape[1])
|
101 |
+
forward_basis = forward_basis * torch.as_tensor(librosa.util.pad_center(torch.hann_window(win_size), size=n_fft)).float()
|
102 |
+
|
103 |
+
import torch.nn.functional as F
|
104 |
+
|
105 |
+
# if center:
|
106 |
+
# signal = F.pad(y[:, None, None, :], (n_fft // 2, n_fft // 2, 0, 0), mode = 'reflect').squeeze(1)
|
107 |
+
assert center is False
|
108 |
+
|
109 |
+
forward_transform_squared = F.conv1d(y, forward_basis.to(y.device), stride = hop_size)
|
110 |
+
spec2 = torch.stack([forward_transform_squared[:, :freq_cutoff, :], forward_transform_squared[:, freq_cutoff:, :]], dim = -1)
|
111 |
+
|
112 |
+
|
113 |
+
# ******************** Verification ************************#
|
114 |
+
spec1 = torch.stft(y.squeeze(1), n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
|
115 |
+
center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False)
|
116 |
+
assert torch.allclose(spec1, spec2, atol=1e-4)
|
117 |
+
|
118 |
+
spec = torch.sqrt(spec2.pow(2).sum(-1) + 1e-6)
|
119 |
+
return spec
|
120 |
+
|
121 |
+
|
122 |
+
def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
|
123 |
+
global mel_basis
|
124 |
+
dtype_device = str(spec.dtype) + "_" + str(spec.device)
|
125 |
+
fmax_dtype_device = str(fmax) + "_" + dtype_device
|
126 |
+
if fmax_dtype_device not in mel_basis:
|
127 |
+
mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
|
128 |
+
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
|
129 |
+
dtype=spec.dtype, device=spec.device
|
130 |
+
)
|
131 |
+
spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
|
132 |
+
spec = spectral_normalize_torch(spec)
|
133 |
+
return spec
|
134 |
+
|
135 |
+
|
136 |
+
def mel_spectrogram_torch(
|
137 |
+
y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False
|
138 |
+
):
|
139 |
+
if torch.min(y) < -1.0:
|
140 |
+
print("min value is ", torch.min(y))
|
141 |
+
if torch.max(y) > 1.0:
|
142 |
+
print("max value is ", torch.max(y))
|
143 |
+
|
144 |
+
global mel_basis, hann_window
|
145 |
+
dtype_device = str(y.dtype) + "_" + str(y.device)
|
146 |
+
fmax_dtype_device = str(fmax) + "_" + dtype_device
|
147 |
+
wnsize_dtype_device = str(win_size) + "_" + dtype_device
|
148 |
+
if fmax_dtype_device not in mel_basis:
|
149 |
+
mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
|
150 |
+
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
|
151 |
+
dtype=y.dtype, device=y.device
|
152 |
+
)
|
153 |
+
if wnsize_dtype_device not in hann_window:
|
154 |
+
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
|
155 |
+
dtype=y.dtype, device=y.device
|
156 |
+
)
|
157 |
+
|
158 |
+
y = torch.nn.functional.pad(
|
159 |
+
y.unsqueeze(1),
|
160 |
+
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
|
161 |
+
mode="reflect",
|
162 |
+
)
|
163 |
+
y = y.squeeze(1)
|
164 |
+
|
165 |
+
spec = torch.stft(
|
166 |
+
y,
|
167 |
+
n_fft,
|
168 |
+
hop_length=hop_size,
|
169 |
+
win_length=win_size,
|
170 |
+
window=hann_window[wnsize_dtype_device],
|
171 |
+
center=center,
|
172 |
+
pad_mode="reflect",
|
173 |
+
normalized=False,
|
174 |
+
onesided=True,
|
175 |
+
return_complex=False,
|
176 |
+
)
|
177 |
+
|
178 |
+
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
179 |
+
|
180 |
+
spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
|
181 |
+
spec = spectral_normalize_torch(spec)
|
182 |
+
|
183 |
+
return spec
|
OpenVoice/openvoice/models.py
ADDED
@@ -0,0 +1,499 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
from torch.nn import functional as F
|
5 |
+
|
6 |
+
from openvoice import commons
|
7 |
+
from openvoice import modules
|
8 |
+
from openvoice import attentions
|
9 |
+
|
10 |
+
from torch.nn import Conv1d, ConvTranspose1d, Conv2d
|
11 |
+
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
12 |
+
|
13 |
+
from openvoice.commons import init_weights, get_padding
|
14 |
+
|
15 |
+
|
16 |
+
class TextEncoder(nn.Module):
|
17 |
+
def __init__(self,
|
18 |
+
n_vocab,
|
19 |
+
out_channels,
|
20 |
+
hidden_channels,
|
21 |
+
filter_channels,
|
22 |
+
n_heads,
|
23 |
+
n_layers,
|
24 |
+
kernel_size,
|
25 |
+
p_dropout):
|
26 |
+
super().__init__()
|
27 |
+
self.n_vocab = n_vocab
|
28 |
+
self.out_channels = out_channels
|
29 |
+
self.hidden_channels = hidden_channels
|
30 |
+
self.filter_channels = filter_channels
|
31 |
+
self.n_heads = n_heads
|
32 |
+
self.n_layers = n_layers
|
33 |
+
self.kernel_size = kernel_size
|
34 |
+
self.p_dropout = p_dropout
|
35 |
+
|
36 |
+
self.emb = nn.Embedding(n_vocab, hidden_channels)
|
37 |
+
nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
|
38 |
+
|
39 |
+
self.encoder = attentions.Encoder(
|
40 |
+
hidden_channels,
|
41 |
+
filter_channels,
|
42 |
+
n_heads,
|
43 |
+
n_layers,
|
44 |
+
kernel_size,
|
45 |
+
p_dropout)
|
46 |
+
self.proj= nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
47 |
+
|
48 |
+
def forward(self, x, x_lengths):
|
49 |
+
x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
|
50 |
+
x = torch.transpose(x, 1, -1) # [b, h, t]
|
51 |
+
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
52 |
+
|
53 |
+
x = self.encoder(x * x_mask, x_mask)
|
54 |
+
stats = self.proj(x) * x_mask
|
55 |
+
|
56 |
+
m, logs = torch.split(stats, self.out_channels, dim=1)
|
57 |
+
return x, m, logs, x_mask
|
58 |
+
|
59 |
+
|
60 |
+
class DurationPredictor(nn.Module):
|
61 |
+
def __init__(
|
62 |
+
self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
|
63 |
+
):
|
64 |
+
super().__init__()
|
65 |
+
|
66 |
+
self.in_channels = in_channels
|
67 |
+
self.filter_channels = filter_channels
|
68 |
+
self.kernel_size = kernel_size
|
69 |
+
self.p_dropout = p_dropout
|
70 |
+
self.gin_channels = gin_channels
|
71 |
+
|
72 |
+
self.drop = nn.Dropout(p_dropout)
|
73 |
+
self.conv_1 = nn.Conv1d(
|
74 |
+
in_channels, filter_channels, kernel_size, padding=kernel_size // 2
|
75 |
+
)
|
76 |
+
self.norm_1 = modules.LayerNorm(filter_channels)
|
77 |
+
self.conv_2 = nn.Conv1d(
|
78 |
+
filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
|
79 |
+
)
|
80 |
+
self.norm_2 = modules.LayerNorm(filter_channels)
|
81 |
+
self.proj = nn.Conv1d(filter_channels, 1, 1)
|
82 |
+
|
83 |
+
if gin_channels != 0:
|
84 |
+
self.cond = nn.Conv1d(gin_channels, in_channels, 1)
|
85 |
+
|
86 |
+
def forward(self, x, x_mask, g=None):
|
87 |
+
x = torch.detach(x)
|
88 |
+
if g is not None:
|
89 |
+
g = torch.detach(g)
|
90 |
+
x = x + self.cond(g)
|
91 |
+
x = self.conv_1(x * x_mask)
|
92 |
+
x = torch.relu(x)
|
93 |
+
x = self.norm_1(x)
|
94 |
+
x = self.drop(x)
|
95 |
+
x = self.conv_2(x * x_mask)
|
96 |
+
x = torch.relu(x)
|
97 |
+
x = self.norm_2(x)
|
98 |
+
x = self.drop(x)
|
99 |
+
x = self.proj(x * x_mask)
|
100 |
+
return x * x_mask
|
101 |
+
|
102 |
+
class StochasticDurationPredictor(nn.Module):
|
103 |
+
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, n_flows=4, gin_channels=0):
|
104 |
+
super().__init__()
|
105 |
+
filter_channels = in_channels # it needs to be removed from future version.
|
106 |
+
self.in_channels = in_channels
|
107 |
+
self.filter_channels = filter_channels
|
108 |
+
self.kernel_size = kernel_size
|
109 |
+
self.p_dropout = p_dropout
|
110 |
+
self.n_flows = n_flows
|
111 |
+
self.gin_channels = gin_channels
|
112 |
+
|
113 |
+
self.log_flow = modules.Log()
|
114 |
+
self.flows = nn.ModuleList()
|
115 |
+
self.flows.append(modules.ElementwiseAffine(2))
|
116 |
+
for i in range(n_flows):
|
117 |
+
self.flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
|
118 |
+
self.flows.append(modules.Flip())
|
119 |
+
|
120 |
+
self.post_pre = nn.Conv1d(1, filter_channels, 1)
|
121 |
+
self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
|
122 |
+
self.post_convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
|
123 |
+
self.post_flows = nn.ModuleList()
|
124 |
+
self.post_flows.append(modules.ElementwiseAffine(2))
|
125 |
+
for i in range(4):
|
126 |
+
self.post_flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
|
127 |
+
self.post_flows.append(modules.Flip())
|
128 |
+
|
129 |
+
self.pre = nn.Conv1d(in_channels, filter_channels, 1)
|
130 |
+
self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
|
131 |
+
self.convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
|
132 |
+
if gin_channels != 0:
|
133 |
+
self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
|
134 |
+
|
135 |
+
def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0):
|
136 |
+
x = torch.detach(x)
|
137 |
+
x = self.pre(x)
|
138 |
+
if g is not None:
|
139 |
+
g = torch.detach(g)
|
140 |
+
x = x + self.cond(g)
|
141 |
+
x = self.convs(x, x_mask)
|
142 |
+
x = self.proj(x) * x_mask
|
143 |
+
|
144 |
+
if not reverse:
|
145 |
+
flows = self.flows
|
146 |
+
assert w is not None
|
147 |
+
|
148 |
+
logdet_tot_q = 0
|
149 |
+
h_w = self.post_pre(w)
|
150 |
+
h_w = self.post_convs(h_w, x_mask)
|
151 |
+
h_w = self.post_proj(h_w) * x_mask
|
152 |
+
e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask
|
153 |
+
z_q = e_q
|
154 |
+
for flow in self.post_flows:
|
155 |
+
z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
|
156 |
+
logdet_tot_q += logdet_q
|
157 |
+
z_u, z1 = torch.split(z_q, [1, 1], 1)
|
158 |
+
u = torch.sigmoid(z_u) * x_mask
|
159 |
+
z0 = (w - u) * x_mask
|
160 |
+
logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1,2])
|
161 |
+
logq = torch.sum(-0.5 * (math.log(2*math.pi) + (e_q**2)) * x_mask, [1,2]) - logdet_tot_q
|
162 |
+
|
163 |
+
logdet_tot = 0
|
164 |
+
z0, logdet = self.log_flow(z0, x_mask)
|
165 |
+
logdet_tot += logdet
|
166 |
+
z = torch.cat([z0, z1], 1)
|
167 |
+
for flow in flows:
|
168 |
+
z, logdet = flow(z, x_mask, g=x, reverse=reverse)
|
169 |
+
logdet_tot = logdet_tot + logdet
|
170 |
+
nll = torch.sum(0.5 * (math.log(2*math.pi) + (z**2)) * x_mask, [1,2]) - logdet_tot
|
171 |
+
return nll + logq # [b]
|
172 |
+
else:
|
173 |
+
flows = list(reversed(self.flows))
|
174 |
+
flows = flows[:-2] + [flows[-1]] # remove a useless vflow
|
175 |
+
z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale
|
176 |
+
for flow in flows:
|
177 |
+
z = flow(z, x_mask, g=x, reverse=reverse)
|
178 |
+
z0, z1 = torch.split(z, [1, 1], 1)
|
179 |
+
logw = z0
|
180 |
+
return logw
|
181 |
+
|
182 |
+
class PosteriorEncoder(nn.Module):
|
183 |
+
def __init__(
|
184 |
+
self,
|
185 |
+
in_channels,
|
186 |
+
out_channels,
|
187 |
+
hidden_channels,
|
188 |
+
kernel_size,
|
189 |
+
dilation_rate,
|
190 |
+
n_layers,
|
191 |
+
gin_channels=0,
|
192 |
+
):
|
193 |
+
super().__init__()
|
194 |
+
self.in_channels = in_channels
|
195 |
+
self.out_channels = out_channels
|
196 |
+
self.hidden_channels = hidden_channels
|
197 |
+
self.kernel_size = kernel_size
|
198 |
+
self.dilation_rate = dilation_rate
|
199 |
+
self.n_layers = n_layers
|
200 |
+
self.gin_channels = gin_channels
|
201 |
+
|
202 |
+
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
|
203 |
+
self.enc = modules.WN(
|
204 |
+
hidden_channels,
|
205 |
+
kernel_size,
|
206 |
+
dilation_rate,
|
207 |
+
n_layers,
|
208 |
+
gin_channels=gin_channels,
|
209 |
+
)
|
210 |
+
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
211 |
+
|
212 |
+
def forward(self, x, x_lengths, g=None, tau=1.0):
|
213 |
+
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
|
214 |
+
x.dtype
|
215 |
+
)
|
216 |
+
x = self.pre(x) * x_mask
|
217 |
+
x = self.enc(x, x_mask, g=g)
|
218 |
+
stats = self.proj(x) * x_mask
|
219 |
+
m, logs = torch.split(stats, self.out_channels, dim=1)
|
220 |
+
z = (m + torch.randn_like(m) * tau * torch.exp(logs)) * x_mask
|
221 |
+
return z, m, logs, x_mask
|
222 |
+
|
223 |
+
|
224 |
+
class Generator(torch.nn.Module):
|
225 |
+
def __init__(
|
226 |
+
self,
|
227 |
+
initial_channel,
|
228 |
+
resblock,
|
229 |
+
resblock_kernel_sizes,
|
230 |
+
resblock_dilation_sizes,
|
231 |
+
upsample_rates,
|
232 |
+
upsample_initial_channel,
|
233 |
+
upsample_kernel_sizes,
|
234 |
+
gin_channels=0,
|
235 |
+
):
|
236 |
+
super(Generator, self).__init__()
|
237 |
+
self.num_kernels = len(resblock_kernel_sizes)
|
238 |
+
self.num_upsamples = len(upsample_rates)
|
239 |
+
self.conv_pre = Conv1d(
|
240 |
+
initial_channel, upsample_initial_channel, 7, 1, padding=3
|
241 |
+
)
|
242 |
+
resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
|
243 |
+
|
244 |
+
self.ups = nn.ModuleList()
|
245 |
+
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
246 |
+
self.ups.append(
|
247 |
+
weight_norm(
|
248 |
+
ConvTranspose1d(
|
249 |
+
upsample_initial_channel // (2**i),
|
250 |
+
upsample_initial_channel // (2 ** (i + 1)),
|
251 |
+
k,
|
252 |
+
u,
|
253 |
+
padding=(k - u) // 2,
|
254 |
+
)
|
255 |
+
)
|
256 |
+
)
|
257 |
+
|
258 |
+
self.resblocks = nn.ModuleList()
|
259 |
+
for i in range(len(self.ups)):
|
260 |
+
ch = upsample_initial_channel // (2 ** (i + 1))
|
261 |
+
for j, (k, d) in enumerate(
|
262 |
+
zip(resblock_kernel_sizes, resblock_dilation_sizes)
|
263 |
+
):
|
264 |
+
self.resblocks.append(resblock(ch, k, d))
|
265 |
+
|
266 |
+
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
|
267 |
+
self.ups.apply(init_weights)
|
268 |
+
|
269 |
+
if gin_channels != 0:
|
270 |
+
self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
|
271 |
+
|
272 |
+
def forward(self, x, g=None):
|
273 |
+
x = self.conv_pre(x)
|
274 |
+
if g is not None:
|
275 |
+
x = x + self.cond(g)
|
276 |
+
|
277 |
+
for i in range(self.num_upsamples):
|
278 |
+
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
279 |
+
x = self.ups[i](x)
|
280 |
+
xs = None
|
281 |
+
for j in range(self.num_kernels):
|
282 |
+
if xs is None:
|
283 |
+
xs = self.resblocks[i * self.num_kernels + j](x)
|
284 |
+
else:
|
285 |
+
xs += self.resblocks[i * self.num_kernels + j](x)
|
286 |
+
x = xs / self.num_kernels
|
287 |
+
x = F.leaky_relu(x)
|
288 |
+
x = self.conv_post(x)
|
289 |
+
x = torch.tanh(x)
|
290 |
+
|
291 |
+
return x
|
292 |
+
|
293 |
+
def remove_weight_norm(self):
|
294 |
+
print("Removing weight norm...")
|
295 |
+
for layer in self.ups:
|
296 |
+
remove_weight_norm(layer)
|
297 |
+
for layer in self.resblocks:
|
298 |
+
layer.remove_weight_norm()
|
299 |
+
|
300 |
+
|
301 |
+
class ReferenceEncoder(nn.Module):
|
302 |
+
"""
|
303 |
+
inputs --- [N, Ty/r, n_mels*r] mels
|
304 |
+
outputs --- [N, ref_enc_gru_size]
|
305 |
+
"""
|
306 |
+
|
307 |
+
def __init__(self, spec_channels, gin_channels=0, layernorm=True):
|
308 |
+
super().__init__()
|
309 |
+
self.spec_channels = spec_channels
|
310 |
+
ref_enc_filters = [32, 32, 64, 64, 128, 128]
|
311 |
+
K = len(ref_enc_filters)
|
312 |
+
filters = [1] + ref_enc_filters
|
313 |
+
convs = [
|
314 |
+
weight_norm(
|
315 |
+
nn.Conv2d(
|
316 |
+
in_channels=filters[i],
|
317 |
+
out_channels=filters[i + 1],
|
318 |
+
kernel_size=(3, 3),
|
319 |
+
stride=(2, 2),
|
320 |
+
padding=(1, 1),
|
321 |
+
)
|
322 |
+
)
|
323 |
+
for i in range(K)
|
324 |
+
]
|
325 |
+
self.convs = nn.ModuleList(convs)
|
326 |
+
|
327 |
+
out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K)
|
328 |
+
self.gru = nn.GRU(
|
329 |
+
input_size=ref_enc_filters[-1] * out_channels,
|
330 |
+
hidden_size=256 // 2,
|
331 |
+
batch_first=True,
|
332 |
+
)
|
333 |
+
self.proj = nn.Linear(128, gin_channels)
|
334 |
+
if layernorm:
|
335 |
+
self.layernorm = nn.LayerNorm(self.spec_channels)
|
336 |
+
else:
|
337 |
+
self.layernorm = None
|
338 |
+
|
339 |
+
def forward(self, inputs, mask=None):
|
340 |
+
N = inputs.size(0)
|
341 |
+
|
342 |
+
out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs]
|
343 |
+
if self.layernorm is not None:
|
344 |
+
out = self.layernorm(out)
|
345 |
+
|
346 |
+
for conv in self.convs:
|
347 |
+
out = conv(out)
|
348 |
+
# out = wn(out)
|
349 |
+
out = F.relu(out) # [N, 128, Ty//2^K, n_mels//2^K]
|
350 |
+
|
351 |
+
out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K]
|
352 |
+
T = out.size(1)
|
353 |
+
N = out.size(0)
|
354 |
+
out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K]
|
355 |
+
|
356 |
+
self.gru.flatten_parameters()
|
357 |
+
memory, out = self.gru(out) # out --- [1, N, 128]
|
358 |
+
|
359 |
+
return self.proj(out.squeeze(0))
|
360 |
+
|
361 |
+
def calculate_channels(self, L, kernel_size, stride, pad, n_convs):
|
362 |
+
for i in range(n_convs):
|
363 |
+
L = (L - kernel_size + 2 * pad) // stride + 1
|
364 |
+
return L
|
365 |
+
|
366 |
+
|
367 |
+
class ResidualCouplingBlock(nn.Module):
|
368 |
+
def __init__(self,
|
369 |
+
channels,
|
370 |
+
hidden_channels,
|
371 |
+
kernel_size,
|
372 |
+
dilation_rate,
|
373 |
+
n_layers,
|
374 |
+
n_flows=4,
|
375 |
+
gin_channels=0):
|
376 |
+
super().__init__()
|
377 |
+
self.channels = channels
|
378 |
+
self.hidden_channels = hidden_channels
|
379 |
+
self.kernel_size = kernel_size
|
380 |
+
self.dilation_rate = dilation_rate
|
381 |
+
self.n_layers = n_layers
|
382 |
+
self.n_flows = n_flows
|
383 |
+
self.gin_channels = gin_channels
|
384 |
+
|
385 |
+
self.flows = nn.ModuleList()
|
386 |
+
for i in range(n_flows):
|
387 |
+
self.flows.append(modules.ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels, mean_only=True))
|
388 |
+
self.flows.append(modules.Flip())
|
389 |
+
|
390 |
+
def forward(self, x, x_mask, g=None, reverse=False):
|
391 |
+
if not reverse:
|
392 |
+
for flow in self.flows:
|
393 |
+
x, _ = flow(x, x_mask, g=g, reverse=reverse)
|
394 |
+
else:
|
395 |
+
for flow in reversed(self.flows):
|
396 |
+
x = flow(x, x_mask, g=g, reverse=reverse)
|
397 |
+
return x
|
398 |
+
|
399 |
+
class SynthesizerTrn(nn.Module):
|
400 |
+
"""
|
401 |
+
Synthesizer for Training
|
402 |
+
"""
|
403 |
+
|
404 |
+
def __init__(
|
405 |
+
self,
|
406 |
+
n_vocab,
|
407 |
+
spec_channels,
|
408 |
+
inter_channels,
|
409 |
+
hidden_channels,
|
410 |
+
filter_channels,
|
411 |
+
n_heads,
|
412 |
+
n_layers,
|
413 |
+
kernel_size,
|
414 |
+
p_dropout,
|
415 |
+
resblock,
|
416 |
+
resblock_kernel_sizes,
|
417 |
+
resblock_dilation_sizes,
|
418 |
+
upsample_rates,
|
419 |
+
upsample_initial_channel,
|
420 |
+
upsample_kernel_sizes,
|
421 |
+
n_speakers=256,
|
422 |
+
gin_channels=256,
|
423 |
+
zero_g=False,
|
424 |
+
**kwargs
|
425 |
+
):
|
426 |
+
super().__init__()
|
427 |
+
|
428 |
+
self.dec = Generator(
|
429 |
+
inter_channels,
|
430 |
+
resblock,
|
431 |
+
resblock_kernel_sizes,
|
432 |
+
resblock_dilation_sizes,
|
433 |
+
upsample_rates,
|
434 |
+
upsample_initial_channel,
|
435 |
+
upsample_kernel_sizes,
|
436 |
+
gin_channels=gin_channels,
|
437 |
+
)
|
438 |
+
self.enc_q = PosteriorEncoder(
|
439 |
+
spec_channels,
|
440 |
+
inter_channels,
|
441 |
+
hidden_channels,
|
442 |
+
5,
|
443 |
+
1,
|
444 |
+
16,
|
445 |
+
gin_channels=gin_channels,
|
446 |
+
)
|
447 |
+
|
448 |
+
self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
|
449 |
+
|
450 |
+
self.n_speakers = n_speakers
|
451 |
+
if n_speakers == 0:
|
452 |
+
self.ref_enc = ReferenceEncoder(spec_channels, gin_channels)
|
453 |
+
else:
|
454 |
+
self.enc_p = TextEncoder(n_vocab,
|
455 |
+
inter_channels,
|
456 |
+
hidden_channels,
|
457 |
+
filter_channels,
|
458 |
+
n_heads,
|
459 |
+
n_layers,
|
460 |
+
kernel_size,
|
461 |
+
p_dropout)
|
462 |
+
self.sdp = StochasticDurationPredictor(hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels)
|
463 |
+
self.dp = DurationPredictor(hidden_channels, 256, 3, 0.5, gin_channels=gin_channels)
|
464 |
+
self.emb_g = nn.Embedding(n_speakers, gin_channels)
|
465 |
+
self.zero_g = zero_g
|
466 |
+
|
467 |
+
def infer(self, x, x_lengths, sid=None, noise_scale=1, length_scale=1, noise_scale_w=1., sdp_ratio=0.2, max_len=None):
|
468 |
+
x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths)
|
469 |
+
if self.n_speakers > 0:
|
470 |
+
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
|
471 |
+
else:
|
472 |
+
g = None
|
473 |
+
|
474 |
+
logw = self.sdp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w) * sdp_ratio \
|
475 |
+
+ self.dp(x, x_mask, g=g) * (1 - sdp_ratio)
|
476 |
+
|
477 |
+
w = torch.exp(logw) * x_mask * length_scale
|
478 |
+
w_ceil = torch.ceil(w)
|
479 |
+
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
|
480 |
+
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(x_mask.dtype)
|
481 |
+
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
|
482 |
+
attn = commons.generate_path(w_ceil, attn_mask)
|
483 |
+
|
484 |
+
m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
|
485 |
+
logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
|
486 |
+
|
487 |
+
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
|
488 |
+
z = self.flow(z_p, y_mask, g=g, reverse=True)
|
489 |
+
o = self.dec((z * y_mask)[:,:,:max_len], g=g)
|
490 |
+
return o, attn, y_mask, (z, z_p, m_p, logs_p)
|
491 |
+
|
492 |
+
def voice_conversion(self, y, y_lengths, sid_src, sid_tgt, tau=1.0):
|
493 |
+
g_src = sid_src
|
494 |
+
g_tgt = sid_tgt
|
495 |
+
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g_src if not self.zero_g else torch.zeros_like(g_src), tau=tau)
|
496 |
+
z_p = self.flow(z, y_mask, g=g_src)
|
497 |
+
z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True)
|
498 |
+
o_hat = self.dec(z_hat * y_mask, g=g_tgt if not self.zero_g else torch.zeros_like(g_tgt))
|
499 |
+
return o_hat, y_mask, (z, z_p, z_hat)
|
OpenVoice/openvoice/modules.py
ADDED
@@ -0,0 +1,598 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
from torch.nn import functional as F
|
5 |
+
|
6 |
+
from torch.nn import Conv1d
|
7 |
+
from torch.nn.utils import weight_norm, remove_weight_norm
|
8 |
+
|
9 |
+
from openvoice import commons
|
10 |
+
from openvoice.commons import init_weights, get_padding
|
11 |
+
from openvoice.transforms import piecewise_rational_quadratic_transform
|
12 |
+
from openvoice.attentions import Encoder
|
13 |
+
|
14 |
+
LRELU_SLOPE = 0.1
|
15 |
+
|
16 |
+
|
17 |
+
class LayerNorm(nn.Module):
|
18 |
+
def __init__(self, channels, eps=1e-5):
|
19 |
+
super().__init__()
|
20 |
+
self.channels = channels
|
21 |
+
self.eps = eps
|
22 |
+
|
23 |
+
self.gamma = nn.Parameter(torch.ones(channels))
|
24 |
+
self.beta = nn.Parameter(torch.zeros(channels))
|
25 |
+
|
26 |
+
def forward(self, x):
|
27 |
+
x = x.transpose(1, -1)
|
28 |
+
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
|
29 |
+
return x.transpose(1, -1)
|
30 |
+
|
31 |
+
|
32 |
+
class ConvReluNorm(nn.Module):
|
33 |
+
def __init__(
|
34 |
+
self,
|
35 |
+
in_channels,
|
36 |
+
hidden_channels,
|
37 |
+
out_channels,
|
38 |
+
kernel_size,
|
39 |
+
n_layers,
|
40 |
+
p_dropout,
|
41 |
+
):
|
42 |
+
super().__init__()
|
43 |
+
self.in_channels = in_channels
|
44 |
+
self.hidden_channels = hidden_channels
|
45 |
+
self.out_channels = out_channels
|
46 |
+
self.kernel_size = kernel_size
|
47 |
+
self.n_layers = n_layers
|
48 |
+
self.p_dropout = p_dropout
|
49 |
+
assert n_layers > 1, "Number of layers should be larger than 0."
|
50 |
+
|
51 |
+
self.conv_layers = nn.ModuleList()
|
52 |
+
self.norm_layers = nn.ModuleList()
|
53 |
+
self.conv_layers.append(
|
54 |
+
nn.Conv1d(
|
55 |
+
in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
|
56 |
+
)
|
57 |
+
)
|
58 |
+
self.norm_layers.append(LayerNorm(hidden_channels))
|
59 |
+
self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
|
60 |
+
for _ in range(n_layers - 1):
|
61 |
+
self.conv_layers.append(
|
62 |
+
nn.Conv1d(
|
63 |
+
hidden_channels,
|
64 |
+
hidden_channels,
|
65 |
+
kernel_size,
|
66 |
+
padding=kernel_size // 2,
|
67 |
+
)
|
68 |
+
)
|
69 |
+
self.norm_layers.append(LayerNorm(hidden_channels))
|
70 |
+
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
|
71 |
+
self.proj.weight.data.zero_()
|
72 |
+
self.proj.bias.data.zero_()
|
73 |
+
|
74 |
+
def forward(self, x, x_mask):
|
75 |
+
x_org = x
|
76 |
+
for i in range(self.n_layers):
|
77 |
+
x = self.conv_layers[i](x * x_mask)
|
78 |
+
x = self.norm_layers[i](x)
|
79 |
+
x = self.relu_drop(x)
|
80 |
+
x = x_org + self.proj(x)
|
81 |
+
return x * x_mask
|
82 |
+
|
83 |
+
|
84 |
+
class DDSConv(nn.Module):
|
85 |
+
"""
|
86 |
+
Dilated and Depth-Separable Convolution
|
87 |
+
"""
|
88 |
+
|
89 |
+
def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0):
|
90 |
+
super().__init__()
|
91 |
+
self.channels = channels
|
92 |
+
self.kernel_size = kernel_size
|
93 |
+
self.n_layers = n_layers
|
94 |
+
self.p_dropout = p_dropout
|
95 |
+
|
96 |
+
self.drop = nn.Dropout(p_dropout)
|
97 |
+
self.convs_sep = nn.ModuleList()
|
98 |
+
self.convs_1x1 = nn.ModuleList()
|
99 |
+
self.norms_1 = nn.ModuleList()
|
100 |
+
self.norms_2 = nn.ModuleList()
|
101 |
+
for i in range(n_layers):
|
102 |
+
dilation = kernel_size**i
|
103 |
+
padding = (kernel_size * dilation - dilation) // 2
|
104 |
+
self.convs_sep.append(
|
105 |
+
nn.Conv1d(
|
106 |
+
channels,
|
107 |
+
channels,
|
108 |
+
kernel_size,
|
109 |
+
groups=channels,
|
110 |
+
dilation=dilation,
|
111 |
+
padding=padding,
|
112 |
+
)
|
113 |
+
)
|
114 |
+
self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
|
115 |
+
self.norms_1.append(LayerNorm(channels))
|
116 |
+
self.norms_2.append(LayerNorm(channels))
|
117 |
+
|
118 |
+
def forward(self, x, x_mask, g=None):
|
119 |
+
if g is not None:
|
120 |
+
x = x + g
|
121 |
+
for i in range(self.n_layers):
|
122 |
+
y = self.convs_sep[i](x * x_mask)
|
123 |
+
y = self.norms_1[i](y)
|
124 |
+
y = F.gelu(y)
|
125 |
+
y = self.convs_1x1[i](y)
|
126 |
+
y = self.norms_2[i](y)
|
127 |
+
y = F.gelu(y)
|
128 |
+
y = self.drop(y)
|
129 |
+
x = x + y
|
130 |
+
return x * x_mask
|
131 |
+
|
132 |
+
|
133 |
+
class WN(torch.nn.Module):
|
134 |
+
def __init__(
|
135 |
+
self,
|
136 |
+
hidden_channels,
|
137 |
+
kernel_size,
|
138 |
+
dilation_rate,
|
139 |
+
n_layers,
|
140 |
+
gin_channels=0,
|
141 |
+
p_dropout=0,
|
142 |
+
):
|
143 |
+
super(WN, self).__init__()
|
144 |
+
assert kernel_size % 2 == 1
|
145 |
+
self.hidden_channels = hidden_channels
|
146 |
+
self.kernel_size = (kernel_size,)
|
147 |
+
self.dilation_rate = dilation_rate
|
148 |
+
self.n_layers = n_layers
|
149 |
+
self.gin_channels = gin_channels
|
150 |
+
self.p_dropout = p_dropout
|
151 |
+
|
152 |
+
self.in_layers = torch.nn.ModuleList()
|
153 |
+
self.res_skip_layers = torch.nn.ModuleList()
|
154 |
+
self.drop = nn.Dropout(p_dropout)
|
155 |
+
|
156 |
+
if gin_channels != 0:
|
157 |
+
cond_layer = torch.nn.Conv1d(
|
158 |
+
gin_channels, 2 * hidden_channels * n_layers, 1
|
159 |
+
)
|
160 |
+
self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
|
161 |
+
|
162 |
+
for i in range(n_layers):
|
163 |
+
dilation = dilation_rate**i
|
164 |
+
padding = int((kernel_size * dilation - dilation) / 2)
|
165 |
+
in_layer = torch.nn.Conv1d(
|
166 |
+
hidden_channels,
|
167 |
+
2 * hidden_channels,
|
168 |
+
kernel_size,
|
169 |
+
dilation=dilation,
|
170 |
+
padding=padding,
|
171 |
+
)
|
172 |
+
in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
|
173 |
+
self.in_layers.append(in_layer)
|
174 |
+
|
175 |
+
# last one is not necessary
|
176 |
+
if i < n_layers - 1:
|
177 |
+
res_skip_channels = 2 * hidden_channels
|
178 |
+
else:
|
179 |
+
res_skip_channels = hidden_channels
|
180 |
+
|
181 |
+
res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
|
182 |
+
res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
|
183 |
+
self.res_skip_layers.append(res_skip_layer)
|
184 |
+
|
185 |
+
def forward(self, x, x_mask, g=None, **kwargs):
|
186 |
+
output = torch.zeros_like(x)
|
187 |
+
n_channels_tensor = torch.IntTensor([self.hidden_channels])
|
188 |
+
|
189 |
+
if g is not None:
|
190 |
+
g = self.cond_layer(g)
|
191 |
+
|
192 |
+
for i in range(self.n_layers):
|
193 |
+
x_in = self.in_layers[i](x)
|
194 |
+
if g is not None:
|
195 |
+
cond_offset = i * 2 * self.hidden_channels
|
196 |
+
g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
|
197 |
+
else:
|
198 |
+
g_l = torch.zeros_like(x_in)
|
199 |
+
|
200 |
+
acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
|
201 |
+
acts = self.drop(acts)
|
202 |
+
|
203 |
+
res_skip_acts = self.res_skip_layers[i](acts)
|
204 |
+
if i < self.n_layers - 1:
|
205 |
+
res_acts = res_skip_acts[:, : self.hidden_channels, :]
|
206 |
+
x = (x + res_acts) * x_mask
|
207 |
+
output = output + res_skip_acts[:, self.hidden_channels :, :]
|
208 |
+
else:
|
209 |
+
output = output + res_skip_acts
|
210 |
+
return output * x_mask
|
211 |
+
|
212 |
+
def remove_weight_norm(self):
|
213 |
+
if self.gin_channels != 0:
|
214 |
+
torch.nn.utils.remove_weight_norm(self.cond_layer)
|
215 |
+
for l in self.in_layers:
|
216 |
+
torch.nn.utils.remove_weight_norm(l)
|
217 |
+
for l in self.res_skip_layers:
|
218 |
+
torch.nn.utils.remove_weight_norm(l)
|
219 |
+
|
220 |
+
|
221 |
+
class ResBlock1(torch.nn.Module):
|
222 |
+
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
|
223 |
+
super(ResBlock1, self).__init__()
|
224 |
+
self.convs1 = nn.ModuleList(
|
225 |
+
[
|
226 |
+
weight_norm(
|
227 |
+
Conv1d(
|
228 |
+
channels,
|
229 |
+
channels,
|
230 |
+
kernel_size,
|
231 |
+
1,
|
232 |
+
dilation=dilation[0],
|
233 |
+
padding=get_padding(kernel_size, dilation[0]),
|
234 |
+
)
|
235 |
+
),
|
236 |
+
weight_norm(
|
237 |
+
Conv1d(
|
238 |
+
channels,
|
239 |
+
channels,
|
240 |
+
kernel_size,
|
241 |
+
1,
|
242 |
+
dilation=dilation[1],
|
243 |
+
padding=get_padding(kernel_size, dilation[1]),
|
244 |
+
)
|
245 |
+
),
|
246 |
+
weight_norm(
|
247 |
+
Conv1d(
|
248 |
+
channels,
|
249 |
+
channels,
|
250 |
+
kernel_size,
|
251 |
+
1,
|
252 |
+
dilation=dilation[2],
|
253 |
+
padding=get_padding(kernel_size, dilation[2]),
|
254 |
+
)
|
255 |
+
),
|
256 |
+
]
|
257 |
+
)
|
258 |
+
self.convs1.apply(init_weights)
|
259 |
+
|
260 |
+
self.convs2 = nn.ModuleList(
|
261 |
+
[
|
262 |
+
weight_norm(
|
263 |
+
Conv1d(
|
264 |
+
channels,
|
265 |
+
channels,
|
266 |
+
kernel_size,
|
267 |
+
1,
|
268 |
+
dilation=1,
|
269 |
+
padding=get_padding(kernel_size, 1),
|
270 |
+
)
|
271 |
+
),
|
272 |
+
weight_norm(
|
273 |
+
Conv1d(
|
274 |
+
channels,
|
275 |
+
channels,
|
276 |
+
kernel_size,
|
277 |
+
1,
|
278 |
+
dilation=1,
|
279 |
+
padding=get_padding(kernel_size, 1),
|
280 |
+
)
|
281 |
+
),
|
282 |
+
weight_norm(
|
283 |
+
Conv1d(
|
284 |
+
channels,
|
285 |
+
channels,
|
286 |
+
kernel_size,
|
287 |
+
1,
|
288 |
+
dilation=1,
|
289 |
+
padding=get_padding(kernel_size, 1),
|
290 |
+
)
|
291 |
+
),
|
292 |
+
]
|
293 |
+
)
|
294 |
+
self.convs2.apply(init_weights)
|
295 |
+
|
296 |
+
def forward(self, x, x_mask=None):
|
297 |
+
for c1, c2 in zip(self.convs1, self.convs2):
|
298 |
+
xt = F.leaky_relu(x, LRELU_SLOPE)
|
299 |
+
if x_mask is not None:
|
300 |
+
xt = xt * x_mask
|
301 |
+
xt = c1(xt)
|
302 |
+
xt = F.leaky_relu(xt, LRELU_SLOPE)
|
303 |
+
if x_mask is not None:
|
304 |
+
xt = xt * x_mask
|
305 |
+
xt = c2(xt)
|
306 |
+
x = xt + x
|
307 |
+
if x_mask is not None:
|
308 |
+
x = x * x_mask
|
309 |
+
return x
|
310 |
+
|
311 |
+
def remove_weight_norm(self):
|
312 |
+
for l in self.convs1:
|
313 |
+
remove_weight_norm(l)
|
314 |
+
for l in self.convs2:
|
315 |
+
remove_weight_norm(l)
|
316 |
+
|
317 |
+
|
318 |
+
class ResBlock2(torch.nn.Module):
|
319 |
+
def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
|
320 |
+
super(ResBlock2, self).__init__()
|
321 |
+
self.convs = nn.ModuleList(
|
322 |
+
[
|
323 |
+
weight_norm(
|
324 |
+
Conv1d(
|
325 |
+
channels,
|
326 |
+
channels,
|
327 |
+
kernel_size,
|
328 |
+
1,
|
329 |
+
dilation=dilation[0],
|
330 |
+
padding=get_padding(kernel_size, dilation[0]),
|
331 |
+
)
|
332 |
+
),
|
333 |
+
weight_norm(
|
334 |
+
Conv1d(
|
335 |
+
channels,
|
336 |
+
channels,
|
337 |
+
kernel_size,
|
338 |
+
1,
|
339 |
+
dilation=dilation[1],
|
340 |
+
padding=get_padding(kernel_size, dilation[1]),
|
341 |
+
)
|
342 |
+
),
|
343 |
+
]
|
344 |
+
)
|
345 |
+
self.convs.apply(init_weights)
|
346 |
+
|
347 |
+
def forward(self, x, x_mask=None):
|
348 |
+
for c in self.convs:
|
349 |
+
xt = F.leaky_relu(x, LRELU_SLOPE)
|
350 |
+
if x_mask is not None:
|
351 |
+
xt = xt * x_mask
|
352 |
+
xt = c(xt)
|
353 |
+
x = xt + x
|
354 |
+
if x_mask is not None:
|
355 |
+
x = x * x_mask
|
356 |
+
return x
|
357 |
+
|
358 |
+
def remove_weight_norm(self):
|
359 |
+
for l in self.convs:
|
360 |
+
remove_weight_norm(l)
|
361 |
+
|
362 |
+
|
363 |
+
class Log(nn.Module):
|
364 |
+
def forward(self, x, x_mask, reverse=False, **kwargs):
|
365 |
+
if not reverse:
|
366 |
+
y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
|
367 |
+
logdet = torch.sum(-y, [1, 2])
|
368 |
+
return y, logdet
|
369 |
+
else:
|
370 |
+
x = torch.exp(x) * x_mask
|
371 |
+
return x
|
372 |
+
|
373 |
+
|
374 |
+
class Flip(nn.Module):
|
375 |
+
def forward(self, x, *args, reverse=False, **kwargs):
|
376 |
+
x = torch.flip(x, [1])
|
377 |
+
if not reverse:
|
378 |
+
logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
|
379 |
+
return x, logdet
|
380 |
+
else:
|
381 |
+
return x
|
382 |
+
|
383 |
+
|
384 |
+
class ElementwiseAffine(nn.Module):
|
385 |
+
def __init__(self, channels):
|
386 |
+
super().__init__()
|
387 |
+
self.channels = channels
|
388 |
+
self.m = nn.Parameter(torch.zeros(channels, 1))
|
389 |
+
self.logs = nn.Parameter(torch.zeros(channels, 1))
|
390 |
+
|
391 |
+
def forward(self, x, x_mask, reverse=False, **kwargs):
|
392 |
+
if not reverse:
|
393 |
+
y = self.m + torch.exp(self.logs) * x
|
394 |
+
y = y * x_mask
|
395 |
+
logdet = torch.sum(self.logs * x_mask, [1, 2])
|
396 |
+
return y, logdet
|
397 |
+
else:
|
398 |
+
x = (x - self.m) * torch.exp(-self.logs) * x_mask
|
399 |
+
return x
|
400 |
+
|
401 |
+
|
402 |
+
class ResidualCouplingLayer(nn.Module):
|
403 |
+
def __init__(
|
404 |
+
self,
|
405 |
+
channels,
|
406 |
+
hidden_channels,
|
407 |
+
kernel_size,
|
408 |
+
dilation_rate,
|
409 |
+
n_layers,
|
410 |
+
p_dropout=0,
|
411 |
+
gin_channels=0,
|
412 |
+
mean_only=False,
|
413 |
+
):
|
414 |
+
assert channels % 2 == 0, "channels should be divisible by 2"
|
415 |
+
super().__init__()
|
416 |
+
self.channels = channels
|
417 |
+
self.hidden_channels = hidden_channels
|
418 |
+
self.kernel_size = kernel_size
|
419 |
+
self.dilation_rate = dilation_rate
|
420 |
+
self.n_layers = n_layers
|
421 |
+
self.half_channels = channels // 2
|
422 |
+
self.mean_only = mean_only
|
423 |
+
|
424 |
+
self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
|
425 |
+
self.enc = WN(
|
426 |
+
hidden_channels,
|
427 |
+
kernel_size,
|
428 |
+
dilation_rate,
|
429 |
+
n_layers,
|
430 |
+
p_dropout=p_dropout,
|
431 |
+
gin_channels=gin_channels,
|
432 |
+
)
|
433 |
+
self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
|
434 |
+
self.post.weight.data.zero_()
|
435 |
+
self.post.bias.data.zero_()
|
436 |
+
|
437 |
+
def forward(self, x, x_mask, g=None, reverse=False):
|
438 |
+
x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
|
439 |
+
h = self.pre(x0) * x_mask
|
440 |
+
h = self.enc(h, x_mask, g=g)
|
441 |
+
stats = self.post(h) * x_mask
|
442 |
+
if not self.mean_only:
|
443 |
+
m, logs = torch.split(stats, [self.half_channels] * 2, 1)
|
444 |
+
else:
|
445 |
+
m = stats
|
446 |
+
logs = torch.zeros_like(m)
|
447 |
+
|
448 |
+
if not reverse:
|
449 |
+
x1 = m + x1 * torch.exp(logs) * x_mask
|
450 |
+
x = torch.cat([x0, x1], 1)
|
451 |
+
logdet = torch.sum(logs, [1, 2])
|
452 |
+
return x, logdet
|
453 |
+
else:
|
454 |
+
x1 = (x1 - m) * torch.exp(-logs) * x_mask
|
455 |
+
x = torch.cat([x0, x1], 1)
|
456 |
+
return x
|
457 |
+
|
458 |
+
|
459 |
+
class ConvFlow(nn.Module):
|
460 |
+
def __init__(
|
461 |
+
self,
|
462 |
+
in_channels,
|
463 |
+
filter_channels,
|
464 |
+
kernel_size,
|
465 |
+
n_layers,
|
466 |
+
num_bins=10,
|
467 |
+
tail_bound=5.0,
|
468 |
+
):
|
469 |
+
super().__init__()
|
470 |
+
self.in_channels = in_channels
|
471 |
+
self.filter_channels = filter_channels
|
472 |
+
self.kernel_size = kernel_size
|
473 |
+
self.n_layers = n_layers
|
474 |
+
self.num_bins = num_bins
|
475 |
+
self.tail_bound = tail_bound
|
476 |
+
self.half_channels = in_channels // 2
|
477 |
+
|
478 |
+
self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
|
479 |
+
self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0)
|
480 |
+
self.proj = nn.Conv1d(
|
481 |
+
filter_channels, self.half_channels * (num_bins * 3 - 1), 1
|
482 |
+
)
|
483 |
+
self.proj.weight.data.zero_()
|
484 |
+
self.proj.bias.data.zero_()
|
485 |
+
|
486 |
+
def forward(self, x, x_mask, g=None, reverse=False):
|
487 |
+
x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
|
488 |
+
h = self.pre(x0)
|
489 |
+
h = self.convs(h, x_mask, g=g)
|
490 |
+
h = self.proj(h) * x_mask
|
491 |
+
|
492 |
+
b, c, t = x0.shape
|
493 |
+
h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
|
494 |
+
|
495 |
+
unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels)
|
496 |
+
unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(
|
497 |
+
self.filter_channels
|
498 |
+
)
|
499 |
+
unnormalized_derivatives = h[..., 2 * self.num_bins :]
|
500 |
+
|
501 |
+
x1, logabsdet = piecewise_rational_quadratic_transform(
|
502 |
+
x1,
|
503 |
+
unnormalized_widths,
|
504 |
+
unnormalized_heights,
|
505 |
+
unnormalized_derivatives,
|
506 |
+
inverse=reverse,
|
507 |
+
tails="linear",
|
508 |
+
tail_bound=self.tail_bound,
|
509 |
+
)
|
510 |
+
|
511 |
+
x = torch.cat([x0, x1], 1) * x_mask
|
512 |
+
logdet = torch.sum(logabsdet * x_mask, [1, 2])
|
513 |
+
if not reverse:
|
514 |
+
return x, logdet
|
515 |
+
else:
|
516 |
+
return x
|
517 |
+
|
518 |
+
|
519 |
+
class TransformerCouplingLayer(nn.Module):
|
520 |
+
def __init__(
|
521 |
+
self,
|
522 |
+
channels,
|
523 |
+
hidden_channels,
|
524 |
+
kernel_size,
|
525 |
+
n_layers,
|
526 |
+
n_heads,
|
527 |
+
p_dropout=0,
|
528 |
+
filter_channels=0,
|
529 |
+
mean_only=False,
|
530 |
+
wn_sharing_parameter=None,
|
531 |
+
gin_channels=0,
|
532 |
+
):
|
533 |
+
assert n_layers == 3, n_layers
|
534 |
+
assert channels % 2 == 0, "channels should be divisible by 2"
|
535 |
+
super().__init__()
|
536 |
+
self.channels = channels
|
537 |
+
self.hidden_channels = hidden_channels
|
538 |
+
self.kernel_size = kernel_size
|
539 |
+
self.n_layers = n_layers
|
540 |
+
self.half_channels = channels // 2
|
541 |
+
self.mean_only = mean_only
|
542 |
+
|
543 |
+
self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
|
544 |
+
self.enc = (
|
545 |
+
Encoder(
|
546 |
+
hidden_channels,
|
547 |
+
filter_channels,
|
548 |
+
n_heads,
|
549 |
+
n_layers,
|
550 |
+
kernel_size,
|
551 |
+
p_dropout,
|
552 |
+
isflow=True,
|
553 |
+
gin_channels=gin_channels,
|
554 |
+
)
|
555 |
+
if wn_sharing_parameter is None
|
556 |
+
else wn_sharing_parameter
|
557 |
+
)
|
558 |
+
self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
|
559 |
+
self.post.weight.data.zero_()
|
560 |
+
self.post.bias.data.zero_()
|
561 |
+
|
562 |
+
def forward(self, x, x_mask, g=None, reverse=False):
|
563 |
+
x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
|
564 |
+
h = self.pre(x0) * x_mask
|
565 |
+
h = self.enc(h, x_mask, g=g)
|
566 |
+
stats = self.post(h) * x_mask
|
567 |
+
if not self.mean_only:
|
568 |
+
m, logs = torch.split(stats, [self.half_channels] * 2, 1)
|
569 |
+
else:
|
570 |
+
m = stats
|
571 |
+
logs = torch.zeros_like(m)
|
572 |
+
|
573 |
+
if not reverse:
|
574 |
+
x1 = m + x1 * torch.exp(logs) * x_mask
|
575 |
+
x = torch.cat([x0, x1], 1)
|
576 |
+
logdet = torch.sum(logs, [1, 2])
|
577 |
+
return x, logdet
|
578 |
+
else:
|
579 |
+
x1 = (x1 - m) * torch.exp(-logs) * x_mask
|
580 |
+
x = torch.cat([x0, x1], 1)
|
581 |
+
return x
|
582 |
+
|
583 |
+
x1, logabsdet = piecewise_rational_quadratic_transform(
|
584 |
+
x1,
|
585 |
+
unnormalized_widths,
|
586 |
+
unnormalized_heights,
|
587 |
+
unnormalized_derivatives,
|
588 |
+
inverse=reverse,
|
589 |
+
tails="linear",
|
590 |
+
tail_bound=self.tail_bound,
|
591 |
+
)
|
592 |
+
|
593 |
+
x = torch.cat([x0, x1], 1) * x_mask
|
594 |
+
logdet = torch.sum(logabsdet * x_mask, [1, 2])
|
595 |
+
if not reverse:
|
596 |
+
return x, logdet
|
597 |
+
else:
|
598 |
+
return x
|
OpenVoice/openvoice/openvoice_app.py
ADDED
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import argparse
|
4 |
+
import gradio as gr
|
5 |
+
from zipfile import ZipFile
|
6 |
+
import langid
|
7 |
+
from openvoice import se_extractor
|
8 |
+
from openvoice.api import BaseSpeakerTTS, ToneColorConverter
|
9 |
+
|
10 |
+
parser = argparse.ArgumentParser()
|
11 |
+
parser.add_argument("--share", action='store_true', default=False, help="make link public")
|
12 |
+
args = parser.parse_args()
|
13 |
+
|
14 |
+
en_ckpt_base = 'checkpoints/base_speakers/EN'
|
15 |
+
zh_ckpt_base = 'checkpoints/base_speakers/ZH'
|
16 |
+
ckpt_converter = 'checkpoints/converter'
|
17 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
18 |
+
output_dir = 'outputs'
|
19 |
+
os.makedirs(output_dir, exist_ok=True)
|
20 |
+
|
21 |
+
# load models
|
22 |
+
en_base_speaker_tts = BaseSpeakerTTS(f'{en_ckpt_base}/config.json', device=device)
|
23 |
+
en_base_speaker_tts.load_ckpt(f'{en_ckpt_base}/checkpoint.pth')
|
24 |
+
zh_base_speaker_tts = BaseSpeakerTTS(f'{zh_ckpt_base}/config.json', device=device)
|
25 |
+
zh_base_speaker_tts.load_ckpt(f'{zh_ckpt_base}/checkpoint.pth')
|
26 |
+
tone_color_converter = ToneColorConverter(f'{ckpt_converter}/config.json', device=device)
|
27 |
+
tone_color_converter.load_ckpt(f'{ckpt_converter}/checkpoint.pth')
|
28 |
+
|
29 |
+
# load speaker embeddings
|
30 |
+
en_source_default_se = torch.load(f'{en_ckpt_base}/en_default_se.pth').to(device)
|
31 |
+
en_source_style_se = torch.load(f'{en_ckpt_base}/en_style_se.pth').to(device)
|
32 |
+
zh_source_se = torch.load(f'{zh_ckpt_base}/zh_default_se.pth').to(device)
|
33 |
+
|
34 |
+
# This online demo mainly supports English and Chinese
|
35 |
+
supported_languages = ['zh', 'en']
|
36 |
+
|
37 |
+
def predict(prompt, style, audio_file_pth, agree):
|
38 |
+
# initialize a empty info
|
39 |
+
text_hint = ''
|
40 |
+
# agree with the terms
|
41 |
+
if agree == False:
|
42 |
+
text_hint += '[ERROR] Please accept the Terms & Condition!\n'
|
43 |
+
gr.Warning("Please accept the Terms & Condition!")
|
44 |
+
return (
|
45 |
+
text_hint,
|
46 |
+
None,
|
47 |
+
None,
|
48 |
+
)
|
49 |
+
|
50 |
+
# first detect the input language
|
51 |
+
language_predicted = langid.classify(prompt)[0].strip()
|
52 |
+
print(f"Detected language:{language_predicted}")
|
53 |
+
|
54 |
+
if language_predicted not in supported_languages:
|
55 |
+
text_hint += f"[ERROR] The detected language {language_predicted} for your input text is not in our Supported Languages: {supported_languages}\n"
|
56 |
+
gr.Warning(
|
57 |
+
f"The detected language {language_predicted} for your input text is not in our Supported Languages: {supported_languages}"
|
58 |
+
)
|
59 |
+
|
60 |
+
return (
|
61 |
+
text_hint,
|
62 |
+
None,
|
63 |
+
None,
|
64 |
+
)
|
65 |
+
|
66 |
+
if language_predicted == "zh":
|
67 |
+
tts_model = zh_base_speaker_tts
|
68 |
+
source_se = zh_source_se
|
69 |
+
language = 'Chinese'
|
70 |
+
if style not in ['default']:
|
71 |
+
text_hint += f"[ERROR] The style {style} is not supported for Chinese, which should be in ['default']\n"
|
72 |
+
gr.Warning(f"The style {style} is not supported for Chinese, which should be in ['default']")
|
73 |
+
return (
|
74 |
+
text_hint,
|
75 |
+
None,
|
76 |
+
None,
|
77 |
+
)
|
78 |
+
|
79 |
+
else:
|
80 |
+
tts_model = en_base_speaker_tts
|
81 |
+
if style == 'default':
|
82 |
+
source_se = en_source_default_se
|
83 |
+
else:
|
84 |
+
source_se = en_source_style_se
|
85 |
+
language = 'English'
|
86 |
+
if style not in ['default', 'whispering', 'shouting', 'excited', 'cheerful', 'terrified', 'angry', 'sad', 'friendly']:
|
87 |
+
text_hint += f"[ERROR] The style {style} is not supported for English, which should be in ['default', 'whispering', 'shouting', 'excited', 'cheerful', 'terrified', 'angry', 'sad', 'friendly']\n"
|
88 |
+
gr.Warning(f"The style {style} is not supported for English, which should be in ['default', 'whispering', 'shouting', 'excited', 'cheerful', 'terrified', 'angry', 'sad', 'friendly']")
|
89 |
+
return (
|
90 |
+
text_hint,
|
91 |
+
None,
|
92 |
+
None,
|
93 |
+
)
|
94 |
+
|
95 |
+
speaker_wav = audio_file_pth
|
96 |
+
|
97 |
+
if len(prompt) < 2:
|
98 |
+
text_hint += f"[ERROR] Please give a longer prompt text \n"
|
99 |
+
gr.Warning("Please give a longer prompt text")
|
100 |
+
return (
|
101 |
+
text_hint,
|
102 |
+
None,
|
103 |
+
None,
|
104 |
+
)
|
105 |
+
if len(prompt) > 200:
|
106 |
+
text_hint += f"[ERROR] Text length limited to 200 characters for this demo, please try shorter text. You can clone our open-source repo and try for your usage \n"
|
107 |
+
gr.Warning(
|
108 |
+
"Text length limited to 200 characters for this demo, please try shorter text. You can clone our open-source repo for your usage"
|
109 |
+
)
|
110 |
+
return (
|
111 |
+
text_hint,
|
112 |
+
None,
|
113 |
+
None,
|
114 |
+
)
|
115 |
+
|
116 |
+
# note diffusion_conditioning not used on hifigan (default mode), it will be empty but need to pass it to model.inference
|
117 |
+
try:
|
118 |
+
target_se, audio_name = se_extractor.get_se(speaker_wav, tone_color_converter, target_dir='processed', vad=True)
|
119 |
+
except Exception as e:
|
120 |
+
text_hint += f"[ERROR] Get target tone color error {str(e)} \n"
|
121 |
+
gr.Warning(
|
122 |
+
"[ERROR] Get target tone color error {str(e)} \n"
|
123 |
+
)
|
124 |
+
return (
|
125 |
+
text_hint,
|
126 |
+
None,
|
127 |
+
None,
|
128 |
+
)
|
129 |
+
|
130 |
+
src_path = f'{output_dir}/tmp.wav'
|
131 |
+
tts_model.tts(prompt, src_path, speaker=style, language=language)
|
132 |
+
|
133 |
+
save_path = f'{output_dir}/output.wav'
|
134 |
+
# Run the tone color converter
|
135 |
+
encode_message = "@MyShell"
|
136 |
+
tone_color_converter.convert(
|
137 |
+
audio_src_path=src_path,
|
138 |
+
src_se=source_se,
|
139 |
+
tgt_se=target_se,
|
140 |
+
output_path=save_path,
|
141 |
+
message=encode_message)
|
142 |
+
|
143 |
+
text_hint += f'''Get response successfully \n'''
|
144 |
+
|
145 |
+
return (
|
146 |
+
text_hint,
|
147 |
+
save_path,
|
148 |
+
speaker_wav,
|
149 |
+
)
|
150 |
+
|
151 |
+
|
152 |
+
|
153 |
+
title = "MyShell OpenVoice"
|
154 |
+
|
155 |
+
description = """
|
156 |
+
We introduce OpenVoice, a versatile instant voice cloning approach that requires only a short audio clip from the reference speaker to replicate their voice and generate speech in multiple languages. OpenVoice enables granular control over voice styles, including emotion, accent, rhythm, pauses, and intonation, in addition to replicating the tone color of the reference speaker. OpenVoice also achieves zero-shot cross-lingual voice cloning for languages not included in the massive-speaker training set.
|
157 |
+
"""
|
158 |
+
|
159 |
+
markdown_table = """
|
160 |
+
<div align="center" style="margin-bottom: 10px;">
|
161 |
+
|
162 |
+
| | | |
|
163 |
+
| :-----------: | :-----------: | :-----------: |
|
164 |
+
| **OpenSource Repo** | **Project Page** | **Join the Community** |
|
165 |
+
| <div style='text-align: center;'><a style="display:inline-block,align:center" href='https://github.com/myshell-ai/OpenVoice'><img src='https://img.shields.io/github/stars/myshell-ai/OpenVoice?style=social' /></a></div> | [OpenVoice](https://research.myshell.ai/open-voice) | [](https://discord.gg/myshell) |
|
166 |
+
|
167 |
+
</div>
|
168 |
+
"""
|
169 |
+
|
170 |
+
markdown_table_v2 = """
|
171 |
+
<div align="center" style="margin-bottom: 2px;">
|
172 |
+
|
173 |
+
| | | | |
|
174 |
+
| :-----------: | :-----------: | :-----------: | :-----------: |
|
175 |
+
| **OpenSource Repo** | <div style='text-align: center;'><a style="display:inline-block,align:center" href='https://github.com/myshell-ai/OpenVoice'><img src='https://img.shields.io/github/stars/myshell-ai/OpenVoice?style=social' /></a></div> | **Project Page** | [OpenVoice](https://research.myshell.ai/open-voice) |
|
176 |
+
|
177 |
+
| | |
|
178 |
+
| :-----------: | :-----------: |
|
179 |
+
**Join the Community** | [](https://discord.gg/myshell) |
|
180 |
+
|
181 |
+
</div>
|
182 |
+
"""
|
183 |
+
content = """
|
184 |
+
<div>
|
185 |
+
<strong>If the generated voice does not sound like the reference voice, please refer to <a href='https://github.com/myshell-ai/OpenVoice/blob/main/docs/QA.md'>this QnA</a>.</strong> <strong>For multi-lingual & cross-lingual examples, please refer to <a href='https://github.com/myshell-ai/OpenVoice/blob/main/demo_part2.ipynb'>this jupyter notebook</a>.</strong>
|
186 |
+
This online demo mainly supports <strong>English</strong>. The <em>default</em> style also supports <strong>Chinese</strong>. But OpenVoice can adapt to any other language as long as a base speaker is provided.
|
187 |
+
</div>
|
188 |
+
"""
|
189 |
+
wrapped_markdown_content = f"<div style='border: 1px solid #000; padding: 10px;'>{content}</div>"
|
190 |
+
|
191 |
+
|
192 |
+
examples = [
|
193 |
+
[
|
194 |
+
"今天天气真好,我们一起出去吃饭吧。",
|
195 |
+
'default',
|
196 |
+
"resources/demo_speaker1.mp3",
|
197 |
+
True,
|
198 |
+
],[
|
199 |
+
"This audio is generated by open voice with a half-performance model.",
|
200 |
+
'whispering',
|
201 |
+
"resources/demo_speaker2.mp3",
|
202 |
+
True,
|
203 |
+
],
|
204 |
+
[
|
205 |
+
"He hoped there would be stew for dinner, turnips and carrots and bruised potatoes and fat mutton pieces to be ladled out in thick, peppered, flour-fattened sauce.",
|
206 |
+
'sad',
|
207 |
+
"resources/demo_speaker0.mp3",
|
208 |
+
True,
|
209 |
+
],
|
210 |
+
]
|
211 |
+
|
212 |
+
with gr.Blocks(analytics_enabled=False) as demo:
|
213 |
+
|
214 |
+
with gr.Row():
|
215 |
+
with gr.Column():
|
216 |
+
with gr.Row():
|
217 |
+
gr.Markdown(
|
218 |
+
"""
|
219 |
+
## <img src="https://huggingface.co/spaces/myshell-ai/OpenVoice/raw/main/logo.jpg" height="40"/>
|
220 |
+
"""
|
221 |
+
)
|
222 |
+
with gr.Row():
|
223 |
+
gr.Markdown(markdown_table_v2)
|
224 |
+
with gr.Row():
|
225 |
+
gr.Markdown(description)
|
226 |
+
with gr.Column():
|
227 |
+
gr.Video('https://github.com/myshell-ai/OpenVoice/assets/40556743/3cba936f-82bf-476c-9e52-09f0f417bb2f', autoplay=True)
|
228 |
+
|
229 |
+
with gr.Row():
|
230 |
+
gr.HTML(wrapped_markdown_content)
|
231 |
+
|
232 |
+
with gr.Row():
|
233 |
+
with gr.Column():
|
234 |
+
input_text_gr = gr.Textbox(
|
235 |
+
label="Text Prompt",
|
236 |
+
info="One or two sentences at a time is better. Up to 200 text characters.",
|
237 |
+
value="He hoped there would be stew for dinner, turnips and carrots and bruised potatoes and fat mutton pieces to be ladled out in thick, peppered, flour-fattened sauce.",
|
238 |
+
)
|
239 |
+
style_gr = gr.Dropdown(
|
240 |
+
label="Style",
|
241 |
+
info="Select a style of output audio for the synthesised speech. (Chinese only support 'default' now)",
|
242 |
+
choices=['default', 'whispering', 'cheerful', 'terrified', 'angry', 'sad', 'friendly'],
|
243 |
+
max_choices=1,
|
244 |
+
value="default",
|
245 |
+
)
|
246 |
+
ref_gr = gr.Audio(
|
247 |
+
label="Reference Audio",
|
248 |
+
info="Click on the ✎ button to upload your own target speaker audio",
|
249 |
+
type="filepath",
|
250 |
+
value="resources/demo_speaker2.mp3",
|
251 |
+
)
|
252 |
+
tos_gr = gr.Checkbox(
|
253 |
+
label="Agree",
|
254 |
+
value=False,
|
255 |
+
info="I agree to the terms of the cc-by-nc-4.0 license-: https://github.com/myshell-ai/OpenVoice/blob/main/LICENSE",
|
256 |
+
)
|
257 |
+
|
258 |
+
tts_button = gr.Button("Send", elem_id="send-btn", visible=True)
|
259 |
+
|
260 |
+
|
261 |
+
with gr.Column():
|
262 |
+
out_text_gr = gr.Text(label="Info")
|
263 |
+
audio_gr = gr.Audio(label="Synthesised Audio", autoplay=True)
|
264 |
+
ref_audio_gr = gr.Audio(label="Reference Audio Used")
|
265 |
+
|
266 |
+
gr.Examples(examples,
|
267 |
+
label="Examples",
|
268 |
+
inputs=[input_text_gr, style_gr, ref_gr, tos_gr],
|
269 |
+
outputs=[out_text_gr, audio_gr, ref_audio_gr],
|
270 |
+
fn=predict,
|
271 |
+
cache_examples=False,)
|
272 |
+
tts_button.click(predict, [input_text_gr, style_gr, ref_gr, tos_gr], outputs=[out_text_gr, audio_gr, ref_audio_gr])
|
273 |
+
|
274 |
+
demo.queue()
|
275 |
+
demo.launch(debug=True, show_api=True, share=args.share)
|
OpenVoice/openvoice/se_extractor.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import glob
|
3 |
+
import torch
|
4 |
+
import hashlib
|
5 |
+
import librosa
|
6 |
+
import base64
|
7 |
+
from glob import glob
|
8 |
+
import numpy as np
|
9 |
+
from pydub import AudioSegment
|
10 |
+
import hashlib
|
11 |
+
import base64
|
12 |
+
import librosa
|
13 |
+
from whisper_timestamped.transcribe import get_audio_tensor, get_vad_segments
|
14 |
+
|
15 |
+
model_size = "medium"
|
16 |
+
# Run on GPU with FP16
|
17 |
+
model = None
|
18 |
+
def split_audio_whisper(audio_path, audio_name, target_dir='processed'):
|
19 |
+
global model
|
20 |
+
if model is None:
|
21 |
+
# Lazy import to avoid loading faster-whisper on environments where it is unsupported
|
22 |
+
from faster_whisper import WhisperModel # type: ignore
|
23 |
+
model = WhisperModel(model_size, device="cuda", compute_type="float16")
|
24 |
+
audio = AudioSegment.from_file(audio_path)
|
25 |
+
max_len = len(audio)
|
26 |
+
|
27 |
+
target_folder = os.path.join(target_dir, audio_name)
|
28 |
+
|
29 |
+
segments, info = model.transcribe(audio_path, beam_size=5, word_timestamps=True)
|
30 |
+
segments = list(segments)
|
31 |
+
|
32 |
+
# create directory
|
33 |
+
os.makedirs(target_folder, exist_ok=True)
|
34 |
+
wavs_folder = os.path.join(target_folder, 'wavs')
|
35 |
+
os.makedirs(wavs_folder, exist_ok=True)
|
36 |
+
|
37 |
+
# segments
|
38 |
+
s_ind = 0
|
39 |
+
start_time = None
|
40 |
+
|
41 |
+
for k, w in enumerate(segments):
|
42 |
+
# process with the time
|
43 |
+
if k == 0:
|
44 |
+
start_time = max(0, w.start)
|
45 |
+
|
46 |
+
end_time = w.end
|
47 |
+
|
48 |
+
# calculate confidence
|
49 |
+
if len(w.words) > 0:
|
50 |
+
confidence = sum([s.probability for s in w.words]) / len(w.words)
|
51 |
+
else:
|
52 |
+
confidence = 0.
|
53 |
+
# clean text
|
54 |
+
text = w.text.replace('...', '')
|
55 |
+
|
56 |
+
# left 0.08s for each audios
|
57 |
+
audio_seg = audio[int( start_time * 1000) : min(max_len, int(end_time * 1000) + 80)]
|
58 |
+
|
59 |
+
# segment file name
|
60 |
+
fname = f"{audio_name}_seg{s_ind}.wav"
|
61 |
+
|
62 |
+
# filter out the segment shorter than 1.5s and longer than 20s
|
63 |
+
save = audio_seg.duration_seconds > 1.5 and \
|
64 |
+
audio_seg.duration_seconds < 20. and \
|
65 |
+
len(text) >= 2 and len(text) < 200
|
66 |
+
|
67 |
+
if save:
|
68 |
+
output_file = os.path.join(wavs_folder, fname)
|
69 |
+
audio_seg.export(output_file, format='wav')
|
70 |
+
|
71 |
+
if k < len(segments) - 1:
|
72 |
+
start_time = max(0, segments[k+1].start - 0.08)
|
73 |
+
|
74 |
+
s_ind = s_ind + 1
|
75 |
+
return wavs_folder
|
76 |
+
|
77 |
+
|
78 |
+
def split_audio_vad(audio_path, audio_name, target_dir, split_seconds=10.0):
|
79 |
+
SAMPLE_RATE = 16000
|
80 |
+
audio_vad = get_audio_tensor(audio_path)
|
81 |
+
segments = get_vad_segments(
|
82 |
+
audio_vad,
|
83 |
+
output_sample=True,
|
84 |
+
min_speech_duration=0.1,
|
85 |
+
min_silence_duration=1,
|
86 |
+
method="silero",
|
87 |
+
)
|
88 |
+
segments = [(seg["start"], seg["end"]) for seg in segments]
|
89 |
+
segments = [(float(s) / SAMPLE_RATE, float(e) / SAMPLE_RATE) for s,e in segments]
|
90 |
+
print(segments)
|
91 |
+
audio_active = AudioSegment.silent(duration=0)
|
92 |
+
audio = AudioSegment.from_file(audio_path)
|
93 |
+
|
94 |
+
for start_time, end_time in segments:
|
95 |
+
audio_active += audio[int( start_time * 1000) : int(end_time * 1000)]
|
96 |
+
|
97 |
+
audio_dur = audio_active.duration_seconds
|
98 |
+
print(f'after vad: dur = {audio_dur}')
|
99 |
+
target_folder = os.path.join(target_dir, audio_name)
|
100 |
+
wavs_folder = os.path.join(target_folder, 'wavs')
|
101 |
+
os.makedirs(wavs_folder, exist_ok=True)
|
102 |
+
start_time = 0.
|
103 |
+
count = 0
|
104 |
+
num_splits = int(np.round(audio_dur / split_seconds))
|
105 |
+
assert num_splits > 0, 'input audio is too short'
|
106 |
+
interval = audio_dur / num_splits
|
107 |
+
|
108 |
+
for i in range(num_splits):
|
109 |
+
end_time = min(start_time + interval, audio_dur)
|
110 |
+
if i == num_splits - 1:
|
111 |
+
end_time = audio_dur
|
112 |
+
output_file = f"{wavs_folder}/{audio_name}_seg{count}.wav"
|
113 |
+
audio_seg = audio_active[int(start_time * 1000): int(end_time * 1000)]
|
114 |
+
audio_seg.export(output_file, format='wav')
|
115 |
+
start_time = end_time
|
116 |
+
count += 1
|
117 |
+
return wavs_folder
|
118 |
+
|
119 |
+
def hash_numpy_array(audio_path):
|
120 |
+
array, _ = librosa.load(audio_path, sr=None, mono=True)
|
121 |
+
# Convert the array to bytes
|
122 |
+
array_bytes = array.tobytes()
|
123 |
+
# Calculate the hash of the array bytes
|
124 |
+
hash_object = hashlib.sha256(array_bytes)
|
125 |
+
hash_value = hash_object.digest()
|
126 |
+
# Convert the hash value to base64
|
127 |
+
base64_value = base64.b64encode(hash_value)
|
128 |
+
return base64_value.decode('utf-8')[:16].replace('/', '_^')
|
129 |
+
|
130 |
+
def get_se(audio_path, vc_model, target_dir='processed', vad=True):
|
131 |
+
device = vc_model.device
|
132 |
+
version = vc_model.version
|
133 |
+
print("OpenVoice version:", version)
|
134 |
+
|
135 |
+
audio_name = f"{os.path.basename(audio_path).rsplit('.', 1)[0]}_{version}_{hash_numpy_array(audio_path)}"
|
136 |
+
se_path = os.path.join(target_dir, audio_name, 'se.pth')
|
137 |
+
|
138 |
+
# if os.path.isfile(se_path):
|
139 |
+
# se = torch.load(se_path).to(device)
|
140 |
+
# return se, audio_name
|
141 |
+
# if os.path.isdir(audio_path):
|
142 |
+
# wavs_folder = audio_path
|
143 |
+
|
144 |
+
if vad:
|
145 |
+
wavs_folder = split_audio_vad(audio_path, target_dir=target_dir, audio_name=audio_name)
|
146 |
+
else:
|
147 |
+
wavs_folder = split_audio_whisper(audio_path, target_dir=target_dir, audio_name=audio_name)
|
148 |
+
|
149 |
+
audio_segs = glob(f'{wavs_folder}/*.wav')
|
150 |
+
if len(audio_segs) == 0:
|
151 |
+
raise NotImplementedError('No audio segments found!')
|
152 |
+
|
153 |
+
return vc_model.extract_se(audio_segs, se_save_path=se_path), audio_name
|
154 |
+
|
OpenVoice/openvoice/text/__init__.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" from https://github.com/keithito/tacotron """
|
2 |
+
from openvoice.text import cleaners
|
3 |
+
from openvoice.text.symbols import symbols
|
4 |
+
|
5 |
+
|
6 |
+
# Mappings from symbol to numeric ID and vice versa:
|
7 |
+
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
8 |
+
_id_to_symbol = {i: s for i, s in enumerate(symbols)}
|
9 |
+
|
10 |
+
|
11 |
+
def text_to_sequence(text, symbols, cleaner_names):
|
12 |
+
'''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
|
13 |
+
Args:
|
14 |
+
text: string to convert to a sequence
|
15 |
+
cleaner_names: names of the cleaner functions to run the text through
|
16 |
+
Returns:
|
17 |
+
List of integers corresponding to the symbols in the text
|
18 |
+
'''
|
19 |
+
sequence = []
|
20 |
+
symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
21 |
+
clean_text = _clean_text(text, cleaner_names)
|
22 |
+
print(clean_text)
|
23 |
+
print(f" length:{len(clean_text)}")
|
24 |
+
for symbol in clean_text:
|
25 |
+
if symbol not in symbol_to_id.keys():
|
26 |
+
continue
|
27 |
+
symbol_id = symbol_to_id[symbol]
|
28 |
+
sequence += [symbol_id]
|
29 |
+
print(f" length:{len(sequence)}")
|
30 |
+
return sequence
|
31 |
+
|
32 |
+
|
33 |
+
def cleaned_text_to_sequence(cleaned_text, symbols):
|
34 |
+
'''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
|
35 |
+
Args:
|
36 |
+
text: string to convert to a sequence
|
37 |
+
Returns:
|
38 |
+
List of integers corresponding to the symbols in the text
|
39 |
+
'''
|
40 |
+
symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
41 |
+
sequence = [symbol_to_id[symbol] for symbol in cleaned_text if symbol in symbol_to_id.keys()]
|
42 |
+
return sequence
|
43 |
+
|
44 |
+
|
45 |
+
|
46 |
+
from openvoice.text.symbols import language_tone_start_map
|
47 |
+
def cleaned_text_to_sequence_vits2(cleaned_text, tones, language, symbols, languages):
|
48 |
+
"""Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
|
49 |
+
Args:
|
50 |
+
text: string to convert to a sequence
|
51 |
+
Returns:
|
52 |
+
List of integers corresponding to the symbols in the text
|
53 |
+
"""
|
54 |
+
symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
55 |
+
language_id_map = {s: i for i, s in enumerate(languages)}
|
56 |
+
phones = [symbol_to_id[symbol] for symbol in cleaned_text]
|
57 |
+
tone_start = language_tone_start_map[language]
|
58 |
+
tones = [i + tone_start for i in tones]
|
59 |
+
lang_id = language_id_map[language]
|
60 |
+
lang_ids = [lang_id for i in phones]
|
61 |
+
return phones, tones, lang_ids
|
62 |
+
|
63 |
+
|
64 |
+
def sequence_to_text(sequence):
|
65 |
+
'''Converts a sequence of IDs back to a string'''
|
66 |
+
result = ''
|
67 |
+
for symbol_id in sequence:
|
68 |
+
s = _id_to_symbol[symbol_id]
|
69 |
+
result += s
|
70 |
+
return result
|
71 |
+
|
72 |
+
|
73 |
+
def _clean_text(text, cleaner_names):
|
74 |
+
for name in cleaner_names:
|
75 |
+
cleaner = getattr(cleaners, name)
|
76 |
+
if not cleaner:
|
77 |
+
raise Exception('Unknown cleaner: %s' % name)
|
78 |
+
text = cleaner(text)
|
79 |
+
return text
|
OpenVoice/openvoice/text/cleaners.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from openvoice.text.english import english_to_lazy_ipa, english_to_ipa2, english_to_lazy_ipa2
|
3 |
+
from openvoice.text.mandarin import number_to_chinese, chinese_to_bopomofo, latin_to_bopomofo, chinese_to_romaji, chinese_to_lazy_ipa, chinese_to_ipa, chinese_to_ipa2
|
4 |
+
|
5 |
+
def cjke_cleaners2(text):
|
6 |
+
text = re.sub(r'\[ZH\](.*?)\[ZH\]',
|
7 |
+
lambda x: chinese_to_ipa(x.group(1))+' ', text)
|
8 |
+
text = re.sub(r'\[JA\](.*?)\[JA\]',
|
9 |
+
lambda x: japanese_to_ipa2(x.group(1))+' ', text)
|
10 |
+
text = re.sub(r'\[KO\](.*?)\[KO\]',
|
11 |
+
lambda x: korean_to_ipa(x.group(1))+' ', text)
|
12 |
+
text = re.sub(r'\[EN\](.*?)\[EN\]',
|
13 |
+
lambda x: english_to_ipa2(x.group(1))+' ', text)
|
14 |
+
text = re.sub(r'\s+$', '', text)
|
15 |
+
text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text)
|
16 |
+
return text
|
OpenVoice/openvoice/text/english.py
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" from https://github.com/keithito/tacotron """
|
2 |
+
|
3 |
+
'''
|
4 |
+
Cleaners are transformations that run over the input text at both training and eval time.
|
5 |
+
|
6 |
+
Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
|
7 |
+
hyperparameter. Some cleaners are English-specific. You'll typically want to use:
|
8 |
+
1. "english_cleaners" for English text
|
9 |
+
2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
|
10 |
+
the Unidecode library (https://pypi.python.org/pypi/Unidecode)
|
11 |
+
3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
|
12 |
+
the symbols in symbols.py to match your data).
|
13 |
+
'''
|
14 |
+
|
15 |
+
|
16 |
+
# Regular expression matching whitespace:
|
17 |
+
|
18 |
+
|
19 |
+
import re
|
20 |
+
import inflect
|
21 |
+
from unidecode import unidecode
|
22 |
+
import eng_to_ipa as ipa
|
23 |
+
_inflect = inflect.engine()
|
24 |
+
_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])')
|
25 |
+
_decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)')
|
26 |
+
_pounds_re = re.compile(r'£([0-9\,]*[0-9]+)')
|
27 |
+
_dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)')
|
28 |
+
_ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)')
|
29 |
+
_number_re = re.compile(r'[0-9]+')
|
30 |
+
|
31 |
+
# List of (regular expression, replacement) pairs for abbreviations:
|
32 |
+
_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
|
33 |
+
('mrs', 'misess'),
|
34 |
+
('mr', 'mister'),
|
35 |
+
('dr', 'doctor'),
|
36 |
+
('st', 'saint'),
|
37 |
+
('co', 'company'),
|
38 |
+
('jr', 'junior'),
|
39 |
+
('maj', 'major'),
|
40 |
+
('gen', 'general'),
|
41 |
+
('drs', 'doctors'),
|
42 |
+
('rev', 'reverend'),
|
43 |
+
('lt', 'lieutenant'),
|
44 |
+
('hon', 'honorable'),
|
45 |
+
('sgt', 'sergeant'),
|
46 |
+
('capt', 'captain'),
|
47 |
+
('esq', 'esquire'),
|
48 |
+
('ltd', 'limited'),
|
49 |
+
('col', 'colonel'),
|
50 |
+
('ft', 'fort'),
|
51 |
+
]]
|
52 |
+
|
53 |
+
|
54 |
+
# List of (ipa, lazy ipa) pairs:
|
55 |
+
_lazy_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [
|
56 |
+
('r', 'ɹ'),
|
57 |
+
('æ', 'e'),
|
58 |
+
('ɑ', 'a'),
|
59 |
+
('ɔ', 'o'),
|
60 |
+
('ð', 'z'),
|
61 |
+
('θ', 's'),
|
62 |
+
('ɛ', 'e'),
|
63 |
+
('ɪ', 'i'),
|
64 |
+
('ʊ', 'u'),
|
65 |
+
('ʒ', 'ʥ'),
|
66 |
+
('ʤ', 'ʥ'),
|
67 |
+
('ˈ', '↓'),
|
68 |
+
]]
|
69 |
+
|
70 |
+
# List of (ipa, lazy ipa2) pairs:
|
71 |
+
_lazy_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [
|
72 |
+
('r', 'ɹ'),
|
73 |
+
('ð', 'z'),
|
74 |
+
('θ', 's'),
|
75 |
+
('ʒ', 'ʑ'),
|
76 |
+
('ʤ', 'dʑ'),
|
77 |
+
('ˈ', '↓'),
|
78 |
+
]]
|
79 |
+
|
80 |
+
# List of (ipa, ipa2) pairs
|
81 |
+
_ipa_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [
|
82 |
+
('r', 'ɹ'),
|
83 |
+
('ʤ', 'dʒ'),
|
84 |
+
('ʧ', 'tʃ')
|
85 |
+
]]
|
86 |
+
|
87 |
+
|
88 |
+
def expand_abbreviations(text):
|
89 |
+
for regex, replacement in _abbreviations:
|
90 |
+
text = re.sub(regex, replacement, text)
|
91 |
+
return text
|
92 |
+
|
93 |
+
|
94 |
+
def collapse_whitespace(text):
|
95 |
+
return re.sub(r'\s+', ' ', text)
|
96 |
+
|
97 |
+
|
98 |
+
def _remove_commas(m):
|
99 |
+
return m.group(1).replace(',', '')
|
100 |
+
|
101 |
+
|
102 |
+
def _expand_decimal_point(m):
|
103 |
+
return m.group(1).replace('.', ' point ')
|
104 |
+
|
105 |
+
|
106 |
+
def _expand_dollars(m):
|
107 |
+
match = m.group(1)
|
108 |
+
parts = match.split('.')
|
109 |
+
if len(parts) > 2:
|
110 |
+
return match + ' dollars' # Unexpected format
|
111 |
+
dollars = int(parts[0]) if parts[0] else 0
|
112 |
+
cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
|
113 |
+
if dollars and cents:
|
114 |
+
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
|
115 |
+
cent_unit = 'cent' if cents == 1 else 'cents'
|
116 |
+
return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit)
|
117 |
+
elif dollars:
|
118 |
+
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
|
119 |
+
return '%s %s' % (dollars, dollar_unit)
|
120 |
+
elif cents:
|
121 |
+
cent_unit = 'cent' if cents == 1 else 'cents'
|
122 |
+
return '%s %s' % (cents, cent_unit)
|
123 |
+
else:
|
124 |
+
return 'zero dollars'
|
125 |
+
|
126 |
+
|
127 |
+
def _expand_ordinal(m):
|
128 |
+
return _inflect.number_to_words(m.group(0))
|
129 |
+
|
130 |
+
|
131 |
+
def _expand_number(m):
|
132 |
+
num = int(m.group(0))
|
133 |
+
if num > 1000 and num < 3000:
|
134 |
+
if num == 2000:
|
135 |
+
return 'two thousand'
|
136 |
+
elif num > 2000 and num < 2010:
|
137 |
+
return 'two thousand ' + _inflect.number_to_words(num % 100)
|
138 |
+
elif num % 100 == 0:
|
139 |
+
return _inflect.number_to_words(num // 100) + ' hundred'
|
140 |
+
else:
|
141 |
+
return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ')
|
142 |
+
else:
|
143 |
+
return _inflect.number_to_words(num, andword='')
|
144 |
+
|
145 |
+
|
146 |
+
def normalize_numbers(text):
|
147 |
+
text = re.sub(_comma_number_re, _remove_commas, text)
|
148 |
+
text = re.sub(_pounds_re, r'\1 pounds', text)
|
149 |
+
text = re.sub(_dollars_re, _expand_dollars, text)
|
150 |
+
text = re.sub(_decimal_number_re, _expand_decimal_point, text)
|
151 |
+
text = re.sub(_ordinal_re, _expand_ordinal, text)
|
152 |
+
text = re.sub(_number_re, _expand_number, text)
|
153 |
+
return text
|
154 |
+
|
155 |
+
|
156 |
+
def mark_dark_l(text):
|
157 |
+
return re.sub(r'l([^aeiouæɑɔəɛɪʊ ]*(?: |$))', lambda x: 'ɫ'+x.group(1), text)
|
158 |
+
|
159 |
+
|
160 |
+
def english_to_ipa(text):
|
161 |
+
text = unidecode(text).lower()
|
162 |
+
text = expand_abbreviations(text)
|
163 |
+
text = normalize_numbers(text)
|
164 |
+
phonemes = ipa.convert(text)
|
165 |
+
phonemes = collapse_whitespace(phonemes)
|
166 |
+
return phonemes
|
167 |
+
|
168 |
+
|
169 |
+
def english_to_lazy_ipa(text):
|
170 |
+
text = english_to_ipa(text)
|
171 |
+
for regex, replacement in _lazy_ipa:
|
172 |
+
text = re.sub(regex, replacement, text)
|
173 |
+
return text
|
174 |
+
|
175 |
+
|
176 |
+
def english_to_ipa2(text):
|
177 |
+
text = english_to_ipa(text)
|
178 |
+
text = mark_dark_l(text)
|
179 |
+
for regex, replacement in _ipa_to_ipa2:
|
180 |
+
text = re.sub(regex, replacement, text)
|
181 |
+
return text.replace('...', '…')
|
182 |
+
|
183 |
+
|
184 |
+
def english_to_lazy_ipa2(text):
|
185 |
+
text = english_to_ipa(text)
|
186 |
+
for regex, replacement in _lazy_ipa2:
|
187 |
+
text = re.sub(regex, replacement, text)
|
188 |
+
return text
|
OpenVoice/openvoice/text/mandarin.py
ADDED
@@ -0,0 +1,326 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import re
|
4 |
+
from pypinyin import lazy_pinyin, BOPOMOFO
|
5 |
+
import jieba
|
6 |
+
import cn2an
|
7 |
+
import logging
|
8 |
+
|
9 |
+
|
10 |
+
# List of (Latin alphabet, bopomofo) pairs:
|
11 |
+
_latin_to_bopomofo = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [
|
12 |
+
('a', 'ㄟˉ'),
|
13 |
+
('b', 'ㄅㄧˋ'),
|
14 |
+
('c', 'ㄙㄧˉ'),
|
15 |
+
('d', 'ㄉㄧˋ'),
|
16 |
+
('e', 'ㄧˋ'),
|
17 |
+
('f', 'ㄝˊㄈㄨˋ'),
|
18 |
+
('g', 'ㄐㄧˋ'),
|
19 |
+
('h', 'ㄝˇㄑㄩˋ'),
|
20 |
+
('i', 'ㄞˋ'),
|
21 |
+
('j', 'ㄐㄟˋ'),
|
22 |
+
('k', 'ㄎㄟˋ'),
|
23 |
+
('l', 'ㄝˊㄛˋ'),
|
24 |
+
('m', 'ㄝˊㄇㄨˋ'),
|
25 |
+
('n', 'ㄣˉ'),
|
26 |
+
('o', 'ㄡˉ'),
|
27 |
+
('p', 'ㄆㄧˉ'),
|
28 |
+
('q', 'ㄎㄧㄡˉ'),
|
29 |
+
('r', 'ㄚˋ'),
|
30 |
+
('s', 'ㄝˊㄙˋ'),
|
31 |
+
('t', 'ㄊㄧˋ'),
|
32 |
+
('u', 'ㄧㄡˉ'),
|
33 |
+
('v', 'ㄨㄧˉ'),
|
34 |
+
('w', 'ㄉㄚˋㄅㄨˋㄌㄧㄡˋ'),
|
35 |
+
('x', 'ㄝˉㄎㄨˋㄙˋ'),
|
36 |
+
('y', 'ㄨㄞˋ'),
|
37 |
+
('z', 'ㄗㄟˋ')
|
38 |
+
]]
|
39 |
+
|
40 |
+
# List of (bopomofo, romaji) pairs:
|
41 |
+
_bopomofo_to_romaji = [(re.compile('%s' % x[0]), x[1]) for x in [
|
42 |
+
('ㄅㄛ', 'p⁼wo'),
|
43 |
+
('ㄆㄛ', 'pʰwo'),
|
44 |
+
('ㄇㄛ', 'mwo'),
|
45 |
+
('ㄈㄛ', 'fwo'),
|
46 |
+
('ㄅ', 'p⁼'),
|
47 |
+
('ㄆ', 'pʰ'),
|
48 |
+
('ㄇ', 'm'),
|
49 |
+
('ㄈ', 'f'),
|
50 |
+
('ㄉ', 't⁼'),
|
51 |
+
('ㄊ', 'tʰ'),
|
52 |
+
('ㄋ', 'n'),
|
53 |
+
('ㄌ', 'l'),
|
54 |
+
('ㄍ', 'k⁼'),
|
55 |
+
('ㄎ', 'kʰ'),
|
56 |
+
('ㄏ', 'h'),
|
57 |
+
('ㄐ', 'ʧ⁼'),
|
58 |
+
('ㄑ', 'ʧʰ'),
|
59 |
+
('ㄒ', 'ʃ'),
|
60 |
+
('ㄓ', 'ʦ`⁼'),
|
61 |
+
('ㄔ', 'ʦ`ʰ'),
|
62 |
+
('ㄕ', 's`'),
|
63 |
+
('ㄖ', 'ɹ`'),
|
64 |
+
('ㄗ', 'ʦ⁼'),
|
65 |
+
('ㄘ', 'ʦʰ'),
|
66 |
+
('ㄙ', 's'),
|
67 |
+
('ㄚ', 'a'),
|
68 |
+
('ㄛ', 'o'),
|
69 |
+
('ㄜ', 'ə'),
|
70 |
+
('ㄝ', 'e'),
|
71 |
+
('ㄞ', 'ai'),
|
72 |
+
('ㄟ', 'ei'),
|
73 |
+
('ㄠ', 'au'),
|
74 |
+
('ㄡ', 'ou'),
|
75 |
+
('ㄧㄢ', 'yeNN'),
|
76 |
+
('ㄢ', 'aNN'),
|
77 |
+
('ㄧㄣ', 'iNN'),
|
78 |
+
('ㄣ', 'əNN'),
|
79 |
+
('ㄤ', 'aNg'),
|
80 |
+
('ㄧㄥ', 'iNg'),
|
81 |
+
('ㄨㄥ', 'uNg'),
|
82 |
+
('ㄩㄥ', 'yuNg'),
|
83 |
+
('ㄥ', 'əNg'),
|
84 |
+
('ㄦ', 'əɻ'),
|
85 |
+
('ㄧ', 'i'),
|
86 |
+
('ㄨ', 'u'),
|
87 |
+
('ㄩ', 'ɥ'),
|
88 |
+
('ˉ', '→'),
|
89 |
+
('ˊ', '↑'),
|
90 |
+
('ˇ', '↓↑'),
|
91 |
+
('ˋ', '↓'),
|
92 |
+
('˙', ''),
|
93 |
+
(',', ','),
|
94 |
+
('。', '.'),
|
95 |
+
('!', '!'),
|
96 |
+
('?', '?'),
|
97 |
+
('—', '-')
|
98 |
+
]]
|
99 |
+
|
100 |
+
# List of (romaji, ipa) pairs:
|
101 |
+
_romaji_to_ipa = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [
|
102 |
+
('ʃy', 'ʃ'),
|
103 |
+
('ʧʰy', 'ʧʰ'),
|
104 |
+
('ʧ⁼y', 'ʧ⁼'),
|
105 |
+
('NN', 'n'),
|
106 |
+
('Ng', 'ŋ'),
|
107 |
+
('y', 'j'),
|
108 |
+
('h', 'x')
|
109 |
+
]]
|
110 |
+
|
111 |
+
# List of (bopomofo, ipa) pairs:
|
112 |
+
_bopomofo_to_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [
|
113 |
+
('ㄅㄛ', 'p⁼wo'),
|
114 |
+
('ㄆㄛ', 'pʰwo'),
|
115 |
+
('ㄇㄛ', 'mwo'),
|
116 |
+
('ㄈㄛ', 'fwo'),
|
117 |
+
('ㄅ', 'p⁼'),
|
118 |
+
('ㄆ', 'pʰ'),
|
119 |
+
('ㄇ', 'm'),
|
120 |
+
('ㄈ', 'f'),
|
121 |
+
('ㄉ', 't⁼'),
|
122 |
+
('ㄊ', 'tʰ'),
|
123 |
+
('ㄋ', 'n'),
|
124 |
+
('ㄌ', 'l'),
|
125 |
+
('ㄍ', 'k⁼'),
|
126 |
+
('ㄎ', 'kʰ'),
|
127 |
+
('ㄏ', 'x'),
|
128 |
+
('ㄐ', 'tʃ⁼'),
|
129 |
+
('ㄑ', 'tʃʰ'),
|
130 |
+
('ㄒ', 'ʃ'),
|
131 |
+
('ㄓ', 'ts`⁼'),
|
132 |
+
('ㄔ', 'ts`ʰ'),
|
133 |
+
('ㄕ', 's`'),
|
134 |
+
('ㄖ', 'ɹ`'),
|
135 |
+
('ㄗ', 'ts⁼'),
|
136 |
+
('ㄘ', 'tsʰ'),
|
137 |
+
('ㄙ', 's'),
|
138 |
+
('ㄚ', 'a'),
|
139 |
+
('ㄛ', 'o'),
|
140 |
+
('ㄜ', 'ə'),
|
141 |
+
('ㄝ', 'ɛ'),
|
142 |
+
('ㄞ', 'aɪ'),
|
143 |
+
('ㄟ', 'eɪ'),
|
144 |
+
('ㄠ', 'ɑʊ'),
|
145 |
+
('ㄡ', 'oʊ'),
|
146 |
+
('ㄧㄢ', 'jɛn'),
|
147 |
+
('ㄩㄢ', 'ɥæn'),
|
148 |
+
('ㄢ', 'an'),
|
149 |
+
('ㄧㄣ', 'in'),
|
150 |
+
('ㄩㄣ', 'ɥn'),
|
151 |
+
('ㄣ', 'ən'),
|
152 |
+
('ㄤ', 'ɑŋ'),
|
153 |
+
('ㄧㄥ', 'iŋ'),
|
154 |
+
('ㄨㄥ', 'ʊŋ'),
|
155 |
+
('ㄩㄥ', 'jʊŋ'),
|
156 |
+
('ㄥ', 'əŋ'),
|
157 |
+
('ㄦ', 'əɻ'),
|
158 |
+
('ㄧ', 'i'),
|
159 |
+
('ㄨ', 'u'),
|
160 |
+
('ㄩ', 'ɥ'),
|
161 |
+
('ˉ', '→'),
|
162 |
+
('ˊ', '↑'),
|
163 |
+
('ˇ', '↓↑'),
|
164 |
+
('ˋ', '↓'),
|
165 |
+
('˙', ''),
|
166 |
+
(',', ','),
|
167 |
+
('。', '.'),
|
168 |
+
('!', '!'),
|
169 |
+
('?', '?'),
|
170 |
+
('—', '-')
|
171 |
+
]]
|
172 |
+
|
173 |
+
# List of (bopomofo, ipa2) pairs:
|
174 |
+
_bopomofo_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [
|
175 |
+
('ㄅㄛ', 'pwo'),
|
176 |
+
('ㄆㄛ', 'pʰwo'),
|
177 |
+
('ㄇㄛ', 'mwo'),
|
178 |
+
('ㄈㄛ', 'fwo'),
|
179 |
+
('ㄅ', 'p'),
|
180 |
+
('ㄆ', 'pʰ'),
|
181 |
+
('ㄇ', 'm'),
|
182 |
+
('ㄈ', 'f'),
|
183 |
+
('ㄉ', 't'),
|
184 |
+
('ㄊ', 'tʰ'),
|
185 |
+
('ㄋ', 'n'),
|
186 |
+
('ㄌ', 'l'),
|
187 |
+
('ㄍ', 'k'),
|
188 |
+
('ㄎ', 'kʰ'),
|
189 |
+
('ㄏ', 'h'),
|
190 |
+
('ㄐ', 'tɕ'),
|
191 |
+
('ㄑ', 'tɕʰ'),
|
192 |
+
('ㄒ', 'ɕ'),
|
193 |
+
('ㄓ', 'tʂ'),
|
194 |
+
('ㄔ', 'tʂʰ'),
|
195 |
+
('ㄕ', 'ʂ'),
|
196 |
+
('ㄖ', 'ɻ'),
|
197 |
+
('ㄗ', 'ts'),
|
198 |
+
('ㄘ', 'tsʰ'),
|
199 |
+
('ㄙ', 's'),
|
200 |
+
('ㄚ', 'a'),
|
201 |
+
('ㄛ', 'o'),
|
202 |
+
('ㄜ', 'ɤ'),
|
203 |
+
('ㄝ', 'ɛ'),
|
204 |
+
('ㄞ', 'aɪ'),
|
205 |
+
('ㄟ', 'eɪ'),
|
206 |
+
('ㄠ', 'ɑʊ'),
|
207 |
+
('ㄡ', 'oʊ'),
|
208 |
+
('ㄧㄢ', 'jɛn'),
|
209 |
+
('ㄩㄢ', 'yæn'),
|
210 |
+
('ㄢ', 'an'),
|
211 |
+
('ㄧㄣ', 'in'),
|
212 |
+
('ㄩㄣ', 'yn'),
|
213 |
+
('ㄣ', 'ən'),
|
214 |
+
('ㄤ', 'ɑŋ'),
|
215 |
+
('ㄧㄥ', 'iŋ'),
|
216 |
+
('ㄨㄥ', 'ʊŋ'),
|
217 |
+
('ㄩㄥ', 'jʊŋ'),
|
218 |
+
('ㄥ', 'ɤŋ'),
|
219 |
+
('ㄦ', 'əɻ'),
|
220 |
+
('ㄧ', 'i'),
|
221 |
+
('ㄨ', 'u'),
|
222 |
+
('ㄩ', 'y'),
|
223 |
+
('ˉ', '˥'),
|
224 |
+
('ˊ', '˧˥'),
|
225 |
+
('ˇ', '˨˩˦'),
|
226 |
+
('ˋ', '˥˩'),
|
227 |
+
('˙', ''),
|
228 |
+
(',', ','),
|
229 |
+
('。', '.'),
|
230 |
+
('!', '!'),
|
231 |
+
('?', '?'),
|
232 |
+
('—', '-')
|
233 |
+
]]
|
234 |
+
|
235 |
+
|
236 |
+
def number_to_chinese(text):
|
237 |
+
numbers = re.findall(r'\d+(?:\.?\d+)?', text)
|
238 |
+
for number in numbers:
|
239 |
+
text = text.replace(number, cn2an.an2cn(number), 1)
|
240 |
+
return text
|
241 |
+
|
242 |
+
|
243 |
+
def chinese_to_bopomofo(text):
|
244 |
+
text = text.replace('、', ',').replace(';', ',').replace(':', ',')
|
245 |
+
words = jieba.lcut(text, cut_all=False)
|
246 |
+
text = ''
|
247 |
+
for word in words:
|
248 |
+
bopomofos = lazy_pinyin(word, BOPOMOFO)
|
249 |
+
if not re.search('[\u4e00-\u9fff]', word):
|
250 |
+
text += word
|
251 |
+
continue
|
252 |
+
for i in range(len(bopomofos)):
|
253 |
+
bopomofos[i] = re.sub(r'([\u3105-\u3129])$', r'\1ˉ', bopomofos[i])
|
254 |
+
if text != '':
|
255 |
+
text += ' '
|
256 |
+
text += ''.join(bopomofos)
|
257 |
+
return text
|
258 |
+
|
259 |
+
|
260 |
+
def latin_to_bopomofo(text):
|
261 |
+
for regex, replacement in _latin_to_bopomofo:
|
262 |
+
text = re.sub(regex, replacement, text)
|
263 |
+
return text
|
264 |
+
|
265 |
+
|
266 |
+
def bopomofo_to_romaji(text):
|
267 |
+
for regex, replacement in _bopomofo_to_romaji:
|
268 |
+
text = re.sub(regex, replacement, text)
|
269 |
+
return text
|
270 |
+
|
271 |
+
|
272 |
+
def bopomofo_to_ipa(text):
|
273 |
+
for regex, replacement in _bopomofo_to_ipa:
|
274 |
+
text = re.sub(regex, replacement, text)
|
275 |
+
return text
|
276 |
+
|
277 |
+
|
278 |
+
def bopomofo_to_ipa2(text):
|
279 |
+
for regex, replacement in _bopomofo_to_ipa2:
|
280 |
+
text = re.sub(regex, replacement, text)
|
281 |
+
return text
|
282 |
+
|
283 |
+
|
284 |
+
def chinese_to_romaji(text):
|
285 |
+
text = number_to_chinese(text)
|
286 |
+
text = chinese_to_bopomofo(text)
|
287 |
+
text = latin_to_bopomofo(text)
|
288 |
+
text = bopomofo_to_romaji(text)
|
289 |
+
text = re.sub('i([aoe])', r'y\1', text)
|
290 |
+
text = re.sub('u([aoəe])', r'w\1', text)
|
291 |
+
text = re.sub('([ʦsɹ]`[⁼ʰ]?)([→↓↑ ]+|$)',
|
292 |
+
r'\1ɹ`\2', text).replace('ɻ', 'ɹ`')
|
293 |
+
text = re.sub('([ʦs][⁼ʰ]?)([→↓↑ ]+|$)', r'\1ɹ\2', text)
|
294 |
+
return text
|
295 |
+
|
296 |
+
|
297 |
+
def chinese_to_lazy_ipa(text):
|
298 |
+
text = chinese_to_romaji(text)
|
299 |
+
for regex, replacement in _romaji_to_ipa:
|
300 |
+
text = re.sub(regex, replacement, text)
|
301 |
+
return text
|
302 |
+
|
303 |
+
|
304 |
+
def chinese_to_ipa(text):
|
305 |
+
text = number_to_chinese(text)
|
306 |
+
text = chinese_to_bopomofo(text)
|
307 |
+
text = latin_to_bopomofo(text)
|
308 |
+
text = bopomofo_to_ipa(text)
|
309 |
+
text = re.sub('i([aoe])', r'j\1', text)
|
310 |
+
text = re.sub('u([aoəe])', r'w\1', text)
|
311 |
+
text = re.sub('([sɹ]`[⁼ʰ]?)([→↓↑ ]+|$)',
|
312 |
+
r'\1ɹ`\2', text).replace('ɻ', 'ɹ`')
|
313 |
+
text = re.sub('([s][⁼ʰ]?)([→↓↑ ]+|$)', r'\1ɹ\2', text)
|
314 |
+
return text
|
315 |
+
|
316 |
+
|
317 |
+
def chinese_to_ipa2(text):
|
318 |
+
text = number_to_chinese(text)
|
319 |
+
text = chinese_to_bopomofo(text)
|
320 |
+
text = latin_to_bopomofo(text)
|
321 |
+
text = bopomofo_to_ipa2(text)
|
322 |
+
text = re.sub(r'i([aoe])', r'j\1', text)
|
323 |
+
text = re.sub(r'u([aoəe])', r'w\1', text)
|
324 |
+
text = re.sub(r'([ʂɹ]ʰ?)([˩˨˧˦˥ ]+|$)', r'\1ʅ\2', text)
|
325 |
+
text = re.sub(r'(sʰ?)([˩˨˧˦˥ ]+|$)', r'\1ɿ\2', text)
|
326 |
+
return text
|
OpenVoice/openvoice/text/symbols.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Defines the set of symbols used in text input to the model.
|
3 |
+
'''
|
4 |
+
|
5 |
+
# japanese_cleaners
|
6 |
+
# _pad = '_'
|
7 |
+
# _punctuation = ',.!?-'
|
8 |
+
# _letters = 'AEINOQUabdefghijkmnoprstuvwyzʃʧ↓↑ '
|
9 |
+
|
10 |
+
|
11 |
+
'''# japanese_cleaners2
|
12 |
+
_pad = '_'
|
13 |
+
_punctuation = ',.!?-~…'
|
14 |
+
_letters = 'AEINOQUabdefghijkmnoprstuvwyzʃʧʦ↓↑ '
|
15 |
+
'''
|
16 |
+
|
17 |
+
|
18 |
+
'''# korean_cleaners
|
19 |
+
_pad = '_'
|
20 |
+
_punctuation = ',.!?…~'
|
21 |
+
_letters = 'ㄱㄴㄷㄹㅁㅂㅅㅇㅈㅊㅋㅌㅍㅎㄲㄸㅃㅆㅉㅏㅓㅗㅜㅡㅣㅐㅔ '
|
22 |
+
'''
|
23 |
+
|
24 |
+
'''# chinese_cleaners
|
25 |
+
_pad = '_'
|
26 |
+
_punctuation = ',。!?—…'
|
27 |
+
_letters = 'ㄅㄆㄇㄈㄉㄊㄋㄌㄍㄎㄏㄐㄑㄒㄓㄔㄕㄖㄗㄘㄙㄚㄛㄜㄝㄞㄟㄠㄡㄢㄣㄤㄥㄦㄧㄨㄩˉˊˇˋ˙ '
|
28 |
+
'''
|
29 |
+
|
30 |
+
# # zh_ja_mixture_cleaners
|
31 |
+
# _pad = '_'
|
32 |
+
# _punctuation = ',.!?-~…'
|
33 |
+
# _letters = 'AEINOQUabdefghijklmnoprstuvwyzʃʧʦɯɹəɥ⁼ʰ`→↓↑ '
|
34 |
+
|
35 |
+
|
36 |
+
'''# sanskrit_cleaners
|
37 |
+
_pad = '_'
|
38 |
+
_punctuation = '।'
|
39 |
+
_letters = 'ँंःअआइईउऊऋएऐओऔकखगघङचछजझञटठडढणतथदधनपफबभमयरलळवशषसहऽािीुूृॄेैोौ्ॠॢ '
|
40 |
+
'''
|
41 |
+
|
42 |
+
'''# cjks_cleaners
|
43 |
+
_pad = '_'
|
44 |
+
_punctuation = ',.!?-~…'
|
45 |
+
_letters = 'NQabdefghijklmnopstuvwxyzʃʧʥʦɯɹəɥçɸɾβŋɦː⁼ʰ`^#*=→↓↑ '
|
46 |
+
'''
|
47 |
+
|
48 |
+
'''# thai_cleaners
|
49 |
+
_pad = '_'
|
50 |
+
_punctuation = '.!? '
|
51 |
+
_letters = 'กขฃคฆงจฉชซฌญฎฏฐฑฒณดตถทธนบปผฝพฟภมยรฤลวศษสหฬอฮฯะัาำิีึืุูเแโใไๅๆ็่้๊๋์'
|
52 |
+
'''
|
53 |
+
|
54 |
+
# # cjke_cleaners2
|
55 |
+
_pad = '_'
|
56 |
+
_punctuation = ',.!?-~…'
|
57 |
+
_letters = 'NQabdefghijklmnopstuvwxyzɑæʃʑçɯɪɔɛɹðəɫɥɸʊɾʒθβŋɦ⁼ʰ`^#*=ˈˌ→↓↑ '
|
58 |
+
|
59 |
+
|
60 |
+
'''# shanghainese_cleaners
|
61 |
+
_pad = '_'
|
62 |
+
_punctuation = ',.!?…'
|
63 |
+
_letters = 'abdfghiklmnopstuvyzøŋȵɑɔɕəɤɦɪɿʑʔʰ̩̃ᴀᴇ15678 '
|
64 |
+
'''
|
65 |
+
|
66 |
+
'''# chinese_dialect_cleaners
|
67 |
+
_pad = '_'
|
68 |
+
_punctuation = ',.!?~…─'
|
69 |
+
_letters = '#Nabdefghijklmnoprstuvwxyzæçøŋœȵɐɑɒɓɔɕɗɘəɚɛɜɣɤɦɪɭɯɵɷɸɻɾɿʂʅʊʋʌʏʑʔʦʮʰʷˀː˥˦˧˨˩̥̩̃̚ᴀᴇ↑↓∅ⱼ '
|
70 |
+
'''
|
71 |
+
|
72 |
+
# Export all symbols:
|
73 |
+
symbols = [_pad] + list(_punctuation) + list(_letters)
|
74 |
+
|
75 |
+
# Special symbol ids
|
76 |
+
SPACE_ID = symbols.index(" ")
|
77 |
+
|
78 |
+
num_ja_tones = 1
|
79 |
+
num_kr_tones = 1
|
80 |
+
num_zh_tones = 6
|
81 |
+
num_en_tones = 4
|
82 |
+
|
83 |
+
language_tone_start_map = {
|
84 |
+
"ZH": 0,
|
85 |
+
"JP": num_zh_tones,
|
86 |
+
"EN": num_zh_tones + num_ja_tones,
|
87 |
+
'KR': num_zh_tones + num_ja_tones + num_en_tones,
|
88 |
+
}
|
OpenVoice/openvoice/transforms.py
ADDED
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.nn import functional as F
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
|
7 |
+
DEFAULT_MIN_BIN_WIDTH = 1e-3
|
8 |
+
DEFAULT_MIN_BIN_HEIGHT = 1e-3
|
9 |
+
DEFAULT_MIN_DERIVATIVE = 1e-3
|
10 |
+
|
11 |
+
|
12 |
+
def piecewise_rational_quadratic_transform(
|
13 |
+
inputs,
|
14 |
+
unnormalized_widths,
|
15 |
+
unnormalized_heights,
|
16 |
+
unnormalized_derivatives,
|
17 |
+
inverse=False,
|
18 |
+
tails=None,
|
19 |
+
tail_bound=1.0,
|
20 |
+
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
|
21 |
+
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
|
22 |
+
min_derivative=DEFAULT_MIN_DERIVATIVE,
|
23 |
+
):
|
24 |
+
if tails is None:
|
25 |
+
spline_fn = rational_quadratic_spline
|
26 |
+
spline_kwargs = {}
|
27 |
+
else:
|
28 |
+
spline_fn = unconstrained_rational_quadratic_spline
|
29 |
+
spline_kwargs = {"tails": tails, "tail_bound": tail_bound}
|
30 |
+
|
31 |
+
outputs, logabsdet = spline_fn(
|
32 |
+
inputs=inputs,
|
33 |
+
unnormalized_widths=unnormalized_widths,
|
34 |
+
unnormalized_heights=unnormalized_heights,
|
35 |
+
unnormalized_derivatives=unnormalized_derivatives,
|
36 |
+
inverse=inverse,
|
37 |
+
min_bin_width=min_bin_width,
|
38 |
+
min_bin_height=min_bin_height,
|
39 |
+
min_derivative=min_derivative,
|
40 |
+
**spline_kwargs
|
41 |
+
)
|
42 |
+
return outputs, logabsdet
|
43 |
+
|
44 |
+
|
45 |
+
def searchsorted(bin_locations, inputs, eps=1e-6):
|
46 |
+
bin_locations[..., -1] += eps
|
47 |
+
return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1
|
48 |
+
|
49 |
+
|
50 |
+
def unconstrained_rational_quadratic_spline(
|
51 |
+
inputs,
|
52 |
+
unnormalized_widths,
|
53 |
+
unnormalized_heights,
|
54 |
+
unnormalized_derivatives,
|
55 |
+
inverse=False,
|
56 |
+
tails="linear",
|
57 |
+
tail_bound=1.0,
|
58 |
+
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
|
59 |
+
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
|
60 |
+
min_derivative=DEFAULT_MIN_DERIVATIVE,
|
61 |
+
):
|
62 |
+
inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
|
63 |
+
outside_interval_mask = ~inside_interval_mask
|
64 |
+
|
65 |
+
outputs = torch.zeros_like(inputs)
|
66 |
+
logabsdet = torch.zeros_like(inputs)
|
67 |
+
|
68 |
+
if tails == "linear":
|
69 |
+
unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
|
70 |
+
constant = np.log(np.exp(1 - min_derivative) - 1)
|
71 |
+
unnormalized_derivatives[..., 0] = constant
|
72 |
+
unnormalized_derivatives[..., -1] = constant
|
73 |
+
|
74 |
+
outputs[outside_interval_mask] = inputs[outside_interval_mask]
|
75 |
+
logabsdet[outside_interval_mask] = 0
|
76 |
+
else:
|
77 |
+
raise RuntimeError("{} tails are not implemented.".format(tails))
|
78 |
+
|
79 |
+
(
|
80 |
+
outputs[inside_interval_mask],
|
81 |
+
logabsdet[inside_interval_mask],
|
82 |
+
) = rational_quadratic_spline(
|
83 |
+
inputs=inputs[inside_interval_mask],
|
84 |
+
unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
|
85 |
+
unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
|
86 |
+
unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
|
87 |
+
inverse=inverse,
|
88 |
+
left=-tail_bound,
|
89 |
+
right=tail_bound,
|
90 |
+
bottom=-tail_bound,
|
91 |
+
top=tail_bound,
|
92 |
+
min_bin_width=min_bin_width,
|
93 |
+
min_bin_height=min_bin_height,
|
94 |
+
min_derivative=min_derivative,
|
95 |
+
)
|
96 |
+
|
97 |
+
return outputs, logabsdet
|
98 |
+
|
99 |
+
|
100 |
+
def rational_quadratic_spline(
|
101 |
+
inputs,
|
102 |
+
unnormalized_widths,
|
103 |
+
unnormalized_heights,
|
104 |
+
unnormalized_derivatives,
|
105 |
+
inverse=False,
|
106 |
+
left=0.0,
|
107 |
+
right=1.0,
|
108 |
+
bottom=0.0,
|
109 |
+
top=1.0,
|
110 |
+
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
|
111 |
+
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
|
112 |
+
min_derivative=DEFAULT_MIN_DERIVATIVE,
|
113 |
+
):
|
114 |
+
if torch.min(inputs) < left or torch.max(inputs) > right:
|
115 |
+
raise ValueError("Input to a transform is not within its domain")
|
116 |
+
|
117 |
+
num_bins = unnormalized_widths.shape[-1]
|
118 |
+
|
119 |
+
if min_bin_width * num_bins > 1.0:
|
120 |
+
raise ValueError("Minimal bin width too large for the number of bins")
|
121 |
+
if min_bin_height * num_bins > 1.0:
|
122 |
+
raise ValueError("Minimal bin height too large for the number of bins")
|
123 |
+
|
124 |
+
widths = F.softmax(unnormalized_widths, dim=-1)
|
125 |
+
widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
|
126 |
+
cumwidths = torch.cumsum(widths, dim=-1)
|
127 |
+
cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)
|
128 |
+
cumwidths = (right - left) * cumwidths + left
|
129 |
+
cumwidths[..., 0] = left
|
130 |
+
cumwidths[..., -1] = right
|
131 |
+
widths = cumwidths[..., 1:] - cumwidths[..., :-1]
|
132 |
+
|
133 |
+
derivatives = min_derivative + F.softplus(unnormalized_derivatives)
|
134 |
+
|
135 |
+
heights = F.softmax(unnormalized_heights, dim=-1)
|
136 |
+
heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
|
137 |
+
cumheights = torch.cumsum(heights, dim=-1)
|
138 |
+
cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)
|
139 |
+
cumheights = (top - bottom) * cumheights + bottom
|
140 |
+
cumheights[..., 0] = bottom
|
141 |
+
cumheights[..., -1] = top
|
142 |
+
heights = cumheights[..., 1:] - cumheights[..., :-1]
|
143 |
+
|
144 |
+
if inverse:
|
145 |
+
bin_idx = searchsorted(cumheights, inputs)[..., None]
|
146 |
+
else:
|
147 |
+
bin_idx = searchsorted(cumwidths, inputs)[..., None]
|
148 |
+
|
149 |
+
input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
|
150 |
+
input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
|
151 |
+
|
152 |
+
input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
|
153 |
+
delta = heights / widths
|
154 |
+
input_delta = delta.gather(-1, bin_idx)[..., 0]
|
155 |
+
|
156 |
+
input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
|
157 |
+
input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
|
158 |
+
|
159 |
+
input_heights = heights.gather(-1, bin_idx)[..., 0]
|
160 |
+
|
161 |
+
if inverse:
|
162 |
+
a = (inputs - input_cumheights) * (
|
163 |
+
input_derivatives + input_derivatives_plus_one - 2 * input_delta
|
164 |
+
) + input_heights * (input_delta - input_derivatives)
|
165 |
+
b = input_heights * input_derivatives - (inputs - input_cumheights) * (
|
166 |
+
input_derivatives + input_derivatives_plus_one - 2 * input_delta
|
167 |
+
)
|
168 |
+
c = -input_delta * (inputs - input_cumheights)
|
169 |
+
|
170 |
+
discriminant = b.pow(2) - 4 * a * c
|
171 |
+
assert (discriminant >= 0).all()
|
172 |
+
|
173 |
+
root = (2 * c) / (-b - torch.sqrt(discriminant))
|
174 |
+
outputs = root * input_bin_widths + input_cumwidths
|
175 |
+
|
176 |
+
theta_one_minus_theta = root * (1 - root)
|
177 |
+
denominator = input_delta + (
|
178 |
+
(input_derivatives + input_derivatives_plus_one - 2 * input_delta)
|
179 |
+
* theta_one_minus_theta
|
180 |
+
)
|
181 |
+
derivative_numerator = input_delta.pow(2) * (
|
182 |
+
input_derivatives_plus_one * root.pow(2)
|
183 |
+
+ 2 * input_delta * theta_one_minus_theta
|
184 |
+
+ input_derivatives * (1 - root).pow(2)
|
185 |
+
)
|
186 |
+
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
|
187 |
+
|
188 |
+
return outputs, -logabsdet
|
189 |
+
else:
|
190 |
+
theta = (inputs - input_cumwidths) / input_bin_widths
|
191 |
+
theta_one_minus_theta = theta * (1 - theta)
|
192 |
+
|
193 |
+
numerator = input_heights * (
|
194 |
+
input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta
|
195 |
+
)
|
196 |
+
denominator = input_delta + (
|
197 |
+
(input_derivatives + input_derivatives_plus_one - 2 * input_delta)
|
198 |
+
* theta_one_minus_theta
|
199 |
+
)
|
200 |
+
outputs = input_cumheights + numerator / denominator
|
201 |
+
|
202 |
+
derivative_numerator = input_delta.pow(2) * (
|
203 |
+
input_derivatives_plus_one * theta.pow(2)
|
204 |
+
+ 2 * input_delta * theta_one_minus_theta
|
205 |
+
+ input_derivatives * (1 - theta).pow(2)
|
206 |
+
)
|
207 |
+
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
|
208 |
+
|
209 |
+
return outputs, logabsdet
|
OpenVoice/openvoice/utils.py
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import json
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
|
6 |
+
def get_hparams_from_file(config_path):
|
7 |
+
with open(config_path, "r", encoding="utf-8") as f:
|
8 |
+
data = f.read()
|
9 |
+
config = json.loads(data)
|
10 |
+
|
11 |
+
hparams = HParams(**config)
|
12 |
+
return hparams
|
13 |
+
|
14 |
+
class HParams:
|
15 |
+
def __init__(self, **kwargs):
|
16 |
+
for k, v in kwargs.items():
|
17 |
+
if type(v) == dict:
|
18 |
+
v = HParams(**v)
|
19 |
+
self[k] = v
|
20 |
+
|
21 |
+
def keys(self):
|
22 |
+
return self.__dict__.keys()
|
23 |
+
|
24 |
+
def items(self):
|
25 |
+
return self.__dict__.items()
|
26 |
+
|
27 |
+
def values(self):
|
28 |
+
return self.__dict__.values()
|
29 |
+
|
30 |
+
def __len__(self):
|
31 |
+
return len(self.__dict__)
|
32 |
+
|
33 |
+
def __getitem__(self, key):
|
34 |
+
return getattr(self, key)
|
35 |
+
|
36 |
+
def __setitem__(self, key, value):
|
37 |
+
return setattr(self, key, value)
|
38 |
+
|
39 |
+
def __contains__(self, key):
|
40 |
+
return key in self.__dict__
|
41 |
+
|
42 |
+
def __repr__(self):
|
43 |
+
return self.__dict__.__repr__()
|
44 |
+
|
45 |
+
|
46 |
+
def string_to_bits(string, pad_len=8):
|
47 |
+
# Convert each character to its ASCII value
|
48 |
+
ascii_values = [ord(char) for char in string]
|
49 |
+
|
50 |
+
# Convert ASCII values to binary representation
|
51 |
+
binary_values = [bin(value)[2:].zfill(8) for value in ascii_values]
|
52 |
+
|
53 |
+
# Convert binary strings to integer arrays
|
54 |
+
bit_arrays = [[int(bit) for bit in binary] for binary in binary_values]
|
55 |
+
|
56 |
+
# Convert list of arrays to NumPy array
|
57 |
+
numpy_array = np.array(bit_arrays)
|
58 |
+
numpy_array_full = np.zeros((pad_len, 8), dtype=numpy_array.dtype)
|
59 |
+
numpy_array_full[:, 2] = 1
|
60 |
+
max_len = min(pad_len, len(numpy_array))
|
61 |
+
numpy_array_full[:max_len] = numpy_array[:max_len]
|
62 |
+
return numpy_array_full
|
63 |
+
|
64 |
+
|
65 |
+
def bits_to_string(bits_array):
|
66 |
+
# Convert each row of the array to a binary string
|
67 |
+
binary_values = [''.join(str(bit) for bit in row) for row in bits_array]
|
68 |
+
|
69 |
+
# Convert binary strings to ASCII values
|
70 |
+
ascii_values = [int(binary, 2) for binary in binary_values]
|
71 |
+
|
72 |
+
# Convert ASCII values to characters
|
73 |
+
output_string = ''.join(chr(value) for value in ascii_values)
|
74 |
+
|
75 |
+
return output_string
|
76 |
+
|
77 |
+
|
78 |
+
def split_sentence(text, min_len=10, language_str='[EN]'):
|
79 |
+
if language_str in ['EN']:
|
80 |
+
sentences = split_sentences_latin(text, min_len=min_len)
|
81 |
+
else:
|
82 |
+
sentences = split_sentences_zh(text, min_len=min_len)
|
83 |
+
return sentences
|
84 |
+
|
85 |
+
def split_sentences_latin(text, min_len=10):
|
86 |
+
"""Split Long sentences into list of short ones
|
87 |
+
|
88 |
+
Args:
|
89 |
+
str: Input sentences.
|
90 |
+
|
91 |
+
Returns:
|
92 |
+
List[str]: list of output sentences.
|
93 |
+
"""
|
94 |
+
# deal with dirty sentences
|
95 |
+
text = re.sub('[。!?;]', '.', text)
|
96 |
+
text = re.sub('[,]', ',', text)
|
97 |
+
text = re.sub('[“”]', '"', text)
|
98 |
+
text = re.sub('[‘’]', "'", text)
|
99 |
+
text = re.sub(r"[\<\>\(\)\[\]\"\«\»]+", "", text)
|
100 |
+
text = re.sub('[\n\t ]+', ' ', text)
|
101 |
+
text = re.sub('([,.!?;])', r'\1 $#!', text)
|
102 |
+
# split
|
103 |
+
sentences = [s.strip() for s in text.split('$#!')]
|
104 |
+
if len(sentences[-1]) == 0: del sentences[-1]
|
105 |
+
|
106 |
+
new_sentences = []
|
107 |
+
new_sent = []
|
108 |
+
count_len = 0
|
109 |
+
for ind, sent in enumerate(sentences):
|
110 |
+
# print(sent)
|
111 |
+
new_sent.append(sent)
|
112 |
+
count_len += len(sent.split(" "))
|
113 |
+
if count_len > min_len or ind == len(sentences) - 1:
|
114 |
+
count_len = 0
|
115 |
+
new_sentences.append(' '.join(new_sent))
|
116 |
+
new_sent = []
|
117 |
+
return merge_short_sentences_latin(new_sentences)
|
118 |
+
|
119 |
+
|
120 |
+
def merge_short_sentences_latin(sens):
|
121 |
+
"""Avoid short sentences by merging them with the following sentence.
|
122 |
+
|
123 |
+
Args:
|
124 |
+
List[str]: list of input sentences.
|
125 |
+
|
126 |
+
Returns:
|
127 |
+
List[str]: list of output sentences.
|
128 |
+
"""
|
129 |
+
sens_out = []
|
130 |
+
for s in sens:
|
131 |
+
# If the previous sentence is too short, merge them with
|
132 |
+
# the current sentence.
|
133 |
+
if len(sens_out) > 0 and len(sens_out[-1].split(" ")) <= 2:
|
134 |
+
sens_out[-1] = sens_out[-1] + " " + s
|
135 |
+
else:
|
136 |
+
sens_out.append(s)
|
137 |
+
try:
|
138 |
+
if len(sens_out[-1].split(" ")) <= 2:
|
139 |
+
sens_out[-2] = sens_out[-2] + " " + sens_out[-1]
|
140 |
+
sens_out.pop(-1)
|
141 |
+
except:
|
142 |
+
pass
|
143 |
+
return sens_out
|
144 |
+
|
145 |
+
def split_sentences_zh(text, min_len=10):
|
146 |
+
text = re.sub('[。!?;]', '.', text)
|
147 |
+
text = re.sub('[,]', ',', text)
|
148 |
+
# 将文本中的换行符、空格和制表符替换为空格
|
149 |
+
text = re.sub('[\n\t ]+', ' ', text)
|
150 |
+
# 在标点符号后添加一个空格
|
151 |
+
text = re.sub('([,.!?;])', r'\1 $#!', text)
|
152 |
+
# 分隔句子并去除前后空格
|
153 |
+
# sentences = [s.strip() for s in re.split('(。|!|?|;)', text)]
|
154 |
+
sentences = [s.strip() for s in text.split('$#!')]
|
155 |
+
if len(sentences[-1]) == 0: del sentences[-1]
|
156 |
+
|
157 |
+
new_sentences = []
|
158 |
+
new_sent = []
|
159 |
+
count_len = 0
|
160 |
+
for ind, sent in enumerate(sentences):
|
161 |
+
new_sent.append(sent)
|
162 |
+
count_len += len(sent)
|
163 |
+
if count_len > min_len or ind == len(sentences) - 1:
|
164 |
+
count_len = 0
|
165 |
+
new_sentences.append(' '.join(new_sent))
|
166 |
+
new_sent = []
|
167 |
+
return merge_short_sentences_zh(new_sentences)
|
168 |
+
|
169 |
+
|
170 |
+
def merge_short_sentences_zh(sens):
|
171 |
+
# return sens
|
172 |
+
"""Avoid short sentences by merging them with the following sentence.
|
173 |
+
|
174 |
+
Args:
|
175 |
+
List[str]: list of input sentences.
|
176 |
+
|
177 |
+
Returns:
|
178 |
+
List[str]: list of output sentences.
|
179 |
+
"""
|
180 |
+
sens_out = []
|
181 |
+
for s in sens:
|
182 |
+
# If the previous sentense is too short, merge them with
|
183 |
+
# the current sentence.
|
184 |
+
if len(sens_out) > 0 and len(sens_out[-1]) <= 2:
|
185 |
+
sens_out[-1] = sens_out[-1] + " " + s
|
186 |
+
else:
|
187 |
+
sens_out.append(s)
|
188 |
+
try:
|
189 |
+
if len(sens_out[-1]) <= 2:
|
190 |
+
sens_out[-2] = sens_out[-2] + " " + sens_out[-1]
|
191 |
+
sens_out.pop(-1)
|
192 |
+
except:
|
193 |
+
pass
|
194 |
+
return sens_out
|
README.md
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: 통합 TTS + 음성 변환 앱
|
3 |
+
emoji: 🎤
|
4 |
+
colorFrom: blue
|
5 |
+
colorTo: purple
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 4.44.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
license: mit
|
11 |
+
short_description: OpenVoice V2와 Seed-VC를 결합한 텍스트-음성 변환 및 음성 변환 앱
|
12 |
+
---
|
13 |
+
|
14 |
+
# 통합 TTS + 음성 변환 앱
|
15 |
+
|
16 |
+
이 앱은 OpenVoice V2와 Seed-VC를 결합하여 텍스트를 음성으로 변환한 후, 참조 음성의 스타일로 변환하는 통합 솔루션입니다.
|
17 |
+
|
18 |
+
## 🚀 주요 기능
|
19 |
+
|
20 |
+
1. **텍스트-음성 변환 (TTS)**: OpenVoice V2를 사용하여 텍스트를 음성으로 변환
|
21 |
+
2. **음성 복제**: 참조 음성의 톤과 스타일을 복제
|
22 |
+
3. **음성 변환**: Seed-VC를 사용하여 생성된 음성을 참조 음성의 스타일로 변환
|
23 |
+
|
24 |
+
## 📋 처리 과정
|
25 |
+
|
26 |
+
1. **입력**: 텍스트 + 참조 음성 + 변환 파라미터들
|
27 |
+
2. **1단계**: 텍스트 → 참조 음성 톤으로 음성 생성 (OpenVoice V2)
|
28 |
+
3. **2단계**: 생성된 음성 → 참조 음성 스타일로 변환 (Seed-VC)
|
29 |
+
4. **출력**: 최종 변환된 음성
|
30 |
+
|
31 |
+
## 🛠️ 설치 및 실행
|
32 |
+
|
33 |
+
### 요구사항
|
34 |
+
- Python 3.8+
|
35 |
+
- CUDA 지원 GPU (권장)
|
36 |
+
|
37 |
+
### 설치
|
38 |
+
```bash
|
39 |
+
pip install -r requirements.txt
|
40 |
+
```
|
41 |
+
|
42 |
+
### 실행
|
43 |
+
```bash
|
44 |
+
python app.py
|
45 |
+
```
|
46 |
+
|
47 |
+
## 📝 사용법
|
48 |
+
|
49 |
+
1. **텍스트 입력**: 변환하고 싶은 텍스트를 입력하세요
|
50 |
+
2. **참조 음성 업로드**: 3-10초 길이의 깨끗한 참조 음성을 업로드하세요
|
51 |
+
3. **기본 설정 조정**:
|
52 |
+
- 기본 음성 스타일 선택
|
53 |
+
- 음성 속도 조정 (0.6x ~ 1.4x)
|
54 |
+
4. **음성 변환 파라미터 조정**:
|
55 |
+
- 확산 단계 (1-200, 기본값: 25)
|
56 |
+
- 길이 조정 (0.5-2.0, 기본값: 1.0)
|
57 |
+
- CFG 비율 (0.0-1.0, 기본값: 0.7)
|
58 |
+
- 피치 시프트 (-24~24 반음, 기본값: 0)
|
59 |
+
- F0 조건부 모델 사용 여부
|
60 |
+
- 자동 F0 조정 여부
|
61 |
+
5. **"통합 변환" 버튼 클릭**
|
62 |
+
|
63 |
+
## ⚙️ 파라미터 설명
|
64 |
+
|
65 |
+
### TTS 파라미터
|
66 |
+
- **기본 음성 스타일**: 지원 언어별 기본 음성 스타일 (EN, ES, FR, ZH, JP, KR)
|
67 |
+
- **음성 속도**: 생성될 음성의 속도 (1.0이 정상 속도)
|
68 |
+
|
69 |
+
### 음성 변환 파라미터
|
70 |
+
- **확산 단계**: 변환 품질에 영향 (높을수록 고품질, 50-100 권장)
|
71 |
+
- **길이 조정**: 음성 길이 조정 (<1.0: 빠르게, >1.0: 느리게)
|
72 |
+
- **CFG 비율**: 변환 강도 조정
|
73 |
+
- **피치 시프트**: 음높이 조정 (반음 단위)
|
74 |
+
- **F0 조건부 모델**: 노래 음성 변환 시 필요
|
75 |
+
- **자동 F0 조정**: 대상 음성에 맞게 자동으로 F0 조정
|
76 |
+
|
77 |
+
## 🔧 기술 스택
|
78 |
+
|
79 |
+
- **OpenVoice V2**: 텍스트-음성 변환 및 음성 복제
|
80 |
+
- **Seed-VC**: 고품질 음성 변환
|
81 |
+
- **MeloTTS**: 다국어 TTS 엔진
|
82 |
+
- **Whisper**: 음성 인코딩
|
83 |
+
- **BigVGAN**: 고품질 보코더
|
84 |
+
- **Gradio**: 웹 인터페이스
|
85 |
+
|
86 |
+
## 📁 프로젝트 구조
|
87 |
+
|
88 |
+
```
|
89 |
+
Lucy_5/
|
90 |
+
├── app.py # 메인 애플리케이션
|
91 |
+
├── requirements.txt # 의존성 패키지
|
92 |
+
├── README.md # 프로젝트 설명
|
93 |
+
├── OpenVoice/ # OpenVoice V2 모듈
|
94 |
+
├── modules/ # Seed-VC 모듈
|
95 |
+
└── hf_utils.py # Hugging Face 유틸리티
|
96 |
+
```
|
97 |
+
|
98 |
+
## 🚀 Hugging Face Spaces 배포
|
99 |
+
|
100 |
+
이 앱은 Hugging Face Spaces에서 배포할 수 있습니다:
|
101 |
+
|
102 |
+
1. Hugging Face Spaces에서 새 Space 생성
|
103 |
+
2. 이 저장소를 클론
|
104 |
+
3. `requirements.txt`에 명시된 의존성 설치
|
105 |
+
4. GPU Space로 설정 (권장)
|
106 |
+
|
107 |
+
## 💡 팁
|
108 |
+
|
109 |
+
- 참조 음성은 3-10초 길이의 깨끗한 음성을 사용하세요
|
110 |
+
- 노래 음성 변환을 원한다면 "F0 조건부 모델 사용"을 체크하세요
|
111 |
+
- 품질을 높이려면 확산 단계를 50-100으로 설정하세요
|
112 |
+
- 긴 텍스트의 경우 처리 시간이 오래 걸릴 수 있습니다
|
113 |
+
|
114 |
+
## 📄 라이선스
|
115 |
+
|
116 |
+
이 프로젝트는 원본 라이브러리들의 라이선스를 따릅니다:
|
117 |
+
- OpenVoice V2
|
118 |
+
- Seed-VC
|
119 |
+
- MeloTTS
|
app.py
ADDED
@@ -0,0 +1,843 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import tempfile
|
4 |
+
import zipfile
|
5 |
+
import shutil
|
6 |
+
import subprocess
|
7 |
+
import importlib
|
8 |
+
import traceback
|
9 |
+
# from typing import Optional, Dict, Any
|
10 |
+
|
11 |
+
# Force flush stdout and stderr for better logging in Hugging Face Spaces
|
12 |
+
def log_print(*args, **kwargs):
|
13 |
+
print(*args, **kwargs)
|
14 |
+
sys.stdout.flush()
|
15 |
+
|
16 |
+
def log_error(*args, **kwargs):
|
17 |
+
print(*args, file=sys.stderr, **kwargs)
|
18 |
+
sys.stderr.flush()
|
19 |
+
|
20 |
+
try:
|
21 |
+
import gradio as gr
|
22 |
+
import spaces
|
23 |
+
import torch
|
24 |
+
import torchaudio
|
25 |
+
import librosa
|
26 |
+
import yaml
|
27 |
+
import numpy as np
|
28 |
+
import nltk
|
29 |
+
import requests
|
30 |
+
from pydub import AudioSegment
|
31 |
+
import soundfile as sf
|
32 |
+
except ImportError as e:
|
33 |
+
log_error(f"Import error: {e}")
|
34 |
+
log_error("Please install required packages using: pip install -r requirements.txt")
|
35 |
+
sys.exit(1)
|
36 |
+
|
37 |
+
# Ensure module import works regardless of working directory location
|
38 |
+
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
39 |
+
OPENVOICE_DIR = os.path.join(ROOT_DIR, "OpenVoice")
|
40 |
+
if os.path.isdir(OPENVOICE_DIR) and OPENVOICE_DIR not in sys.path:
|
41 |
+
sys.path.insert(0, OPENVOICE_DIR)
|
42 |
+
|
43 |
+
# Also work inside the OpenVoice project subdir so relative ckpt paths resolve
|
44 |
+
PROJECT_SUBDIR = "OpenVoice"
|
45 |
+
if os.path.isdir(os.path.join(ROOT_DIR, PROJECT_SUBDIR)):
|
46 |
+
os.chdir(os.path.join(ROOT_DIR, PROJECT_SUBDIR))
|
47 |
+
|
48 |
+
# Import OpenVoice modules
|
49 |
+
from openvoice import se_extractor
|
50 |
+
from openvoice.api import ToneColorConverter
|
51 |
+
TTS = None # will import lazily after ensuring MeCab/Unidic
|
52 |
+
|
53 |
+
# Import Seed-VC modules
|
54 |
+
from modules.commons import build_model, load_checkpoint, recursive_munch
|
55 |
+
from hf_utils import load_custom_model_from_hf
|
56 |
+
|
57 |
+
# OpenVoice configuration
|
58 |
+
CKPT_CONVERTER_DIR = 'checkpoints_v2/converter'
|
59 |
+
BASE_SPEAKER_SE_DIR = 'checkpoints_v2/base_speakers/ses'
|
60 |
+
OUTPUT_DIR = 'outputs_v2'
|
61 |
+
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
62 |
+
|
63 |
+
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
|
64 |
+
|
65 |
+
# Lazy singletons
|
66 |
+
_tone_color_converter = None
|
67 |
+
_melo_models = {}
|
68 |
+
|
69 |
+
# Seed-VC models (will be initialized at startup)
|
70 |
+
_seed_vc_models = None
|
71 |
+
|
72 |
+
def get_tone_color_converter():
|
73 |
+
global _tone_color_converter
|
74 |
+
if _tone_color_converter is None:
|
75 |
+
ensure_checkpoints()
|
76 |
+
converter = ToneColorConverter(f'{CKPT_CONVERTER_DIR}/config.json', device=DEVICE)
|
77 |
+
converter.load_ckpt(f'{CKPT_CONVERTER_DIR}/checkpoint.pth')
|
78 |
+
_tone_color_converter = converter
|
79 |
+
return _tone_color_converter
|
80 |
+
|
81 |
+
def ensure_unidic_available() -> None:
|
82 |
+
try:
|
83 |
+
import unidic # type: ignore
|
84 |
+
dicdir = getattr(unidic, 'DICDIR', None)
|
85 |
+
if not dicdir or not os.path.exists(os.path.join(dicdir, 'mecabrc')):
|
86 |
+
subprocess.run([sys.executable, '-m', 'unidic', 'download'], check=False)
|
87 |
+
# Reload to get DICDIR
|
88 |
+
unidic = importlib.reload(unidic)
|
89 |
+
dicdir = getattr(unidic, 'DICDIR', None)
|
90 |
+
if dicdir and os.path.exists(os.path.join(dicdir, 'mecabrc')):
|
91 |
+
os.environ['MECABRC'] = os.path.join(dicdir, 'mecabrc')
|
92 |
+
except Exception:
|
93 |
+
# Best-effort; MeloTTS may still work for non-Japanese
|
94 |
+
pass
|
95 |
+
|
96 |
+
def ensure_nltk_resources() -> None:
|
97 |
+
try:
|
98 |
+
nltk.data.find('taggers/averaged_perceptron_tagger_eng')
|
99 |
+
except LookupError:
|
100 |
+
try:
|
101 |
+
nltk.download('averaged_perceptron_tagger_eng', quiet=True)
|
102 |
+
except Exception:
|
103 |
+
pass
|
104 |
+
try:
|
105 |
+
nltk.data.find('corpora/cmudict')
|
106 |
+
except LookupError:
|
107 |
+
try:
|
108 |
+
nltk.download('cmudict', quiet=True)
|
109 |
+
except Exception:
|
110 |
+
pass
|
111 |
+
|
112 |
+
def get_melo_model(language: str):
|
113 |
+
# Normalize a couple of aliases from demo_part3
|
114 |
+
if language.lower() in {"en_us", "en-newest", "en_newest"}:
|
115 |
+
language = "EN_NEWEST"
|
116 |
+
language = language.upper()
|
117 |
+
if language not in _melo_models:
|
118 |
+
global TTS
|
119 |
+
if TTS is None:
|
120 |
+
ensure_unidic_available()
|
121 |
+
ensure_nltk_resources()
|
122 |
+
from melo.api import TTS as _TTS # type: ignore
|
123 |
+
TTS = _TTS
|
124 |
+
_melo_models[language] = TTS(language=language, device=DEVICE)
|
125 |
+
return _melo_models[language]
|
126 |
+
|
127 |
+
def list_supported_styles():
|
128 |
+
# Map speaker .pth names we have to user choices
|
129 |
+
style_list = []
|
130 |
+
if not os.path.isdir(BASE_SPEAKER_SE_DIR):
|
131 |
+
return style_list
|
132 |
+
for name in sorted(os.listdir(BASE_SPEAKER_SE_DIR)):
|
133 |
+
if name.endswith('.pth'):
|
134 |
+
style_list.append(os.path.splitext(name)[0])
|
135 |
+
return style_list
|
136 |
+
|
137 |
+
def ensure_checkpoints():
|
138 |
+
# Download and place checkpoints at exact expected paths.
|
139 |
+
if os.path.exists(CKPT_CONVERTER_DIR) and os.path.isdir(BASE_SPEAKER_SE_DIR):
|
140 |
+
return
|
141 |
+
url = 'https://myshell-public-repo-host.s3.amazonaws.com/openvoice/checkpoints_v2_0417.zip'
|
142 |
+
tmp_zip = os.path.join(tempfile.gettempdir(), 'checkpoints_v2_0417.zip')
|
143 |
+
tmp_extract_root = tempfile.mkdtemp(prefix='ov_ckpt_')
|
144 |
+
try:
|
145 |
+
with requests.get(url, stream=True, timeout=300) as r:
|
146 |
+
r.raise_for_status()
|
147 |
+
with open(tmp_zip, 'wb') as f:
|
148 |
+
for chunk in r.iter_content(chunk_size=1 << 20):
|
149 |
+
if chunk:
|
150 |
+
f.write(chunk)
|
151 |
+
with zipfile.ZipFile(tmp_zip, 'r') as zf:
|
152 |
+
zf.extractall(tmp_extract_root)
|
153 |
+
|
154 |
+
# Find the folder that contains 'converter/config.json' directly under it
|
155 |
+
candidate = None
|
156 |
+
for root, _, _ in os.walk(tmp_extract_root):
|
157 |
+
cfg = os.path.join(root, 'converter', 'config.json')
|
158 |
+
ses_dir = os.path.join(root, 'base_speakers', 'ses')
|
159 |
+
if os.path.isfile(cfg) and os.path.isdir(ses_dir):
|
160 |
+
candidate = root
|
161 |
+
break
|
162 |
+
if candidate is None:
|
163 |
+
raise RuntimeError('Could not locate converter/config.json inside the downloaded archive.')
|
164 |
+
|
165 |
+
# Place contents into 'checkpoints_v2'
|
166 |
+
target_root = 'checkpoints_v2'
|
167 |
+
if os.path.exists(target_root):
|
168 |
+
shutil.rmtree(target_root, ignore_errors=True)
|
169 |
+
os.makedirs(target_root, exist_ok=True)
|
170 |
+
|
171 |
+
# Copy only required subfolders
|
172 |
+
for name in ['converter', 'base_speakers']:
|
173 |
+
src_path = os.path.join(candidate, name)
|
174 |
+
dst_path = os.path.join(target_root, name)
|
175 |
+
if os.path.exists(dst_path):
|
176 |
+
shutil.rmtree(dst_path, ignore_errors=True)
|
177 |
+
shutil.copytree(src_path, dst_path)
|
178 |
+
except Exception as e:
|
179 |
+
raise gr.Error(f"Failed to prepare checkpoints: {e}")
|
180 |
+
finally:
|
181 |
+
try:
|
182 |
+
if os.path.isdir(tmp_extract_root):
|
183 |
+
shutil.rmtree(tmp_extract_root, ignore_errors=True)
|
184 |
+
except Exception:
|
185 |
+
pass
|
186 |
+
|
187 |
+
def initialize_seed_vc_models():
|
188 |
+
"""Initialize Seed-VC models with memory optimization"""
|
189 |
+
global _seed_vc_models
|
190 |
+
|
191 |
+
if _seed_vc_models is not None:
|
192 |
+
return _seed_vc_models
|
193 |
+
|
194 |
+
log_print("Loading Seed-VC models...")
|
195 |
+
# Clear GPU cache before loading models
|
196 |
+
if torch.cuda.is_available():
|
197 |
+
torch.cuda.empty_cache()
|
198 |
+
|
199 |
+
# Load DiT model
|
200 |
+
dit_checkpoint_path, dit_config_path = load_custom_model_from_hf("Plachta/Seed-VC",
|
201 |
+
"DiT_seed_v2_uvit_whisper_small_wavenet_bigvgan_pruned.pth",
|
202 |
+
"config_dit_mel_seed_uvit_whisper_small_wavenet.yml")
|
203 |
+
|
204 |
+
with open(dit_config_path, 'r', encoding='utf-8') as f:
|
205 |
+
config = yaml.safe_load(f)
|
206 |
+
model_params = recursive_munch(config['model_params'])
|
207 |
+
model = build_model(model_params, stage='DiT')
|
208 |
+
hop_length = config['preprocess_params']['spect_params']['hop_length']
|
209 |
+
sr = config['preprocess_params']['sr']
|
210 |
+
|
211 |
+
# Load checkpoints with memory optimization
|
212 |
+
model, _, _, _ = load_checkpoint(model, None, dit_checkpoint_path,
|
213 |
+
load_only_params=True, ignore_modules=[], is_distributed=False)
|
214 |
+
for key in model:
|
215 |
+
model[key].eval()
|
216 |
+
model[key].to(DEVICE)
|
217 |
+
model.cfm.estimator.setup_caches(max_batch_size=1, max_seq_length=4096) # Reduced from 8192
|
218 |
+
|
219 |
+
# Load CAMPPlus
|
220 |
+
from modules.campplus.DTDNN import CAMPPlus
|
221 |
+
campplus_ckpt_path = load_custom_model_from_hf("funasr/campplus", "campplus_cn_common.bin", config_filename=None)
|
222 |
+
campplus_model = CAMPPlus(feat_dim=80, embedding_size=192)
|
223 |
+
campplus_model.load_state_dict(torch.load(campplus_ckpt_path, map_location="cpu"))
|
224 |
+
campplus_model.eval()
|
225 |
+
campplus_model.to(DEVICE)
|
226 |
+
|
227 |
+
# Load BigVGAN
|
228 |
+
from modules.bigvgan import bigvgan
|
229 |
+
bigvgan_model = bigvgan.BigVGAN.from_pretrained('nvidia/bigvgan_v2_22khz_80band_256x', use_cuda_kernel=False)
|
230 |
+
bigvgan_model.remove_weight_norm()
|
231 |
+
bigvgan_model = bigvgan_model.eval().to(DEVICE)
|
232 |
+
|
233 |
+
# Load FAcodec with error handling
|
234 |
+
try:
|
235 |
+
ckpt_path, config_path = load_custom_model_from_hf("Plachta/FAcodec", 'pytorch_model.bin', 'config.yml')
|
236 |
+
with open(config_path, 'r', encoding='utf-8') as f:
|
237 |
+
codec_config = yaml.safe_load(f)
|
238 |
+
codec_model_params = recursive_munch(codec_config['model_params'])
|
239 |
+
|
240 |
+
# Remove problematic 'causal' parameter if it exists
|
241 |
+
if hasattr(codec_model_params, 'dac_params') and hasattr(codec_model_params.dac_params, 'causal'):
|
242 |
+
delattr(codec_model_params.dac_params, 'causal')
|
243 |
+
log_print("Removed 'causal' parameter from DAC config")
|
244 |
+
|
245 |
+
# Also check for other problematic parameters
|
246 |
+
if hasattr(codec_model_params, 'dac_params'):
|
247 |
+
dac_params = codec_model_params.dac_params
|
248 |
+
# Remove any parameters that might cause issues
|
249 |
+
problematic_params = ['causal', 'causal_conv', 'causal_attention']
|
250 |
+
for param in problematic_params:
|
251 |
+
if hasattr(dac_params, param):
|
252 |
+
delattr(dac_params, param)
|
253 |
+
log_print(f"Removed '{param}' parameter from DAC config")
|
254 |
+
|
255 |
+
codec_encoder = build_model(codec_model_params, stage="codec")
|
256 |
+
log_print("✓ FAcodec loaded successfully")
|
257 |
+
except Exception as e:
|
258 |
+
log_error(f"Warning: Failed to load FAcodec: {e}")
|
259 |
+
log_error(f"FAcodec error traceback: {traceback.format_exc()}")
|
260 |
+
# Create a minimal dummy codec encoder
|
261 |
+
log_print("Creating minimal codec encoder as fallback...")
|
262 |
+
try:
|
263 |
+
# Try to create a basic DAC model without problematic parameters
|
264 |
+
from descript_audio_codec import DAC
|
265 |
+
codec_encoder = {'codec': DAC()}
|
266 |
+
log_print("✓ Created minimal DAC fallback")
|
267 |
+
except Exception as e2:
|
268 |
+
log_error(f"Failed to create DAC fallback: {e2}")
|
269 |
+
# Create a completely dummy encoder
|
270 |
+
class DummyCodec:
|
271 |
+
def __getitem__(self, key):
|
272 |
+
return self
|
273 |
+
def eval(self):
|
274 |
+
return self
|
275 |
+
def to(self, device):
|
276 |
+
return self
|
277 |
+
codec_encoder = {'codec': DummyCodec()}
|
278 |
+
log_print("✓ Created dummy codec encoder")
|
279 |
+
|
280 |
+
# Load codec checkpoint with error handling
|
281 |
+
try:
|
282 |
+
ckpt_params = torch.load(ckpt_path, map_location="cpu")
|
283 |
+
if 'codec' in ckpt_params:
|
284 |
+
codec_encoder.codec.load_state_dict(ckpt_params['codec'], strict=False)
|
285 |
+
elif 'model' in ckpt_params:
|
286 |
+
codec_encoder.codec.load_state_dict(ckpt_params['model'], strict=False)
|
287 |
+
else:
|
288 |
+
codec_encoder.codec.load_state_dict(ckpt_params, strict=False)
|
289 |
+
except Exception as e:
|
290 |
+
log_error(f"Warning: Could not load codec state dict: {e}")
|
291 |
+
log_error(f"Codec state dict error traceback: {traceback.format_exc()}")
|
292 |
+
log_error("Codec will use default parameters")
|
293 |
+
|
294 |
+
_ = [codec_encoder[key].eval() for key in codec_encoder]
|
295 |
+
_ = [codec_encoder[key].to(DEVICE) for key in codec_encoder]
|
296 |
+
|
297 |
+
# Load Whisper
|
298 |
+
from transformers import AutoFeatureExtractor, WhisperModel
|
299 |
+
whisper_name = model_params.speech_tokenizer.whisper_name if hasattr(model_params.speech_tokenizer, 'whisper_name') else "openai/whisper-small"
|
300 |
+
whisper_model = WhisperModel.from_pretrained(whisper_name, torch_dtype=torch.float16).to(DEVICE)
|
301 |
+
del whisper_model.decoder
|
302 |
+
whisper_feature_extractor = AutoFeatureExtractor.from_pretrained(whisper_name)
|
303 |
+
|
304 |
+
# Mel spectrogram function
|
305 |
+
mel_fn_args = {
|
306 |
+
"n_fft": config['preprocess_params']['spect_params']['n_fft'],
|
307 |
+
"win_size": config['preprocess_params']['spect_params']['win_length'],
|
308 |
+
"hop_size": config['preprocess_params']['spect_params']['hop_length'],
|
309 |
+
"num_mels": config['preprocess_params']['spect_params']['n_mels'],
|
310 |
+
"sampling_rate": sr,
|
311 |
+
"fmin": 0,
|
312 |
+
"fmax": None,
|
313 |
+
"center": False
|
314 |
+
}
|
315 |
+
from modules.audio import mel_spectrogram
|
316 |
+
to_mel = lambda x: mel_spectrogram(x, **mel_fn_args)
|
317 |
+
|
318 |
+
# Load F0 conditioned model
|
319 |
+
dit_checkpoint_path_f0, dit_config_path_f0 = load_custom_model_from_hf("Plachta/Seed-VC",
|
320 |
+
"DiT_seed_v2_uvit_whisper_base_f0_44k_bigvgan_pruned_ft_ema.pth",
|
321 |
+
"config_dit_mel_seed_uvit_whisper_base_f0_44k.yml")
|
322 |
+
|
323 |
+
with open(dit_config_path_f0, 'r', encoding='utf-8') as f:
|
324 |
+
config_f0 = yaml.safe_load(f)
|
325 |
+
model_params_f0 = recursive_munch(config_f0['model_params'])
|
326 |
+
model_f0 = build_model(model_params_f0, stage='DiT')
|
327 |
+
hop_length_f0 = config_f0['preprocess_params']['spect_params']['hop_length']
|
328 |
+
sr_f0 = config_f0['preprocess_params']['sr']
|
329 |
+
|
330 |
+
# Load checkpoints
|
331 |
+
model_f0, _, _, _ = load_checkpoint(model_f0, None, dit_checkpoint_path_f0,
|
332 |
+
load_only_params=True, ignore_modules=[], is_distributed=False)
|
333 |
+
for key in model_f0:
|
334 |
+
model_f0[key].eval()
|
335 |
+
model_f0[key].to(DEVICE)
|
336 |
+
model_f0.cfm.estimator.setup_caches(max_batch_size=1, max_seq_length=4096) # Reduced from 8192
|
337 |
+
|
338 |
+
# Load RMVPE
|
339 |
+
from modules.rmvpe import RMVPE
|
340 |
+
model_path = load_custom_model_from_hf("lj1995/VoiceConversionWebUI", "rmvpe.pt", None)
|
341 |
+
rmvpe = RMVPE(model_path, is_half=False, device=DEVICE)
|
342 |
+
|
343 |
+
mel_fn_args_f0 = {
|
344 |
+
"n_fft": config_f0['preprocess_params']['spect_params']['n_fft'],
|
345 |
+
"win_size": config_f0['preprocess_params']['spect_params']['win_length'],
|
346 |
+
"hop_size": config_f0['preprocess_params']['spect_params']['hop_length'],
|
347 |
+
"num_mels": config_f0['preprocess_params']['spect_params']['n_mels'],
|
348 |
+
"sampling_rate": sr_f0,
|
349 |
+
"fmin": 0,
|
350 |
+
"fmax": None,
|
351 |
+
"center": False
|
352 |
+
}
|
353 |
+
to_mel_f0 = lambda x: mel_spectrogram(x, **mel_fn_args_f0)
|
354 |
+
|
355 |
+
bigvgan_44k_model = bigvgan.BigVGAN.from_pretrained('nvidia/bigvgan_v2_44khz_128band_512x', use_cuda_kernel=False)
|
356 |
+
bigvgan_44k_model.remove_weight_norm()
|
357 |
+
bigvgan_44k_model = bigvgan_44k_model.eval().to(DEVICE)
|
358 |
+
|
359 |
+
_seed_vc_models = {
|
360 |
+
'model': model,
|
361 |
+
'model_f0': model_f0,
|
362 |
+
'campplus_model': campplus_model,
|
363 |
+
'bigvgan_model': bigvgan_model,
|
364 |
+
'bigvgan_44k_model': bigvgan_44k_model,
|
365 |
+
'codec_encoder': codec_encoder,
|
366 |
+
'whisper_model': whisper_model,
|
367 |
+
'whisper_feature_extractor': whisper_feature_extractor,
|
368 |
+
'to_mel': to_mel,
|
369 |
+
'to_mel_f0': to_mel_f0,
|
370 |
+
'rmvpe': rmvpe,
|
371 |
+
'config': config,
|
372 |
+
'config_f0': config_f0,
|
373 |
+
'hop_length': hop_length,
|
374 |
+
'sr': sr,
|
375 |
+
'hop_length_f0': hop_length_f0,
|
376 |
+
'sr_f0': sr_f0
|
377 |
+
}
|
378 |
+
|
379 |
+
return _seed_vc_models
|
380 |
+
|
381 |
+
def adjust_f0_semitones(f0_sequence, n_semitones):
|
382 |
+
factor = 2 ** (n_semitones / 12)
|
383 |
+
return f0_sequence * factor
|
384 |
+
|
385 |
+
def crossfade(chunk1, chunk2, overlap):
|
386 |
+
fade_out = np.cos(np.linspace(0, np.pi / 2, overlap)) ** 2
|
387 |
+
fade_in = np.cos(np.linspace(np.pi / 2, 0, overlap)) ** 2
|
388 |
+
chunk2[:overlap] = chunk2[:overlap] * fade_in + chunk1[-overlap:] * fade_out
|
389 |
+
return chunk2
|
390 |
+
|
391 |
+
# Step 1: OpenVoice TTS + Voice Cloning
|
392 |
+
def run_openvoice_inference(text: str, style_key: str, speed: float, reference_audio_path: str) -> str:
|
393 |
+
if not text or not reference_audio_path:
|
394 |
+
raise gr.Error("Please provide text and a reference audio.")
|
395 |
+
|
396 |
+
# Re-evaluate device at call time for ZeroGPU
|
397 |
+
global DEVICE
|
398 |
+
DEVICE = "cuda:0" if torch.cuda.is_available() else DEVICE
|
399 |
+
converter = get_tone_color_converter()
|
400 |
+
|
401 |
+
# Extract target speaker embedding from uploaded reference audio
|
402 |
+
target_se, _ = se_extractor.get_se(reference_audio_path, converter, vad=True)
|
403 |
+
|
404 |
+
# Prepare base speech with Melo
|
405 |
+
language_from_style = "EN_NEWEST" if style_key.startswith("en-") else None
|
406 |
+
if style_key.startswith("es"):
|
407 |
+
language_from_style = "ES"
|
408 |
+
elif style_key.startswith("fr"):
|
409 |
+
language_from_style = "FR"
|
410 |
+
elif style_key.startswith("zh"):
|
411 |
+
language_from_style = "ZH"
|
412 |
+
elif style_key.startswith("jp"):
|
413 |
+
language_from_style = "JP"
|
414 |
+
elif style_key.startswith("kr"):
|
415 |
+
language_from_style = "KR"
|
416 |
+
|
417 |
+
melo = get_melo_model(language_from_style or "EN_NEWEST")
|
418 |
+
speaker_ids = melo.hps.data.spk2id
|
419 |
+
|
420 |
+
# Pick first available speaker id for that language
|
421 |
+
speaker_id = next(iter(speaker_ids.values()))
|
422 |
+
|
423 |
+
# Disable MPS quirk similar to demo_part3
|
424 |
+
if torch.backends.mps.is_available() and DEVICE == 'cpu':
|
425 |
+
torch.backends.mps.is_available = lambda: False
|
426 |
+
|
427 |
+
tmp_wav = os.path.join(OUTPUT_DIR, 'tmp.wav')
|
428 |
+
melo.tts_to_file(text, speaker_id, tmp_wav, speed=speed)
|
429 |
+
|
430 |
+
# Source speaker embedding from selected base style
|
431 |
+
source_se_path = os.path.join(BASE_SPEAKER_SE_DIR, f'{style_key}.pth')
|
432 |
+
if not os.path.exists(source_se_path):
|
433 |
+
raise gr.Error(f"Missing base speaker embedding: {source_se_path}")
|
434 |
+
source_se = torch.load(source_se_path, map_location=DEVICE)
|
435 |
+
|
436 |
+
out_path = os.path.join(OUTPUT_DIR, f'openvoice_output_{style_key}.wav')
|
437 |
+
|
438 |
+
# Convert tone color
|
439 |
+
get_tone_color_converter().convert(
|
440 |
+
audio_src_path=tmp_wav,
|
441 |
+
src_se=source_se,
|
442 |
+
tgt_se=target_se,
|
443 |
+
output_path=out_path,
|
444 |
+
message='@MyShell',
|
445 |
+
)
|
446 |
+
|
447 |
+
return out_path
|
448 |
+
|
449 |
+
# Step 2: Seed-VC Voice Conversion
|
450 |
+
@torch.no_grad()
|
451 |
+
@torch.inference_mode()
|
452 |
+
def run_seed_vc_inference(source_audio_path: str, target_audio_path: str, vc_diffusion_steps: int,
|
453 |
+
vc_length_adjust: float, vc_inference_cfg_rate: float, vc_f0_condition: bool,
|
454 |
+
vc_auto_f0_adjust: bool, vc_pitch_shift: int) -> str:
|
455 |
+
|
456 |
+
log_print("Initializing Seed-VC models...")
|
457 |
+
models = initialize_seed_vc_models()
|
458 |
+
log_print("✓ Seed-VC models ready")
|
459 |
+
|
460 |
+
inference_module = models['model_f0'] if vc_f0_condition else models['model']
|
461 |
+
mel_fn = models['to_mel_f0'] if vc_f0_condition else models['to_mel']
|
462 |
+
bigvgan_fn = models['bigvgan_44k_model'] if vc_f0_condition else models['bigvgan_model']
|
463 |
+
sr = models['sr_f0'] if vc_f0_condition else models['sr']
|
464 |
+
hop_length = models['hop_length_f0'] if vc_f0_condition else models['hop_length']
|
465 |
+
|
466 |
+
max_context_window = sr // hop_length * 30
|
467 |
+
overlap_frame_len = 16
|
468 |
+
overlap_wave_len = overlap_frame_len * hop_length
|
469 |
+
bitrate = "320k"
|
470 |
+
|
471 |
+
# Load audio
|
472 |
+
source_audio = librosa.load(source_audio_path, sr=sr)[0]
|
473 |
+
ref_audio = librosa.load(target_audio_path, sr=sr)[0]
|
474 |
+
|
475 |
+
# Process audio
|
476 |
+
source_audio = torch.tensor(source_audio).unsqueeze(0).float().to(DEVICE)
|
477 |
+
ref_audio = torch.tensor(ref_audio[:sr * 25]).unsqueeze(0).float().to(DEVICE)
|
478 |
+
|
479 |
+
# Resample
|
480 |
+
ref_waves_16k = torchaudio.functional.resample(ref_audio, sr, 16000)
|
481 |
+
converted_waves_16k = torchaudio.functional.resample(source_audio, sr, 16000)
|
482 |
+
|
483 |
+
# Whisper processing
|
484 |
+
if converted_waves_16k.size(-1) <= 16000 * 30:
|
485 |
+
alt_inputs = models['whisper_feature_extractor']([converted_waves_16k.squeeze(0).cpu().numpy()],
|
486 |
+
return_tensors="pt",
|
487 |
+
return_attention_mask=True,
|
488 |
+
sampling_rate=16000)
|
489 |
+
alt_input_features = models['whisper_model']._mask_input_features(
|
490 |
+
alt_inputs.input_features, attention_mask=alt_inputs.attention_mask).to(DEVICE)
|
491 |
+
alt_outputs = models['whisper_model'].encoder(
|
492 |
+
alt_input_features.to(models['whisper_model'].encoder.dtype),
|
493 |
+
head_mask=None,
|
494 |
+
output_attentions=False,
|
495 |
+
output_hidden_states=False,
|
496 |
+
return_dict=True,
|
497 |
+
)
|
498 |
+
S_alt = alt_outputs.last_hidden_state.to(torch.float32)
|
499 |
+
S_alt = S_alt[:, :converted_waves_16k.size(-1) // 320 + 1]
|
500 |
+
else:
|
501 |
+
overlapping_time = 5 # 5 seconds
|
502 |
+
S_alt_list = []
|
503 |
+
buffer = None
|
504 |
+
traversed_time = 0
|
505 |
+
while traversed_time < converted_waves_16k.size(-1):
|
506 |
+
if buffer is None: # first chunk
|
507 |
+
chunk = converted_waves_16k[:, traversed_time:traversed_time + 16000 * 30]
|
508 |
+
else:
|
509 |
+
chunk = torch.cat([buffer, converted_waves_16k[:, traversed_time:traversed_time + 16000 * (30 - overlapping_time)]], dim=-1)
|
510 |
+
alt_inputs = models['whisper_feature_extractor']([chunk.squeeze(0).cpu().numpy()],
|
511 |
+
return_tensors="pt",
|
512 |
+
return_attention_mask=True,
|
513 |
+
sampling_rate=16000)
|
514 |
+
alt_input_features = models['whisper_model']._mask_input_features(
|
515 |
+
alt_inputs.input_features, attention_mask=alt_inputs.attention_mask).to(DEVICE)
|
516 |
+
alt_outputs = models['whisper_model'].encoder(
|
517 |
+
alt_input_features.to(models['whisper_model'].encoder.dtype),
|
518 |
+
head_mask=None,
|
519 |
+
output_attentions=False,
|
520 |
+
output_hidden_states=False,
|
521 |
+
return_dict=True,
|
522 |
+
)
|
523 |
+
S_alt = alt_outputs.last_hidden_state.to(torch.float32)
|
524 |
+
S_alt = S_alt[:, :chunk.size(-1) // 320 + 1]
|
525 |
+
if traversed_time == 0:
|
526 |
+
S_alt_list.append(S_alt)
|
527 |
+
else:
|
528 |
+
S_alt_list.append(S_alt[:, 50 * overlapping_time:])
|
529 |
+
buffer = chunk[:, -16000 * overlapping_time:]
|
530 |
+
traversed_time += 30 * 16000 if traversed_time == 0 else chunk.size(-1) - 16000 * overlapping_time
|
531 |
+
S_alt = torch.cat(S_alt_list, dim=1)
|
532 |
+
|
533 |
+
ori_waves_16k = torchaudio.functional.resample(ref_audio, sr, 16000)
|
534 |
+
ori_inputs = models['whisper_feature_extractor']([ori_waves_16k.squeeze(0).cpu().numpy()],
|
535 |
+
return_tensors="pt",
|
536 |
+
return_attention_mask=True)
|
537 |
+
ori_input_features = models['whisper_model']._mask_input_features(
|
538 |
+
ori_inputs.input_features, attention_mask=ori_inputs.attention_mask).to(DEVICE)
|
539 |
+
with torch.no_grad():
|
540 |
+
ori_outputs = models['whisper_model'].encoder(
|
541 |
+
ori_input_features.to(models['whisper_model'].encoder.dtype),
|
542 |
+
head_mask=None,
|
543 |
+
output_attentions=False,
|
544 |
+
output_hidden_states=False,
|
545 |
+
return_dict=True,
|
546 |
+
)
|
547 |
+
S_ori = ori_outputs.last_hidden_state.to(torch.float32)
|
548 |
+
S_ori = S_ori[:, :ori_waves_16k.size(-1) // 320 + 1]
|
549 |
+
|
550 |
+
mel = mel_fn(source_audio.to(DEVICE).float())
|
551 |
+
mel2 = mel_fn(ref_audio.to(DEVICE).float())
|
552 |
+
|
553 |
+
target_lengths = torch.LongTensor([int(mel.size(2) * vc_length_adjust)]).to(mel.device)
|
554 |
+
target2_lengths = torch.LongTensor([mel2.size(2)]).to(mel2.device)
|
555 |
+
|
556 |
+
feat2 = torchaudio.compliance.kaldi.fbank(ref_waves_16k,
|
557 |
+
num_mel_bins=80,
|
558 |
+
dither=0,
|
559 |
+
sample_frequency=16000)
|
560 |
+
feat2 = feat2 - feat2.mean(dim=0, keepdim=True)
|
561 |
+
style2 = models['campplus_model'](feat2.unsqueeze(0))
|
562 |
+
|
563 |
+
if vc_f0_condition:
|
564 |
+
F0_ori = models['rmvpe'].infer_from_audio(ref_waves_16k[0], thred=0.5)
|
565 |
+
F0_alt = models['rmvpe'].infer_from_audio(converted_waves_16k[0], thred=0.5)
|
566 |
+
|
567 |
+
F0_ori = torch.from_numpy(F0_ori).to(DEVICE)[None]
|
568 |
+
F0_alt = torch.from_numpy(F0_alt).to(DEVICE)[None]
|
569 |
+
|
570 |
+
voiced_F0_ori = F0_ori[F0_ori > 1]
|
571 |
+
voiced_F0_alt = F0_alt[F0_alt > 1]
|
572 |
+
|
573 |
+
log_f0_alt = torch.log(F0_alt + 1e-5)
|
574 |
+
voiced_log_f0_ori = torch.log(voiced_F0_ori + 1e-5)
|
575 |
+
voiced_log_f0_alt = torch.log(voiced_F0_alt + 1e-5)
|
576 |
+
median_log_f0_ori = torch.median(voiced_log_f0_ori)
|
577 |
+
median_log_f0_alt = torch.median(voiced_log_f0_alt)
|
578 |
+
|
579 |
+
# shift alt log f0 level to ori log f0 level
|
580 |
+
shifted_log_f0_alt = log_f0_alt.clone()
|
581 |
+
if vc_auto_f0_adjust:
|
582 |
+
shifted_log_f0_alt[F0_alt > 1] = log_f0_alt[F0_alt > 1] - median_log_f0_alt + median_log_f0_ori
|
583 |
+
shifted_f0_alt = torch.exp(shifted_log_f0_alt)
|
584 |
+
if vc_pitch_shift != 0:
|
585 |
+
shifted_f0_alt[F0_alt > 1] = adjust_f0_semitones(shifted_f0_alt[F0_alt > 1], vc_pitch_shift)
|
586 |
+
else:
|
587 |
+
F0_ori = None
|
588 |
+
F0_alt = None
|
589 |
+
shifted_f0_alt = None
|
590 |
+
|
591 |
+
# Length regulation
|
592 |
+
cond, _, _, _, _ = inference_module.length_regulator(S_alt, ylens=target_lengths, n_quantizers=3, f0=shifted_f0_alt)
|
593 |
+
prompt_condition, _, _, _, _ = inference_module.length_regulator(S_ori, ylens=target2_lengths, n_quantizers=3, f0=F0_ori)
|
594 |
+
|
595 |
+
max_source_window = max_context_window - mel2.size(2)
|
596 |
+
# split source condition (cond) into chunks
|
597 |
+
processed_frames = 0
|
598 |
+
generated_wave_chunks = []
|
599 |
+
# generate chunk by chunk and stream the output
|
600 |
+
while processed_frames < cond.size(1):
|
601 |
+
chunk_cond = cond[:, processed_frames:processed_frames + max_source_window]
|
602 |
+
is_last_chunk = processed_frames + max_source_window >= cond.size(1)
|
603 |
+
cat_condition = torch.cat([prompt_condition, chunk_cond], dim=1)
|
604 |
+
with torch.autocast(device_type='cuda', dtype=torch.float16):
|
605 |
+
# Voice Conversion
|
606 |
+
vc_target = inference_module.cfm.inference(cat_condition,
|
607 |
+
torch.LongTensor([cat_condition.size(1)]).to(mel2.device),
|
608 |
+
mel2, style2, None, vc_diffusion_steps,
|
609 |
+
inference_cfg_rate=vc_inference_cfg_rate)
|
610 |
+
vc_target = vc_target[:, :, mel2.size(-1):]
|
611 |
+
vc_wave = bigvgan_fn(vc_target.float())[0]
|
612 |
+
if processed_frames == 0:
|
613 |
+
if is_last_chunk:
|
614 |
+
output_wave = vc_wave[0].cpu().numpy()
|
615 |
+
generated_wave_chunks.append(output_wave)
|
616 |
+
output_wave = (output_wave * 32768.0).astype(np.int16)
|
617 |
+
mp3_bytes = AudioSegment(
|
618 |
+
output_wave.tobytes(), frame_rate=sr,
|
619 |
+
sample_width=output_wave.dtype.itemsize, channels=1
|
620 |
+
).export(format="mp3", bitrate=bitrate).read()
|
621 |
+
yield mp3_bytes, (sr, np.concatenate(generated_wave_chunks))
|
622 |
+
break
|
623 |
+
output_wave = vc_wave[0, :-overlap_wave_len].cpu().numpy()
|
624 |
+
generated_wave_chunks.append(output_wave)
|
625 |
+
previous_chunk = vc_wave[0, -overlap_wave_len:]
|
626 |
+
processed_frames += vc_target.size(2) - overlap_frame_len
|
627 |
+
output_wave = (output_wave * 32768.0).astype(np.int16)
|
628 |
+
mp3_bytes = AudioSegment(
|
629 |
+
output_wave.tobytes(), frame_rate=sr,
|
630 |
+
sample_width=output_wave.dtype.itemsize, channels=1
|
631 |
+
).export(format="mp3", bitrate=bitrate).read()
|
632 |
+
yield mp3_bytes, None
|
633 |
+
elif is_last_chunk:
|
634 |
+
output_wave = crossfade(previous_chunk.cpu().numpy(), vc_wave[0].cpu().numpy(), overlap_wave_len)
|
635 |
+
generated_wave_chunks.append(output_wave)
|
636 |
+
processed_frames += vc_target.size(2) - overlap_frame_len
|
637 |
+
output_wave = (output_wave * 32768.0).astype(np.int16)
|
638 |
+
mp3_bytes = AudioSegment(
|
639 |
+
output_wave.tobytes(), frame_rate=sr,
|
640 |
+
sample_width=output_wave.dtype.itemsize, channels=1
|
641 |
+
).export(format="mp3", bitrate=bitrate).read()
|
642 |
+
yield mp3_bytes, (sr, np.concatenate(generated_wave_chunks))
|
643 |
+
break
|
644 |
+
else:
|
645 |
+
output_wave = crossfade(previous_chunk.cpu().numpy(), vc_wave[0, :-overlap_wave_len].cpu().numpy(), overlap_wave_len)
|
646 |
+
generated_wave_chunks.append(output_wave)
|
647 |
+
previous_chunk = vc_wave[0, -overlap_wave_len:]
|
648 |
+
processed_frames += vc_target.size(2) - overlap_frame_len
|
649 |
+
output_wave = (output_wave * 32768.0).astype(np.int16)
|
650 |
+
mp3_bytes = AudioSegment(
|
651 |
+
output_wave.tobytes(), frame_rate=sr,
|
652 |
+
sample_width=output_wave.dtype.itemsize, channels=1
|
653 |
+
).export(format="mp3", bitrate=bitrate).read()
|
654 |
+
yield mp3_bytes, None
|
655 |
+
|
656 |
+
# Main integrated function
|
657 |
+
@spaces.GPU
|
658 |
+
def process_integrated_tts_vc(text, style, speed, reference_audio, vc_diffusion_steps, vc_length_adjust,
|
659 |
+
vc_inference_cfg_rate, vc_f0_condition, vc_auto_f0_adjust, vc_pitch_shift):
|
660 |
+
"""Integrated TTS + Voice Conversion pipeline"""
|
661 |
+
|
662 |
+
log_print("=" * 50)
|
663 |
+
log_print("STARTING PROCESSING...")
|
664 |
+
log_print(f"Text: {text[:50]}...")
|
665 |
+
log_print(f"Style: {style}, Speed: {speed}")
|
666 |
+
log_print(f"VC params: steps={vc_diffusion_steps}, length={vc_length_adjust}")
|
667 |
+
log_print("=" * 50)
|
668 |
+
|
669 |
+
# Handle Gradio audio input format
|
670 |
+
ref_path = None
|
671 |
+
if isinstance(reference_audio, tuple):
|
672 |
+
sr, data = reference_audio
|
673 |
+
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as f:
|
674 |
+
sf.write(f.name, data, sr)
|
675 |
+
ref_path = f.name
|
676 |
+
elif isinstance(reference_audio, str):
|
677 |
+
ref_path = reference_audio
|
678 |
+
|
679 |
+
if not ref_path:
|
680 |
+
log_error("ERROR: No reference audio provided")
|
681 |
+
raise gr.Error("Please provide a reference audio.")
|
682 |
+
|
683 |
+
try:
|
684 |
+
# Step 1: OpenVoice TTS + Voice Cloning
|
685 |
+
log_print("Step 1: Running OpenVoice TTS...")
|
686 |
+
intermediate_audio = run_openvoice_inference(text, style, speed, ref_path)
|
687 |
+
log_print(f"✓ OpenVoice completed. Intermediate audio: {intermediate_audio}")
|
688 |
+
|
689 |
+
# Step 2: Seed-VC Voice Conversion
|
690 |
+
log_print("Step 2: Running Seed-VC Voice Conversion...")
|
691 |
+
# Call the actual voice conversion function and collect all results
|
692 |
+
results = list(run_seed_vc_inference(intermediate_audio, ref_path, vc_diffusion_steps, vc_length_adjust,
|
693 |
+
vc_inference_cfg_rate, vc_f0_condition, vc_auto_f0_adjust, vc_pitch_shift))
|
694 |
+
log_print(f"✓ Seed-VC completed. Results count: {len(results)}")
|
695 |
+
|
696 |
+
except Exception as e:
|
697 |
+
log_error(f"CRITICAL ERROR in processing: {str(e)}")
|
698 |
+
log_error(f"Error type: {type(e).__name__}")
|
699 |
+
log_error("Full traceback:")
|
700 |
+
log_error(traceback.format_exc())
|
701 |
+
# Re-raise the error to see it in Gradio
|
702 |
+
raise
|
703 |
+
|
704 |
+
# Find the final result (the one with the complete audio data)
|
705 |
+
final_result = None
|
706 |
+
for result in results:
|
707 |
+
if isinstance(result, tuple) and len(result) == 2 and result[1] is not None:
|
708 |
+
# This is the final result with complete audio data
|
709 |
+
final_result = result[1]
|
710 |
+
break
|
711 |
+
|
712 |
+
if final_result is not None:
|
713 |
+
# Save the final audio to a temporary file
|
714 |
+
sr, audio_data = final_result
|
715 |
+
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as f:
|
716 |
+
sf.write(f.name, audio_data, sr)
|
717 |
+
return f.name
|
718 |
+
|
719 |
+
return None
|
720 |
+
|
721 |
+
# Get supported styles
|
722 |
+
styles = list_supported_styles() or [
|
723 |
+
'en-newest', 'en-default', 'en-us', 'en-br', 'en-au', 'en-india',
|
724 |
+
'es', 'fr', 'zh', 'jp', 'kr'
|
725 |
+
]
|
726 |
+
|
727 |
+
# Skip model pre-loading for faster startup
|
728 |
+
log_print("=" * 50)
|
729 |
+
log_print("SKIPPING MODEL PRE-LOADING FOR FASTER STARTUP")
|
730 |
+
log_print("Models will be loaded on first use")
|
731 |
+
log_print("=" * 50)
|
732 |
+
|
733 |
+
# Create Gradio interface
|
734 |
+
with gr.Blocks(title="Integrated TTS + Voice Conversion", analytics_enabled=False) as demo:
|
735 |
+
gr.Markdown("""
|
736 |
+
# **Integrated TTS + Voice Conversion** — 텍스트를 음성으로 변환 후 음성 변환
|
737 |
+
|
738 |
+
텍스트를 입력하고 참조 음성을 업로드하면, 먼저 텍스트가 음성으로 변환된 후 참조 음성의 스타일로 변환됩니다.
|
739 |
+
|
740 |
+
**사용법:**
|
741 |
+
1. 변환할 텍스트를 입력하세요
|
742 |
+
2. 참조 음성을 업로드하세요 (3-10초 권장)
|
743 |
+
3. 기본 음성 스타일과 속도를 선택하세요
|
744 |
+
4. 음성 변환 파라미터를 조정하세요
|
745 |
+
5. "통합 변환" 버튼을 클릭하세요
|
746 |
+
""")
|
747 |
+
|
748 |
+
with gr.Row():
|
749 |
+
with gr.Column(scale=6):
|
750 |
+
# TTS Parameters
|
751 |
+
gr.Markdown("### 🎤 텍스트-음성 변환 설정")
|
752 |
+
text_input = gr.Textbox(
|
753 |
+
label="변환할 텍스트",
|
754 |
+
value="안녕하세요! 이것은 통합 TTS와 음성 변환 데모입니다.",
|
755 |
+
lines=3
|
756 |
+
)
|
757 |
+
style_input = gr.Dropdown(
|
758 |
+
label="기본 음성 스타일",
|
759 |
+
choices=styles,
|
760 |
+
value=styles[0]
|
761 |
+
)
|
762 |
+
speed_input = gr.Slider(
|
763 |
+
0.6, 1.4, value=1.0, step=0.05,
|
764 |
+
label="음성 속도 (×)"
|
765 |
+
)
|
766 |
+
reference_audio_input = gr.Audio(
|
767 |
+
label="참조 음성",
|
768 |
+
sources=["upload", "microphone"],
|
769 |
+
type="filepath"
|
770 |
+
)
|
771 |
+
|
772 |
+
# Voice Conversion Parameters
|
773 |
+
gr.Markdown("### 🔄 음성 변환 설정")
|
774 |
+
with gr.Row():
|
775 |
+
vc_diffusion_steps = gr.Slider(
|
776 |
+
minimum=1, maximum=200, value=25, step=1,
|
777 |
+
label="확산 단계",
|
778 |
+
info="25 기본값, 50~100 최고 품질"
|
779 |
+
)
|
780 |
+
vc_length_adjust = gr.Slider(
|
781 |
+
minimum=0.5, maximum=2.0, step=0.1, value=1.0,
|
782 |
+
label="길이 조정",
|
783 |
+
info="<1.0 빠르게, >1.0 느리게"
|
784 |
+
)
|
785 |
+
|
786 |
+
with gr.Row():
|
787 |
+
vc_inference_cfg_rate = gr.Slider(
|
788 |
+
minimum=0.0, maximum=1.0, step=0.1, value=0.7,
|
789 |
+
label="CFG 비율",
|
790 |
+
info="미묘한 영향"
|
791 |
+
)
|
792 |
+
vc_pitch_shift = gr.Slider(
|
793 |
+
minimum=-24, maximum=24, step=1, value=0,
|
794 |
+
label="피치 시프트",
|
795 |
+
info="반음 단위"
|
796 |
+
)
|
797 |
+
|
798 |
+
with gr.Row():
|
799 |
+
vc_f0_condition = gr.Checkbox(
|
800 |
+
label="F0 조건부 모델 사용",
|
801 |
+
value=False,
|
802 |
+
info="노래 음성 변환에 필요"
|
803 |
+
)
|
804 |
+
vc_auto_f0_adjust = gr.Checkbox(
|
805 |
+
label="자동 F0 조정",
|
806 |
+
value=True,
|
807 |
+
info="대상 음성에 맞게 F0 조정"
|
808 |
+
)
|
809 |
+
|
810 |
+
convert_btn = gr.Button("통합 변환", variant="primary", size="lg")
|
811 |
+
|
812 |
+
with gr.Column(scale=6):
|
813 |
+
output_audio = gr.Audio(
|
814 |
+
label="최종 변환된 음성",
|
815 |
+
autoplay=True,
|
816 |
+
format="wav"
|
817 |
+
)
|
818 |
+
|
819 |
+
gr.Markdown("""
|
820 |
+
### 📋 처리 과정:
|
821 |
+
1. **텍스트 → 음성**: 입력된 텍스트가 참조 음성의 톤으로 변환됩니다
|
822 |
+
2. **음성 변환**: 생성된 음성이 참조 음성의 스타일로 최종 변환됩니다
|
823 |
+
|
824 |
+
### 💡 팁:
|
825 |
+
- 참조 음성은 3-10초 길이의 깨끗한 음성을 사용하세요
|
826 |
+
- 노래 음성 변환을 원한다면 "F0 조건부 모델 사용"을 체크하세요
|
827 |
+
- 품질을 높이려면 확산 단계를 50-100으로 설정하세요
|
828 |
+
""")
|
829 |
+
|
830 |
+
# Connect the button click to the processing function
|
831 |
+
convert_btn.click(
|
832 |
+
fn=process_integrated_tts_vc,
|
833 |
+
inputs=[
|
834 |
+
text_input, style_input, speed_input, reference_audio_input,
|
835 |
+
vc_diffusion_steps, vc_length_adjust, vc_inference_cfg_rate,
|
836 |
+
vc_f0_condition, vc_auto_f0_adjust, vc_pitch_shift
|
837 |
+
],
|
838 |
+
outputs=[output_audio],
|
839 |
+
concurrency_limit=1
|
840 |
+
)
|
841 |
+
|
842 |
+
demo.queue()
|
843 |
+
demo.launch()
|
hf_utils.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from huggingface_hub import hf_hub_download
|
3 |
+
|
4 |
+
|
5 |
+
def load_custom_model_from_hf(repo_id, model_filename="pytorch_model.bin", config_filename=None):
|
6 |
+
os.makedirs("./checkpoints", exist_ok=True)
|
7 |
+
model_path = hf_hub_download(repo_id=repo_id, filename=model_filename, cache_dir="./checkpoints")
|
8 |
+
if config_filename is None:
|
9 |
+
return model_path
|
10 |
+
config_path = hf_hub_download(repo_id=repo_id, filename=config_filename, cache_dir="./checkpoints")
|
11 |
+
|
12 |
+
return model_path, config_path
|
modules/__pycache__/audio.cpython-310.pyc
ADDED
Binary file (2.5 kB). View file
|
|
modules/__pycache__/commons.cpython-310.pyc
ADDED
Binary file (13.3 kB). View file
|
|
modules/__pycache__/commons.cpython-38.pyc
ADDED
Binary file (14.2 kB). View file
|
|
modules/__pycache__/diffusion_transformer.cpython-310.pyc
ADDED
Binary file (17.4 kB). View file
|
|
modules/__pycache__/encodec.cpython-310.pyc
ADDED
Binary file (10.8 kB). View file
|
|
modules/__pycache__/flow_matching.cpython-310.pyc
ADDED
Binary file (5.41 kB). View file
|
|
modules/__pycache__/length_regulator.cpython-310.pyc
ADDED
Binary file (4.23 kB). View file
|
|
modules/__pycache__/rmvpe.cpython-310.pyc
ADDED
Binary file (17.6 kB). View file
|
|
modules/__pycache__/wavenet.cpython-310.pyc
ADDED
Binary file (5.15 kB). View file
|
|
modules/alias_free_torch/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
2 |
+
|
3 |
+
from .filter import *
|
4 |
+
from .resample import *
|
5 |
+
from .act import *
|
modules/alias_free_torch/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (198 Bytes). View file
|
|
modules/alias_free_torch/__pycache__/act.cpython-310.pyc
ADDED
Binary file (1.03 kB). View file
|
|
modules/alias_free_torch/__pycache__/filter.cpython-310.pyc
ADDED
Binary file (2.61 kB). View file
|
|
modules/alias_free_torch/__pycache__/resample.cpython-310.pyc
ADDED
Binary file (1.89 kB). View file
|
|
modules/alias_free_torch/act.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
2 |
+
|
3 |
+
import torch.nn as nn
|
4 |
+
from .resample import UpSample1d, DownSample1d
|
5 |
+
|
6 |
+
|
7 |
+
class Activation1d(nn.Module):
|
8 |
+
def __init__(
|
9 |
+
self,
|
10 |
+
activation,
|
11 |
+
up_ratio: int = 2,
|
12 |
+
down_ratio: int = 2,
|
13 |
+
up_kernel_size: int = 12,
|
14 |
+
down_kernel_size: int = 12,
|
15 |
+
):
|
16 |
+
super().__init__()
|
17 |
+
self.up_ratio = up_ratio
|
18 |
+
self.down_ratio = down_ratio
|
19 |
+
self.act = activation
|
20 |
+
self.upsample = UpSample1d(up_ratio, up_kernel_size)
|
21 |
+
self.downsample = DownSample1d(down_ratio, down_kernel_size)
|
22 |
+
|
23 |
+
# x: [B,C,T]
|
24 |
+
def forward(self, x):
|
25 |
+
x = self.upsample(x)
|
26 |
+
x = self.act(x)
|
27 |
+
x = self.downsample(x)
|
28 |
+
|
29 |
+
return x
|
modules/alias_free_torch/filter.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import math
|
7 |
+
|
8 |
+
if "sinc" in dir(torch):
|
9 |
+
sinc = torch.sinc
|
10 |
+
else:
|
11 |
+
# This code is adopted from adefossez's julius.core.sinc under the MIT License
|
12 |
+
# https://adefossez.github.io/julius/julius/core.html
|
13 |
+
def sinc(x: torch.Tensor):
|
14 |
+
"""
|
15 |
+
Implementation of sinc, i.e. sin(pi * x) / (pi * x)
|
16 |
+
__Warning__: Different to julius.sinc, the input is multiplied by `pi`!
|
17 |
+
"""
|
18 |
+
return torch.where(
|
19 |
+
x == 0,
|
20 |
+
torch.tensor(1.0, device=x.device, dtype=x.dtype),
|
21 |
+
torch.sin(math.pi * x) / math.pi / x,
|
22 |
+
)
|
23 |
+
|
24 |
+
|
25 |
+
# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
|
26 |
+
# https://adefossez.github.io/julius/julius/lowpass.html
|
27 |
+
def kaiser_sinc_filter1d(
|
28 |
+
cutoff, half_width, kernel_size
|
29 |
+
): # return filter [1,1,kernel_size]
|
30 |
+
even = kernel_size % 2 == 0
|
31 |
+
half_size = kernel_size // 2
|
32 |
+
|
33 |
+
# For kaiser window
|
34 |
+
delta_f = 4 * half_width
|
35 |
+
A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
|
36 |
+
if A > 50.0:
|
37 |
+
beta = 0.1102 * (A - 8.7)
|
38 |
+
elif A >= 21.0:
|
39 |
+
beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0)
|
40 |
+
else:
|
41 |
+
beta = 0.0
|
42 |
+
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
|
43 |
+
|
44 |
+
# ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
|
45 |
+
if even:
|
46 |
+
time = torch.arange(-half_size, half_size) + 0.5
|
47 |
+
else:
|
48 |
+
time = torch.arange(kernel_size) - half_size
|
49 |
+
if cutoff == 0:
|
50 |
+
filter_ = torch.zeros_like(time)
|
51 |
+
else:
|
52 |
+
filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
|
53 |
+
# Normalize filter to have sum = 1, otherwise we will have a small leakage
|
54 |
+
# of the constant component in the input signal.
|
55 |
+
filter_ /= filter_.sum()
|
56 |
+
filter = filter_.view(1, 1, kernel_size)
|
57 |
+
|
58 |
+
return filter
|
59 |
+
|
60 |
+
|
61 |
+
class LowPassFilter1d(nn.Module):
|
62 |
+
def __init__(
|
63 |
+
self,
|
64 |
+
cutoff=0.5,
|
65 |
+
half_width=0.6,
|
66 |
+
stride: int = 1,
|
67 |
+
padding: bool = True,
|
68 |
+
padding_mode: str = "replicate",
|
69 |
+
kernel_size: int = 12,
|
70 |
+
):
|
71 |
+
# kernel_size should be even number for stylegan3 setup,
|
72 |
+
# in this implementation, odd number is also possible.
|
73 |
+
super().__init__()
|
74 |
+
if cutoff < -0.0:
|
75 |
+
raise ValueError("Minimum cutoff must be larger than zero.")
|
76 |
+
if cutoff > 0.5:
|
77 |
+
raise ValueError("A cutoff above 0.5 does not make sense.")
|
78 |
+
self.kernel_size = kernel_size
|
79 |
+
self.even = kernel_size % 2 == 0
|
80 |
+
self.pad_left = kernel_size // 2 - int(self.even)
|
81 |
+
self.pad_right = kernel_size // 2
|
82 |
+
self.stride = stride
|
83 |
+
self.padding = padding
|
84 |
+
self.padding_mode = padding_mode
|
85 |
+
filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
|
86 |
+
self.register_buffer("filter", filter)
|
87 |
+
|
88 |
+
# input [B, C, T]
|
89 |
+
def forward(self, x):
|
90 |
+
_, C, _ = x.shape
|
91 |
+
|
92 |
+
if self.padding:
|
93 |
+
x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
|
94 |
+
out = F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
|
95 |
+
|
96 |
+
return out
|
modules/alias_free_torch/resample.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
2 |
+
|
3 |
+
import torch.nn as nn
|
4 |
+
from torch.nn import functional as F
|
5 |
+
from .filter import LowPassFilter1d
|
6 |
+
from .filter import kaiser_sinc_filter1d
|
7 |
+
|
8 |
+
|
9 |
+
class UpSample1d(nn.Module):
|
10 |
+
def __init__(self, ratio=2, kernel_size=None):
|
11 |
+
super().__init__()
|
12 |
+
self.ratio = ratio
|
13 |
+
self.kernel_size = (
|
14 |
+
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
15 |
+
)
|
16 |
+
self.stride = ratio
|
17 |
+
self.pad = self.kernel_size // ratio - 1
|
18 |
+
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
|
19 |
+
self.pad_right = (
|
20 |
+
self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
|
21 |
+
)
|
22 |
+
filter = kaiser_sinc_filter1d(
|
23 |
+
cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size
|
24 |
+
)
|
25 |
+
self.register_buffer("filter", filter)
|
26 |
+
|
27 |
+
# x: [B, C, T]
|
28 |
+
def forward(self, x):
|
29 |
+
_, C, _ = x.shape
|
30 |
+
|
31 |
+
x = F.pad(x, (self.pad, self.pad), mode="replicate")
|
32 |
+
x = self.ratio * F.conv_transpose1d(
|
33 |
+
x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C
|
34 |
+
)
|
35 |
+
x = x[..., self.pad_left : -self.pad_right]
|
36 |
+
|
37 |
+
return x
|
38 |
+
|
39 |
+
|
40 |
+
class DownSample1d(nn.Module):
|
41 |
+
def __init__(self, ratio=2, kernel_size=None):
|
42 |
+
super().__init__()
|
43 |
+
self.ratio = ratio
|
44 |
+
self.kernel_size = (
|
45 |
+
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
46 |
+
)
|
47 |
+
self.lowpass = LowPassFilter1d(
|
48 |
+
cutoff=0.5 / ratio,
|
49 |
+
half_width=0.6 / ratio,
|
50 |
+
stride=ratio,
|
51 |
+
kernel_size=self.kernel_size,
|
52 |
+
)
|
53 |
+
|
54 |
+
def forward(self, x):
|
55 |
+
xx = self.lowpass(x)
|
56 |
+
|
57 |
+
return xx
|
modules/astral_quantization/__pycache__/bsq.cpython-310.pyc
ADDED
Binary file (12.7 kB). View file
|
|
modules/astral_quantization/__pycache__/convnext.cpython-310.pyc
ADDED
Binary file (6.89 kB). View file
|
|
modules/astral_quantization/__pycache__/default_model.cpython-310.pyc
ADDED
Binary file (2.82 kB). View file
|
|
modules/astral_quantization/bsq.py
ADDED
@@ -0,0 +1,569 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Lookup Free Quantization
|
3 |
+
Proposed in https://arxiv.org/abs/2310.05737
|
4 |
+
|
5 |
+
In the simplest setup, each dimension is quantized into {-1, 1}.
|
6 |
+
An entropy penalty is used to encourage utilization.
|
7 |
+
"""
|
8 |
+
|
9 |
+
from math import log2, ceil
|
10 |
+
from functools import partial, cache
|
11 |
+
from collections import namedtuple
|
12 |
+
from contextlib import nullcontext
|
13 |
+
|
14 |
+
import torch.distributed as dist
|
15 |
+
from torch.distributed import nn as dist_nn
|
16 |
+
|
17 |
+
import torch
|
18 |
+
from torch import nn, einsum
|
19 |
+
import torch.nn.functional as F
|
20 |
+
from torch.nn import Module
|
21 |
+
from torch.amp import autocast
|
22 |
+
|
23 |
+
from einops import rearrange, reduce, pack, unpack
|
24 |
+
|
25 |
+
# constants
|
26 |
+
|
27 |
+
Return = namedtuple('Return', ['quantized', 'indices', 'entropy_aux_loss'])
|
28 |
+
|
29 |
+
LossBreakdown = namedtuple('LossBreakdown', ['per_sample_entropy', 'batch_entropy', 'commitment'])
|
30 |
+
|
31 |
+
# distributed helpers
|
32 |
+
|
33 |
+
@cache
|
34 |
+
def is_distributed():
|
35 |
+
return dist.is_initialized() and dist.get_world_size() > 1
|
36 |
+
|
37 |
+
def maybe_distributed_mean(t):
|
38 |
+
if not is_distributed():
|
39 |
+
return t
|
40 |
+
|
41 |
+
dist_nn.all_reduce(t)
|
42 |
+
t = t / dist.get_world_size()
|
43 |
+
return t
|
44 |
+
|
45 |
+
# helper functions
|
46 |
+
|
47 |
+
def exists(v):
|
48 |
+
return v is not None
|
49 |
+
|
50 |
+
def identity(t):
|
51 |
+
return t
|
52 |
+
|
53 |
+
def default(*args):
|
54 |
+
for arg in args:
|
55 |
+
if exists(arg):
|
56 |
+
return arg() if callable(arg) else arg
|
57 |
+
return None
|
58 |
+
|
59 |
+
def pack_one(t, pattern):
|
60 |
+
return pack([t], pattern)
|
61 |
+
|
62 |
+
def unpack_one(t, ps, pattern):
|
63 |
+
return unpack(t, ps, pattern)[0]
|
64 |
+
|
65 |
+
def l2norm(t):
|
66 |
+
return F.normalize(t, dim = -1)
|
67 |
+
|
68 |
+
# entropy
|
69 |
+
|
70 |
+
def log(t, eps = 1e-5):
|
71 |
+
return t.clamp(min = eps).log()
|
72 |
+
|
73 |
+
def entropy(prob):
|
74 |
+
return (-prob * log(prob)).sum(dim=-1)
|
75 |
+
|
76 |
+
# cosine sim linear
|
77 |
+
|
78 |
+
class CosineSimLinear(Module):
|
79 |
+
def __init__(
|
80 |
+
self,
|
81 |
+
dim_in,
|
82 |
+
dim_out,
|
83 |
+
scale = 1.
|
84 |
+
):
|
85 |
+
super().__init__()
|
86 |
+
self.scale = scale
|
87 |
+
self.weight = nn.Parameter(torch.randn(dim_in, dim_out))
|
88 |
+
|
89 |
+
def forward(self, x):
|
90 |
+
x = F.normalize(x, dim = -1)
|
91 |
+
w = F.normalize(self.weight, dim = 0)
|
92 |
+
return (x @ w) * self.scale
|
93 |
+
|
94 |
+
def soft_entropy_loss(u, tau=1.0, gamma=1.0):
|
95 |
+
"""
|
96 |
+
Compute the soft entropy loss for Binary Spherical Quantization (BSQ).
|
97 |
+
|
98 |
+
Args:
|
99 |
+
u (torch.Tensor): Input latent embeddings of shape (batch_size, L).
|
100 |
+
tau (float): Temperature scaling factor.
|
101 |
+
gamma (float): Weight for the second entropy term.
|
102 |
+
|
103 |
+
Returns:
|
104 |
+
torch.Tensor: Soft entropy loss.
|
105 |
+
"""
|
106 |
+
# Binary quantization: Generate implicit codebook corners
|
107 |
+
L = u.size(1) # Dimensionality of codebook
|
108 |
+
corners = torch.tensor([-1.0, 1.0], device=u.device) / (L**0.5)
|
109 |
+
|
110 |
+
# Compute soft quantization probabilities for all dimensions
|
111 |
+
# q_hat(c|u) for each dimension
|
112 |
+
prob_matrix = torch.sigmoid(2 * tau * corners.unsqueeze(1) * u.unsqueeze(2)) # Shape: (batch_size, L, 2)
|
113 |
+
|
114 |
+
# Entropy of q_hat(c|u) (independent along each dimension)
|
115 |
+
entropy_per_dim = -torch.sum(prob_matrix * prob_matrix.log(), dim=-1) # Shape: (batch_size, L)
|
116 |
+
entropy_term1 = entropy_per_dim.mean()
|
117 |
+
|
118 |
+
# Expected probabilities for dataset entropy (approximation)
|
119 |
+
expected_probs = prob_matrix.mean(dim=0) # Mean across batch, shape: (L, 2)
|
120 |
+
entropy_term2 = -torch.sum(expected_probs * expected_probs.log(), dim=-1).mean()
|
121 |
+
|
122 |
+
# Final entropy loss
|
123 |
+
loss = entropy_term1 - gamma * entropy_term2
|
124 |
+
return loss
|
125 |
+
|
126 |
+
# class
|
127 |
+
|
128 |
+
class BinarySphericalQuantize(Module):
|
129 |
+
def __init__(
|
130 |
+
self,
|
131 |
+
*,
|
132 |
+
dim = None,
|
133 |
+
codebook_size = None,
|
134 |
+
entropy_loss_weight = 0.1,
|
135 |
+
commitment_loss_weight = 0.,
|
136 |
+
diversity_gamma = 1.,
|
137 |
+
straight_through_activation = nn.Identity(),
|
138 |
+
num_codebooks = 1,
|
139 |
+
keep_num_codebooks_dim = None,
|
140 |
+
codebook_scale = 1., # for residual LFQ, codebook scaled down by 2x at each layer
|
141 |
+
frac_per_sample_entropy = 0.25, # make less than 1. to only use a random fraction of the probs for per sample entropy
|
142 |
+
has_projections = None,
|
143 |
+
projection_has_bias = True,
|
144 |
+
soft_clamp_input_value = None,
|
145 |
+
cosine_sim_project_in = False,
|
146 |
+
cosine_sim_project_in_scale = None,
|
147 |
+
channel_first = None,
|
148 |
+
experimental_softplus_entropy_loss = False,
|
149 |
+
entropy_loss_offset = 5., # how much to shift the loss before softplus
|
150 |
+
spherical = True, # from https://arxiv.org/abs/2406.07548
|
151 |
+
force_quantization_f32 = True, # will force the quantization step to be full precision
|
152 |
+
enable_entropy_loss = True,
|
153 |
+
soft_entropy_loss = True,
|
154 |
+
):
|
155 |
+
super().__init__()
|
156 |
+
|
157 |
+
# some assert validations
|
158 |
+
|
159 |
+
assert exists(dim) or exists(codebook_size), 'either dim or codebook_size must be specified for LFQ'
|
160 |
+
assert not exists(codebook_size) or log2(codebook_size).is_integer(), f'your codebook size must be a power of 2 for lookup free quantization (suggested {2 ** ceil(log2(codebook_size))})'
|
161 |
+
|
162 |
+
codebook_size = default(codebook_size, lambda: 2 ** dim)
|
163 |
+
self.codebook_size = codebook_size
|
164 |
+
|
165 |
+
codebook_dim = int(log2(codebook_size))
|
166 |
+
codebook_dims = codebook_dim * num_codebooks
|
167 |
+
dim = default(dim, codebook_dims)
|
168 |
+
|
169 |
+
has_projections = default(has_projections, dim != codebook_dims)
|
170 |
+
|
171 |
+
if cosine_sim_project_in:
|
172 |
+
cosine_sim_project_in = default(cosine_sim_project_in_scale, codebook_scale)
|
173 |
+
project_in_klass = partial(CosineSimLinear, scale = cosine_sim_project_in)
|
174 |
+
else:
|
175 |
+
project_in_klass = partial(nn.Linear, bias = projection_has_bias)
|
176 |
+
|
177 |
+
self.project_in = project_in_klass(dim, codebook_dims) if has_projections else nn.Identity()
|
178 |
+
self.project_out = nn.Linear(codebook_dims, dim, bias = projection_has_bias) if has_projections else nn.Identity()
|
179 |
+
self.has_projections = has_projections
|
180 |
+
|
181 |
+
self.dim = dim
|
182 |
+
self.codebook_dim = codebook_dim
|
183 |
+
self.num_codebooks = num_codebooks
|
184 |
+
|
185 |
+
keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1)
|
186 |
+
assert not (num_codebooks > 1 and not keep_num_codebooks_dim)
|
187 |
+
self.keep_num_codebooks_dim = keep_num_codebooks_dim
|
188 |
+
|
189 |
+
# channel first
|
190 |
+
|
191 |
+
self.channel_first = channel_first
|
192 |
+
|
193 |
+
# straight through activation
|
194 |
+
|
195 |
+
self.activation = straight_through_activation
|
196 |
+
|
197 |
+
# whether to use BSQ (binary spherical quantization)
|
198 |
+
|
199 |
+
self.spherical = spherical
|
200 |
+
self.maybe_l2norm = (lambda t: l2norm(t) * self.codebook_scale) if spherical else identity
|
201 |
+
|
202 |
+
# entropy aux loss related weights
|
203 |
+
|
204 |
+
assert 0 < frac_per_sample_entropy <= 1.
|
205 |
+
self.frac_per_sample_entropy = frac_per_sample_entropy
|
206 |
+
|
207 |
+
self.diversity_gamma = diversity_gamma
|
208 |
+
self.entropy_loss_weight = entropy_loss_weight
|
209 |
+
|
210 |
+
# codebook scale
|
211 |
+
|
212 |
+
self.codebook_scale = codebook_scale
|
213 |
+
|
214 |
+
# commitment loss
|
215 |
+
|
216 |
+
self.commitment_loss_weight = commitment_loss_weight
|
217 |
+
|
218 |
+
# whether to soft clamp the input value from -value to value
|
219 |
+
|
220 |
+
self.soft_clamp_input_value = soft_clamp_input_value
|
221 |
+
assert not exists(soft_clamp_input_value) or soft_clamp_input_value >= codebook_scale
|
222 |
+
|
223 |
+
# whether to make the entropy loss positive through a softplus (experimental, please report if this worked or not in discussions)
|
224 |
+
|
225 |
+
self.entropy_loss_offset = entropy_loss_offset
|
226 |
+
self.experimental_softplus_entropy_loss = experimental_softplus_entropy_loss
|
227 |
+
|
228 |
+
# for no auxiliary loss, during inference
|
229 |
+
|
230 |
+
self.register_buffer('mask', 2 ** torch.arange(codebook_dim - 1, -1, -1))
|
231 |
+
self.register_buffer('zero', torch.tensor(0.), persistent = False)
|
232 |
+
|
233 |
+
# whether to force quantization step to be f32
|
234 |
+
|
235 |
+
self.force_quantization_f32 = force_quantization_f32
|
236 |
+
|
237 |
+
# codes
|
238 |
+
self.enable_entropy_loss = enable_entropy_loss
|
239 |
+
self.soft_entropy_loss = soft_entropy_loss
|
240 |
+
if codebook_size <= 100000:
|
241 |
+
all_codes = torch.arange(codebook_size)
|
242 |
+
bits = ((all_codes[..., None].int() & self.mask) != 0).float()
|
243 |
+
codebook = self.bits_to_codes(bits)
|
244 |
+
|
245 |
+
self.register_buffer('codebook', codebook.float(), persistent = False)
|
246 |
+
else:
|
247 |
+
all_codes = torch.arange(pow(2, 16))
|
248 |
+
mask = 2 ** torch.arange(16 - 1, -1, -1)
|
249 |
+
bits = ((all_codes[..., None].int() & mask) != 0).float()
|
250 |
+
codebook = self.bits_to_codes(bits)
|
251 |
+
|
252 |
+
self.register_buffer('codebook', codebook.float(), persistent = False)
|
253 |
+
|
254 |
+
def bits_to_codes(self, bits):
|
255 |
+
return bits * self.codebook_scale * 2 - self.codebook_scale
|
256 |
+
|
257 |
+
@property
|
258 |
+
def dtype(self):
|
259 |
+
return self.codebook.dtype
|
260 |
+
|
261 |
+
def indices_to_codes(
|
262 |
+
self,
|
263 |
+
indices,
|
264 |
+
project_out = True
|
265 |
+
):
|
266 |
+
is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim))
|
267 |
+
should_transpose = default(self.channel_first, is_img_or_video)
|
268 |
+
|
269 |
+
if not self.keep_num_codebooks_dim:
|
270 |
+
indices = rearrange(indices, '... -> ... 1')
|
271 |
+
|
272 |
+
# indices to codes, which are bits of either -1 or 1
|
273 |
+
|
274 |
+
bits = ((indices[..., None].int() & self.mask) != 0).to(self.dtype)
|
275 |
+
|
276 |
+
codes = self.bits_to_codes(bits)
|
277 |
+
|
278 |
+
codes = self.maybe_l2norm(codes)
|
279 |
+
|
280 |
+
codes = rearrange(codes, '... c d -> ... (c d)')
|
281 |
+
|
282 |
+
# whether to project codes out to original dimensions
|
283 |
+
# if the input feature dimensions were not log2(codebook size)
|
284 |
+
|
285 |
+
if project_out:
|
286 |
+
codes = self.project_out(codes)
|
287 |
+
|
288 |
+
# rearrange codes back to original shape
|
289 |
+
|
290 |
+
if should_transpose:
|
291 |
+
codes = rearrange(codes, 'b ... d -> b d ...')
|
292 |
+
|
293 |
+
return codes
|
294 |
+
|
295 |
+
def bits_to_z(self, bits):
|
296 |
+
# assert bits must contain only -1 and 1
|
297 |
+
assert torch.all(bits.abs() == 1)
|
298 |
+
quantized = bits.float()
|
299 |
+
quantized = self.maybe_l2norm(quantized)
|
300 |
+
z = self.project_out(quantized)
|
301 |
+
return z
|
302 |
+
|
303 |
+
def forward(
|
304 |
+
self,
|
305 |
+
x,
|
306 |
+
inv_temperature = 100.,
|
307 |
+
return_loss_breakdown = False,
|
308 |
+
mask = None,
|
309 |
+
return_bits = False
|
310 |
+
):
|
311 |
+
"""
|
312 |
+
einstein notation
|
313 |
+
b - batch
|
314 |
+
n - sequence (or flattened spatial dimensions)
|
315 |
+
d - feature dimension, which is also log2(codebook size)
|
316 |
+
c - number of codebook dim
|
317 |
+
"""
|
318 |
+
|
319 |
+
is_img_or_video = x.ndim >= 4
|
320 |
+
should_transpose = default(self.channel_first, is_img_or_video)
|
321 |
+
|
322 |
+
# standardize image or video into (batch, seq, dimension)
|
323 |
+
|
324 |
+
if should_transpose:
|
325 |
+
x = rearrange(x, 'b d ... -> b ... d')
|
326 |
+
x, ps = pack_one(x, 'b * d')
|
327 |
+
|
328 |
+
assert x.shape[-1] == self.dim, f'expected dimension of {self.dim} but received {x.shape[-1]}'
|
329 |
+
|
330 |
+
x = self.project_in(x)
|
331 |
+
|
332 |
+
# maybe soft clamp
|
333 |
+
|
334 |
+
if exists(self.soft_clamp_input_value):
|
335 |
+
clamp_value = self.soft_clamp_input_value
|
336 |
+
x = (x / clamp_value).tanh() * clamp_value
|
337 |
+
|
338 |
+
# split out number of codebooks
|
339 |
+
|
340 |
+
x = rearrange(x, 'b n (c d) -> b n c d', c = self.num_codebooks)
|
341 |
+
|
342 |
+
# maybe l2norm
|
343 |
+
|
344 |
+
x = self.maybe_l2norm(x)
|
345 |
+
|
346 |
+
# whether to force quantization step to be full precision or not
|
347 |
+
|
348 |
+
force_f32 = self.force_quantization_f32
|
349 |
+
|
350 |
+
quantization_context = partial(autocast, 'cuda', enabled = False) if force_f32 else nullcontext
|
351 |
+
|
352 |
+
with quantization_context():
|
353 |
+
|
354 |
+
if force_f32:
|
355 |
+
orig_dtype = x.dtype
|
356 |
+
x = x.float()
|
357 |
+
|
358 |
+
# quantize by eq 3.
|
359 |
+
|
360 |
+
original_input = x
|
361 |
+
|
362 |
+
codebook_value = torch.ones_like(x) * self.codebook_scale
|
363 |
+
quantized = torch.where(x > 0, codebook_value, -codebook_value)
|
364 |
+
if return_bits:
|
365 |
+
return quantized
|
366 |
+
|
367 |
+
# calculate indices
|
368 |
+
|
369 |
+
indices = reduce((quantized > 0).int() * self.mask.int(), 'b n c d -> b n c', 'sum')
|
370 |
+
|
371 |
+
# maybe l2norm
|
372 |
+
|
373 |
+
quantized = self.maybe_l2norm(quantized)
|
374 |
+
|
375 |
+
# use straight-through gradients (optionally with custom activation fn) if training
|
376 |
+
|
377 |
+
if self.training:
|
378 |
+
x = self.activation(x)
|
379 |
+
x = x + (quantized - x).detach()
|
380 |
+
else:
|
381 |
+
x = quantized
|
382 |
+
|
383 |
+
# entropy aux loss
|
384 |
+
if self.soft_entropy_loss:
|
385 |
+
entropy_aux_loss = soft_entropy_loss(x, tau=1.0, gamma=1.0)
|
386 |
+
elif self.training and self.enable_entropy_loss:
|
387 |
+
|
388 |
+
if force_f32:
|
389 |
+
codebook = self.codebook.float()
|
390 |
+
|
391 |
+
codebook = self.maybe_l2norm(codebook)
|
392 |
+
|
393 |
+
# whether to only use a fraction of probs, for reducing memory
|
394 |
+
|
395 |
+
if self.frac_per_sample_entropy < 1.:
|
396 |
+
# account for mask
|
397 |
+
if exists(mask):
|
398 |
+
original_input = original_input[mask]
|
399 |
+
original_input = rearrange(original_input, 'b n ... -> (b n) ...')
|
400 |
+
|
401 |
+
rand_mask = torch.randn(self.codebook_dim).argsort(dim = -1) < 16
|
402 |
+
|
403 |
+
sampled_input = original_input[..., rand_mask]
|
404 |
+
|
405 |
+
sampled_distance = -2 * einsum('... i d, j d -> ... i j', sampled_input, codebook)
|
406 |
+
|
407 |
+
sampled_prob = (-sampled_distance * inv_temperature).softmax(dim = -1)
|
408 |
+
|
409 |
+
per_sample_probs = sampled_prob
|
410 |
+
else:
|
411 |
+
if exists(mask):
|
412 |
+
original_input = original_input[mask]
|
413 |
+
original_input = rearrange(original_input, 'b n ... -> (b n) ...')
|
414 |
+
# the same as euclidean distance up to a constant
|
415 |
+
distance = -2 * einsum('... i d, j d -> ... i j', original_input, codebook)
|
416 |
+
|
417 |
+
prob = (-distance * inv_temperature).softmax(dim = -1)
|
418 |
+
|
419 |
+
per_sample_probs = prob
|
420 |
+
|
421 |
+
# calculate per sample entropy
|
422 |
+
|
423 |
+
per_sample_entropy = entropy(per_sample_probs).mean()
|
424 |
+
|
425 |
+
# distribution over all available tokens in the batch
|
426 |
+
|
427 |
+
avg_prob = reduce(per_sample_probs, '... c d -> c d', 'mean')
|
428 |
+
|
429 |
+
avg_prob = maybe_distributed_mean(avg_prob)
|
430 |
+
|
431 |
+
codebook_entropy = entropy(avg_prob).mean()
|
432 |
+
|
433 |
+
# 1. entropy will be nudged to be low for each code, to encourage the network to output confident predictions
|
434 |
+
# 2. codebook entropy will be nudged to be high, to encourage all codes to be uniformly used within the batch
|
435 |
+
|
436 |
+
entropy_aux_loss = per_sample_entropy - self.diversity_gamma * codebook_entropy
|
437 |
+
else:
|
438 |
+
# if not training, just return dummy 0
|
439 |
+
entropy_aux_loss = per_sample_entropy = codebook_entropy = self.zero
|
440 |
+
|
441 |
+
# whether to make the entropy loss positive or not through a (shifted) softplus
|
442 |
+
|
443 |
+
if self.training and self.experimental_softplus_entropy_loss:
|
444 |
+
entropy_aux_loss = F.softplus(entropy_aux_loss + self.entropy_loss_offset)
|
445 |
+
|
446 |
+
# commit loss
|
447 |
+
|
448 |
+
if self.training and self.commitment_loss_weight > 0.:
|
449 |
+
|
450 |
+
commit_loss = F.mse_loss(original_input, quantized.detach(), reduction = 'none')
|
451 |
+
|
452 |
+
if exists(mask):
|
453 |
+
commit_loss = commit_loss[mask]
|
454 |
+
|
455 |
+
commit_loss = commit_loss.mean()
|
456 |
+
else:
|
457 |
+
commit_loss = self.zero
|
458 |
+
|
459 |
+
# input back to original dtype if needed
|
460 |
+
|
461 |
+
if force_f32:
|
462 |
+
x = x.type(orig_dtype)
|
463 |
+
|
464 |
+
# merge back codebook dim
|
465 |
+
|
466 |
+
x = rearrange(x, 'b n c d -> b n (c d)')
|
467 |
+
|
468 |
+
# project out to feature dimension if needed
|
469 |
+
|
470 |
+
x = self.project_out(x)
|
471 |
+
|
472 |
+
# reconstitute image or video dimensions
|
473 |
+
|
474 |
+
if should_transpose:
|
475 |
+
x = unpack_one(x, ps, 'b * d')
|
476 |
+
x = rearrange(x, 'b ... d -> b d ...')
|
477 |
+
|
478 |
+
indices = unpack_one(indices, ps, 'b * c')
|
479 |
+
|
480 |
+
# whether to remove single codebook dim
|
481 |
+
|
482 |
+
if not self.keep_num_codebooks_dim:
|
483 |
+
indices = rearrange(indices, '... 1 -> ...')
|
484 |
+
|
485 |
+
# complete aux loss
|
486 |
+
|
487 |
+
aux_loss = entropy_aux_loss * self.entropy_loss_weight + commit_loss * self.commitment_loss_weight
|
488 |
+
|
489 |
+
# returns
|
490 |
+
|
491 |
+
ret = Return(x, indices, aux_loss)
|
492 |
+
|
493 |
+
if not return_loss_breakdown:
|
494 |
+
return ret
|
495 |
+
|
496 |
+
return ret, LossBreakdown(per_sample_entropy, codebook_entropy, commit_loss)
|
497 |
+
|
498 |
+
class GroupedResidualBSQ(Module):
|
499 |
+
def __init__(
|
500 |
+
self,
|
501 |
+
*,
|
502 |
+
dim,
|
503 |
+
groups = 1,
|
504 |
+
accept_image_fmap = False,
|
505 |
+
**kwargs
|
506 |
+
):
|
507 |
+
super().__init__()
|
508 |
+
self.dim = dim
|
509 |
+
self.groups = groups
|
510 |
+
assert (dim % groups) == 0
|
511 |
+
dim_per_group = dim // groups
|
512 |
+
|
513 |
+
self.accept_image_fmap = accept_image_fmap
|
514 |
+
|
515 |
+
self.rvqs = nn.ModuleList([])
|
516 |
+
|
517 |
+
for _ in range(groups):
|
518 |
+
self.rvqs.append(LFQ(
|
519 |
+
dim = dim_per_group,
|
520 |
+
**kwargs
|
521 |
+
))
|
522 |
+
|
523 |
+
self.codebook_size = self.rvqs[0].codebook_size
|
524 |
+
|
525 |
+
@property
|
526 |
+
def codebooks(self):
|
527 |
+
return torch.stack(tuple(rvq.codebooks for rvq in self.rvqs))
|
528 |
+
|
529 |
+
@property
|
530 |
+
def split_dim(self):
|
531 |
+
return 1 if self.accept_image_fmap else -1
|
532 |
+
|
533 |
+
def get_codes_from_indices(self, indices):
|
534 |
+
codes = tuple(rvq.get_codes_from_indices(chunk_indices) for rvq, chunk_indices in zip(self.rvqs, indices))
|
535 |
+
return torch.stack(codes)
|
536 |
+
|
537 |
+
def get_output_from_indices(self, indices):
|
538 |
+
outputs = tuple(rvq.get_output_from_indices(chunk_indices) for rvq, chunk_indices in zip(self.rvqs, indices))
|
539 |
+
return torch.cat(outputs, dim = self.split_dim)
|
540 |
+
|
541 |
+
def forward(
|
542 |
+
self,
|
543 |
+
x,
|
544 |
+
return_all_codes = False
|
545 |
+
):
|
546 |
+
shape, split_dim = x.shape, self.split_dim
|
547 |
+
assert shape[split_dim] == self.dim
|
548 |
+
|
549 |
+
# split the feature dimension into groups
|
550 |
+
|
551 |
+
x = x.chunk(self.groups, dim = split_dim)
|
552 |
+
|
553 |
+
forward_kwargs = dict(
|
554 |
+
)
|
555 |
+
|
556 |
+
# invoke residual vq on each group
|
557 |
+
|
558 |
+
out = tuple(rvq(chunk, **forward_kwargs) for rvq, chunk in zip(self.rvqs, x))
|
559 |
+
out = tuple(zip(*out))
|
560 |
+
|
561 |
+
# otherwise, get all the zipped outputs and combine them
|
562 |
+
|
563 |
+
quantized, all_indices, *maybe_aux_loss = out
|
564 |
+
|
565 |
+
quantized = torch.cat(quantized, dim = split_dim)
|
566 |
+
all_indices = torch.stack(all_indices)
|
567 |
+
|
568 |
+
ret = (quantized, all_indices, *maybe_aux_loss)
|
569 |
+
return ret
|
modules/astral_quantization/convnext.py
ADDED
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from typing import List
|
5 |
+
|
6 |
+
|
7 |
+
class ConvNextV2LayerNorm(nn.Module):
|
8 |
+
r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
|
9 |
+
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height,
|
10 |
+
width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width).
|
11 |
+
"""
|
12 |
+
|
13 |
+
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
|
14 |
+
super().__init__()
|
15 |
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
16 |
+
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
17 |
+
self.eps = eps
|
18 |
+
self.data_format = data_format
|
19 |
+
if self.data_format not in ["channels_last", "channels_first"]:
|
20 |
+
raise NotImplementedError(f"Unsupported data format: {self.data_format}")
|
21 |
+
self.normalized_shape = (normalized_shape,)
|
22 |
+
|
23 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
24 |
+
if self.data_format == "channels_last":
|
25 |
+
x = torch.nn.functional.layer_norm(
|
26 |
+
x, self.normalized_shape, self.weight, self.bias, self.eps
|
27 |
+
)
|
28 |
+
elif self.data_format == "channels_first":
|
29 |
+
input_dtype = x.dtype
|
30 |
+
x = x.float()
|
31 |
+
u = x.mean(1, keepdim=True)
|
32 |
+
s = (x - u).pow(2).mean(1, keepdim=True)
|
33 |
+
x = (x - u) / torch.sqrt(s + self.eps)
|
34 |
+
x = x.to(dtype=input_dtype)
|
35 |
+
x = self.weight[None, :, None] * x + self.bias[None, :, None]
|
36 |
+
return x
|
37 |
+
|
38 |
+
|
39 |
+
class GRN(nn.Module):
|
40 |
+
def __init__(self, dim):
|
41 |
+
super().__init__()
|
42 |
+
self.gamma = nn.Parameter(torch.zeros(1, 1, dim))
|
43 |
+
self.beta = nn.Parameter(torch.zeros(1, 1, dim))
|
44 |
+
|
45 |
+
def forward(self, x):
|
46 |
+
Gx = torch.norm(x, p=2, dim=1, keepdim=True)
|
47 |
+
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
|
48 |
+
return self.gamma * (x * Nx) + self.beta + x
|
49 |
+
|
50 |
+
class InterpolationLayer(nn.Module):
|
51 |
+
def __init__(self, ): # this is a default of 1 / 50 * (44100 / 512) / 4
|
52 |
+
super().__init__()
|
53 |
+
pass
|
54 |
+
|
55 |
+
def forward(self, x: torch.Tensor, target_len: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
56 |
+
x = F.interpolate(x, size=target_len, mode='linear')
|
57 |
+
return x
|
58 |
+
|
59 |
+
class ConvNeXtV2Stage(nn.Module):
|
60 |
+
def __init__(
|
61 |
+
self,
|
62 |
+
dim: int = 512,
|
63 |
+
intermediate_dim: int = 2048,
|
64 |
+
num_blocks: int = 1,
|
65 |
+
dilation: int = 1,
|
66 |
+
downsample_layer_indices: List[int] = None,
|
67 |
+
downsample_factors: List[int] = None,
|
68 |
+
upsample_layer_indices: List[int] = None,
|
69 |
+
upsample_factors: List[int] = None,
|
70 |
+
interpolation_layer_indices: List[int] = None,
|
71 |
+
input_dim: int = None,
|
72 |
+
output_dim: int = None,
|
73 |
+
gin_channels: int = 0,
|
74 |
+
):
|
75 |
+
super().__init__()
|
76 |
+
# maybe downsample layers
|
77 |
+
if downsample_layer_indices is not None:
|
78 |
+
assert downsample_factors is not None
|
79 |
+
self.downsample_blocks = nn.ModuleList(
|
80 |
+
[
|
81 |
+
nn.Sequential(
|
82 |
+
ConvNextV2LayerNorm(dim, data_format="channels_first"),
|
83 |
+
nn.Conv1d(
|
84 |
+
dim, dim, kernel_size=downsample_factor, stride=downsample_factor
|
85 |
+
),
|
86 |
+
) for _, downsample_factor in zip(downsample_layer_indices, downsample_factors)
|
87 |
+
]
|
88 |
+
)
|
89 |
+
self.downsample_layer_indices = downsample_layer_indices
|
90 |
+
else:
|
91 |
+
self.downsample_blocks = nn.ModuleList()
|
92 |
+
self.downsample_layer_indices = []
|
93 |
+
|
94 |
+
# maybe upsample layers
|
95 |
+
if upsample_layer_indices is not None:
|
96 |
+
assert upsample_factors is not None
|
97 |
+
self.upsample_blocks = nn.ModuleList(
|
98 |
+
[
|
99 |
+
nn.Sequential(
|
100 |
+
ConvNextV2LayerNorm(dim, data_format="channels_first"),
|
101 |
+
nn.ConvTranspose1d(
|
102 |
+
dim, dim, kernel_size=upsample_factor, stride=upsample_factor
|
103 |
+
),
|
104 |
+
) for _, upsample_factor in zip(upsample_layer_indices, upsample_factors)
|
105 |
+
]
|
106 |
+
)
|
107 |
+
self.upsample_layer_indices = upsample_layer_indices
|
108 |
+
else:
|
109 |
+
self.upsample_blocks = nn.ModuleList()
|
110 |
+
self.upsample_layer_indices = []
|
111 |
+
|
112 |
+
# maybe interpolation layers
|
113 |
+
if interpolation_layer_indices is not None:
|
114 |
+
self.interpolation_blocks = nn.ModuleList(
|
115 |
+
[
|
116 |
+
InterpolationLayer()
|
117 |
+
for _ in interpolation_layer_indices
|
118 |
+
]
|
119 |
+
)
|
120 |
+
self.interpolation_layer_indices = interpolation_layer_indices
|
121 |
+
else:
|
122 |
+
self.interpolation_blocks = nn.ModuleList()
|
123 |
+
self.interpolation_layer_indices = []
|
124 |
+
|
125 |
+
# main blocks
|
126 |
+
self.blocks = nn.ModuleList(
|
127 |
+
[
|
128 |
+
ConvNeXtV2Block(
|
129 |
+
dim=dim,
|
130 |
+
intermediate_dim=intermediate_dim,
|
131 |
+
dilation=dilation,
|
132 |
+
)
|
133 |
+
for _ in range(num_blocks)
|
134 |
+
]
|
135 |
+
)
|
136 |
+
# maybe input and output projections
|
137 |
+
if input_dim is not None and input_dim != dim:
|
138 |
+
self.input_projection = nn.Conv1d(input_dim, dim, kernel_size=1)
|
139 |
+
else:
|
140 |
+
self.input_projection = nn.Identity()
|
141 |
+
if output_dim is not None and output_dim != dim:
|
142 |
+
self.output_projection = nn.Conv1d(dim, output_dim, kernel_size=1)
|
143 |
+
else:
|
144 |
+
self.output_projection = nn.Identity()
|
145 |
+
|
146 |
+
if gin_channels > 0:
|
147 |
+
self.gin = nn.Conv1d(gin_channels, dim, kernel_size=1)
|
148 |
+
|
149 |
+
def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
150 |
+
x = self.input_projection(x) # B, D, T
|
151 |
+
if hasattr(self, 'gin'):
|
152 |
+
g = kwargs['g']
|
153 |
+
x = x + self.gin(g)
|
154 |
+
# pad to a multiple of cumprod(downsample_factors)
|
155 |
+
if len(self.downsample_blocks) > 0:
|
156 |
+
downsample_factor = 1
|
157 |
+
for factor in self.downsample_blocks:
|
158 |
+
downsample_factor *= factor[1].stride[0]
|
159 |
+
pad_len = downsample_factor - x.size(-1) % downsample_factor
|
160 |
+
if pad_len > 0:
|
161 |
+
x = torch.cat([x, torch.zeros_like(x[:, :, :pad_len])], dim=-1)
|
162 |
+
|
163 |
+
# main blocks
|
164 |
+
for layer_idx, block in enumerate(self.blocks):
|
165 |
+
if layer_idx in self.downsample_layer_indices:
|
166 |
+
x = self.downsample_blocks[self.downsample_layer_indices.index(layer_idx)](x)
|
167 |
+
if layer_idx in self.upsample_layer_indices:
|
168 |
+
x = self.upsample_blocks[self.upsample_layer_indices.index(layer_idx)](x)
|
169 |
+
if layer_idx in self.interpolation_layer_indices:
|
170 |
+
x = self.interpolation_blocks[self.interpolation_layer_indices.index(layer_idx)](x, target_len=kwargs['target_len'])
|
171 |
+
x = block(x)
|
172 |
+
x = self.output_projection(x)
|
173 |
+
return x
|
174 |
+
|
175 |
+
def setup_caches(self, *args, **kwargs):
|
176 |
+
pass
|
177 |
+
|
178 |
+
|
179 |
+
class ConvNeXtV2Block(nn.Module):
|
180 |
+
def __init__(
|
181 |
+
self,
|
182 |
+
dim: int,
|
183 |
+
intermediate_dim: int,
|
184 |
+
dilation: int = 1,
|
185 |
+
):
|
186 |
+
super().__init__()
|
187 |
+
padding = (dilation * (7 - 1)) // 2
|
188 |
+
self.dwconv = nn.Conv1d(
|
189 |
+
dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation
|
190 |
+
) # depthwise conv
|
191 |
+
self.norm = ConvNextV2LayerNorm(dim, data_format="channels_first")
|
192 |
+
self.pwconv1 = nn.Linear(
|
193 |
+
dim, intermediate_dim
|
194 |
+
) # pointwise/1x1 convs, implemented with linear layers
|
195 |
+
self.act = nn.GELU()
|
196 |
+
self.grn = GRN(intermediate_dim)
|
197 |
+
self.pwconv2 = nn.Linear(intermediate_dim, dim)
|
198 |
+
|
199 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
200 |
+
residual = x
|
201 |
+
x = self.dwconv(x)
|
202 |
+
x = self.norm(x)
|
203 |
+
x = x.transpose(1, 2) # b d n -> b n d
|
204 |
+
x = self.pwconv1(x)
|
205 |
+
x = self.act(x)
|
206 |
+
x = self.grn(x)
|
207 |
+
x = self.pwconv2(x)
|
208 |
+
x = x.transpose(1, 2) # b n d -> b d n
|
209 |
+
return residual + x
|
modules/astral_quantization/default_model.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import AutoTokenizer, AutoModel, Wav2Vec2FeatureExtractor
|
3 |
+
|
4 |
+
class AstralQuantizer(torch.nn.Module):
|
5 |
+
def __init__(
|
6 |
+
self,
|
7 |
+
tokenizer_name: str,
|
8 |
+
ssl_model_name: str,
|
9 |
+
ssl_output_layer: int,
|
10 |
+
encoder: torch.nn.Module,
|
11 |
+
quantizer: torch.nn.Module,
|
12 |
+
skip_ssl: bool = False,
|
13 |
+
):
|
14 |
+
super().__init__()
|
15 |
+
self.encoder = encoder
|
16 |
+
self.quantizer = quantizer
|
17 |
+
self.tokenizer_name = tokenizer_name
|
18 |
+
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
19 |
+
|
20 |
+
# Load SSL model from Huggingface
|
21 |
+
self.ssl_model_name = ssl_model_name
|
22 |
+
self.ssl_output_layer = ssl_output_layer
|
23 |
+
self.ssl_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(ssl_model_name)
|
24 |
+
|
25 |
+
if skip_ssl: # in case the same SSL model has been loaded somewhere else
|
26 |
+
self.ssl_model = None
|
27 |
+
else:
|
28 |
+
self.ssl_model = AutoModel.from_pretrained(ssl_model_name).eval()
|
29 |
+
self.ssl_model.encoder.layers = self.ssl_model.encoder.layers[:ssl_output_layer]
|
30 |
+
self.ssl_model.encoder.layer_norm = torch.nn.Identity()
|
31 |
+
|
32 |
+
def load_separate_checkpoint(self, checkpoint_path):
|
33 |
+
params = torch.load(checkpoint_path, map_location='cpu')['net']
|
34 |
+
for key in params.keys():
|
35 |
+
for k in list(params[key].keys()):
|
36 |
+
if k.startswith("module."):
|
37 |
+
params[key][k[len("module."):]] = params[key][k]
|
38 |
+
del params[key][k]
|
39 |
+
self.encoder.load_state_dict(params['encoder'])
|
40 |
+
self.quantizer.load_state_dict(params['vq'])
|
41 |
+
if self.decoder is not None:
|
42 |
+
self.decoder.load_state_dict(params['decoder'])
|
43 |
+
if self.asr_decoder is not None:
|
44 |
+
self.asr_decoder.load_state_dict(params['predictor'], strict=False)
|
45 |
+
|
46 |
+
def forward(self, waves_16k, wave_16k_lens, ssl_model=None):
|
47 |
+
ssl_fn = self.ssl_model if self.ssl_model else ssl_model
|
48 |
+
assert ssl_fn is not None, "In case in-class SSL model loading is skipped, external ssl_model must be provided"
|
49 |
+
waves_16k_input_list = [
|
50 |
+
waves_16k[bib, :wave_16k_lens[bib]].cpu().numpy()
|
51 |
+
for bib in range(len(waves_16k))
|
52 |
+
]
|
53 |
+
alt_inputs = self.ssl_feature_extractor(
|
54 |
+
waves_16k_input_list,
|
55 |
+
return_tensors='pt',
|
56 |
+
return_attention_mask=True,
|
57 |
+
padding=True,
|
58 |
+
sampling_rate=16000
|
59 |
+
).to(waves_16k.device)
|
60 |
+
feature_lens = alt_inputs.data['attention_mask'].sum(-1) // 320 # frame rate of hubert is 50 Hz
|
61 |
+
|
62 |
+
outputs = ssl_fn(
|
63 |
+
alt_inputs.input_values,
|
64 |
+
attention_mask=alt_inputs.attention_mask,
|
65 |
+
)
|
66 |
+
last_hidden_states = outputs.last_hidden_state
|
67 |
+
last_hidden_states = last_hidden_states[:, :feature_lens.max(), :]
|
68 |
+
feature_lens = feature_lens.clamp(max=last_hidden_states.size(1))
|
69 |
+
last_hidden_states = last_hidden_states.transpose(1, 2)
|
70 |
+
x_hidden = self.encoder(last_hidden_states, feature_lens)
|
71 |
+
x_hidden = x_hidden.transpose(1, 2)
|
72 |
+
x_quantized, indices = self.quantizer(x_hidden)[:2]
|
73 |
+
return x_quantized, indices, feature_lens
|
modules/astral_quantization/transformer.py
ADDED
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
from dataclasses import dataclass
|
7 |
+
from typing import Optional
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from torch import Tensor
|
12 |
+
from torch.nn import functional as F
|
13 |
+
import time
|
14 |
+
|
15 |
+
def find_multiple(n: int, k: int) -> int:
|
16 |
+
if n % k == 0:
|
17 |
+
return n
|
18 |
+
return n + k - (n % k)
|
19 |
+
|
20 |
+
class AdaptiveLayerNorm(nn.Module):
|
21 |
+
r"""Adaptive Layer Normalization"""
|
22 |
+
|
23 |
+
def __init__(self, d_model, norm) -> None:
|
24 |
+
super(AdaptiveLayerNorm, self).__init__()
|
25 |
+
self.project_layer = nn.Linear(d_model, 2 * d_model)
|
26 |
+
self.norm = norm
|
27 |
+
self.d_model = d_model
|
28 |
+
self.eps = self.norm.eps
|
29 |
+
|
30 |
+
def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor:
|
31 |
+
if embedding is None:
|
32 |
+
return self.norm(input)
|
33 |
+
weight, bias = torch.split(
|
34 |
+
self.project_layer(embedding),
|
35 |
+
split_size_or_sections=self.d_model,
|
36 |
+
dim=-1,
|
37 |
+
)
|
38 |
+
return weight * self.norm(input) + bias
|
39 |
+
|
40 |
+
|
41 |
+
@dataclass
|
42 |
+
class ModelArgs:
|
43 |
+
block_size: int = 2048
|
44 |
+
vocab_size: int = 32000
|
45 |
+
n_layer: int = 32
|
46 |
+
n_head: int = 32
|
47 |
+
dim: int = 4096
|
48 |
+
intermediate_size: int = None
|
49 |
+
n_local_heads: int = -1
|
50 |
+
head_dim: int = 64
|
51 |
+
rope_base: float = 10000
|
52 |
+
norm_eps: float = 1e-5
|
53 |
+
has_cross_attention: bool = False
|
54 |
+
context_dim: int = 0
|
55 |
+
is_causal: bool = False
|
56 |
+
dropout_rate: float = 0.1
|
57 |
+
attn_dropout_rate: float = 0.1
|
58 |
+
|
59 |
+
def __post_init__(self):
|
60 |
+
if self.n_local_heads == -1:
|
61 |
+
self.n_local_heads = self.n_head
|
62 |
+
if self.intermediate_size is None:
|
63 |
+
hidden_dim = 4 * self.dim
|
64 |
+
n_hidden = int(2 * hidden_dim / 3)
|
65 |
+
self.intermediate_size = find_multiple(n_hidden, 256)
|
66 |
+
# self.head_dim = self.dim // self.n_head
|
67 |
+
|
68 |
+
class Transformer(nn.Module):
|
69 |
+
def __init__(self, config: ModelArgs) -> None:
|
70 |
+
super().__init__()
|
71 |
+
self.config = config
|
72 |
+
|
73 |
+
self.layers = nn.ModuleList(TransformerBlock(config) for _ in range(config.n_layer))
|
74 |
+
self.norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps))
|
75 |
+
|
76 |
+
self.max_batch_size = -1
|
77 |
+
self.max_seq_length = config.block_size
|
78 |
+
freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.head_dim,
|
79 |
+
self.config.rope_base)
|
80 |
+
self.register_buffer("freqs_cis", freqs_cis)
|
81 |
+
|
82 |
+
causal_mask = torch.tril(
|
83 |
+
torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)
|
84 |
+
)
|
85 |
+
self.register_buffer("causal_mask", causal_mask)
|
86 |
+
|
87 |
+
def forward(self,
|
88 |
+
x: Tensor,
|
89 |
+
c: Tensor,
|
90 |
+
input_pos: Optional[Tensor] = None,
|
91 |
+
mask: Optional[Tensor] = None,
|
92 |
+
context: Optional[Tensor] = None,
|
93 |
+
context_input_pos: Optional[Tensor] = None,
|
94 |
+
cross_attention_mask: Optional[Tensor] = None,
|
95 |
+
) -> Tensor:
|
96 |
+
if mask is None:
|
97 |
+
mask = self.causal_mask[:x.size(1), :x.size(1)]
|
98 |
+
else:
|
99 |
+
mask = mask[..., input_pos]
|
100 |
+
freqs_cis = self.freqs_cis[input_pos]
|
101 |
+
if context is not None:
|
102 |
+
context_freqs_cis = self.freqs_cis[context_input_pos]
|
103 |
+
else:
|
104 |
+
context_freqs_cis = None
|
105 |
+
skip_in_x_list = []
|
106 |
+
for i, layer in enumerate(self.layers):
|
107 |
+
x = layer(x, c, freqs_cis, mask, context, context_freqs_cis, cross_attention_mask)
|
108 |
+
x = self.norm(x, c)
|
109 |
+
return x
|
110 |
+
|
111 |
+
|
112 |
+
class TransformerBlock(nn.Module):
|
113 |
+
def __init__(self, config: ModelArgs) -> None:
|
114 |
+
super().__init__()
|
115 |
+
self.attention = Attention(config)
|
116 |
+
self.feed_forward = FeedForward(config)
|
117 |
+
self.ffn_norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps))
|
118 |
+
self.attention_norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps))
|
119 |
+
|
120 |
+
if config.has_cross_attention:
|
121 |
+
self.has_cross_attention = True
|
122 |
+
self.cross_attention = Attention(config, is_cross_attention=True)
|
123 |
+
self.cross_attention_norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps))
|
124 |
+
else:
|
125 |
+
self.has_cross_attention = False
|
126 |
+
|
127 |
+
def forward(self,
|
128 |
+
x: Tensor,
|
129 |
+
c: Tensor,
|
130 |
+
freqs_cis: Tensor,
|
131 |
+
mask: Tensor,
|
132 |
+
context: Optional[Tensor] = None,
|
133 |
+
context_freqs_cis: Optional[Tensor] = None,
|
134 |
+
cross_attention_mask: Optional[Tensor] = None,
|
135 |
+
) -> Tensor:
|
136 |
+
#time_attn_start = time.time()
|
137 |
+
h = x + self.attention(self.attention_norm(x, c), freqs_cis, mask)
|
138 |
+
#print(f"time take for attention of sequence length {x.shape[1]} is {time.time() - time_attn_start}")
|
139 |
+
if self.has_cross_attention:
|
140 |
+
h = h + self.cross_attention(self.cross_attention_norm(h, c), freqs_cis, cross_attention_mask, context, context_freqs_cis)
|
141 |
+
out = h + self.feed_forward(self.ffn_norm(h, c))
|
142 |
+
return out
|
143 |
+
|
144 |
+
|
145 |
+
class Attention(nn.Module):
|
146 |
+
def __init__(self, config: ModelArgs, is_cross_attention: bool = False):
|
147 |
+
super().__init__()
|
148 |
+
assert config.dim % config.n_head == 0
|
149 |
+
|
150 |
+
total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
|
151 |
+
# key, query, value projections for all heads, but in a batch
|
152 |
+
if is_cross_attention:
|
153 |
+
self.wq = nn.Linear(config.dim, config.n_head * config.head_dim, bias=False)
|
154 |
+
self.wkv = nn.Linear(config.context_dim, 2 * config.n_local_heads * config.head_dim, bias=False)
|
155 |
+
else:
|
156 |
+
self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
|
157 |
+
self.wo = nn.Linear(config.head_dim * config.n_head, config.dim, bias=False)
|
158 |
+
self.kv_cache = None
|
159 |
+
|
160 |
+
self.n_head = config.n_head
|
161 |
+
self.head_dim = config.head_dim
|
162 |
+
self.n_local_heads = config.n_local_heads
|
163 |
+
self.dim = config.dim
|
164 |
+
self.attn_dropout_rate = config.attn_dropout_rate
|
165 |
+
|
166 |
+
def forward(self,
|
167 |
+
x: Tensor,
|
168 |
+
freqs_cis: Tensor,
|
169 |
+
mask: Tensor,
|
170 |
+
context: Optional[Tensor] = None,
|
171 |
+
context_freqs_cis: Optional[Tensor] = None,
|
172 |
+
) -> Tensor:
|
173 |
+
bsz, seqlen, _ = x.shape
|
174 |
+
|
175 |
+
kv_size = self.n_local_heads * self.head_dim
|
176 |
+
if context is None:
|
177 |
+
q, k, v = self.wqkv(x).split([kv_size, kv_size, kv_size], dim=-1)
|
178 |
+
context_seqlen = seqlen
|
179 |
+
else:
|
180 |
+
q = self.wq(x)
|
181 |
+
k, v = self.wkv(context).split([kv_size, kv_size], dim=-1)
|
182 |
+
context_seqlen = context.shape[1]
|
183 |
+
|
184 |
+
q = q.view(bsz, seqlen, self.n_head, self.head_dim)
|
185 |
+
k = k.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
|
186 |
+
v = v.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
|
187 |
+
|
188 |
+
q = apply_rotary_emb(q, freqs_cis)
|
189 |
+
k = apply_rotary_emb(k, context_freqs_cis if context_freqs_cis is not None else freqs_cis)
|
190 |
+
|
191 |
+
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
|
192 |
+
|
193 |
+
k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
|
194 |
+
v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
|
195 |
+
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=self.attn_dropout_rate if self.training else 0.0)
|
196 |
+
|
197 |
+
y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.head_dim * self.n_head)
|
198 |
+
|
199 |
+
y = self.wo(y)
|
200 |
+
return y
|
201 |
+
|
202 |
+
|
203 |
+
class FeedForward(nn.Module):
|
204 |
+
def __init__(self, config: ModelArgs) -> None:
|
205 |
+
super().__init__()
|
206 |
+
self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
|
207 |
+
self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
|
208 |
+
self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
|
209 |
+
self.dropout = nn.Dropout(config.dropout_rate)
|
210 |
+
|
211 |
+
def forward(self, x: Tensor) -> Tensor:
|
212 |
+
return self.w2(self.dropout(F.silu(self.w1(x)) * self.w3(x)))
|
213 |
+
|
214 |
+
|
215 |
+
class RMSNorm(nn.Module):
|
216 |
+
def __init__(self, dim: int, eps: float = 1e-5):
|
217 |
+
super().__init__()
|
218 |
+
self.eps = eps
|
219 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
220 |
+
|
221 |
+
def _norm(self, x):
|
222 |
+
return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
|
223 |
+
|
224 |
+
def forward(self, x: Tensor) -> Tensor:
|
225 |
+
output = self._norm(x.float()).type_as(x)
|
226 |
+
return output * self.weight
|
227 |
+
|
228 |
+
|
229 |
+
def precompute_freqs_cis(
|
230 |
+
seq_len: int, n_elem: int, base: int = 10000,
|
231 |
+
dtype: torch.dtype = torch.bfloat16
|
232 |
+
) -> Tensor:
|
233 |
+
freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem))
|
234 |
+
t = torch.arange(seq_len, device=freqs.device)
|
235 |
+
freqs = torch.outer(t, freqs)
|
236 |
+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
237 |
+
cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
|
238 |
+
return cache.to(dtype=dtype)
|
239 |
+
|
240 |
+
|
241 |
+
def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
|
242 |
+
xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
|
243 |
+
freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
|
244 |
+
x_out2 = torch.stack(
|
245 |
+
[
|
246 |
+
xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
|
247 |
+
xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
|
248 |
+
],
|
249 |
+
-1,
|
250 |
+
)
|
251 |
+
|
252 |
+
x_out2 = x_out2.flatten(3)
|
253 |
+
return x_out2.type_as(x)
|
254 |
+
|
modules/audio.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.utils.data
|
4 |
+
from librosa.filters import mel as librosa_mel_fn
|
5 |
+
from scipy.io.wavfile import read
|
6 |
+
|
7 |
+
MAX_WAV_VALUE = 32768.0
|
8 |
+
|
9 |
+
|
10 |
+
def load_wav(full_path):
|
11 |
+
sampling_rate, data = read(full_path)
|
12 |
+
return data, sampling_rate
|
13 |
+
|
14 |
+
|
15 |
+
def dynamic_range_compression(x, C=1, clip_val=1e-5):
|
16 |
+
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
|
17 |
+
|
18 |
+
|
19 |
+
def dynamic_range_decompression(x, C=1):
|
20 |
+
return np.exp(x) / C
|
21 |
+
|
22 |
+
|
23 |
+
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
24 |
+
return torch.log(torch.clamp(x, min=clip_val) * C)
|
25 |
+
|
26 |
+
|
27 |
+
def dynamic_range_decompression_torch(x, C=1):
|
28 |
+
return torch.exp(x) / C
|
29 |
+
|
30 |
+
|
31 |
+
def spectral_normalize_torch(magnitudes):
|
32 |
+
output = dynamic_range_compression_torch(magnitudes)
|
33 |
+
return output
|
34 |
+
|
35 |
+
|
36 |
+
def spectral_de_normalize_torch(magnitudes):
|
37 |
+
output = dynamic_range_decompression_torch(magnitudes)
|
38 |
+
return output
|
39 |
+
|
40 |
+
|
41 |
+
mel_basis = {}
|
42 |
+
hann_window = {}
|
43 |
+
|
44 |
+
|
45 |
+
def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
|
46 |
+
if torch.min(y) < -1.0:
|
47 |
+
print("min value is ", torch.min(y))
|
48 |
+
if torch.max(y) > 1.0:
|
49 |
+
print("max value is ", torch.max(y))
|
50 |
+
|
51 |
+
global mel_basis, hann_window # pylint: disable=global-statement
|
52 |
+
if f"{str(sampling_rate)}_{str(fmax)}_{str(y.device)}" not in mel_basis:
|
53 |
+
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
54 |
+
mel_basis[str(sampling_rate) + "_" + str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
|
55 |
+
hann_window[str(sampling_rate) + "_" + str(y.device)] = torch.hann_window(win_size).to(y.device)
|
56 |
+
|
57 |
+
y = torch.nn.functional.pad(
|
58 |
+
y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
|
59 |
+
)
|
60 |
+
y = y.squeeze(1)
|
61 |
+
|
62 |
+
spec = torch.view_as_real(
|
63 |
+
torch.stft(
|
64 |
+
y,
|
65 |
+
n_fft,
|
66 |
+
hop_length=hop_size,
|
67 |
+
win_length=win_size,
|
68 |
+
window=hann_window[str(sampling_rate) + "_" + str(y.device)],
|
69 |
+
center=center,
|
70 |
+
pad_mode="reflect",
|
71 |
+
normalized=False,
|
72 |
+
onesided=True,
|
73 |
+
return_complex=True,
|
74 |
+
)
|
75 |
+
)
|
76 |
+
|
77 |
+
spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
|
78 |
+
|
79 |
+
spec = torch.matmul(mel_basis[str(sampling_rate) + "_" + str(fmax) + "_" + str(y.device)], spec)
|
80 |
+
spec = spectral_normalize_torch(spec)
|
81 |
+
|
82 |
+
return spec
|
modules/bigvgan/__pycache__/activations.cpython-310.pyc
ADDED
Binary file (4.02 kB). View file
|
|
modules/bigvgan/__pycache__/bigvgan.cpython-310.pyc
ADDED
Binary file (11.8 kB). View file
|
|
modules/bigvgan/__pycache__/env.cpython-310.pyc
ADDED
Binary file (817 Bytes). View file
|
|