Spaces:
Running
on
Zero
Running
on
Zero
zhzluke96
commited on
Commit
·
32b2aaa
1
Parent(s):
2ca1c87
update
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- modules/Enhancer/ResembleEnhance.py +3 -8
- modules/repos_static/__init__.py +0 -0
- modules/repos_static/readme.md +5 -0
- modules/repos_static/resemble_enhance/__init__.py +0 -0
- modules/repos_static/resemble_enhance/common.py +55 -0
- modules/repos_static/resemble_enhance/data/__init__.py +48 -0
- modules/repos_static/resemble_enhance/data/dataset.py +171 -0
- modules/repos_static/resemble_enhance/data/distorter/__init__.py +1 -0
- modules/repos_static/resemble_enhance/data/distorter/base.py +104 -0
- modules/repos_static/resemble_enhance/data/distorter/custom.py +85 -0
- modules/repos_static/resemble_enhance/data/distorter/distorter.py +32 -0
- modules/repos_static/resemble_enhance/data/distorter/sox.py +176 -0
- modules/repos_static/resemble_enhance/data/utils.py +43 -0
- modules/repos_static/resemble_enhance/denoiser/__init__.py +0 -0
- modules/repos_static/resemble_enhance/denoiser/__main__.py +30 -0
- modules/repos_static/resemble_enhance/denoiser/denoiser.py +181 -0
- modules/repos_static/resemble_enhance/denoiser/hparams.py +9 -0
- modules/repos_static/resemble_enhance/denoiser/inference.py +31 -0
- modules/repos_static/resemble_enhance/denoiser/unet.py +144 -0
- modules/repos_static/resemble_enhance/enhancer/__init__.py +0 -0
- modules/repos_static/resemble_enhance/enhancer/__main__.py +129 -0
- modules/repos_static/resemble_enhance/enhancer/download.py +30 -0
- modules/repos_static/resemble_enhance/enhancer/enhancer.py +185 -0
- modules/repos_static/resemble_enhance/enhancer/hparams.py +23 -0
- modules/repos_static/resemble_enhance/enhancer/inference.py +48 -0
- modules/repos_static/resemble_enhance/enhancer/lcfm/__init__.py +2 -0
- modules/repos_static/resemble_enhance/enhancer/lcfm/cfm.py +372 -0
- modules/repos_static/resemble_enhance/enhancer/lcfm/irmae.py +123 -0
- modules/repos_static/resemble_enhance/enhancer/lcfm/lcfm.py +152 -0
- modules/repos_static/resemble_enhance/enhancer/lcfm/wn.py +147 -0
- modules/repos_static/resemble_enhance/enhancer/univnet/__init__.py +1 -0
- modules/repos_static/resemble_enhance/enhancer/univnet/alias_free_torch/__init__.py +5 -0
- modules/repos_static/resemble_enhance/enhancer/univnet/alias_free_torch/filter.py +95 -0
- modules/repos_static/resemble_enhance/enhancer/univnet/alias_free_torch/resample.py +49 -0
- modules/repos_static/resemble_enhance/enhancer/univnet/amp.py +101 -0
- modules/repos_static/resemble_enhance/enhancer/univnet/discriminator.py +210 -0
- modules/repos_static/resemble_enhance/enhancer/univnet/lvcnet.py +281 -0
- modules/repos_static/resemble_enhance/enhancer/univnet/mrstft.py +128 -0
- modules/repos_static/resemble_enhance/enhancer/univnet/univnet.py +94 -0
- modules/repos_static/resemble_enhance/hparams.py +128 -0
- modules/repos_static/resemble_enhance/inference.py +163 -0
- modules/repos_static/resemble_enhance/melspec.py +61 -0
- modules/repos_static/resemble_enhance/utils/__init__.py +2 -0
- modules/repos_static/resemble_enhance/utils/control.py +26 -0
- modules/repos_static/resemble_enhance/utils/logging.py +38 -0
- modules/repos_static/resemble_enhance/utils/utils.py +73 -0
- modules/speaker.py +4 -0
- modules/webui/speaker/__init__.py +0 -0
- modules/webui/speaker/speaker_creator.py +171 -0
- modules/webui/speaker/speaker_merger.py +255 -0
modules/Enhancer/ResembleEnhance.py
CHANGED
@@ -1,13 +1,8 @@
|
|
1 |
import os
|
2 |
from typing import List
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
from resemble_enhance.enhancer.hparams import HParams
|
7 |
-
from resemble_enhance.inference import inference
|
8 |
-
except:
|
9 |
-
HParams = dict
|
10 |
-
Enhancer = dict
|
11 |
|
12 |
import torch
|
13 |
|
|
|
1 |
import os
|
2 |
from typing import List
|
3 |
+
from modules.repos_static.resemble_enhance.enhancer.enhancer import Enhancer
|
4 |
+
from modules.repos_static.resemble_enhance.enhancer.hparams import HParams
|
5 |
+
from modules.repos_static.resemble_enhance.inference import inference
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
import torch
|
8 |
|
modules/repos_static/__init__.py
ADDED
File without changes
|
modules/repos_static/readme.md
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# repos static
|
2 |
+
|
3 |
+
## resemble_enhance
|
4 |
+
|
5 |
+
https://github.com/resemble-ai/resemble-enhance/tree/main
|
modules/repos_static/resemble_enhance/__init__.py
ADDED
File without changes
|
modules/repos_static/resemble_enhance/common.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import Tensor, nn
|
5 |
+
|
6 |
+
logger = logging.getLogger(__name__)
|
7 |
+
|
8 |
+
|
9 |
+
class Normalizer(nn.Module):
|
10 |
+
def __init__(self, momentum=0.01, eps=1e-9):
|
11 |
+
super().__init__()
|
12 |
+
self.momentum = momentum
|
13 |
+
self.eps = eps
|
14 |
+
self.running_mean_unsafe: Tensor
|
15 |
+
self.running_var_unsafe: Tensor
|
16 |
+
self.register_buffer("running_mean_unsafe", torch.full([], torch.nan))
|
17 |
+
self.register_buffer("running_var_unsafe", torch.full([], torch.nan))
|
18 |
+
|
19 |
+
@property
|
20 |
+
def started(self):
|
21 |
+
return not torch.isnan(self.running_mean_unsafe)
|
22 |
+
|
23 |
+
@property
|
24 |
+
def running_mean(self):
|
25 |
+
if not self.started:
|
26 |
+
return torch.zeros_like(self.running_mean_unsafe)
|
27 |
+
return self.running_mean_unsafe
|
28 |
+
|
29 |
+
@property
|
30 |
+
def running_std(self):
|
31 |
+
if not self.started:
|
32 |
+
return torch.ones_like(self.running_var_unsafe)
|
33 |
+
return (self.running_var_unsafe + self.eps).sqrt()
|
34 |
+
|
35 |
+
@torch.no_grad()
|
36 |
+
def _ema(self, a: Tensor, x: Tensor):
|
37 |
+
return (1 - self.momentum) * a + self.momentum * x
|
38 |
+
|
39 |
+
def update_(self, x):
|
40 |
+
if not self.started:
|
41 |
+
self.running_mean_unsafe = x.mean()
|
42 |
+
self.running_var_unsafe = x.var()
|
43 |
+
else:
|
44 |
+
self.running_mean_unsafe = self._ema(self.running_mean_unsafe, x.mean())
|
45 |
+
self.running_var_unsafe = self._ema(self.running_var_unsafe, (x - self.running_mean).pow(2).mean())
|
46 |
+
|
47 |
+
def forward(self, x: Tensor, update=True):
|
48 |
+
if self.training and update:
|
49 |
+
self.update_(x)
|
50 |
+
self.stats = dict(mean=self.running_mean.item(), std=self.running_std.item())
|
51 |
+
x = (x - self.running_mean) / self.running_std
|
52 |
+
return x
|
53 |
+
|
54 |
+
def inverse(self, x: Tensor):
|
55 |
+
return x * self.running_std + self.running_mean
|
modules/repos_static/resemble_enhance/data/__init__.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import random
|
3 |
+
|
4 |
+
from torch.utils.data import DataLoader
|
5 |
+
|
6 |
+
from ..hparams import HParams
|
7 |
+
from .dataset import Dataset
|
8 |
+
from .utils import mix_fg_bg, rglob_audio_files
|
9 |
+
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
|
12 |
+
|
13 |
+
def _create_datasets(hp: HParams, mode, val_size=10, seed=123):
|
14 |
+
paths = rglob_audio_files(hp.fg_dir)
|
15 |
+
logger.info(f"Found {len(paths)} audio files in {hp.fg_dir}")
|
16 |
+
|
17 |
+
random.Random(seed).shuffle(paths)
|
18 |
+
train_paths = paths[:-val_size]
|
19 |
+
val_paths = paths[-val_size:]
|
20 |
+
|
21 |
+
train_ds = Dataset(train_paths, hp, training=True, mode=mode)
|
22 |
+
val_ds = Dataset(val_paths, hp, training=False, mode=mode)
|
23 |
+
|
24 |
+
logger.info(f"Train set: {len(train_ds)} samples - Val set: {len(val_ds)} samples")
|
25 |
+
|
26 |
+
return train_ds, val_ds
|
27 |
+
|
28 |
+
|
29 |
+
def create_dataloaders(hp: HParams, mode):
|
30 |
+
train_ds, val_ds = _create_datasets(hp=hp, mode=mode)
|
31 |
+
|
32 |
+
train_dl = DataLoader(
|
33 |
+
train_ds,
|
34 |
+
batch_size=hp.batch_size_per_gpu,
|
35 |
+
shuffle=True,
|
36 |
+
num_workers=hp.nj,
|
37 |
+
drop_last=True,
|
38 |
+
collate_fn=train_ds.collate_fn,
|
39 |
+
)
|
40 |
+
val_dl = DataLoader(
|
41 |
+
val_ds,
|
42 |
+
batch_size=1,
|
43 |
+
shuffle=False,
|
44 |
+
num_workers=hp.nj,
|
45 |
+
drop_last=False,
|
46 |
+
collate_fn=val_ds.collate_fn,
|
47 |
+
)
|
48 |
+
return train_dl, val_dl
|
modules/repos_static/resemble_enhance/data/dataset.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import random
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torchaudio
|
8 |
+
import torchaudio.functional as AF
|
9 |
+
from torch.nn.utils.rnn import pad_sequence
|
10 |
+
from torch.utils.data import Dataset as DatasetBase
|
11 |
+
|
12 |
+
from ..hparams import HParams
|
13 |
+
from .distorter import Distorter
|
14 |
+
from .utils import rglob_audio_files
|
15 |
+
|
16 |
+
logger = logging.getLogger(__name__)
|
17 |
+
|
18 |
+
|
19 |
+
def _normalize(x):
|
20 |
+
return x / (np.abs(x).max() + 1e-7)
|
21 |
+
|
22 |
+
|
23 |
+
def _collate(batch, key, tensor=True, pad=True):
|
24 |
+
l = [d[key] for d in batch]
|
25 |
+
if l[0] is None:
|
26 |
+
return None
|
27 |
+
if tensor:
|
28 |
+
l = [torch.from_numpy(x) for x in l]
|
29 |
+
if pad:
|
30 |
+
assert tensor, "Can't pad non-tensor"
|
31 |
+
l = pad_sequence(l, batch_first=True)
|
32 |
+
return l
|
33 |
+
|
34 |
+
|
35 |
+
def praat_augment(wav, sr):
|
36 |
+
try:
|
37 |
+
import parselmouth
|
38 |
+
except ImportError:
|
39 |
+
raise ImportError("Please install parselmouth>=0.5.0 to use Praat augmentation")
|
40 |
+
# "praat-parselmouth @ git+https://github.com/YannickJadoul/Parselmouth@0bbcca69705ed73322f3712b19d71bb3694b2540",
|
41 |
+
# https://github.com/YannickJadoul/Parselmouth/issues/68
|
42 |
+
# note that this function may hang if the praat version is 0.4.3
|
43 |
+
assert wav.ndim == 1, f"wav.ndim must be 1 but got {wav.ndim}"
|
44 |
+
sound = parselmouth.Sound(wav, sr)
|
45 |
+
formant_shift_ratio = random.uniform(1.1, 1.5)
|
46 |
+
pitch_range_factor = random.uniform(0.5, 2.0)
|
47 |
+
sound = parselmouth.praat.call(sound, "Change gender", 75, 600, formant_shift_ratio, 0, pitch_range_factor, 1.0)
|
48 |
+
wav = np.array(sound.values)[0].astype(np.float32)
|
49 |
+
return wav
|
50 |
+
|
51 |
+
|
52 |
+
class Dataset(DatasetBase):
|
53 |
+
def __init__(
|
54 |
+
self,
|
55 |
+
fg_paths: list[Path],
|
56 |
+
hp: HParams,
|
57 |
+
training=True,
|
58 |
+
max_retries=100,
|
59 |
+
silent_fg_prob=0.01,
|
60 |
+
mode=False,
|
61 |
+
):
|
62 |
+
super().__init__()
|
63 |
+
|
64 |
+
assert mode in ("enhancer", "denoiser"), f"Invalid mode: {mode}"
|
65 |
+
|
66 |
+
self.hp = hp
|
67 |
+
self.fg_paths = fg_paths
|
68 |
+
self.bg_paths = rglob_audio_files(hp.bg_dir)
|
69 |
+
|
70 |
+
if len(self.fg_paths) == 0:
|
71 |
+
raise ValueError(f"No foreground audio files found in {hp.fg_dir}")
|
72 |
+
|
73 |
+
if len(self.bg_paths) == 0:
|
74 |
+
raise ValueError(f"No background audio files found in {hp.bg_dir}")
|
75 |
+
|
76 |
+
logger.info(f"Found {len(self.fg_paths)} foreground files and {len(self.bg_paths)} background files")
|
77 |
+
|
78 |
+
self.training = training
|
79 |
+
self.max_retries = max_retries
|
80 |
+
self.silent_fg_prob = silent_fg_prob
|
81 |
+
|
82 |
+
self.mode = mode
|
83 |
+
self.distorter = Distorter(hp, training=training, mode=mode)
|
84 |
+
|
85 |
+
def _load_wav(self, path, length=None, random_crop=True):
|
86 |
+
wav, sr = torchaudio.load(path)
|
87 |
+
|
88 |
+
wav = AF.resample(
|
89 |
+
waveform=wav,
|
90 |
+
orig_freq=sr,
|
91 |
+
new_freq=self.hp.wav_rate,
|
92 |
+
lowpass_filter_width=64,
|
93 |
+
rolloff=0.9475937167399596,
|
94 |
+
resampling_method="sinc_interp_kaiser",
|
95 |
+
beta=14.769656459379492,
|
96 |
+
)
|
97 |
+
|
98 |
+
wav = wav.float().numpy()
|
99 |
+
|
100 |
+
if wav.ndim == 2:
|
101 |
+
wav = np.mean(wav, axis=0)
|
102 |
+
|
103 |
+
if length is None and self.training:
|
104 |
+
length = int(self.hp.training_seconds * self.hp.wav_rate)
|
105 |
+
|
106 |
+
if length is not None:
|
107 |
+
if random_crop:
|
108 |
+
start = random.randint(0, max(0, len(wav) - length))
|
109 |
+
wav = wav[start : start + length]
|
110 |
+
else:
|
111 |
+
wav = wav[:length]
|
112 |
+
|
113 |
+
if length is not None and len(wav) < length:
|
114 |
+
wav = np.pad(wav, (0, length - len(wav)))
|
115 |
+
|
116 |
+
wav = _normalize(wav)
|
117 |
+
|
118 |
+
return wav
|
119 |
+
|
120 |
+
def _getitem_unsafe(self, index: int):
|
121 |
+
fg_path = self.fg_paths[index]
|
122 |
+
|
123 |
+
if self.training and random.random() < self.silent_fg_prob:
|
124 |
+
fg_wav = np.zeros(int(self.hp.training_seconds * self.hp.wav_rate), dtype=np.float32)
|
125 |
+
else:
|
126 |
+
fg_wav = self._load_wav(fg_path)
|
127 |
+
if random.random() < self.hp.praat_augment_prob and self.training:
|
128 |
+
fg_wav = praat_augment(fg_wav, self.hp.wav_rate)
|
129 |
+
|
130 |
+
if self.hp.load_fg_only:
|
131 |
+
bg_wav = None
|
132 |
+
fg_dwav = None
|
133 |
+
bg_dwav = None
|
134 |
+
else:
|
135 |
+
fg_dwav = _normalize(self.distorter(fg_wav, self.hp.wav_rate)).astype(np.float32)
|
136 |
+
if self.training:
|
137 |
+
bg_path = random.choice(self.bg_paths)
|
138 |
+
else:
|
139 |
+
# Deterministic for validation
|
140 |
+
bg_path = self.bg_paths[index % len(self.bg_paths)]
|
141 |
+
bg_wav = self._load_wav(bg_path, length=len(fg_wav), random_crop=self.training)
|
142 |
+
bg_dwav = _normalize(self.distorter(bg_wav, self.hp.wav_rate)).astype(np.float32)
|
143 |
+
|
144 |
+
return dict(
|
145 |
+
fg_wav=fg_wav,
|
146 |
+
bg_wav=bg_wav,
|
147 |
+
fg_dwav=fg_dwav,
|
148 |
+
bg_dwav=bg_dwav,
|
149 |
+
)
|
150 |
+
|
151 |
+
def __getitem__(self, index: int):
|
152 |
+
for i in range(self.max_retries):
|
153 |
+
try:
|
154 |
+
return self._getitem_unsafe(index)
|
155 |
+
except Exception as e:
|
156 |
+
if i == self.max_retries - 1:
|
157 |
+
raise RuntimeError(f"Failed to load {self.fg_paths[index]} after {self.max_retries} retries") from e
|
158 |
+
logger.debug(f"Error loading {self.fg_paths[index]}: {e}, skipping")
|
159 |
+
index = np.random.randint(0, len(self))
|
160 |
+
|
161 |
+
def __len__(self):
|
162 |
+
return len(self.fg_paths)
|
163 |
+
|
164 |
+
@staticmethod
|
165 |
+
def collate_fn(batch):
|
166 |
+
return dict(
|
167 |
+
fg_wavs=_collate(batch, "fg_wav"),
|
168 |
+
bg_wavs=_collate(batch, "bg_wav"),
|
169 |
+
fg_dwavs=_collate(batch, "fg_dwav"),
|
170 |
+
bg_dwavs=_collate(batch, "bg_dwav"),
|
171 |
+
)
|
modules/repos_static/resemble_enhance/data/distorter/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .distorter import Distorter
|
modules/repos_static/resemble_enhance/data/distorter/base.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import itertools
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
import time
|
5 |
+
import warnings
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
_DEBUG = bool(os.environ.get("DEBUG", False))
|
10 |
+
|
11 |
+
|
12 |
+
class Effect:
|
13 |
+
def apply(self, wav: np.ndarray, sr: int):
|
14 |
+
"""
|
15 |
+
Args:
|
16 |
+
wav: (T)
|
17 |
+
sr: sample rate
|
18 |
+
Returns:
|
19 |
+
wav: (T) with the same sample rate of `sr`
|
20 |
+
"""
|
21 |
+
raise NotImplementedError
|
22 |
+
|
23 |
+
def __call__(self, wav: np.ndarray, sr: int):
|
24 |
+
"""
|
25 |
+
Args:
|
26 |
+
wav: (T)
|
27 |
+
sr: sample rate
|
28 |
+
Returns:
|
29 |
+
wav: (T) with the same sample rate of `sr`
|
30 |
+
"""
|
31 |
+
assert len(wav.shape) == 1, wav.shape
|
32 |
+
|
33 |
+
if _DEBUG:
|
34 |
+
start = time.time()
|
35 |
+
else:
|
36 |
+
start = None
|
37 |
+
|
38 |
+
shape = wav.shape
|
39 |
+
assert wav.ndim == 1, f"{self}: Expected wav.ndim == 1, got {wav.ndim}."
|
40 |
+
wav = self.apply(wav, sr)
|
41 |
+
assert shape == wav.shape, f"{self}: {shape} != {wav.shape}."
|
42 |
+
|
43 |
+
if start is not None:
|
44 |
+
end = time.time()
|
45 |
+
print(f"{self.__class__.__name__}: {end - start:.3f} sec")
|
46 |
+
|
47 |
+
return wav
|
48 |
+
|
49 |
+
|
50 |
+
class Chain(Effect):
|
51 |
+
def __init__(self, *effects):
|
52 |
+
super().__init__()
|
53 |
+
|
54 |
+
self.effects = effects
|
55 |
+
|
56 |
+
def apply(self, wav, sr):
|
57 |
+
for effect in self.effects:
|
58 |
+
wav = effect(wav, sr)
|
59 |
+
return wav
|
60 |
+
|
61 |
+
|
62 |
+
class Maybe(Effect):
|
63 |
+
def __init__(self, prob, effect):
|
64 |
+
super().__init__()
|
65 |
+
|
66 |
+
self.prob = prob
|
67 |
+
self.effect = effect
|
68 |
+
|
69 |
+
if _DEBUG:
|
70 |
+
warnings.warn("DEBUG mode is on. Maybe -> Must.")
|
71 |
+
self.prob = 1
|
72 |
+
|
73 |
+
def apply(self, wav, sr):
|
74 |
+
if random.random() > self.prob:
|
75 |
+
return wav
|
76 |
+
return self.effect(wav, sr)
|
77 |
+
|
78 |
+
|
79 |
+
class Choice(Effect):
|
80 |
+
def __init__(self, *effects, **kwargs):
|
81 |
+
super().__init__()
|
82 |
+
self.effects = effects
|
83 |
+
self.kwargs = kwargs
|
84 |
+
|
85 |
+
def apply(self, wav, sr):
|
86 |
+
return np.random.choice(self.effects, **self.kwargs)(wav, sr)
|
87 |
+
|
88 |
+
|
89 |
+
class Permutation(Effect):
|
90 |
+
def __init__(self, *effects, n: int | None = None):
|
91 |
+
super().__init__()
|
92 |
+
self.effects = effects
|
93 |
+
self.n = n
|
94 |
+
|
95 |
+
def apply(self, wav, sr):
|
96 |
+
if self.n is None:
|
97 |
+
n = np.random.binomial(len(self.effects), 0.5)
|
98 |
+
else:
|
99 |
+
n = self.n
|
100 |
+
if n == 0:
|
101 |
+
return wav
|
102 |
+
perms = itertools.permutations(self.effects, n)
|
103 |
+
effects = random.choice(list(perms))
|
104 |
+
return Chain(*effects)(wav, sr)
|
modules/repos_static/resemble_enhance/data/distorter/custom.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import random
|
3 |
+
from dataclasses import dataclass
|
4 |
+
from functools import cached_property
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
import librosa
|
8 |
+
import numpy as np
|
9 |
+
from scipy import signal
|
10 |
+
|
11 |
+
from ..utils import walk_paths
|
12 |
+
from .base import Effect
|
13 |
+
|
14 |
+
_logger = logging.getLogger(__name__)
|
15 |
+
|
16 |
+
|
17 |
+
@dataclass
|
18 |
+
class RandomRIR(Effect):
|
19 |
+
rir_dir: Path | None
|
20 |
+
rir_rate: int = 44_000
|
21 |
+
rir_suffix: str = ".npy"
|
22 |
+
deterministic: bool = False
|
23 |
+
|
24 |
+
@cached_property
|
25 |
+
def rir_paths(self):
|
26 |
+
if self.rir_dir is None:
|
27 |
+
return []
|
28 |
+
return list(walk_paths(self.rir_dir, self.rir_suffix))
|
29 |
+
|
30 |
+
def _sample_rir(self):
|
31 |
+
if len(self.rir_paths) == 0:
|
32 |
+
return None
|
33 |
+
|
34 |
+
if self.deterministic:
|
35 |
+
rir_path = self.rir_paths[0]
|
36 |
+
else:
|
37 |
+
rir_path = random.choice(self.rir_paths)
|
38 |
+
|
39 |
+
rir = np.squeeze(np.load(rir_path))
|
40 |
+
assert isinstance(rir, np.ndarray)
|
41 |
+
|
42 |
+
return rir
|
43 |
+
|
44 |
+
def apply(self, wav, sr):
|
45 |
+
# ref: https://github.com/haoheliu/voicefixer_main/blob/b06e07c945ac1d309b8a57ddcd599ca376b98cd9/dataloaders/augmentation/magical_effects.py#L158
|
46 |
+
|
47 |
+
if len(self.rir_paths) == 0:
|
48 |
+
return wav
|
49 |
+
|
50 |
+
length = len(wav)
|
51 |
+
|
52 |
+
wav = librosa.resample(wav, orig_sr=sr, target_sr=self.rir_rate, res_type="kaiser_fast")
|
53 |
+
rir = self._sample_rir()
|
54 |
+
|
55 |
+
wav = signal.convolve(wav, rir, mode="same")
|
56 |
+
|
57 |
+
actlev = np.max(np.abs(wav))
|
58 |
+
if actlev > 0.99:
|
59 |
+
wav = (wav / actlev) * 0.98
|
60 |
+
|
61 |
+
wav = librosa.resample(wav, orig_sr=self.rir_rate, target_sr=sr, res_type="kaiser_fast")
|
62 |
+
|
63 |
+
if abs(length - len(wav)) > 10:
|
64 |
+
_logger.warning(f"length mismatch: {length} vs {len(wav)}")
|
65 |
+
|
66 |
+
if length > len(wav):
|
67 |
+
wav = np.pad(wav, (0, length - len(wav)))
|
68 |
+
elif length < len(wav):
|
69 |
+
wav = wav[:length]
|
70 |
+
|
71 |
+
return wav
|
72 |
+
|
73 |
+
|
74 |
+
class RandomGaussianNoise(Effect):
|
75 |
+
def __init__(self, alpha_range=(0.8, 1)):
|
76 |
+
super().__init__()
|
77 |
+
self.alpha_range = alpha_range
|
78 |
+
|
79 |
+
def apply(self, wav, sr):
|
80 |
+
noise = np.random.randn(*wav.shape)
|
81 |
+
noise_energy = np.sum(noise**2)
|
82 |
+
wav_energy = np.sum(wav**2)
|
83 |
+
noise = noise * np.sqrt(wav_energy / noise_energy)
|
84 |
+
alpha = random.uniform(*self.alpha_range)
|
85 |
+
return wav * alpha + noise * (1 - alpha)
|
modules/repos_static/resemble_enhance/data/distorter/distorter.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ...hparams import HParams
|
2 |
+
from .base import Chain, Choice, Permutation
|
3 |
+
from .custom import RandomGaussianNoise, RandomRIR
|
4 |
+
|
5 |
+
|
6 |
+
class Distorter(Chain):
|
7 |
+
def __init__(self, hp: HParams, training: bool = False, mode: str = "enhancer"):
|
8 |
+
# Lazy import
|
9 |
+
from .sox import RandomBandpassDistorter, RandomEqualizer, RandomLowpassDistorter, RandomOverdrive, RandomReverb
|
10 |
+
|
11 |
+
if training:
|
12 |
+
permutation = Permutation(
|
13 |
+
RandomRIR(hp.rir_dir),
|
14 |
+
RandomReverb(),
|
15 |
+
RandomGaussianNoise(),
|
16 |
+
RandomOverdrive(),
|
17 |
+
RandomEqualizer(),
|
18 |
+
Choice(
|
19 |
+
RandomLowpassDistorter(),
|
20 |
+
RandomBandpassDistorter(),
|
21 |
+
),
|
22 |
+
)
|
23 |
+
if mode == "denoiser":
|
24 |
+
super().__init__(permutation)
|
25 |
+
else:
|
26 |
+
# 80%: distortion, 20%: clean
|
27 |
+
super().__init__(Choice(permutation, Chain(), p=[0.8, 0.2]))
|
28 |
+
else:
|
29 |
+
super().__init__(
|
30 |
+
RandomRIR(hp.rir_dir, deterministic=True),
|
31 |
+
RandomReverb(deterministic=True),
|
32 |
+
)
|
modules/repos_static/resemble_enhance/data/distorter/sox.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
import warnings
|
5 |
+
from functools import partial
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
|
10 |
+
try:
|
11 |
+
import augment
|
12 |
+
except ImportError:
|
13 |
+
raise ImportError(
|
14 |
+
"augment is not installed, please install it first using:"
|
15 |
+
"\npip install git+https://github.com/facebookresearch/WavAugment@54afcdb00ccc852c2f030f239f8532c9562b550e"
|
16 |
+
)
|
17 |
+
|
18 |
+
from .base import Effect
|
19 |
+
|
20 |
+
_logger = logging.getLogger(__name__)
|
21 |
+
_DEBUG = bool(os.environ.get("DEBUG", False))
|
22 |
+
|
23 |
+
|
24 |
+
class AttachableEffect(Effect):
|
25 |
+
def attach(self, chain: augment.EffectChain) -> augment.EffectChain:
|
26 |
+
raise NotImplementedError
|
27 |
+
|
28 |
+
def apply(self, wav: np.ndarray, sr: int):
|
29 |
+
chain = augment.EffectChain()
|
30 |
+
chain = self.attach(chain)
|
31 |
+
tensor = torch.from_numpy(wav)[None].float() # (1, T)
|
32 |
+
tensor = chain.apply(tensor, src_info={"rate": sr}, target_info={"channels": 1, "rate": sr})
|
33 |
+
wav = tensor.numpy()[0] # (T,)
|
34 |
+
return wav
|
35 |
+
|
36 |
+
|
37 |
+
class SoxEffect(AttachableEffect):
|
38 |
+
def __init__(self, effect_name: str, *args, **kwargs):
|
39 |
+
self.effect_name = effect_name
|
40 |
+
self.args = args
|
41 |
+
self.kwargs = kwargs
|
42 |
+
|
43 |
+
def attach(self, chain: augment.EffectChain) -> augment.EffectChain:
|
44 |
+
_logger.debug(f"Attaching {self.effect_name} with {self.args} and {self.kwargs}")
|
45 |
+
if not hasattr(chain, self.effect_name):
|
46 |
+
raise ValueError(f"EffectChain has no attribute {self.effect_name}")
|
47 |
+
return getattr(chain, self.effect_name)(*self.args, **self.kwargs)
|
48 |
+
|
49 |
+
|
50 |
+
class Maybe(AttachableEffect):
|
51 |
+
"""
|
52 |
+
Attach an effect with a probability.
|
53 |
+
"""
|
54 |
+
|
55 |
+
def __init__(self, prob: float, effect: AttachableEffect):
|
56 |
+
self.prob = prob
|
57 |
+
self.effect = effect
|
58 |
+
if _DEBUG:
|
59 |
+
warnings.warn("DEBUG mode is on. Maybe -> Must.")
|
60 |
+
self.prob = 1
|
61 |
+
|
62 |
+
def attach(self, chain: augment.EffectChain) -> augment.EffectChain:
|
63 |
+
if random.random() > self.prob:
|
64 |
+
return chain
|
65 |
+
return self.effect.attach(chain)
|
66 |
+
|
67 |
+
|
68 |
+
class Chain(AttachableEffect):
|
69 |
+
"""
|
70 |
+
Attach a chain of effects.
|
71 |
+
"""
|
72 |
+
|
73 |
+
def __init__(self, *effects: AttachableEffect):
|
74 |
+
self.effects = effects
|
75 |
+
|
76 |
+
def attach(self, chain: augment.EffectChain) -> augment.EffectChain:
|
77 |
+
for effect in self.effects:
|
78 |
+
chain = effect.attach(chain)
|
79 |
+
return chain
|
80 |
+
|
81 |
+
|
82 |
+
class Choice(AttachableEffect):
|
83 |
+
"""
|
84 |
+
Attach one of the effects randomly.
|
85 |
+
"""
|
86 |
+
|
87 |
+
def __init__(self, *effects: AttachableEffect):
|
88 |
+
self.effects = effects
|
89 |
+
|
90 |
+
def attach(self, chain: augment.EffectChain) -> augment.EffectChain:
|
91 |
+
return random.choice(self.effects).attach(chain)
|
92 |
+
|
93 |
+
|
94 |
+
class Generator:
|
95 |
+
def __call__(self) -> str:
|
96 |
+
raise NotImplementedError
|
97 |
+
|
98 |
+
|
99 |
+
class Uniform(Generator):
|
100 |
+
def __init__(self, low, high):
|
101 |
+
self.low = low
|
102 |
+
self.high = high
|
103 |
+
|
104 |
+
def __call__(self) -> str:
|
105 |
+
return str(random.uniform(self.low, self.high))
|
106 |
+
|
107 |
+
|
108 |
+
class Randint(Generator):
|
109 |
+
def __init__(self, low, high):
|
110 |
+
self.low = low
|
111 |
+
self.high = high
|
112 |
+
|
113 |
+
def __call__(self) -> str:
|
114 |
+
return str(random.randint(self.low, self.high))
|
115 |
+
|
116 |
+
|
117 |
+
class Concat(Generator):
|
118 |
+
def __init__(self, *parts: Generator | str):
|
119 |
+
self.parts = parts
|
120 |
+
|
121 |
+
def __call__(self):
|
122 |
+
return "".join([part if isinstance(part, str) else part() for part in self.parts])
|
123 |
+
|
124 |
+
|
125 |
+
class RandomLowpassDistorter(SoxEffect):
|
126 |
+
def __init__(self, low=2000, high=16000):
|
127 |
+
super().__init__("sinc", "-n", Randint(50, 200), Concat("-", Uniform(low, high)))
|
128 |
+
|
129 |
+
|
130 |
+
class RandomBandpassDistorter(SoxEffect):
|
131 |
+
def __init__(self, low=100, high=1000, min_width=2000, max_width=4000):
|
132 |
+
super().__init__("sinc", "-n", Randint(50, 200), partial(self._fn, low, high, min_width, max_width))
|
133 |
+
|
134 |
+
@staticmethod
|
135 |
+
def _fn(low, high, min_width, max_width):
|
136 |
+
start = random.randint(low, high)
|
137 |
+
stop = start + random.randint(min_width, max_width)
|
138 |
+
return f"{start}-{stop}"
|
139 |
+
|
140 |
+
|
141 |
+
class RandomEqualizer(SoxEffect):
|
142 |
+
def __init__(self, low=100, high=4000, q_low=1, q_high=5, db_low: int = -30, db_high: int = 30):
|
143 |
+
super().__init__(
|
144 |
+
"equalizer",
|
145 |
+
Uniform(low, high),
|
146 |
+
lambda: f"{random.randint(q_low, q_high)}q",
|
147 |
+
lambda: random.randint(db_low, db_high),
|
148 |
+
)
|
149 |
+
|
150 |
+
|
151 |
+
class RandomOverdrive(SoxEffect):
|
152 |
+
def __init__(self, gain_low=5, gain_high=40, colour_low=20, colour_high=80):
|
153 |
+
super().__init__("overdrive", Uniform(gain_low, gain_high), Uniform(colour_low, colour_high))
|
154 |
+
|
155 |
+
|
156 |
+
class RandomReverb(Chain):
|
157 |
+
def __init__(self, deterministic=False):
|
158 |
+
super().__init__(
|
159 |
+
SoxEffect(
|
160 |
+
"reverb",
|
161 |
+
Uniform(50, 50) if deterministic else Uniform(0, 100),
|
162 |
+
Uniform(50, 50) if deterministic else Uniform(0, 100),
|
163 |
+
Uniform(50, 50) if deterministic else Uniform(0, 100),
|
164 |
+
),
|
165 |
+
SoxEffect("channels", 1),
|
166 |
+
)
|
167 |
+
|
168 |
+
|
169 |
+
class Flanger(SoxEffect):
|
170 |
+
def __init__(self):
|
171 |
+
super().__init__("flanger")
|
172 |
+
|
173 |
+
|
174 |
+
class Phaser(SoxEffect):
|
175 |
+
def __init__(self):
|
176 |
+
super().__init__("phaser")
|
modules/repos_static/resemble_enhance/data/utils.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
from typing import Callable
|
3 |
+
|
4 |
+
from torch import Tensor
|
5 |
+
|
6 |
+
|
7 |
+
def walk_paths(root, suffix):
|
8 |
+
for path in Path(root).iterdir():
|
9 |
+
if path.is_dir():
|
10 |
+
yield from walk_paths(path, suffix)
|
11 |
+
elif path.suffix == suffix:
|
12 |
+
yield path
|
13 |
+
|
14 |
+
|
15 |
+
def rglob_audio_files(path: Path):
|
16 |
+
return list(walk_paths(path, ".wav")) + list(walk_paths(path, ".flac"))
|
17 |
+
|
18 |
+
|
19 |
+
def mix_fg_bg(fg: Tensor, bg: Tensor, alpha: float | Callable[..., float] = 0.5, eps=1e-7):
|
20 |
+
"""
|
21 |
+
Args:
|
22 |
+
fg: (b, t)
|
23 |
+
bg: (b, t)
|
24 |
+
"""
|
25 |
+
assert bg.shape == fg.shape, f"bg.shape != fg.shape: {bg.shape} != {fg.shape}"
|
26 |
+
fg = fg / (fg.abs().max(dim=-1, keepdim=True).values + eps)
|
27 |
+
bg = bg / (bg.abs().max(dim=-1, keepdim=True).values + eps)
|
28 |
+
|
29 |
+
fg_energy = fg.pow(2).sum(dim=-1, keepdim=True)
|
30 |
+
bg_energy = bg.pow(2).sum(dim=-1, keepdim=True)
|
31 |
+
|
32 |
+
fg = fg / (fg_energy + eps).sqrt()
|
33 |
+
bg = bg / (bg_energy + eps).sqrt()
|
34 |
+
|
35 |
+
if callable(alpha):
|
36 |
+
alpha = alpha()
|
37 |
+
|
38 |
+
assert 0 <= alpha <= 1, f"alpha must be between 0 and 1: {alpha}"
|
39 |
+
|
40 |
+
mx = alpha * fg + (1 - alpha) * bg
|
41 |
+
mx = mx / (mx.abs().max(dim=-1, keepdim=True).values + eps)
|
42 |
+
|
43 |
+
return mx
|
modules/repos_static/resemble_enhance/denoiser/__init__.py
ADDED
File without changes
|
modules/repos_static/resemble_enhance/denoiser/__main__.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torchaudio
|
6 |
+
|
7 |
+
from .inference import denoise
|
8 |
+
|
9 |
+
|
10 |
+
@torch.inference_mode()
|
11 |
+
def main():
|
12 |
+
parser = argparse.ArgumentParser()
|
13 |
+
parser.add_argument("in_dir", type=Path, help="Path to input audio folder")
|
14 |
+
parser.add_argument("out_dir", type=Path, help="Output folder")
|
15 |
+
parser.add_argument("--run_dir", type=Path, default="runs/denoiser", help="Path to run folder")
|
16 |
+
parser.add_argument("--suffix", type=str, default=".wav", help="File suffix")
|
17 |
+
parser.add_argument("--device", type=str, default="cuda", help="Device")
|
18 |
+
args = parser.parse_args()
|
19 |
+
|
20 |
+
for path in args.in_dir.glob(f"**/*{args.suffix}"):
|
21 |
+
print(f"Processing {path} ..")
|
22 |
+
dwav, sr = torchaudio.load(path)
|
23 |
+
hwav, sr = denoise(dwav[0], sr, args.run_dir, args.device)
|
24 |
+
out_path = args.out_dir / path.relative_to(args.in_dir)
|
25 |
+
out_path.parent.mkdir(parents=True, exist_ok=True)
|
26 |
+
torchaudio.save(out_path, hwav[None], sr)
|
27 |
+
|
28 |
+
|
29 |
+
if __name__ == "__main__":
|
30 |
+
main()
|
modules/repos_static/resemble_enhance/denoiser/denoiser.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch import Tensor, nn
|
6 |
+
|
7 |
+
from ..melspec import MelSpectrogram
|
8 |
+
from .hparams import HParams
|
9 |
+
from .unet import UNet
|
10 |
+
|
11 |
+
logger = logging.getLogger(__name__)
|
12 |
+
|
13 |
+
|
14 |
+
def _normalize(x: Tensor) -> Tensor:
|
15 |
+
return x / (x.abs().max(dim=-1, keepdim=True).values + 1e-7)
|
16 |
+
|
17 |
+
|
18 |
+
class Denoiser(nn.Module):
|
19 |
+
@property
|
20 |
+
def stft_cfg(self) -> dict:
|
21 |
+
hop_size = self.hp.hop_size
|
22 |
+
return dict(hop_length=hop_size, n_fft=hop_size * 4, win_length=hop_size * 4)
|
23 |
+
|
24 |
+
@property
|
25 |
+
def n_fft(self):
|
26 |
+
return self.stft_cfg["n_fft"]
|
27 |
+
|
28 |
+
@property
|
29 |
+
def eps(self):
|
30 |
+
return 1e-7
|
31 |
+
|
32 |
+
def __init__(self, hp: HParams):
|
33 |
+
super().__init__()
|
34 |
+
self.hp = hp
|
35 |
+
self.net = UNet(input_dim=3, output_dim=3)
|
36 |
+
self.mel_fn = MelSpectrogram(hp)
|
37 |
+
|
38 |
+
self.dummy: Tensor
|
39 |
+
self.register_buffer("dummy", torch.zeros(1), persistent=False)
|
40 |
+
|
41 |
+
def to_mel(self, x: Tensor, drop_last=True):
|
42 |
+
"""
|
43 |
+
Args:
|
44 |
+
x: (b t), wavs
|
45 |
+
Returns:
|
46 |
+
o: (b c t), mels
|
47 |
+
"""
|
48 |
+
if drop_last:
|
49 |
+
return self.mel_fn(x)[..., :-1] # (b d t)
|
50 |
+
return self.mel_fn(x)
|
51 |
+
|
52 |
+
def _stft(self, x):
|
53 |
+
"""
|
54 |
+
Args:
|
55 |
+
x: (b t)
|
56 |
+
Returns:
|
57 |
+
mag: (b f t) in [0, inf)
|
58 |
+
cos: (b f t) in [-1, 1]
|
59 |
+
sin: (b f t) in [-1, 1]
|
60 |
+
"""
|
61 |
+
dtype = x.dtype
|
62 |
+
device = x.device
|
63 |
+
|
64 |
+
if x.is_mps:
|
65 |
+
x = x.cpu()
|
66 |
+
|
67 |
+
window = torch.hann_window(self.stft_cfg["win_length"], device=x.device)
|
68 |
+
s = torch.stft(x.float(), **self.stft_cfg, window=window, return_complex=True) # (b f t+1)
|
69 |
+
|
70 |
+
s = s[..., :-1] # (b f t)
|
71 |
+
|
72 |
+
mag = s.abs() # (b f t)
|
73 |
+
|
74 |
+
phi = s.angle() # (b f t)
|
75 |
+
cos = phi.cos() # (b f t)
|
76 |
+
sin = phi.sin() # (b f t)
|
77 |
+
|
78 |
+
mag = mag.to(dtype=dtype, device=device)
|
79 |
+
cos = cos.to(dtype=dtype, device=device)
|
80 |
+
sin = sin.to(dtype=dtype, device=device)
|
81 |
+
|
82 |
+
return mag, cos, sin
|
83 |
+
|
84 |
+
def _istft(self, mag: Tensor, cos: Tensor, sin: Tensor):
|
85 |
+
"""
|
86 |
+
Args:
|
87 |
+
mag: (b f t) in [0, inf)
|
88 |
+
cos: (b f t) in [-1, 1]
|
89 |
+
sin: (b f t) in [-1, 1]
|
90 |
+
Returns:
|
91 |
+
x: (b t)
|
92 |
+
"""
|
93 |
+
device = mag.device
|
94 |
+
dtype = mag.dtype
|
95 |
+
|
96 |
+
if mag.is_mps:
|
97 |
+
mag = mag.cpu()
|
98 |
+
cos = cos.cpu()
|
99 |
+
sin = sin.cpu()
|
100 |
+
|
101 |
+
real = mag * cos # (b f t)
|
102 |
+
imag = mag * sin # (b f t)
|
103 |
+
|
104 |
+
s = torch.complex(real, imag) # (b f t)
|
105 |
+
|
106 |
+
if s.isnan().any():
|
107 |
+
logger.warning("NaN detected in ISTFT input.")
|
108 |
+
|
109 |
+
s = F.pad(s, (0, 1), "replicate") # (b f t+1)
|
110 |
+
|
111 |
+
window = torch.hann_window(self.stft_cfg["win_length"], device=s.device)
|
112 |
+
x = torch.istft(s, **self.stft_cfg, window=window, return_complex=False)
|
113 |
+
|
114 |
+
if x.isnan().any():
|
115 |
+
logger.warning("NaN detected in ISTFT output, set to zero.")
|
116 |
+
x = torch.where(x.isnan(), torch.zeros_like(x), x)
|
117 |
+
|
118 |
+
x = x.to(dtype=dtype, device=device)
|
119 |
+
|
120 |
+
return x
|
121 |
+
|
122 |
+
def _magphase(self, real, imag):
|
123 |
+
mag = (real.pow(2) + imag.pow(2) + self.eps).sqrt()
|
124 |
+
cos = real / mag
|
125 |
+
sin = imag / mag
|
126 |
+
return mag, cos, sin
|
127 |
+
|
128 |
+
def _predict(self, mag: Tensor, cos: Tensor, sin: Tensor):
|
129 |
+
"""
|
130 |
+
Args:
|
131 |
+
mag: (b f t)
|
132 |
+
cos: (b f t)
|
133 |
+
sin: (b f t)
|
134 |
+
Returns:
|
135 |
+
mag_mask: (b f t) in [0, 1], magnitude mask
|
136 |
+
cos_res: (b f t) in [-1, 1], phase residual
|
137 |
+
sin_res: (b f t) in [-1, 1], phase residual
|
138 |
+
"""
|
139 |
+
x = torch.stack([mag, cos, sin], dim=1) # (b 3 f t)
|
140 |
+
mag_mask, real, imag = self.net(x).unbind(1) # (b 3 f t)
|
141 |
+
mag_mask = mag_mask.sigmoid() # (b f t)
|
142 |
+
real = real.tanh() # (b f t)
|
143 |
+
imag = imag.tanh() # (b f t)
|
144 |
+
_, cos_res, sin_res = self._magphase(real, imag) # (b f t)
|
145 |
+
return mag_mask, sin_res, cos_res
|
146 |
+
|
147 |
+
def _separate(self, mag, cos, sin, mag_mask, cos_res, sin_res):
|
148 |
+
"""Ref: https://audio-agi.github.io/Separate-Anything-You-Describe/AudioSep_arXiv.pdf"""
|
149 |
+
sep_mag = F.relu(mag * mag_mask)
|
150 |
+
sep_cos = cos * cos_res - sin * sin_res
|
151 |
+
sep_sin = sin * cos_res + cos * sin_res
|
152 |
+
return sep_mag, sep_cos, sep_sin
|
153 |
+
|
154 |
+
def forward(self, x: Tensor, y: Tensor | None = None):
|
155 |
+
"""
|
156 |
+
Args:
|
157 |
+
x: (b t), a mixed audio
|
158 |
+
y: (b t), a fg audio
|
159 |
+
"""
|
160 |
+
assert x.dim() == 2, f"Expected (b t), got {x.size()}"
|
161 |
+
x = x.to(self.dummy)
|
162 |
+
x = _normalize(x)
|
163 |
+
|
164 |
+
if y is not None:
|
165 |
+
assert y.dim() == 2, f"Expected (b t), got {y.size()}"
|
166 |
+
y = y.to(self.dummy)
|
167 |
+
y = _normalize(y)
|
168 |
+
|
169 |
+
mag, cos, sin = self._stft(x) # (b 2f t)
|
170 |
+
mag_mask, sin_res, cos_res = self._predict(mag, cos, sin)
|
171 |
+
sep_mag, sep_cos, sep_sin = self._separate(mag, cos, sin, mag_mask, cos_res, sin_res)
|
172 |
+
|
173 |
+
o = self._istft(sep_mag, sep_cos, sep_sin)
|
174 |
+
|
175 |
+
npad = x.shape[-1] - o.shape[-1]
|
176 |
+
o = F.pad(o, (0, npad))
|
177 |
+
|
178 |
+
if y is not None:
|
179 |
+
self.losses = dict(l1=F.l1_loss(o, y))
|
180 |
+
|
181 |
+
return o
|
modules/repos_static/resemble_enhance/denoiser/hparams.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
|
3 |
+
from ..hparams import HParams as HParamsBase
|
4 |
+
|
5 |
+
|
6 |
+
@dataclass(frozen=True)
|
7 |
+
class HParams(HParamsBase):
|
8 |
+
batch_size_per_gpu: int = 128
|
9 |
+
distort_prob: float = 0.5
|
modules/repos_static/resemble_enhance/denoiser/inference.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from functools import cache
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from ..denoiser.denoiser import Denoiser
|
7 |
+
|
8 |
+
from ..inference import inference
|
9 |
+
from .hparams import HParams
|
10 |
+
|
11 |
+
logger = logging.getLogger(__name__)
|
12 |
+
|
13 |
+
|
14 |
+
@cache
|
15 |
+
def load_denoiser(run_dir, device):
|
16 |
+
if run_dir is None:
|
17 |
+
return Denoiser(HParams())
|
18 |
+
hp = HParams.load(run_dir)
|
19 |
+
denoiser = Denoiser(hp)
|
20 |
+
path = run_dir / "ds" / "G" / "default" / "mp_rank_00_model_states.pt"
|
21 |
+
state_dict = torch.load(path, map_location="cpu")["module"]
|
22 |
+
denoiser.load_state_dict(state_dict)
|
23 |
+
denoiser.eval()
|
24 |
+
denoiser.to(device)
|
25 |
+
return denoiser
|
26 |
+
|
27 |
+
|
28 |
+
@torch.inference_mode()
|
29 |
+
def denoise(dwav, sr, run_dir, device):
|
30 |
+
denoiser = load_denoiser(run_dir, device)
|
31 |
+
return inference(model=denoiser, dwav=dwav, sr=sr, device=device)
|
modules/repos_static/resemble_enhance/denoiser/unet.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn.functional as F
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
|
5 |
+
class PreactResBlock(nn.Sequential):
|
6 |
+
def __init__(self, dim):
|
7 |
+
super().__init__(
|
8 |
+
nn.GroupNorm(dim // 16, dim),
|
9 |
+
nn.GELU(),
|
10 |
+
nn.Conv2d(dim, dim, 3, padding=1),
|
11 |
+
nn.GroupNorm(dim // 16, dim),
|
12 |
+
nn.GELU(),
|
13 |
+
nn.Conv2d(dim, dim, 3, padding=1),
|
14 |
+
)
|
15 |
+
|
16 |
+
def forward(self, x):
|
17 |
+
return x + super().forward(x)
|
18 |
+
|
19 |
+
|
20 |
+
class UNetBlock(nn.Module):
|
21 |
+
def __init__(self, input_dim, output_dim=None, scale_factor=1.0):
|
22 |
+
super().__init__()
|
23 |
+
if output_dim is None:
|
24 |
+
output_dim = input_dim
|
25 |
+
self.pre_conv = nn.Conv2d(input_dim, output_dim, 3, padding=1)
|
26 |
+
self.res_block1 = PreactResBlock(output_dim)
|
27 |
+
self.res_block2 = PreactResBlock(output_dim)
|
28 |
+
self.downsample = self.upsample = nn.Identity()
|
29 |
+
if scale_factor > 1:
|
30 |
+
self.upsample = nn.Upsample(scale_factor=scale_factor)
|
31 |
+
elif scale_factor < 1:
|
32 |
+
self.downsample = nn.Upsample(scale_factor=scale_factor)
|
33 |
+
|
34 |
+
def forward(self, x, h=None):
|
35 |
+
"""
|
36 |
+
Args:
|
37 |
+
x: (b c h w), last output
|
38 |
+
h: (b c h w), skip output
|
39 |
+
Returns:
|
40 |
+
o: (b c h w), output
|
41 |
+
s: (b c h w), skip output
|
42 |
+
"""
|
43 |
+
x = self.upsample(x)
|
44 |
+
if h is not None:
|
45 |
+
assert x.shape == h.shape, f"{x.shape} != {h.shape}"
|
46 |
+
x = x + h
|
47 |
+
x = self.pre_conv(x)
|
48 |
+
x = self.res_block1(x)
|
49 |
+
x = self.res_block2(x)
|
50 |
+
return self.downsample(x), x
|
51 |
+
|
52 |
+
|
53 |
+
class UNet(nn.Module):
|
54 |
+
def __init__(self, input_dim, output_dim, hidden_dim=16, num_blocks=4, num_middle_blocks=2):
|
55 |
+
super().__init__()
|
56 |
+
self.input_dim = input_dim
|
57 |
+
self.output_dim = output_dim
|
58 |
+
self.input_proj = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
|
59 |
+
self.encoder_blocks = nn.ModuleList(
|
60 |
+
[
|
61 |
+
UNetBlock(input_dim=hidden_dim * 2**i, output_dim=hidden_dim * 2 ** (i + 1), scale_factor=0.5)
|
62 |
+
for i in range(num_blocks)
|
63 |
+
]
|
64 |
+
)
|
65 |
+
self.middle_blocks = nn.ModuleList(
|
66 |
+
[UNetBlock(input_dim=hidden_dim * 2**num_blocks) for _ in range(num_middle_blocks)]
|
67 |
+
)
|
68 |
+
self.decoder_blocks = nn.ModuleList(
|
69 |
+
[
|
70 |
+
UNetBlock(input_dim=hidden_dim * 2 ** (i + 1), output_dim=hidden_dim * 2**i, scale_factor=2)
|
71 |
+
for i in reversed(range(num_blocks))
|
72 |
+
]
|
73 |
+
)
|
74 |
+
self.head = nn.Sequential(
|
75 |
+
nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
|
76 |
+
nn.GELU(),
|
77 |
+
nn.Conv2d(hidden_dim, output_dim, 1),
|
78 |
+
)
|
79 |
+
|
80 |
+
@property
|
81 |
+
def scale_factor(self):
|
82 |
+
return 2 ** len(self.encoder_blocks)
|
83 |
+
|
84 |
+
def pad_to_fit(self, x):
|
85 |
+
"""
|
86 |
+
Args:
|
87 |
+
x: (b c h w), input
|
88 |
+
Returns:
|
89 |
+
x: (b c h' w'), padded input
|
90 |
+
"""
|
91 |
+
hpad = (self.scale_factor - x.shape[2] % self.scale_factor) % self.scale_factor
|
92 |
+
wpad = (self.scale_factor - x.shape[3] % self.scale_factor) % self.scale_factor
|
93 |
+
return F.pad(x, (0, wpad, 0, hpad))
|
94 |
+
|
95 |
+
def forward(self, x):
|
96 |
+
"""
|
97 |
+
Args:
|
98 |
+
x: (b c h w), input
|
99 |
+
Returns:
|
100 |
+
o: (b c h w), output
|
101 |
+
"""
|
102 |
+
shape = x.shape
|
103 |
+
|
104 |
+
x = self.pad_to_fit(x)
|
105 |
+
x = self.input_proj(x)
|
106 |
+
|
107 |
+
s_list = []
|
108 |
+
for block in self.encoder_blocks:
|
109 |
+
x, s = block(x)
|
110 |
+
s_list.append(s)
|
111 |
+
|
112 |
+
for block in self.middle_blocks:
|
113 |
+
x, _ = block(x)
|
114 |
+
|
115 |
+
for block, s in zip(self.decoder_blocks, reversed(s_list)):
|
116 |
+
x, _ = block(x, s)
|
117 |
+
|
118 |
+
x = self.head(x)
|
119 |
+
x = x[..., : shape[2], : shape[3]]
|
120 |
+
|
121 |
+
return x
|
122 |
+
|
123 |
+
def test(self, shape=(3, 512, 256)):
|
124 |
+
import ptflops
|
125 |
+
|
126 |
+
macs, params = ptflops.get_model_complexity_info(
|
127 |
+
self,
|
128 |
+
shape,
|
129 |
+
as_strings=True,
|
130 |
+
print_per_layer_stat=True,
|
131 |
+
verbose=True,
|
132 |
+
)
|
133 |
+
|
134 |
+
print(f"macs: {macs}")
|
135 |
+
print(f"params: {params}")
|
136 |
+
|
137 |
+
|
138 |
+
def main():
|
139 |
+
model = UNet(3, 3)
|
140 |
+
model.test()
|
141 |
+
|
142 |
+
|
143 |
+
if __name__ == "__main__":
|
144 |
+
main()
|
modules/repos_static/resemble_enhance/enhancer/__init__.py
ADDED
File without changes
|
modules/repos_static/resemble_enhance/enhancer/__main__.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import random
|
3 |
+
import time
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torchaudio
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
+
from .inference import denoise, enhance
|
11 |
+
|
12 |
+
|
13 |
+
@torch.inference_mode()
|
14 |
+
def main():
|
15 |
+
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
16 |
+
parser.add_argument("in_dir", type=Path, help="Path to input audio folder")
|
17 |
+
parser.add_argument("out_dir", type=Path, help="Output folder")
|
18 |
+
parser.add_argument(
|
19 |
+
"--run_dir",
|
20 |
+
type=Path,
|
21 |
+
default=None,
|
22 |
+
help="Path to the enhancer run folder, if None, use the default model",
|
23 |
+
)
|
24 |
+
parser.add_argument(
|
25 |
+
"--suffix",
|
26 |
+
type=str,
|
27 |
+
default=".wav",
|
28 |
+
help="Audio file suffix",
|
29 |
+
)
|
30 |
+
parser.add_argument(
|
31 |
+
"--device",
|
32 |
+
type=str,
|
33 |
+
default="cuda",
|
34 |
+
help="Device to use for computation, recommended to use CUDA",
|
35 |
+
)
|
36 |
+
parser.add_argument(
|
37 |
+
"--denoise_only",
|
38 |
+
action="store_true",
|
39 |
+
help="Only apply denoising without enhancement",
|
40 |
+
)
|
41 |
+
parser.add_argument(
|
42 |
+
"--lambd",
|
43 |
+
type=float,
|
44 |
+
default=1.0,
|
45 |
+
help="Denoise strength for enhancement (0.0 to 1.0)",
|
46 |
+
)
|
47 |
+
parser.add_argument(
|
48 |
+
"--tau",
|
49 |
+
type=float,
|
50 |
+
default=0.5,
|
51 |
+
help="CFM prior temperature (0.0 to 1.0)",
|
52 |
+
)
|
53 |
+
parser.add_argument(
|
54 |
+
"--solver",
|
55 |
+
type=str,
|
56 |
+
default="midpoint",
|
57 |
+
choices=["midpoint", "rk4", "euler"],
|
58 |
+
help="Numerical solver to use",
|
59 |
+
)
|
60 |
+
parser.add_argument(
|
61 |
+
"--nfe",
|
62 |
+
type=int,
|
63 |
+
default=64,
|
64 |
+
help="Number of function evaluations",
|
65 |
+
)
|
66 |
+
parser.add_argument(
|
67 |
+
"--parallel_mode",
|
68 |
+
action="store_true",
|
69 |
+
help="Shuffle the audio paths and skip the existing ones, enabling multiple jobs to run in parallel",
|
70 |
+
)
|
71 |
+
|
72 |
+
args = parser.parse_args()
|
73 |
+
|
74 |
+
device = args.device
|
75 |
+
|
76 |
+
if device == "cuda" and not torch.cuda.is_available():
|
77 |
+
print("CUDA is not available but --device is set to cuda, using CPU instead")
|
78 |
+
device = "cpu"
|
79 |
+
|
80 |
+
start_time = time.perf_counter()
|
81 |
+
|
82 |
+
run_dir = args.run_dir
|
83 |
+
|
84 |
+
paths = sorted(args.in_dir.glob(f"**/*{args.suffix}"))
|
85 |
+
|
86 |
+
if args.parallel_mode:
|
87 |
+
random.shuffle(paths)
|
88 |
+
|
89 |
+
if len(paths) == 0:
|
90 |
+
print(f"No {args.suffix} files found in the following path: {args.in_dir}")
|
91 |
+
return
|
92 |
+
|
93 |
+
pbar = tqdm(paths)
|
94 |
+
|
95 |
+
for path in pbar:
|
96 |
+
out_path = args.out_dir / path.relative_to(args.in_dir)
|
97 |
+
if args.parallel_mode and out_path.exists():
|
98 |
+
continue
|
99 |
+
pbar.set_description(f"Processing {out_path}")
|
100 |
+
dwav, sr = torchaudio.load(path)
|
101 |
+
dwav = dwav.mean(0)
|
102 |
+
if args.denoise_only:
|
103 |
+
hwav, sr = denoise(
|
104 |
+
dwav=dwav,
|
105 |
+
sr=sr,
|
106 |
+
device=device,
|
107 |
+
run_dir=args.run_dir,
|
108 |
+
)
|
109 |
+
else:
|
110 |
+
hwav, sr = enhance(
|
111 |
+
dwav=dwav,
|
112 |
+
sr=sr,
|
113 |
+
device=device,
|
114 |
+
nfe=args.nfe,
|
115 |
+
solver=args.solver,
|
116 |
+
lambd=args.lambd,
|
117 |
+
tau=args.tau,
|
118 |
+
run_dir=run_dir,
|
119 |
+
)
|
120 |
+
out_path.parent.mkdir(parents=True, exist_ok=True)
|
121 |
+
torchaudio.save(out_path, hwav[None], sr)
|
122 |
+
|
123 |
+
# Cool emoji effect saying the job is done
|
124 |
+
elapsed_time = time.perf_counter() - start_time
|
125 |
+
print(f"🌟 Enhancement done! {len(paths)} files processed in {elapsed_time:.2f}s")
|
126 |
+
|
127 |
+
|
128 |
+
if __name__ == "__main__":
|
129 |
+
main()
|
modules/repos_static/resemble_enhance/enhancer/download.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
RUN_NAME = "enhancer_stage2"
|
7 |
+
|
8 |
+
logger = logging.getLogger(__name__)
|
9 |
+
|
10 |
+
|
11 |
+
def get_source_url(relpath):
|
12 |
+
return f"https://huggingface.co/ResembleAI/resemble-enhance/resolve/main/{RUN_NAME}/{relpath}?download=true"
|
13 |
+
|
14 |
+
|
15 |
+
def get_target_path(relpath: str | Path, run_dir: str | Path | None = None):
|
16 |
+
if run_dir is None:
|
17 |
+
run_dir = Path(__file__).parent.parent / "model_repo" / RUN_NAME
|
18 |
+
return Path(run_dir) / relpath
|
19 |
+
|
20 |
+
|
21 |
+
def download(run_dir: str | Path | None = None):
|
22 |
+
relpaths = ["hparams.yaml", "ds/G/latest", "ds/G/default/mp_rank_00_model_states.pt"]
|
23 |
+
for relpath in relpaths:
|
24 |
+
path = get_target_path(relpath, run_dir=run_dir)
|
25 |
+
if path.exists():
|
26 |
+
continue
|
27 |
+
url = get_source_url(relpath)
|
28 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
29 |
+
torch.hub.download_url_to_file(url, str(path))
|
30 |
+
return get_target_path("", run_dir=run_dir)
|
modules/repos_static/resemble_enhance/enhancer/enhancer.py
ADDED
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
import pandas as pd
|
5 |
+
import torch
|
6 |
+
from torch import Tensor, nn
|
7 |
+
from torch.distributions import Beta
|
8 |
+
|
9 |
+
from ..common import Normalizer
|
10 |
+
from ..denoiser.inference import load_denoiser
|
11 |
+
from ..melspec import MelSpectrogram
|
12 |
+
from .hparams import HParams
|
13 |
+
from .lcfm import CFM, IRMAE, LCFM
|
14 |
+
from .univnet import UnivNet
|
15 |
+
|
16 |
+
logger = logging.getLogger(__name__)
|
17 |
+
|
18 |
+
|
19 |
+
def _maybe(fn):
|
20 |
+
def _fn(*args):
|
21 |
+
if args[0] is None:
|
22 |
+
return None
|
23 |
+
return fn(*args)
|
24 |
+
|
25 |
+
return _fn
|
26 |
+
|
27 |
+
|
28 |
+
def _normalize_wav(x: Tensor):
|
29 |
+
return x / (x.abs().max(dim=-1, keepdim=True).values + 1e-7)
|
30 |
+
|
31 |
+
|
32 |
+
class Enhancer(nn.Module):
|
33 |
+
def __init__(self, hp: HParams):
|
34 |
+
super().__init__()
|
35 |
+
self.hp = hp
|
36 |
+
|
37 |
+
n_mels = self.hp.num_mels
|
38 |
+
vocoder_input_dim = n_mels + self.hp.vocoder_extra_dim
|
39 |
+
latent_dim = self.hp.lcfm_latent_dim
|
40 |
+
|
41 |
+
self.lcfm = LCFM(
|
42 |
+
IRMAE(
|
43 |
+
input_dim=n_mels,
|
44 |
+
output_dim=vocoder_input_dim,
|
45 |
+
latent_dim=latent_dim,
|
46 |
+
),
|
47 |
+
CFM(
|
48 |
+
cond_dim=n_mels,
|
49 |
+
output_dim=self.hp.lcfm_latent_dim,
|
50 |
+
solver_nfe=self.hp.cfm_solver_nfe,
|
51 |
+
solver_method=self.hp.cfm_solver_method,
|
52 |
+
time_mapping_divisor=self.hp.cfm_time_mapping_divisor,
|
53 |
+
),
|
54 |
+
z_scale=self.hp.lcfm_z_scale,
|
55 |
+
)
|
56 |
+
|
57 |
+
self.lcfm.set_mode_(self.hp.lcfm_training_mode)
|
58 |
+
|
59 |
+
self.mel_fn = MelSpectrogram(hp)
|
60 |
+
self.vocoder = UnivNet(self.hp, vocoder_input_dim)
|
61 |
+
self.denoiser = load_denoiser(self.hp.denoiser_run_dir, "cpu")
|
62 |
+
self.normalizer = Normalizer()
|
63 |
+
|
64 |
+
self._eval_lambd = 0.0
|
65 |
+
|
66 |
+
self.dummy: Tensor
|
67 |
+
self.register_buffer("dummy", torch.zeros(1))
|
68 |
+
|
69 |
+
if self.hp.enhancer_stage1_run_dir is not None:
|
70 |
+
pretrained_path = (
|
71 |
+
self.hp.enhancer_stage1_run_dir
|
72 |
+
/ "ds/G/default/mp_rank_00_model_states.pt"
|
73 |
+
)
|
74 |
+
self._load_pretrained(pretrained_path)
|
75 |
+
|
76 |
+
logger.info(f"{self.__class__.__name__} summary")
|
77 |
+
logger.info(f"{self.summarize()}")
|
78 |
+
|
79 |
+
def _load_pretrained(self, path):
|
80 |
+
# Clone is necessary as otherwise it holds a reference to the original model
|
81 |
+
cfm_state_dict = {k: v.clone() for k, v in self.lcfm.cfm.state_dict().items()}
|
82 |
+
denoiser_state_dict = {
|
83 |
+
k: v.clone() for k, v in self.denoiser.state_dict().items()
|
84 |
+
}
|
85 |
+
state_dict = torch.load(path, map_location="cpu")["module"]
|
86 |
+
self.load_state_dict(state_dict, strict=False)
|
87 |
+
self.lcfm.cfm.load_state_dict(cfm_state_dict) # Reset cfm
|
88 |
+
self.denoiser.load_state_dict(denoiser_state_dict) # Reset denoiser
|
89 |
+
logger.info(f"Loaded pretrained model from {path}")
|
90 |
+
|
91 |
+
def summarize(self):
|
92 |
+
npa_train = lambda m: sum(p.numel() for p in m.parameters() if p.requires_grad)
|
93 |
+
npa = lambda m: sum(p.numel() for p in m.parameters())
|
94 |
+
rows = []
|
95 |
+
for name, module in self.named_children():
|
96 |
+
rows.append(dict(name=name, trainable=npa_train(module), total=npa(module)))
|
97 |
+
rows.append(dict(name="total", trainable=npa_train(self), total=npa(self)))
|
98 |
+
df = pd.DataFrame(rows)
|
99 |
+
return df.to_markdown(index=False)
|
100 |
+
|
101 |
+
def to_mel(self, x: Tensor, drop_last=True):
|
102 |
+
"""
|
103 |
+
Args:
|
104 |
+
x: (b t), wavs
|
105 |
+
Returns:
|
106 |
+
o: (b c t), mels
|
107 |
+
"""
|
108 |
+
if drop_last:
|
109 |
+
return self.mel_fn(x)[..., :-1] # (b d t)
|
110 |
+
return self.mel_fn(x)
|
111 |
+
|
112 |
+
def _may_denoise(self, x: Tensor, y: Tensor | None = None):
|
113 |
+
if self.hp.lcfm_training_mode == "cfm":
|
114 |
+
return self.denoiser(x, y)
|
115 |
+
return x
|
116 |
+
|
117 |
+
def configurate_(self, nfe, solver, lambd, tau):
|
118 |
+
"""
|
119 |
+
Args:
|
120 |
+
nfe: number of function evaluations
|
121 |
+
solver: solver method
|
122 |
+
lambd: denoiser strength [0, 1]
|
123 |
+
tau: prior temperature [0, 1]
|
124 |
+
"""
|
125 |
+
self.lcfm.cfm.solver.configurate_(nfe, solver)
|
126 |
+
self.lcfm.eval_tau_(tau)
|
127 |
+
self._eval_lambd = lambd
|
128 |
+
|
129 |
+
def forward(self, x: Tensor, y: Tensor | None = None, z: Tensor | None = None):
|
130 |
+
"""
|
131 |
+
Args:
|
132 |
+
x: (b t), mix wavs (fg + bg)
|
133 |
+
y: (b t), fg clean wavs
|
134 |
+
z: (b t), fg distorted wavs
|
135 |
+
Returns:
|
136 |
+
o: (b t), reconstructed wavs
|
137 |
+
"""
|
138 |
+
assert x.dim() == 2, f"Expected (b t), got {x.size()}"
|
139 |
+
assert y is None or y.dim() == 2, f"Expected (b t), got {y.size()}"
|
140 |
+
|
141 |
+
if self.hp.lcfm_training_mode == "cfm":
|
142 |
+
self.normalizer.eval()
|
143 |
+
|
144 |
+
x = _normalize_wav(x)
|
145 |
+
y = _maybe(_normalize_wav)(y)
|
146 |
+
z = _maybe(_normalize_wav)(z)
|
147 |
+
|
148 |
+
x_mel_original = self.normalizer(self.to_mel(x), update=False) # (b d t)
|
149 |
+
|
150 |
+
if self.hp.lcfm_training_mode == "cfm":
|
151 |
+
if self.training:
|
152 |
+
lambd = Beta(0.2, 0.2).sample(x.shape[:1]).to(x.device)
|
153 |
+
lambd = lambd[:, None, None]
|
154 |
+
x_mel_denoised = self.normalizer(
|
155 |
+
self.to_mel(self._may_denoise(x, z)), update=False
|
156 |
+
)
|
157 |
+
x_mel_denoised = x_mel_denoised.detach()
|
158 |
+
x_mel_denoised = lambd * x_mel_denoised + (1 - lambd) * x_mel_original
|
159 |
+
self._visualize(x_mel_original, x_mel_denoised)
|
160 |
+
else:
|
161 |
+
lambd = self._eval_lambd
|
162 |
+
if lambd == 0:
|
163 |
+
x_mel_denoised = x_mel_original
|
164 |
+
else:
|
165 |
+
x_mel_denoised = self.normalizer(
|
166 |
+
self.to_mel(self._may_denoise(x, z)), update=False
|
167 |
+
)
|
168 |
+
x_mel_denoised = x_mel_denoised.detach()
|
169 |
+
x_mel_denoised = (
|
170 |
+
lambd * x_mel_denoised + (1 - lambd) * x_mel_original
|
171 |
+
)
|
172 |
+
else:
|
173 |
+
x_mel_denoised = x_mel_original
|
174 |
+
|
175 |
+
y_mel = _maybe(self.to_mel)(y) # (b d t)
|
176 |
+
y_mel = _maybe(self.normalizer)(y_mel)
|
177 |
+
|
178 |
+
lcfm_decoded = self.lcfm(x_mel_denoised, y_mel, ψ0=x_mel_original) # (b d t)
|
179 |
+
|
180 |
+
if lcfm_decoded is None:
|
181 |
+
o = None
|
182 |
+
else:
|
183 |
+
o = self.vocoder(lcfm_decoded, y)
|
184 |
+
|
185 |
+
return o
|
modules/repos_static/resemble_enhance/enhancer/hparams.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
from ..hparams import HParams as HParamsBase
|
5 |
+
|
6 |
+
|
7 |
+
@dataclass(frozen=True)
|
8 |
+
class HParams(HParamsBase):
|
9 |
+
cfm_solver_method: str = "midpoint"
|
10 |
+
cfm_solver_nfe: int = 64
|
11 |
+
cfm_time_mapping_divisor: int = 4
|
12 |
+
univnet_nc: int = 96
|
13 |
+
|
14 |
+
lcfm_latent_dim: int = 64
|
15 |
+
lcfm_training_mode: str = "ae"
|
16 |
+
lcfm_z_scale: float = 5
|
17 |
+
|
18 |
+
vocoder_extra_dim: int = 32
|
19 |
+
|
20 |
+
gan_training_start_step: int | None = 5_000
|
21 |
+
enhancer_stage1_run_dir: Path | None = None
|
22 |
+
|
23 |
+
denoiser_run_dir: Path | None = None
|
modules/repos_static/resemble_enhance/enhancer/inference.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from functools import cache
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from ..inference import inference
|
8 |
+
from .download import download
|
9 |
+
from .hparams import HParams
|
10 |
+
from .enhancer import Enhancer
|
11 |
+
|
12 |
+
logger = logging.getLogger(__name__)
|
13 |
+
|
14 |
+
|
15 |
+
@cache
|
16 |
+
def load_enhancer(run_dir: str | Path | None, device):
|
17 |
+
run_dir = download(run_dir)
|
18 |
+
hp = HParams.load(run_dir)
|
19 |
+
enhancer = Enhancer(hp)
|
20 |
+
path = run_dir / "ds" / "G" / "default" / "mp_rank_00_model_states.pt"
|
21 |
+
state_dict = torch.load(path, map_location="cpu")["module"]
|
22 |
+
enhancer.load_state_dict(state_dict)
|
23 |
+
enhancer.eval()
|
24 |
+
enhancer.to(device)
|
25 |
+
return enhancer
|
26 |
+
|
27 |
+
|
28 |
+
@torch.inference_mode()
|
29 |
+
def denoise(dwav, sr, device, run_dir=None):
|
30 |
+
enhancer = load_enhancer(run_dir, device)
|
31 |
+
return inference(model=enhancer.denoiser, dwav=dwav, sr=sr, device=device)
|
32 |
+
|
33 |
+
|
34 |
+
@torch.inference_mode()
|
35 |
+
def enhance(
|
36 |
+
dwav, sr, device, nfe=32, solver="midpoint", lambd=0.5, tau=0.5, run_dir=None
|
37 |
+
):
|
38 |
+
assert 0 < nfe <= 128, f"nfe must be in (0, 128], got {nfe}"
|
39 |
+
assert solver in (
|
40 |
+
"midpoint",
|
41 |
+
"rk4",
|
42 |
+
"euler",
|
43 |
+
), f"solver must be in ('midpoint', 'rk4', 'euler'), got {solver}"
|
44 |
+
assert 0 <= lambd <= 1, f"lambd must be in [0, 1], got {lambd}"
|
45 |
+
assert 0 <= tau <= 1, f"tau must be in [0, 1], got {tau}"
|
46 |
+
enhancer = load_enhancer(run_dir, device)
|
47 |
+
enhancer.configurate_(nfe=nfe, solver=solver, lambd=lambd, tau=tau)
|
48 |
+
return inference(model=enhancer, dwav=dwav, sr=sr, device=device)
|
modules/repos_static/resemble_enhance/enhancer/lcfm/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .irmae import IRMAE
|
2 |
+
from .lcfm import CFM, LCFM
|
modules/repos_static/resemble_enhance/enhancer/lcfm/cfm.py
ADDED
@@ -0,0 +1,372 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from dataclasses import dataclass
|
3 |
+
from functools import partial
|
4 |
+
from typing import Protocol
|
5 |
+
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
import numpy as np
|
8 |
+
import scipy
|
9 |
+
import torch
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from torch import Tensor, nn
|
12 |
+
from tqdm import trange
|
13 |
+
|
14 |
+
from .wn import WN
|
15 |
+
|
16 |
+
logger = logging.getLogger(__name__)
|
17 |
+
|
18 |
+
|
19 |
+
class VelocityField(Protocol):
|
20 |
+
def __call__(self, *, t: Tensor, ψt: Tensor, dt: Tensor) -> Tensor:
|
21 |
+
...
|
22 |
+
|
23 |
+
|
24 |
+
class Solver:
|
25 |
+
def __init__(
|
26 |
+
self,
|
27 |
+
method="midpoint",
|
28 |
+
nfe=32,
|
29 |
+
viz_name="solver",
|
30 |
+
viz_every=100,
|
31 |
+
mel_fn=None,
|
32 |
+
time_mapping_divisor=4,
|
33 |
+
verbose=False,
|
34 |
+
):
|
35 |
+
self.configurate_(nfe=nfe, method=method)
|
36 |
+
|
37 |
+
self.verbose = verbose
|
38 |
+
self.viz_every = viz_every
|
39 |
+
self.viz_name = viz_name
|
40 |
+
|
41 |
+
self._camera = None
|
42 |
+
self._mel_fn = mel_fn
|
43 |
+
self._time_mapping = partial(self.exponential_decay_mapping, n=time_mapping_divisor)
|
44 |
+
|
45 |
+
def configurate_(self, nfe=None, method=None):
|
46 |
+
if nfe is None:
|
47 |
+
nfe = self.nfe
|
48 |
+
|
49 |
+
if method is None:
|
50 |
+
method = self.method
|
51 |
+
|
52 |
+
if nfe == 1 and method in ("midpoint", "rk4"):
|
53 |
+
logger.warning(f"1 NFE is not supported for {method}, using euler method instead.")
|
54 |
+
method = "euler"
|
55 |
+
|
56 |
+
self.nfe = nfe
|
57 |
+
self.method = method
|
58 |
+
|
59 |
+
@property
|
60 |
+
def time_mapping(self):
|
61 |
+
return self._time_mapping
|
62 |
+
|
63 |
+
@staticmethod
|
64 |
+
def exponential_decay_mapping(t, n=4):
|
65 |
+
"""
|
66 |
+
Args:
|
67 |
+
n: target step
|
68 |
+
"""
|
69 |
+
|
70 |
+
def h(t, a):
|
71 |
+
return (a**t - 1) / (a - 1)
|
72 |
+
|
73 |
+
# Solve h(1/n) = 0.5
|
74 |
+
a = float(scipy.optimize.fsolve(lambda a: h(1 / n, a) - 0.5, x0=0))
|
75 |
+
|
76 |
+
t = h(t, a=a)
|
77 |
+
|
78 |
+
return t
|
79 |
+
|
80 |
+
@torch.no_grad()
|
81 |
+
def _maybe_camera_snap(self, *, ψt, t):
|
82 |
+
camera = self._camera
|
83 |
+
if camera is not None:
|
84 |
+
if ψt.shape[1] == 1:
|
85 |
+
# Waveform, b 1 t, plot every 100 samples
|
86 |
+
plt.subplot(211)
|
87 |
+
plt.plot(ψt.detach().cpu().numpy()[0, 0, ::100], color="blue")
|
88 |
+
if self._mel_fn is not None:
|
89 |
+
plt.subplot(212)
|
90 |
+
mel = self._mel_fn(ψt.detach().cpu().numpy()[0, 0])
|
91 |
+
plt.imshow(mel, origin="lower", interpolation="none")
|
92 |
+
elif ψt.shape[1] == 2:
|
93 |
+
# Complex
|
94 |
+
plt.subplot(121)
|
95 |
+
plt.imshow(
|
96 |
+
ψt.detach().cpu().numpy()[0, 0],
|
97 |
+
origin="lower",
|
98 |
+
interpolation="none",
|
99 |
+
)
|
100 |
+
plt.subplot(122)
|
101 |
+
plt.imshow(
|
102 |
+
ψt.detach().cpu().numpy()[0, 1],
|
103 |
+
origin="lower",
|
104 |
+
interpolation="none",
|
105 |
+
)
|
106 |
+
else:
|
107 |
+
# Spectrogram, b c t
|
108 |
+
plt.imshow(ψt.detach().cpu().numpy()[0], origin="lower", interpolation="none")
|
109 |
+
ax = plt.gca()
|
110 |
+
ax.text(0.5, 1.01, f"t={t:.2f}", transform=ax.transAxes, ha="center")
|
111 |
+
camera.snap()
|
112 |
+
|
113 |
+
@staticmethod
|
114 |
+
def _euler_step(t, ψt, dt, f: VelocityField):
|
115 |
+
return ψt + dt * f(t=t, ψt=ψt, dt=dt)
|
116 |
+
|
117 |
+
@staticmethod
|
118 |
+
def _midpoint_step(t, ψt, dt, f: VelocityField):
|
119 |
+
return ψt + dt * f(t=t + dt / 2, ψt=ψt + dt * f(t=t, ψt=ψt, dt=dt) / 2, dt=dt)
|
120 |
+
|
121 |
+
@staticmethod
|
122 |
+
def _rk4_step(t, ψt, dt, f: VelocityField):
|
123 |
+
k1 = f(t=t, ψt=ψt, dt=dt)
|
124 |
+
k2 = f(t=t + dt / 2, ψt=ψt + dt * k1 / 2, dt=dt)
|
125 |
+
k3 = f(t=t + dt / 2, ψt=ψt + dt * k2 / 2, dt=dt)
|
126 |
+
k4 = f(t=t + dt, ψt=ψt + dt * k3, dt=dt)
|
127 |
+
return ψt + dt * (k1 + 2 * k2 + 2 * k3 + k4) / 6
|
128 |
+
|
129 |
+
@property
|
130 |
+
def _step(self):
|
131 |
+
if self.method == "euler":
|
132 |
+
return self._euler_step
|
133 |
+
elif self.method == "midpoint":
|
134 |
+
return self._midpoint_step
|
135 |
+
elif self.method == "rk4":
|
136 |
+
return self._rk4_step
|
137 |
+
else:
|
138 |
+
raise ValueError(f"Unknown method: {self.method}")
|
139 |
+
|
140 |
+
def get_running_train_loop(self):
|
141 |
+
try:
|
142 |
+
# Lazy import
|
143 |
+
from ...utils.train_loop import TrainLoop
|
144 |
+
|
145 |
+
return TrainLoop.get_running_loop()
|
146 |
+
except ImportError:
|
147 |
+
return None
|
148 |
+
|
149 |
+
@property
|
150 |
+
def visualizing(self):
|
151 |
+
loop = self.get_running_train_loop()
|
152 |
+
if loop is None:
|
153 |
+
return
|
154 |
+
out_path = loop.make_current_step_viz_path(self.viz_name, ".gif")
|
155 |
+
return loop.global_step % self.viz_every == 0 and not out_path.exists()
|
156 |
+
|
157 |
+
def _reset_camera(self):
|
158 |
+
try:
|
159 |
+
from celluloid import Camera
|
160 |
+
|
161 |
+
self._camera = Camera(plt.figure())
|
162 |
+
except:
|
163 |
+
pass
|
164 |
+
|
165 |
+
def _maybe_dump_camera(self):
|
166 |
+
camera = self._camera
|
167 |
+
loop = self.get_running_train_loop()
|
168 |
+
if camera is not None and loop is not None:
|
169 |
+
animation = camera.animate()
|
170 |
+
out_path = loop.make_current_step_viz_path(self.viz_name, ".gif")
|
171 |
+
out_path.parent.mkdir(exist_ok=True, parents=True)
|
172 |
+
animation.save(out_path, writer="pillow", fps=4)
|
173 |
+
plt.close()
|
174 |
+
self._camera = None
|
175 |
+
|
176 |
+
@property
|
177 |
+
def n_steps(self):
|
178 |
+
n = self.nfe
|
179 |
+
if self.method == "euler":
|
180 |
+
pass
|
181 |
+
elif self.method == "midpoint":
|
182 |
+
n //= 2
|
183 |
+
elif self.method == "rk4":
|
184 |
+
n //= 4
|
185 |
+
else:
|
186 |
+
raise ValueError(f"Unknown method: {self.method}")
|
187 |
+
return n
|
188 |
+
|
189 |
+
def solve(self, f: VelocityField, ψ0: Tensor, t0=0.0, t1=1.0):
|
190 |
+
ts = self._time_mapping(np.linspace(t0, t1, self.n_steps + 1))
|
191 |
+
|
192 |
+
if self.visualizing:
|
193 |
+
self._reset_camera()
|
194 |
+
|
195 |
+
if self.verbose:
|
196 |
+
steps = trange(self.n_steps, desc="CFM inference")
|
197 |
+
else:
|
198 |
+
steps = range(self.n_steps)
|
199 |
+
|
200 |
+
ψt = ψ0
|
201 |
+
|
202 |
+
for i in steps:
|
203 |
+
dt = ts[i + 1] - ts[i]
|
204 |
+
t = ts[i]
|
205 |
+
self._maybe_camera_snap(ψt=ψt, t=t)
|
206 |
+
ψt = self._step(t=t, ψt=ψt, dt=dt, f=f)
|
207 |
+
|
208 |
+
self._maybe_camera_snap(ψt=ψt, t=ts[-1])
|
209 |
+
|
210 |
+
ψ1 = ψt
|
211 |
+
del ψt
|
212 |
+
|
213 |
+
self._maybe_dump_camera()
|
214 |
+
|
215 |
+
return ψ1
|
216 |
+
|
217 |
+
def __call__(self, f: VelocityField, ψ0: Tensor, t0=0.0, t1=1.0):
|
218 |
+
return self.solve(f=f, ψ0=ψ0, t0=t0, t1=t1)
|
219 |
+
|
220 |
+
|
221 |
+
class SinusodialTimeEmbedding(nn.Module):
|
222 |
+
def __init__(self, d_embed):
|
223 |
+
super().__init__()
|
224 |
+
self.d_embed = d_embed
|
225 |
+
assert d_embed % 2 == 0
|
226 |
+
|
227 |
+
def forward(self, t):
|
228 |
+
t = t.unsqueeze(-1) # ... 1
|
229 |
+
p = torch.linspace(0, 4, self.d_embed // 2).to(t)
|
230 |
+
while p.dim() < t.dim():
|
231 |
+
p = p.unsqueeze(0) # ... d/2
|
232 |
+
sin = torch.sin(t * 10**p)
|
233 |
+
cos = torch.cos(t * 10**p)
|
234 |
+
return torch.cat([sin, cos], dim=-1)
|
235 |
+
|
236 |
+
|
237 |
+
@dataclass(eq=False)
|
238 |
+
class CFM(nn.Module):
|
239 |
+
"""
|
240 |
+
This mixin is for general diffusion models.
|
241 |
+
|
242 |
+
ψ0 stands for the gaussian noise, and ψ1 is the data point.
|
243 |
+
|
244 |
+
Here we follow the CFM style:
|
245 |
+
The generation process (reverse process) is from t=0 to t=1.
|
246 |
+
The forward process is from t=1 to t=0.
|
247 |
+
"""
|
248 |
+
|
249 |
+
cond_dim: int
|
250 |
+
output_dim: int
|
251 |
+
time_emb_dim: int = 128
|
252 |
+
viz_name: str = "cfm"
|
253 |
+
solver_nfe: int = 32
|
254 |
+
solver_method: str = "midpoint"
|
255 |
+
time_mapping_divisor: int = 4
|
256 |
+
|
257 |
+
def __post_init__(self):
|
258 |
+
super().__init__()
|
259 |
+
self.solver = Solver(
|
260 |
+
viz_name=self.viz_name,
|
261 |
+
viz_every=1,
|
262 |
+
nfe=self.solver_nfe,
|
263 |
+
method=self.solver_method,
|
264 |
+
time_mapping_divisor=self.time_mapping_divisor,
|
265 |
+
)
|
266 |
+
self.emb = SinusodialTimeEmbedding(self.time_emb_dim)
|
267 |
+
self.net = WN(
|
268 |
+
input_dim=self.output_dim,
|
269 |
+
output_dim=self.output_dim,
|
270 |
+
local_dim=self.cond_dim,
|
271 |
+
global_dim=self.time_emb_dim,
|
272 |
+
)
|
273 |
+
|
274 |
+
def _perturb(self, ψ1: Tensor, t: Tensor | None = None):
|
275 |
+
"""
|
276 |
+
Perturb ψ1 to ψt.
|
277 |
+
"""
|
278 |
+
raise NotImplementedError
|
279 |
+
|
280 |
+
def _sample_ψ0(self, x: Tensor):
|
281 |
+
"""
|
282 |
+
Args:
|
283 |
+
x: (b c t), which implies the shape of ψ0
|
284 |
+
"""
|
285 |
+
shape = list(x.shape)
|
286 |
+
shape[1] = self.output_dim
|
287 |
+
if self.training:
|
288 |
+
g = None
|
289 |
+
else:
|
290 |
+
g = torch.Generator(device=x.device)
|
291 |
+
g.manual_seed(0) # deterministic sampling during eval
|
292 |
+
ψ0 = torch.randn(shape, device=x.device, dtype=x.dtype, generator=g)
|
293 |
+
return ψ0
|
294 |
+
|
295 |
+
@property
|
296 |
+
def sigma(self):
|
297 |
+
return 1e-4
|
298 |
+
|
299 |
+
def _to_ψt(self, *, ψ1: Tensor, ψ0: Tensor, t: Tensor):
|
300 |
+
"""
|
301 |
+
Eq (22)
|
302 |
+
"""
|
303 |
+
while t.dim() < ψ1.dim():
|
304 |
+
t = t.unsqueeze(-1)
|
305 |
+
μ = t * ψ1 + (1 - t) * ψ0
|
306 |
+
return μ + torch.randn_like(μ) * self.sigma
|
307 |
+
|
308 |
+
def _to_u(self, *, ψ1, ψ0: Tensor):
|
309 |
+
"""
|
310 |
+
Eq (21)
|
311 |
+
"""
|
312 |
+
return ψ1 - ψ0
|
313 |
+
|
314 |
+
def _to_v(self, *, ψt, x, t: float | Tensor):
|
315 |
+
"""
|
316 |
+
Args:
|
317 |
+
ψt: (b c t)
|
318 |
+
x: (b c t)
|
319 |
+
t: (b)
|
320 |
+
Returns:
|
321 |
+
v: (b c t)
|
322 |
+
"""
|
323 |
+
if isinstance(t, (float, int)):
|
324 |
+
t = torch.full(ψt.shape[:1], t).to(ψt)
|
325 |
+
t = t.clamp(0, 1) # [0, 1)
|
326 |
+
g = self.emb(t) # (b d)
|
327 |
+
v = self.net(ψt, l=x, g=g)
|
328 |
+
return v
|
329 |
+
|
330 |
+
def compute_losses(self, x, y, ψ0) -> dict:
|
331 |
+
"""
|
332 |
+
Args:
|
333 |
+
x: (b c t)
|
334 |
+
y: (b c t)
|
335 |
+
Returns:
|
336 |
+
losses: dict
|
337 |
+
"""
|
338 |
+
t = torch.rand(len(x), device=x.device, dtype=x.dtype)
|
339 |
+
t = self.solver.time_mapping(t)
|
340 |
+
|
341 |
+
if ψ0 is None:
|
342 |
+
ψ0 = self._sample_ψ0(x)
|
343 |
+
|
344 |
+
ψt = self._to_ψt(ψ1=y, t=t, ψ0=ψ0)
|
345 |
+
|
346 |
+
v = self._to_v(ψt=ψt, t=t, x=x)
|
347 |
+
u = self._to_u(ψ1=y, ψ0=ψ0)
|
348 |
+
|
349 |
+
losses = dict(l1=F.l1_loss(v, u))
|
350 |
+
|
351 |
+
return losses
|
352 |
+
|
353 |
+
@torch.inference_mode()
|
354 |
+
def sample(self, x, ψ0=None, t0=0.0):
|
355 |
+
"""
|
356 |
+
Args:
|
357 |
+
x: (b c t)
|
358 |
+
Returns:
|
359 |
+
y: (b ... t)
|
360 |
+
"""
|
361 |
+
if ψ0 is None:
|
362 |
+
ψ0 = self._sample_ψ0(x)
|
363 |
+
f = lambda t, ψt, dt: self._to_v(ψt=ψt, t=t, x=x)
|
364 |
+
ψ1 = self.solver(f=f, ψ0=ψ0, t0=t0)
|
365 |
+
return ψ1
|
366 |
+
|
367 |
+
def forward(self, x: Tensor, y: Tensor | None = None, ψ0: Tensor | None = None, t0=0.0):
|
368 |
+
if y is None:
|
369 |
+
y = self.sample(x, ψ0=ψ0, t0=t0)
|
370 |
+
else:
|
371 |
+
self.losses = self.compute_losses(x, y, ψ0=ψ0)
|
372 |
+
return y
|
modules/repos_static/resemble_enhance/enhancer/lcfm/irmae.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from dataclasses import dataclass
|
3 |
+
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from torch import Tensor, nn
|
7 |
+
from torch.nn.utils.parametrizations import weight_norm
|
8 |
+
|
9 |
+
from ...common import Normalizer
|
10 |
+
|
11 |
+
logger = logging.getLogger(__name__)
|
12 |
+
|
13 |
+
|
14 |
+
@dataclass
|
15 |
+
class IRMAEOutput:
|
16 |
+
latent: Tensor # latent vector
|
17 |
+
decoded: Tensor | None # decoder output, include extra dim
|
18 |
+
|
19 |
+
|
20 |
+
class ResBlock(nn.Sequential):
|
21 |
+
def __init__(self, channels, dilations=[1, 2, 4, 8]):
|
22 |
+
wn = weight_norm
|
23 |
+
super().__init__(
|
24 |
+
nn.GroupNorm(32, channels),
|
25 |
+
nn.GELU(),
|
26 |
+
wn(nn.Conv1d(channels, channels, 3, padding="same", dilation=dilations[0])),
|
27 |
+
nn.GroupNorm(32, channels),
|
28 |
+
nn.GELU(),
|
29 |
+
wn(nn.Conv1d(channels, channels, 3, padding="same", dilation=dilations[1])),
|
30 |
+
nn.GroupNorm(32, channels),
|
31 |
+
nn.GELU(),
|
32 |
+
wn(nn.Conv1d(channels, channels, 3, padding="same", dilation=dilations[2])),
|
33 |
+
nn.GroupNorm(32, channels),
|
34 |
+
nn.GELU(),
|
35 |
+
wn(nn.Conv1d(channels, channels, 3, padding="same", dilation=dilations[3])),
|
36 |
+
)
|
37 |
+
|
38 |
+
def forward(self, x: Tensor):
|
39 |
+
return x + super().forward(x)
|
40 |
+
|
41 |
+
|
42 |
+
class IRMAE(nn.Module):
|
43 |
+
def __init__(
|
44 |
+
self,
|
45 |
+
input_dim,
|
46 |
+
output_dim,
|
47 |
+
latent_dim,
|
48 |
+
hidden_dim=1024,
|
49 |
+
num_irms=4,
|
50 |
+
):
|
51 |
+
"""
|
52 |
+
Args:
|
53 |
+
input_dim: input dimension
|
54 |
+
output_dim: output dimension
|
55 |
+
latent_dim: latent dimension
|
56 |
+
hidden_dim: hidden layer dimension
|
57 |
+
num_irm_matrics: number of implicit rank minimization matrices
|
58 |
+
norm: normalization layer
|
59 |
+
"""
|
60 |
+
self.input_dim = input_dim
|
61 |
+
super().__init__()
|
62 |
+
|
63 |
+
self.encoder = nn.Sequential(
|
64 |
+
nn.Conv1d(input_dim, hidden_dim, 3, padding="same"),
|
65 |
+
*[ResBlock(hidden_dim) for _ in range(4)],
|
66 |
+
# Try to obtain compact representation (https://proceedings.neurips.cc/paper/2020/file/a9078e8653368c9c291ae2f8b74012e7-Paper.pdf)
|
67 |
+
*[nn.Conv1d(hidden_dim if i == 0 else latent_dim, latent_dim, 1, bias=False) for i in range(num_irms)],
|
68 |
+
nn.Tanh(),
|
69 |
+
)
|
70 |
+
|
71 |
+
self.decoder = nn.Sequential(
|
72 |
+
nn.Conv1d(latent_dim, hidden_dim, 3, padding="same"),
|
73 |
+
*[ResBlock(hidden_dim) for _ in range(4)],
|
74 |
+
nn.Conv1d(hidden_dim, output_dim, 1),
|
75 |
+
)
|
76 |
+
|
77 |
+
self.head = nn.Sequential(
|
78 |
+
nn.Conv1d(output_dim, hidden_dim, 3, padding="same"),
|
79 |
+
nn.GELU(),
|
80 |
+
nn.Conv1d(hidden_dim, input_dim, 1),
|
81 |
+
)
|
82 |
+
|
83 |
+
self.estimator = Normalizer()
|
84 |
+
|
85 |
+
def encode(self, x):
|
86 |
+
"""
|
87 |
+
Args:
|
88 |
+
x: (b c t) tensor
|
89 |
+
"""
|
90 |
+
z = self.encoder(x) # (b c t)
|
91 |
+
_ = self.estimator(z) # Estimate the glboal mean and std of z
|
92 |
+
self.stats = {}
|
93 |
+
self.stats["z_mean"] = z.mean().item()
|
94 |
+
self.stats["z_std"] = z.std().item()
|
95 |
+
self.stats["z_abs_68"] = z.abs().quantile(0.6827).item()
|
96 |
+
self.stats["z_abs_95"] = z.abs().quantile(0.9545).item()
|
97 |
+
self.stats["z_abs_99"] = z.abs().quantile(0.9973).item()
|
98 |
+
return z
|
99 |
+
|
100 |
+
def decode(self, z):
|
101 |
+
"""
|
102 |
+
Args:
|
103 |
+
z: (b c t) tensor
|
104 |
+
"""
|
105 |
+
return self.decoder(z)
|
106 |
+
|
107 |
+
def forward(self, x, skip_decoding=False):
|
108 |
+
"""
|
109 |
+
Args:
|
110 |
+
x: (b c t) tensor
|
111 |
+
skip_decoding: if True, skip the decoding step
|
112 |
+
"""
|
113 |
+
z = self.encode(x) # q(z|x)
|
114 |
+
|
115 |
+
if skip_decoding:
|
116 |
+
# This speeds up the training in cfm only mode
|
117 |
+
decoded = None
|
118 |
+
else:
|
119 |
+
decoded = self.decode(z) # p(x|z)
|
120 |
+
predicted = self.head(decoded)
|
121 |
+
self.losses = dict(mse=F.mse_loss(predicted, x))
|
122 |
+
|
123 |
+
return IRMAEOutput(latent=z, decoded=decoded)
|
modules/repos_static/resemble_enhance/enhancer/lcfm/lcfm.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from enum import Enum
|
3 |
+
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from torch import Tensor, nn
|
8 |
+
|
9 |
+
from .cfm import CFM
|
10 |
+
from .irmae import IRMAE, IRMAEOutput
|
11 |
+
|
12 |
+
logger = logging.getLogger(__name__)
|
13 |
+
|
14 |
+
|
15 |
+
def freeze_(module):
|
16 |
+
for p in module.parameters():
|
17 |
+
p.requires_grad_(False)
|
18 |
+
|
19 |
+
|
20 |
+
class LCFM(nn.Module):
|
21 |
+
class Mode(Enum):
|
22 |
+
AE = "ae"
|
23 |
+
CFM = "cfm"
|
24 |
+
|
25 |
+
def __init__(self, ae: IRMAE, cfm: CFM, z_scale: float = 1.0):
|
26 |
+
super().__init__()
|
27 |
+
self.ae = ae
|
28 |
+
self.cfm = cfm
|
29 |
+
self.z_scale = z_scale
|
30 |
+
self._mode = None
|
31 |
+
self._eval_tau = 0.5
|
32 |
+
|
33 |
+
@property
|
34 |
+
def mode(self):
|
35 |
+
return self._mode
|
36 |
+
|
37 |
+
def set_mode_(self, mode):
|
38 |
+
mode = self.Mode(mode)
|
39 |
+
self._mode = mode
|
40 |
+
|
41 |
+
if mode == mode.AE:
|
42 |
+
freeze_(self.cfm)
|
43 |
+
logger.info("Freeze cfm")
|
44 |
+
elif mode == mode.CFM:
|
45 |
+
freeze_(self.ae)
|
46 |
+
logger.info("Freeze ae (encoder and decoder)")
|
47 |
+
else:
|
48 |
+
raise ValueError(f"Unknown training mode: {mode}")
|
49 |
+
|
50 |
+
def get_running_train_loop(self):
|
51 |
+
try:
|
52 |
+
# Lazy import
|
53 |
+
from ...utils.train_loop import TrainLoop
|
54 |
+
|
55 |
+
return TrainLoop.get_running_loop()
|
56 |
+
except ImportError:
|
57 |
+
return None
|
58 |
+
|
59 |
+
@property
|
60 |
+
def global_step(self):
|
61 |
+
loop = self.get_running_train_loop()
|
62 |
+
if loop is None:
|
63 |
+
return None
|
64 |
+
return loop.global_step
|
65 |
+
|
66 |
+
@torch.no_grad()
|
67 |
+
def _visualize(self, x, y, y_):
|
68 |
+
loop = self.get_running_train_loop()
|
69 |
+
if loop is None:
|
70 |
+
return
|
71 |
+
|
72 |
+
plt.subplot(221)
|
73 |
+
plt.imshow(y[0].detach().cpu().numpy(), aspect="auto", origin="lower", interpolation="none")
|
74 |
+
plt.title("GT")
|
75 |
+
|
76 |
+
plt.subplot(222)
|
77 |
+
y_ = y_[:, : y.shape[1]]
|
78 |
+
plt.imshow(y_[0].detach().cpu().numpy(), aspect="auto", origin="lower", interpolation="none")
|
79 |
+
plt.title("Posterior")
|
80 |
+
|
81 |
+
plt.subplot(223)
|
82 |
+
z_ = self.cfm(x)
|
83 |
+
y__ = self.ae.decode(z_)
|
84 |
+
y__ = y__[:, : y.shape[1]]
|
85 |
+
plt.imshow(y__[0].detach().cpu().numpy(), aspect="auto", origin="lower", interpolation="none")
|
86 |
+
plt.title("C-Prior")
|
87 |
+
del y__
|
88 |
+
|
89 |
+
plt.subplot(224)
|
90 |
+
z_ = torch.randn_like(z_)
|
91 |
+
y__ = self.ae.decode(z_)
|
92 |
+
y__ = y__[:, : y.shape[1]]
|
93 |
+
plt.imshow(y__[0].detach().cpu().numpy(), aspect="auto", origin="lower", interpolation="none")
|
94 |
+
plt.title("Prior")
|
95 |
+
del z_, y__
|
96 |
+
|
97 |
+
path = loop.make_current_step_viz_path("recon", ".png")
|
98 |
+
path.parent.mkdir(exist_ok=True, parents=True)
|
99 |
+
plt.tight_layout()
|
100 |
+
plt.savefig(path, dpi=500)
|
101 |
+
plt.close()
|
102 |
+
|
103 |
+
def _scale(self, z: Tensor):
|
104 |
+
return z * self.z_scale
|
105 |
+
|
106 |
+
def _unscale(self, z: Tensor):
|
107 |
+
return z / self.z_scale
|
108 |
+
|
109 |
+
def eval_tau_(self, tau):
|
110 |
+
self._eval_tau = tau
|
111 |
+
|
112 |
+
def forward(self, x, y: Tensor | None = None, ψ0: Tensor | None = None):
|
113 |
+
"""
|
114 |
+
Args:
|
115 |
+
x: (b d t), condition mel
|
116 |
+
y: (b d t), target mel
|
117 |
+
ψ0: (b d t), starting mel
|
118 |
+
"""
|
119 |
+
if self.mode == self.Mode.CFM:
|
120 |
+
self.ae.eval() # Always set to eval when training cfm
|
121 |
+
|
122 |
+
if ψ0 is not None:
|
123 |
+
ψ0 = self._scale(self.ae.encode(ψ0))
|
124 |
+
if self.training:
|
125 |
+
tau = torch.rand_like(ψ0[:, :1, :1])
|
126 |
+
else:
|
127 |
+
tau = self._eval_tau
|
128 |
+
ψ0 = tau * torch.randn_like(ψ0) + (1 - tau) * ψ0
|
129 |
+
|
130 |
+
if y is None:
|
131 |
+
if self.mode == self.Mode.AE:
|
132 |
+
with torch.no_grad():
|
133 |
+
training = self.ae.training
|
134 |
+
self.ae.eval()
|
135 |
+
z = self.ae.encode(x)
|
136 |
+
self.ae.train(training)
|
137 |
+
else:
|
138 |
+
z = self._unscale(self.cfm(x, ψ0=ψ0))
|
139 |
+
|
140 |
+
h = self.ae.decode(z)
|
141 |
+
else:
|
142 |
+
ae_output: IRMAEOutput = self.ae(y, skip_decoding=self.mode == self.Mode.CFM)
|
143 |
+
|
144 |
+
if self.mode == self.Mode.CFM:
|
145 |
+
_ = self.cfm(x, self._scale(ae_output.latent.detach()), ψ0=ψ0)
|
146 |
+
|
147 |
+
h = ae_output.decoded
|
148 |
+
|
149 |
+
if h is not None and self.global_step is not None and self.global_step % 100 == 0:
|
150 |
+
self._visualize(x[:1], y[:1], h[:1])
|
151 |
+
|
152 |
+
return h
|
modules/repos_static/resemble_enhance/enhancer/lcfm/wn.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import math
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
logger = logging.getLogger(__name__)
|
8 |
+
|
9 |
+
|
10 |
+
@torch.jit.script
|
11 |
+
def _fused_tanh_sigmoid(h):
|
12 |
+
a, b = h.chunk(2, dim=1)
|
13 |
+
h = a.tanh() * b.sigmoid()
|
14 |
+
return h
|
15 |
+
|
16 |
+
|
17 |
+
class WNLayer(nn.Module):
|
18 |
+
"""
|
19 |
+
A DiffWave-like WN
|
20 |
+
"""
|
21 |
+
|
22 |
+
def __init__(self, hidden_dim, local_dim, global_dim, kernel_size, dilation):
|
23 |
+
super().__init__()
|
24 |
+
|
25 |
+
local_output_dim = hidden_dim * 2
|
26 |
+
|
27 |
+
if global_dim is not None:
|
28 |
+
self.gconv = nn.Conv1d(global_dim, hidden_dim, 1)
|
29 |
+
|
30 |
+
if local_dim is not None:
|
31 |
+
self.lconv = nn.Conv1d(local_dim, local_output_dim, 1)
|
32 |
+
|
33 |
+
self.dconv = nn.Conv1d(hidden_dim, local_output_dim, kernel_size, dilation=dilation, padding="same")
|
34 |
+
|
35 |
+
self.out = nn.Conv1d(hidden_dim, 2 * hidden_dim, kernel_size=1)
|
36 |
+
|
37 |
+
def forward(self, z, l, g):
|
38 |
+
identity = z
|
39 |
+
|
40 |
+
if g is not None:
|
41 |
+
if g.dim() == 2:
|
42 |
+
g = g.unsqueeze(-1)
|
43 |
+
z = z + self.gconv(g)
|
44 |
+
|
45 |
+
z = self.dconv(z)
|
46 |
+
|
47 |
+
if l is not None:
|
48 |
+
z = z + self.lconv(l)
|
49 |
+
|
50 |
+
z = _fused_tanh_sigmoid(z)
|
51 |
+
|
52 |
+
h = self.out(z)
|
53 |
+
|
54 |
+
z, s = h.chunk(2, dim=1)
|
55 |
+
|
56 |
+
o = (z + identity) / math.sqrt(2)
|
57 |
+
|
58 |
+
return o, s
|
59 |
+
|
60 |
+
|
61 |
+
class WN(nn.Module):
|
62 |
+
def __init__(
|
63 |
+
self,
|
64 |
+
input_dim,
|
65 |
+
output_dim,
|
66 |
+
local_dim=None,
|
67 |
+
global_dim=None,
|
68 |
+
n_layers=30,
|
69 |
+
kernel_size=3,
|
70 |
+
dilation_cycle=5,
|
71 |
+
hidden_dim=512,
|
72 |
+
):
|
73 |
+
super().__init__()
|
74 |
+
assert kernel_size % 2 == 1
|
75 |
+
assert hidden_dim % 2 == 0
|
76 |
+
|
77 |
+
self.input_dim = input_dim
|
78 |
+
self.hidden_dim = hidden_dim
|
79 |
+
self.local_dim = local_dim
|
80 |
+
self.global_dim = global_dim
|
81 |
+
|
82 |
+
self.start = nn.Conv1d(input_dim, hidden_dim, 1)
|
83 |
+
if local_dim is not None:
|
84 |
+
self.local_norm = nn.InstanceNorm1d(local_dim)
|
85 |
+
|
86 |
+
self.layers = nn.ModuleList(
|
87 |
+
[
|
88 |
+
WNLayer(
|
89 |
+
hidden_dim=hidden_dim,
|
90 |
+
local_dim=local_dim,
|
91 |
+
global_dim=global_dim,
|
92 |
+
kernel_size=kernel_size,
|
93 |
+
dilation=2 ** (i % dilation_cycle),
|
94 |
+
)
|
95 |
+
for i in range(n_layers)
|
96 |
+
]
|
97 |
+
)
|
98 |
+
|
99 |
+
self.end = nn.Conv1d(hidden_dim, output_dim, 1)
|
100 |
+
|
101 |
+
def forward(self, z, l=None, g=None):
|
102 |
+
"""
|
103 |
+
Args:
|
104 |
+
z: input (b c t)
|
105 |
+
l: local condition (b c t)
|
106 |
+
g: global condition (b d)
|
107 |
+
"""
|
108 |
+
z = self.start(z)
|
109 |
+
|
110 |
+
if l is not None:
|
111 |
+
l = self.local_norm(l)
|
112 |
+
|
113 |
+
# Skips
|
114 |
+
s_list = []
|
115 |
+
|
116 |
+
for layer in self.layers:
|
117 |
+
z, s = layer(z, l, g)
|
118 |
+
s_list.append(s)
|
119 |
+
|
120 |
+
s_list = torch.stack(s_list, dim=0).sum(dim=0)
|
121 |
+
s_list = s_list / math.sqrt(len(self.layers))
|
122 |
+
|
123 |
+
o = self.end(s_list)
|
124 |
+
|
125 |
+
return o
|
126 |
+
|
127 |
+
def summarize(self, length=100):
|
128 |
+
from ptflops import get_model_complexity_info
|
129 |
+
|
130 |
+
x = torch.randn(1, self.input_dim, length)
|
131 |
+
|
132 |
+
macs, params = get_model_complexity_info(
|
133 |
+
self,
|
134 |
+
(self.input_dim, length),
|
135 |
+
as_strings=True,
|
136 |
+
print_per_layer_stat=True,
|
137 |
+
verbose=True,
|
138 |
+
)
|
139 |
+
|
140 |
+
print(f"Input shape: {x.shape}")
|
141 |
+
print(f"Computational complexity: {macs}")
|
142 |
+
print(f"Number of parameters: {params}")
|
143 |
+
|
144 |
+
|
145 |
+
if __name__ == "__main__":
|
146 |
+
model = WN(input_dim=64, output_dim=64)
|
147 |
+
model.summarize()
|
modules/repos_static/resemble_enhance/enhancer/univnet/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .univnet import UnivNet
|
modules/repos_static/resemble_enhance/enhancer/univnet/alias_free_torch/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
2 |
+
# LICENSE is in incl_licenses directory.
|
3 |
+
|
4 |
+
from .filter import *
|
5 |
+
from .resample import *
|
modules/repos_static/resemble_enhance/enhancer/univnet/alias_free_torch/filter.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
2 |
+
# LICENSE is in incl_licenses directory.
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import math
|
8 |
+
|
9 |
+
if 'sinc' in dir(torch):
|
10 |
+
sinc = torch.sinc
|
11 |
+
else:
|
12 |
+
# This code is adopted from adefossez's julius.core.sinc under the MIT License
|
13 |
+
# https://adefossez.github.io/julius/julius/core.html
|
14 |
+
# LICENSE is in incl_licenses directory.
|
15 |
+
def sinc(x: torch.Tensor):
|
16 |
+
"""
|
17 |
+
Implementation of sinc, i.e. sin(pi * x) / (pi * x)
|
18 |
+
__Warning__: Different to julius.sinc, the input is multiplied by `pi`!
|
19 |
+
"""
|
20 |
+
return torch.where(x == 0,
|
21 |
+
torch.tensor(1., device=x.device, dtype=x.dtype),
|
22 |
+
torch.sin(math.pi * x) / math.pi / x)
|
23 |
+
|
24 |
+
|
25 |
+
# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
|
26 |
+
# https://adefossez.github.io/julius/julius/lowpass.html
|
27 |
+
# LICENSE is in incl_licenses directory.
|
28 |
+
def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size]
|
29 |
+
even = (kernel_size % 2 == 0)
|
30 |
+
half_size = kernel_size // 2
|
31 |
+
|
32 |
+
#For kaiser window
|
33 |
+
delta_f = 4 * half_width
|
34 |
+
A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
|
35 |
+
if A > 50.:
|
36 |
+
beta = 0.1102 * (A - 8.7)
|
37 |
+
elif A >= 21.:
|
38 |
+
beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.)
|
39 |
+
else:
|
40 |
+
beta = 0.
|
41 |
+
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
|
42 |
+
|
43 |
+
# ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
|
44 |
+
if even:
|
45 |
+
time = (torch.arange(-half_size, half_size) + 0.5)
|
46 |
+
else:
|
47 |
+
time = torch.arange(kernel_size) - half_size
|
48 |
+
if cutoff == 0:
|
49 |
+
filter_ = torch.zeros_like(time)
|
50 |
+
else:
|
51 |
+
filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
|
52 |
+
# Normalize filter to have sum = 1, otherwise we will have a small leakage
|
53 |
+
# of the constant component in the input signal.
|
54 |
+
filter_ /= filter_.sum()
|
55 |
+
filter = filter_.view(1, 1, kernel_size)
|
56 |
+
|
57 |
+
return filter
|
58 |
+
|
59 |
+
|
60 |
+
class LowPassFilter1d(nn.Module):
|
61 |
+
def __init__(self,
|
62 |
+
cutoff=0.5,
|
63 |
+
half_width=0.6,
|
64 |
+
stride: int = 1,
|
65 |
+
padding: bool = True,
|
66 |
+
padding_mode: str = 'replicate',
|
67 |
+
kernel_size: int = 12):
|
68 |
+
# kernel_size should be even number for stylegan3 setup,
|
69 |
+
# in this implementation, odd number is also possible.
|
70 |
+
super().__init__()
|
71 |
+
if cutoff < -0.:
|
72 |
+
raise ValueError("Minimum cutoff must be larger than zero.")
|
73 |
+
if cutoff > 0.5:
|
74 |
+
raise ValueError("A cutoff above 0.5 does not make sense.")
|
75 |
+
self.kernel_size = kernel_size
|
76 |
+
self.even = (kernel_size % 2 == 0)
|
77 |
+
self.pad_left = kernel_size // 2 - int(self.even)
|
78 |
+
self.pad_right = kernel_size // 2
|
79 |
+
self.stride = stride
|
80 |
+
self.padding = padding
|
81 |
+
self.padding_mode = padding_mode
|
82 |
+
filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
|
83 |
+
self.register_buffer("filter", filter)
|
84 |
+
|
85 |
+
#input [B, C, T]
|
86 |
+
def forward(self, x):
|
87 |
+
_, C, _ = x.shape
|
88 |
+
|
89 |
+
if self.padding:
|
90 |
+
x = F.pad(x, (self.pad_left, self.pad_right),
|
91 |
+
mode=self.padding_mode)
|
92 |
+
out = F.conv1d(x, self.filter.expand(C, -1, -1),
|
93 |
+
stride=self.stride, groups=C)
|
94 |
+
|
95 |
+
return out
|
modules/repos_static/resemble_enhance/enhancer/univnet/alias_free_torch/resample.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
2 |
+
# LICENSE is in incl_licenses directory.
|
3 |
+
|
4 |
+
import torch.nn as nn
|
5 |
+
from torch.nn import functional as F
|
6 |
+
from .filter import LowPassFilter1d
|
7 |
+
from .filter import kaiser_sinc_filter1d
|
8 |
+
|
9 |
+
|
10 |
+
class UpSample1d(nn.Module):
|
11 |
+
def __init__(self, ratio=2, kernel_size=None):
|
12 |
+
super().__init__()
|
13 |
+
self.ratio = ratio
|
14 |
+
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
15 |
+
self.stride = ratio
|
16 |
+
self.pad = self.kernel_size // ratio - 1
|
17 |
+
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
|
18 |
+
self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
|
19 |
+
filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio,
|
20 |
+
half_width=0.6 / ratio,
|
21 |
+
kernel_size=self.kernel_size)
|
22 |
+
self.register_buffer("filter", filter)
|
23 |
+
|
24 |
+
# x: [B, C, T]
|
25 |
+
def forward(self, x):
|
26 |
+
_, C, _ = x.shape
|
27 |
+
|
28 |
+
x = F.pad(x, (self.pad, self.pad), mode='replicate')
|
29 |
+
x = self.ratio * F.conv_transpose1d(
|
30 |
+
x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
|
31 |
+
x = x[..., self.pad_left:-self.pad_right]
|
32 |
+
|
33 |
+
return x
|
34 |
+
|
35 |
+
|
36 |
+
class DownSample1d(nn.Module):
|
37 |
+
def __init__(self, ratio=2, kernel_size=None):
|
38 |
+
super().__init__()
|
39 |
+
self.ratio = ratio
|
40 |
+
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
41 |
+
self.lowpass = LowPassFilter1d(cutoff=0.5 / ratio,
|
42 |
+
half_width=0.6 / ratio,
|
43 |
+
stride=ratio,
|
44 |
+
kernel_size=self.kernel_size)
|
45 |
+
|
46 |
+
def forward(self, x):
|
47 |
+
xx = self.lowpass(x)
|
48 |
+
|
49 |
+
return xx
|
modules/repos_static/resemble_enhance/enhancer/univnet/amp.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Refer from https://github.com/NVIDIA/BigVGAN
|
2 |
+
|
3 |
+
import math
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from torch import nn
|
8 |
+
from torch.nn.utils.parametrizations import weight_norm
|
9 |
+
|
10 |
+
from .alias_free_torch import DownSample1d, UpSample1d
|
11 |
+
|
12 |
+
|
13 |
+
class SnakeBeta(nn.Module):
|
14 |
+
"""
|
15 |
+
A modified Snake function which uses separate parameters for the magnitude of the periodic components
|
16 |
+
Shape:
|
17 |
+
- Input: (B, C, T)
|
18 |
+
- Output: (B, C, T), same shape as the input
|
19 |
+
Parameters:
|
20 |
+
- alpha - trainable parameter that controls frequency
|
21 |
+
- beta - trainable parameter that controls magnitude
|
22 |
+
References:
|
23 |
+
- This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
24 |
+
https://arxiv.org/abs/2006.08195
|
25 |
+
Examples:
|
26 |
+
>>> a1 = snakebeta(256)
|
27 |
+
>>> x = torch.randn(256)
|
28 |
+
>>> x = a1(x)
|
29 |
+
"""
|
30 |
+
|
31 |
+
def __init__(self, in_features, alpha=1.0, clamp=(1e-2, 50)):
|
32 |
+
"""
|
33 |
+
Initialization.
|
34 |
+
INPUT:
|
35 |
+
- in_features: shape of the input
|
36 |
+
- alpha - trainable parameter that controls frequency
|
37 |
+
- beta - trainable parameter that controls magnitude
|
38 |
+
alpha is initialized to 1 by default, higher values = higher-frequency.
|
39 |
+
beta is initialized to 1 by default, higher values = higher-magnitude.
|
40 |
+
alpha will be trained along with the rest of your model.
|
41 |
+
"""
|
42 |
+
super().__init__()
|
43 |
+
self.in_features = in_features
|
44 |
+
self.log_alpha = nn.Parameter(torch.zeros(in_features) + math.log(alpha))
|
45 |
+
self.log_beta = nn.Parameter(torch.zeros(in_features) + math.log(alpha))
|
46 |
+
self.clamp = clamp
|
47 |
+
|
48 |
+
def forward(self, x):
|
49 |
+
"""
|
50 |
+
Forward pass of the function.
|
51 |
+
Applies the function to the input elementwise.
|
52 |
+
SnakeBeta ∶= x + 1/b * sin^2 (xa)
|
53 |
+
"""
|
54 |
+
alpha = self.log_alpha.exp().clamp(*self.clamp)
|
55 |
+
alpha = alpha[None, :, None]
|
56 |
+
|
57 |
+
beta = self.log_beta.exp().clamp(*self.clamp)
|
58 |
+
beta = beta[None, :, None]
|
59 |
+
|
60 |
+
x = x + (1.0 / beta) * (x * alpha).sin().pow(2)
|
61 |
+
|
62 |
+
return x
|
63 |
+
|
64 |
+
|
65 |
+
class UpActDown(nn.Module):
|
66 |
+
def __init__(
|
67 |
+
self,
|
68 |
+
act,
|
69 |
+
up_ratio: int = 2,
|
70 |
+
down_ratio: int = 2,
|
71 |
+
up_kernel_size: int = 12,
|
72 |
+
down_kernel_size: int = 12,
|
73 |
+
):
|
74 |
+
super().__init__()
|
75 |
+
self.up_ratio = up_ratio
|
76 |
+
self.down_ratio = down_ratio
|
77 |
+
self.act = act
|
78 |
+
self.upsample = UpSample1d(up_ratio, up_kernel_size)
|
79 |
+
self.downsample = DownSample1d(down_ratio, down_kernel_size)
|
80 |
+
|
81 |
+
def forward(self, x):
|
82 |
+
# x: [B,C,T]
|
83 |
+
x = self.upsample(x)
|
84 |
+
x = self.act(x)
|
85 |
+
x = self.downsample(x)
|
86 |
+
return x
|
87 |
+
|
88 |
+
|
89 |
+
class AMPBlock(nn.Sequential):
|
90 |
+
def __init__(self, channels, *, kernel_size=3, dilations=(1, 3, 5)):
|
91 |
+
super().__init__(*(self._make_layer(channels, kernel_size, d) for d in dilations))
|
92 |
+
|
93 |
+
def _make_layer(self, channels, kernel_size, dilation):
|
94 |
+
return nn.Sequential(
|
95 |
+
weight_norm(nn.Conv1d(channels, channels, kernel_size, dilation=dilation, padding="same")),
|
96 |
+
UpActDown(act=SnakeBeta(channels)),
|
97 |
+
weight_norm(nn.Conv1d(channels, channels, kernel_size, padding="same")),
|
98 |
+
)
|
99 |
+
|
100 |
+
def forward(self, x):
|
101 |
+
return x + super().forward(x)
|
modules/repos_static/resemble_enhance/enhancer/univnet/discriminator.py
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch import Tensor, nn
|
6 |
+
from torch.nn.utils.parametrizations import weight_norm
|
7 |
+
|
8 |
+
from ..hparams import HParams
|
9 |
+
from .mrstft import get_stft_cfgs
|
10 |
+
|
11 |
+
logger = logging.getLogger(__name__)
|
12 |
+
|
13 |
+
|
14 |
+
class PeriodNetwork(nn.Module):
|
15 |
+
def __init__(self, period):
|
16 |
+
super().__init__()
|
17 |
+
self.period = period
|
18 |
+
wn = weight_norm
|
19 |
+
self.convs = nn.ModuleList(
|
20 |
+
[
|
21 |
+
wn(nn.Conv2d(1, 64, (5, 1), (3, 1), padding=(2, 0))),
|
22 |
+
wn(nn.Conv2d(64, 128, (5, 1), (3, 1), padding=(2, 0))),
|
23 |
+
wn(nn.Conv2d(128, 256, (5, 1), (3, 1), padding=(2, 0))),
|
24 |
+
wn(nn.Conv2d(256, 512, (5, 1), (3, 1), padding=(2, 0))),
|
25 |
+
wn(nn.Conv2d(512, 1024, (5, 1), 1, padding=(2, 0))),
|
26 |
+
]
|
27 |
+
)
|
28 |
+
self.conv_post = wn(nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
29 |
+
|
30 |
+
def forward(self, x):
|
31 |
+
"""
|
32 |
+
Args:
|
33 |
+
x: [B, 1, T]
|
34 |
+
"""
|
35 |
+
assert x.dim() == 3, f"(B, 1, T) is expected, but got {x.shape}."
|
36 |
+
|
37 |
+
# 1d to 2d
|
38 |
+
b, c, t = x.shape
|
39 |
+
if t % self.period != 0: # pad first
|
40 |
+
n_pad = self.period - (t % self.period)
|
41 |
+
x = F.pad(x, (0, n_pad), "reflect")
|
42 |
+
t = t + n_pad
|
43 |
+
x = x.view(b, c, t // self.period, self.period)
|
44 |
+
|
45 |
+
for l in self.convs:
|
46 |
+
x = l(x)
|
47 |
+
x = F.leaky_relu(x, 0.2)
|
48 |
+
x = self.conv_post(x)
|
49 |
+
x = torch.flatten(x, 1, -1)
|
50 |
+
|
51 |
+
return x
|
52 |
+
|
53 |
+
|
54 |
+
class SpecNetwork(nn.Module):
|
55 |
+
def __init__(self, stft_cfg: dict):
|
56 |
+
super().__init__()
|
57 |
+
wn = weight_norm
|
58 |
+
self.stft_cfg = stft_cfg
|
59 |
+
self.convs = nn.ModuleList(
|
60 |
+
[
|
61 |
+
wn(nn.Conv2d(1, 32, (3, 9), padding=(1, 4))),
|
62 |
+
wn(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
|
63 |
+
wn(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
|
64 |
+
wn(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
|
65 |
+
wn(nn.Conv2d(32, 32, (3, 3), padding=(1, 1))),
|
66 |
+
]
|
67 |
+
)
|
68 |
+
self.conv_post = wn(nn.Conv2d(32, 1, (3, 3), padding=(1, 1)))
|
69 |
+
|
70 |
+
def forward(self, x):
|
71 |
+
"""
|
72 |
+
Args:
|
73 |
+
x: [B, 1, T]
|
74 |
+
"""
|
75 |
+
x = self.spectrogram(x)
|
76 |
+
x = x.unsqueeze(1)
|
77 |
+
for l in self.convs:
|
78 |
+
x = l(x)
|
79 |
+
x = F.leaky_relu(x, 0.2)
|
80 |
+
x = self.conv_post(x)
|
81 |
+
x = x.flatten(1, -1)
|
82 |
+
return x
|
83 |
+
|
84 |
+
def spectrogram(self, x):
|
85 |
+
"""
|
86 |
+
Args:
|
87 |
+
x: [B, 1, T]
|
88 |
+
"""
|
89 |
+
x = x.squeeze(1)
|
90 |
+
dtype = x.dtype
|
91 |
+
stft_cfg = dict(self.stft_cfg)
|
92 |
+
x = torch.stft(x.float(), center=False, return_complex=False, **stft_cfg)
|
93 |
+
mag = x.norm(p=2, dim=-1) # [B, F, TT]
|
94 |
+
mag = mag.to(dtype) # [B, F, TT]
|
95 |
+
return mag
|
96 |
+
|
97 |
+
|
98 |
+
class MD(nn.ModuleList):
|
99 |
+
def __init__(self, l: list):
|
100 |
+
super().__init__([self._create_network(x) for x in l])
|
101 |
+
self._loss_type = None
|
102 |
+
|
103 |
+
def loss_type_(self, loss_type):
|
104 |
+
self._loss_type = loss_type
|
105 |
+
|
106 |
+
def _create_network(self, _):
|
107 |
+
raise NotImplementedError
|
108 |
+
|
109 |
+
def _forward_each(self, d, x, y):
|
110 |
+
assert self._loss_type is not None, "loss_type is not set."
|
111 |
+
loss_type = self._loss_type
|
112 |
+
|
113 |
+
if loss_type == "hinge":
|
114 |
+
if y == 0:
|
115 |
+
# d(x) should be small -> -1
|
116 |
+
loss = F.relu(1 + d(x)).mean()
|
117 |
+
elif y == 1:
|
118 |
+
# d(x) should be large -> 1
|
119 |
+
loss = F.relu(1 - d(x)).mean()
|
120 |
+
else:
|
121 |
+
raise ValueError(f"Invalid y: {y}")
|
122 |
+
elif loss_type == "wgan":
|
123 |
+
if y == 0:
|
124 |
+
loss = d(x).mean()
|
125 |
+
elif y == 1:
|
126 |
+
loss = -d(x).mean()
|
127 |
+
else:
|
128 |
+
raise ValueError(f"Invalid y: {y}")
|
129 |
+
else:
|
130 |
+
raise ValueError(f"Invalid loss_type: {loss_type}")
|
131 |
+
|
132 |
+
return loss
|
133 |
+
|
134 |
+
def forward(self, x, y) -> Tensor:
|
135 |
+
losses = [self._forward_each(d, x, y) for d in self]
|
136 |
+
return torch.stack(losses).mean()
|
137 |
+
|
138 |
+
|
139 |
+
class MPD(MD):
|
140 |
+
def __init__(self):
|
141 |
+
super().__init__([2, 3, 7, 13, 17])
|
142 |
+
|
143 |
+
def _create_network(self, period):
|
144 |
+
return PeriodNetwork(period)
|
145 |
+
|
146 |
+
|
147 |
+
class MRD(MD):
|
148 |
+
def __init__(self, stft_cfgs):
|
149 |
+
super().__init__(stft_cfgs)
|
150 |
+
|
151 |
+
def _create_network(self, stft_cfg):
|
152 |
+
return SpecNetwork(stft_cfg)
|
153 |
+
|
154 |
+
|
155 |
+
class Discriminator(nn.Module):
|
156 |
+
@property
|
157 |
+
def wav_rate(self):
|
158 |
+
return self.hp.wav_rate
|
159 |
+
|
160 |
+
def __init__(self, hp: HParams):
|
161 |
+
super().__init__()
|
162 |
+
self.hp = hp
|
163 |
+
self.stft_cfgs = get_stft_cfgs(hp)
|
164 |
+
self.mpd = MPD()
|
165 |
+
self.mrd = MRD(self.stft_cfgs)
|
166 |
+
self.dummy_float: Tensor
|
167 |
+
self.register_buffer("dummy_float", torch.zeros(0), persistent=False)
|
168 |
+
|
169 |
+
def loss_type_(self, loss_type):
|
170 |
+
self.mpd.loss_type_(loss_type)
|
171 |
+
self.mrd.loss_type_(loss_type)
|
172 |
+
|
173 |
+
def forward(self, fake, real=None):
|
174 |
+
"""
|
175 |
+
Args:
|
176 |
+
fake: [B T]
|
177 |
+
real: [B T]
|
178 |
+
"""
|
179 |
+
fake = fake.to(self.dummy_float)
|
180 |
+
|
181 |
+
if real is None:
|
182 |
+
self.loss_type_("wgan")
|
183 |
+
else:
|
184 |
+
length_difference = (fake.shape[-1] - real.shape[-1]) / real.shape[-1]
|
185 |
+
assert length_difference < 0.05, f"length_difference should be smaller than 5%"
|
186 |
+
|
187 |
+
self.loss_type_("hinge")
|
188 |
+
real = real.to(self.dummy_float)
|
189 |
+
|
190 |
+
fake = fake[..., : real.shape[-1]]
|
191 |
+
real = real[..., : fake.shape[-1]]
|
192 |
+
|
193 |
+
losses = {}
|
194 |
+
|
195 |
+
assert fake.dim() == 2, f"(B, T) is expected, but got {fake.shape}."
|
196 |
+
assert real is None or real.dim() == 2, f"(B, T) is expected, but got {real.shape}."
|
197 |
+
|
198 |
+
fake = fake.unsqueeze(1)
|
199 |
+
|
200 |
+
if real is None:
|
201 |
+
losses["mpd"] = self.mpd(fake, 1)
|
202 |
+
losses["mrd"] = self.mrd(fake, 1)
|
203 |
+
else:
|
204 |
+
real = real.unsqueeze(1)
|
205 |
+
losses["mpd_fake"] = self.mpd(fake, 0)
|
206 |
+
losses["mpd_real"] = self.mpd(real, 1)
|
207 |
+
losses["mrd_fake"] = self.mrd(fake, 0)
|
208 |
+
losses["mrd_real"] = self.mrd(real, 1)
|
209 |
+
|
210 |
+
return losses
|
modules/repos_static/resemble_enhance/enhancer/univnet/lvcnet.py
ADDED
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" refer from https://github.com/zceng/LVCNet """
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from torch import nn
|
7 |
+
from torch.nn.utils.parametrizations import weight_norm
|
8 |
+
|
9 |
+
from .amp import AMPBlock
|
10 |
+
|
11 |
+
|
12 |
+
class KernelPredictor(torch.nn.Module):
|
13 |
+
"""Kernel predictor for the location-variable convolutions"""
|
14 |
+
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
cond_channels,
|
18 |
+
conv_in_channels,
|
19 |
+
conv_out_channels,
|
20 |
+
conv_layers,
|
21 |
+
conv_kernel_size=3,
|
22 |
+
kpnet_hidden_channels=64,
|
23 |
+
kpnet_conv_size=3,
|
24 |
+
kpnet_dropout=0.0,
|
25 |
+
kpnet_nonlinear_activation="LeakyReLU",
|
26 |
+
kpnet_nonlinear_activation_params={"negative_slope": 0.1},
|
27 |
+
):
|
28 |
+
"""
|
29 |
+
Args:
|
30 |
+
cond_channels (int): number of channel for the conditioning sequence,
|
31 |
+
conv_in_channels (int): number of channel for the input sequence,
|
32 |
+
conv_out_channels (int): number of channel for the output sequence,
|
33 |
+
conv_layers (int): number of layers
|
34 |
+
"""
|
35 |
+
super().__init__()
|
36 |
+
|
37 |
+
self.conv_in_channels = conv_in_channels
|
38 |
+
self.conv_out_channels = conv_out_channels
|
39 |
+
self.conv_kernel_size = conv_kernel_size
|
40 |
+
self.conv_layers = conv_layers
|
41 |
+
|
42 |
+
kpnet_kernel_channels = conv_in_channels * conv_out_channels * conv_kernel_size * conv_layers # l_w
|
43 |
+
kpnet_bias_channels = conv_out_channels * conv_layers # l_b
|
44 |
+
|
45 |
+
self.input_conv = nn.Sequential(
|
46 |
+
weight_norm(nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True)),
|
47 |
+
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
48 |
+
)
|
49 |
+
|
50 |
+
self.residual_convs = nn.ModuleList()
|
51 |
+
padding = (kpnet_conv_size - 1) // 2
|
52 |
+
for _ in range(3):
|
53 |
+
self.residual_convs.append(
|
54 |
+
nn.Sequential(
|
55 |
+
nn.Dropout(kpnet_dropout),
|
56 |
+
weight_norm(
|
57 |
+
nn.Conv1d(
|
58 |
+
kpnet_hidden_channels,
|
59 |
+
kpnet_hidden_channels,
|
60 |
+
kpnet_conv_size,
|
61 |
+
padding=padding,
|
62 |
+
bias=True,
|
63 |
+
)
|
64 |
+
),
|
65 |
+
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
66 |
+
weight_norm(
|
67 |
+
nn.Conv1d(
|
68 |
+
kpnet_hidden_channels,
|
69 |
+
kpnet_hidden_channels,
|
70 |
+
kpnet_conv_size,
|
71 |
+
padding=padding,
|
72 |
+
bias=True,
|
73 |
+
)
|
74 |
+
),
|
75 |
+
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
76 |
+
)
|
77 |
+
)
|
78 |
+
self.kernel_conv = weight_norm(
|
79 |
+
nn.Conv1d(
|
80 |
+
kpnet_hidden_channels,
|
81 |
+
kpnet_kernel_channels,
|
82 |
+
kpnet_conv_size,
|
83 |
+
padding=padding,
|
84 |
+
bias=True,
|
85 |
+
)
|
86 |
+
)
|
87 |
+
self.bias_conv = weight_norm(
|
88 |
+
nn.Conv1d(
|
89 |
+
kpnet_hidden_channels,
|
90 |
+
kpnet_bias_channels,
|
91 |
+
kpnet_conv_size,
|
92 |
+
padding=padding,
|
93 |
+
bias=True,
|
94 |
+
)
|
95 |
+
)
|
96 |
+
|
97 |
+
def forward(self, c):
|
98 |
+
"""
|
99 |
+
Args:
|
100 |
+
c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
|
101 |
+
"""
|
102 |
+
batch, _, cond_length = c.shape
|
103 |
+
c = self.input_conv(c)
|
104 |
+
for residual_conv in self.residual_convs:
|
105 |
+
residual_conv.to(c.device)
|
106 |
+
c = c + residual_conv(c)
|
107 |
+
k = self.kernel_conv(c)
|
108 |
+
b = self.bias_conv(c)
|
109 |
+
kernels = k.contiguous().view(
|
110 |
+
batch,
|
111 |
+
self.conv_layers,
|
112 |
+
self.conv_in_channels,
|
113 |
+
self.conv_out_channels,
|
114 |
+
self.conv_kernel_size,
|
115 |
+
cond_length,
|
116 |
+
)
|
117 |
+
bias = b.contiguous().view(
|
118 |
+
batch,
|
119 |
+
self.conv_layers,
|
120 |
+
self.conv_out_channels,
|
121 |
+
cond_length,
|
122 |
+
)
|
123 |
+
|
124 |
+
return kernels, bias
|
125 |
+
|
126 |
+
|
127 |
+
class LVCBlock(torch.nn.Module):
|
128 |
+
"""the location-variable convolutions"""
|
129 |
+
|
130 |
+
def __init__(
|
131 |
+
self,
|
132 |
+
in_channels,
|
133 |
+
cond_channels,
|
134 |
+
stride,
|
135 |
+
dilations=[1, 3, 9, 27],
|
136 |
+
lReLU_slope=0.2,
|
137 |
+
conv_kernel_size=3,
|
138 |
+
cond_hop_length=256,
|
139 |
+
kpnet_hidden_channels=64,
|
140 |
+
kpnet_conv_size=3,
|
141 |
+
kpnet_dropout=0.0,
|
142 |
+
add_extra_noise=False,
|
143 |
+
downsampling=False,
|
144 |
+
):
|
145 |
+
super().__init__()
|
146 |
+
|
147 |
+
self.add_extra_noise = add_extra_noise
|
148 |
+
|
149 |
+
self.cond_hop_length = cond_hop_length
|
150 |
+
self.conv_layers = len(dilations)
|
151 |
+
self.conv_kernel_size = conv_kernel_size
|
152 |
+
|
153 |
+
self.kernel_predictor = KernelPredictor(
|
154 |
+
cond_channels=cond_channels,
|
155 |
+
conv_in_channels=in_channels,
|
156 |
+
conv_out_channels=2 * in_channels,
|
157 |
+
conv_layers=len(dilations),
|
158 |
+
conv_kernel_size=conv_kernel_size,
|
159 |
+
kpnet_hidden_channels=kpnet_hidden_channels,
|
160 |
+
kpnet_conv_size=kpnet_conv_size,
|
161 |
+
kpnet_dropout=kpnet_dropout,
|
162 |
+
kpnet_nonlinear_activation_params={"negative_slope": lReLU_slope},
|
163 |
+
)
|
164 |
+
|
165 |
+
if downsampling:
|
166 |
+
self.convt_pre = nn.Sequential(
|
167 |
+
nn.LeakyReLU(lReLU_slope),
|
168 |
+
weight_norm(nn.Conv1d(in_channels, in_channels, 2 * stride + 1, padding="same")),
|
169 |
+
nn.AvgPool1d(stride, stride),
|
170 |
+
)
|
171 |
+
else:
|
172 |
+
if stride == 1:
|
173 |
+
self.convt_pre = nn.Sequential(
|
174 |
+
nn.LeakyReLU(lReLU_slope),
|
175 |
+
weight_norm(nn.Conv1d(in_channels, in_channels, 1)),
|
176 |
+
)
|
177 |
+
else:
|
178 |
+
self.convt_pre = nn.Sequential(
|
179 |
+
nn.LeakyReLU(lReLU_slope),
|
180 |
+
weight_norm(
|
181 |
+
nn.ConvTranspose1d(
|
182 |
+
in_channels,
|
183 |
+
in_channels,
|
184 |
+
2 * stride,
|
185 |
+
stride=stride,
|
186 |
+
padding=stride // 2 + stride % 2,
|
187 |
+
output_padding=stride % 2,
|
188 |
+
)
|
189 |
+
),
|
190 |
+
)
|
191 |
+
|
192 |
+
self.amp_block = AMPBlock(in_channels)
|
193 |
+
|
194 |
+
self.conv_blocks = nn.ModuleList()
|
195 |
+
for d in dilations:
|
196 |
+
self.conv_blocks.append(
|
197 |
+
nn.Sequential(
|
198 |
+
nn.LeakyReLU(lReLU_slope),
|
199 |
+
weight_norm(nn.Conv1d(in_channels, in_channels, conv_kernel_size, dilation=d, padding="same")),
|
200 |
+
nn.LeakyReLU(lReLU_slope),
|
201 |
+
)
|
202 |
+
)
|
203 |
+
|
204 |
+
def forward(self, x, c):
|
205 |
+
"""forward propagation of the location-variable convolutions.
|
206 |
+
Args:
|
207 |
+
x (Tensor): the input sequence (batch, in_channels, in_length)
|
208 |
+
c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
|
209 |
+
|
210 |
+
Returns:
|
211 |
+
Tensor: the output sequence (batch, in_channels, in_length)
|
212 |
+
"""
|
213 |
+
_, in_channels, _ = x.shape # (B, c_g, L')
|
214 |
+
|
215 |
+
x = self.convt_pre(x) # (B, c_g, stride * L')
|
216 |
+
|
217 |
+
# Add one amp block just after the upsampling
|
218 |
+
x = self.amp_block(x) # (B, c_g, stride * L')
|
219 |
+
|
220 |
+
kernels, bias = self.kernel_predictor(c)
|
221 |
+
|
222 |
+
if self.add_extra_noise:
|
223 |
+
# Add extra noise to part of the feature
|
224 |
+
a, b = x.chunk(2, dim=1)
|
225 |
+
b = b + torch.randn_like(b) * 0.1
|
226 |
+
x = torch.cat([a, b], dim=1)
|
227 |
+
|
228 |
+
for i, conv in enumerate(self.conv_blocks):
|
229 |
+
output = conv(x) # (B, c_g, stride * L')
|
230 |
+
|
231 |
+
k = kernels[:, i, :, :, :, :] # (B, 2 * c_g, c_g, kernel_size, cond_length)
|
232 |
+
b = bias[:, i, :, :] # (B, 2 * c_g, cond_length)
|
233 |
+
|
234 |
+
output = self.location_variable_convolution(
|
235 |
+
output, k, b, hop_size=self.cond_hop_length
|
236 |
+
) # (B, 2 * c_g, stride * L'): LVC
|
237 |
+
x = x + torch.sigmoid(output[:, :in_channels, :]) * torch.tanh(
|
238 |
+
output[:, in_channels:, :]
|
239 |
+
) # (B, c_g, stride * L'): GAU
|
240 |
+
|
241 |
+
return x
|
242 |
+
|
243 |
+
def location_variable_convolution(self, x, kernel, bias, dilation=1, hop_size=256):
|
244 |
+
"""perform location-variable convolution operation on the input sequence (x) using the local convolution kernl.
|
245 |
+
Time: 414 μs ± 309 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each), test on NVIDIA V100.
|
246 |
+
Args:
|
247 |
+
x (Tensor): the input sequence (batch, in_channels, in_length).
|
248 |
+
kernel (Tensor): the local convolution kernel (batch, in_channel, out_channels, kernel_size, kernel_length)
|
249 |
+
bias (Tensor): the bias for the local convolution (batch, out_channels, kernel_length)
|
250 |
+
dilation (int): the dilation of convolution.
|
251 |
+
hop_size (int): the hop_size of the conditioning sequence.
|
252 |
+
Returns:
|
253 |
+
(Tensor): the output sequence after performing local convolution. (batch, out_channels, in_length).
|
254 |
+
"""
|
255 |
+
batch, _, in_length = x.shape
|
256 |
+
batch, _, out_channels, kernel_size, kernel_length = kernel.shape
|
257 |
+
|
258 |
+
assert in_length == (
|
259 |
+
kernel_length * hop_size
|
260 |
+
), f"length of (x, kernel) is not matched, {in_length} != {kernel_length} * {hop_size}"
|
261 |
+
|
262 |
+
padding = dilation * int((kernel_size - 1) / 2)
|
263 |
+
x = F.pad(x, (padding, padding), "constant", 0) # (batch, in_channels, in_length + 2*padding)
|
264 |
+
x = x.unfold(2, hop_size + 2 * padding, hop_size) # (batch, in_channels, kernel_length, hop_size + 2*padding)
|
265 |
+
|
266 |
+
if hop_size < dilation:
|
267 |
+
x = F.pad(x, (0, dilation), "constant", 0)
|
268 |
+
x = x.unfold(
|
269 |
+
3, dilation, dilation
|
270 |
+
) # (batch, in_channels, kernel_length, (hop_size + 2*padding)/dilation, dilation)
|
271 |
+
x = x[:, :, :, :, :hop_size]
|
272 |
+
x = x.transpose(3, 4) # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation)
|
273 |
+
x = x.unfold(4, kernel_size, 1) # (batch, in_channels, kernel_length, dilation, _, kernel_size)
|
274 |
+
|
275 |
+
o = torch.einsum("bildsk,biokl->bolsd", x, kernel)
|
276 |
+
o = o.to(memory_format=torch.channels_last_3d)
|
277 |
+
bias = bias.unsqueeze(-1).unsqueeze(-1).to(memory_format=torch.channels_last_3d)
|
278 |
+
o = o + bias
|
279 |
+
o = o.contiguous().view(batch, out_channels, -1)
|
280 |
+
|
281 |
+
return o
|
modules/repos_static/resemble_enhance/enhancer/univnet/mrstft.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
# Copyright 2019 Tomoki Hayashi
|
4 |
+
# MIT License (https://opensource.org/licenses/MIT)
|
5 |
+
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from torch import nn
|
10 |
+
|
11 |
+
from ..hparams import HParams
|
12 |
+
|
13 |
+
|
14 |
+
def _make_stft_cfg(hop_length, win_length=None):
|
15 |
+
if win_length is None:
|
16 |
+
win_length = 4 * hop_length
|
17 |
+
n_fft = 2 ** (win_length - 1).bit_length()
|
18 |
+
return dict(n_fft=n_fft, hop_length=hop_length, win_length=win_length)
|
19 |
+
|
20 |
+
|
21 |
+
def get_stft_cfgs(hp: HParams):
|
22 |
+
assert hp.wav_rate == 44100, f"wav_rate must be 44100, got {hp.wav_rate}"
|
23 |
+
return [_make_stft_cfg(h) for h in (100, 256, 512)]
|
24 |
+
|
25 |
+
|
26 |
+
def stft(x, n_fft, hop_length, win_length, window):
|
27 |
+
dtype = x.dtype
|
28 |
+
x = torch.stft(x.float(), n_fft, hop_length, win_length, window, return_complex=True)
|
29 |
+
x = x.abs().to(dtype)
|
30 |
+
x = x.transpose(2, 1) # (b f t) -> (b t f)
|
31 |
+
return x
|
32 |
+
|
33 |
+
|
34 |
+
class SpectralConvergengeLoss(nn.Module):
|
35 |
+
def forward(self, x_mag, y_mag):
|
36 |
+
"""Calculate forward propagation.
|
37 |
+
Args:
|
38 |
+
x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
|
39 |
+
y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
|
40 |
+
Returns:
|
41 |
+
Tensor: Spectral convergence loss value.
|
42 |
+
"""
|
43 |
+
return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro")
|
44 |
+
|
45 |
+
|
46 |
+
class LogSTFTMagnitudeLoss(nn.Module):
|
47 |
+
def forward(self, x_mag, y_mag):
|
48 |
+
"""Calculate forward propagation.
|
49 |
+
Args:
|
50 |
+
x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
|
51 |
+
y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
|
52 |
+
Returns:
|
53 |
+
Tensor: Log STFT magnitude loss value.
|
54 |
+
"""
|
55 |
+
return F.l1_loss(torch.log1p(x_mag), torch.log1p(y_mag))
|
56 |
+
|
57 |
+
|
58 |
+
class STFTLoss(nn.Module):
|
59 |
+
def __init__(self, hp, stft_cfg: dict, window="hann_window"):
|
60 |
+
super().__init__()
|
61 |
+
self.hp = hp
|
62 |
+
self.stft_cfg = stft_cfg
|
63 |
+
self.spectral_convergenge_loss = SpectralConvergengeLoss()
|
64 |
+
self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss()
|
65 |
+
self.register_buffer("window", getattr(torch, window)(stft_cfg["win_length"]), persistent=False)
|
66 |
+
|
67 |
+
def forward(self, x, y):
|
68 |
+
"""Calculate forward propagation.
|
69 |
+
Args:
|
70 |
+
x (Tensor): Predicted signal (B, T).
|
71 |
+
y (Tensor): Groundtruth signal (B, T).
|
72 |
+
Returns:
|
73 |
+
Tensor: Spectral convergence loss value.
|
74 |
+
Tensor: Log STFT magnitude loss value.
|
75 |
+
"""
|
76 |
+
stft_cfg = dict(self.stft_cfg)
|
77 |
+
x_mag = stft(x, **stft_cfg, window=self.window) # (b t) -> (b t f)
|
78 |
+
y_mag = stft(y, **stft_cfg, window=self.window)
|
79 |
+
sc_loss = self.spectral_convergenge_loss(x_mag, y_mag)
|
80 |
+
mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
|
81 |
+
return dict(sc=sc_loss, mag=mag_loss)
|
82 |
+
|
83 |
+
|
84 |
+
class MRSTFTLoss(nn.Module):
|
85 |
+
def __init__(self, hp: HParams, window="hann_window"):
|
86 |
+
"""Initialize Multi resolution STFT loss module.
|
87 |
+
Args:
|
88 |
+
resolutions (list): List of (FFT size, hop size, window length).
|
89 |
+
window (str): Window function type.
|
90 |
+
"""
|
91 |
+
super().__init__()
|
92 |
+
stft_cfgs = get_stft_cfgs(hp)
|
93 |
+
self.stft_losses = nn.ModuleList()
|
94 |
+
self.hp = hp
|
95 |
+
for c in stft_cfgs:
|
96 |
+
self.stft_losses += [STFTLoss(hp, c, window=window)]
|
97 |
+
|
98 |
+
def forward(self, x, y):
|
99 |
+
"""Calculate forward propagation.
|
100 |
+
Args:
|
101 |
+
x (Tensor): Predicted signal (b t).
|
102 |
+
y (Tensor): Groundtruth signal (b t).
|
103 |
+
Returns:
|
104 |
+
Tensor: Multi resolution spectral convergence loss value.
|
105 |
+
Tensor: Multi resolution log STFT magnitude loss value.
|
106 |
+
"""
|
107 |
+
assert x.dim() == 2 and y.dim() == 2, f"(b t) is expected, but got {x.shape} and {y.shape}."
|
108 |
+
|
109 |
+
dtype = x.dtype
|
110 |
+
|
111 |
+
x = x.float()
|
112 |
+
y = y.float()
|
113 |
+
|
114 |
+
# Align length
|
115 |
+
x = x[..., : y.shape[-1]]
|
116 |
+
y = y[..., : x.shape[-1]]
|
117 |
+
|
118 |
+
losses = {}
|
119 |
+
|
120 |
+
for f in self.stft_losses:
|
121 |
+
d = f(x, y)
|
122 |
+
for k, v in d.items():
|
123 |
+
losses.setdefault(k, []).append(v)
|
124 |
+
|
125 |
+
for k, v in losses.items():
|
126 |
+
losses[k] = torch.stack(v, dim=0).mean().to(dtype)
|
127 |
+
|
128 |
+
return losses
|
modules/repos_static/resemble_enhance/enhancer/univnet/univnet.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch import Tensor, nn
|
5 |
+
from torch.nn.utils.parametrizations import weight_norm
|
6 |
+
|
7 |
+
from ..hparams import HParams
|
8 |
+
from .lvcnet import LVCBlock
|
9 |
+
from .mrstft import MRSTFTLoss
|
10 |
+
|
11 |
+
|
12 |
+
class UnivNet(nn.Module):
|
13 |
+
@property
|
14 |
+
def d_noise(self):
|
15 |
+
return 128
|
16 |
+
|
17 |
+
@property
|
18 |
+
def strides(self):
|
19 |
+
return [7, 5, 4, 3]
|
20 |
+
|
21 |
+
@property
|
22 |
+
def dilations(self):
|
23 |
+
return [1, 3, 9, 27]
|
24 |
+
|
25 |
+
@property
|
26 |
+
def nc(self):
|
27 |
+
return self.hp.univnet_nc
|
28 |
+
|
29 |
+
@property
|
30 |
+
def scale_factor(self) -> int:
|
31 |
+
return self.hp.hop_size
|
32 |
+
|
33 |
+
def __init__(self, hp: HParams, d_input):
|
34 |
+
super().__init__()
|
35 |
+
self.d_input = d_input
|
36 |
+
|
37 |
+
self.hp = hp
|
38 |
+
|
39 |
+
self.blocks = nn.ModuleList(
|
40 |
+
[
|
41 |
+
LVCBlock(
|
42 |
+
self.nc,
|
43 |
+
d_input,
|
44 |
+
stride=stride,
|
45 |
+
dilations=self.dilations,
|
46 |
+
cond_hop_length=hop_length,
|
47 |
+
kpnet_conv_size=3,
|
48 |
+
)
|
49 |
+
for stride, hop_length in zip(self.strides, np.cumprod(self.strides))
|
50 |
+
]
|
51 |
+
)
|
52 |
+
|
53 |
+
self.conv_pre = weight_norm(nn.Conv1d(self.d_noise, self.nc, 7, padding=3, padding_mode="reflect"))
|
54 |
+
|
55 |
+
self.conv_post = nn.Sequential(
|
56 |
+
nn.LeakyReLU(0.2),
|
57 |
+
weight_norm(nn.Conv1d(self.nc, 1, 7, padding=3, padding_mode="reflect")),
|
58 |
+
nn.Tanh(),
|
59 |
+
)
|
60 |
+
|
61 |
+
self.mrstft = MRSTFTLoss(hp)
|
62 |
+
|
63 |
+
@property
|
64 |
+
def eps(self):
|
65 |
+
return 1e-5
|
66 |
+
|
67 |
+
def forward(self, x: Tensor, y: Tensor | None = None, npad=10):
|
68 |
+
"""
|
69 |
+
Args:
|
70 |
+
x: (b c t), acoustic features
|
71 |
+
y: (b t), waveform
|
72 |
+
Returns:
|
73 |
+
z: (b t), waveform
|
74 |
+
"""
|
75 |
+
assert x.ndim == 3, "x must be 3D tensor"
|
76 |
+
assert y is None or y.ndim == 2, "y must be 2D tensor"
|
77 |
+
assert x.shape[1] == self.d_input, f"x.shape[1] must be {self.d_input}, but got {x.shape}"
|
78 |
+
assert npad >= 0, "npad must be positive or zero"
|
79 |
+
|
80 |
+
x = F.pad(x, (0, npad), "constant", 0)
|
81 |
+
z = torch.randn(x.shape[0], self.d_noise, x.shape[2]).to(x)
|
82 |
+
z = self.conv_pre(z) # (b c t)
|
83 |
+
|
84 |
+
for block in self.blocks:
|
85 |
+
z = block(z, x) # (b c t)
|
86 |
+
|
87 |
+
z = self.conv_post(z) # (b 1 t)
|
88 |
+
z = z[..., : -self.scale_factor * npad]
|
89 |
+
z = z.squeeze(1) # (b t)
|
90 |
+
|
91 |
+
if y is not None:
|
92 |
+
self.losses = self.mrstft(z, y)
|
93 |
+
|
94 |
+
return z
|
modules/repos_static/resemble_enhance/hparams.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from dataclasses import asdict, dataclass
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
from omegaconf import OmegaConf
|
6 |
+
from rich.console import Console
|
7 |
+
from rich.panel import Panel
|
8 |
+
from rich.table import Table
|
9 |
+
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
|
12 |
+
console = Console()
|
13 |
+
|
14 |
+
|
15 |
+
def _make_stft_cfg(hop_length, win_length=None):
|
16 |
+
if win_length is None:
|
17 |
+
win_length = 4 * hop_length
|
18 |
+
n_fft = 2 ** (win_length - 1).bit_length()
|
19 |
+
return dict(n_fft=n_fft, hop_length=hop_length, win_length=win_length)
|
20 |
+
|
21 |
+
|
22 |
+
def _build_rich_table(rows, columns, title=None):
|
23 |
+
table = Table(title=title, header_style=None)
|
24 |
+
for column in columns:
|
25 |
+
table.add_column(column.capitalize(), justify="left")
|
26 |
+
for row in rows:
|
27 |
+
table.add_row(*map(str, row))
|
28 |
+
return Panel(table, expand=False)
|
29 |
+
|
30 |
+
|
31 |
+
def _rich_print_dict(d, title="Config", key="Key", value="Value"):
|
32 |
+
console.print(_build_rich_table(d.items(), [key, value], title))
|
33 |
+
|
34 |
+
|
35 |
+
@dataclass(frozen=True)
|
36 |
+
class HParams:
|
37 |
+
# Dataset
|
38 |
+
fg_dir: Path = Path("data/fg")
|
39 |
+
bg_dir: Path = Path("data/bg")
|
40 |
+
rir_dir: Path = Path("data/rir")
|
41 |
+
load_fg_only: bool = False
|
42 |
+
praat_augment_prob: float = 0
|
43 |
+
|
44 |
+
# Audio settings
|
45 |
+
wav_rate: int = 44_100
|
46 |
+
n_fft: int = 2048
|
47 |
+
win_size: int = 2048
|
48 |
+
hop_size: int = 420 # 9.5ms
|
49 |
+
num_mels: int = 128
|
50 |
+
stft_magnitude_min: float = 1e-4
|
51 |
+
preemphasis: float = 0.97
|
52 |
+
mix_alpha_range: tuple[float, float] = (0.2, 0.8)
|
53 |
+
|
54 |
+
# Training
|
55 |
+
nj: int = 64
|
56 |
+
training_seconds: float = 1.0
|
57 |
+
batch_size_per_gpu: int = 16
|
58 |
+
min_lr: float = 1e-5
|
59 |
+
max_lr: float = 1e-4
|
60 |
+
warmup_steps: int = 1000
|
61 |
+
max_steps: int = 1_000_000
|
62 |
+
gradient_clipping: float = 1.0
|
63 |
+
|
64 |
+
@property
|
65 |
+
def deepspeed_config(self):
|
66 |
+
return {
|
67 |
+
"train_micro_batch_size_per_gpu": self.batch_size_per_gpu,
|
68 |
+
"optimizer": {
|
69 |
+
"type": "Adam",
|
70 |
+
"params": {"lr": float(self.min_lr)},
|
71 |
+
},
|
72 |
+
"scheduler": {
|
73 |
+
"type": "WarmupDecayLR",
|
74 |
+
"params": {
|
75 |
+
"warmup_min_lr": float(self.min_lr),
|
76 |
+
"warmup_max_lr": float(self.max_lr),
|
77 |
+
"warmup_num_steps": self.warmup_steps,
|
78 |
+
"total_num_steps": self.max_steps,
|
79 |
+
"warmup_type": "linear",
|
80 |
+
},
|
81 |
+
},
|
82 |
+
"gradient_clipping": self.gradient_clipping,
|
83 |
+
}
|
84 |
+
|
85 |
+
@property
|
86 |
+
def stft_cfgs(self):
|
87 |
+
assert self.wav_rate == 44_100, f"wav_rate must be 44_100, got {self.wav_rate}"
|
88 |
+
return [_make_stft_cfg(h) for h in (100, 256, 512)]
|
89 |
+
|
90 |
+
@classmethod
|
91 |
+
def from_yaml(cls, path: Path) -> "HParams":
|
92 |
+
logger.info(f"Reading hparams from {path}")
|
93 |
+
# First merge to fix types (e.g., str -> Path)
|
94 |
+
return cls(**dict(OmegaConf.merge(cls(), OmegaConf.load(path))))
|
95 |
+
|
96 |
+
def save_if_not_exists(self, run_dir: Path):
|
97 |
+
path = run_dir / "hparams.yaml"
|
98 |
+
if path.exists():
|
99 |
+
logger.info(f"{path} already exists, not saving")
|
100 |
+
return
|
101 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
102 |
+
OmegaConf.save(asdict(self), str(path))
|
103 |
+
|
104 |
+
@classmethod
|
105 |
+
def load(cls, run_dir, yaml: Path | None = None):
|
106 |
+
hps = []
|
107 |
+
|
108 |
+
if (run_dir / "hparams.yaml").exists():
|
109 |
+
hps.append(cls.from_yaml(run_dir / "hparams.yaml"))
|
110 |
+
|
111 |
+
if yaml is not None:
|
112 |
+
hps.append(cls.from_yaml(yaml))
|
113 |
+
|
114 |
+
if len(hps) == 0:
|
115 |
+
hps.append(cls())
|
116 |
+
|
117 |
+
for hp in hps[1:]:
|
118 |
+
if hp != hps[0]:
|
119 |
+
errors = {}
|
120 |
+
for k, v in asdict(hp).items():
|
121 |
+
if getattr(hps[0], k) != v:
|
122 |
+
errors[k] = f"{getattr(hps[0], k)} != {v}"
|
123 |
+
raise ValueError(f"Found inconsistent hparams: {errors}, consider deleting {run_dir}")
|
124 |
+
|
125 |
+
return hps[0]
|
126 |
+
|
127 |
+
def print(self):
|
128 |
+
_rich_print_dict(asdict(self), title="HParams")
|
modules/repos_static/resemble_enhance/inference.py
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import time
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from torch.nn.utils.parametrize import remove_parametrizations
|
7 |
+
from torchaudio.functional import resample
|
8 |
+
from torchaudio.transforms import MelSpectrogram
|
9 |
+
from tqdm import trange
|
10 |
+
|
11 |
+
from .hparams import HParams
|
12 |
+
|
13 |
+
logger = logging.getLogger(__name__)
|
14 |
+
|
15 |
+
|
16 |
+
@torch.inference_mode()
|
17 |
+
def inference_chunk(model, dwav, sr, device, npad=441):
|
18 |
+
assert model.hp.wav_rate == sr, f"Expected {model.hp.wav_rate} Hz, got {sr} Hz"
|
19 |
+
del sr
|
20 |
+
|
21 |
+
length = dwav.shape[-1]
|
22 |
+
abs_max = dwav.abs().max().clamp(min=1e-7)
|
23 |
+
|
24 |
+
assert dwav.dim() == 1, f"Expected 1D waveform, got {dwav.dim()}D"
|
25 |
+
dwav = dwav.to(device)
|
26 |
+
dwav = dwav / abs_max # Normalize
|
27 |
+
dwav = F.pad(dwav, (0, npad))
|
28 |
+
hwav = model(dwav[None])[0].cpu() # (T,)
|
29 |
+
hwav = hwav[:length] # Trim padding
|
30 |
+
hwav = hwav * abs_max # Unnormalize
|
31 |
+
|
32 |
+
return hwav
|
33 |
+
|
34 |
+
|
35 |
+
def compute_corr(x, y):
|
36 |
+
return torch.fft.ifft(torch.fft.fft(x) * torch.fft.fft(y).conj()).abs()
|
37 |
+
|
38 |
+
|
39 |
+
def compute_offset(chunk1, chunk2, sr=44100):
|
40 |
+
"""
|
41 |
+
Args:
|
42 |
+
chunk1: (T,)
|
43 |
+
chunk2: (T,)
|
44 |
+
Returns:
|
45 |
+
offset: int, offset in samples such that chunk1 ~= chunk2.roll(-offset)
|
46 |
+
"""
|
47 |
+
hop_length = sr // 200 # 5 ms resolution
|
48 |
+
win_length = hop_length * 4
|
49 |
+
n_fft = 2 ** (win_length - 1).bit_length()
|
50 |
+
|
51 |
+
mel_fn = MelSpectrogram(
|
52 |
+
sample_rate=sr,
|
53 |
+
n_fft=n_fft,
|
54 |
+
win_length=win_length,
|
55 |
+
hop_length=hop_length,
|
56 |
+
n_mels=80,
|
57 |
+
f_min=0.0,
|
58 |
+
f_max=sr // 2,
|
59 |
+
)
|
60 |
+
|
61 |
+
spec1 = mel_fn(chunk1).log1p()
|
62 |
+
spec2 = mel_fn(chunk2).log1p()
|
63 |
+
|
64 |
+
corr = compute_corr(spec1, spec2) # (F, T)
|
65 |
+
corr = corr.mean(dim=0) # (T,)
|
66 |
+
|
67 |
+
argmax = corr.argmax().item()
|
68 |
+
|
69 |
+
if argmax > len(corr) // 2:
|
70 |
+
argmax -= len(corr)
|
71 |
+
|
72 |
+
offset = -argmax * hop_length
|
73 |
+
|
74 |
+
return offset
|
75 |
+
|
76 |
+
|
77 |
+
def merge_chunks(chunks, chunk_length, hop_length, sr=44100, length=None):
|
78 |
+
signal_length = (len(chunks) - 1) * hop_length + chunk_length
|
79 |
+
overlap_length = chunk_length - hop_length
|
80 |
+
signal = torch.zeros(signal_length, device=chunks[0].device)
|
81 |
+
|
82 |
+
fadein = torch.linspace(0, 1, overlap_length, device=chunks[0].device)
|
83 |
+
fadein = torch.cat([fadein, torch.ones(hop_length, device=chunks[0].device)])
|
84 |
+
fadeout = torch.linspace(1, 0, overlap_length, device=chunks[0].device)
|
85 |
+
fadeout = torch.cat([torch.ones(hop_length, device=chunks[0].device), fadeout])
|
86 |
+
|
87 |
+
for i, chunk in enumerate(chunks):
|
88 |
+
start = i * hop_length
|
89 |
+
end = start + chunk_length
|
90 |
+
|
91 |
+
if len(chunk) < chunk_length:
|
92 |
+
chunk = F.pad(chunk, (0, chunk_length - len(chunk)))
|
93 |
+
|
94 |
+
if i > 0:
|
95 |
+
pre_region = chunks[i - 1][-overlap_length:]
|
96 |
+
cur_region = chunk[:overlap_length]
|
97 |
+
offset = compute_offset(pre_region, cur_region, sr=sr)
|
98 |
+
start -= offset
|
99 |
+
end -= offset
|
100 |
+
|
101 |
+
if i == 0:
|
102 |
+
chunk = chunk * fadeout
|
103 |
+
elif i == len(chunks) - 1:
|
104 |
+
chunk = chunk * fadein
|
105 |
+
else:
|
106 |
+
chunk = chunk * fadein * fadeout
|
107 |
+
|
108 |
+
signal[start:end] += chunk[: len(signal[start:end])]
|
109 |
+
|
110 |
+
signal = signal[:length]
|
111 |
+
|
112 |
+
return signal
|
113 |
+
|
114 |
+
|
115 |
+
def remove_weight_norm_recursively(module):
|
116 |
+
for _, module in module.named_modules():
|
117 |
+
try:
|
118 |
+
remove_parametrizations(module, "weight")
|
119 |
+
except Exception:
|
120 |
+
pass
|
121 |
+
|
122 |
+
|
123 |
+
def inference(model, dwav, sr, device, chunk_seconds: float = 30.0, overlap_seconds: float = 1.0):
|
124 |
+
remove_weight_norm_recursively(model)
|
125 |
+
|
126 |
+
hp: HParams = model.hp
|
127 |
+
|
128 |
+
dwav = resample(
|
129 |
+
dwav,
|
130 |
+
orig_freq=sr,
|
131 |
+
new_freq=hp.wav_rate,
|
132 |
+
lowpass_filter_width=64,
|
133 |
+
rolloff=0.9475937167399596,
|
134 |
+
resampling_method="sinc_interp_kaiser",
|
135 |
+
beta=14.769656459379492,
|
136 |
+
)
|
137 |
+
|
138 |
+
del sr # Everything is in hp.wav_rate now
|
139 |
+
|
140 |
+
sr = hp.wav_rate
|
141 |
+
|
142 |
+
if torch.cuda.is_available():
|
143 |
+
torch.cuda.synchronize()
|
144 |
+
|
145 |
+
start_time = time.perf_counter()
|
146 |
+
|
147 |
+
chunk_length = int(sr * chunk_seconds)
|
148 |
+
overlap_length = int(sr * overlap_seconds)
|
149 |
+
hop_length = chunk_length - overlap_length
|
150 |
+
|
151 |
+
chunks = []
|
152 |
+
for start in trange(0, dwav.shape[-1], hop_length):
|
153 |
+
chunks.append(inference_chunk(model, dwav[start : start + chunk_length], sr, device))
|
154 |
+
|
155 |
+
hwav = merge_chunks(chunks, chunk_length, hop_length, sr=sr, length=dwav.shape[-1])
|
156 |
+
|
157 |
+
if torch.cuda.is_available():
|
158 |
+
torch.cuda.synchronize()
|
159 |
+
|
160 |
+
elapsed_time = time.perf_counter() - start_time
|
161 |
+
logger.info(f"Elapsed time: {elapsed_time:.3f} s, {hwav.shape[-1] / elapsed_time / 1000:.3f} kHz")
|
162 |
+
|
163 |
+
return hwav, sr
|
modules/repos_static/resemble_enhance/melspec.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
from torchaudio.transforms import MelSpectrogram as TorchMelSpectrogram
|
5 |
+
|
6 |
+
from .hparams import HParams
|
7 |
+
|
8 |
+
|
9 |
+
class MelSpectrogram(nn.Module):
|
10 |
+
def __init__(self, hp: HParams):
|
11 |
+
"""
|
12 |
+
Torch implementation of Resemble's mel extraction.
|
13 |
+
Note that the values are NOT identical to librosa's implementation
|
14 |
+
due to floating point precisions.
|
15 |
+
"""
|
16 |
+
super().__init__()
|
17 |
+
self.hp = hp
|
18 |
+
self.melspec = TorchMelSpectrogram(
|
19 |
+
hp.wav_rate,
|
20 |
+
n_fft=hp.n_fft,
|
21 |
+
win_length=hp.win_size,
|
22 |
+
hop_length=hp.hop_size,
|
23 |
+
f_min=0,
|
24 |
+
f_max=hp.wav_rate // 2,
|
25 |
+
n_mels=hp.num_mels,
|
26 |
+
power=1,
|
27 |
+
normalized=False,
|
28 |
+
# NOTE: Folowing librosa's default.
|
29 |
+
pad_mode="constant",
|
30 |
+
norm="slaney",
|
31 |
+
mel_scale="slaney",
|
32 |
+
)
|
33 |
+
self.register_buffer("stft_magnitude_min", torch.FloatTensor([hp.stft_magnitude_min]))
|
34 |
+
self.min_level_db = 20 * np.log10(hp.stft_magnitude_min)
|
35 |
+
self.preemphasis = hp.preemphasis
|
36 |
+
self.hop_size = hp.hop_size
|
37 |
+
|
38 |
+
def forward(self, wav, pad=True):
|
39 |
+
"""
|
40 |
+
Args:
|
41 |
+
wav: [B, T]
|
42 |
+
"""
|
43 |
+
device = wav.device
|
44 |
+
if wav.is_mps:
|
45 |
+
wav = wav.cpu()
|
46 |
+
self.to(wav.device)
|
47 |
+
if self.preemphasis > 0:
|
48 |
+
wav = torch.nn.functional.pad(wav, [1, 0], value=0)
|
49 |
+
wav = wav[..., 1:] - self.preemphasis * wav[..., :-1]
|
50 |
+
mel = self.melspec(wav)
|
51 |
+
mel = self._amp_to_db(mel)
|
52 |
+
mel_normed = self._normalize(mel)
|
53 |
+
assert not pad or mel_normed.shape[-1] == 1 + wav.shape[-1] // self.hop_size # Sanity check
|
54 |
+
mel_normed = mel_normed.to(device)
|
55 |
+
return mel_normed # (M, T)
|
56 |
+
|
57 |
+
def _normalize(self, s, headroom_db=15):
|
58 |
+
return (s - self.min_level_db) / (-self.min_level_db + headroom_db)
|
59 |
+
|
60 |
+
def _amp_to_db(self, x):
|
61 |
+
return x.clamp_min(self.hp.stft_magnitude_min).log10() * 20
|
modules/repos_static/resemble_enhance/utils/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .logging import setup_logging
|
2 |
+
from .utils import save_mels, tree_map
|
modules/repos_static/resemble_enhance/utils/control.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import selectors
|
3 |
+
import sys
|
4 |
+
from functools import cache
|
5 |
+
|
6 |
+
from .distributed import global_leader_only
|
7 |
+
|
8 |
+
_logger = logging.getLogger(__name__)
|
9 |
+
|
10 |
+
|
11 |
+
@cache
|
12 |
+
def _get_stdin_selector():
|
13 |
+
selector = selectors.DefaultSelector()
|
14 |
+
selector.register(fileobj=sys.stdin, events=selectors.EVENT_READ)
|
15 |
+
return selector
|
16 |
+
|
17 |
+
|
18 |
+
@global_leader_only(boardcast_return=True)
|
19 |
+
def non_blocking_input():
|
20 |
+
s = ""
|
21 |
+
selector = _get_stdin_selector()
|
22 |
+
events = selector.select(timeout=0)
|
23 |
+
for key, _ in events:
|
24 |
+
s: str = key.fileobj.readline().strip()
|
25 |
+
_logger.info(f'Get stdin "{s}".')
|
26 |
+
return s
|
modules/repos_static/resemble_enhance/utils/logging.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
from rich.logging import RichHandler
|
5 |
+
|
6 |
+
from .distributed import global_leader_only
|
7 |
+
|
8 |
+
|
9 |
+
@global_leader_only
|
10 |
+
def setup_logging(run_dir):
|
11 |
+
handlers = []
|
12 |
+
stdout_handler = RichHandler()
|
13 |
+
stdout_handler.setLevel(logging.INFO)
|
14 |
+
handlers.append(stdout_handler)
|
15 |
+
|
16 |
+
if run_dir is not None:
|
17 |
+
filename = Path(run_dir) / f"log.txt"
|
18 |
+
filename.parent.mkdir(parents=True, exist_ok=True)
|
19 |
+
file_handler = logging.FileHandler(filename, mode="a")
|
20 |
+
file_handler.setLevel(logging.DEBUG)
|
21 |
+
handlers.append(file_handler)
|
22 |
+
|
23 |
+
# Update all existing loggers
|
24 |
+
for name in ["DeepSpeed"]:
|
25 |
+
logger = logging.getLogger(name)
|
26 |
+
if isinstance(logger, logging.Logger):
|
27 |
+
for handler in list(logger.handlers):
|
28 |
+
logger.removeHandler(handler)
|
29 |
+
for handler in handlers:
|
30 |
+
logger.addHandler(handler)
|
31 |
+
|
32 |
+
# Set the default logger
|
33 |
+
logging.basicConfig(
|
34 |
+
level=logging.getLevelName("INFO"),
|
35 |
+
format="%(message)s",
|
36 |
+
datefmt="[%X]",
|
37 |
+
handlers=handlers,
|
38 |
+
)
|
modules/repos_static/resemble_enhance/utils/utils.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable, TypeVar, overload
|
2 |
+
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
|
7 |
+
def save_mels(path, *, targ_mel, pred_mel, cond_mel):
|
8 |
+
n = 3 if cond_mel is None else 4
|
9 |
+
|
10 |
+
plt.figure(figsize=(10, n * 4))
|
11 |
+
|
12 |
+
i = 1
|
13 |
+
|
14 |
+
plt.subplot(n, 1, i)
|
15 |
+
plt.imshow(pred_mel, origin="lower", interpolation="none")
|
16 |
+
plt.title(f"Pred mel {pred_mel.shape}")
|
17 |
+
i += 1
|
18 |
+
|
19 |
+
plt.subplot(n, 1, i)
|
20 |
+
plt.imshow(targ_mel, origin="lower", interpolation="none")
|
21 |
+
plt.title(f"GT mel {targ_mel.shape}")
|
22 |
+
i += 1
|
23 |
+
|
24 |
+
plt.subplot(n, 1, i)
|
25 |
+
pred_mel = pred_mel[:, : targ_mel.shape[1]]
|
26 |
+
targ_mel = targ_mel[:, : pred_mel.shape[1]]
|
27 |
+
plt.imshow(np.abs(pred_mel - targ_mel), origin="lower", interpolation="none")
|
28 |
+
plt.title(f"Diff mel {pred_mel.shape}, mse={np.mean((pred_mel - targ_mel)**2):.4f}")
|
29 |
+
i += 1
|
30 |
+
|
31 |
+
if cond_mel is not None:
|
32 |
+
plt.subplot(n, 1, i)
|
33 |
+
plt.imshow(cond_mel, origin="lower", interpolation="none")
|
34 |
+
plt.title(f"Cond mel {cond_mel.shape}")
|
35 |
+
i += 1
|
36 |
+
|
37 |
+
plt.savefig(path, dpi=480)
|
38 |
+
plt.close()
|
39 |
+
|
40 |
+
|
41 |
+
T = TypeVar("T")
|
42 |
+
|
43 |
+
|
44 |
+
@overload
|
45 |
+
def tree_map(fn: Callable, x: list[T]) -> list[T]:
|
46 |
+
...
|
47 |
+
|
48 |
+
|
49 |
+
@overload
|
50 |
+
def tree_map(fn: Callable, x: tuple[T]) -> tuple[T]:
|
51 |
+
...
|
52 |
+
|
53 |
+
|
54 |
+
@overload
|
55 |
+
def tree_map(fn: Callable, x: dict[str, T]) -> dict[str, T]:
|
56 |
+
...
|
57 |
+
|
58 |
+
|
59 |
+
@overload
|
60 |
+
def tree_map(fn: Callable, x: T) -> T:
|
61 |
+
...
|
62 |
+
|
63 |
+
|
64 |
+
def tree_map(fn: Callable, x):
|
65 |
+
if isinstance(x, list):
|
66 |
+
x = [tree_map(fn, xi) for xi in x]
|
67 |
+
elif isinstance(x, tuple):
|
68 |
+
x = (tree_map(fn, xi) for xi in x)
|
69 |
+
elif isinstance(x, dict):
|
70 |
+
x = {k: tree_map(fn, v) for k, v in x.items()}
|
71 |
+
else:
|
72 |
+
x = fn(x)
|
73 |
+
return x
|
modules/speaker.py
CHANGED
@@ -99,6 +99,10 @@ class SpeakerManager:
|
|
99 |
self.speakers[speaker_file] = Speaker.from_file(
|
100 |
self.speaker_dir + speaker_file
|
101 |
)
|
|
|
|
|
|
|
|
|
102 |
|
103 |
def list_speakers(self):
|
104 |
return list(self.speakers.values())
|
|
|
99 |
self.speakers[speaker_file] = Speaker.from_file(
|
100 |
self.speaker_dir + speaker_file
|
101 |
)
|
102 |
+
# 检查是否有被删除的,同步到 speakers
|
103 |
+
for fname, spk in self.speakers.items():
|
104 |
+
if not os.path.exists(self.speaker_dir + fname):
|
105 |
+
del self.speakers[fname]
|
106 |
|
107 |
def list_speakers(self):
|
108 |
return list(self.speakers.values())
|
modules/webui/speaker/__init__.py
ADDED
File without changes
|
modules/webui/speaker/speaker_creator.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
from modules.speaker import Speaker
|
4 |
+
from modules.utils.SeedContext import SeedContext
|
5 |
+
from modules.hf import spaces
|
6 |
+
from modules.models import load_chat_tts
|
7 |
+
from modules.utils.rng import np_rng
|
8 |
+
from modules.webui.webui_utils import get_speakers, tts_generate
|
9 |
+
|
10 |
+
import tempfile
|
11 |
+
|
12 |
+
names_list = [
|
13 |
+
"Alice",
|
14 |
+
"Bob",
|
15 |
+
"Carol",
|
16 |
+
"Carlos",
|
17 |
+
"Charlie",
|
18 |
+
"Chuck",
|
19 |
+
"Chad",
|
20 |
+
"Craig",
|
21 |
+
"Dan",
|
22 |
+
"Dave",
|
23 |
+
"David",
|
24 |
+
"Erin",
|
25 |
+
"Eve",
|
26 |
+
"Yves",
|
27 |
+
"Faythe",
|
28 |
+
"Frank",
|
29 |
+
"Grace",
|
30 |
+
"Heidi",
|
31 |
+
"Ivan",
|
32 |
+
"Judy",
|
33 |
+
"Mallory",
|
34 |
+
"Mallet",
|
35 |
+
"Darth",
|
36 |
+
"Michael",
|
37 |
+
"Mike",
|
38 |
+
"Niaj",
|
39 |
+
"Olivia",
|
40 |
+
"Oscar",
|
41 |
+
"Peggy",
|
42 |
+
"Pat",
|
43 |
+
"Rupert",
|
44 |
+
"Sybil",
|
45 |
+
"Trent",
|
46 |
+
"Ted",
|
47 |
+
"Trudy",
|
48 |
+
"Victor",
|
49 |
+
"Vanna",
|
50 |
+
"Walter",
|
51 |
+
"Wendy",
|
52 |
+
]
|
53 |
+
|
54 |
+
|
55 |
+
@torch.inference_mode()
|
56 |
+
@spaces.GPU
|
57 |
+
def create_spk_from_seed(
|
58 |
+
seed: int,
|
59 |
+
name: str,
|
60 |
+
gender: str,
|
61 |
+
desc: str,
|
62 |
+
):
|
63 |
+
chat_tts = load_chat_tts()
|
64 |
+
with SeedContext(seed):
|
65 |
+
emb = chat_tts.sample_random_speaker()
|
66 |
+
spk = Speaker(seed=-2, name=name, gender=gender, describe=desc)
|
67 |
+
spk.emb = emb
|
68 |
+
|
69 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".pt") as tmp_file:
|
70 |
+
torch.save(spk, tmp_file)
|
71 |
+
tmp_file_path = tmp_file.name
|
72 |
+
|
73 |
+
return tmp_file_path
|
74 |
+
|
75 |
+
|
76 |
+
@torch.inference_mode()
|
77 |
+
@spaces.GPU
|
78 |
+
def test_spk_voice(seed: int, text: str):
|
79 |
+
return tts_generate(
|
80 |
+
spk=seed,
|
81 |
+
text=text,
|
82 |
+
)
|
83 |
+
|
84 |
+
|
85 |
+
def random_speaker():
|
86 |
+
seed = np_rng()
|
87 |
+
name = names_list[seed % len(names_list)]
|
88 |
+
return seed, name
|
89 |
+
|
90 |
+
|
91 |
+
creator_ui_desc = """
|
92 |
+
## Speaker Creator
|
93 |
+
使用本面板快捷抽卡生成 speaker.pt 文件。
|
94 |
+
|
95 |
+
1. **生成说话人**:输入种子、名字、性别和描述。点击 "Generate speaker.pt" 按钮,生成的说话人配置会保存为.pt文件。
|
96 |
+
2. **测试说话人声音**:输入测试文本。点击 "Test Voice" 按钮,生成的音频会在 "Output Audio" 中播放。
|
97 |
+
3. **随机生成说话人**:点击 "Random Speaker" 按钮,随机生成一个种子和名字,可以进一步编辑其他信息并测试。
|
98 |
+
"""
|
99 |
+
|
100 |
+
|
101 |
+
def speaker_creator_ui():
|
102 |
+
def on_generate(seed, name, gender, desc):
|
103 |
+
file_path = create_spk_from_seed(seed, name, gender, desc)
|
104 |
+
return file_path
|
105 |
+
|
106 |
+
def create_test_voice_card(seed_input):
|
107 |
+
with gr.Group():
|
108 |
+
gr.Markdown("🎤Test voice")
|
109 |
+
with gr.Row():
|
110 |
+
test_voice_btn = gr.Button("Test Voice", variant="secondary")
|
111 |
+
|
112 |
+
with gr.Column(scale=4):
|
113 |
+
test_text = gr.Textbox(
|
114 |
+
label="Test Text",
|
115 |
+
placeholder="Please input test text",
|
116 |
+
value="说话人测试 123456789 [uv_break] ok, test done [lbreak]",
|
117 |
+
)
|
118 |
+
with gr.Row():
|
119 |
+
current_seed = gr.Label(label="Current Seed", value=-1)
|
120 |
+
with gr.Column(scale=4):
|
121 |
+
output_audio = gr.Audio(label="Output Audio")
|
122 |
+
|
123 |
+
test_voice_btn.click(
|
124 |
+
fn=test_spk_voice,
|
125 |
+
inputs=[seed_input, test_text],
|
126 |
+
outputs=[output_audio],
|
127 |
+
)
|
128 |
+
test_voice_btn.click(
|
129 |
+
fn=lambda x: x,
|
130 |
+
inputs=[seed_input],
|
131 |
+
outputs=[current_seed],
|
132 |
+
)
|
133 |
+
|
134 |
+
gr.Markdown(creator_ui_desc)
|
135 |
+
|
136 |
+
with gr.Row():
|
137 |
+
with gr.Column(scale=2):
|
138 |
+
with gr.Group():
|
139 |
+
gr.Markdown("ℹ️Speaker info")
|
140 |
+
seed_input = gr.Number(label="Seed", value=2)
|
141 |
+
name_input = gr.Textbox(
|
142 |
+
label="Name", placeholder="Enter speaker name", value="Bob"
|
143 |
+
)
|
144 |
+
gender_input = gr.Textbox(
|
145 |
+
label="Gender", placeholder="Enter gender", value="*"
|
146 |
+
)
|
147 |
+
desc_input = gr.Textbox(
|
148 |
+
label="Description",
|
149 |
+
placeholder="Enter description",
|
150 |
+
)
|
151 |
+
random_button = gr.Button("Random Speaker")
|
152 |
+
with gr.Group():
|
153 |
+
gr.Markdown("🔊Generate speaker.pt")
|
154 |
+
generate_button = gr.Button("Save .pt file")
|
155 |
+
output_file = gr.File(label="Save to File")
|
156 |
+
with gr.Column(scale=5):
|
157 |
+
create_test_voice_card(seed_input=seed_input)
|
158 |
+
create_test_voice_card(seed_input=seed_input)
|
159 |
+
create_test_voice_card(seed_input=seed_input)
|
160 |
+
create_test_voice_card(seed_input=seed_input)
|
161 |
+
|
162 |
+
random_button.click(
|
163 |
+
random_speaker,
|
164 |
+
outputs=[seed_input, name_input],
|
165 |
+
)
|
166 |
+
|
167 |
+
generate_button.click(
|
168 |
+
fn=on_generate,
|
169 |
+
inputs=[seed_input, name_input, gender_input, desc_input],
|
170 |
+
outputs=[output_file],
|
171 |
+
)
|
modules/webui/speaker/speaker_merger.py
ADDED
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import gradio as gr
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from modules.hf import spaces
|
6 |
+
from modules.webui.webui_utils import get_speakers, tts_generate
|
7 |
+
from modules.speaker import speaker_mgr, Speaker
|
8 |
+
|
9 |
+
import tempfile
|
10 |
+
|
11 |
+
|
12 |
+
def spk_to_tensor(spk):
|
13 |
+
spk = spk.split(" : ")[1].strip() if " : " in spk else spk
|
14 |
+
if spk == "None" or spk == "":
|
15 |
+
return None
|
16 |
+
return speaker_mgr.get_speaker(spk).emb
|
17 |
+
|
18 |
+
|
19 |
+
def get_speaker_show_name(spk):
|
20 |
+
if spk.gender == "*" or spk.gender == "":
|
21 |
+
return spk.name
|
22 |
+
return f"{spk.gender} : {spk.name}"
|
23 |
+
|
24 |
+
|
25 |
+
def merge_spk(
|
26 |
+
spk_a,
|
27 |
+
spk_a_w,
|
28 |
+
spk_b,
|
29 |
+
spk_b_w,
|
30 |
+
spk_c,
|
31 |
+
spk_c_w,
|
32 |
+
spk_d,
|
33 |
+
spk_d_w,
|
34 |
+
):
|
35 |
+
tensor_a = spk_to_tensor(spk_a)
|
36 |
+
tensor_b = spk_to_tensor(spk_b)
|
37 |
+
tensor_c = spk_to_tensor(spk_c)
|
38 |
+
tensor_d = spk_to_tensor(spk_d)
|
39 |
+
|
40 |
+
assert (
|
41 |
+
tensor_a is not None
|
42 |
+
or tensor_b is not None
|
43 |
+
or tensor_c is not None
|
44 |
+
or tensor_d is not None
|
45 |
+
), "At least one speaker should be selected"
|
46 |
+
|
47 |
+
merge_tensor = torch.zeros_like(
|
48 |
+
tensor_a
|
49 |
+
if tensor_a is not None
|
50 |
+
else (
|
51 |
+
tensor_b
|
52 |
+
if tensor_b is not None
|
53 |
+
else tensor_c if tensor_c is not None else tensor_d
|
54 |
+
)
|
55 |
+
)
|
56 |
+
|
57 |
+
total_weight = 0
|
58 |
+
if tensor_a is not None:
|
59 |
+
merge_tensor += spk_a_w * tensor_a
|
60 |
+
total_weight += spk_a_w
|
61 |
+
if tensor_b is not None:
|
62 |
+
merge_tensor += spk_b_w * tensor_b
|
63 |
+
total_weight += spk_b_w
|
64 |
+
if tensor_c is not None:
|
65 |
+
merge_tensor += spk_c_w * tensor_c
|
66 |
+
total_weight += spk_c_w
|
67 |
+
if tensor_d is not None:
|
68 |
+
merge_tensor += spk_d_w * tensor_d
|
69 |
+
total_weight += spk_d_w
|
70 |
+
|
71 |
+
if total_weight > 0:
|
72 |
+
merge_tensor /= total_weight
|
73 |
+
|
74 |
+
merged_spk = Speaker.from_tensor(merge_tensor)
|
75 |
+
merged_spk.name = "<MIX>"
|
76 |
+
|
77 |
+
return merged_spk
|
78 |
+
|
79 |
+
|
80 |
+
@torch.inference_mode()
|
81 |
+
@spaces.GPU
|
82 |
+
def merge_and_test_spk_voice(
|
83 |
+
spk_a, spk_a_w, spk_b, spk_b_w, spk_c, spk_c_w, spk_d, spk_d_w, test_text
|
84 |
+
):
|
85 |
+
merged_spk = merge_spk(
|
86 |
+
spk_a,
|
87 |
+
spk_a_w,
|
88 |
+
spk_b,
|
89 |
+
spk_b_w,
|
90 |
+
spk_c,
|
91 |
+
spk_c_w,
|
92 |
+
spk_d,
|
93 |
+
spk_d_w,
|
94 |
+
)
|
95 |
+
return tts_generate(
|
96 |
+
spk=merged_spk,
|
97 |
+
text=test_text,
|
98 |
+
)
|
99 |
+
|
100 |
+
|
101 |
+
@torch.inference_mode()
|
102 |
+
@spaces.GPU
|
103 |
+
def merge_spk_to_file(
|
104 |
+
spk_a,
|
105 |
+
spk_a_w,
|
106 |
+
spk_b,
|
107 |
+
spk_b_w,
|
108 |
+
spk_c,
|
109 |
+
spk_c_w,
|
110 |
+
spk_d,
|
111 |
+
spk_d_w,
|
112 |
+
speaker_name,
|
113 |
+
speaker_gender,
|
114 |
+
speaker_desc,
|
115 |
+
):
|
116 |
+
merged_spk = merge_spk(
|
117 |
+
spk_a, spk_a_w, spk_b, spk_b_w, spk_c, spk_c_w, spk_d, spk_d_w
|
118 |
+
)
|
119 |
+
merged_spk.name = speaker_name
|
120 |
+
merged_spk.gender = speaker_gender
|
121 |
+
merged_spk.desc = speaker_desc
|
122 |
+
|
123 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".pt") as tmp_file:
|
124 |
+
torch.save(merged_spk, tmp_file)
|
125 |
+
tmp_file_path = tmp_file.name
|
126 |
+
|
127 |
+
return tmp_file_path
|
128 |
+
|
129 |
+
|
130 |
+
merge_desc = """
|
131 |
+
## Speaker Merger
|
132 |
+
|
133 |
+
在本面板中,您可以选择多个说话人并指定他们的权重,合成新的语音并进行测试。以下是各个功能的详细说明:
|
134 |
+
|
135 |
+
1. 选择说话人: 您可以从下拉菜单中选择最多四个说话人(A、B、C、D),每个说话人都有一个对应的权重滑块,范围从0到10。权重决定了每个说话人在合成语音中的影响程度。
|
136 |
+
2. 合成语音: 在选择好说话人和设置好权重后,您可以在“Test Text”框中输入要测试的文本,然后点击“测试语音”按钮来生成并播放合成的语音。
|
137 |
+
3. 保存说话人: 您还可以在右侧的“说话人信息”部分填写新的说话人的名称、性别和描述,并点击“Save Speaker”按钮来保存合成的说话人。保存后的说话人文件将显示在“Merged Speaker”栏中,供下载使用。
|
138 |
+
"""
|
139 |
+
|
140 |
+
|
141 |
+
def get_spk_choices():
|
142 |
+
speakers = get_speakers()
|
143 |
+
|
144 |
+
speaker_names = ["None"] + [get_speaker_show_name(speaker) for speaker in speakers]
|
145 |
+
return speaker_names
|
146 |
+
|
147 |
+
|
148 |
+
# 显示 a b c d 四个选择框,选择一个或多个,然后可以试音,并导出
|
149 |
+
def create_speaker_merger():
|
150 |
+
speaker_names = get_spk_choices()
|
151 |
+
|
152 |
+
gr.Markdown(merge_desc)
|
153 |
+
|
154 |
+
def spk_picker(label_tail: str):
|
155 |
+
with gr.Row():
|
156 |
+
spk_a = gr.Dropdown(
|
157 |
+
choices=speaker_names, value="None", label=f"Speaker {label_tail}"
|
158 |
+
)
|
159 |
+
refresh_a_btn = gr.Button("🔄", variant="secondary")
|
160 |
+
|
161 |
+
def refresh_a():
|
162 |
+
speaker_mgr.refresh_speakers()
|
163 |
+
speaker_names = get_spk_choices()
|
164 |
+
return gr.update(choices=speaker_names)
|
165 |
+
|
166 |
+
refresh_a_btn.click(refresh_a, outputs=[spk_a])
|
167 |
+
spk_a_w = gr.Slider(
|
168 |
+
value=1,
|
169 |
+
minimum=0,
|
170 |
+
maximum=10,
|
171 |
+
step=0.1,
|
172 |
+
label=f"Weight {label_tail}",
|
173 |
+
)
|
174 |
+
return spk_a, spk_a_w
|
175 |
+
|
176 |
+
with gr.Row():
|
177 |
+
with gr.Column(scale=5):
|
178 |
+
with gr.Row():
|
179 |
+
with gr.Group():
|
180 |
+
spk_a, spk_a_w = spk_picker("A")
|
181 |
+
|
182 |
+
with gr.Group():
|
183 |
+
spk_b, spk_b_w = spk_picker("B")
|
184 |
+
|
185 |
+
with gr.Group():
|
186 |
+
spk_c, spk_c_w = spk_picker("C")
|
187 |
+
|
188 |
+
with gr.Group():
|
189 |
+
spk_d, spk_d_w = spk_picker("D")
|
190 |
+
|
191 |
+
with gr.Row():
|
192 |
+
with gr.Column(scale=3):
|
193 |
+
with gr.Group():
|
194 |
+
gr.Markdown("🎤Test voice")
|
195 |
+
with gr.Row():
|
196 |
+
test_voice_btn = gr.Button(
|
197 |
+
"Test Voice", variant="secondary"
|
198 |
+
)
|
199 |
+
|
200 |
+
with gr.Column(scale=4):
|
201 |
+
test_text = gr.Textbox(
|
202 |
+
label="Test Text",
|
203 |
+
placeholder="Please input test text",
|
204 |
+
value="说话人合并测试 123456789 [uv_break] ok, test done [lbreak]",
|
205 |
+
)
|
206 |
+
|
207 |
+
output_audio = gr.Audio(label="Output Audio")
|
208 |
+
|
209 |
+
with gr.Column(scale=1):
|
210 |
+
with gr.Group():
|
211 |
+
gr.Markdown("🗃️Save to file")
|
212 |
+
|
213 |
+
speaker_name = gr.Textbox(label="Name", value="forge_speaker_merged")
|
214 |
+
speaker_gender = gr.Textbox(label="Gender", value="*")
|
215 |
+
speaker_desc = gr.Textbox(label="Description", value="merged speaker")
|
216 |
+
|
217 |
+
save_btn = gr.Button("Save Speaker", variant="primary")
|
218 |
+
|
219 |
+
merged_spker = gr.File(
|
220 |
+
label="Merged Speaker", interactive=False, type="binary"
|
221 |
+
)
|
222 |
+
|
223 |
+
test_voice_btn.click(
|
224 |
+
merge_and_test_spk_voice,
|
225 |
+
inputs=[
|
226 |
+
spk_a,
|
227 |
+
spk_a_w,
|
228 |
+
spk_b,
|
229 |
+
spk_b_w,
|
230 |
+
spk_c,
|
231 |
+
spk_c_w,
|
232 |
+
spk_d,
|
233 |
+
spk_d_w,
|
234 |
+
test_text,
|
235 |
+
],
|
236 |
+
outputs=[output_audio],
|
237 |
+
)
|
238 |
+
|
239 |
+
save_btn.click(
|
240 |
+
merge_spk_to_file,
|
241 |
+
inputs=[
|
242 |
+
spk_a,
|
243 |
+
spk_a_w,
|
244 |
+
spk_b,
|
245 |
+
spk_b_w,
|
246 |
+
spk_c,
|
247 |
+
spk_c_w,
|
248 |
+
spk_d,
|
249 |
+
spk_d_w,
|
250 |
+
speaker_name,
|
251 |
+
speaker_gender,
|
252 |
+
speaker_desc,
|
253 |
+
],
|
254 |
+
outputs=[merged_spker],
|
255 |
+
)
|