tonyshark commited on
Commit
271cff2
·
verified ·
1 Parent(s): 5c6b467

Upload 26 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ speakers/example_female.wav filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import time
4
+ import numpy as np
5
+ import soundfile as sf
6
+ import matplotlib.pyplot as plt
7
+ import librosa
8
+ import gradio as gr
9
+ from scipy.signal import fftconvolve
10
+ from model import StyleTTModel
11
+
12
+ SPEAKER_WAV_PATH = "speakers/example_female.wav"
13
+ OUTPUT_FILENAME = "output.wav"
14
+ SAMPLE_RATE = 24000
15
+
16
+ # Global model variable
17
+ model = None
18
+
19
+ def initialize_model():
20
+ """Initialize the StyleTTS model with error handling"""
21
+ global model
22
+ try:
23
+ # Check if speaker reference file exists
24
+ if not os.path.exists(SPEAKER_WAV_PATH):
25
+ raise FileNotFoundError(f"Không tìm thấy file giọng nói tham chiếu tại: {SPEAKER_WAV_PATH}. "
26
+ "Vui lòng tạo thư mục và đặt file .wav của bạn vào đó.")
27
+
28
+ print("Bắt đầu khởi tạo StyleTTS2 Model...")
29
+ model = StyleTTModel(speaker_wav=SPEAKER_WAV_PATH)
30
+ print("Đang tải model StyleTTS2. Quá trình này có thể mất vài phút...")
31
+ start_time = time.time()
32
+ model.load()
33
+ end_time = time.time()
34
+ print(f"Model đã được tải thành công sau {end_time - start_time:.2f} giây.")
35
+ return True
36
+ except Exception as e:
37
+ print(f"Lỗi khi khởi tạo model: {e}")
38
+ model = None
39
+ return False
40
+
41
+ # Initialize model on startup
42
+ model_loaded = initialize_model()
43
+
44
+ # ---------------------------
45
+ # Load HF TTS model (hexgrad/styletts2)
46
+ # ---------------------------
47
+ SR_OUT = 24000
48
+ # tts_pipe = pipeline("text-to-speech", model="hexgrad/styletts2")
49
+
50
+ # ---------------------------
51
+ # Audio helpers
52
+ # ---------------------------
53
+ def load_wav(path, sr_target=SR_OUT):
54
+ wav, sr = sf.read(path)
55
+ if wav.ndim > 1:
56
+ wav = wav.mean(axis=1)
57
+ if sr != sr_target:
58
+ wav = librosa.resample(wav.astype(np.float32), orig_sr=sr, target_sr=sr_target)
59
+ sr = sr_target
60
+ return wav.astype(np.float32), sr
61
+
62
+ def apply_reverb(wav, ir_path):
63
+ """Apply reverb effect using impulse response"""
64
+ try:
65
+ if not os.path.exists(ir_path):
66
+ print(f"Cảnh báo: Không tìm thấy file impulse response: {ir_path}")
67
+ return wav
68
+ ir, _ = load_wav(ir_path, sr_target=SR_OUT)
69
+ return fftconvolve(wav, ir, mode="full")
70
+ except Exception as e:
71
+ print(f"Lỗi khi áp dụng reverb: {e}")
72
+ return wav
73
+
74
+ def add_noise(wav, noise_path, snr_db=10):
75
+ """Add background noise to audio"""
76
+ try:
77
+ if not os.path.exists(noise_path):
78
+ print(f"Cảnh báo: Không tìm thấy file noise: {noise_path}")
79
+ return wav
80
+ noise, _ = load_wav(noise_path, sr_target=SR_OUT)
81
+ if len(noise) < len(wav):
82
+ noise = np.tile(noise, int(len(wav)/len(noise)) + 1)
83
+ noise = noise[:len(wav)]
84
+ sig_power = np.mean(wav**2)
85
+ noise_power = np.mean(noise**2)
86
+ if noise_power == 0:
87
+ return wav
88
+ scale = np.sqrt(sig_power / (10**(snr_db/10) * noise_power))
89
+ return wav + noise * scale
90
+ except Exception as e:
91
+ print(f"Lỗi khi thêm noise: {e}")
92
+ return wav
93
+
94
+ def bandlimit_phone(wav, sr=SR_OUT):
95
+ """Apply phone-like band limiting"""
96
+ try:
97
+ return librosa.effects.preemphasis(wav)
98
+ except Exception as e:
99
+ print(f"Lỗi khi áp dụng band limiting: {e}")
100
+ return wav
101
+
102
+ def plot_waveforms(clean, processed, sr=SR_OUT):
103
+ """Create waveform comparison plot"""
104
+ try:
105
+ fig, axes = plt.subplots(2, 1, figsize=(10, 4), sharex=True)
106
+ t_clean = np.arange(len(clean)) / sr
107
+ t_proc = np.arange(len(processed)) / sr
108
+
109
+ axes[0].plot(t_clean, clean, color="blue", linewidth=0.8)
110
+ axes[0].set_title("🎤 Waveform gốc (StyleTTS2)")
111
+ axes[0].set_ylabel("Amplitude")
112
+ axes[0].grid(True, alpha=0.3)
113
+
114
+ axes[1].plot(t_proc, processed, color="red", linewidth=0.8)
115
+ axes[1].set_title("🎵 Waveform có hiệu ứng môi trường")
116
+ axes[1].set_xlabel("Thời gian (s)")
117
+ axes[1].set_ylabel("Amplitude")
118
+ axes[1].grid(True, alpha=0.3)
119
+
120
+ fig.tight_layout()
121
+ return fig
122
+ except Exception as e:
123
+ print(f"Lỗi khi tạo biểu đồ: {e}")
124
+ # Return a simple error plot
125
+ fig, ax = plt.subplots(1, 1, figsize=(10, 2))
126
+ ax.text(0.5, 0.5, "Không thể tạo biểu đồ", ha='center', va='center', transform=ax.transAxes)
127
+ ax.set_title("Lỗi tạo biểu đồ")
128
+ return fig
129
+
130
+ # ---------------------------
131
+ # Tag list
132
+ # ---------------------------
133
+ TAG_LIST = {
134
+ "laugh": "😆 Cười thoải mái",
135
+ "whisper": "🤫 Thì thầm",
136
+ "naughty": "😏 Tinh nghịch",
137
+ "giggle": "😂 Cười rúc rích",
138
+ "tease": "😉 Trêu chọc",
139
+ "smirk": "😼 Đắc ý",
140
+ "surprise": "😲 Ngạc nhiên",
141
+ "shock": "😱 Hoảng hốt",
142
+ "romantic": "❤️ Lãng mạn",
143
+ "shy": "�� Bẽn lẽn",
144
+ "excited": "🤩 Phấn khích",
145
+ "curious": "🧐 Tò mò",
146
+ "discover": "✨ Phát hiện",
147
+ "blush": "🌸 Ngượng ngùng",
148
+ "angry": "😡 Giận dữ",
149
+ "sad": "😢 Buồn",
150
+ "happy": "😊 Vui vẻ",
151
+ "fear": "😨 Sợ hãi",
152
+ "confident": "😎 Tự tin",
153
+ "serious": "😐 Nghiêm túc",
154
+ "tired": "🥱 Mệt mỏi",
155
+ "cry": "😭 Khóc",
156
+ "love": "😍 Yêu thương",
157
+ "disgust": "🤢 Ghê tởm",
158
+ }
159
+ TAG_PATTERN = r"(<\/?(?:" + "|".join(TAG_LIST.keys()) + ")>)"
160
+
161
+ # ---------------------------
162
+ # Core synthesis
163
+ # ---------------------------
164
+ def synthesize(text, env, snr_db=10, speed=1.0):
165
+ """Synthesize text to speech with environment effects"""
166
+ try:
167
+ # Check if model is loaded
168
+ if model is None:
169
+ print("Lỗi: Model chưa được tải. Vui lòng khởi động lại ứng dụng.")
170
+ return None, None, None
171
+
172
+ # Parse text and extract segments
173
+ tokens = re.split(TAG_PATTERN, text)
174
+ clean_segments = []
175
+
176
+ for tok in tokens:
177
+ if not tok or tok.isspace():
178
+ continue
179
+ if tok.startswith("<") and tok.endswith(">"):
180
+ # Skip tags for now - they're just for text segmentation
181
+ continue
182
+ else:
183
+ # Synthesize each text segment
184
+ try:
185
+ audio_array = model.synthesize(tok, speed=speed)
186
+ clean_segments.append(audio_array)
187
+ except Exception as e:
188
+ print(f"Lỗi khi tổng hợp đoạn '{tok}': {e}")
189
+ continue
190
+
191
+ if not clean_segments:
192
+ return None, None, None
193
+
194
+ # Concatenate all audio segments
195
+ clean_audio = np.concatenate(clean_segments, axis=0)
196
+ processed = clean_audio.copy()
197
+
198
+ # Apply environment effects
199
+ try:
200
+ if env == "Church":
201
+ processed = apply_reverb(processed, "ir_church.wav")
202
+ elif env == "Hall":
203
+ processed = apply_reverb(processed, "ir_hall.wav")
204
+ elif env == "Cafe":
205
+ processed = add_noise(processed, "noise_cafe.wav", snr_db=snr_db)
206
+ elif env == "Street":
207
+ processed = add_noise(processed, "noise_street.wav", snr_db=snr_db)
208
+ elif env == "Office":
209
+ processed = add_noise(processed, "noise_office.wav", snr_db=snr_db)
210
+ elif env == "Supermarket":
211
+ processed = add_noise(processed, "noise_supermarket.wav", snr_db=snr_db)
212
+ elif env == "Phone":
213
+ processed = bandlimit_phone(processed, sr=SR_OUT)
214
+ except Exception as e:
215
+ print(f"Cảnh báo: Không thể áp dụng hiệu ứng môi trường '{env}': {e}")
216
+ # Continue with clean audio if environment effects fail
217
+
218
+ # Create waveform comparison plot
219
+ fig = plot_waveforms(clean_audio, processed, sr=SR_OUT)
220
+
221
+ return (SR_OUT, processed), fig, (SR_OUT, clean_audio)
222
+
223
+ except Exception as e:
224
+ print(f"Lỗi trong quá trình tổng hợp: {e}")
225
+ return None, None, None
226
+
227
+ # ---------------------------
228
+ # Examples
229
+ # ---------------------------
230
+ EXAMPLES = [
231
+ "Xin chào <whisper> tôi nói nhỏ </whisper> rồi <laugh> bật cười </laugh>.",
232
+ "Tôi cảm thấy <happy> vui </happy> nhưng cũng <sad> buồn </sad>.",
233
+ "Khi <surprise> bất ngờ </surprise> tôi <shock> hoảng hốt </shock>.",
234
+ ]
235
+
236
+ # ---------------------------
237
+ # Gradio UI
238
+ # ---------------------------
239
+ with gr.Blocks(title="StyleTTS2 Text-to-Speech", theme=gr.themes.Soft()) as demo:
240
+ gr.Markdown("# 🎙️ StyleTTS2 Text-to-Speech với Hiệu ứng Môi trường")
241
+
242
+ # Model status indicator
243
+ if model_loaded:
244
+ gr.Markdown("✅ **Model đã sẵn sàng** - Bạn có thể bắt đầu tạo giọng nói!")
245
+ else:
246
+ gr.Markdown("❌ **Lỗi tải model** - Vui lòng kiểm tra file giọng nói tham chiếu và khởi động lại.")
247
+
248
+ gr.Markdown("Sử dụng StyleTTS2 với khả năng thêm hiệu ứng môi trường và điều chỉnh tốc độ nói.")
249
+
250
+ with gr.Accordion("📑 Danh sách Tags + Emoji", open=False):
251
+ md = "| Tag | Ý nghĩa |\n|-----|----------|\n"
252
+ for k, v in TAG_LIST.items():
253
+ md += f"| `<{k}>...</{k}>` | {v} |\n"
254
+ gr.Markdown(md)
255
+
256
+ with gr.Row():
257
+ with gr.Column(scale=1):
258
+ gr.Markdown("### ⚙️ Cài đặt")
259
+
260
+ text_in = gr.Textbox(
261
+ value=EXAMPLES[0],
262
+ label="📝 Văn bản cần chuyển đổi",
263
+ lines=4,
264
+ placeholder="Nhập văn bản của bạn ở đây. Sử dụng tags để tạo cảm xúc..."
265
+ )
266
+
267
+ with gr.Row():
268
+ env_in = gr.Dropdown(
269
+ choices=["Neutral", "Church", "Hall", "Cafe", "Street", "Phone", "Office", "Supermarket"],
270
+ value="Neutral",
271
+ label="🌍 Môi trường âm thanh",
272
+ info="Chọn môi trường để áp dụng hiệu ứng"
273
+ )
274
+ with gr.Row():
275
+ speed_slider = gr.Slider(
276
+ minimum=0.5,
277
+ maximum=2.0,
278
+ value=1.0,
279
+ step=0.1,
280
+ label="⚡ Tốc độ nói",
281
+ info="1.0 = bình thường, < 1.0 = chậm, > 1.0 = nhanh"
282
+ )
283
+ with gr.Row():
284
+ snr_slider = gr.Slider(
285
+ 0, 30,
286
+ value=10,
287
+ step=1,
288
+ label="🔊 Mức độ nhiễu (SNR dB)",
289
+ info="Chỉ áp dụng cho môi trường có tiếng ồn. Cao hơn = ít nhiễu hơn"
290
+ )
291
+
292
+ btn = gr.Button("🎵 Tạo giọng nói", variant="primary", size="lg")
293
+
294
+ gr.Examples(
295
+ examples=[[ex] for ex in EXAMPLES],
296
+ inputs=[text_in],
297
+ label="💡 Ví dụ nhanh"
298
+ )
299
+
300
+ with gr.Column(scale=1):
301
+ gr.Markdown("### 🎧 Kết quả")
302
+
303
+ audio_out = gr.Audio(
304
+ label="🎵 Âm thanh có hiệu ứng",
305
+ type="numpy",
306
+ info="Phiên bản có áp dụng hiệu ứng môi trường"
307
+ )
308
+ clean_out = gr.Audio(
309
+ label="🎤 Âm thanh gốc",
310
+ type="numpy",
311
+ info="Phiên bản gốc không có hiệu ứng"
312
+ )
313
+ wave_plot = gr.Plot(
314
+ label="📊 So sánh dạng sóng",
315
+ info="Biểu đồ so sánh âm thanh gốc và có hiệu ứng"
316
+ )
317
+
318
+ btn.click(fn=synthesize,
319
+ inputs=[text_in, env_in, snr_slider, speed_slider],
320
+ outputs=[audio_out, wave_plot, clean_out])
321
+
322
+ # Launch the application
323
+ if __name__ == "__main__":
324
+ try:
325
+ print("🚀 Đang khởi động ứng dụng StyleTTS2...")
326
+ demo.launch(
327
+ server_name="0.0.0.0",
328
+ server_port=7860,
329
+ share=False,
330
+ show_error=True
331
+ )
332
+ except Exception as e:
333
+ print(f"❌ Lỗi khi khởi động ứng dụng: {e}")
334
+ print("Vui lòng kiểm tra lại cấu hình và thử lại.")
libs/.DS_Store ADDED
Binary file (6.15 kB). View file
 
libs/Modules/.DS_Store ADDED
Binary file (6.15 kB). View file
 
libs/Modules/ASR/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
libs/Modules/ASR/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (186 Bytes). View file
 
libs/Modules/ASR/__pycache__/layers.cpython-310.pyc ADDED
Binary file (11.1 kB). View file
 
libs/Modules/ASR/__pycache__/models.cpython-310.pyc ADDED
Binary file (6.15 kB). View file
 
libs/Modules/ASR/layers.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from typing import Optional, Any
5
+ from torch import Tensor
6
+ import torch.nn.functional as F
7
+ import torchaudio
8
+ import torchaudio.functional as audio_F
9
+
10
+ import random
11
+ random.seed(0)
12
+
13
+
14
+ def _get_activation_fn(activ):
15
+ if activ == 'relu':
16
+ return nn.ReLU()
17
+ elif activ == 'lrelu':
18
+ return nn.LeakyReLU(0.2)
19
+ elif activ == 'swish':
20
+ return lambda x: x*torch.sigmoid(x)
21
+ else:
22
+ raise RuntimeError('Unexpected activ type %s, expected [relu, lrelu, swish]' % activ)
23
+
24
+ class LinearNorm(torch.nn.Module):
25
+ def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
26
+ super(LinearNorm, self).__init__()
27
+ self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
28
+
29
+ torch.nn.init.xavier_uniform_(
30
+ self.linear_layer.weight,
31
+ gain=torch.nn.init.calculate_gain(w_init_gain))
32
+
33
+ def forward(self, x):
34
+ return self.linear_layer(x)
35
+
36
+
37
+ class ConvNorm(torch.nn.Module):
38
+ def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
39
+ padding=None, dilation=1, bias=True, w_init_gain='linear', param=None):
40
+ super(ConvNorm, self).__init__()
41
+ if padding is None:
42
+ assert(kernel_size % 2 == 1)
43
+ padding = int(dilation * (kernel_size - 1) / 2)
44
+
45
+ self.conv = torch.nn.Conv1d(in_channels, out_channels,
46
+ kernel_size=kernel_size, stride=stride,
47
+ padding=padding, dilation=dilation,
48
+ bias=bias)
49
+
50
+ torch.nn.init.xavier_uniform_(
51
+ self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain, param=param))
52
+
53
+ def forward(self, signal):
54
+ conv_signal = self.conv(signal)
55
+ return conv_signal
56
+
57
+ class CausualConv(nn.Module):
58
+ def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=1, dilation=1, bias=True, w_init_gain='linear', param=None):
59
+ super(CausualConv, self).__init__()
60
+ if padding is None:
61
+ assert(kernel_size % 2 == 1)
62
+ padding = int(dilation * (kernel_size - 1) / 2) * 2
63
+ else:
64
+ self.padding = padding * 2
65
+ self.conv = nn.Conv1d(in_channels, out_channels,
66
+ kernel_size=kernel_size, stride=stride,
67
+ padding=self.padding,
68
+ dilation=dilation,
69
+ bias=bias)
70
+
71
+ torch.nn.init.xavier_uniform_(
72
+ self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain, param=param))
73
+
74
+ def forward(self, x):
75
+ x = self.conv(x)
76
+ x = x[:, :, :-self.padding]
77
+ return x
78
+
79
+ class CausualBlock(nn.Module):
80
+ def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2, activ='lrelu'):
81
+ super(CausualBlock, self).__init__()
82
+ self.blocks = nn.ModuleList([
83
+ self._get_conv(hidden_dim, dilation=3**i, activ=activ, dropout_p=dropout_p)
84
+ for i in range(n_conv)])
85
+
86
+ def forward(self, x):
87
+ for block in self.blocks:
88
+ res = x
89
+ x = block(x)
90
+ x += res
91
+ return x
92
+
93
+ def _get_conv(self, hidden_dim, dilation, activ='lrelu', dropout_p=0.2):
94
+ layers = [
95
+ CausualConv(hidden_dim, hidden_dim, kernel_size=3, padding=dilation, dilation=dilation),
96
+ _get_activation_fn(activ),
97
+ nn.BatchNorm1d(hidden_dim),
98
+ nn.Dropout(p=dropout_p),
99
+ CausualConv(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1),
100
+ _get_activation_fn(activ),
101
+ nn.Dropout(p=dropout_p)
102
+ ]
103
+ return nn.Sequential(*layers)
104
+
105
+ class ConvBlock(nn.Module):
106
+ def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2, activ='relu'):
107
+ super().__init__()
108
+ self._n_groups = 8
109
+ self.blocks = nn.ModuleList([
110
+ self._get_conv(hidden_dim, dilation=3**i, activ=activ, dropout_p=dropout_p)
111
+ for i in range(n_conv)])
112
+
113
+
114
+ def forward(self, x):
115
+ for block in self.blocks:
116
+ res = x
117
+ x = block(x)
118
+ x += res
119
+ return x
120
+
121
+ def _get_conv(self, hidden_dim, dilation, activ='relu', dropout_p=0.2):
122
+ layers = [
123
+ ConvNorm(hidden_dim, hidden_dim, kernel_size=3, padding=dilation, dilation=dilation),
124
+ _get_activation_fn(activ),
125
+ nn.GroupNorm(num_groups=self._n_groups, num_channels=hidden_dim),
126
+ nn.Dropout(p=dropout_p),
127
+ ConvNorm(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1),
128
+ _get_activation_fn(activ),
129
+ nn.Dropout(p=dropout_p)
130
+ ]
131
+ return nn.Sequential(*layers)
132
+
133
+ class LocationLayer(nn.Module):
134
+ def __init__(self, attention_n_filters, attention_kernel_size,
135
+ attention_dim):
136
+ super(LocationLayer, self).__init__()
137
+ padding = int((attention_kernel_size - 1) / 2)
138
+ self.location_conv = ConvNorm(2, attention_n_filters,
139
+ kernel_size=attention_kernel_size,
140
+ padding=padding, bias=False, stride=1,
141
+ dilation=1)
142
+ self.location_dense = LinearNorm(attention_n_filters, attention_dim,
143
+ bias=False, w_init_gain='tanh')
144
+
145
+ def forward(self, attention_weights_cat):
146
+ processed_attention = self.location_conv(attention_weights_cat)
147
+ processed_attention = processed_attention.transpose(1, 2)
148
+ processed_attention = self.location_dense(processed_attention)
149
+ return processed_attention
150
+
151
+
152
+ class Attention(nn.Module):
153
+ def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
154
+ attention_location_n_filters, attention_location_kernel_size):
155
+ super(Attention, self).__init__()
156
+ self.query_layer = LinearNorm(attention_rnn_dim, attention_dim,
157
+ bias=False, w_init_gain='tanh')
158
+ self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False,
159
+ w_init_gain='tanh')
160
+ self.v = LinearNorm(attention_dim, 1, bias=False)
161
+ self.location_layer = LocationLayer(attention_location_n_filters,
162
+ attention_location_kernel_size,
163
+ attention_dim)
164
+ self.score_mask_value = -float("inf")
165
+
166
+ def get_alignment_energies(self, query, processed_memory,
167
+ attention_weights_cat):
168
+ """
169
+ PARAMS
170
+ ------
171
+ query: decoder output (batch, n_mel_channels * n_frames_per_step)
172
+ processed_memory: processed encoder outputs (B, T_in, attention_dim)
173
+ attention_weights_cat: cumulative and prev. att weights (B, 2, max_time)
174
+ RETURNS
175
+ -------
176
+ alignment (batch, max_time)
177
+ """
178
+
179
+ processed_query = self.query_layer(query.unsqueeze(1))
180
+ processed_attention_weights = self.location_layer(attention_weights_cat)
181
+ energies = self.v(torch.tanh(
182
+ processed_query + processed_attention_weights + processed_memory))
183
+
184
+ energies = energies.squeeze(-1)
185
+ return energies
186
+
187
+ def forward(self, attention_hidden_state, memory, processed_memory,
188
+ attention_weights_cat, mask):
189
+ """
190
+ PARAMS
191
+ ------
192
+ attention_hidden_state: attention rnn last output
193
+ memory: encoder outputs
194
+ processed_memory: processed encoder outputs
195
+ attention_weights_cat: previous and cummulative attention weights
196
+ mask: binary mask for padded data
197
+ """
198
+ alignment = self.get_alignment_energies(
199
+ attention_hidden_state, processed_memory, attention_weights_cat)
200
+
201
+ if mask is not None:
202
+ alignment.data.masked_fill_(mask, self.score_mask_value)
203
+
204
+ attention_weights = F.softmax(alignment, dim=1)
205
+ attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
206
+ attention_context = attention_context.squeeze(1)
207
+
208
+ return attention_context, attention_weights
209
+
210
+
211
+ class ForwardAttentionV2(nn.Module):
212
+ def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
213
+ attention_location_n_filters, attention_location_kernel_size):
214
+ super(ForwardAttentionV2, self).__init__()
215
+ self.query_layer = LinearNorm(attention_rnn_dim, attention_dim,
216
+ bias=False, w_init_gain='tanh')
217
+ self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False,
218
+ w_init_gain='tanh')
219
+ self.v = LinearNorm(attention_dim, 1, bias=False)
220
+ self.location_layer = LocationLayer(attention_location_n_filters,
221
+ attention_location_kernel_size,
222
+ attention_dim)
223
+ self.score_mask_value = -float(1e20)
224
+
225
+ def get_alignment_energies(self, query, processed_memory,
226
+ attention_weights_cat):
227
+ """
228
+ PARAMS
229
+ ------
230
+ query: decoder output (batch, n_mel_channels * n_frames_per_step)
231
+ processed_memory: processed encoder outputs (B, T_in, attention_dim)
232
+ attention_weights_cat: prev. and cumulative att weights (B, 2, max_time)
233
+ RETURNS
234
+ -------
235
+ alignment (batch, max_time)
236
+ """
237
+
238
+ processed_query = self.query_layer(query.unsqueeze(1))
239
+ processed_attention_weights = self.location_layer(attention_weights_cat)
240
+ energies = self.v(torch.tanh(
241
+ processed_query + processed_attention_weights + processed_memory))
242
+
243
+ energies = energies.squeeze(-1)
244
+ return energies
245
+
246
+ def forward(self, attention_hidden_state, memory, processed_memory,
247
+ attention_weights_cat, mask, log_alpha):
248
+ """
249
+ PARAMS
250
+ ------
251
+ attention_hidden_state: attention rnn last output
252
+ memory: encoder outputs
253
+ processed_memory: processed encoder outputs
254
+ attention_weights_cat: previous and cummulative attention weights
255
+ mask: binary mask for padded data
256
+ """
257
+ log_energy = self.get_alignment_energies(
258
+ attention_hidden_state, processed_memory, attention_weights_cat)
259
+
260
+ #log_energy =
261
+
262
+ if mask is not None:
263
+ log_energy.data.masked_fill_(mask, self.score_mask_value)
264
+
265
+ #attention_weights = F.softmax(alignment, dim=1)
266
+
267
+ #content_score = log_energy.unsqueeze(1) #[B, MAX_TIME] -> [B, 1, MAX_TIME]
268
+ #log_alpha = log_alpha.unsqueeze(2) #[B, MAX_TIME] -> [B, MAX_TIME, 1]
269
+
270
+ #log_total_score = log_alpha + content_score
271
+
272
+ #previous_attention_weights = attention_weights_cat[:,0,:]
273
+
274
+ log_alpha_shift_padded = []
275
+ max_time = log_energy.size(1)
276
+ for sft in range(2):
277
+ shifted = log_alpha[:,:max_time-sft]
278
+ shift_padded = F.pad(shifted, (sft,0), 'constant', self.score_mask_value)
279
+ log_alpha_shift_padded.append(shift_padded.unsqueeze(2))
280
+
281
+ biased = torch.logsumexp(torch.cat(log_alpha_shift_padded,2), 2)
282
+
283
+ log_alpha_new = biased + log_energy
284
+
285
+ attention_weights = F.softmax(log_alpha_new, dim=1)
286
+
287
+ attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
288
+ attention_context = attention_context.squeeze(1)
289
+
290
+ return attention_context, attention_weights, log_alpha_new
291
+
292
+
293
+ class PhaseShuffle2d(nn.Module):
294
+ def __init__(self, n=2):
295
+ super(PhaseShuffle2d, self).__init__()
296
+ self.n = n
297
+ self.random = random.Random(1)
298
+
299
+ def forward(self, x, move=None):
300
+ # x.size = (B, C, M, L)
301
+ if move is None:
302
+ move = self.random.randint(-self.n, self.n)
303
+
304
+ if move == 0:
305
+ return x
306
+ else:
307
+ left = x[:, :, :, :move]
308
+ right = x[:, :, :, move:]
309
+ shuffled = torch.cat([right, left], dim=3)
310
+ return shuffled
311
+
312
+ class PhaseShuffle1d(nn.Module):
313
+ def __init__(self, n=2):
314
+ super(PhaseShuffle1d, self).__init__()
315
+ self.n = n
316
+ self.random = random.Random(1)
317
+
318
+ def forward(self, x, move=None):
319
+ # x.size = (B, C, M, L)
320
+ if move is None:
321
+ move = self.random.randint(-self.n, self.n)
322
+
323
+ if move == 0:
324
+ return x
325
+ else:
326
+ left = x[:, :, :move]
327
+ right = x[:, :, move:]
328
+ shuffled = torch.cat([right, left], dim=2)
329
+
330
+ return shuffled
331
+
332
+ class MFCC(nn.Module):
333
+ def __init__(self, n_mfcc=40, n_mels=80):
334
+ super(MFCC, self).__init__()
335
+ self.n_mfcc = n_mfcc
336
+ self.n_mels = n_mels
337
+ self.norm = 'ortho'
338
+ dct_mat = audio_F.create_dct(self.n_mfcc, self.n_mels, self.norm)
339
+ self.register_buffer('dct_mat', dct_mat)
340
+
341
+ def forward(self, mel_specgram):
342
+ if len(mel_specgram.shape) == 2:
343
+ mel_specgram = mel_specgram.unsqueeze(0)
344
+ unsqueezed = True
345
+ else:
346
+ unsqueezed = False
347
+ # (channel, n_mels, time).tranpose(...) dot (n_mels, n_mfcc)
348
+ # -> (channel, time, n_mfcc).tranpose(...)
349
+ mfcc = torch.matmul(mel_specgram.transpose(1, 2), self.dct_mat).transpose(1, 2)
350
+
351
+ # unpack batch
352
+ if unsqueezed:
353
+ mfcc = mfcc.squeeze(0)
354
+ return mfcc
libs/Modules/ASR/models.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import TransformerEncoder
5
+ import torch.nn.functional as F
6
+ from .layers import MFCC, Attention, LinearNorm, ConvNorm, ConvBlock
7
+
8
+ class ASRCNN(nn.Module):
9
+ def __init__(self,
10
+ input_dim=80,
11
+ hidden_dim=256,
12
+ n_token=35,
13
+ n_layers=6,
14
+ token_embedding_dim=256,
15
+
16
+ ):
17
+ super().__init__()
18
+ self.n_token = n_token
19
+ self.n_down = 1
20
+ self.to_mfcc = MFCC()
21
+ self.init_cnn = ConvNorm(input_dim//2, hidden_dim, kernel_size=7, padding=3, stride=2)
22
+ self.cnns = nn.Sequential(
23
+ *[nn.Sequential(
24
+ ConvBlock(hidden_dim),
25
+ nn.GroupNorm(num_groups=1, num_channels=hidden_dim)
26
+ ) for n in range(n_layers)])
27
+ self.projection = ConvNorm(hidden_dim, hidden_dim // 2)
28
+ self.ctc_linear = nn.Sequential(
29
+ LinearNorm(hidden_dim//2, hidden_dim),
30
+ nn.ReLU(),
31
+ LinearNorm(hidden_dim, n_token))
32
+ self.asr_s2s = ASRS2S(
33
+ embedding_dim=token_embedding_dim,
34
+ hidden_dim=hidden_dim//2,
35
+ n_token=n_token)
36
+
37
+ def forward(self, x, src_key_padding_mask=None, text_input=None):
38
+ x = self.to_mfcc(x)
39
+ x = self.init_cnn(x)
40
+ x = self.cnns(x)
41
+ x = self.projection(x)
42
+ x = x.transpose(1, 2)
43
+ ctc_logit = self.ctc_linear(x)
44
+ if text_input is not None:
45
+ _, s2s_logit, s2s_attn = self.asr_s2s(x, src_key_padding_mask, text_input)
46
+ return ctc_logit, s2s_logit, s2s_attn
47
+ else:
48
+ return ctc_logit
49
+
50
+ def get_feature(self, x):
51
+ x = self.to_mfcc(x.squeeze(1))
52
+ x = self.init_cnn(x)
53
+ x = self.cnns(x)
54
+ x = self.projection(x)
55
+ return x
56
+
57
+ def length_to_mask(self, lengths):
58
+ mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
59
+ mask = torch.gt(mask+1, lengths.unsqueeze(1)).to(lengths.device)
60
+ return mask
61
+
62
+ def get_future_mask(self, out_length, unmask_future_steps=0):
63
+ """
64
+ Args:
65
+ out_length (int): returned mask shape is (out_length, out_length).
66
+ unmask_futre_steps (int): unmasking future step size.
67
+ Return:
68
+ mask (torch.BoolTensor): mask future timesteps mask[i, j] = True if i > j + unmask_future_steps else False
69
+ """
70
+ index_tensor = torch.arange(out_length).unsqueeze(0).expand(out_length, -1)
71
+ mask = torch.gt(index_tensor, index_tensor.T + unmask_future_steps)
72
+ return mask
73
+
74
+ class ASRS2S(nn.Module):
75
+ def __init__(self,
76
+ embedding_dim=256,
77
+ hidden_dim=512,
78
+ n_location_filters=32,
79
+ location_kernel_size=63,
80
+ n_token=40):
81
+ super(ASRS2S, self).__init__()
82
+ self.embedding = nn.Embedding(n_token, embedding_dim)
83
+ val_range = math.sqrt(6 / hidden_dim)
84
+ self.embedding.weight.data.uniform_(-val_range, val_range)
85
+
86
+ self.decoder_rnn_dim = hidden_dim
87
+ self.project_to_n_symbols = nn.Linear(self.decoder_rnn_dim, n_token)
88
+ self.attention_layer = Attention(
89
+ self.decoder_rnn_dim,
90
+ hidden_dim,
91
+ hidden_dim,
92
+ n_location_filters,
93
+ location_kernel_size
94
+ )
95
+ self.decoder_rnn = nn.LSTMCell(self.decoder_rnn_dim + embedding_dim, self.decoder_rnn_dim)
96
+ self.project_to_hidden = nn.Sequential(
97
+ LinearNorm(self.decoder_rnn_dim * 2, hidden_dim),
98
+ nn.Tanh())
99
+ self.sos = 1
100
+ self.eos = 2
101
+
102
+ def initialize_decoder_states(self, memory, mask):
103
+ """
104
+ moemory.shape = (B, L, H) = (Batchsize, Maxtimestep, Hiddendim)
105
+ """
106
+ B, L, H = memory.shape
107
+ self.decoder_hidden = torch.zeros((B, self.decoder_rnn_dim)).type_as(memory)
108
+ self.decoder_cell = torch.zeros((B, self.decoder_rnn_dim)).type_as(memory)
109
+ self.attention_weights = torch.zeros((B, L)).type_as(memory)
110
+ self.attention_weights_cum = torch.zeros((B, L)).type_as(memory)
111
+ self.attention_context = torch.zeros((B, H)).type_as(memory)
112
+ self.memory = memory
113
+ self.processed_memory = self.attention_layer.memory_layer(memory)
114
+ self.mask = mask
115
+ self.unk_index = 3
116
+ self.random_mask = 0.1
117
+
118
+ def forward(self, memory, memory_mask, text_input):
119
+ """
120
+ moemory.shape = (B, L, H) = (Batchsize, Maxtimestep, Hiddendim)
121
+ moemory_mask.shape = (B, L, )
122
+ texts_input.shape = (B, T)
123
+ """
124
+ self.initialize_decoder_states(memory, memory_mask)
125
+ # text random mask
126
+ random_mask = (torch.rand(text_input.shape) < self.random_mask).to(text_input.device)
127
+ _text_input = text_input.clone()
128
+ _text_input.masked_fill_(random_mask, self.unk_index)
129
+ decoder_inputs = self.embedding(_text_input).transpose(0, 1) # -> [T, B, channel]
130
+ start_embedding = self.embedding(
131
+ torch.LongTensor([self.sos]*decoder_inputs.size(1)).to(decoder_inputs.device))
132
+ decoder_inputs = torch.cat((start_embedding.unsqueeze(0), decoder_inputs), dim=0)
133
+
134
+ hidden_outputs, logit_outputs, alignments = [], [], []
135
+ while len(hidden_outputs) < decoder_inputs.size(0):
136
+
137
+ decoder_input = decoder_inputs[len(hidden_outputs)]
138
+ hidden, logit, attention_weights = self.decode(decoder_input)
139
+ hidden_outputs += [hidden]
140
+ logit_outputs += [logit]
141
+ alignments += [attention_weights]
142
+
143
+ hidden_outputs, logit_outputs, alignments = \
144
+ self.parse_decoder_outputs(
145
+ hidden_outputs, logit_outputs, alignments)
146
+
147
+ return hidden_outputs, logit_outputs, alignments
148
+
149
+
150
+ def decode(self, decoder_input):
151
+
152
+ cell_input = torch.cat((decoder_input, self.attention_context), -1)
153
+ self.decoder_hidden, self.decoder_cell = self.decoder_rnn(
154
+ cell_input,
155
+ (self.decoder_hidden, self.decoder_cell))
156
+
157
+ attention_weights_cat = torch.cat(
158
+ (self.attention_weights.unsqueeze(1),
159
+ self.attention_weights_cum.unsqueeze(1)),dim=1)
160
+
161
+ self.attention_context, self.attention_weights = self.attention_layer(
162
+ self.decoder_hidden,
163
+ self.memory,
164
+ self.processed_memory,
165
+ attention_weights_cat,
166
+ self.mask)
167
+
168
+ self.attention_weights_cum += self.attention_weights
169
+
170
+ hidden_and_context = torch.cat((self.decoder_hidden, self.attention_context), -1)
171
+ hidden = self.project_to_hidden(hidden_and_context)
172
+
173
+ # dropout to increasing g
174
+ logit = self.project_to_n_symbols(F.dropout(hidden, 0.5, self.training))
175
+
176
+ return hidden, logit, self.attention_weights
177
+
178
+ def parse_decoder_outputs(self, hidden, logit, alignments):
179
+
180
+ # -> [B, T_out + 1, max_time]
181
+ alignments = torch.stack(alignments).transpose(0,1)
182
+ # [T_out + 1, B, n_symbols] -> [B, T_out + 1, n_symbols]
183
+ logit = torch.stack(logit).transpose(0, 1).contiguous()
184
+ hidden = torch.stack(hidden).transpose(0, 1).contiguous()
185
+
186
+ return hidden, logit, alignments
libs/Modules/JDC/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
libs/Modules/JDC/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (186 Bytes). View file
 
libs/Modules/JDC/__pycache__/model.cpython-310.pyc ADDED
Binary file (4.81 kB). View file
 
libs/Modules/JDC/model.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Implementation of model from:
3
+ Kum et al. - "Joint Detection and Classification of Singing Voice Melody Using
4
+ Convolutional Recurrent Neural Networks" (2019)
5
+ Link: https://www.semanticscholar.org/paper/Joint-Detection-and-Classification-of-Singing-Voice-Kum-Nam/60a2ad4c7db43bace75805054603747fcd062c0d
6
+ """
7
+ import torch
8
+ from torch import nn
9
+
10
+ class JDCNet(nn.Module):
11
+ """
12
+ Joint Detection and Classification Network model for singing voice melody.
13
+ """
14
+ def __init__(self, num_class=722, seq_len=31, leaky_relu_slope=0.01):
15
+ super().__init__()
16
+ self.num_class = num_class
17
+
18
+ # input = (b, 1, 31, 513), b = batch size
19
+ self.conv_block = nn.Sequential(
20
+ nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, padding=1, bias=False), # out: (b, 64, 31, 513)
21
+ nn.BatchNorm2d(num_features=64),
22
+ nn.LeakyReLU(leaky_relu_slope, inplace=True),
23
+ nn.Conv2d(64, 64, 3, padding=1, bias=False), # (b, 64, 31, 513)
24
+ )
25
+
26
+ # res blocks
27
+ self.res_block1 = ResBlock(in_channels=64, out_channels=128) # (b, 128, 31, 128)
28
+ self.res_block2 = ResBlock(in_channels=128, out_channels=192) # (b, 192, 31, 32)
29
+ self.res_block3 = ResBlock(in_channels=192, out_channels=256) # (b, 256, 31, 8)
30
+
31
+ # pool block
32
+ self.pool_block = nn.Sequential(
33
+ nn.BatchNorm2d(num_features=256),
34
+ nn.LeakyReLU(leaky_relu_slope, inplace=True),
35
+ nn.MaxPool2d(kernel_size=(1, 4)), # (b, 256, 31, 2)
36
+ nn.Dropout(p=0.2),
37
+ )
38
+
39
+ # maxpool layers (for auxiliary network inputs)
40
+ # in = (b, 128, 31, 513) from conv_block, out = (b, 128, 31, 2)
41
+ self.maxpool1 = nn.MaxPool2d(kernel_size=(1, 40))
42
+ # in = (b, 128, 31, 128) from res_block1, out = (b, 128, 31, 2)
43
+ self.maxpool2 = nn.MaxPool2d(kernel_size=(1, 20))
44
+ # in = (b, 128, 31, 32) from res_block2, out = (b, 128, 31, 2)
45
+ self.maxpool3 = nn.MaxPool2d(kernel_size=(1, 10))
46
+
47
+ # in = (b, 640, 31, 2), out = (b, 256, 31, 2)
48
+ self.detector_conv = nn.Sequential(
49
+ nn.Conv2d(640, 256, 1, bias=False),
50
+ nn.BatchNorm2d(256),
51
+ nn.LeakyReLU(leaky_relu_slope, inplace=True),
52
+ nn.Dropout(p=0.2),
53
+ )
54
+
55
+ # input: (b, 31, 512) - resized from (b, 256, 31, 2)
56
+ self.bilstm_classifier = nn.LSTM(
57
+ input_size=512, hidden_size=256,
58
+ batch_first=True, bidirectional=True) # (b, 31, 512)
59
+
60
+ # input: (b, 31, 512) - resized from (b, 256, 31, 2)
61
+ self.bilstm_detector = nn.LSTM(
62
+ input_size=512, hidden_size=256,
63
+ batch_first=True, bidirectional=True) # (b, 31, 512)
64
+
65
+ # input: (b * 31, 512)
66
+ self.classifier = nn.Linear(in_features=512, out_features=self.num_class) # (b * 31, num_class)
67
+
68
+ # input: (b * 31, 512)
69
+ self.detector = nn.Linear(in_features=512, out_features=2) # (b * 31, 2) - binary classifier
70
+
71
+ # initialize weights
72
+ self.apply(self.init_weights)
73
+
74
+ def get_feature_GAN(self, x):
75
+ seq_len = x.shape[-2]
76
+ x = x.float().transpose(-1, -2)
77
+
78
+ convblock_out = self.conv_block(x)
79
+
80
+ resblock1_out = self.res_block1(convblock_out)
81
+ resblock2_out = self.res_block2(resblock1_out)
82
+ resblock3_out = self.res_block3(resblock2_out)
83
+ poolblock_out = self.pool_block[0](resblock3_out)
84
+ poolblock_out = self.pool_block[1](poolblock_out)
85
+
86
+ return poolblock_out.transpose(-1, -2)
87
+
88
+ def get_feature(self, x):
89
+ seq_len = x.shape[-2]
90
+ x = x.float().transpose(-1, -2)
91
+
92
+ convblock_out = self.conv_block(x)
93
+
94
+ resblock1_out = self.res_block1(convblock_out)
95
+ resblock2_out = self.res_block2(resblock1_out)
96
+ resblock3_out = self.res_block3(resblock2_out)
97
+ poolblock_out = self.pool_block[0](resblock3_out)
98
+ poolblock_out = self.pool_block[1](poolblock_out)
99
+
100
+ return self.pool_block[2](poolblock_out)
101
+
102
+ def forward(self, x):
103
+ """
104
+ Returns:
105
+ classification_prediction, detection_prediction
106
+ sizes: (b, 31, 722), (b, 31, 2)
107
+ """
108
+ ###############################
109
+ # forward pass for classifier #
110
+ ###############################
111
+ seq_len = x.shape[-1]
112
+ x = x.float().transpose(-1, -2)
113
+
114
+ convblock_out = self.conv_block(x)
115
+
116
+ resblock1_out = self.res_block1(convblock_out)
117
+ resblock2_out = self.res_block2(resblock1_out)
118
+ resblock3_out = self.res_block3(resblock2_out)
119
+
120
+
121
+ poolblock_out = self.pool_block[0](resblock3_out)
122
+ poolblock_out = self.pool_block[1](poolblock_out)
123
+ GAN_feature = poolblock_out.transpose(-1, -2)
124
+ poolblock_out = self.pool_block[2](poolblock_out)
125
+
126
+ # (b, 256, 31, 2) => (b, 31, 256, 2) => (b, 31, 512)
127
+ classifier_out = poolblock_out.permute(0, 2, 1, 3).contiguous().view((-1, seq_len, 512))
128
+ classifier_out, _ = self.bilstm_classifier(classifier_out) # ignore the hidden states
129
+
130
+ classifier_out = classifier_out.contiguous().view((-1, 512)) # (b * 31, 512)
131
+ classifier_out = self.classifier(classifier_out)
132
+ classifier_out = classifier_out.view((-1, seq_len, self.num_class)) # (b, 31, num_class)
133
+
134
+ # sizes: (b, 31, 722), (b, 31, 2)
135
+ # classifier output consists of predicted pitch classes per frame
136
+ # detector output consists of: (isvoice, notvoice) estimates per frame
137
+ return torch.abs(classifier_out.squeeze()), GAN_feature, poolblock_out
138
+
139
+ @staticmethod
140
+ def init_weights(m):
141
+ if isinstance(m, nn.Linear):
142
+ nn.init.kaiming_uniform_(m.weight)
143
+ if m.bias is not None:
144
+ nn.init.constant_(m.bias, 0)
145
+ elif isinstance(m, nn.Conv2d):
146
+ nn.init.xavier_normal_(m.weight)
147
+ elif isinstance(m, nn.LSTM) or isinstance(m, nn.LSTMCell):
148
+ for p in m.parameters():
149
+ if p.data is None:
150
+ continue
151
+
152
+ if len(p.shape) >= 2:
153
+ nn.init.orthogonal_(p.data)
154
+ else:
155
+ nn.init.normal_(p.data)
156
+
157
+
158
+ class ResBlock(nn.Module):
159
+ def __init__(self, in_channels: int, out_channels: int, leaky_relu_slope=0.01):
160
+ super().__init__()
161
+ self.downsample = in_channels != out_channels
162
+
163
+ # BN / LReLU / MaxPool layer before the conv layer - see Figure 1b in the paper
164
+ self.pre_conv = nn.Sequential(
165
+ nn.BatchNorm2d(num_features=in_channels),
166
+ nn.LeakyReLU(leaky_relu_slope, inplace=True),
167
+ nn.MaxPool2d(kernel_size=(1, 2)), # apply downsampling on the y axis only
168
+ )
169
+
170
+ # conv layers
171
+ self.conv = nn.Sequential(
172
+ nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
173
+ kernel_size=3, padding=1, bias=False),
174
+ nn.BatchNorm2d(out_channels),
175
+ nn.LeakyReLU(leaky_relu_slope, inplace=True),
176
+ nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
177
+ )
178
+
179
+ # 1 x 1 convolution layer to match the feature dimensions
180
+ self.conv1by1 = None
181
+ if self.downsample:
182
+ self.conv1by1 = nn.Conv2d(in_channels, out_channels, 1, bias=False)
183
+
184
+ def forward(self, x):
185
+ x = self.pre_conv(x)
186
+ if self.downsample:
187
+ x = self.conv(x) + self.conv1by1(x)
188
+ else:
189
+ x = self.conv(x) + x
190
+ return x
libs/Modules/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
libs/Modules/discriminators.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import torch.nn as nn
4
+ from torch.nn import Conv1d, AvgPool1d, Conv2d
5
+ from torch.nn.utils import weight_norm, spectral_norm
6
+
7
+ from .utils import get_padding
8
+
9
+ LRELU_SLOPE = 0.1
10
+
11
+ def stft(x, fft_size, hop_size, win_length, window):
12
+ """Perform STFT and convert to magnitude spectrogram.
13
+ Args:
14
+ x (Tensor): Input signal tensor (B, T).
15
+ fft_size (int): FFT size.
16
+ hop_size (int): Hop size.
17
+ win_length (int): Window length.
18
+ window (str): Window function type.
19
+ Returns:
20
+ Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
21
+ """
22
+ x_stft = torch.stft(x, fft_size, hop_size, win_length, window,
23
+ return_complex=True)
24
+ real = x_stft[..., 0]
25
+ imag = x_stft[..., 1]
26
+
27
+ return torch.abs(x_stft).transpose(2, 1)
28
+
29
+ class SpecDiscriminator(nn.Module):
30
+ """docstring for Discriminator."""
31
+
32
+ def __init__(self, fft_size=1024, shift_size=120, win_length=600, window="hann_window", use_spectral_norm=False):
33
+ super(SpecDiscriminator, self).__init__()
34
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
35
+ self.fft_size = fft_size
36
+ self.shift_size = shift_size
37
+ self.win_length = win_length
38
+ self.window = getattr(torch, window)(win_length)
39
+ self.discriminators = nn.ModuleList([
40
+ norm_f(nn.Conv2d(1, 32, kernel_size=(3, 9), padding=(1, 4))),
41
+ norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1,2), padding=(1, 4))),
42
+ norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1,2), padding=(1, 4))),
43
+ norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1,2), padding=(1, 4))),
44
+ norm_f(nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1,1), padding=(1, 1))),
45
+ ])
46
+
47
+ self.out = norm_f(nn.Conv2d(32, 1, 3, 1, 1))
48
+
49
+ def forward(self, y):
50
+
51
+ fmap = []
52
+ y = y.squeeze(1)
53
+ y = stft(y, self.fft_size, self.shift_size, self.win_length, self.window.to(y.get_device()))
54
+ y = y.unsqueeze(1)
55
+ for i, d in enumerate(self.discriminators):
56
+ y = d(y)
57
+ y = F.leaky_relu(y, LRELU_SLOPE)
58
+ fmap.append(y)
59
+
60
+ y = self.out(y)
61
+ fmap.append(y)
62
+
63
+ return torch.flatten(y, 1, -1), fmap
64
+
65
+ class MultiResSpecDiscriminator(torch.nn.Module):
66
+
67
+ def __init__(self,
68
+ fft_sizes=[1024, 2048, 512],
69
+ hop_sizes=[120, 240, 50],
70
+ win_lengths=[600, 1200, 240],
71
+ window="hann_window"):
72
+
73
+ super(MultiResSpecDiscriminator, self).__init__()
74
+ self.discriminators = nn.ModuleList([
75
+ SpecDiscriminator(fft_sizes[0], hop_sizes[0], win_lengths[0], window),
76
+ SpecDiscriminator(fft_sizes[1], hop_sizes[1], win_lengths[1], window),
77
+ SpecDiscriminator(fft_sizes[2], hop_sizes[2], win_lengths[2], window)
78
+ ])
79
+
80
+ def forward(self, y, y_hat):
81
+ y_d_rs = []
82
+ y_d_gs = []
83
+ fmap_rs = []
84
+ fmap_gs = []
85
+ for i, d in enumerate(self.discriminators):
86
+ y_d_r, fmap_r = d(y)
87
+ y_d_g, fmap_g = d(y_hat)
88
+ y_d_rs.append(y_d_r)
89
+ fmap_rs.append(fmap_r)
90
+ y_d_gs.append(y_d_g)
91
+ fmap_gs.append(fmap_g)
92
+
93
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
94
+
95
+
96
+ class DiscriminatorP(torch.nn.Module):
97
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
98
+ super(DiscriminatorP, self).__init__()
99
+ self.period = period
100
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
101
+ self.convs = nn.ModuleList([
102
+ norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
103
+ norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
104
+ norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
105
+ norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
106
+ norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
107
+ ])
108
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
109
+
110
+ def forward(self, x):
111
+ fmap = []
112
+
113
+ # 1d to 2d
114
+ b, c, t = x.shape
115
+ if t % self.period != 0: # pad first
116
+ n_pad = self.period - (t % self.period)
117
+ x = F.pad(x, (0, n_pad), "reflect")
118
+ t = t + n_pad
119
+ x = x.view(b, c, t // self.period, self.period)
120
+
121
+ for l in self.convs:
122
+ x = l(x)
123
+ x = F.leaky_relu(x, LRELU_SLOPE)
124
+ fmap.append(x)
125
+ x = self.conv_post(x)
126
+ fmap.append(x)
127
+ x = torch.flatten(x, 1, -1)
128
+
129
+ return x, fmap
130
+
131
+
132
+ class MultiPeriodDiscriminator(torch.nn.Module):
133
+ def __init__(self):
134
+ super(MultiPeriodDiscriminator, self).__init__()
135
+ self.discriminators = nn.ModuleList([
136
+ DiscriminatorP(2),
137
+ DiscriminatorP(3),
138
+ DiscriminatorP(5),
139
+ DiscriminatorP(7),
140
+ DiscriminatorP(11),
141
+ ])
142
+
143
+ def forward(self, y, y_hat):
144
+ y_d_rs = []
145
+ y_d_gs = []
146
+ fmap_rs = []
147
+ fmap_gs = []
148
+ for i, d in enumerate(self.discriminators):
149
+ y_d_r, fmap_r = d(y)
150
+ y_d_g, fmap_g = d(y_hat)
151
+ y_d_rs.append(y_d_r)
152
+ fmap_rs.append(fmap_r)
153
+ y_d_gs.append(y_d_g)
154
+ fmap_gs.append(fmap_g)
155
+
156
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
157
+
158
+ class WavLMDiscriminator(nn.Module):
159
+ """docstring for Discriminator."""
160
+
161
+ def __init__(self, slm_hidden=768,
162
+ slm_layers=13,
163
+ initial_channel=64,
164
+ use_spectral_norm=False):
165
+ super(WavLMDiscriminator, self).__init__()
166
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
167
+ self.pre = norm_f(Conv1d(slm_hidden * slm_layers, initial_channel, 1, 1, padding=0))
168
+
169
+ self.convs = nn.ModuleList([
170
+ norm_f(nn.Conv1d(initial_channel, initial_channel * 2, kernel_size=5, padding=2)),
171
+ norm_f(nn.Conv1d(initial_channel * 2, initial_channel * 4, kernel_size=5, padding=2)),
172
+ norm_f(nn.Conv1d(initial_channel * 4, initial_channel * 4, 5, 1, padding=2)),
173
+ ])
174
+
175
+ self.conv_post = norm_f(Conv1d(initial_channel * 4, 1, 3, 1, padding=1))
176
+
177
+ def forward(self, x):
178
+ x = self.pre(x)
179
+
180
+ fmap = []
181
+ for l in self.convs:
182
+ x = l(x)
183
+ x = F.leaky_relu(x, LRELU_SLOPE)
184
+ fmap.append(x)
185
+ x = self.conv_post(x)
186
+ x = torch.flatten(x, 1, -1)
187
+
188
+ return x
libs/Modules/hifigan.py ADDED
@@ -0,0 +1,477 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import torch.nn as nn
4
+ from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
5
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
6
+ from .utils import init_weights, get_padding
7
+
8
+ import math
9
+ import random
10
+ import numpy as np
11
+
12
+ LRELU_SLOPE = 0.1
13
+
14
+ class AdaIN1d(nn.Module):
15
+ def __init__(self, style_dim, num_features):
16
+ super().__init__()
17
+ self.norm = nn.InstanceNorm1d(num_features, affine=False)
18
+ self.fc = nn.Linear(style_dim, num_features*2)
19
+
20
+ def forward(self, x, s):
21
+ h = self.fc(s)
22
+ h = h.view(h.size(0), h.size(1), 1)
23
+ gamma, beta = torch.chunk(h, chunks=2, dim=1)
24
+ return (1 + gamma) * self.norm(x) + beta
25
+
26
+ class AdaINResBlock1(torch.nn.Module):
27
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), style_dim=64):
28
+ super(AdaINResBlock1, self).__init__()
29
+ self.convs1 = nn.ModuleList([
30
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
31
+ padding=get_padding(kernel_size, dilation[0]))),
32
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
33
+ padding=get_padding(kernel_size, dilation[1]))),
34
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
35
+ padding=get_padding(kernel_size, dilation[2])))
36
+ ])
37
+ self.convs1.apply(init_weights)
38
+
39
+ self.convs2 = nn.ModuleList([
40
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
41
+ padding=get_padding(kernel_size, 1))),
42
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
43
+ padding=get_padding(kernel_size, 1))),
44
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
45
+ padding=get_padding(kernel_size, 1)))
46
+ ])
47
+ self.convs2.apply(init_weights)
48
+
49
+ self.adain1 = nn.ModuleList([
50
+ AdaIN1d(style_dim, channels),
51
+ AdaIN1d(style_dim, channels),
52
+ AdaIN1d(style_dim, channels),
53
+ ])
54
+
55
+ self.adain2 = nn.ModuleList([
56
+ AdaIN1d(style_dim, channels),
57
+ AdaIN1d(style_dim, channels),
58
+ AdaIN1d(style_dim, channels),
59
+ ])
60
+
61
+ self.alpha1 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs1))])
62
+ self.alpha2 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs2))])
63
+
64
+
65
+ def forward(self, x, s):
66
+ for c1, c2, n1, n2, a1, a2 in zip(self.convs1, self.convs2, self.adain1, self.adain2, self.alpha1, self.alpha2):
67
+ xt = n1(x, s)
68
+ xt = xt + (1 / a1) * (torch.sin(a1 * xt) ** 2) # Snake1D
69
+ xt = c1(xt)
70
+ xt = n2(xt, s)
71
+ xt = xt + (1 / a2) * (torch.sin(a2 * xt) ** 2) # Snake1D
72
+ xt = c2(xt)
73
+ x = xt + x
74
+ return x
75
+
76
+ def remove_weight_norm(self):
77
+ for l in self.convs1:
78
+ remove_weight_norm(l)
79
+ for l in self.convs2:
80
+ remove_weight_norm(l)
81
+
82
+ class SineGen(torch.nn.Module):
83
+ """ Definition of sine generator
84
+ SineGen(samp_rate, harmonic_num = 0,
85
+ sine_amp = 0.1, noise_std = 0.003,
86
+ voiced_threshold = 0,
87
+ flag_for_pulse=False)
88
+ samp_rate: sampling rate in Hz
89
+ harmonic_num: number of harmonic overtones (default 0)
90
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
91
+ noise_std: std of Gaussian noise (default 0.003)
92
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
93
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
94
+ Note: when flag_for_pulse is True, the first time step of a voiced
95
+ segment is always sin(np.pi) or cos(0)
96
+ """
97
+
98
+ def __init__(self, samp_rate, upsample_scale, harmonic_num=0,
99
+ sine_amp=0.1, noise_std=0.003,
100
+ voiced_threshold=0,
101
+ flag_for_pulse=False):
102
+ super(SineGen, self).__init__()
103
+ self.sine_amp = sine_amp
104
+ self.noise_std = noise_std
105
+ self.harmonic_num = harmonic_num
106
+ self.dim = self.harmonic_num + 1
107
+ self.sampling_rate = samp_rate
108
+ self.voiced_threshold = voiced_threshold
109
+ self.flag_for_pulse = flag_for_pulse
110
+ self.upsample_scale = upsample_scale
111
+
112
+ def _f02uv(self, f0):
113
+ # generate uv signal
114
+ uv = (f0 > self.voiced_threshold).type(torch.float32)
115
+ return uv
116
+
117
+ def _f02sine(self, f0_values):
118
+ """ f0_values: (batchsize, length, dim)
119
+ where dim indicates fundamental tone and overtones
120
+ """
121
+ # convert to F0 in rad. The interger part n can be ignored
122
+ # because 2 * np.pi * n doesn't affect phase
123
+ rad_values = (f0_values / self.sampling_rate) % 1
124
+
125
+ # initial phase noise (no noise for fundamental component)
126
+ rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], \
127
+ device=f0_values.device)
128
+ rand_ini[:, 0] = 0
129
+ rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
130
+
131
+ # instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
132
+ if not self.flag_for_pulse:
133
+ # # for normal case
134
+
135
+ # # To prevent torch.cumsum numerical overflow,
136
+ # # it is necessary to add -1 whenever \sum_k=1^n rad_value_k > 1.
137
+ # # Buffer tmp_over_one_idx indicates the time step to add -1.
138
+ # # This will not change F0 of sine because (x-1) * 2*pi = x * 2*pi
139
+ # tmp_over_one = torch.cumsum(rad_values, 1) % 1
140
+ # tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
141
+ # cumsum_shift = torch.zeros_like(rad_values)
142
+ # cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
143
+
144
+ # phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
145
+ rad_values = torch.nn.functional.interpolate(rad_values.transpose(1, 2),
146
+ scale_factor=1/self.upsample_scale,
147
+ mode="linear").transpose(1, 2)
148
+
149
+ # tmp_over_one = torch.cumsum(rad_values, 1) % 1
150
+ # tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
151
+ # cumsum_shift = torch.zeros_like(rad_values)
152
+ # cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
153
+
154
+ phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
155
+ phase = torch.nn.functional.interpolate(phase.transpose(1, 2) * self.upsample_scale,
156
+ scale_factor=self.upsample_scale, mode="linear").transpose(1, 2)
157
+ sines = torch.sin(phase)
158
+
159
+ else:
160
+ # If necessary, make sure that the first time step of every
161
+ # voiced segments is sin(pi) or cos(0)
162
+ # This is used for pulse-train generation
163
+
164
+ # identify the last time step in unvoiced segments
165
+ uv = self._f02uv(f0_values)
166
+ uv_1 = torch.roll(uv, shifts=-1, dims=1)
167
+ uv_1[:, -1, :] = 1
168
+ u_loc = (uv < 1) * (uv_1 > 0)
169
+
170
+ # get the instantanouse phase
171
+ tmp_cumsum = torch.cumsum(rad_values, dim=1)
172
+ # different batch needs to be processed differently
173
+ for idx in range(f0_values.shape[0]):
174
+ temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
175
+ temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
176
+ # stores the accumulation of i.phase within
177
+ # each voiced segments
178
+ tmp_cumsum[idx, :, :] = 0
179
+ tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
180
+
181
+ # rad_values - tmp_cumsum: remove the accumulation of i.phase
182
+ # within the previous voiced segment.
183
+ i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
184
+
185
+ # get the sines
186
+ sines = torch.cos(i_phase * 2 * np.pi)
187
+ return sines
188
+
189
+ def forward(self, f0):
190
+ """ sine_tensor, uv = forward(f0)
191
+ input F0: tensor(batchsize=1, length, dim=1)
192
+ f0 for unvoiced steps should be 0
193
+ output sine_tensor: tensor(batchsize=1, length, dim)
194
+ output uv: tensor(batchsize=1, length, 1)
195
+ """
196
+ f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim,
197
+ device=f0.device)
198
+ # fundamental component
199
+ fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device))
200
+
201
+ # generate sine waveforms
202
+ sine_waves = self._f02sine(fn) * self.sine_amp
203
+
204
+ # generate uv signal
205
+ # uv = torch.ones(f0.shape)
206
+ # uv = uv * (f0 > self.voiced_threshold)
207
+ uv = self._f02uv(f0)
208
+
209
+ # noise: for unvoiced should be similar to sine_amp
210
+ # std = self.sine_amp/3 -> max value ~ self.sine_amp
211
+ # . for voiced regions is self.noise_std
212
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
213
+ noise = noise_amp * torch.randn_like(sine_waves)
214
+
215
+ # first: set the unvoiced part to 0 by uv
216
+ # then: additive noise
217
+ sine_waves = sine_waves * uv + noise
218
+ return sine_waves, uv, noise
219
+
220
+
221
+ class SourceModuleHnNSF(torch.nn.Module):
222
+ """ SourceModule for hn-nsf
223
+ SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
224
+ add_noise_std=0.003, voiced_threshod=0)
225
+ sampling_rate: sampling_rate in Hz
226
+ harmonic_num: number of harmonic above F0 (default: 0)
227
+ sine_amp: amplitude of sine source signal (default: 0.1)
228
+ add_noise_std: std of additive Gaussian noise (default: 0.003)
229
+ note that amplitude of noise in unvoiced is decided
230
+ by sine_amp
231
+ voiced_threshold: threhold to set U/V given F0 (default: 0)
232
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
233
+ F0_sampled (batchsize, length, 1)
234
+ Sine_source (batchsize, length, 1)
235
+ noise_source (batchsize, length 1)
236
+ uv (batchsize, length, 1)
237
+ """
238
+
239
+ def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
240
+ add_noise_std=0.003, voiced_threshod=0):
241
+ super(SourceModuleHnNSF, self).__init__()
242
+
243
+ self.sine_amp = sine_amp
244
+ self.noise_std = add_noise_std
245
+
246
+ # to produce sine waveforms
247
+ self.l_sin_gen = SineGen(sampling_rate, upsample_scale, harmonic_num,
248
+ sine_amp, add_noise_std, voiced_threshod)
249
+
250
+ # to merge source harmonics into a single excitation
251
+ self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
252
+ self.l_tanh = torch.nn.Tanh()
253
+
254
+ def forward(self, x):
255
+ """
256
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
257
+ F0_sampled (batchsize, length, 1)
258
+ Sine_source (batchsize, length, 1)
259
+ noise_source (batchsize, length 1)
260
+ """
261
+ # source for harmonic branch
262
+ with torch.no_grad():
263
+ sine_wavs, uv, _ = self.l_sin_gen(x)
264
+ sine_merge = self.l_tanh(self.l_linear(sine_wavs))
265
+
266
+ # source for noise branch, in the same shape as uv
267
+ noise = torch.randn_like(uv) * self.sine_amp / 3
268
+ return sine_merge, noise, uv
269
+ def padDiff(x):
270
+ return F.pad(F.pad(x, (0,0,-1,1), 'constant', 0) - x, (0,0,0,-1), 'constant', 0)
271
+
272
+ class Generator(torch.nn.Module):
273
+ def __init__(self, style_dim, resblock_kernel_sizes, upsample_rates, upsample_initial_channel, resblock_dilation_sizes, upsample_kernel_sizes):
274
+ super(Generator, self).__init__()
275
+ self.num_kernels = len(resblock_kernel_sizes)
276
+ self.num_upsamples = len(upsample_rates)
277
+ resblock = AdaINResBlock1
278
+
279
+ self.m_source = SourceModuleHnNSF(
280
+ sampling_rate=24000,
281
+ upsample_scale=np.prod(upsample_rates),
282
+ harmonic_num=8, voiced_threshod=10)
283
+
284
+ self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates))
285
+ self.noise_convs = nn.ModuleList()
286
+ self.ups = nn.ModuleList()
287
+ self.noise_res = nn.ModuleList()
288
+
289
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
290
+ c_cur = upsample_initial_channel // (2 ** (i + 1))
291
+
292
+ self.ups.append(weight_norm(ConvTranspose1d(upsample_initial_channel//(2**i),
293
+ upsample_initial_channel//(2**(i+1)),
294
+ k, u, padding=(u//2 + u%2), output_padding=u%2)))
295
+
296
+ if i + 1 < len(upsample_rates): #
297
+ stride_f0 = np.prod(upsample_rates[i + 1:])
298
+ self.noise_convs.append(Conv1d(
299
+ 1, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=(stride_f0+1) // 2))
300
+ self.noise_res.append(resblock(c_cur, 7, [1,3,5], style_dim))
301
+ else:
302
+ self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
303
+ self.noise_res.append(resblock(c_cur, 11, [1,3,5], style_dim))
304
+
305
+ self.resblocks = nn.ModuleList()
306
+
307
+ self.alphas = nn.ParameterList()
308
+ self.alphas.append(nn.Parameter(torch.ones(1, upsample_initial_channel, 1)))
309
+
310
+ for i in range(len(self.ups)):
311
+ ch = upsample_initial_channel//(2**(i+1))
312
+ self.alphas.append(nn.Parameter(torch.ones(1, ch, 1)))
313
+
314
+ for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
315
+ self.resblocks.append(resblock(ch, k, d, style_dim))
316
+
317
+ self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
318
+ self.ups.apply(init_weights)
319
+ self.conv_post.apply(init_weights)
320
+
321
+ def forward(self, x, s, f0):
322
+
323
+ f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
324
+
325
+ har_source, noi_source, uv = self.m_source(f0)
326
+ har_source = har_source.transpose(1, 2)
327
+
328
+ for i in range(self.num_upsamples):
329
+ x = x + (1 / self.alphas[i]) * (torch.sin(self.alphas[i] * x) ** 2)
330
+ x_source = self.noise_convs[i](har_source)
331
+ x_source = self.noise_res[i](x_source, s)
332
+
333
+ x = self.ups[i](x)
334
+ x = x + x_source
335
+
336
+ xs = None
337
+ for j in range(self.num_kernels):
338
+ if xs is None:
339
+ xs = self.resblocks[i*self.num_kernels+j](x, s)
340
+ else:
341
+ xs += self.resblocks[i*self.num_kernels+j](x, s)
342
+ x = xs / self.num_kernels
343
+ x = x + (1 / self.alphas[i+1]) * (torch.sin(self.alphas[i+1] * x) ** 2)
344
+ x = self.conv_post(x)
345
+ x = torch.tanh(x)
346
+
347
+ return x
348
+
349
+ def remove_weight_norm(self):
350
+ print('Removing weight norm...')
351
+ for l in self.ups:
352
+ remove_weight_norm(l)
353
+ for l in self.resblocks:
354
+ l.remove_weight_norm()
355
+ remove_weight_norm(self.conv_pre)
356
+ remove_weight_norm(self.conv_post)
357
+
358
+
359
+ class AdainResBlk1d(nn.Module):
360
+ def __init__(self, dim_in, dim_out, style_dim=64, actv=nn.LeakyReLU(0.2),
361
+ upsample='none', dropout_p=0.0):
362
+ super().__init__()
363
+ self.actv = actv
364
+ self.upsample_type = upsample
365
+ self.upsample = UpSample1d(upsample)
366
+ self.learned_sc = dim_in != dim_out
367
+ self._build_weights(dim_in, dim_out, style_dim)
368
+ self.dropout = nn.Dropout(dropout_p)
369
+
370
+ if upsample == 'none':
371
+ self.pool = nn.Identity()
372
+ else:
373
+ self.pool = weight_norm(nn.ConvTranspose1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1, output_padding=1))
374
+
375
+
376
+ def _build_weights(self, dim_in, dim_out, style_dim):
377
+ self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
378
+ self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1))
379
+ self.norm1 = AdaIN1d(style_dim, dim_in)
380
+ self.norm2 = AdaIN1d(style_dim, dim_out)
381
+ if self.learned_sc:
382
+ self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
383
+
384
+ def _shortcut(self, x):
385
+ x = self.upsample(x)
386
+ if self.learned_sc:
387
+ x = self.conv1x1(x)
388
+ return x
389
+
390
+ def _residual(self, x, s):
391
+ x = self.norm1(x, s)
392
+ x = self.actv(x)
393
+ x = self.pool(x)
394
+ x = self.conv1(self.dropout(x))
395
+ x = self.norm2(x, s)
396
+ x = self.actv(x)
397
+ x = self.conv2(self.dropout(x))
398
+ return x
399
+
400
+ def forward(self, x, s):
401
+ out = self._residual(x, s)
402
+ out = (out + self._shortcut(x)) / math.sqrt(2)
403
+ return out
404
+
405
+ class UpSample1d(nn.Module):
406
+ def __init__(self, layer_type):
407
+ super().__init__()
408
+ self.layer_type = layer_type
409
+
410
+ def forward(self, x):
411
+ if self.layer_type == 'none':
412
+ return x
413
+ else:
414
+ return F.interpolate(x, scale_factor=2, mode='nearest')
415
+
416
+ class Decoder(nn.Module):
417
+ def __init__(self, dim_in=512, F0_channel=512, style_dim=64, dim_out=80,
418
+ resblock_kernel_sizes = [3,7,11],
419
+ upsample_rates = [10,5,3,2],
420
+ upsample_initial_channel=512,
421
+ resblock_dilation_sizes=[[1,3,5], [1,3,5], [1,3,5]],
422
+ upsample_kernel_sizes=[20,10,6,4]):
423
+ super().__init__()
424
+
425
+ self.decode = nn.ModuleList()
426
+
427
+ self.encode = AdainResBlk1d(dim_in + 2, 1024, style_dim)
428
+
429
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
430
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
431
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
432
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 512, style_dim, upsample=True))
433
+
434
+ self.F0_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1))
435
+
436
+ self.N_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1))
437
+
438
+ self.asr_res = nn.Sequential(
439
+ weight_norm(nn.Conv1d(512, 64, kernel_size=1)),
440
+ )
441
+
442
+
443
+ self.generator = Generator(style_dim, resblock_kernel_sizes, upsample_rates, upsample_initial_channel, resblock_dilation_sizes, upsample_kernel_sizes)
444
+
445
+
446
+ def forward(self, asr, F0_curve, N, s):
447
+ if self.training:
448
+ downlist = [0, 3, 7]
449
+ F0_down = downlist[random.randint(0, 2)]
450
+ downlist = [0, 3, 7, 15]
451
+ N_down = downlist[random.randint(0, 3)]
452
+ if F0_down:
453
+ F0_curve = nn.functional.conv1d(F0_curve.unsqueeze(1), torch.ones(1, 1, F0_down).to(asr.device), padding=F0_down//2).squeeze(1) / F0_down
454
+ if N_down:
455
+ N = nn.functional.conv1d(N.unsqueeze(1), torch.ones(1, 1, N_down).to(asr.device), padding=N_down//2).squeeze(1) / N_down
456
+
457
+
458
+ F0 = self.F0_conv(F0_curve.unsqueeze(1))
459
+ N = self.N_conv(N.unsqueeze(1))
460
+
461
+ x = torch.cat([asr, F0, N], axis=1)
462
+ x = self.encode(x, s)
463
+
464
+ asr_res = self.asr_res(asr)
465
+
466
+ res = True
467
+ for block in self.decode:
468
+ if res:
469
+ x = torch.cat([x, asr_res, F0, N], axis=1)
470
+ x = block(x, s)
471
+ if block.upsample_type != "none":
472
+ res = False
473
+
474
+ x = self.generator(x, s, F0_curve)
475
+ return x
476
+
477
+
libs/Modules/istftnet.py ADDED
@@ -0,0 +1,723 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import torch.nn as nn
4
+ from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
5
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
6
+ from .utils import init_weights, get_padding
7
+
8
+ import math
9
+ import random
10
+ import numpy as np
11
+ from scipy.signal import get_window
12
+
13
+ LRELU_SLOPE = 0.1
14
+
15
+ class AdaIN1d(nn.Module):
16
+ def __init__(self, style_dim, num_features):
17
+ super().__init__()
18
+ self.norm = nn.InstanceNorm1d(num_features, affine=False)
19
+ self.fc = nn.Linear(style_dim, num_features*2)
20
+
21
+ def forward(self, x, s):
22
+ h = self.fc(s)
23
+ h = h.view(h.size(0), h.size(1), 1)
24
+ gamma, beta = torch.chunk(h, chunks=2, dim=1)
25
+ return (1 + gamma) * self.norm(x) + beta
26
+
27
+ class AdaINResBlock1(torch.nn.Module):
28
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), style_dim=64):
29
+ super(AdaINResBlock1, self).__init__()
30
+ self.convs1 = nn.ModuleList([
31
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
32
+ padding=get_padding(kernel_size, dilation[0]))),
33
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
34
+ padding=get_padding(kernel_size, dilation[1]))),
35
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
36
+ padding=get_padding(kernel_size, dilation[2])))
37
+ ])
38
+ self.convs1.apply(init_weights)
39
+
40
+ self.convs2 = nn.ModuleList([
41
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
42
+ padding=get_padding(kernel_size, 1))),
43
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
44
+ padding=get_padding(kernel_size, 1))),
45
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
46
+ padding=get_padding(kernel_size, 1)))
47
+ ])
48
+ self.convs2.apply(init_weights)
49
+
50
+ self.adain1 = nn.ModuleList([
51
+ AdaIN1d(style_dim, channels),
52
+ AdaIN1d(style_dim, channels),
53
+ AdaIN1d(style_dim, channels),
54
+ ])
55
+
56
+ self.adain2 = nn.ModuleList([
57
+ AdaIN1d(style_dim, channels),
58
+ AdaIN1d(style_dim, channels),
59
+ AdaIN1d(style_dim, channels),
60
+ ])
61
+
62
+ self.alpha1 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs1))])
63
+ self.alpha2 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs2))])
64
+
65
+
66
+ def forward(self, x, s):
67
+ for c1, c2, n1, n2, a1, a2 in zip(self.convs1, self.convs2, self.adain1, self.adain2, self.alpha1, self.alpha2):
68
+ xt = n1(x, s)
69
+ xt = xt + (1 / a1) * (torch.sin(a1 * xt) ** 2) # Snake1D
70
+ xt = c1(xt)
71
+ xt = n2(xt, s)
72
+ xt = xt + (1 / a2) * (torch.sin(a2 * xt) ** 2) # Snake1D
73
+ xt = c2(xt)
74
+ x = xt + x
75
+ return x
76
+
77
+ def remove_weight_norm(self):
78
+ for l in self.convs1:
79
+ remove_weight_norm(l)
80
+ for l in self.convs2:
81
+ remove_weight_norm(l)
82
+
83
+ class TorchSTFT(torch.nn.Module):
84
+ def __init__(self, filter_length=800, hop_length=200, win_length=800, window='hann'):
85
+ super().__init__()
86
+ self.filter_length = filter_length
87
+ self.hop_length = hop_length
88
+ self.win_length = win_length
89
+ self.window = torch.from_numpy(get_window(window, win_length, fftbins=True).astype(np.float32))
90
+
91
+ def transform(self, input_data):
92
+ forward_transform = torch.stft(
93
+ input_data,
94
+ self.filter_length, self.hop_length, self.win_length, window=self.window.to(input_data.device),
95
+ return_complex=True)
96
+
97
+ return torch.abs(forward_transform), torch.angle(forward_transform)
98
+
99
+ def inverse(self, magnitude, phase):
100
+ inverse_transform = torch.istft(
101
+ magnitude * torch.exp(phase * 1j),
102
+ self.filter_length, self.hop_length, self.win_length, window=self.window.to(magnitude.device))
103
+
104
+ return inverse_transform.unsqueeze(-2) # unsqueeze to stay consistent with conv_transpose1d implementation
105
+
106
+ def forward(self, input_data):
107
+ self.magnitude, self.phase = self.transform(input_data)
108
+ reconstruction = self.inverse(self.magnitude, self.phase)
109
+ return reconstruction
110
+
111
+ class CustomSTFT(nn.Module):
112
+ """
113
+ STFT/iSTFT without unfold/complex ops, using conv1d and conv_transpose1d.
114
+
115
+ - forward STFT => Real-part conv1d + Imag-part conv1d
116
+ - inverse STFT => Real-part conv_transpose1d + Imag-part conv_transpose1d + sum
117
+ - avoids F.unfold, so easier to export to ONNX
118
+ - uses replicate or constant padding for 'center=True' to approximate 'reflect'
119
+ (reflect is not supported for dynamic shapes in ONNX)
120
+ """
121
+
122
+ def __init__(
123
+ self,
124
+ filter_length=800,
125
+ hop_length=200,
126
+ win_length=800,
127
+ window="hann",
128
+ center=True,
129
+ pad_mode="replicate", # or 'constant'
130
+ ):
131
+ super().__init__()
132
+ self.filter_length = filter_length
133
+ self.hop_length = hop_length
134
+ self.win_length = win_length
135
+ self.n_fft = filter_length
136
+ self.center = center
137
+ self.pad_mode = pad_mode
138
+
139
+ # Number of frequency bins for real-valued STFT with onesided=True
140
+ self.freq_bins = self.n_fft // 2 + 1
141
+
142
+ # Build window
143
+ assert window == 'hann', window
144
+ window_tensor = torch.hann_window(win_length, periodic=True, dtype=torch.float32)
145
+ if self.win_length < self.n_fft:
146
+ # Zero-pad up to n_fft
147
+ extra = self.n_fft - self.win_length
148
+ window_tensor = F.pad(window_tensor, (0, extra))
149
+ elif self.win_length > self.n_fft:
150
+ window_tensor = window_tensor[: self.n_fft]
151
+ self.register_buffer("window", window_tensor)
152
+
153
+ # Precompute forward DFT (real, imag)
154
+ # PyTorch stft uses e^{-j 2 pi k n / N} => real=cos(...), imag=-sin(...)
155
+ n = np.arange(self.n_fft)
156
+ k = np.arange(self.freq_bins)
157
+ angle = 2 * np.pi * np.outer(k, n) / self.n_fft # shape (freq_bins, n_fft)
158
+ dft_real = np.cos(angle)
159
+ dft_imag = -np.sin(angle) # note negative sign
160
+
161
+ # Combine window and dft => shape (freq_bins, filter_length)
162
+ # We'll make 2 conv weight tensors of shape (freq_bins, 1, filter_length).
163
+ forward_window = window_tensor.numpy() # shape (n_fft,)
164
+ forward_real = dft_real * forward_window # (freq_bins, n_fft)
165
+ forward_imag = dft_imag * forward_window
166
+
167
+ # Convert to PyTorch
168
+ forward_real_torch = torch.from_numpy(forward_real).float()
169
+ forward_imag_torch = torch.from_numpy(forward_imag).float()
170
+
171
+ # Register as Conv1d weight => (out_channels, in_channels, kernel_size)
172
+ # out_channels = freq_bins, in_channels=1, kernel_size=n_fft
173
+ self.register_buffer(
174
+ "weight_forward_real", forward_real_torch.unsqueeze(1)
175
+ )
176
+ self.register_buffer(
177
+ "weight_forward_imag", forward_imag_torch.unsqueeze(1)
178
+ )
179
+
180
+ # Precompute inverse DFT
181
+ # Real iFFT formula => scale = 1/n_fft, doubling for bins 1..freq_bins-2 if n_fft even, etc.
182
+ # For simplicity, we won't do the "DC/nyquist not doubled" approach here.
183
+ # If you want perfect real iSTFT, you can add that logic.
184
+ # This version just yields good approximate reconstruction with Hann + typical overlap.
185
+ inv_scale = 1.0 / self.n_fft
186
+ n = np.arange(self.n_fft)
187
+ angle_t = 2 * np.pi * np.outer(n, k) / self.n_fft # shape (n_fft, freq_bins)
188
+ idft_cos = np.cos(angle_t).T # => (freq_bins, n_fft)
189
+ idft_sin = np.sin(angle_t).T # => (freq_bins, n_fft)
190
+
191
+ # Multiply by window again for typical overlap-add
192
+ # We also incorporate the scale factor 1/n_fft
193
+ inv_window = window_tensor.numpy() * inv_scale
194
+ backward_real = idft_cos * inv_window # (freq_bins, n_fft)
195
+ backward_imag = idft_sin * inv_window
196
+
197
+ # We'll implement iSTFT as real+imag conv_transpose with stride=hop.
198
+ self.register_buffer(
199
+ "weight_backward_real", torch.from_numpy(backward_real).float().unsqueeze(1)
200
+ )
201
+ self.register_buffer(
202
+ "weight_backward_imag", torch.from_numpy(backward_imag).float().unsqueeze(1)
203
+ )
204
+
205
+
206
+
207
+ def transform(self, waveform: torch.Tensor):
208
+ """
209
+ Forward STFT => returns magnitude, phase
210
+ Output shape => (batch, freq_bins, frames)
211
+ """
212
+ # waveform shape => (B, T). conv1d expects (B, 1, T).
213
+ # Optional center pad
214
+ if self.center:
215
+ pad_len = self.n_fft // 2
216
+ waveform = F.pad(waveform, (pad_len, pad_len), mode=self.pad_mode)
217
+
218
+ x = waveform.unsqueeze(1) # => (B, 1, T)
219
+ # Convolution to get real part => shape (B, freq_bins, frames)
220
+ real_out = F.conv1d(
221
+ x,
222
+ self.weight_forward_real,
223
+ bias=None,
224
+ stride=self.hop_length,
225
+ padding=0,
226
+ )
227
+ # Imag part
228
+ imag_out = F.conv1d(
229
+ x,
230
+ self.weight_forward_imag,
231
+ bias=None,
232
+ stride=self.hop_length,
233
+ padding=0,
234
+ )
235
+
236
+ # magnitude, phase
237
+ magnitude = torch.sqrt(real_out**2 + imag_out**2 + 1e-14)
238
+ phase = torch.atan2(imag_out, real_out)
239
+ # Handle the case where imag_out is 0 and real_out is negative to correct ONNX atan2 to match PyTorch
240
+ # In this case, PyTorch returns pi, ONNX returns -pi
241
+ correction_mask = (imag_out == 0) & (real_out < 0)
242
+ phase[correction_mask] = torch.pi
243
+ return magnitude, phase
244
+
245
+
246
+ def inverse(self, magnitude: torch.Tensor, phase: torch.Tensor, length=None):
247
+ """
248
+ Inverse STFT => returns waveform shape (B, T).
249
+ """
250
+ # magnitude, phase => (B, freq_bins, frames)
251
+ # Re-create real/imag => shape (B, freq_bins, frames)
252
+ real_part = magnitude * torch.cos(phase)
253
+ imag_part = magnitude * torch.sin(phase)
254
+
255
+ # conv_transpose wants shape (B, freq_bins, frames). We'll treat "frames" as time dimension
256
+ # so we do (B, freq_bins, frames) => (B, freq_bins, frames)
257
+ # But PyTorch conv_transpose1d expects (B, in_channels, input_length)
258
+ real_part = real_part # (B, freq_bins, frames)
259
+ imag_part = imag_part
260
+
261
+ # real iSTFT => convolve with "backward_real", "backward_imag", and sum
262
+ # We'll do 2 conv_transpose calls, each giving (B, 1, time),
263
+ # then add them => (B, 1, time).
264
+ real_rec = F.conv_transpose1d(
265
+ real_part,
266
+ self.weight_backward_real, # shape (freq_bins, 1, filter_length)
267
+ bias=None,
268
+ stride=self.hop_length,
269
+ padding=0,
270
+ )
271
+ imag_rec = F.conv_transpose1d(
272
+ imag_part,
273
+ self.weight_backward_imag,
274
+ bias=None,
275
+ stride=self.hop_length,
276
+ padding=0,
277
+ )
278
+ # sum => (B, 1, time)
279
+ waveform = real_rec - imag_rec # typical real iFFT has minus for imaginary part
280
+
281
+ # If we used "center=True" in forward, we should remove pad
282
+ if self.center:
283
+ pad_len = self.n_fft // 2
284
+ # Because of transposed convolution, total length might have extra samples
285
+ # We remove `pad_len` from start & end if possible
286
+ waveform = waveform[..., pad_len:-pad_len]
287
+
288
+ # If a specific length is desired, clamp
289
+ if length is not None:
290
+ waveform = waveform[..., :length]
291
+
292
+ # shape => (B, T)
293
+ return waveform
294
+
295
+ def forward(self, x: torch.Tensor):
296
+ """
297
+ Full STFT -> iSTFT pass: returns time-domain reconstruction.
298
+ Same interface as your original code.
299
+ """
300
+ mag, phase = self.transform(x)
301
+ return self.inverse(mag, phase, length=x.shape[-1])
302
+
303
+ class SineGen(torch.nn.Module):
304
+ """ Definition of sine generator
305
+ SineGen(samp_rate, harmonic_num = 0,
306
+ sine_amp = 0.1, noise_std = 0.003,
307
+ voiced_threshold = 0,
308
+ flag_for_pulse=False)
309
+ samp_rate: sampling rate in Hz
310
+ harmonic_num: number of harmonic overtones (default 0)
311
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
312
+ noise_std: std of Gaussian noise (default 0.003)
313
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
314
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
315
+ Note: when flag_for_pulse is True, the first time step of a voiced
316
+ segment is always sin(np.pi) or cos(0)
317
+ """
318
+
319
+ def __init__(self, samp_rate, upsample_scale, harmonic_num=0,
320
+ sine_amp=0.1, noise_std=0.003,
321
+ voiced_threshold=0,
322
+ flag_for_pulse=False):
323
+ super(SineGen, self).__init__()
324
+ self.sine_amp = sine_amp
325
+ self.noise_std = noise_std
326
+ self.harmonic_num = harmonic_num
327
+ self.dim = self.harmonic_num + 1
328
+ self.sampling_rate = samp_rate
329
+ self.voiced_threshold = voiced_threshold
330
+ self.flag_for_pulse = flag_for_pulse
331
+ self.upsample_scale = upsample_scale
332
+
333
+ def _f02uv(self, f0):
334
+ # generate uv signal
335
+ uv = (f0 > self.voiced_threshold).type(torch.float32)
336
+ return uv
337
+
338
+ def _f02sine(self, f0_values):
339
+ """ f0_values: (batchsize, length, dim)
340
+ where dim indicates fundamental tone and overtones
341
+ """
342
+ # convert to F0 in rad. The interger part n can be ignored
343
+ # because 2 * np.pi * n doesn't affect phase
344
+ rad_values = (f0_values / self.sampling_rate) % 1
345
+
346
+ # initial phase noise (no noise for fundamental component)
347
+ rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], \
348
+ device=f0_values.device)
349
+ rand_ini[:, 0] = 0
350
+ rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
351
+
352
+ # instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
353
+ if not self.flag_for_pulse:
354
+ # # for normal case
355
+
356
+ # # To prevent torch.cumsum numerical overflow,
357
+ # # it is necessary to add -1 whenever \sum_k=1^n rad_value_k > 1.
358
+ # # Buffer tmp_over_one_idx indicates the time step to add -1.
359
+ # # This will not change F0 of sine because (x-1) * 2*pi = x * 2*pi
360
+ # tmp_over_one = torch.cumsum(rad_values, 1) % 1
361
+ # tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
362
+ # cumsum_shift = torch.zeros_like(rad_values)
363
+ # cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
364
+
365
+ # phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
366
+ rad_values = torch.nn.functional.interpolate(rad_values.transpose(1, 2),
367
+ scale_factor=1/self.upsample_scale,
368
+ mode="linear").transpose(1, 2)
369
+
370
+ # tmp_over_one = torch.cumsum(rad_values, 1) % 1
371
+ # tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
372
+ # cumsum_shift = torch.zeros_like(rad_values)
373
+ # cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
374
+
375
+ phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
376
+ phase = torch.nn.functional.interpolate(phase.transpose(1, 2) * self.upsample_scale,
377
+ scale_factor=self.upsample_scale, mode="linear").transpose(1, 2)
378
+ sines = torch.sin(phase)
379
+
380
+ else:
381
+ # If necessary, make sure that the first time step of every
382
+ # voiced segments is sin(pi) or cos(0)
383
+ # This is used for pulse-train generation
384
+
385
+ # identify the last time step in unvoiced segments
386
+ uv = self._f02uv(f0_values)
387
+ uv_1 = torch.roll(uv, shifts=-1, dims=1)
388
+ uv_1[:, -1, :] = 1
389
+ u_loc = (uv < 1) * (uv_1 > 0)
390
+
391
+ # get the instantanouse phase
392
+ tmp_cumsum = torch.cumsum(rad_values, dim=1)
393
+ # different batch needs to be processed differently
394
+ for idx in range(f0_values.shape[0]):
395
+ temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
396
+ temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
397
+ # stores the accumulation of i.phase within
398
+ # each voiced segments
399
+ tmp_cumsum[idx, :, :] = 0
400
+ tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
401
+
402
+ # rad_values - tmp_cumsum: remove the accumulation of i.phase
403
+ # within the previous voiced segment.
404
+ i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
405
+
406
+ # get the sines
407
+ sines = torch.cos(i_phase * 2 * np.pi)
408
+ return sines
409
+
410
+ def forward(self, f0):
411
+ """ sine_tensor, uv = forward(f0)
412
+ input F0: tensor(batchsize=1, length, dim=1)
413
+ f0 for unvoiced steps should be 0
414
+ output sine_tensor: tensor(batchsize=1, length, dim)
415
+ output uv: tensor(batchsize=1, length, 1)
416
+ """
417
+ f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim,
418
+ device=f0.device)
419
+ # fundamental component
420
+ fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device))
421
+
422
+ # generate sine waveforms
423
+ sine_waves = self._f02sine(fn) * self.sine_amp
424
+
425
+ # generate uv signal
426
+ # uv = torch.ones(f0.shape)
427
+ # uv = uv * (f0 > self.voiced_threshold)
428
+ uv = self._f02uv(f0)
429
+
430
+ # noise: for unvoiced should be similar to sine_amp
431
+ # std = self.sine_amp/3 -> max value ~ self.sine_amp
432
+ # . for voiced regions is self.noise_std
433
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
434
+ noise = noise_amp * torch.randn_like(sine_waves)
435
+
436
+ # first: set the unvoiced part to 0 by uv
437
+ # then: additive noise
438
+ sine_waves = sine_waves * uv + noise
439
+ return sine_waves, uv, noise
440
+
441
+
442
+ class SourceModuleHnNSF(torch.nn.Module):
443
+ """ SourceModule for hn-nsf
444
+ SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
445
+ add_noise_std=0.003, voiced_threshod=0)
446
+ sampling_rate: sampling_rate in Hz
447
+ harmonic_num: number of harmonic above F0 (default: 0)
448
+ sine_amp: amplitude of sine source signal (default: 0.1)
449
+ add_noise_std: std of additive Gaussian noise (default: 0.003)
450
+ note that amplitude of noise in unvoiced is decided
451
+ by sine_amp
452
+ voiced_threshold: threhold to set U/V given F0 (default: 0)
453
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
454
+ F0_sampled (batchsize, length, 1)
455
+ Sine_source (batchsize, length, 1)
456
+ noise_source (batchsize, length 1)
457
+ uv (batchsize, length, 1)
458
+ """
459
+
460
+ def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
461
+ add_noise_std=0.003, voiced_threshod=0):
462
+ super(SourceModuleHnNSF, self).__init__()
463
+
464
+ self.sine_amp = sine_amp
465
+ self.noise_std = add_noise_std
466
+
467
+ # to produce sine waveforms
468
+ self.l_sin_gen = SineGen(sampling_rate, upsample_scale, harmonic_num,
469
+ sine_amp, add_noise_std, voiced_threshod)
470
+
471
+ # to merge source harmonics into a single excitation
472
+ self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
473
+ self.l_tanh = torch.nn.Tanh()
474
+
475
+ def forward(self, x):
476
+ """
477
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
478
+ F0_sampled (batchsize, length, 1)
479
+ Sine_source (batchsize, length, 1)
480
+ noise_source (batchsize, length 1)
481
+ """
482
+ # source for harmonic branch
483
+ with torch.no_grad():
484
+ sine_wavs, uv, _ = self.l_sin_gen(x)
485
+ sine_merge = self.l_tanh(self.l_linear(sine_wavs))
486
+
487
+ # source for noise branch, in the same shape as uv
488
+ noise = torch.randn_like(uv) * self.sine_amp / 3
489
+ return sine_merge, noise, uv
490
+ def padDiff(x):
491
+ return F.pad(F.pad(x, (0,0,-1,1), 'constant', 0) - x, (0,0,0,-1), 'constant', 0)
492
+
493
+
494
+ class Generator(torch.nn.Module):
495
+ def __init__(self, style_dim, resblock_kernel_sizes, upsample_rates, upsample_initial_channel, resblock_dilation_sizes, upsample_kernel_sizes, gen_istft_n_fft, gen_istft_hop_size):
496
+ super(Generator, self).__init__()
497
+
498
+ self.num_kernels = len(resblock_kernel_sizes)
499
+ self.num_upsamples = len(upsample_rates)
500
+ resblock = AdaINResBlock1
501
+
502
+ self.m_source = SourceModuleHnNSF(
503
+ sampling_rate=24000,
504
+ upsample_scale=np.prod(upsample_rates) * gen_istft_hop_size,
505
+ harmonic_num=8, voiced_threshod=10)
506
+ self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * gen_istft_hop_size)
507
+ self.noise_convs = nn.ModuleList()
508
+ self.noise_res = nn.ModuleList()
509
+
510
+ self.ups = nn.ModuleList()
511
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
512
+ self.ups.append(weight_norm(
513
+ ConvTranspose1d(upsample_initial_channel//(2**i), upsample_initial_channel//(2**(i+1)),
514
+ k, u, padding=(k-u)//2)))
515
+
516
+ self.resblocks = nn.ModuleList()
517
+ for i in range(len(self.ups)):
518
+ ch = upsample_initial_channel//(2**(i+1))
519
+ for j, (k, d) in enumerate(zip(resblock_kernel_sizes,resblock_dilation_sizes)):
520
+ self.resblocks.append(resblock(ch, k, d, style_dim))
521
+
522
+ c_cur = upsample_initial_channel // (2 ** (i + 1))
523
+
524
+ if i + 1 < len(upsample_rates): #
525
+ stride_f0 = np.prod(upsample_rates[i + 1:])
526
+ self.noise_convs.append(Conv1d(
527
+ gen_istft_n_fft + 2, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=(stride_f0+1) // 2))
528
+ self.noise_res.append(resblock(c_cur, 7, [1,3,5], style_dim))
529
+ else:
530
+ self.noise_convs.append(Conv1d(gen_istft_n_fft + 2, c_cur, kernel_size=1))
531
+ self.noise_res.append(resblock(c_cur, 11, [1,3,5], style_dim))
532
+
533
+
534
+ self.post_n_fft = gen_istft_n_fft
535
+ self.conv_post = weight_norm(Conv1d(ch, self.post_n_fft + 2, 7, 1, padding=3))
536
+ self.ups.apply(init_weights)
537
+ self.conv_post.apply(init_weights)
538
+ self.reflection_pad = torch.nn.ReflectionPad1d((1, 0))
539
+ #self.stft = TorchSTFT(filter_length=gen_istft_n_fft, hop_length=gen_istft_hop_size, win_length=gen_istft_n_fft)
540
+ self.stft = CustomSTFT(filter_length=gen_istft_n_fft, hop_length=gen_istft_hop_size, win_length=gen_istft_n_fft)
541
+
542
+
543
+ def forward(self, x, s, f0):
544
+ with torch.no_grad():
545
+ f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
546
+
547
+ har_source, noi_source, uv = self.m_source(f0)
548
+ har_source = har_source.transpose(1, 2).squeeze(1)
549
+ har_spec, har_phase = self.stft.transform(har_source)
550
+ har = torch.cat([har_spec, har_phase], dim=1)
551
+
552
+ for i in range(self.num_upsamples):
553
+ x = F.leaky_relu(x, LRELU_SLOPE)
554
+ x_source = self.noise_convs[i](har)
555
+ x_source = self.noise_res[i](x_source, s)
556
+
557
+ x = self.ups[i](x)
558
+ if i == self.num_upsamples - 1:
559
+ x = self.reflection_pad(x)
560
+
561
+ x = x + x_source
562
+ xs = None
563
+ for j in range(self.num_kernels):
564
+ if xs is None:
565
+ xs = self.resblocks[i*self.num_kernels+j](x, s)
566
+ else:
567
+ xs += self.resblocks[i*self.num_kernels+j](x, s)
568
+ x = xs / self.num_kernels
569
+ x = F.leaky_relu(x)
570
+ x = self.conv_post(x)
571
+ spec = torch.exp(x[:,:self.post_n_fft // 2 + 1, :])
572
+ phase = torch.sin(x[:, self.post_n_fft // 2 + 1:, :])
573
+ return self.stft.inverse(spec, phase)
574
+
575
+ def fw_phase(self, x, s):
576
+ for i in range(self.num_upsamples):
577
+ x = F.leaky_relu(x, LRELU_SLOPE)
578
+ x = self.ups[i](x)
579
+ xs = None
580
+ for j in range(self.num_kernels):
581
+ if xs is None:
582
+ xs = self.resblocks[i*self.num_kernels+j](x, s)
583
+ else:
584
+ xs += self.resblocks[i*self.num_kernels+j](x, s)
585
+ x = xs / self.num_kernels
586
+ x = F.leaky_relu(x)
587
+ x = self.reflection_pad(x)
588
+ x = self.conv_post(x)
589
+ spec = torch.exp(x[:,:self.post_n_fft // 2 + 1, :])
590
+ phase = torch.sin(x[:, self.post_n_fft // 2 + 1:, :])
591
+ return spec, phase
592
+
593
+ def remove_weight_norm(self):
594
+ print('Removing weight norm...')
595
+ for l in self.ups:
596
+ remove_weight_norm(l)
597
+ for l in self.resblocks:
598
+ l.remove_weight_norm()
599
+ remove_weight_norm(self.conv_pre)
600
+ remove_weight_norm(self.conv_post)
601
+
602
+
603
+ class AdainResBlk1d(nn.Module):
604
+ def __init__(self, dim_in, dim_out, style_dim=64, actv=nn.LeakyReLU(0.2),
605
+ upsample='none', dropout_p=0.0):
606
+ super().__init__()
607
+ self.actv = actv
608
+ self.upsample_type = upsample
609
+ self.upsample = UpSample1d(upsample)
610
+ self.learned_sc = dim_in != dim_out
611
+ self._build_weights(dim_in, dim_out, style_dim)
612
+ self.dropout = nn.Dropout(dropout_p)
613
+
614
+ if upsample == 'none':
615
+ self.pool = nn.Identity()
616
+ else:
617
+ self.pool = weight_norm(nn.ConvTranspose1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1, output_padding=1))
618
+
619
+
620
+ def _build_weights(self, dim_in, dim_out, style_dim):
621
+ self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
622
+ self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1))
623
+ self.norm1 = AdaIN1d(style_dim, dim_in)
624
+ self.norm2 = AdaIN1d(style_dim, dim_out)
625
+ if self.learned_sc:
626
+ self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
627
+
628
+ def _shortcut(self, x):
629
+ x = self.upsample(x)
630
+ if self.learned_sc:
631
+ x = self.conv1x1(x)
632
+ return x
633
+
634
+ def _residual(self, x, s):
635
+ x = self.norm1(x, s)
636
+ x = self.actv(x)
637
+ x = self.pool(x)
638
+ x = self.conv1(self.dropout(x))
639
+ x = self.norm2(x, s)
640
+ x = self.actv(x)
641
+ x = self.conv2(self.dropout(x))
642
+ return x
643
+
644
+ def forward(self, x, s):
645
+ out = self._residual(x, s)
646
+ out = (out + self._shortcut(x)) / math.sqrt(2)
647
+ return out
648
+
649
+ class UpSample1d(nn.Module):
650
+ def __init__(self, layer_type):
651
+ super().__init__()
652
+ self.layer_type = layer_type
653
+
654
+ def forward(self, x):
655
+ if self.layer_type == 'none':
656
+ return x
657
+ else:
658
+ return F.interpolate(x, scale_factor=2, mode='nearest')
659
+
660
+ class Decoder(nn.Module):
661
+ def __init__(self, dim_in=512, F0_channel=512, style_dim=64, dim_out=80,
662
+ resblock_kernel_sizes = [3,7,11],
663
+ upsample_rates = [10, 6],
664
+ upsample_initial_channel=512,
665
+ resblock_dilation_sizes=[[1,3,5], [1,3,5], [1,3,5]],
666
+ upsample_kernel_sizes=[20, 12],
667
+ gen_istft_n_fft=20, gen_istft_hop_size=5):
668
+ super().__init__()
669
+
670
+ self.decode = nn.ModuleList()
671
+
672
+ self.encode = AdainResBlk1d(dim_in + 2, 1024, style_dim)
673
+
674
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
675
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
676
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
677
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 512, style_dim, upsample=True))
678
+
679
+ self.F0_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1))
680
+
681
+ self.N_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1))
682
+
683
+ self.asr_res = nn.Sequential(
684
+ weight_norm(nn.Conv1d(512, 64, kernel_size=1)),
685
+ )
686
+
687
+
688
+ self.generator = Generator(style_dim, resblock_kernel_sizes, upsample_rates,
689
+ upsample_initial_channel, resblock_dilation_sizes,
690
+ upsample_kernel_sizes, gen_istft_n_fft, gen_istft_hop_size)
691
+
692
+ def forward(self, asr, F0_curve, N, s):
693
+ if self.training:
694
+ downlist = [0, 3, 7]
695
+ F0_down = downlist[random.randint(0, 2)]
696
+ downlist = [0, 3, 7, 15]
697
+ N_down = downlist[random.randint(0, 3)]
698
+ if F0_down:
699
+ F0_curve = nn.functional.conv1d(F0_curve.unsqueeze(1), torch.ones(1, 1, F0_down).to('cuda'), padding=F0_down//2).squeeze(1) / F0_down
700
+ if N_down:
701
+ N = nn.functional.conv1d(N.unsqueeze(1), torch.ones(1, 1, N_down).to('cuda'), padding=N_down//2).squeeze(1) / N_down
702
+
703
+
704
+ F0 = self.F0_conv(F0_curve.unsqueeze(1))
705
+ N = self.N_conv(N.unsqueeze(1))
706
+
707
+ x = torch.cat([asr, F0, N], axis=1)
708
+ x = self.encode(x, s)
709
+
710
+ asr_res = self.asr_res(asr)
711
+
712
+ res = True
713
+ for block in self.decode:
714
+ if res:
715
+ x = torch.cat([x, asr_res, F0, N], axis=1)
716
+ x = block(x, s)
717
+ if block.upsample_type != "none":
718
+ res = False
719
+
720
+ x = self.generator(x, s, F0_curve)
721
+ return x
722
+
723
+
libs/Modules/slmadv.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import torch.nn.functional as F
4
+
5
+ class SLMAdversarialLoss(torch.nn.Module):
6
+
7
+ def __init__(self, model, wl, min_len, max_len, batch_percentage=0.5, skip_update=10, sig=1.5):
8
+ super(SLMAdversarialLoss, self).__init__()
9
+ self.model = model
10
+ self.wl = wl
11
+
12
+ self.min_len = min_len
13
+ self.max_len = max_len
14
+ self.batch_percentage = batch_percentage
15
+
16
+ self.sig = sig
17
+ self.skip_update = skip_update
18
+
19
+ def forward(self, iters, y_rec_gt, y_rec_gt_pred, waves, mel_input_length, ref_text, ref_lengths, ref_s):
20
+ text_mask = length_to_mask(ref_lengths).to(ref_text.device)
21
+ t_en = self.model.text_encoder(ref_text, ref_lengths, text_mask)
22
+
23
+ s_dur = ref_s[:, 128:]
24
+ #s = ref_s[:, :128] #Not used
25
+
26
+ d, _ = self.model.predictor(t_en, s_dur,
27
+ ref_lengths,
28
+ torch.randn(ref_lengths.shape[0], ref_lengths.max(), 2).to(ref_text.device),
29
+ text_mask)
30
+
31
+ bib = 0
32
+
33
+ output_lengths = []
34
+ attn_preds = []
35
+
36
+ # differentiable duration modeling
37
+ for _s2s_pred, _text_length in zip(d, ref_lengths):
38
+
39
+ _s2s_pred_org = _s2s_pred[:_text_length, :]
40
+
41
+ _s2s_pred = torch.sigmoid(_s2s_pred_org)
42
+ _dur_pred = _s2s_pred.sum(axis=-1)
43
+
44
+ l = int(torch.round(_s2s_pred.sum()).item())
45
+ t = torch.arange(0, l).expand(l)
46
+
47
+ t = torch.arange(0, l).unsqueeze(0).expand((len(_s2s_pred), l)).to(ref_text.device)
48
+ loc = torch.cumsum(_dur_pred, dim=0) - _dur_pred / 2
49
+
50
+ h = torch.exp(-0.5 * torch.square(t - (l - loc.unsqueeze(-1))) / (self.sig)**2)
51
+
52
+ out = torch.nn.functional.conv1d(_s2s_pred_org.unsqueeze(0),
53
+ h.unsqueeze(1),
54
+ padding=h.shape[-1] - 1, groups=int(_text_length))[..., :l]
55
+ attn_preds.append(F.softmax(out.squeeze(), dim=0))
56
+
57
+ output_lengths.append(l)
58
+
59
+ max_len = max(output_lengths)
60
+
61
+ with torch.no_grad():
62
+ t_en = self.model.text_encoder(ref_text, ref_lengths, text_mask)
63
+
64
+ s2s_attn = torch.zeros(len(ref_lengths), int(ref_lengths.max()), max_len).to(ref_text.device)
65
+ for bib in range(len(output_lengths)):
66
+ s2s_attn[bib, :ref_lengths[bib], :output_lengths[bib]] = attn_preds[bib]
67
+
68
+ asr_pred = t_en @ s2s_attn
69
+
70
+ _, p_pred = self.model.predictor(t_en, s_dur,
71
+ ref_lengths,
72
+ s2s_attn,
73
+ text_mask)
74
+
75
+ mel_len = max(int(min(output_lengths) / 2 - 1), self.min_len // 2)
76
+ mel_len = min(mel_len, self.max_len // 2)
77
+
78
+ # get clips
79
+
80
+ en = []
81
+ p_en = []
82
+ sp = []
83
+
84
+ F0_fakes = []
85
+ N_fakes = []
86
+
87
+ wav = []
88
+
89
+ for bib in range(len(output_lengths)):
90
+ mel_length_pred = output_lengths[bib]
91
+ mel_length_gt = int(mel_input_length[bib].item() / 2)
92
+ if mel_length_gt <= mel_len or mel_length_pred <= mel_len:
93
+ continue
94
+
95
+ sp.append(ref_s[bib])
96
+
97
+ random_start = np.random.randint(0, mel_length_pred - mel_len)
98
+ en.append(asr_pred[bib, :, random_start:random_start+mel_len])
99
+ p_en.append(p_pred[bib, :, random_start:random_start+mel_len])
100
+
101
+ # get ground truth clips
102
+ random_start = np.random.randint(0, mel_length_gt - mel_len)
103
+ y = waves[bib][(random_start * 2) * 300:((random_start+mel_len) * 2) * 300]
104
+ wav.append(torch.from_numpy(y).to(ref_text.device))
105
+
106
+ if len(wav) >= self.batch_percentage * len(waves): # prevent OOM due to longer lengths
107
+ break
108
+
109
+ if len(sp) <= 1:
110
+ return None
111
+
112
+ sp = torch.stack(sp)
113
+ wav = torch.stack(wav).float()
114
+ en = torch.stack(en)
115
+ p_en = torch.stack(p_en)
116
+
117
+ F0_fake, N_fake = self.model.predictor.F0Ntrain(p_en, sp[:, 128:])
118
+ y_pred = self.model.decoder(en, F0_fake, N_fake, sp[:, :128])
119
+
120
+ # discriminator loss
121
+ if (iters + 1) % self.skip_update == 0:
122
+ if np.random.randint(0, 2) == 0:
123
+ wav = y_rec_gt_pred
124
+ use_rec = True
125
+ else:
126
+ use_rec = False
127
+
128
+ crop_size = min(wav.size(-1), y_pred.size(-1))
129
+ if use_rec: # use reconstructed (shorter lengths), do length invariant regularization
130
+ if wav.size(-1) > y_pred.size(-1):
131
+ real_GP = wav[:, : , :crop_size]
132
+ out_crop = self.wl.discriminator_forward(real_GP.detach().squeeze())
133
+ out_org = self.wl.discriminator_forward(wav.detach().squeeze())
134
+ loss_reg = F.l1_loss(out_crop, out_org[..., :out_crop.size(-1)])
135
+
136
+ if np.random.randint(0, 2) == 0:
137
+ d_loss = self.wl.discriminator(real_GP.detach().squeeze(), y_pred.detach().squeeze()).mean()
138
+ else:
139
+ d_loss = self.wl.discriminator(wav.detach().squeeze(), y_pred.detach().squeeze()).mean()
140
+ else:
141
+ real_GP = y_pred[:, : , :crop_size]
142
+ out_crop = self.wl.discriminator_forward(real_GP.detach().squeeze())
143
+ out_org = self.wl.discriminator_forward(y_pred.detach().squeeze())
144
+ loss_reg = F.l1_loss(out_crop, out_org[..., :out_crop.size(-1)])
145
+
146
+ if np.random.randint(0, 2) == 0:
147
+ d_loss = self.wl.discriminator(wav.detach().squeeze(), real_GP.detach().squeeze()).mean()
148
+ else:
149
+ d_loss = self.wl.discriminator(wav.detach().squeeze(), y_pred.detach().squeeze()).mean()
150
+
151
+ # regularization (ignore length variation)
152
+ d_loss += loss_reg
153
+
154
+ out_gt = self.wl.discriminator_forward(y_rec_gt.detach().squeeze())
155
+ out_rec = self.wl.discriminator_forward(y_rec_gt_pred.detach().squeeze())
156
+
157
+ # regularization (ignore reconstruction artifacts)
158
+ d_loss += F.l1_loss(out_gt, out_rec)
159
+
160
+ else:
161
+ d_loss = self.wl.discriminator(wav.detach().squeeze(), y_pred.detach().squeeze()).mean()
162
+ else:
163
+ d_loss = 0
164
+
165
+ # generator loss
166
+ gen_loss = self.wl.generator(y_pred.squeeze())
167
+
168
+ gen_loss = gen_loss.mean()
169
+
170
+ return d_loss, gen_loss, y_pred.detach().cpu().numpy()
171
+
172
+ def length_to_mask(lengths):
173
+ mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
174
+ mask = torch.gt(mask+1, lengths.unsqueeze(1))
175
+ return mask
libs/Modules/utils.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.nn.utils import weight_norm
2
+
3
+ def init_weights(m, mean=0.0, std=0.01):
4
+ classname = m.__class__.__name__
5
+ if classname.find("Conv") != -1:
6
+ m.weight.data.normal_(mean, std)
7
+
8
+
9
+ def apply_weight_norm(m):
10
+ classname = m.__class__.__name__
11
+ if classname.find("Conv") != -1:
12
+ weight_norm(m)
13
+
14
+
15
+ def get_padding(kernel_size, dilation=1):
16
+ return int((kernel_size*dilation - dilation)/2)
libs/Modules/vocos.py ADDED
@@ -0,0 +1,422 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import math
4
+ import random
5
+ import numpy as np
6
+
7
+ import torch
8
+ from torch import nn
9
+ import torch.nn.functional as F
10
+ from torch.nn.utils.parametrizations import weight_norm
11
+
12
+ from typing import Optional, Tuple
13
+ from scipy.signal import get_window
14
+
15
+ class AdaIN1d(nn.Module):
16
+ def __init__(self, style_dim, num_features):
17
+ super().__init__()
18
+ self.norm = nn.InstanceNorm1d(num_features, affine=False)
19
+ self.fc = nn.Linear(style_dim, num_features*2)
20
+
21
+ def forward(self, x, s):
22
+ h = self.fc(s)
23
+ h = h.view(h.size(0), h.size(1), 1)
24
+ gamma, beta = torch.chunk(h, chunks=2, dim=1)
25
+ return (1 + gamma) * self.norm(x) + beta
26
+
27
+ class ConvNeXtBlock(nn.Module):
28
+ """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
29
+
30
+ Args:
31
+ dim (int): Number of input channels.
32
+ intermediate_dim (int): Dimensionality of the intermediate layer.
33
+ layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
34
+ Defaults to None.
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ dim: int,
40
+ intermediate_dim: int,
41
+ layer_scale_init_value: float,
42
+ style_dim: int,
43
+ ):
44
+ super().__init__()
45
+ self.dwconv = nn.Conv1d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
46
+ self.norm = AdaIN1d(style_dim, dim)
47
+ self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
48
+ self.act = nn.GELU()
49
+ self.pwconv2 = nn.Linear(intermediate_dim, dim)
50
+ self.gamma = (
51
+ nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
52
+ if layer_scale_init_value > 0
53
+ else None
54
+ )
55
+
56
+ def forward(self, x: torch.Tensor, s: torch.Tensor) -> torch.Tensor:
57
+ residual = x
58
+ x = self.dwconv(x)
59
+ x = self.norm(x, s)
60
+ x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
61
+ x = self.pwconv1(x)
62
+ x = self.act(x)
63
+ x = self.pwconv2(x)
64
+ if self.gamma is not None:
65
+ x = self.gamma * x
66
+ x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
67
+
68
+ x = residual + x
69
+ return x
70
+
71
+ def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor:
72
+ """
73
+ Computes the element-wise logarithm of the input tensor with clipping to avoid near-zero values.
74
+
75
+ Args:
76
+ x (Tensor): Input tensor.
77
+ clip_val (float, optional): Minimum value to clip the input tensor. Defaults to 1e-7.
78
+
79
+ Returns:
80
+ Tensor: Element-wise logarithm of the input tensor with clipping applied.
81
+ """
82
+ return torch.log(torch.clip(x, min=clip_val))
83
+
84
+
85
+ def symlog(x: torch.Tensor) -> torch.Tensor:
86
+ return torch.sign(x) * torch.log1p(x.abs())
87
+
88
+
89
+ def symexp(x: torch.Tensor) -> torch.Tensor:
90
+ return torch.sign(x) * (torch.exp(x.abs()) - 1)
91
+
92
+ class Backbone(nn.Module):
93
+ """Base class for the generator's backbone. It preserves the same temporal resolution across all layers."""
94
+
95
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
96
+ """
97
+ Args:
98
+ x (Tensor): Input tensor of shape (B, C, L), where B is the batch size,
99
+ C denotes output features, and L is the sequence length.
100
+
101
+ Returns:
102
+ Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length,
103
+ and H denotes the model dimension.
104
+ """
105
+ raise NotImplementedError("Subclasses must implement the forward method.")
106
+
107
+
108
+ class Generator(Backbone):
109
+ """
110
+ Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization
111
+
112
+ Args:
113
+ input_channels (int): Number of input features channels.
114
+ dim (int): Hidden dimension of the model.
115
+ intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock.
116
+ num_layers (int): Number of ConvNeXtBlock layers.
117
+ layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`.
118
+ """
119
+
120
+ def __init__(
121
+ self,
122
+ input_channels: int,
123
+ dim: int,
124
+ style_dim: int,
125
+ intermediate_dim: int,
126
+ num_layers: int,
127
+ gen_istft_n_fft: int,
128
+ gen_istft_hop_size: int,
129
+ layer_scale_init_value: Optional[float] = None,
130
+ ):
131
+ super().__init__()
132
+ self.input_channels = input_channels
133
+ layer_scale_init_value = layer_scale_init_value or 1 / num_layers
134
+
135
+ self.convnext = nn.ModuleList()
136
+
137
+ for i in range(num_layers):
138
+ self.convnext.append(
139
+ ConvNeXtBlock(
140
+ dim=dim,
141
+ intermediate_dim=intermediate_dim,
142
+ layer_scale_init_value=layer_scale_init_value,
143
+ style_dim=style_dim,
144
+ )
145
+ )
146
+
147
+ self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6)
148
+ self.apply(self._init_weights)
149
+ self.reflection_pad = torch.nn.ReflectionPad1d((1, 0))
150
+ self.stft = ISTFTHead(dim=dim, n_fft=gen_istft_n_fft, hop_length=gen_istft_hop_size, padding="same")
151
+
152
+ def _init_weights(self, m):
153
+ if isinstance(m, (nn.Conv1d, nn.Linear)):
154
+ nn.init.trunc_normal_(m.weight, std=0.02)
155
+ nn.init.constant_(m.bias, 0)
156
+
157
+ def forward(self, x, s) -> torch.Tensor:
158
+ for i, conv_block in enumerate(self.convnext):
159
+ x = conv_block(x, s)
160
+ x = self.final_layer_norm(x.transpose(1, 2))
161
+ x = self.stft(x)
162
+ return x
163
+
164
+ class ISTFT(nn.Module):
165
+ """
166
+ Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with
167
+ windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges.
168
+ See issue: https://github.com/pytorch/pytorch/issues/62323
169
+ Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs.
170
+ The NOLA constraint is met as we trim padded samples anyway.
171
+
172
+ Args:
173
+ n_fft (int): Size of Fourier transform.
174
+ hop_length (int): The distance between neighboring sliding window frames.
175
+ win_length (int): The size of window frame and STFT filter.
176
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
177
+ """
178
+
179
+ def __init__(self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"):
180
+ super().__init__()
181
+ if padding not in ["center", "same"]:
182
+ raise ValueError("Padding must be 'center' or 'same'.")
183
+ self.padding = padding
184
+ self.n_fft = n_fft
185
+ self.hop_length = hop_length
186
+ self.win_length = win_length
187
+ window = torch.hann_window(win_length)
188
+ self.register_buffer("window", window)
189
+
190
+ def forward(self, spec: torch.Tensor) -> torch.Tensor:
191
+ """
192
+ Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram.
193
+
194
+ Args:
195
+ spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size,
196
+ N is the number of frequency bins, and T is the number of time frames.
197
+
198
+ Returns:
199
+ Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal.
200
+ """
201
+ if self.padding == "center":
202
+ # Fallback to pytorch native implementation
203
+ return torch.istft(spec, self.n_fft, self.hop_length, self.win_length, self.window, center=True)
204
+ elif self.padding == "same":
205
+ pad = (self.win_length - self.hop_length) // 2
206
+ else:
207
+ raise ValueError("Padding must be 'center' or 'same'.")
208
+
209
+ assert spec.dim() == 3, "Expected a 3D tensor as input"
210
+ B, N, T = spec.shape
211
+
212
+ # Inverse FFT
213
+ ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward")
214
+ ifft = ifft * self.window[None, :, None]
215
+
216
+ # Overlap and Add
217
+ output_size = (T - 1) * self.hop_length + self.win_length
218
+ y = torch.nn.functional.fold(
219
+ ifft, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length),
220
+ )[:, 0, 0, pad:-pad]
221
+
222
+ # Window envelope
223
+ window_sq = self.window.square().expand(1, T, -1).transpose(1, 2)
224
+ window_envelope = torch.nn.functional.fold(
225
+ window_sq, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length),
226
+ ).squeeze()[pad:-pad]
227
+
228
+ # Normalize
229
+ assert (window_envelope > 1e-11).all()
230
+ y = y / window_envelope
231
+
232
+ return y
233
+
234
+ class FourierHead(nn.Module):
235
+ """Base class for inverse fourier modules."""
236
+
237
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
238
+ """
239
+ Args:
240
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
241
+ L is the sequence length, and H denotes the model dimension.
242
+
243
+ Returns:
244
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
245
+ """
246
+ raise NotImplementedError("Subclasses must implement the forward method.")
247
+
248
+ class ISTFTHead(FourierHead):
249
+ """
250
+ ISTFT Head module for predicting STFT complex coefficients.
251
+
252
+ Args:
253
+ dim (int): Hidden dimension of the model.
254
+ n_fft (int): Size of Fourier transform.
255
+ hop_length (int): The distance between neighboring sliding window frames, which should align with
256
+ the resolution of the input features.
257
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
258
+ """
259
+
260
+ def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"):
261
+ super().__init__()
262
+ self.filter_length = n_fft
263
+ self.win_length = n_fft
264
+ self.hop_length = hop_length
265
+ self.window = torch.from_numpy(get_window("hann", self.win_length, fftbins=True).astype(np.float32))
266
+
267
+ out_dim = n_fft + 2
268
+ self.out = torch.nn.Linear(dim, out_dim)
269
+ self.istft = ISTFT(n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding)
270
+
271
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
272
+ """
273
+ Forward pass of the ISTFTHead module.
274
+
275
+ Args:
276
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
277
+ L is the sequence length, and H denotes the model dimension.
278
+
279
+ Returns:
280
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
281
+ """
282
+ x = self.out(x).transpose(1, 2)
283
+ mag, p = x.chunk(2, dim=1)
284
+ mag = torch.exp(mag)
285
+ mag = torch.clip(mag, max=1e2) # safeguard to prevent excessively large magnitudes
286
+ # wrapping happens here. These two lines produce real and imaginary value
287
+ x = torch.cos(p)
288
+ y = torch.sin(p)
289
+ # recalculating phase here does not produce anything new
290
+ # only costs time
291
+ # phase = torch.atan2(y, x)
292
+ # S = mag * torch.exp(phase * 1j)
293
+ # better directly produce the complex value
294
+ S = mag * (x + 1j * y)
295
+ audio = self.istft(S)
296
+ return audio
297
+
298
+ def transform(self, input_data):
299
+ forward_transform = torch.stft(
300
+ input_data,
301
+ self.filter_length, self.hop_length, self.win_length, window=self.window.to(input_data.device),
302
+ return_complex=True)
303
+
304
+ return torch.abs(forward_transform), torch.angle(forward_transform)
305
+
306
+
307
+ class AdainResBlk1d(nn.Module):
308
+ def __init__(self, dim_in, dim_out, style_dim=64, actv=nn.LeakyReLU(0.2),
309
+ upsample='none', dropout_p=0.0):
310
+ super().__init__()
311
+ self.actv = actv
312
+ self.upsample_type = upsample
313
+ self.upsample = UpSample1d(upsample)
314
+ self.learned_sc = dim_in != dim_out
315
+ self._build_weights(dim_in, dim_out, style_dim)
316
+ self.dropout = nn.Dropout(dropout_p)
317
+
318
+ if upsample == 'none':
319
+ self.pool = nn.Identity()
320
+ else:
321
+ self.pool = weight_norm(nn.ConvTranspose1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1, output_padding=1))
322
+
323
+
324
+ def _build_weights(self, dim_in, dim_out, style_dim):
325
+ self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
326
+ self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1))
327
+ self.norm1 = AdaIN1d(style_dim, dim_in)
328
+ self.norm2 = AdaIN1d(style_dim, dim_out)
329
+ if self.learned_sc:
330
+ self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
331
+
332
+ def _shortcut(self, x):
333
+ x = self.upsample(x)
334
+ if self.learned_sc:
335
+ x = self.conv1x1(x)
336
+ return x
337
+
338
+ def _residual(self, x, s):
339
+ x = self.norm1(x, s)
340
+ x = self.actv(x)
341
+ x = self.pool(x)
342
+ x = self.conv1(self.dropout(x))
343
+ x = self.norm2(x, s)
344
+ x = self.actv(x)
345
+ x = self.conv2(self.dropout(x))
346
+ return x
347
+
348
+ def forward(self, x, s):
349
+ out = self._residual(x, s)
350
+ out = (out + self._shortcut(x)) / math.sqrt(2)
351
+ return out
352
+
353
+ class UpSample1d(nn.Module):
354
+ def __init__(self, layer_type):
355
+ super().__init__()
356
+ self.layer_type = layer_type
357
+
358
+ def forward(self, x):
359
+ if self.layer_type == 'none':
360
+ return x
361
+ else:
362
+ return F.interpolate(x, scale_factor=2, mode='nearest')
363
+
364
+ class Decoder(nn.Module):
365
+ def __init__(self, dim_in=512, style_dim=64, dim_out=80,
366
+ intermediate_dim=1536,
367
+ num_layers=8,
368
+ gen_istft_n_fft=1024, gen_istft_hop_size=256):
369
+ super().__init__()
370
+
371
+ self.decode = nn.ModuleList()
372
+
373
+ self.encode = AdainResBlk1d(dim_in + 2, 1024, style_dim)
374
+
375
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
376
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
377
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
378
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 512, style_dim, upsample=True))
379
+
380
+ self.F0_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1))
381
+
382
+ self.N_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1))
383
+
384
+ self.asr_res = nn.Sequential(
385
+ weight_norm(nn.Conv1d(512, 64, kernel_size=1)),
386
+ )
387
+
388
+ self.generator = Generator(input_channels=dim_out, dim=dim_in, style_dim=style_dim,
389
+ intermediate_dim=intermediate_dim, num_layers=num_layers,
390
+ gen_istft_n_fft=gen_istft_n_fft, gen_istft_hop_size=gen_istft_hop_size)
391
+
392
+ def forward(self, asr, F0_curve, N, s):
393
+ if self.training:
394
+ downlist = [0, 3, 7]
395
+ F0_down = downlist[random.randint(0, 2)]
396
+ downlist = [0, 3, 7, 15]
397
+ N_down = downlist[random.randint(0, 3)]
398
+ if F0_down:
399
+ F0_curve = nn.functional.conv1d(F0_curve.unsqueeze(1), torch.ones(1, 1, F0_down).to('cuda'), padding=F0_down//2).squeeze(1) / F0_down
400
+ if N_down:
401
+ N = nn.functional.conv1d(N.unsqueeze(1), torch.ones(1, 1, N_down).to('cuda'), padding=N_down//2).squeeze(1) / N_down
402
+
403
+
404
+ F0 = self.F0_conv(F0_curve.unsqueeze(1))
405
+ N = self.N_conv(N.unsqueeze(1))
406
+
407
+ x = torch.cat([asr, F0, N], axis=1)
408
+ x = self.encode(x, s)
409
+
410
+ asr_res = self.asr_res(asr)
411
+
412
+ res = True
413
+ for block in self.decode:
414
+ if res:
415
+ x = torch.cat([x, asr_res, F0, N], axis=1)
416
+ x = block(x, s)
417
+ if block.upsample_type != "none":
418
+ res = False
419
+
420
+ x = self.generator(x, s)
421
+ x = x.unsqueeze(1)
422
+ return x
libs/inference.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import yaml
3
+ from munch import Munch
4
+ import numpy as np
5
+ import librosa
6
+ import noisereduce as nr
7
+ from .meldataset import TextCleaner
8
+ import torch
9
+ import torchaudio
10
+ from nltk.tokenize import word_tokenize
11
+ import nltk
12
+ nltk.download('punkt_tab')
13
+
14
+ from .models import ProsodyPredictor, TextEncoder, StyleEncoder
15
+
16
+ class Preprocess:
17
+ def __text_normalize(self, text):
18
+ punctuation = [",", "、", "،", ";", "(", ".", "。", "…", "!", "–", ":", "?"]
19
+ map_to = "."
20
+ punctuation_pattern = re.compile(f"[{''.join(re.escape(p) for p in punctuation)}]")
21
+ #replace punctuation that acts like a comma or period
22
+ text = punctuation_pattern.sub(map_to, text)
23
+ #replace consecutive whitespace chars with a single space and strip leading/trailing spaces
24
+ text = re.sub(r'\s+', ' ', text).strip()
25
+ return text
26
+ def __merge_fragments(self, texts, n):
27
+ merged = []
28
+ i = 0
29
+ while i < len(texts):
30
+ fragment = texts[i]
31
+ j = i + 1
32
+ while len(fragment.split()) < n and j < len(texts):
33
+ fragment += ", " + texts[j]
34
+ j += 1
35
+ merged.append(fragment)
36
+ i = j
37
+ if len(merged[-1].split()) < n and len(merged) > 1: #handle last sentence
38
+ merged[-2] = merged[-2] + ", " + merged[-1]
39
+ del merged[-1]
40
+ else:
41
+ merged[-1] = merged[-1]
42
+ return merged
43
+ def wave_preprocess(self, wave):
44
+ to_mel = torchaudio.transforms.MelSpectrogram(n_mels=80, n_fft=2048, win_length=1200, hop_length=300)
45
+ mean, std = -4, 4
46
+ wave_tensor = torch.from_numpy(wave).float()
47
+ mel_tensor = to_mel(wave_tensor)
48
+ mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std
49
+ return mel_tensor
50
+ def text_preprocess(self, text, n_merge=12):
51
+ text_norm = self.__text_normalize(text).split(".")#split by sentences.
52
+ text_norm = [s.strip() for s in text_norm]
53
+ text_norm = list(filter(lambda x: x != '', text_norm)) #filter empty index
54
+ text_norm = self.__merge_fragments(text_norm, n=n_merge) #merge if a sentence has less that n
55
+ return text_norm
56
+ def length_to_mask(self, lengths):
57
+ mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
58
+ mask = torch.gt(mask+1, lengths.unsqueeze(1))
59
+ return mask
60
+
61
+ #For inference only
62
+ class StyleTTS2(torch.nn.Module):
63
+ def __init__(self, config_path, models_path):
64
+ super().__init__()
65
+ self.register_buffer("get_device", torch.empty(0))
66
+ self.preprocess = Preprocess()
67
+ self.ref_s = None
68
+ config = yaml.safe_load(open(config_path, "r", encoding="utf-8"))
69
+
70
+ try:
71
+ symbols = (
72
+ list(config['symbol']['pad']) +
73
+ list(config['symbol']['punctuation']) +
74
+ list(config['symbol']['letters']) +
75
+ list(config['symbol']['letters_ipa']) +
76
+ list(config['symbol']['extend'])
77
+ )
78
+ symbol_dict = {}
79
+ for i in range(len((symbols))):
80
+ symbol_dict[symbols[i]] = i
81
+
82
+ n_token = len(symbol_dict) + 1
83
+ print("\nFound:", n_token, "symbols")
84
+ except Exception as e:
85
+ print(f"\nERROR: Cannot find {e} in config file!\nYour config file is likely outdated, please download updated version from the repository.")
86
+ raise SystemExit(1)
87
+
88
+ args = self.__recursive_munch(config['model_params'])
89
+ args['n_token'] = n_token
90
+
91
+ self.cleaner = TextCleaner(symbol_dict, debug=False)
92
+
93
+ assert args.decoder.type in ['istftnet', 'hifigan', 'vocos'], 'Decoder type unknown'
94
+
95
+ if args.decoder.type == "istftnet":
96
+ from .Modules.istftnet import Decoder
97
+ self.decoder = Decoder(dim_in=args.hidden_dim, style_dim=args.style_dim, dim_out=args.n_mels,
98
+ resblock_kernel_sizes = args.decoder.resblock_kernel_sizes,
99
+ upsample_rates = args.decoder.upsample_rates,
100
+ upsample_initial_channel=args.decoder.upsample_initial_channel,
101
+ resblock_dilation_sizes=args.decoder.resblock_dilation_sizes,
102
+ upsample_kernel_sizes=args.decoder.upsample_kernel_sizes,
103
+ gen_istft_n_fft=args.decoder.gen_istft_n_fft, gen_istft_hop_size=args.decoder.gen_istft_hop_size)
104
+ elif args.decoder.type == "hifigan":
105
+ from .Modules.hifigan import Decoder
106
+ self.decoder = Decoder(dim_in=args.hidden_dim, style_dim=args.style_dim, dim_out=args.n_mels,
107
+ resblock_kernel_sizes = args.decoder.resblock_kernel_sizes,
108
+ upsample_rates = args.decoder.upsample_rates,
109
+ upsample_initial_channel=args.decoder.upsample_initial_channel,
110
+ resblock_dilation_sizes=args.decoder.resblock_dilation_sizes,
111
+ upsample_kernel_sizes=args.decoder.upsample_kernel_sizes)
112
+ elif args.decoder.type == "vocos":
113
+ from .Modules.vocos import Decoder
114
+ self.decoder = Decoder(dim_in=args.hidden_dim, style_dim=args.style_dim, dim_out=args.n_mels,
115
+ intermediate_dim=args.decoder.intermediate_dim,
116
+ num_layers=args.decoder.num_layers,
117
+ gen_istft_n_fft=args.decoder.gen_istft_n_fft,
118
+ gen_istft_hop_size=args.decoder.gen_istft_hop_size)
119
+
120
+ self.predictor = ProsodyPredictor(style_dim=args.style_dim, d_hid=args.hidden_dim, nlayers=args.n_layer, max_dur=args.max_dur, dropout=args.dropout)
121
+ self.text_encoder = TextEncoder(channels=args.hidden_dim, kernel_size=5, depth=args.n_layer, n_symbols=args.n_token)
122
+ self.style_encoder = StyleEncoder(dim_in=args.dim_in, style_dim=args.style_dim, max_conv_dim=args.hidden_dim)# acoustic style encoder
123
+
124
+ self.__load_models(models_path)
125
+
126
+ def __recursive_munch(self, d):
127
+ if isinstance(d, dict):
128
+ return Munch((k, self.__recursive_munch(v)) for k, v in d.items())
129
+ elif isinstance(d, list):
130
+ return [self.__recursive_munch(v) for v in d]
131
+ else:
132
+ return d
133
+
134
+ def __replace_outliers_zscore(self, tensor, threshold=3.0, factor=0.95):
135
+ mean = tensor.mean()
136
+ std = tensor.std()
137
+ z = (tensor - mean) / std
138
+
139
+ # Identify outliers
140
+ outlier_mask = torch.abs(z) > threshold
141
+ # Compute replacement value, respecting sign
142
+ sign = torch.sign(tensor - mean)
143
+ replacement = mean + sign * (threshold * std * factor)
144
+
145
+ result = tensor.clone()
146
+ result[outlier_mask] = replacement[outlier_mask]
147
+
148
+ return result
149
+
150
+ def __load_models(self, models_path):
151
+ module_params = []
152
+ model = {'decoder':self.decoder, 'predictor':self.predictor, 'text_encoder':self.text_encoder, 'style_encoder':self.style_encoder}
153
+
154
+ params_whole = torch.load(models_path, map_location='cpu')
155
+ params = params_whole['net']
156
+ params = {key: value for key, value in params.items() if key in model.keys()}
157
+
158
+ for key in model:
159
+ try:
160
+ model[key].load_state_dict(params[key])
161
+ except:
162
+ from collections import OrderedDict
163
+ state_dict = params[key]
164
+ new_state_dict = OrderedDict()
165
+ for k, v in state_dict.items():
166
+ name = k[7:] # remove `module.`
167
+ new_state_dict[name] = v
168
+ model[key].load_state_dict(new_state_dict, strict=False)
169
+
170
+ total_params = sum(p.numel() for p in model[key].parameters())
171
+ print(key,":",total_params)
172
+ module_params.append(total_params)
173
+
174
+ print('\nTotal',":",sum(module_params))
175
+
176
+ def __compute_style(self, path, denoise, split_dur):
177
+ device = self.get_device.device
178
+ denoise = min(denoise, 1)
179
+ if split_dur != 0: split_dur = max(int(split_dur), 1)
180
+ max_samples = 24000*20 #max 20 seconds ref audio
181
+ print("Computing the style for:", path)
182
+
183
+ wave, sr = librosa.load(path, sr=24000)
184
+ audio, index = librosa.effects.trim(wave, top_db=30)
185
+ if sr != 24000:
186
+ audio = librosa.resample(audio, sr, 24000)
187
+ if len(audio) > max_samples:
188
+ audio = audio[:max_samples]
189
+
190
+ if denoise > 0.0:
191
+ audio_denoise = nr.reduce_noise(y=audio, sr=sr, n_fft=2048, win_length=1200, hop_length=300)
192
+ audio = audio*(1-denoise) + audio_denoise*denoise
193
+
194
+ with torch.no_grad():
195
+ if split_dur>0 and len(audio)/sr>=4: #Only effective if audio length is >= 4s
196
+ #This option will split the ref audio to multiple parts, calculate styles and average them
197
+ count = 0
198
+ ref_s = None
199
+ jump = sr*split_dur
200
+ total_len = len(audio)
201
+
202
+ #Need to init before the loop
203
+ mel_tensor = self.preprocess.wave_preprocess(audio[0:jump]).to(device)
204
+ ref_s = self.style_encoder(mel_tensor.unsqueeze(1))
205
+ count += 1
206
+ for i in range(jump, total_len, jump):
207
+ if i+jump >= total_len:
208
+ left_dur = (total_len-i)/sr
209
+ if left_dur >= 1: #Still count if left over dur is >= 1s
210
+ mel_tensor = self.preprocess.wave_preprocess(audio[i:total_len]).to(device)
211
+ ref_s += self.style_encoder(mel_tensor.unsqueeze(1))
212
+ count += 1
213
+ continue
214
+ mel_tensor = self.preprocess.wave_preprocess(audio[i:i+jump]).to(device)
215
+ ref_s += self.style_encoder(mel_tensor.unsqueeze(1))
216
+ count += 1
217
+ ref_s /= count
218
+ else:
219
+ mel_tensor = self.preprocess.wave_preprocess(audio).to(device)
220
+ ref_s = self.style_encoder(mel_tensor.unsqueeze(1))
221
+
222
+ return ref_s
223
+
224
+ def __inference(self, phonem, ref_s, speed=1, prev_d_mean=0, t=0.1):
225
+ device = self.get_device.device
226
+ speed = min(max(speed, 0.0001), 2) #speed range [0, 2]
227
+
228
+ phonem = ' '.join(word_tokenize(phonem))
229
+ tokens = self.cleaner(phonem)
230
+ tokens.insert(0, 0)
231
+ tokens.append(0)
232
+ tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
233
+
234
+ with torch.no_grad():
235
+ input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
236
+ text_mask = self.preprocess.length_to_mask(input_lengths).to(device)
237
+
238
+ # encode
239
+ t_en = self.text_encoder(tokens, input_lengths, text_mask)
240
+ s = ref_s.to(device)
241
+
242
+ # cal alignment
243
+ d = self.predictor.text_encoder(t_en, s, input_lengths, text_mask)
244
+ x, _ = self.predictor.lstm(d)
245
+ duration = self.predictor.duration_proj(x)
246
+ duration = torch.sigmoid(duration).sum(axis=-1)
247
+
248
+ if prev_d_mean != 0:#Stabilize speaking speed between splits
249
+ dur_stats = torch.empty(duration.shape).normal_(mean=prev_d_mean, std=duration.std()).to(device)
250
+ else:
251
+ dur_stats = torch.empty(duration.shape).normal_(mean=duration.mean(), std=duration.std()).to(device)
252
+ duration = duration*(1-t) + dur_stats*t
253
+ duration[:,1:-2] = self.__replace_outliers_zscore(duration[:,1:-2]) #Normalize outlier
254
+
255
+ duration /= speed
256
+
257
+ pred_dur = torch.round(duration.squeeze()).clamp(min=1)
258
+ pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))
259
+ c_frame = 0
260
+ for i in range(pred_aln_trg.size(0)):
261
+ pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
262
+ c_frame += int(pred_dur[i].data)
263
+ alignment = pred_aln_trg.unsqueeze(0).to(device)
264
+
265
+ # encode prosody
266
+ en = (d.transpose(-1, -2) @ alignment)
267
+ F0_pred, N_pred = self.predictor.F0Ntrain(en, s)
268
+ asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))
269
+
270
+ out = self.decoder(asr, F0_pred, N_pred, s)
271
+
272
+ return out.squeeze().cpu().numpy(), duration.mean()
273
+
274
+ def get_styles(self, speaker, denoise=0.3, avg_style=True, load_styles=False):
275
+ if not load_styles:
276
+ if avg_style: split_dur = 3
277
+ else: split_dur = 0
278
+ self.ref_s = self.__compute_style(speaker['path'], denoise=denoise, split_dur=split_dur)
279
+ else:
280
+ if self.ref_s is None:
281
+ raise Exception("Have to compute or load the styles first!")
282
+ style = {
283
+ 'style': self.ref_s,
284
+ 'path': speaker['path'],
285
+ 'speed': speaker['speed'],
286
+ }
287
+ return style
288
+
289
+ def save_styles(self, save_dir):
290
+ if self.ref_s is not None:
291
+ torch.save(self.ref_s, save_dir)
292
+ print("Saved styles!")
293
+ else:
294
+ raise Exception("Have to compute the styles before saving it.")
295
+
296
+ def load_styles(self, save_dir):
297
+ try:
298
+ self.ref_s = torch.load(save_dir)
299
+ print("Loaded styles!")
300
+ except Exception as e:
301
+ print(e)
302
+
303
+ def generate(self, phonem, style, stabilize=True, n_merge=16):
304
+ if stabilize: smooth_value=0.2
305
+ else: smooth_value=0
306
+
307
+ list_wav = []
308
+ prev_d_mean = 0
309
+
310
+ print("Generating Audio...")
311
+ text_norm = self.preprocess.text_preprocess(phonem, n_merge=n_merge)
312
+ for sentence in text_norm:
313
+ wav, prev_d_mean = self.__inference(sentence, style['style'], speed=style['speed'], prev_d_mean=prev_d_mean, t=smooth_value)
314
+ wav = wav[4000:-4000] #Remove weird pulse and silent tokens
315
+ list_wav.append(wav)
316
+
317
+ final_wav = np.concatenate(list_wav)
318
+ final_wav = np.concatenate([np.zeros([4000]), final_wav, np.zeros([4000])], axis=0) # add padding
319
+ return final_wav
libs/meldataset.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #coding: utf-8
2
+ import os.path as osp
3
+ import random
4
+ import numpy as np
5
+ import random
6
+ import soundfile as sf
7
+ import librosa
8
+
9
+ import torch
10
+ import torchaudio
11
+ import torch.utils.data
12
+ import torch.distributed as dist
13
+ from multiprocessing import Pool
14
+
15
+ import logging
16
+ logger = logging.getLogger(__name__)
17
+ logger.setLevel(logging.DEBUG)
18
+
19
+ import pandas as pd
20
+
21
+ class TextCleaner:
22
+ def __init__(self, symbol_dict, debug=True):
23
+ self.word_index_dictionary = symbol_dict
24
+ self.debug = debug
25
+ def __call__(self, text):
26
+ indexes = []
27
+ for char in text:
28
+ try:
29
+ indexes.append(self.word_index_dictionary[char])
30
+ except KeyError as e:
31
+ if self.debug:
32
+ print("\nWARNING UNKNOWN IPA CHARACTERS/LETTERS: ", char)
33
+ print("To ignore set 'debug' to false in the config")
34
+ continue
35
+ return indexes
36
+
37
+ np.random.seed(1)
38
+ random.seed(1)
39
+ SPECT_PARAMS = {
40
+ "n_fft": 2048,
41
+ "win_length": 1200,
42
+ "hop_length": 300
43
+ }
44
+ MEL_PARAMS = {
45
+ "n_mels": 80,
46
+ }
47
+
48
+ to_mel = torchaudio.transforms.MelSpectrogram(
49
+ n_mels=80, n_fft=2048, win_length=1200, hop_length=300)
50
+ mean, std = -4, 4
51
+
52
+ def preprocess(wave):
53
+ wave_tensor = torch.from_numpy(wave).float()
54
+ mel_tensor = to_mel(wave_tensor)
55
+ mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std
56
+ return mel_tensor
57
+
58
+ class FilePathDataset(torch.utils.data.Dataset):
59
+ def __init__(self,
60
+ data_list,
61
+ root_path,
62
+ symbol_dict,
63
+ sr=24000,
64
+ data_augmentation=False,
65
+ validation=False,
66
+ debug=True
67
+ ):
68
+
69
+ _data_list = [l.strip().split('|') for l in data_list]
70
+ self.data_list = _data_list #[data if len(data) == 3 else (*data, 0) for data in _data_list] #append speakerid=0 for all
71
+ self.text_cleaner = TextCleaner(symbol_dict, debug)
72
+ self.sr = sr
73
+
74
+ self.df = pd.DataFrame(self.data_list)
75
+
76
+ self.to_melspec = torchaudio.transforms.MelSpectrogram(**MEL_PARAMS)
77
+
78
+ self.mean, self.std = -4, 4
79
+ self.data_augmentation = data_augmentation and (not validation)
80
+ self.max_mel_length = 192
81
+
82
+ self.root_path = root_path
83
+
84
+ def __len__(self):
85
+ return len(self.data_list)
86
+
87
+ def __getitem__(self, idx):
88
+ data = self.data_list[idx]
89
+ path = data[0]
90
+
91
+ wave, text_tensor = self._load_tensor(data)
92
+
93
+ mel_tensor = preprocess(wave).squeeze()
94
+
95
+ acoustic_feature = mel_tensor.squeeze()
96
+ length_feature = acoustic_feature.size(1)
97
+ acoustic_feature = acoustic_feature[:, :(length_feature - length_feature % 2)]
98
+
99
+ return acoustic_feature, text_tensor, path, wave
100
+
101
+ def _load_tensor(self, data):
102
+ wave_path, text = data
103
+ wave, sr = sf.read(osp.join(self.root_path, wave_path))
104
+ if wave.shape[-1] == 2:
105
+ wave = wave[:, 0].squeeze()
106
+ if sr != 24000:
107
+ wave = librosa.resample(wave, orig_sr=sr, target_sr=24000)
108
+ print(wave_path, sr)
109
+
110
+ # Adding half a second padding.
111
+ wave = np.concatenate([np.zeros([12000]), wave, np.zeros([12000])], axis=0)
112
+
113
+ text = self.text_cleaner(text)
114
+
115
+ text.insert(0, 0)
116
+ text.append(0)
117
+
118
+ text = torch.LongTensor(text)
119
+
120
+ return wave, text
121
+
122
+ def _load_data(self, data):
123
+ wave, text_tensor = self._load_tensor(data)
124
+ mel_tensor = preprocess(wave).squeeze()
125
+
126
+ mel_length = mel_tensor.size(1)
127
+ if mel_length > self.max_mel_length:
128
+ random_start = np.random.randint(0, mel_length - self.max_mel_length)
129
+ mel_tensor = mel_tensor[:, random_start:random_start + self.max_mel_length]
130
+
131
+ return mel_tensor
132
+
133
+
134
+ class Collater(object):
135
+ """
136
+ Args:
137
+ adaptive_batch_size (bool): if true, decrease batch size when long data comes.
138
+ """
139
+
140
+ def __init__(self, return_wave=False):
141
+ self.text_pad_index = 0
142
+ self.min_mel_length = 192
143
+ self.max_mel_length = 192
144
+ self.return_wave = return_wave
145
+
146
+
147
+ def __call__(self, batch):
148
+ batch_size = len(batch)
149
+
150
+ # sort by mel length
151
+ lengths = [b[0].shape[1] for b in batch]
152
+ batch_indexes = np.argsort(lengths)[::-1]
153
+ batch = [batch[bid] for bid in batch_indexes]
154
+
155
+ nmels = batch[0][0].size(0)
156
+ max_mel_length = max([b[0].shape[1] for b in batch])
157
+ max_text_length = max([b[1].shape[0] for b in batch])
158
+
159
+ mels = torch.zeros((batch_size, nmels, max_mel_length)).float()
160
+ texts = torch.zeros((batch_size, max_text_length)).long()
161
+
162
+ input_lengths = torch.zeros(batch_size).long()
163
+ output_lengths = torch.zeros(batch_size).long()
164
+ paths = ['' for _ in range(batch_size)]
165
+ waves = [None for _ in range(batch_size)]
166
+
167
+ for bid, (mel, text, path, wave) in enumerate(batch):
168
+ mel_size = mel.size(1)
169
+ text_size = text.size(0)
170
+ mels[bid, :, :mel_size] = mel
171
+ texts[bid, :text_size] = text
172
+ input_lengths[bid] = text_size
173
+ output_lengths[bid] = mel_size
174
+ paths[bid] = path
175
+
176
+ waves[bid] = wave
177
+
178
+ return waves, texts, input_lengths, mels, output_lengths
179
+
180
+
181
+ def get_length(wave_path, root_path):
182
+ info = sf.info(osp.join(root_path, wave_path))
183
+ return info.frames * (24000 / info.samplerate)
184
+
185
+ def build_dataloader(path_list,
186
+ root_path,
187
+ symbol_dict,
188
+ validation=False,
189
+ batch_size=4,
190
+ num_workers=1,
191
+ device='cpu',
192
+ collate_config={},
193
+ dataset_config={}):
194
+
195
+ dataset = FilePathDataset(path_list, root_path, symbol_dict, validation=validation, **dataset_config)
196
+ collate_fn = Collater(**collate_config)
197
+
198
+ print("Getting sample lengths...")
199
+
200
+ num_processes = num_workers * 2
201
+ if num_processes != 0:
202
+ list_of_tuples = [(d[0], root_path) for d in dataset.data_list]
203
+ with Pool(processes=num_processes) as pool:
204
+ sample_lengths = pool.starmap(get_length, list_of_tuples, chunksize=16)
205
+ else:
206
+ sample_lengths = []
207
+ for d in dataset.data_list:
208
+ sample_lengths.append(get_length(d[0], root_path))
209
+
210
+ data_loader = torch.utils.data.DataLoader(
211
+ dataset,
212
+ num_workers=num_workers,
213
+ batch_sampler=BatchSampler(
214
+ sample_lengths,
215
+ batch_size,
216
+ shuffle=(not validation),
217
+ drop_last=(not validation),
218
+ num_replicas=1,
219
+ rank=0,
220
+ ),
221
+ collate_fn=collate_fn,
222
+ pin_memory=(device != "cpu"),
223
+ )
224
+
225
+ return data_loader
226
+
227
+ #https://github.com/duerig/StyleTTS2/
228
+ class BatchSampler(torch.utils.data.Sampler):
229
+ def __init__(
230
+ self,
231
+ sample_lengths,
232
+ batch_sizes,
233
+ num_replicas=None,
234
+ rank=None,
235
+ shuffle=True,
236
+ drop_last=False,
237
+ ):
238
+ self.batch_sizes = batch_sizes
239
+ if num_replicas is None:
240
+ self.num_replicas = dist.get_world_size()
241
+ else:
242
+ self.num_replicas = num_replicas
243
+ if rank is None:
244
+ self.rank = dist.get_rank()
245
+ else:
246
+ self.rank = rank
247
+ self.shuffle = shuffle
248
+ self.drop_last = drop_last
249
+
250
+ self.time_bins = {}
251
+ self.epoch = 0
252
+ self.total_len = 0
253
+ self.last_bin = None
254
+
255
+ for i in range(len(sample_lengths)):
256
+ bin_num = self.get_time_bin(sample_lengths[i])
257
+ if bin_num != -1:
258
+ if bin_num not in self.time_bins:
259
+ self.time_bins[bin_num] = []
260
+ self.time_bins[bin_num].append(i)
261
+
262
+ for key in self.time_bins.keys():
263
+ val = self.time_bins[key]
264
+ total_batch = self.batch_sizes * num_replicas
265
+ self.total_len += len(val) // total_batch
266
+ if not self.drop_last and len(val) % total_batch != 0:
267
+ self.total_len += 1
268
+
269
+ def __iter__(self):
270
+ sampler_order = list(self.time_bins.keys())
271
+ sampler_indices = []
272
+
273
+ if self.shuffle:
274
+ sampler_indices = torch.randperm(len(sampler_order)).tolist()
275
+ else:
276
+ sampler_indices = list(range(len(sampler_order)))
277
+
278
+ for index in sampler_indices:
279
+ key = sampler_order[index]
280
+ current_bin = self.time_bins[key]
281
+ dist = torch.utils.data.distributed.DistributedSampler(
282
+ current_bin,
283
+ num_replicas=self.num_replicas,
284
+ rank=self.rank,
285
+ shuffle=self.shuffle,
286
+ drop_last=self.drop_last,
287
+ )
288
+ dist.set_epoch(self.epoch)
289
+ sampler = torch.utils.data.sampler.BatchSampler(
290
+ dist, self.batch_sizes, self.drop_last
291
+ )
292
+ for item_list in sampler:
293
+ self.last_bin = key
294
+ yield [current_bin[i] for i in item_list]
295
+
296
+ def __len__(self):
297
+ return self.total_len
298
+
299
+ def set_epoch(self, epoch):
300
+ self.epoch = epoch
301
+
302
+ def get_time_bin(self, sample_count):
303
+ result = -1
304
+ frames = sample_count // 300
305
+ if frames >= 20:
306
+ result = (frames - 20) // 20
307
+ return result
libs/models.py ADDED
@@ -0,0 +1,633 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.nn.utils import weight_norm
5
+
6
+ from .Modules.ASR.models import ASRCNN
7
+ from .Modules.JDC.model import JDCNet
8
+ from .Modules.discriminators import MultiPeriodDiscriminator, MultiResSpecDiscriminator
9
+
10
+ import math
11
+ from munch import Munch
12
+
13
+ class LearnedDownSample(nn.Module):
14
+ def __init__(self, layer_type, dim_in):
15
+ super().__init__()
16
+ self.layer_type = layer_type
17
+
18
+ if self.layer_type == 'none':
19
+ self.conv = nn.Identity()
20
+ elif self.layer_type == 'timepreserve':
21
+ self.conv = nn.Conv2d(dim_in, dim_in, kernel_size=(3, 1), stride=(2, 1), groups=dim_in, padding=(1, 0))
22
+ elif self.layer_type == 'half':
23
+ self.conv = nn.Conv2d(dim_in, dim_in, kernel_size=(3, 3), stride=(2, 2), groups=dim_in, padding=1)
24
+ else:
25
+ raise RuntimeError('Got unexpected donwsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
26
+
27
+ def forward(self, x):
28
+ return self.conv(x)
29
+
30
+ class LearnedUpSample(nn.Module):
31
+ def __init__(self, layer_type, dim_in):
32
+ super().__init__()
33
+ self.layer_type = layer_type
34
+
35
+ if self.layer_type == 'none':
36
+ self.conv = nn.Identity()
37
+ elif self.layer_type == 'timepreserve':
38
+ self.conv = nn.ConvTranspose2d(dim_in, dim_in, kernel_size=(3, 1), stride=(2, 1), groups=dim_in, output_padding=(1, 0), padding=(1, 0))
39
+ elif self.layer_type == 'half':
40
+ self.conv = nn.ConvTranspose2d(dim_in, dim_in, kernel_size=(3, 3), stride=(2, 2), groups=dim_in, output_padding=1, padding=1)
41
+ else:
42
+ raise RuntimeError('Got unexpected upsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
43
+
44
+
45
+ def forward(self, x):
46
+ return self.conv(x)
47
+
48
+ class DownSample(nn.Module):
49
+ def __init__(self, layer_type):
50
+ super().__init__()
51
+ self.layer_type = layer_type
52
+
53
+ def forward(self, x):
54
+ if self.layer_type == 'none':
55
+ return x
56
+ elif self.layer_type == 'timepreserve':
57
+ return F.avg_pool2d(x, (2, 1))
58
+ elif self.layer_type == 'half':
59
+ if x.shape[-1] % 2 != 0:
60
+ x = torch.cat([x, x[..., -1].unsqueeze(-1)], dim=-1)
61
+ return F.avg_pool2d(x, 2)
62
+ else:
63
+ raise RuntimeError('Got unexpected donwsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
64
+
65
+
66
+ class UpSample(nn.Module):
67
+ def __init__(self, layer_type):
68
+ super().__init__()
69
+ self.layer_type = layer_type
70
+
71
+ def forward(self, x):
72
+ if self.layer_type == 'none':
73
+ return x
74
+ elif self.layer_type == 'timepreserve':
75
+ return F.interpolate(x, scale_factor=(2, 1), mode='nearest')
76
+ elif self.layer_type == 'half':
77
+ return F.interpolate(x, scale_factor=2, mode='nearest')
78
+ else:
79
+ raise RuntimeError('Got unexpected upsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
80
+
81
+
82
+ class ResBlk(nn.Module):
83
+ def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
84
+ normalize=False, downsample='none'):
85
+ super().__init__()
86
+ self.actv = actv
87
+ self.normalize = normalize
88
+ self.downsample = DownSample(downsample)
89
+ self.downsample_res = LearnedDownSample(downsample, dim_in)
90
+ self.learned_sc = dim_in != dim_out
91
+ self._build_weights(dim_in, dim_out)
92
+
93
+ def _build_weights(self, dim_in, dim_out):
94
+ self.conv1 = nn.Conv2d(dim_in, dim_in, 3, 1, 1)
95
+ self.conv2 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
96
+ if self.normalize:
97
+ self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
98
+ self.norm2 = nn.InstanceNorm2d(dim_in, affine=True)
99
+ if self.learned_sc:
100
+ self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
101
+
102
+ def _shortcut(self, x):
103
+ if self.learned_sc:
104
+ x = self.conv1x1(x)
105
+ if self.downsample:
106
+ x = self.downsample(x)
107
+ return x
108
+
109
+ def _residual(self, x):
110
+ if self.normalize:
111
+ x = self.norm1(x)
112
+ x = self.actv(x)
113
+ x = self.conv1(x)
114
+ x = self.downsample_res(x)
115
+ if self.normalize:
116
+ x = self.norm2(x)
117
+ x = self.actv(x)
118
+ x = self.conv2(x)
119
+ return x
120
+
121
+ def forward(self, x):
122
+ x = self._shortcut(x) + self._residual(x)
123
+ return x / math.sqrt(2) # unit variance
124
+
125
+ class StyleEncoder(nn.Module):
126
+ def __init__(self, dim_in=48, style_dim=48, max_conv_dim=384):
127
+ super().__init__()
128
+ blocks = []
129
+ blocks += [nn.Conv2d(1, dim_in, 3, 1, 1)]
130
+
131
+ repeat_num = 4
132
+ for _ in range(repeat_num):
133
+ dim_out = min(dim_in*2, max_conv_dim)
134
+ blocks += [ResBlk(dim_in, dim_out, downsample='half')]
135
+ dim_in = dim_out
136
+
137
+ blocks += [nn.LeakyReLU(0.2)]
138
+ blocks += [nn.Conv2d(dim_out, dim_out, 5, 1, 0)]
139
+ blocks += [nn.AdaptiveAvgPool2d(1)]
140
+ blocks += [nn.LeakyReLU(0.2)]
141
+ self.shared = nn.Sequential(*blocks)
142
+
143
+ self.unshared = nn.Linear(dim_out, style_dim)
144
+
145
+ def forward(self, x):
146
+ h = self.shared(x)
147
+ h = h.view(h.size(0), -1)
148
+ s = self.unshared(h)
149
+
150
+ return s
151
+
152
+ class LinearNorm(torch.nn.Module):
153
+ def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
154
+ super(LinearNorm, self).__init__()
155
+ self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
156
+
157
+ torch.nn.init.xavier_uniform_(
158
+ self.linear_layer.weight,
159
+ gain=torch.nn.init.calculate_gain(w_init_gain))
160
+
161
+ def forward(self, x):
162
+ return self.linear_layer(x)
163
+
164
+ class ResBlk1d(nn.Module):
165
+ def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
166
+ normalize=False, downsample='none', dropout_p=0.2):
167
+ super().__init__()
168
+ self.actv = actv
169
+ self.normalize = normalize
170
+ self.downsample_type = downsample
171
+ self.learned_sc = dim_in != dim_out
172
+ self._build_weights(dim_in, dim_out)
173
+ self.dropout_p = dropout_p
174
+
175
+ if self.downsample_type == 'none':
176
+ self.pool = nn.Identity()
177
+ else:
178
+ self.pool = weight_norm(nn.Conv1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1))
179
+
180
+ def _build_weights(self, dim_in, dim_out):
181
+ self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_in, 3, 1, 1))
182
+ self.conv2 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
183
+ if self.normalize:
184
+ self.norm1 = nn.InstanceNorm1d(dim_in, affine=True)
185
+ self.norm2 = nn.InstanceNorm1d(dim_in, affine=True)
186
+ if self.learned_sc:
187
+ self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
188
+
189
+ def downsample(self, x):
190
+ if self.downsample_type == 'none':
191
+ return x
192
+ else:
193
+ if x.shape[-1] % 2 != 0:
194
+ x = torch.cat([x, x[..., -1].unsqueeze(-1)], dim=-1)
195
+ return F.avg_pool1d(x, 2)
196
+
197
+ def _shortcut(self, x):
198
+ if self.learned_sc:
199
+ x = self.conv1x1(x)
200
+ x = self.downsample(x)
201
+ return x
202
+
203
+ def _residual(self, x):
204
+ if self.normalize:
205
+ x = self.norm1(x)
206
+ x = self.actv(x)
207
+ x = F.dropout(x, p=self.dropout_p, training=self.training)
208
+
209
+ x = self.conv1(x)
210
+ x = self.pool(x)
211
+ if self.normalize:
212
+ x = self.norm2(x)
213
+
214
+ x = self.actv(x)
215
+ x = F.dropout(x, p=self.dropout_p, training=self.training)
216
+
217
+ x = self.conv2(x)
218
+ return x
219
+
220
+ def forward(self, x):
221
+ x = self._shortcut(x) + self._residual(x)
222
+ return x / math.sqrt(2) # unit variance
223
+
224
+ class LayerNorm(nn.Module):
225
+ def __init__(self, channels, eps=1e-5):
226
+ super().__init__()
227
+ self.channels = channels
228
+ self.eps = eps
229
+
230
+ self.gamma = nn.Parameter(torch.ones(channels))
231
+ self.beta = nn.Parameter(torch.zeros(channels))
232
+
233
+ def forward(self, x):
234
+ x = x.transpose(1, -1)
235
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
236
+ return x.transpose(1, -1)
237
+
238
+ class TextEncoder(nn.Module):
239
+ def __init__(self, channels, kernel_size, depth, n_symbols, actv=nn.LeakyReLU(0.2)):
240
+ super().__init__()
241
+ self.embedding = nn.Embedding(n_symbols, channels)
242
+
243
+ padding = (kernel_size - 1) // 2
244
+ self.cnn = nn.ModuleList()
245
+ for _ in range(depth):
246
+ self.cnn.append(nn.Sequential(
247
+ weight_norm(nn.Conv1d(channels, channels, kernel_size=kernel_size, padding=padding)),
248
+ LayerNorm(channels),
249
+ actv,
250
+ nn.Dropout(0.2),
251
+ ))
252
+ # self.cnn = nn.Sequential(*self.cnn)
253
+
254
+ self.lstm = nn.LSTM(channels, channels//2, 1, batch_first=True, bidirectional=True)
255
+
256
+ def forward(self, x, input_lengths, m):
257
+ x = self.embedding(x) # [B, T, emb]
258
+ x = x.transpose(1, 2) # [B, emb, T]
259
+ m = m.to(input_lengths.device).unsqueeze(1)
260
+ x.masked_fill_(m, 0.0)
261
+
262
+ for c in self.cnn:
263
+ x = c(x)
264
+ x.masked_fill_(m, 0.0)
265
+
266
+ x = x.transpose(1, 2) # [B, T, chn]
267
+
268
+ input_lengths = input_lengths.cpu()
269
+ x = nn.utils.rnn.pack_padded_sequence(
270
+ x, input_lengths, batch_first=True, enforce_sorted=False)
271
+
272
+ self.lstm.flatten_parameters()
273
+ x, _ = self.lstm(x)
274
+ x, _ = nn.utils.rnn.pad_packed_sequence(
275
+ x, batch_first=True)
276
+
277
+ x = x.transpose(-1, -2)
278
+ x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]])
279
+
280
+ x_pad[:, :, :x.shape[-1]] = x
281
+ x = x_pad.to(x.device)
282
+
283
+ x.masked_fill_(m, 0.0)
284
+
285
+ return x
286
+
287
+ def inference(self, x):
288
+ x = self.embedding(x)
289
+ x = x.transpose(1, 2)
290
+ x = self.cnn(x)
291
+ x = x.transpose(1, 2)
292
+ self.lstm.flatten_parameters()
293
+ x, _ = self.lstm(x)
294
+ return x
295
+
296
+ def length_to_mask(self, lengths):
297
+ mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
298
+ mask = torch.gt(mask+1, lengths.unsqueeze(1))
299
+ return mask
300
+
301
+
302
+
303
+ class AdaIN1d(nn.Module):
304
+ def __init__(self, style_dim, num_features):
305
+ super().__init__()
306
+ self.norm = nn.InstanceNorm1d(num_features, affine=False)
307
+ self.fc = nn.Linear(style_dim, num_features*2)
308
+
309
+ def forward(self, x, s):
310
+ h = self.fc(s)
311
+ h = h.view(h.size(0), h.size(1), 1)
312
+ gamma, beta = torch.chunk(h, chunks=2, dim=1)
313
+ return (1 + gamma) * self.norm(x) + beta
314
+
315
+ class UpSample1d(nn.Module):
316
+ def __init__(self, layer_type):
317
+ super().__init__()
318
+ self.layer_type = layer_type
319
+
320
+ def forward(self, x):
321
+ if self.layer_type == 'none':
322
+ return x
323
+ else:
324
+ return F.interpolate(x, scale_factor=2, mode='nearest')
325
+
326
+ class AdainResBlk1d(nn.Module):
327
+ def __init__(self, dim_in, dim_out, style_dim=64, actv=nn.LeakyReLU(0.2),
328
+ upsample='none', dropout_p=0.0):
329
+ super().__init__()
330
+ self.actv = actv
331
+ self.upsample_type = upsample
332
+ self.upsample = UpSample1d(upsample)
333
+ self.learned_sc = dim_in != dim_out
334
+ self._build_weights(dim_in, dim_out, style_dim)
335
+ self.dropout = nn.Dropout(dropout_p)
336
+
337
+ if upsample == 'none':
338
+ self.pool = nn.Identity()
339
+ else:
340
+ self.pool = weight_norm(nn.ConvTranspose1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1, output_padding=1))
341
+
342
+
343
+ def _build_weights(self, dim_in, dim_out, style_dim):
344
+ self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
345
+ self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1))
346
+ self.norm1 = AdaIN1d(style_dim, dim_in)
347
+ self.norm2 = AdaIN1d(style_dim, dim_out)
348
+ if self.learned_sc:
349
+ self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
350
+
351
+ def _shortcut(self, x):
352
+ x = self.upsample(x)
353
+ if self.learned_sc:
354
+ x = self.conv1x1(x)
355
+ return x
356
+
357
+ def _residual(self, x, s):
358
+ x = self.norm1(x, s)
359
+ x = self.actv(x)
360
+ x = self.pool(x)
361
+ x = self.conv1(self.dropout(x))
362
+ x = self.norm2(x, s)
363
+ x = self.actv(x)
364
+ x = self.conv2(self.dropout(x))
365
+ return x
366
+
367
+ def forward(self, x, s):
368
+ out = self._residual(x, s)
369
+ out = (out + self._shortcut(x)) / math.sqrt(2)
370
+ return out
371
+
372
+ class AdaLayerNorm(nn.Module):
373
+ def __init__(self, style_dim, channels, eps=1e-5):
374
+ super().__init__()
375
+ self.channels = channels
376
+ self.eps = eps
377
+
378
+ self.fc = nn.Linear(style_dim, channels*2)
379
+
380
+ def forward(self, x, s):
381
+ x = x.transpose(-1, -2)
382
+ x = x.transpose(1, -1)
383
+
384
+ h = self.fc(s)
385
+ h = h.view(h.size(0), h.size(1), 1)
386
+ gamma, beta = torch.chunk(h, chunks=2, dim=1)
387
+ gamma, beta = gamma.transpose(1, -1), beta.transpose(1, -1)
388
+
389
+
390
+ x = F.layer_norm(x, (self.channels,), eps=self.eps)
391
+ x = (1 + gamma) * x + beta
392
+ return x.transpose(1, -1).transpose(-1, -2)
393
+
394
+ class ProsodyPredictor(nn.Module):
395
+
396
+ def __init__(self, style_dim, d_hid, nlayers, max_dur=50, dropout=0.1):
397
+ super().__init__()
398
+
399
+ self.text_encoder = DurationEncoder(sty_dim=style_dim,
400
+ d_model=d_hid,
401
+ nlayers=nlayers,
402
+ dropout=dropout)
403
+
404
+ self.lstm = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
405
+ self.duration_proj = LinearNorm(d_hid, max_dur)
406
+
407
+ self.shared = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
408
+ self.F0 = nn.ModuleList()
409
+ self.F0.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
410
+ self.F0.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
411
+ self.F0.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
412
+
413
+ self.N = nn.ModuleList()
414
+ self.N.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
415
+ self.N.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
416
+ self.N.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
417
+
418
+ self.F0_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
419
+ self.N_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
420
+
421
+
422
+ def forward(self, texts, style, text_lengths, alignment, m):
423
+ d = self.text_encoder(texts, style, text_lengths, m)
424
+
425
+ # predict duration
426
+ input_lengths = text_lengths.cpu()
427
+ x = nn.utils.rnn.pack_padded_sequence(
428
+ d, input_lengths, batch_first=True, enforce_sorted=False)
429
+
430
+ m = m.to(text_lengths.device).unsqueeze(1)
431
+
432
+ self.lstm.flatten_parameters()
433
+ x, _ = self.lstm(x)
434
+ x, _ = nn.utils.rnn.pad_packed_sequence(
435
+ x, batch_first=True)
436
+
437
+ x_pad = torch.zeros([x.shape[0], m.shape[-1], x.shape[-1]])
438
+
439
+ x_pad[:, :x.shape[1], :] = x
440
+ x = x_pad.to(x.device)
441
+
442
+ duration = self.duration_proj(nn.functional.dropout(x, 0.5, training=self.training))
443
+
444
+ en = (d.transpose(-1, -2) @ alignment)
445
+
446
+ return duration.squeeze(-1), en
447
+
448
+ def F0Ntrain(self, x, s):
449
+ x, _ = self.shared(x.transpose(-1, -2))
450
+
451
+ F0 = x.transpose(-1, -2)
452
+ for block in self.F0:
453
+ F0 = block(F0, s)
454
+ F0 = self.F0_proj(F0)
455
+
456
+ N = x.transpose(-1, -2)
457
+ for block in self.N:
458
+ N = block(N, s)
459
+ N = self.N_proj(N)
460
+
461
+ return F0.squeeze(1), N.squeeze(1)
462
+
463
+ def length_to_mask(self, lengths):
464
+ mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
465
+ mask = torch.gt(mask+1, lengths.unsqueeze(1))
466
+ return mask
467
+
468
+ class DurationEncoder(nn.Module):
469
+
470
+ def __init__(self, sty_dim, d_model, nlayers, dropout=0.1):
471
+ super().__init__()
472
+ self.lstms = nn.ModuleList()
473
+ for _ in range(nlayers):
474
+ self.lstms.append(nn.LSTM(d_model + sty_dim,
475
+ d_model // 2,
476
+ num_layers=1,
477
+ batch_first=True,
478
+ bidirectional=True,
479
+ dropout=dropout))
480
+ self.lstms.append(AdaLayerNorm(sty_dim, d_model))
481
+
482
+
483
+ self.dropout = dropout
484
+ self.d_model = d_model
485
+ self.sty_dim = sty_dim
486
+
487
+ def forward(self, x, style, text_lengths, m):
488
+ masks = m.to(text_lengths.device)
489
+
490
+ x = x.permute(2, 0, 1)
491
+ s = style.expand(x.shape[0], x.shape[1], -1)
492
+ x = torch.cat([x, s], axis=-1)
493
+ x.masked_fill_(masks.unsqueeze(-1).transpose(0, 1), 0.0)
494
+
495
+ x = x.transpose(0, 1)
496
+ input_lengths = text_lengths.cpu()
497
+ x = x.transpose(-1, -2)
498
+
499
+ for block in self.lstms:
500
+ if isinstance(block, AdaLayerNorm):
501
+ x = block(x.transpose(-1, -2), style).transpose(-1, -2)
502
+ x = torch.cat([x, s.permute(1, -1, 0)], axis=1)
503
+ x.masked_fill_(masks.unsqueeze(-1).transpose(-1, -2), 0.0)
504
+ else:
505
+ x = x.transpose(-1, -2)
506
+ x = nn.utils.rnn.pack_padded_sequence(
507
+ x, input_lengths, batch_first=True, enforce_sorted=False)
508
+ block.flatten_parameters()
509
+ x, _ = block(x)
510
+ x, _ = nn.utils.rnn.pad_packed_sequence(
511
+ x, batch_first=True)
512
+ x = F.dropout(x, p=self.dropout, training=self.training)
513
+ x = x.transpose(-1, -2)
514
+
515
+ x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]])
516
+
517
+ x_pad[:, :, :x.shape[-1]] = x
518
+ x = x_pad.to(x.device)
519
+
520
+ return x.transpose(-1, -2)
521
+
522
+ def inference(self, x, style):
523
+ x = self.embedding(x.transpose(-1, -2)) * math.sqrt(self.d_model)
524
+ style = style.expand(x.shape[0], x.shape[1], -1)
525
+ x = torch.cat([x, style], axis=-1)
526
+ src = self.pos_encoder(x)
527
+ output = self.transformer_encoder(src).transpose(0, 1)
528
+ return output
529
+
530
+ def length_to_mask(self, lengths):
531
+ mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
532
+ mask = torch.gt(mask+1, lengths.unsqueeze(1))
533
+ return mask
534
+
535
+ def build_model(args):
536
+ assert args.decoder.type in ['istftnet', 'hifigan', 'vocos'], 'Decoder type unknown'
537
+
538
+ if args.decoder.type == "istftnet":
539
+ from Modules.istftnet import Decoder
540
+ decoder = Decoder(dim_in=args.hidden_dim, style_dim=args.style_dim, dim_out=args.n_mels,
541
+ resblock_kernel_sizes = args.decoder.resblock_kernel_sizes,
542
+ upsample_rates = args.decoder.upsample_rates,
543
+ upsample_initial_channel=args.decoder.upsample_initial_channel,
544
+ resblock_dilation_sizes=args.decoder.resblock_dilation_sizes,
545
+ upsample_kernel_sizes=args.decoder.upsample_kernel_sizes,
546
+ gen_istft_n_fft=args.decoder.gen_istft_n_fft, gen_istft_hop_size=args.decoder.gen_istft_hop_size)
547
+ elif args.decoder.type == "hifigan":
548
+ from Modules.hifigan import Decoder
549
+ decoder = Decoder(dim_in=args.hidden_dim, style_dim=args.style_dim, dim_out=args.n_mels,
550
+ resblock_kernel_sizes = args.decoder.resblock_kernel_sizes,
551
+ upsample_rates = args.decoder.upsample_rates,
552
+ upsample_initial_channel=args.decoder.upsample_initial_channel,
553
+ resblock_dilation_sizes=args.decoder.resblock_dilation_sizes,
554
+ upsample_kernel_sizes=args.decoder.upsample_kernel_sizes)
555
+ elif args.decoder.type == "vocos":
556
+ from Modules.vocos import Decoder
557
+ decoder = Decoder(dim_in=args.hidden_dim, style_dim=args.style_dim, dim_out=args.n_mels,
558
+ intermediate_dim=args.decoder.intermediate_dim,
559
+ num_layers=args.decoder.num_layers,
560
+ gen_istft_n_fft=args.decoder.gen_istft_n_fft,
561
+ gen_istft_hop_size=args.decoder.gen_istft_hop_size)
562
+
563
+ nets = Munch(
564
+ decoder = decoder,
565
+ predictor = ProsodyPredictor(style_dim=args.style_dim, d_hid=args.hidden_dim, nlayers=args.n_layer, max_dur=args.max_dur, dropout=args.dropout),
566
+ text_encoder = TextEncoder(channels=args.hidden_dim, kernel_size=5, depth=args.n_layer, n_symbols=args.n_token),
567
+ style_encoder = StyleEncoder(dim_in=args.dim_in, style_dim=args.style_dim, max_conv_dim=args.hidden_dim),# acoustic style encoder
568
+ text_aligner = ASRCNN(input_dim=args.ASR_params.input_dim, hidden_dim=args.ASR_params.hidden_dim, n_token=args.n_token,
569
+ n_layers=args.ASR_params.n_layers, token_embedding_dim=args.ASR_params.token_embedding_dim), #ASR
570
+ pitch_extractor = JDCNet(num_class=args.JDC_params.num_class, seq_len=args.JDC_params.seq_len), #F0
571
+
572
+ mpd = MultiPeriodDiscriminator(),
573
+ msd = MultiResSpecDiscriminator(),
574
+ )
575
+
576
+ return nets
577
+
578
+ def load_checkpoint(model, optimizer, path, load_only_params=True, ignore_modules=[], freeze_modules=[]):
579
+ print("\n")
580
+ state = torch.load(path, map_location='cpu')
581
+ params = state['net']
582
+
583
+ for key in model:
584
+ loaded_keys = list(params[key].keys())
585
+ loaded_has_module = loaded_keys[0].startswith('module.')
586
+ model_keys = list(model[key].state_dict().keys())
587
+ model_has_module = model_keys[0].startswith('module.')
588
+
589
+ if key in params and key not in ignore_modules:
590
+ try:
591
+ model[key].load_state_dict(params[key], strict=True)
592
+ except Exception as e:
593
+ from collections import OrderedDict
594
+ state_dict = params[key]
595
+ new_state_dict = OrderedDict()
596
+ if not loaded_has_module and model_has_module:
597
+ print("Loading non-DP weights into DP model")
598
+ #Add module
599
+ for k, v in state_dict.items():
600
+ # If key already has module. leave it otherwise add it
601
+ new_key = k if k.startswith('module.') else 'module.' + k
602
+ new_state_dict[new_key] = v
603
+ model[key].load_state_dict(new_state_dict, strict=True)# load params
604
+ elif loaded_has_module and not model_has_module:
605
+ print("Loading DP weights into non-DP model")
606
+ #Remove module
607
+ for k, v in state_dict.items():
608
+ name = k[7:] # remove `module.`
609
+ new_state_dict[name] = v
610
+ model[key].load_state_dict(new_state_dict, strict=True)# load params
611
+ else:
612
+ print(e)
613
+ print('%s Loaded' % key)
614
+ if key in freeze_modules:
615
+ for param in model[key].parameters():
616
+ param.requires_grad = False
617
+ print('%s Freezed' % key)
618
+ if key in ignore_modules:
619
+ print('%s Ignored' % key)
620
+
621
+ _ = [model[key].eval() for key in model]
622
+
623
+ if not load_only_params:
624
+ print('\nLoading old optimizer')
625
+ epoch = state["epoch"]
626
+ iters = state["iters"]
627
+ optimizer.load_state_dict(state["optimizer"])
628
+ else:
629
+ print('\nNOT Loading old optimizer')
630
+ epoch = 0
631
+ iters = 0
632
+
633
+ return model, optimizer, epoch, iters
model.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # styletts_plugin.py
2
+ import os
3
+ import sys
4
+ import numpy as np
5
+ import yaml
6
+ import torch
7
+ import phonemizer
8
+ from phonemizer.backend.espeak.wrapper import EspeakWrapper
9
+ import soundfile as sf
10
+ import httpx
11
+ import nltk
12
+ import subprocess
13
+ from libs.inference import StyleTTS2
14
+
15
+ try:
16
+ nltk.data.find('tokenizers/punkt_tab')
17
+ except nltk.downloader.DownloadError:
18
+ print("Đang tải NLTK tokenizer 'punkt_tab'...")
19
+ nltk.download('punkt_tab')
20
+ print("Tải thành công.")
21
+
22
+ class StyleTTModel():
23
+ def __init__(self, **kwargs):
24
+ self.model_weights_path = "models/base_model.pth"
25
+ self.model_config_path = "models/config.yaml"
26
+
27
+ self.speaker_wav = kwargs.get("speaker_wav", "speakers/example_female.wav")
28
+ self.language = kwargs.get("language", "en-us")
29
+ self.speed = kwargs.get("speed", 1.0)
30
+ self.denoise = kwargs.get("denoise", 0.2)
31
+ self.avg_style = kwargs.get("avg_style", True)
32
+ self.stabilize = kwargs.get("stabilize", True)
33
+ self.device = self._get_device()
34
+
35
+ self.sample_rate = 24000
36
+ self.model = None
37
+
38
+ def _get_device(self):
39
+ if torch.cuda.is_available():
40
+ return "cuda"
41
+ return "cpu"
42
+
43
+ def _download_file(self, url: str, destination: str):
44
+ print(f"Đang tải file từ {url}...")
45
+ try:
46
+ os.makedirs(os.path.dirname(destination), exist_ok=True)
47
+ with httpx.stream("GET", url, follow_redirects=True, timeout=30) as r:
48
+ r.raise_for_status()
49
+ with open(destination, 'wb') as f:
50
+ for chunk in r.iter_bytes(chunk_size=8192):
51
+ f.write(chunk)
52
+ print(f"Tải thành công và lưu tại: {destination}")
53
+ except Exception as e:
54
+ print(f"Lỗi khi tải file bằng httpx: {e}")
55
+ raise
56
+
57
+ def _phonemize(self, text: str, lang: str) -> str:
58
+ # Tạo mới instance phonemizer mỗi lần gọi để đảm bảo an toàn luồng
59
+
60
+ if sys.platform == 'darwin':
61
+ try:
62
+ # Dùng lệnh brew để tìm đường dẫn cài đặt của espeak-ng một cách an toàn
63
+ result = subprocess.run(['brew', '--prefix', 'espeak-ng'], capture_output=True, text=True, check=True)
64
+ espeak_ng_prefix = result.stdout.strip()
65
+
66
+ # Xây dựng đường dẫn đến file thư viện động (.dylib)
67
+ # Đây là cách làm ổn định hơn nhiều so với việc mã hóa cứng phiên bản
68
+ espeak_lib_path = os.path.join(espeak_ng_prefix, 'lib', 'libespeak-ng.dylib')
69
+
70
+ if os.path.exists(espeak_lib_path):
71
+ EspeakWrapper.set_library(espeak_lib_path)
72
+ print(f"✅ Đã tự động tìm và cấu hình eSpeak NG cho macOS tại: {espeak_lib_path}")
73
+ else:
74
+ print(f"⚠️ Không tìm thấy file thư viện tại {espeak_lib_path}. Hãy chắc chắn bạn đã cài espeak-ng qua Homebrew.")
75
+
76
+ except (subprocess.CalledProcessError, FileNotFoundError):
77
+ print("🛑 Lỗi: Không thể chạy lệnh 'brew'. Hãy chắc chắn Homebrew và espeak-ng đã được cài đặt đúng cách.")
78
+ print(" Chạy lệnh 'brew install espeak-ng' trong terminal.")
79
+
80
+ elif sys.platform == 'win32':
81
+ try:
82
+ import espeakng_loader
83
+ EspeakWrapper.set_library(espeakng_loader.get_library_path())
84
+ EspeakWrapper.data_path = espeakng_loader.get_data_path()
85
+ except ImportError:
86
+ print("Cảnh báo: Không tìm thấy espeakng_loader.")
87
+
88
+ phonemizer_instance = phonemizer.backend.EspeakBackend(
89
+ language=lang, preserve_punctuation=True, with_stress=True
90
+ )
91
+ return phonemizer_instance.phonemize([text])[0]
92
+
93
+ def cache_speaker_style(self, speaker_wav: str):
94
+ """
95
+ Tính toán và cache style của một giọng nói để tái sử dụng.
96
+ Hàm này nên được gọi một lần khi bắt đầu cuộc hội thoại.
97
+ """
98
+ if self.model is None:
99
+ self.load()
100
+
101
+ print(f"-> Đang tính toán và cache style cho giọng nói: {speaker_wav}")
102
+ speaker_info = {"path": speaker_wav, "speed": self.speed} # Tốc độ có thể không cần ở đây
103
+
104
+ # Sử dụng các tham số mặc định của plugin để cache
105
+ with torch.no_grad():
106
+ self.cached_style = self.model.get_styles(
107
+ speaker_info,
108
+ denoise=self.denoise,
109
+ avg_style=self.avg_style
110
+ )
111
+ print("-> Cache style thành công.")
112
+
113
+ def load(self):
114
+ print("Đang khởi tạo StyleTTS PyTorch plugin...")
115
+ if not os.path.exists(self.model_config_path):
116
+ config_url = "https://huggingface.co/dangtr0408/StyleTTS2-lite/resolve/main/Models/config.yaml"
117
+ self._download_file(config_url, self.model_config_path)
118
+ if not os.path.exists(self.model_weights_path):
119
+ weights_url = "https://huggingface.co/dangtr0408/StyleTTS2-lite/resolve/main/Models/base_model.pth"
120
+ self._download_file(weights_url, self.model_weights_path)
121
+
122
+ print("\nBắt đầu tải model PyTorch vào bộ nhớ...")
123
+ self.model = StyleTTS2(self.model_config_path, self.model_weights_path)
124
+ self.model.eval()
125
+ self.model.to(self.device)
126
+ print(f"StyleTTS PyTorch plugin đã tải thành công trên thiết bị {self.device}.")
127
+
128
+ # Tự động cache style cho giọng nói mặc định
129
+ print(f"-> Tự động tính toán và cache style cho giọng nói: {self.speaker_wav}")
130
+ try:
131
+ speaker_info = {"path": self.speaker_wav, "speed": self.speed}
132
+ with torch.no_grad():
133
+ self.cached_style = self.model.get_styles(
134
+ speaker_info,
135
+ denoise=self.denoise,
136
+ avg_style=self.avg_style
137
+ )
138
+ print("-> Cache style thành công.")
139
+ except Exception as e:
140
+ print(f"-> CẢNH BÁO: Không thể cache style. Lỗi: {e}")
141
+ self.cached_style = None
142
+
143
+ # "Warm-up" cho phonemizer
144
+ print("-> Đang thực hiện warm-up cho phonemizer...")
145
+ try:
146
+ self._phonemize("warm-up", self.language)
147
+ print("-> Phonemizer warm-up thành công.")
148
+ except Exception as e:
149
+ print(f"-> Cảnh báo: Phonemizer warm-up thất bại: {e}")
150
+
151
+ return self
152
+
153
+ def synthesize(self, text: str, **kwargs) -> np.ndarray:
154
+ if self.model is None:
155
+ self.load()
156
+
157
+ language = kwargs.get("language", self.language)
158
+ speed = kwargs.get("speed", self.speed)
159
+ stabilize = kwargs.get("stabilize", self.stabilize)
160
+
161
+ if not hasattr(self, 'cached_style') or self.cached_style is None:
162
+ print("Cảnh báo: Style chưa được cache. Đang tính toán lại...")
163
+ speaker_wav = kwargs.get("speaker_wav", self.speaker_wav)
164
+ speaker_info = {"path": speaker_wav, "speed": speed}
165
+ styles = self.model.get_styles(speaker_info, denoise=kwargs.get("denoise", self.denoise), avg_style=kwargs.get("avg_style", self.avg_style))
166
+ else:
167
+ styles = self.cached_style
168
+ styles['speed'] = speed
169
+
170
+ with torch.no_grad():
171
+ phonemes = self._phonemize(text, language)
172
+ wav = self.model.generate(phonemes, styles, stabilize=stabilize)
173
+ wav = wav / np.max(np.abs(wav))
174
+
175
+ return wav.astype(np.float32)
176
+
177
+ if __name__ == "__main__":
178
+ SPEAKER_WAV_PATH = "speakers/example_female.wav"
179
+ if not os.path.exists(SPEAKER_WAV_PATH):
180
+ print(f"Lỗi: Không tìm thấy file âm thanh mẫu tại '{SPEAKER_WAV_PATH}'.")
181
+ else:
182
+ # Khởi tạo plugin
183
+ styletts_utils = StyleTTModel(speaker_wav=SPEAKER_WAV_PATH)
184
+ styletts_utils.load() # Load model trước
185
+ print("\n" + "="*50)
186
+ print("🔍 KIỂM TRA THIẾT BỊ (DEVICE) RUNTIME")
187
+
188
+ # 1. PyTorch có "nhìn thấy" GPU không?
189
+ cuda_available = torch.cuda.is_available()
190
+ print(f" - PyTorch có tìm thấy CUDA không? : {cuda_available}")
191
+ if styletts_utils.model:
192
+ model_device = next(styletts_utils.model.parameters()).device
193
+ print(f" - Model thực sự đang nằm trên? : {model_device}")
194
+ if "cuda" in str(model_device):
195
+ print("\n>>> KẾT LUẬN: ✅ Model đang chạy trên GPU.")
196
+ else:
197
+ print("\n>>> KẾT LUẬN: ❌ Model đang chạy trên CPU.")
198
+ else:
199
+ print(" - Model chưa được load.")
200
+ print("="*50)
201
+ print("\n--- Thử nghiệm tổng hợp âm thanh ---")
202
+ long_text = "StyleTTS 2 is a text-to-speech model that offers zero-shot speaker adaptation."
203
+ audio = styletts_utils.synthesize(long_text)
204
+
205
+ output_path = "plugin_pytorch_output.wav"
206
+ styletts_utils.save_audio(audio, output_path)
207
+ print(f"✅ Âm thanh đã được lưu thành công tại: {output_path}")
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ transformers
2
+ torch
3
+ soundfile
4
+ numpy
5
+ scipy
6
+ gradio
7
+ librosa
8
+ matplotlib
speakers/example_female.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:89a4fa9a16b6463f852cf9424f72c3d3c87aa83010e89db534c53fcd1ae12c02
3
+ size 1002030