zhzluke96 commited on
Commit
32b2aaa
1 Parent(s): 2ca1c87
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. modules/Enhancer/ResembleEnhance.py +3 -8
  2. modules/repos_static/__init__.py +0 -0
  3. modules/repos_static/readme.md +5 -0
  4. modules/repos_static/resemble_enhance/__init__.py +0 -0
  5. modules/repos_static/resemble_enhance/common.py +55 -0
  6. modules/repos_static/resemble_enhance/data/__init__.py +48 -0
  7. modules/repos_static/resemble_enhance/data/dataset.py +171 -0
  8. modules/repos_static/resemble_enhance/data/distorter/__init__.py +1 -0
  9. modules/repos_static/resemble_enhance/data/distorter/base.py +104 -0
  10. modules/repos_static/resemble_enhance/data/distorter/custom.py +85 -0
  11. modules/repos_static/resemble_enhance/data/distorter/distorter.py +32 -0
  12. modules/repos_static/resemble_enhance/data/distorter/sox.py +176 -0
  13. modules/repos_static/resemble_enhance/data/utils.py +43 -0
  14. modules/repos_static/resemble_enhance/denoiser/__init__.py +0 -0
  15. modules/repos_static/resemble_enhance/denoiser/__main__.py +30 -0
  16. modules/repos_static/resemble_enhance/denoiser/denoiser.py +181 -0
  17. modules/repos_static/resemble_enhance/denoiser/hparams.py +9 -0
  18. modules/repos_static/resemble_enhance/denoiser/inference.py +31 -0
  19. modules/repos_static/resemble_enhance/denoiser/unet.py +144 -0
  20. modules/repos_static/resemble_enhance/enhancer/__init__.py +0 -0
  21. modules/repos_static/resemble_enhance/enhancer/__main__.py +129 -0
  22. modules/repos_static/resemble_enhance/enhancer/download.py +30 -0
  23. modules/repos_static/resemble_enhance/enhancer/enhancer.py +185 -0
  24. modules/repos_static/resemble_enhance/enhancer/hparams.py +23 -0
  25. modules/repos_static/resemble_enhance/enhancer/inference.py +48 -0
  26. modules/repos_static/resemble_enhance/enhancer/lcfm/__init__.py +2 -0
  27. modules/repos_static/resemble_enhance/enhancer/lcfm/cfm.py +372 -0
  28. modules/repos_static/resemble_enhance/enhancer/lcfm/irmae.py +123 -0
  29. modules/repos_static/resemble_enhance/enhancer/lcfm/lcfm.py +152 -0
  30. modules/repos_static/resemble_enhance/enhancer/lcfm/wn.py +147 -0
  31. modules/repos_static/resemble_enhance/enhancer/univnet/__init__.py +1 -0
  32. modules/repos_static/resemble_enhance/enhancer/univnet/alias_free_torch/__init__.py +5 -0
  33. modules/repos_static/resemble_enhance/enhancer/univnet/alias_free_torch/filter.py +95 -0
  34. modules/repos_static/resemble_enhance/enhancer/univnet/alias_free_torch/resample.py +49 -0
  35. modules/repos_static/resemble_enhance/enhancer/univnet/amp.py +101 -0
  36. modules/repos_static/resemble_enhance/enhancer/univnet/discriminator.py +210 -0
  37. modules/repos_static/resemble_enhance/enhancer/univnet/lvcnet.py +281 -0
  38. modules/repos_static/resemble_enhance/enhancer/univnet/mrstft.py +128 -0
  39. modules/repos_static/resemble_enhance/enhancer/univnet/univnet.py +94 -0
  40. modules/repos_static/resemble_enhance/hparams.py +128 -0
  41. modules/repos_static/resemble_enhance/inference.py +163 -0
  42. modules/repos_static/resemble_enhance/melspec.py +61 -0
  43. modules/repos_static/resemble_enhance/utils/__init__.py +2 -0
  44. modules/repos_static/resemble_enhance/utils/control.py +26 -0
  45. modules/repos_static/resemble_enhance/utils/logging.py +38 -0
  46. modules/repos_static/resemble_enhance/utils/utils.py +73 -0
  47. modules/speaker.py +4 -0
  48. modules/webui/speaker/__init__.py +0 -0
  49. modules/webui/speaker/speaker_creator.py +171 -0
  50. modules/webui/speaker/speaker_merger.py +255 -0
modules/Enhancer/ResembleEnhance.py CHANGED
@@ -1,13 +1,8 @@
1
  import os
2
  from typing import List
3
-
4
- try:
5
- from resemble_enhance.enhancer.enhancer import Enhancer
6
- from resemble_enhance.enhancer.hparams import HParams
7
- from resemble_enhance.inference import inference
8
- except:
9
- HParams = dict
10
- Enhancer = dict
11
 
12
  import torch
13
 
 
1
  import os
2
  from typing import List
3
+ from modules.repos_static.resemble_enhance.enhancer.enhancer import Enhancer
4
+ from modules.repos_static.resemble_enhance.enhancer.hparams import HParams
5
+ from modules.repos_static.resemble_enhance.inference import inference
 
 
 
 
 
6
 
7
  import torch
8
 
modules/repos_static/__init__.py ADDED
File without changes
modules/repos_static/readme.md ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # repos static
2
+
3
+ ## resemble_enhance
4
+
5
+ https://github.com/resemble-ai/resemble-enhance/tree/main
modules/repos_static/resemble_enhance/__init__.py ADDED
File without changes
modules/repos_static/resemble_enhance/common.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ import torch
4
+ from torch import Tensor, nn
5
+
6
+ logger = logging.getLogger(__name__)
7
+
8
+
9
+ class Normalizer(nn.Module):
10
+ def __init__(self, momentum=0.01, eps=1e-9):
11
+ super().__init__()
12
+ self.momentum = momentum
13
+ self.eps = eps
14
+ self.running_mean_unsafe: Tensor
15
+ self.running_var_unsafe: Tensor
16
+ self.register_buffer("running_mean_unsafe", torch.full([], torch.nan))
17
+ self.register_buffer("running_var_unsafe", torch.full([], torch.nan))
18
+
19
+ @property
20
+ def started(self):
21
+ return not torch.isnan(self.running_mean_unsafe)
22
+
23
+ @property
24
+ def running_mean(self):
25
+ if not self.started:
26
+ return torch.zeros_like(self.running_mean_unsafe)
27
+ return self.running_mean_unsafe
28
+
29
+ @property
30
+ def running_std(self):
31
+ if not self.started:
32
+ return torch.ones_like(self.running_var_unsafe)
33
+ return (self.running_var_unsafe + self.eps).sqrt()
34
+
35
+ @torch.no_grad()
36
+ def _ema(self, a: Tensor, x: Tensor):
37
+ return (1 - self.momentum) * a + self.momentum * x
38
+
39
+ def update_(self, x):
40
+ if not self.started:
41
+ self.running_mean_unsafe = x.mean()
42
+ self.running_var_unsafe = x.var()
43
+ else:
44
+ self.running_mean_unsafe = self._ema(self.running_mean_unsafe, x.mean())
45
+ self.running_var_unsafe = self._ema(self.running_var_unsafe, (x - self.running_mean).pow(2).mean())
46
+
47
+ def forward(self, x: Tensor, update=True):
48
+ if self.training and update:
49
+ self.update_(x)
50
+ self.stats = dict(mean=self.running_mean.item(), std=self.running_std.item())
51
+ x = (x - self.running_mean) / self.running_std
52
+ return x
53
+
54
+ def inverse(self, x: Tensor):
55
+ return x * self.running_std + self.running_mean
modules/repos_static/resemble_enhance/data/__init__.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import random
3
+
4
+ from torch.utils.data import DataLoader
5
+
6
+ from ..hparams import HParams
7
+ from .dataset import Dataset
8
+ from .utils import mix_fg_bg, rglob_audio_files
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ def _create_datasets(hp: HParams, mode, val_size=10, seed=123):
14
+ paths = rglob_audio_files(hp.fg_dir)
15
+ logger.info(f"Found {len(paths)} audio files in {hp.fg_dir}")
16
+
17
+ random.Random(seed).shuffle(paths)
18
+ train_paths = paths[:-val_size]
19
+ val_paths = paths[-val_size:]
20
+
21
+ train_ds = Dataset(train_paths, hp, training=True, mode=mode)
22
+ val_ds = Dataset(val_paths, hp, training=False, mode=mode)
23
+
24
+ logger.info(f"Train set: {len(train_ds)} samples - Val set: {len(val_ds)} samples")
25
+
26
+ return train_ds, val_ds
27
+
28
+
29
+ def create_dataloaders(hp: HParams, mode):
30
+ train_ds, val_ds = _create_datasets(hp=hp, mode=mode)
31
+
32
+ train_dl = DataLoader(
33
+ train_ds,
34
+ batch_size=hp.batch_size_per_gpu,
35
+ shuffle=True,
36
+ num_workers=hp.nj,
37
+ drop_last=True,
38
+ collate_fn=train_ds.collate_fn,
39
+ )
40
+ val_dl = DataLoader(
41
+ val_ds,
42
+ batch_size=1,
43
+ shuffle=False,
44
+ num_workers=hp.nj,
45
+ drop_last=False,
46
+ collate_fn=val_ds.collate_fn,
47
+ )
48
+ return train_dl, val_dl
modules/repos_static/resemble_enhance/data/dataset.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import random
3
+ from pathlib import Path
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torchaudio
8
+ import torchaudio.functional as AF
9
+ from torch.nn.utils.rnn import pad_sequence
10
+ from torch.utils.data import Dataset as DatasetBase
11
+
12
+ from ..hparams import HParams
13
+ from .distorter import Distorter
14
+ from .utils import rglob_audio_files
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ def _normalize(x):
20
+ return x / (np.abs(x).max() + 1e-7)
21
+
22
+
23
+ def _collate(batch, key, tensor=True, pad=True):
24
+ l = [d[key] for d in batch]
25
+ if l[0] is None:
26
+ return None
27
+ if tensor:
28
+ l = [torch.from_numpy(x) for x in l]
29
+ if pad:
30
+ assert tensor, "Can't pad non-tensor"
31
+ l = pad_sequence(l, batch_first=True)
32
+ return l
33
+
34
+
35
+ def praat_augment(wav, sr):
36
+ try:
37
+ import parselmouth
38
+ except ImportError:
39
+ raise ImportError("Please install parselmouth>=0.5.0 to use Praat augmentation")
40
+ # "praat-parselmouth @ git+https://github.com/YannickJadoul/Parselmouth@0bbcca69705ed73322f3712b19d71bb3694b2540",
41
+ # https://github.com/YannickJadoul/Parselmouth/issues/68
42
+ # note that this function may hang if the praat version is 0.4.3
43
+ assert wav.ndim == 1, f"wav.ndim must be 1 but got {wav.ndim}"
44
+ sound = parselmouth.Sound(wav, sr)
45
+ formant_shift_ratio = random.uniform(1.1, 1.5)
46
+ pitch_range_factor = random.uniform(0.5, 2.0)
47
+ sound = parselmouth.praat.call(sound, "Change gender", 75, 600, formant_shift_ratio, 0, pitch_range_factor, 1.0)
48
+ wav = np.array(sound.values)[0].astype(np.float32)
49
+ return wav
50
+
51
+
52
+ class Dataset(DatasetBase):
53
+ def __init__(
54
+ self,
55
+ fg_paths: list[Path],
56
+ hp: HParams,
57
+ training=True,
58
+ max_retries=100,
59
+ silent_fg_prob=0.01,
60
+ mode=False,
61
+ ):
62
+ super().__init__()
63
+
64
+ assert mode in ("enhancer", "denoiser"), f"Invalid mode: {mode}"
65
+
66
+ self.hp = hp
67
+ self.fg_paths = fg_paths
68
+ self.bg_paths = rglob_audio_files(hp.bg_dir)
69
+
70
+ if len(self.fg_paths) == 0:
71
+ raise ValueError(f"No foreground audio files found in {hp.fg_dir}")
72
+
73
+ if len(self.bg_paths) == 0:
74
+ raise ValueError(f"No background audio files found in {hp.bg_dir}")
75
+
76
+ logger.info(f"Found {len(self.fg_paths)} foreground files and {len(self.bg_paths)} background files")
77
+
78
+ self.training = training
79
+ self.max_retries = max_retries
80
+ self.silent_fg_prob = silent_fg_prob
81
+
82
+ self.mode = mode
83
+ self.distorter = Distorter(hp, training=training, mode=mode)
84
+
85
+ def _load_wav(self, path, length=None, random_crop=True):
86
+ wav, sr = torchaudio.load(path)
87
+
88
+ wav = AF.resample(
89
+ waveform=wav,
90
+ orig_freq=sr,
91
+ new_freq=self.hp.wav_rate,
92
+ lowpass_filter_width=64,
93
+ rolloff=0.9475937167399596,
94
+ resampling_method="sinc_interp_kaiser",
95
+ beta=14.769656459379492,
96
+ )
97
+
98
+ wav = wav.float().numpy()
99
+
100
+ if wav.ndim == 2:
101
+ wav = np.mean(wav, axis=0)
102
+
103
+ if length is None and self.training:
104
+ length = int(self.hp.training_seconds * self.hp.wav_rate)
105
+
106
+ if length is not None:
107
+ if random_crop:
108
+ start = random.randint(0, max(0, len(wav) - length))
109
+ wav = wav[start : start + length]
110
+ else:
111
+ wav = wav[:length]
112
+
113
+ if length is not None and len(wav) < length:
114
+ wav = np.pad(wav, (0, length - len(wav)))
115
+
116
+ wav = _normalize(wav)
117
+
118
+ return wav
119
+
120
+ def _getitem_unsafe(self, index: int):
121
+ fg_path = self.fg_paths[index]
122
+
123
+ if self.training and random.random() < self.silent_fg_prob:
124
+ fg_wav = np.zeros(int(self.hp.training_seconds * self.hp.wav_rate), dtype=np.float32)
125
+ else:
126
+ fg_wav = self._load_wav(fg_path)
127
+ if random.random() < self.hp.praat_augment_prob and self.training:
128
+ fg_wav = praat_augment(fg_wav, self.hp.wav_rate)
129
+
130
+ if self.hp.load_fg_only:
131
+ bg_wav = None
132
+ fg_dwav = None
133
+ bg_dwav = None
134
+ else:
135
+ fg_dwav = _normalize(self.distorter(fg_wav, self.hp.wav_rate)).astype(np.float32)
136
+ if self.training:
137
+ bg_path = random.choice(self.bg_paths)
138
+ else:
139
+ # Deterministic for validation
140
+ bg_path = self.bg_paths[index % len(self.bg_paths)]
141
+ bg_wav = self._load_wav(bg_path, length=len(fg_wav), random_crop=self.training)
142
+ bg_dwav = _normalize(self.distorter(bg_wav, self.hp.wav_rate)).astype(np.float32)
143
+
144
+ return dict(
145
+ fg_wav=fg_wav,
146
+ bg_wav=bg_wav,
147
+ fg_dwav=fg_dwav,
148
+ bg_dwav=bg_dwav,
149
+ )
150
+
151
+ def __getitem__(self, index: int):
152
+ for i in range(self.max_retries):
153
+ try:
154
+ return self._getitem_unsafe(index)
155
+ except Exception as e:
156
+ if i == self.max_retries - 1:
157
+ raise RuntimeError(f"Failed to load {self.fg_paths[index]} after {self.max_retries} retries") from e
158
+ logger.debug(f"Error loading {self.fg_paths[index]}: {e}, skipping")
159
+ index = np.random.randint(0, len(self))
160
+
161
+ def __len__(self):
162
+ return len(self.fg_paths)
163
+
164
+ @staticmethod
165
+ def collate_fn(batch):
166
+ return dict(
167
+ fg_wavs=_collate(batch, "fg_wav"),
168
+ bg_wavs=_collate(batch, "bg_wav"),
169
+ fg_dwavs=_collate(batch, "fg_dwav"),
170
+ bg_dwavs=_collate(batch, "bg_dwav"),
171
+ )
modules/repos_static/resemble_enhance/data/distorter/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .distorter import Distorter
modules/repos_static/resemble_enhance/data/distorter/base.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ import os
3
+ import random
4
+ import time
5
+ import warnings
6
+
7
+ import numpy as np
8
+
9
+ _DEBUG = bool(os.environ.get("DEBUG", False))
10
+
11
+
12
+ class Effect:
13
+ def apply(self, wav: np.ndarray, sr: int):
14
+ """
15
+ Args:
16
+ wav: (T)
17
+ sr: sample rate
18
+ Returns:
19
+ wav: (T) with the same sample rate of `sr`
20
+ """
21
+ raise NotImplementedError
22
+
23
+ def __call__(self, wav: np.ndarray, sr: int):
24
+ """
25
+ Args:
26
+ wav: (T)
27
+ sr: sample rate
28
+ Returns:
29
+ wav: (T) with the same sample rate of `sr`
30
+ """
31
+ assert len(wav.shape) == 1, wav.shape
32
+
33
+ if _DEBUG:
34
+ start = time.time()
35
+ else:
36
+ start = None
37
+
38
+ shape = wav.shape
39
+ assert wav.ndim == 1, f"{self}: Expected wav.ndim == 1, got {wav.ndim}."
40
+ wav = self.apply(wav, sr)
41
+ assert shape == wav.shape, f"{self}: {shape} != {wav.shape}."
42
+
43
+ if start is not None:
44
+ end = time.time()
45
+ print(f"{self.__class__.__name__}: {end - start:.3f} sec")
46
+
47
+ return wav
48
+
49
+
50
+ class Chain(Effect):
51
+ def __init__(self, *effects):
52
+ super().__init__()
53
+
54
+ self.effects = effects
55
+
56
+ def apply(self, wav, sr):
57
+ for effect in self.effects:
58
+ wav = effect(wav, sr)
59
+ return wav
60
+
61
+
62
+ class Maybe(Effect):
63
+ def __init__(self, prob, effect):
64
+ super().__init__()
65
+
66
+ self.prob = prob
67
+ self.effect = effect
68
+
69
+ if _DEBUG:
70
+ warnings.warn("DEBUG mode is on. Maybe -> Must.")
71
+ self.prob = 1
72
+
73
+ def apply(self, wav, sr):
74
+ if random.random() > self.prob:
75
+ return wav
76
+ return self.effect(wav, sr)
77
+
78
+
79
+ class Choice(Effect):
80
+ def __init__(self, *effects, **kwargs):
81
+ super().__init__()
82
+ self.effects = effects
83
+ self.kwargs = kwargs
84
+
85
+ def apply(self, wav, sr):
86
+ return np.random.choice(self.effects, **self.kwargs)(wav, sr)
87
+
88
+
89
+ class Permutation(Effect):
90
+ def __init__(self, *effects, n: int | None = None):
91
+ super().__init__()
92
+ self.effects = effects
93
+ self.n = n
94
+
95
+ def apply(self, wav, sr):
96
+ if self.n is None:
97
+ n = np.random.binomial(len(self.effects), 0.5)
98
+ else:
99
+ n = self.n
100
+ if n == 0:
101
+ return wav
102
+ perms = itertools.permutations(self.effects, n)
103
+ effects = random.choice(list(perms))
104
+ return Chain(*effects)(wav, sr)
modules/repos_static/resemble_enhance/data/distorter/custom.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import random
3
+ from dataclasses import dataclass
4
+ from functools import cached_property
5
+ from pathlib import Path
6
+
7
+ import librosa
8
+ import numpy as np
9
+ from scipy import signal
10
+
11
+ from ..utils import walk_paths
12
+ from .base import Effect
13
+
14
+ _logger = logging.getLogger(__name__)
15
+
16
+
17
+ @dataclass
18
+ class RandomRIR(Effect):
19
+ rir_dir: Path | None
20
+ rir_rate: int = 44_000
21
+ rir_suffix: str = ".npy"
22
+ deterministic: bool = False
23
+
24
+ @cached_property
25
+ def rir_paths(self):
26
+ if self.rir_dir is None:
27
+ return []
28
+ return list(walk_paths(self.rir_dir, self.rir_suffix))
29
+
30
+ def _sample_rir(self):
31
+ if len(self.rir_paths) == 0:
32
+ return None
33
+
34
+ if self.deterministic:
35
+ rir_path = self.rir_paths[0]
36
+ else:
37
+ rir_path = random.choice(self.rir_paths)
38
+
39
+ rir = np.squeeze(np.load(rir_path))
40
+ assert isinstance(rir, np.ndarray)
41
+
42
+ return rir
43
+
44
+ def apply(self, wav, sr):
45
+ # ref: https://github.com/haoheliu/voicefixer_main/blob/b06e07c945ac1d309b8a57ddcd599ca376b98cd9/dataloaders/augmentation/magical_effects.py#L158
46
+
47
+ if len(self.rir_paths) == 0:
48
+ return wav
49
+
50
+ length = len(wav)
51
+
52
+ wav = librosa.resample(wav, orig_sr=sr, target_sr=self.rir_rate, res_type="kaiser_fast")
53
+ rir = self._sample_rir()
54
+
55
+ wav = signal.convolve(wav, rir, mode="same")
56
+
57
+ actlev = np.max(np.abs(wav))
58
+ if actlev > 0.99:
59
+ wav = (wav / actlev) * 0.98
60
+
61
+ wav = librosa.resample(wav, orig_sr=self.rir_rate, target_sr=sr, res_type="kaiser_fast")
62
+
63
+ if abs(length - len(wav)) > 10:
64
+ _logger.warning(f"length mismatch: {length} vs {len(wav)}")
65
+
66
+ if length > len(wav):
67
+ wav = np.pad(wav, (0, length - len(wav)))
68
+ elif length < len(wav):
69
+ wav = wav[:length]
70
+
71
+ return wav
72
+
73
+
74
+ class RandomGaussianNoise(Effect):
75
+ def __init__(self, alpha_range=(0.8, 1)):
76
+ super().__init__()
77
+ self.alpha_range = alpha_range
78
+
79
+ def apply(self, wav, sr):
80
+ noise = np.random.randn(*wav.shape)
81
+ noise_energy = np.sum(noise**2)
82
+ wav_energy = np.sum(wav**2)
83
+ noise = noise * np.sqrt(wav_energy / noise_energy)
84
+ alpha = random.uniform(*self.alpha_range)
85
+ return wav * alpha + noise * (1 - alpha)
modules/repos_static/resemble_enhance/data/distorter/distorter.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ...hparams import HParams
2
+ from .base import Chain, Choice, Permutation
3
+ from .custom import RandomGaussianNoise, RandomRIR
4
+
5
+
6
+ class Distorter(Chain):
7
+ def __init__(self, hp: HParams, training: bool = False, mode: str = "enhancer"):
8
+ # Lazy import
9
+ from .sox import RandomBandpassDistorter, RandomEqualizer, RandomLowpassDistorter, RandomOverdrive, RandomReverb
10
+
11
+ if training:
12
+ permutation = Permutation(
13
+ RandomRIR(hp.rir_dir),
14
+ RandomReverb(),
15
+ RandomGaussianNoise(),
16
+ RandomOverdrive(),
17
+ RandomEqualizer(),
18
+ Choice(
19
+ RandomLowpassDistorter(),
20
+ RandomBandpassDistorter(),
21
+ ),
22
+ )
23
+ if mode == "denoiser":
24
+ super().__init__(permutation)
25
+ else:
26
+ # 80%: distortion, 20%: clean
27
+ super().__init__(Choice(permutation, Chain(), p=[0.8, 0.2]))
28
+ else:
29
+ super().__init__(
30
+ RandomRIR(hp.rir_dir, deterministic=True),
31
+ RandomReverb(deterministic=True),
32
+ )
modules/repos_static/resemble_enhance/data/distorter/sox.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import random
4
+ import warnings
5
+ from functools import partial
6
+
7
+ import numpy as np
8
+ import torch
9
+
10
+ try:
11
+ import augment
12
+ except ImportError:
13
+ raise ImportError(
14
+ "augment is not installed, please install it first using:"
15
+ "\npip install git+https://github.com/facebookresearch/WavAugment@54afcdb00ccc852c2f030f239f8532c9562b550e"
16
+ )
17
+
18
+ from .base import Effect
19
+
20
+ _logger = logging.getLogger(__name__)
21
+ _DEBUG = bool(os.environ.get("DEBUG", False))
22
+
23
+
24
+ class AttachableEffect(Effect):
25
+ def attach(self, chain: augment.EffectChain) -> augment.EffectChain:
26
+ raise NotImplementedError
27
+
28
+ def apply(self, wav: np.ndarray, sr: int):
29
+ chain = augment.EffectChain()
30
+ chain = self.attach(chain)
31
+ tensor = torch.from_numpy(wav)[None].float() # (1, T)
32
+ tensor = chain.apply(tensor, src_info={"rate": sr}, target_info={"channels": 1, "rate": sr})
33
+ wav = tensor.numpy()[0] # (T,)
34
+ return wav
35
+
36
+
37
+ class SoxEffect(AttachableEffect):
38
+ def __init__(self, effect_name: str, *args, **kwargs):
39
+ self.effect_name = effect_name
40
+ self.args = args
41
+ self.kwargs = kwargs
42
+
43
+ def attach(self, chain: augment.EffectChain) -> augment.EffectChain:
44
+ _logger.debug(f"Attaching {self.effect_name} with {self.args} and {self.kwargs}")
45
+ if not hasattr(chain, self.effect_name):
46
+ raise ValueError(f"EffectChain has no attribute {self.effect_name}")
47
+ return getattr(chain, self.effect_name)(*self.args, **self.kwargs)
48
+
49
+
50
+ class Maybe(AttachableEffect):
51
+ """
52
+ Attach an effect with a probability.
53
+ """
54
+
55
+ def __init__(self, prob: float, effect: AttachableEffect):
56
+ self.prob = prob
57
+ self.effect = effect
58
+ if _DEBUG:
59
+ warnings.warn("DEBUG mode is on. Maybe -> Must.")
60
+ self.prob = 1
61
+
62
+ def attach(self, chain: augment.EffectChain) -> augment.EffectChain:
63
+ if random.random() > self.prob:
64
+ return chain
65
+ return self.effect.attach(chain)
66
+
67
+
68
+ class Chain(AttachableEffect):
69
+ """
70
+ Attach a chain of effects.
71
+ """
72
+
73
+ def __init__(self, *effects: AttachableEffect):
74
+ self.effects = effects
75
+
76
+ def attach(self, chain: augment.EffectChain) -> augment.EffectChain:
77
+ for effect in self.effects:
78
+ chain = effect.attach(chain)
79
+ return chain
80
+
81
+
82
+ class Choice(AttachableEffect):
83
+ """
84
+ Attach one of the effects randomly.
85
+ """
86
+
87
+ def __init__(self, *effects: AttachableEffect):
88
+ self.effects = effects
89
+
90
+ def attach(self, chain: augment.EffectChain) -> augment.EffectChain:
91
+ return random.choice(self.effects).attach(chain)
92
+
93
+
94
+ class Generator:
95
+ def __call__(self) -> str:
96
+ raise NotImplementedError
97
+
98
+
99
+ class Uniform(Generator):
100
+ def __init__(self, low, high):
101
+ self.low = low
102
+ self.high = high
103
+
104
+ def __call__(self) -> str:
105
+ return str(random.uniform(self.low, self.high))
106
+
107
+
108
+ class Randint(Generator):
109
+ def __init__(self, low, high):
110
+ self.low = low
111
+ self.high = high
112
+
113
+ def __call__(self) -> str:
114
+ return str(random.randint(self.low, self.high))
115
+
116
+
117
+ class Concat(Generator):
118
+ def __init__(self, *parts: Generator | str):
119
+ self.parts = parts
120
+
121
+ def __call__(self):
122
+ return "".join([part if isinstance(part, str) else part() for part in self.parts])
123
+
124
+
125
+ class RandomLowpassDistorter(SoxEffect):
126
+ def __init__(self, low=2000, high=16000):
127
+ super().__init__("sinc", "-n", Randint(50, 200), Concat("-", Uniform(low, high)))
128
+
129
+
130
+ class RandomBandpassDistorter(SoxEffect):
131
+ def __init__(self, low=100, high=1000, min_width=2000, max_width=4000):
132
+ super().__init__("sinc", "-n", Randint(50, 200), partial(self._fn, low, high, min_width, max_width))
133
+
134
+ @staticmethod
135
+ def _fn(low, high, min_width, max_width):
136
+ start = random.randint(low, high)
137
+ stop = start + random.randint(min_width, max_width)
138
+ return f"{start}-{stop}"
139
+
140
+
141
+ class RandomEqualizer(SoxEffect):
142
+ def __init__(self, low=100, high=4000, q_low=1, q_high=5, db_low: int = -30, db_high: int = 30):
143
+ super().__init__(
144
+ "equalizer",
145
+ Uniform(low, high),
146
+ lambda: f"{random.randint(q_low, q_high)}q",
147
+ lambda: random.randint(db_low, db_high),
148
+ )
149
+
150
+
151
+ class RandomOverdrive(SoxEffect):
152
+ def __init__(self, gain_low=5, gain_high=40, colour_low=20, colour_high=80):
153
+ super().__init__("overdrive", Uniform(gain_low, gain_high), Uniform(colour_low, colour_high))
154
+
155
+
156
+ class RandomReverb(Chain):
157
+ def __init__(self, deterministic=False):
158
+ super().__init__(
159
+ SoxEffect(
160
+ "reverb",
161
+ Uniform(50, 50) if deterministic else Uniform(0, 100),
162
+ Uniform(50, 50) if deterministic else Uniform(0, 100),
163
+ Uniform(50, 50) if deterministic else Uniform(0, 100),
164
+ ),
165
+ SoxEffect("channels", 1),
166
+ )
167
+
168
+
169
+ class Flanger(SoxEffect):
170
+ def __init__(self):
171
+ super().__init__("flanger")
172
+
173
+
174
+ class Phaser(SoxEffect):
175
+ def __init__(self):
176
+ super().__init__("phaser")
modules/repos_static/resemble_enhance/data/utils.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Callable
3
+
4
+ from torch import Tensor
5
+
6
+
7
+ def walk_paths(root, suffix):
8
+ for path in Path(root).iterdir():
9
+ if path.is_dir():
10
+ yield from walk_paths(path, suffix)
11
+ elif path.suffix == suffix:
12
+ yield path
13
+
14
+
15
+ def rglob_audio_files(path: Path):
16
+ return list(walk_paths(path, ".wav")) + list(walk_paths(path, ".flac"))
17
+
18
+
19
+ def mix_fg_bg(fg: Tensor, bg: Tensor, alpha: float | Callable[..., float] = 0.5, eps=1e-7):
20
+ """
21
+ Args:
22
+ fg: (b, t)
23
+ bg: (b, t)
24
+ """
25
+ assert bg.shape == fg.shape, f"bg.shape != fg.shape: {bg.shape} != {fg.shape}"
26
+ fg = fg / (fg.abs().max(dim=-1, keepdim=True).values + eps)
27
+ bg = bg / (bg.abs().max(dim=-1, keepdim=True).values + eps)
28
+
29
+ fg_energy = fg.pow(2).sum(dim=-1, keepdim=True)
30
+ bg_energy = bg.pow(2).sum(dim=-1, keepdim=True)
31
+
32
+ fg = fg / (fg_energy + eps).sqrt()
33
+ bg = bg / (bg_energy + eps).sqrt()
34
+
35
+ if callable(alpha):
36
+ alpha = alpha()
37
+
38
+ assert 0 <= alpha <= 1, f"alpha must be between 0 and 1: {alpha}"
39
+
40
+ mx = alpha * fg + (1 - alpha) * bg
41
+ mx = mx / (mx.abs().max(dim=-1, keepdim=True).values + eps)
42
+
43
+ return mx
modules/repos_static/resemble_enhance/denoiser/__init__.py ADDED
File without changes
modules/repos_static/resemble_enhance/denoiser/__main__.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from pathlib import Path
3
+
4
+ import torch
5
+ import torchaudio
6
+
7
+ from .inference import denoise
8
+
9
+
10
+ @torch.inference_mode()
11
+ def main():
12
+ parser = argparse.ArgumentParser()
13
+ parser.add_argument("in_dir", type=Path, help="Path to input audio folder")
14
+ parser.add_argument("out_dir", type=Path, help="Output folder")
15
+ parser.add_argument("--run_dir", type=Path, default="runs/denoiser", help="Path to run folder")
16
+ parser.add_argument("--suffix", type=str, default=".wav", help="File suffix")
17
+ parser.add_argument("--device", type=str, default="cuda", help="Device")
18
+ args = parser.parse_args()
19
+
20
+ for path in args.in_dir.glob(f"**/*{args.suffix}"):
21
+ print(f"Processing {path} ..")
22
+ dwav, sr = torchaudio.load(path)
23
+ hwav, sr = denoise(dwav[0], sr, args.run_dir, args.device)
24
+ out_path = args.out_dir / path.relative_to(args.in_dir)
25
+ out_path.parent.mkdir(parents=True, exist_ok=True)
26
+ torchaudio.save(out_path, hwav[None], sr)
27
+
28
+
29
+ if __name__ == "__main__":
30
+ main()
modules/repos_static/resemble_enhance/denoiser/denoiser.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import Tensor, nn
6
+
7
+ from ..melspec import MelSpectrogram
8
+ from .hparams import HParams
9
+ from .unet import UNet
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ def _normalize(x: Tensor) -> Tensor:
15
+ return x / (x.abs().max(dim=-1, keepdim=True).values + 1e-7)
16
+
17
+
18
+ class Denoiser(nn.Module):
19
+ @property
20
+ def stft_cfg(self) -> dict:
21
+ hop_size = self.hp.hop_size
22
+ return dict(hop_length=hop_size, n_fft=hop_size * 4, win_length=hop_size * 4)
23
+
24
+ @property
25
+ def n_fft(self):
26
+ return self.stft_cfg["n_fft"]
27
+
28
+ @property
29
+ def eps(self):
30
+ return 1e-7
31
+
32
+ def __init__(self, hp: HParams):
33
+ super().__init__()
34
+ self.hp = hp
35
+ self.net = UNet(input_dim=3, output_dim=3)
36
+ self.mel_fn = MelSpectrogram(hp)
37
+
38
+ self.dummy: Tensor
39
+ self.register_buffer("dummy", torch.zeros(1), persistent=False)
40
+
41
+ def to_mel(self, x: Tensor, drop_last=True):
42
+ """
43
+ Args:
44
+ x: (b t), wavs
45
+ Returns:
46
+ o: (b c t), mels
47
+ """
48
+ if drop_last:
49
+ return self.mel_fn(x)[..., :-1] # (b d t)
50
+ return self.mel_fn(x)
51
+
52
+ def _stft(self, x):
53
+ """
54
+ Args:
55
+ x: (b t)
56
+ Returns:
57
+ mag: (b f t) in [0, inf)
58
+ cos: (b f t) in [-1, 1]
59
+ sin: (b f t) in [-1, 1]
60
+ """
61
+ dtype = x.dtype
62
+ device = x.device
63
+
64
+ if x.is_mps:
65
+ x = x.cpu()
66
+
67
+ window = torch.hann_window(self.stft_cfg["win_length"], device=x.device)
68
+ s = torch.stft(x.float(), **self.stft_cfg, window=window, return_complex=True) # (b f t+1)
69
+
70
+ s = s[..., :-1] # (b f t)
71
+
72
+ mag = s.abs() # (b f t)
73
+
74
+ phi = s.angle() # (b f t)
75
+ cos = phi.cos() # (b f t)
76
+ sin = phi.sin() # (b f t)
77
+
78
+ mag = mag.to(dtype=dtype, device=device)
79
+ cos = cos.to(dtype=dtype, device=device)
80
+ sin = sin.to(dtype=dtype, device=device)
81
+
82
+ return mag, cos, sin
83
+
84
+ def _istft(self, mag: Tensor, cos: Tensor, sin: Tensor):
85
+ """
86
+ Args:
87
+ mag: (b f t) in [0, inf)
88
+ cos: (b f t) in [-1, 1]
89
+ sin: (b f t) in [-1, 1]
90
+ Returns:
91
+ x: (b t)
92
+ """
93
+ device = mag.device
94
+ dtype = mag.dtype
95
+
96
+ if mag.is_mps:
97
+ mag = mag.cpu()
98
+ cos = cos.cpu()
99
+ sin = sin.cpu()
100
+
101
+ real = mag * cos # (b f t)
102
+ imag = mag * sin # (b f t)
103
+
104
+ s = torch.complex(real, imag) # (b f t)
105
+
106
+ if s.isnan().any():
107
+ logger.warning("NaN detected in ISTFT input.")
108
+
109
+ s = F.pad(s, (0, 1), "replicate") # (b f t+1)
110
+
111
+ window = torch.hann_window(self.stft_cfg["win_length"], device=s.device)
112
+ x = torch.istft(s, **self.stft_cfg, window=window, return_complex=False)
113
+
114
+ if x.isnan().any():
115
+ logger.warning("NaN detected in ISTFT output, set to zero.")
116
+ x = torch.where(x.isnan(), torch.zeros_like(x), x)
117
+
118
+ x = x.to(dtype=dtype, device=device)
119
+
120
+ return x
121
+
122
+ def _magphase(self, real, imag):
123
+ mag = (real.pow(2) + imag.pow(2) + self.eps).sqrt()
124
+ cos = real / mag
125
+ sin = imag / mag
126
+ return mag, cos, sin
127
+
128
+ def _predict(self, mag: Tensor, cos: Tensor, sin: Tensor):
129
+ """
130
+ Args:
131
+ mag: (b f t)
132
+ cos: (b f t)
133
+ sin: (b f t)
134
+ Returns:
135
+ mag_mask: (b f t) in [0, 1], magnitude mask
136
+ cos_res: (b f t) in [-1, 1], phase residual
137
+ sin_res: (b f t) in [-1, 1], phase residual
138
+ """
139
+ x = torch.stack([mag, cos, sin], dim=1) # (b 3 f t)
140
+ mag_mask, real, imag = self.net(x).unbind(1) # (b 3 f t)
141
+ mag_mask = mag_mask.sigmoid() # (b f t)
142
+ real = real.tanh() # (b f t)
143
+ imag = imag.tanh() # (b f t)
144
+ _, cos_res, sin_res = self._magphase(real, imag) # (b f t)
145
+ return mag_mask, sin_res, cos_res
146
+
147
+ def _separate(self, mag, cos, sin, mag_mask, cos_res, sin_res):
148
+ """Ref: https://audio-agi.github.io/Separate-Anything-You-Describe/AudioSep_arXiv.pdf"""
149
+ sep_mag = F.relu(mag * mag_mask)
150
+ sep_cos = cos * cos_res - sin * sin_res
151
+ sep_sin = sin * cos_res + cos * sin_res
152
+ return sep_mag, sep_cos, sep_sin
153
+
154
+ def forward(self, x: Tensor, y: Tensor | None = None):
155
+ """
156
+ Args:
157
+ x: (b t), a mixed audio
158
+ y: (b t), a fg audio
159
+ """
160
+ assert x.dim() == 2, f"Expected (b t), got {x.size()}"
161
+ x = x.to(self.dummy)
162
+ x = _normalize(x)
163
+
164
+ if y is not None:
165
+ assert y.dim() == 2, f"Expected (b t), got {y.size()}"
166
+ y = y.to(self.dummy)
167
+ y = _normalize(y)
168
+
169
+ mag, cos, sin = self._stft(x) # (b 2f t)
170
+ mag_mask, sin_res, cos_res = self._predict(mag, cos, sin)
171
+ sep_mag, sep_cos, sep_sin = self._separate(mag, cos, sin, mag_mask, cos_res, sin_res)
172
+
173
+ o = self._istft(sep_mag, sep_cos, sep_sin)
174
+
175
+ npad = x.shape[-1] - o.shape[-1]
176
+ o = F.pad(o, (0, npad))
177
+
178
+ if y is not None:
179
+ self.losses = dict(l1=F.l1_loss(o, y))
180
+
181
+ return o
modules/repos_static/resemble_enhance/denoiser/hparams.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ from ..hparams import HParams as HParamsBase
4
+
5
+
6
+ @dataclass(frozen=True)
7
+ class HParams(HParamsBase):
8
+ batch_size_per_gpu: int = 128
9
+ distort_prob: float = 0.5
modules/repos_static/resemble_enhance/denoiser/inference.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from functools import cache
3
+
4
+ import torch
5
+
6
+ from ..denoiser.denoiser import Denoiser
7
+
8
+ from ..inference import inference
9
+ from .hparams import HParams
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ @cache
15
+ def load_denoiser(run_dir, device):
16
+ if run_dir is None:
17
+ return Denoiser(HParams())
18
+ hp = HParams.load(run_dir)
19
+ denoiser = Denoiser(hp)
20
+ path = run_dir / "ds" / "G" / "default" / "mp_rank_00_model_states.pt"
21
+ state_dict = torch.load(path, map_location="cpu")["module"]
22
+ denoiser.load_state_dict(state_dict)
23
+ denoiser.eval()
24
+ denoiser.to(device)
25
+ return denoiser
26
+
27
+
28
+ @torch.inference_mode()
29
+ def denoise(dwav, sr, run_dir, device):
30
+ denoiser = load_denoiser(run_dir, device)
31
+ return inference(model=denoiser, dwav=dwav, sr=sr, device=device)
modules/repos_static/resemble_enhance/denoiser/unet.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn.functional as F
2
+ from torch import nn
3
+
4
+
5
+ class PreactResBlock(nn.Sequential):
6
+ def __init__(self, dim):
7
+ super().__init__(
8
+ nn.GroupNorm(dim // 16, dim),
9
+ nn.GELU(),
10
+ nn.Conv2d(dim, dim, 3, padding=1),
11
+ nn.GroupNorm(dim // 16, dim),
12
+ nn.GELU(),
13
+ nn.Conv2d(dim, dim, 3, padding=1),
14
+ )
15
+
16
+ def forward(self, x):
17
+ return x + super().forward(x)
18
+
19
+
20
+ class UNetBlock(nn.Module):
21
+ def __init__(self, input_dim, output_dim=None, scale_factor=1.0):
22
+ super().__init__()
23
+ if output_dim is None:
24
+ output_dim = input_dim
25
+ self.pre_conv = nn.Conv2d(input_dim, output_dim, 3, padding=1)
26
+ self.res_block1 = PreactResBlock(output_dim)
27
+ self.res_block2 = PreactResBlock(output_dim)
28
+ self.downsample = self.upsample = nn.Identity()
29
+ if scale_factor > 1:
30
+ self.upsample = nn.Upsample(scale_factor=scale_factor)
31
+ elif scale_factor < 1:
32
+ self.downsample = nn.Upsample(scale_factor=scale_factor)
33
+
34
+ def forward(self, x, h=None):
35
+ """
36
+ Args:
37
+ x: (b c h w), last output
38
+ h: (b c h w), skip output
39
+ Returns:
40
+ o: (b c h w), output
41
+ s: (b c h w), skip output
42
+ """
43
+ x = self.upsample(x)
44
+ if h is not None:
45
+ assert x.shape == h.shape, f"{x.shape} != {h.shape}"
46
+ x = x + h
47
+ x = self.pre_conv(x)
48
+ x = self.res_block1(x)
49
+ x = self.res_block2(x)
50
+ return self.downsample(x), x
51
+
52
+
53
+ class UNet(nn.Module):
54
+ def __init__(self, input_dim, output_dim, hidden_dim=16, num_blocks=4, num_middle_blocks=2):
55
+ super().__init__()
56
+ self.input_dim = input_dim
57
+ self.output_dim = output_dim
58
+ self.input_proj = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
59
+ self.encoder_blocks = nn.ModuleList(
60
+ [
61
+ UNetBlock(input_dim=hidden_dim * 2**i, output_dim=hidden_dim * 2 ** (i + 1), scale_factor=0.5)
62
+ for i in range(num_blocks)
63
+ ]
64
+ )
65
+ self.middle_blocks = nn.ModuleList(
66
+ [UNetBlock(input_dim=hidden_dim * 2**num_blocks) for _ in range(num_middle_blocks)]
67
+ )
68
+ self.decoder_blocks = nn.ModuleList(
69
+ [
70
+ UNetBlock(input_dim=hidden_dim * 2 ** (i + 1), output_dim=hidden_dim * 2**i, scale_factor=2)
71
+ for i in reversed(range(num_blocks))
72
+ ]
73
+ )
74
+ self.head = nn.Sequential(
75
+ nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
76
+ nn.GELU(),
77
+ nn.Conv2d(hidden_dim, output_dim, 1),
78
+ )
79
+
80
+ @property
81
+ def scale_factor(self):
82
+ return 2 ** len(self.encoder_blocks)
83
+
84
+ def pad_to_fit(self, x):
85
+ """
86
+ Args:
87
+ x: (b c h w), input
88
+ Returns:
89
+ x: (b c h' w'), padded input
90
+ """
91
+ hpad = (self.scale_factor - x.shape[2] % self.scale_factor) % self.scale_factor
92
+ wpad = (self.scale_factor - x.shape[3] % self.scale_factor) % self.scale_factor
93
+ return F.pad(x, (0, wpad, 0, hpad))
94
+
95
+ def forward(self, x):
96
+ """
97
+ Args:
98
+ x: (b c h w), input
99
+ Returns:
100
+ o: (b c h w), output
101
+ """
102
+ shape = x.shape
103
+
104
+ x = self.pad_to_fit(x)
105
+ x = self.input_proj(x)
106
+
107
+ s_list = []
108
+ for block in self.encoder_blocks:
109
+ x, s = block(x)
110
+ s_list.append(s)
111
+
112
+ for block in self.middle_blocks:
113
+ x, _ = block(x)
114
+
115
+ for block, s in zip(self.decoder_blocks, reversed(s_list)):
116
+ x, _ = block(x, s)
117
+
118
+ x = self.head(x)
119
+ x = x[..., : shape[2], : shape[3]]
120
+
121
+ return x
122
+
123
+ def test(self, shape=(3, 512, 256)):
124
+ import ptflops
125
+
126
+ macs, params = ptflops.get_model_complexity_info(
127
+ self,
128
+ shape,
129
+ as_strings=True,
130
+ print_per_layer_stat=True,
131
+ verbose=True,
132
+ )
133
+
134
+ print(f"macs: {macs}")
135
+ print(f"params: {params}")
136
+
137
+
138
+ def main():
139
+ model = UNet(3, 3)
140
+ model.test()
141
+
142
+
143
+ if __name__ == "__main__":
144
+ main()
modules/repos_static/resemble_enhance/enhancer/__init__.py ADDED
File without changes
modules/repos_static/resemble_enhance/enhancer/__main__.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import random
3
+ import time
4
+ from pathlib import Path
5
+
6
+ import torch
7
+ import torchaudio
8
+ from tqdm import tqdm
9
+
10
+ from .inference import denoise, enhance
11
+
12
+
13
+ @torch.inference_mode()
14
+ def main():
15
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
16
+ parser.add_argument("in_dir", type=Path, help="Path to input audio folder")
17
+ parser.add_argument("out_dir", type=Path, help="Output folder")
18
+ parser.add_argument(
19
+ "--run_dir",
20
+ type=Path,
21
+ default=None,
22
+ help="Path to the enhancer run folder, if None, use the default model",
23
+ )
24
+ parser.add_argument(
25
+ "--suffix",
26
+ type=str,
27
+ default=".wav",
28
+ help="Audio file suffix",
29
+ )
30
+ parser.add_argument(
31
+ "--device",
32
+ type=str,
33
+ default="cuda",
34
+ help="Device to use for computation, recommended to use CUDA",
35
+ )
36
+ parser.add_argument(
37
+ "--denoise_only",
38
+ action="store_true",
39
+ help="Only apply denoising without enhancement",
40
+ )
41
+ parser.add_argument(
42
+ "--lambd",
43
+ type=float,
44
+ default=1.0,
45
+ help="Denoise strength for enhancement (0.0 to 1.0)",
46
+ )
47
+ parser.add_argument(
48
+ "--tau",
49
+ type=float,
50
+ default=0.5,
51
+ help="CFM prior temperature (0.0 to 1.0)",
52
+ )
53
+ parser.add_argument(
54
+ "--solver",
55
+ type=str,
56
+ default="midpoint",
57
+ choices=["midpoint", "rk4", "euler"],
58
+ help="Numerical solver to use",
59
+ )
60
+ parser.add_argument(
61
+ "--nfe",
62
+ type=int,
63
+ default=64,
64
+ help="Number of function evaluations",
65
+ )
66
+ parser.add_argument(
67
+ "--parallel_mode",
68
+ action="store_true",
69
+ help="Shuffle the audio paths and skip the existing ones, enabling multiple jobs to run in parallel",
70
+ )
71
+
72
+ args = parser.parse_args()
73
+
74
+ device = args.device
75
+
76
+ if device == "cuda" and not torch.cuda.is_available():
77
+ print("CUDA is not available but --device is set to cuda, using CPU instead")
78
+ device = "cpu"
79
+
80
+ start_time = time.perf_counter()
81
+
82
+ run_dir = args.run_dir
83
+
84
+ paths = sorted(args.in_dir.glob(f"**/*{args.suffix}"))
85
+
86
+ if args.parallel_mode:
87
+ random.shuffle(paths)
88
+
89
+ if len(paths) == 0:
90
+ print(f"No {args.suffix} files found in the following path: {args.in_dir}")
91
+ return
92
+
93
+ pbar = tqdm(paths)
94
+
95
+ for path in pbar:
96
+ out_path = args.out_dir / path.relative_to(args.in_dir)
97
+ if args.parallel_mode and out_path.exists():
98
+ continue
99
+ pbar.set_description(f"Processing {out_path}")
100
+ dwav, sr = torchaudio.load(path)
101
+ dwav = dwav.mean(0)
102
+ if args.denoise_only:
103
+ hwav, sr = denoise(
104
+ dwav=dwav,
105
+ sr=sr,
106
+ device=device,
107
+ run_dir=args.run_dir,
108
+ )
109
+ else:
110
+ hwav, sr = enhance(
111
+ dwav=dwav,
112
+ sr=sr,
113
+ device=device,
114
+ nfe=args.nfe,
115
+ solver=args.solver,
116
+ lambd=args.lambd,
117
+ tau=args.tau,
118
+ run_dir=run_dir,
119
+ )
120
+ out_path.parent.mkdir(parents=True, exist_ok=True)
121
+ torchaudio.save(out_path, hwav[None], sr)
122
+
123
+ # Cool emoji effect saying the job is done
124
+ elapsed_time = time.perf_counter() - start_time
125
+ print(f"🌟 Enhancement done! {len(paths)} files processed in {elapsed_time:.2f}s")
126
+
127
+
128
+ if __name__ == "__main__":
129
+ main()
modules/repos_static/resemble_enhance/enhancer/download.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from pathlib import Path
3
+
4
+ import torch
5
+
6
+ RUN_NAME = "enhancer_stage2"
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ def get_source_url(relpath):
12
+ return f"https://huggingface.co/ResembleAI/resemble-enhance/resolve/main/{RUN_NAME}/{relpath}?download=true"
13
+
14
+
15
+ def get_target_path(relpath: str | Path, run_dir: str | Path | None = None):
16
+ if run_dir is None:
17
+ run_dir = Path(__file__).parent.parent / "model_repo" / RUN_NAME
18
+ return Path(run_dir) / relpath
19
+
20
+
21
+ def download(run_dir: str | Path | None = None):
22
+ relpaths = ["hparams.yaml", "ds/G/latest", "ds/G/default/mp_rank_00_model_states.pt"]
23
+ for relpath in relpaths:
24
+ path = get_target_path(relpath, run_dir=run_dir)
25
+ if path.exists():
26
+ continue
27
+ url = get_source_url(relpath)
28
+ path.parent.mkdir(parents=True, exist_ok=True)
29
+ torch.hub.download_url_to_file(url, str(path))
30
+ return get_target_path("", run_dir=run_dir)
modules/repos_static/resemble_enhance/enhancer/enhancer.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ import matplotlib.pyplot as plt
4
+ import pandas as pd
5
+ import torch
6
+ from torch import Tensor, nn
7
+ from torch.distributions import Beta
8
+
9
+ from ..common import Normalizer
10
+ from ..denoiser.inference import load_denoiser
11
+ from ..melspec import MelSpectrogram
12
+ from .hparams import HParams
13
+ from .lcfm import CFM, IRMAE, LCFM
14
+ from .univnet import UnivNet
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ def _maybe(fn):
20
+ def _fn(*args):
21
+ if args[0] is None:
22
+ return None
23
+ return fn(*args)
24
+
25
+ return _fn
26
+
27
+
28
+ def _normalize_wav(x: Tensor):
29
+ return x / (x.abs().max(dim=-1, keepdim=True).values + 1e-7)
30
+
31
+
32
+ class Enhancer(nn.Module):
33
+ def __init__(self, hp: HParams):
34
+ super().__init__()
35
+ self.hp = hp
36
+
37
+ n_mels = self.hp.num_mels
38
+ vocoder_input_dim = n_mels + self.hp.vocoder_extra_dim
39
+ latent_dim = self.hp.lcfm_latent_dim
40
+
41
+ self.lcfm = LCFM(
42
+ IRMAE(
43
+ input_dim=n_mels,
44
+ output_dim=vocoder_input_dim,
45
+ latent_dim=latent_dim,
46
+ ),
47
+ CFM(
48
+ cond_dim=n_mels,
49
+ output_dim=self.hp.lcfm_latent_dim,
50
+ solver_nfe=self.hp.cfm_solver_nfe,
51
+ solver_method=self.hp.cfm_solver_method,
52
+ time_mapping_divisor=self.hp.cfm_time_mapping_divisor,
53
+ ),
54
+ z_scale=self.hp.lcfm_z_scale,
55
+ )
56
+
57
+ self.lcfm.set_mode_(self.hp.lcfm_training_mode)
58
+
59
+ self.mel_fn = MelSpectrogram(hp)
60
+ self.vocoder = UnivNet(self.hp, vocoder_input_dim)
61
+ self.denoiser = load_denoiser(self.hp.denoiser_run_dir, "cpu")
62
+ self.normalizer = Normalizer()
63
+
64
+ self._eval_lambd = 0.0
65
+
66
+ self.dummy: Tensor
67
+ self.register_buffer("dummy", torch.zeros(1))
68
+
69
+ if self.hp.enhancer_stage1_run_dir is not None:
70
+ pretrained_path = (
71
+ self.hp.enhancer_stage1_run_dir
72
+ / "ds/G/default/mp_rank_00_model_states.pt"
73
+ )
74
+ self._load_pretrained(pretrained_path)
75
+
76
+ logger.info(f"{self.__class__.__name__} summary")
77
+ logger.info(f"{self.summarize()}")
78
+
79
+ def _load_pretrained(self, path):
80
+ # Clone is necessary as otherwise it holds a reference to the original model
81
+ cfm_state_dict = {k: v.clone() for k, v in self.lcfm.cfm.state_dict().items()}
82
+ denoiser_state_dict = {
83
+ k: v.clone() for k, v in self.denoiser.state_dict().items()
84
+ }
85
+ state_dict = torch.load(path, map_location="cpu")["module"]
86
+ self.load_state_dict(state_dict, strict=False)
87
+ self.lcfm.cfm.load_state_dict(cfm_state_dict) # Reset cfm
88
+ self.denoiser.load_state_dict(denoiser_state_dict) # Reset denoiser
89
+ logger.info(f"Loaded pretrained model from {path}")
90
+
91
+ def summarize(self):
92
+ npa_train = lambda m: sum(p.numel() for p in m.parameters() if p.requires_grad)
93
+ npa = lambda m: sum(p.numel() for p in m.parameters())
94
+ rows = []
95
+ for name, module in self.named_children():
96
+ rows.append(dict(name=name, trainable=npa_train(module), total=npa(module)))
97
+ rows.append(dict(name="total", trainable=npa_train(self), total=npa(self)))
98
+ df = pd.DataFrame(rows)
99
+ return df.to_markdown(index=False)
100
+
101
+ def to_mel(self, x: Tensor, drop_last=True):
102
+ """
103
+ Args:
104
+ x: (b t), wavs
105
+ Returns:
106
+ o: (b c t), mels
107
+ """
108
+ if drop_last:
109
+ return self.mel_fn(x)[..., :-1] # (b d t)
110
+ return self.mel_fn(x)
111
+
112
+ def _may_denoise(self, x: Tensor, y: Tensor | None = None):
113
+ if self.hp.lcfm_training_mode == "cfm":
114
+ return self.denoiser(x, y)
115
+ return x
116
+
117
+ def configurate_(self, nfe, solver, lambd, tau):
118
+ """
119
+ Args:
120
+ nfe: number of function evaluations
121
+ solver: solver method
122
+ lambd: denoiser strength [0, 1]
123
+ tau: prior temperature [0, 1]
124
+ """
125
+ self.lcfm.cfm.solver.configurate_(nfe, solver)
126
+ self.lcfm.eval_tau_(tau)
127
+ self._eval_lambd = lambd
128
+
129
+ def forward(self, x: Tensor, y: Tensor | None = None, z: Tensor | None = None):
130
+ """
131
+ Args:
132
+ x: (b t), mix wavs (fg + bg)
133
+ y: (b t), fg clean wavs
134
+ z: (b t), fg distorted wavs
135
+ Returns:
136
+ o: (b t), reconstructed wavs
137
+ """
138
+ assert x.dim() == 2, f"Expected (b t), got {x.size()}"
139
+ assert y is None or y.dim() == 2, f"Expected (b t), got {y.size()}"
140
+
141
+ if self.hp.lcfm_training_mode == "cfm":
142
+ self.normalizer.eval()
143
+
144
+ x = _normalize_wav(x)
145
+ y = _maybe(_normalize_wav)(y)
146
+ z = _maybe(_normalize_wav)(z)
147
+
148
+ x_mel_original = self.normalizer(self.to_mel(x), update=False) # (b d t)
149
+
150
+ if self.hp.lcfm_training_mode == "cfm":
151
+ if self.training:
152
+ lambd = Beta(0.2, 0.2).sample(x.shape[:1]).to(x.device)
153
+ lambd = lambd[:, None, None]
154
+ x_mel_denoised = self.normalizer(
155
+ self.to_mel(self._may_denoise(x, z)), update=False
156
+ )
157
+ x_mel_denoised = x_mel_denoised.detach()
158
+ x_mel_denoised = lambd * x_mel_denoised + (1 - lambd) * x_mel_original
159
+ self._visualize(x_mel_original, x_mel_denoised)
160
+ else:
161
+ lambd = self._eval_lambd
162
+ if lambd == 0:
163
+ x_mel_denoised = x_mel_original
164
+ else:
165
+ x_mel_denoised = self.normalizer(
166
+ self.to_mel(self._may_denoise(x, z)), update=False
167
+ )
168
+ x_mel_denoised = x_mel_denoised.detach()
169
+ x_mel_denoised = (
170
+ lambd * x_mel_denoised + (1 - lambd) * x_mel_original
171
+ )
172
+ else:
173
+ x_mel_denoised = x_mel_original
174
+
175
+ y_mel = _maybe(self.to_mel)(y) # (b d t)
176
+ y_mel = _maybe(self.normalizer)(y_mel)
177
+
178
+ lcfm_decoded = self.lcfm(x_mel_denoised, y_mel, ψ0=x_mel_original) # (b d t)
179
+
180
+ if lcfm_decoded is None:
181
+ o = None
182
+ else:
183
+ o = self.vocoder(lcfm_decoded, y)
184
+
185
+ return o
modules/repos_static/resemble_enhance/enhancer/hparams.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from pathlib import Path
3
+
4
+ from ..hparams import HParams as HParamsBase
5
+
6
+
7
+ @dataclass(frozen=True)
8
+ class HParams(HParamsBase):
9
+ cfm_solver_method: str = "midpoint"
10
+ cfm_solver_nfe: int = 64
11
+ cfm_time_mapping_divisor: int = 4
12
+ univnet_nc: int = 96
13
+
14
+ lcfm_latent_dim: int = 64
15
+ lcfm_training_mode: str = "ae"
16
+ lcfm_z_scale: float = 5
17
+
18
+ vocoder_extra_dim: int = 32
19
+
20
+ gan_training_start_step: int | None = 5_000
21
+ enhancer_stage1_run_dir: Path | None = None
22
+
23
+ denoiser_run_dir: Path | None = None
modules/repos_static/resemble_enhance/enhancer/inference.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from functools import cache
3
+ from pathlib import Path
4
+
5
+ import torch
6
+
7
+ from ..inference import inference
8
+ from .download import download
9
+ from .hparams import HParams
10
+ from .enhancer import Enhancer
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ @cache
16
+ def load_enhancer(run_dir: str | Path | None, device):
17
+ run_dir = download(run_dir)
18
+ hp = HParams.load(run_dir)
19
+ enhancer = Enhancer(hp)
20
+ path = run_dir / "ds" / "G" / "default" / "mp_rank_00_model_states.pt"
21
+ state_dict = torch.load(path, map_location="cpu")["module"]
22
+ enhancer.load_state_dict(state_dict)
23
+ enhancer.eval()
24
+ enhancer.to(device)
25
+ return enhancer
26
+
27
+
28
+ @torch.inference_mode()
29
+ def denoise(dwav, sr, device, run_dir=None):
30
+ enhancer = load_enhancer(run_dir, device)
31
+ return inference(model=enhancer.denoiser, dwav=dwav, sr=sr, device=device)
32
+
33
+
34
+ @torch.inference_mode()
35
+ def enhance(
36
+ dwav, sr, device, nfe=32, solver="midpoint", lambd=0.5, tau=0.5, run_dir=None
37
+ ):
38
+ assert 0 < nfe <= 128, f"nfe must be in (0, 128], got {nfe}"
39
+ assert solver in (
40
+ "midpoint",
41
+ "rk4",
42
+ "euler",
43
+ ), f"solver must be in ('midpoint', 'rk4', 'euler'), got {solver}"
44
+ assert 0 <= lambd <= 1, f"lambd must be in [0, 1], got {lambd}"
45
+ assert 0 <= tau <= 1, f"tau must be in [0, 1], got {tau}"
46
+ enhancer = load_enhancer(run_dir, device)
47
+ enhancer.configurate_(nfe=nfe, solver=solver, lambd=lambd, tau=tau)
48
+ return inference(model=enhancer, dwav=dwav, sr=sr, device=device)
modules/repos_static/resemble_enhance/enhancer/lcfm/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .irmae import IRMAE
2
+ from .lcfm import CFM, LCFM
modules/repos_static/resemble_enhance/enhancer/lcfm/cfm.py ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from dataclasses import dataclass
3
+ from functools import partial
4
+ from typing import Protocol
5
+
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
+ import scipy
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from torch import Tensor, nn
12
+ from tqdm import trange
13
+
14
+ from .wn import WN
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ class VelocityField(Protocol):
20
+ def __call__(self, *, t: Tensor, ψt: Tensor, dt: Tensor) -> Tensor:
21
+ ...
22
+
23
+
24
+ class Solver:
25
+ def __init__(
26
+ self,
27
+ method="midpoint",
28
+ nfe=32,
29
+ viz_name="solver",
30
+ viz_every=100,
31
+ mel_fn=None,
32
+ time_mapping_divisor=4,
33
+ verbose=False,
34
+ ):
35
+ self.configurate_(nfe=nfe, method=method)
36
+
37
+ self.verbose = verbose
38
+ self.viz_every = viz_every
39
+ self.viz_name = viz_name
40
+
41
+ self._camera = None
42
+ self._mel_fn = mel_fn
43
+ self._time_mapping = partial(self.exponential_decay_mapping, n=time_mapping_divisor)
44
+
45
+ def configurate_(self, nfe=None, method=None):
46
+ if nfe is None:
47
+ nfe = self.nfe
48
+
49
+ if method is None:
50
+ method = self.method
51
+
52
+ if nfe == 1 and method in ("midpoint", "rk4"):
53
+ logger.warning(f"1 NFE is not supported for {method}, using euler method instead.")
54
+ method = "euler"
55
+
56
+ self.nfe = nfe
57
+ self.method = method
58
+
59
+ @property
60
+ def time_mapping(self):
61
+ return self._time_mapping
62
+
63
+ @staticmethod
64
+ def exponential_decay_mapping(t, n=4):
65
+ """
66
+ Args:
67
+ n: target step
68
+ """
69
+
70
+ def h(t, a):
71
+ return (a**t - 1) / (a - 1)
72
+
73
+ # Solve h(1/n) = 0.5
74
+ a = float(scipy.optimize.fsolve(lambda a: h(1 / n, a) - 0.5, x0=0))
75
+
76
+ t = h(t, a=a)
77
+
78
+ return t
79
+
80
+ @torch.no_grad()
81
+ def _maybe_camera_snap(self, *, ψt, t):
82
+ camera = self._camera
83
+ if camera is not None:
84
+ if ψt.shape[1] == 1:
85
+ # Waveform, b 1 t, plot every 100 samples
86
+ plt.subplot(211)
87
+ plt.plot(ψt.detach().cpu().numpy()[0, 0, ::100], color="blue")
88
+ if self._mel_fn is not None:
89
+ plt.subplot(212)
90
+ mel = self._mel_fn(ψt.detach().cpu().numpy()[0, 0])
91
+ plt.imshow(mel, origin="lower", interpolation="none")
92
+ elif ψt.shape[1] == 2:
93
+ # Complex
94
+ plt.subplot(121)
95
+ plt.imshow(
96
+ ψt.detach().cpu().numpy()[0, 0],
97
+ origin="lower",
98
+ interpolation="none",
99
+ )
100
+ plt.subplot(122)
101
+ plt.imshow(
102
+ ψt.detach().cpu().numpy()[0, 1],
103
+ origin="lower",
104
+ interpolation="none",
105
+ )
106
+ else:
107
+ # Spectrogram, b c t
108
+ plt.imshow(ψt.detach().cpu().numpy()[0], origin="lower", interpolation="none")
109
+ ax = plt.gca()
110
+ ax.text(0.5, 1.01, f"t={t:.2f}", transform=ax.transAxes, ha="center")
111
+ camera.snap()
112
+
113
+ @staticmethod
114
+ def _euler_step(t, ψt, dt, f: VelocityField):
115
+ return ψt + dt * f(t=t, ψt=ψt, dt=dt)
116
+
117
+ @staticmethod
118
+ def _midpoint_step(t, ψt, dt, f: VelocityField):
119
+ return ψt + dt * f(t=t + dt / 2, ψt=ψt + dt * f(t=t, ψt=ψt, dt=dt) / 2, dt=dt)
120
+
121
+ @staticmethod
122
+ def _rk4_step(t, ψt, dt, f: VelocityField):
123
+ k1 = f(t=t, ψt=ψt, dt=dt)
124
+ k2 = f(t=t + dt / 2, ψt=ψt + dt * k1 / 2, dt=dt)
125
+ k3 = f(t=t + dt / 2, ψt=ψt + dt * k2 / 2, dt=dt)
126
+ k4 = f(t=t + dt, ψt=ψt + dt * k3, dt=dt)
127
+ return ψt + dt * (k1 + 2 * k2 + 2 * k3 + k4) / 6
128
+
129
+ @property
130
+ def _step(self):
131
+ if self.method == "euler":
132
+ return self._euler_step
133
+ elif self.method == "midpoint":
134
+ return self._midpoint_step
135
+ elif self.method == "rk4":
136
+ return self._rk4_step
137
+ else:
138
+ raise ValueError(f"Unknown method: {self.method}")
139
+
140
+ def get_running_train_loop(self):
141
+ try:
142
+ # Lazy import
143
+ from ...utils.train_loop import TrainLoop
144
+
145
+ return TrainLoop.get_running_loop()
146
+ except ImportError:
147
+ return None
148
+
149
+ @property
150
+ def visualizing(self):
151
+ loop = self.get_running_train_loop()
152
+ if loop is None:
153
+ return
154
+ out_path = loop.make_current_step_viz_path(self.viz_name, ".gif")
155
+ return loop.global_step % self.viz_every == 0 and not out_path.exists()
156
+
157
+ def _reset_camera(self):
158
+ try:
159
+ from celluloid import Camera
160
+
161
+ self._camera = Camera(plt.figure())
162
+ except:
163
+ pass
164
+
165
+ def _maybe_dump_camera(self):
166
+ camera = self._camera
167
+ loop = self.get_running_train_loop()
168
+ if camera is not None and loop is not None:
169
+ animation = camera.animate()
170
+ out_path = loop.make_current_step_viz_path(self.viz_name, ".gif")
171
+ out_path.parent.mkdir(exist_ok=True, parents=True)
172
+ animation.save(out_path, writer="pillow", fps=4)
173
+ plt.close()
174
+ self._camera = None
175
+
176
+ @property
177
+ def n_steps(self):
178
+ n = self.nfe
179
+ if self.method == "euler":
180
+ pass
181
+ elif self.method == "midpoint":
182
+ n //= 2
183
+ elif self.method == "rk4":
184
+ n //= 4
185
+ else:
186
+ raise ValueError(f"Unknown method: {self.method}")
187
+ return n
188
+
189
+ def solve(self, f: VelocityField, ψ0: Tensor, t0=0.0, t1=1.0):
190
+ ts = self._time_mapping(np.linspace(t0, t1, self.n_steps + 1))
191
+
192
+ if self.visualizing:
193
+ self._reset_camera()
194
+
195
+ if self.verbose:
196
+ steps = trange(self.n_steps, desc="CFM inference")
197
+ else:
198
+ steps = range(self.n_steps)
199
+
200
+ ψt = ψ0
201
+
202
+ for i in steps:
203
+ dt = ts[i + 1] - ts[i]
204
+ t = ts[i]
205
+ self._maybe_camera_snap(ψt=ψt, t=t)
206
+ ψt = self._step(t=t, ψt=ψt, dt=dt, f=f)
207
+
208
+ self._maybe_camera_snap(ψt=ψt, t=ts[-1])
209
+
210
+ ψ1 = ψt
211
+ del ψt
212
+
213
+ self._maybe_dump_camera()
214
+
215
+ return ψ1
216
+
217
+ def __call__(self, f: VelocityField, ψ0: Tensor, t0=0.0, t1=1.0):
218
+ return self.solve(f=f, ψ0=ψ0, t0=t0, t1=t1)
219
+
220
+
221
+ class SinusodialTimeEmbedding(nn.Module):
222
+ def __init__(self, d_embed):
223
+ super().__init__()
224
+ self.d_embed = d_embed
225
+ assert d_embed % 2 == 0
226
+
227
+ def forward(self, t):
228
+ t = t.unsqueeze(-1) # ... 1
229
+ p = torch.linspace(0, 4, self.d_embed // 2).to(t)
230
+ while p.dim() < t.dim():
231
+ p = p.unsqueeze(0) # ... d/2
232
+ sin = torch.sin(t * 10**p)
233
+ cos = torch.cos(t * 10**p)
234
+ return torch.cat([sin, cos], dim=-1)
235
+
236
+
237
+ @dataclass(eq=False)
238
+ class CFM(nn.Module):
239
+ """
240
+ This mixin is for general diffusion models.
241
+
242
+ ψ0 stands for the gaussian noise, and ψ1 is the data point.
243
+
244
+ Here we follow the CFM style:
245
+ The generation process (reverse process) is from t=0 to t=1.
246
+ The forward process is from t=1 to t=0.
247
+ """
248
+
249
+ cond_dim: int
250
+ output_dim: int
251
+ time_emb_dim: int = 128
252
+ viz_name: str = "cfm"
253
+ solver_nfe: int = 32
254
+ solver_method: str = "midpoint"
255
+ time_mapping_divisor: int = 4
256
+
257
+ def __post_init__(self):
258
+ super().__init__()
259
+ self.solver = Solver(
260
+ viz_name=self.viz_name,
261
+ viz_every=1,
262
+ nfe=self.solver_nfe,
263
+ method=self.solver_method,
264
+ time_mapping_divisor=self.time_mapping_divisor,
265
+ )
266
+ self.emb = SinusodialTimeEmbedding(self.time_emb_dim)
267
+ self.net = WN(
268
+ input_dim=self.output_dim,
269
+ output_dim=self.output_dim,
270
+ local_dim=self.cond_dim,
271
+ global_dim=self.time_emb_dim,
272
+ )
273
+
274
+ def _perturb(self, ψ1: Tensor, t: Tensor | None = None):
275
+ """
276
+ Perturb ψ1 to ψt.
277
+ """
278
+ raise NotImplementedError
279
+
280
+ def _sample_ψ0(self, x: Tensor):
281
+ """
282
+ Args:
283
+ x: (b c t), which implies the shape of ψ0
284
+ """
285
+ shape = list(x.shape)
286
+ shape[1] = self.output_dim
287
+ if self.training:
288
+ g = None
289
+ else:
290
+ g = torch.Generator(device=x.device)
291
+ g.manual_seed(0) # deterministic sampling during eval
292
+ ψ0 = torch.randn(shape, device=x.device, dtype=x.dtype, generator=g)
293
+ return ψ0
294
+
295
+ @property
296
+ def sigma(self):
297
+ return 1e-4
298
+
299
+ def _to_ψt(self, *, ψ1: Tensor, ψ0: Tensor, t: Tensor):
300
+ """
301
+ Eq (22)
302
+ """
303
+ while t.dim() < ψ1.dim():
304
+ t = t.unsqueeze(-1)
305
+ μ = t * ψ1 + (1 - t) * ψ0
306
+ return μ + torch.randn_like(μ) * self.sigma
307
+
308
+ def _to_u(self, *, ψ1, ψ0: Tensor):
309
+ """
310
+ Eq (21)
311
+ """
312
+ return ψ1 - ψ0
313
+
314
+ def _to_v(self, *, ψt, x, t: float | Tensor):
315
+ """
316
+ Args:
317
+ ψt: (b c t)
318
+ x: (b c t)
319
+ t: (b)
320
+ Returns:
321
+ v: (b c t)
322
+ """
323
+ if isinstance(t, (float, int)):
324
+ t = torch.full(ψt.shape[:1], t).to(ψt)
325
+ t = t.clamp(0, 1) # [0, 1)
326
+ g = self.emb(t) # (b d)
327
+ v = self.net(ψt, l=x, g=g)
328
+ return v
329
+
330
+ def compute_losses(self, x, y, ψ0) -> dict:
331
+ """
332
+ Args:
333
+ x: (b c t)
334
+ y: (b c t)
335
+ Returns:
336
+ losses: dict
337
+ """
338
+ t = torch.rand(len(x), device=x.device, dtype=x.dtype)
339
+ t = self.solver.time_mapping(t)
340
+
341
+ if ψ0 is None:
342
+ ψ0 = self._sample_ψ0(x)
343
+
344
+ ψt = self._to_ψt(ψ1=y, t=t, ψ0=ψ0)
345
+
346
+ v = self._to_v(ψt=ψt, t=t, x=x)
347
+ u = self._to_u(ψ1=y, ψ0=ψ0)
348
+
349
+ losses = dict(l1=F.l1_loss(v, u))
350
+
351
+ return losses
352
+
353
+ @torch.inference_mode()
354
+ def sample(self, x, ψ0=None, t0=0.0):
355
+ """
356
+ Args:
357
+ x: (b c t)
358
+ Returns:
359
+ y: (b ... t)
360
+ """
361
+ if ψ0 is None:
362
+ ψ0 = self._sample_ψ0(x)
363
+ f = lambda t, ψt, dt: self._to_v(ψt=ψt, t=t, x=x)
364
+ ψ1 = self.solver(f=f, ψ0=ψ0, t0=t0)
365
+ return ψ1
366
+
367
+ def forward(self, x: Tensor, y: Tensor | None = None, ψ0: Tensor | None = None, t0=0.0):
368
+ if y is None:
369
+ y = self.sample(x, ψ0=ψ0, t0=t0)
370
+ else:
371
+ self.losses = self.compute_losses(x, y, ψ0=ψ0)
372
+ return y
modules/repos_static/resemble_enhance/enhancer/lcfm/irmae.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from dataclasses import dataclass
3
+
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from torch import Tensor, nn
7
+ from torch.nn.utils.parametrizations import weight_norm
8
+
9
+ from ...common import Normalizer
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ @dataclass
15
+ class IRMAEOutput:
16
+ latent: Tensor # latent vector
17
+ decoded: Tensor | None # decoder output, include extra dim
18
+
19
+
20
+ class ResBlock(nn.Sequential):
21
+ def __init__(self, channels, dilations=[1, 2, 4, 8]):
22
+ wn = weight_norm
23
+ super().__init__(
24
+ nn.GroupNorm(32, channels),
25
+ nn.GELU(),
26
+ wn(nn.Conv1d(channels, channels, 3, padding="same", dilation=dilations[0])),
27
+ nn.GroupNorm(32, channels),
28
+ nn.GELU(),
29
+ wn(nn.Conv1d(channels, channels, 3, padding="same", dilation=dilations[1])),
30
+ nn.GroupNorm(32, channels),
31
+ nn.GELU(),
32
+ wn(nn.Conv1d(channels, channels, 3, padding="same", dilation=dilations[2])),
33
+ nn.GroupNorm(32, channels),
34
+ nn.GELU(),
35
+ wn(nn.Conv1d(channels, channels, 3, padding="same", dilation=dilations[3])),
36
+ )
37
+
38
+ def forward(self, x: Tensor):
39
+ return x + super().forward(x)
40
+
41
+
42
+ class IRMAE(nn.Module):
43
+ def __init__(
44
+ self,
45
+ input_dim,
46
+ output_dim,
47
+ latent_dim,
48
+ hidden_dim=1024,
49
+ num_irms=4,
50
+ ):
51
+ """
52
+ Args:
53
+ input_dim: input dimension
54
+ output_dim: output dimension
55
+ latent_dim: latent dimension
56
+ hidden_dim: hidden layer dimension
57
+ num_irm_matrics: number of implicit rank minimization matrices
58
+ norm: normalization layer
59
+ """
60
+ self.input_dim = input_dim
61
+ super().__init__()
62
+
63
+ self.encoder = nn.Sequential(
64
+ nn.Conv1d(input_dim, hidden_dim, 3, padding="same"),
65
+ *[ResBlock(hidden_dim) for _ in range(4)],
66
+ # Try to obtain compact representation (https://proceedings.neurips.cc/paper/2020/file/a9078e8653368c9c291ae2f8b74012e7-Paper.pdf)
67
+ *[nn.Conv1d(hidden_dim if i == 0 else latent_dim, latent_dim, 1, bias=False) for i in range(num_irms)],
68
+ nn.Tanh(),
69
+ )
70
+
71
+ self.decoder = nn.Sequential(
72
+ nn.Conv1d(latent_dim, hidden_dim, 3, padding="same"),
73
+ *[ResBlock(hidden_dim) for _ in range(4)],
74
+ nn.Conv1d(hidden_dim, output_dim, 1),
75
+ )
76
+
77
+ self.head = nn.Sequential(
78
+ nn.Conv1d(output_dim, hidden_dim, 3, padding="same"),
79
+ nn.GELU(),
80
+ nn.Conv1d(hidden_dim, input_dim, 1),
81
+ )
82
+
83
+ self.estimator = Normalizer()
84
+
85
+ def encode(self, x):
86
+ """
87
+ Args:
88
+ x: (b c t) tensor
89
+ """
90
+ z = self.encoder(x) # (b c t)
91
+ _ = self.estimator(z) # Estimate the glboal mean and std of z
92
+ self.stats = {}
93
+ self.stats["z_mean"] = z.mean().item()
94
+ self.stats["z_std"] = z.std().item()
95
+ self.stats["z_abs_68"] = z.abs().quantile(0.6827).item()
96
+ self.stats["z_abs_95"] = z.abs().quantile(0.9545).item()
97
+ self.stats["z_abs_99"] = z.abs().quantile(0.9973).item()
98
+ return z
99
+
100
+ def decode(self, z):
101
+ """
102
+ Args:
103
+ z: (b c t) tensor
104
+ """
105
+ return self.decoder(z)
106
+
107
+ def forward(self, x, skip_decoding=False):
108
+ """
109
+ Args:
110
+ x: (b c t) tensor
111
+ skip_decoding: if True, skip the decoding step
112
+ """
113
+ z = self.encode(x) # q(z|x)
114
+
115
+ if skip_decoding:
116
+ # This speeds up the training in cfm only mode
117
+ decoded = None
118
+ else:
119
+ decoded = self.decode(z) # p(x|z)
120
+ predicted = self.head(decoded)
121
+ self.losses = dict(mse=F.mse_loss(predicted, x))
122
+
123
+ return IRMAEOutput(latent=z, decoded=decoded)
modules/repos_static/resemble_enhance/enhancer/lcfm/lcfm.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from enum import Enum
3
+
4
+ import matplotlib.pyplot as plt
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch import Tensor, nn
8
+
9
+ from .cfm import CFM
10
+ from .irmae import IRMAE, IRMAEOutput
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ def freeze_(module):
16
+ for p in module.parameters():
17
+ p.requires_grad_(False)
18
+
19
+
20
+ class LCFM(nn.Module):
21
+ class Mode(Enum):
22
+ AE = "ae"
23
+ CFM = "cfm"
24
+
25
+ def __init__(self, ae: IRMAE, cfm: CFM, z_scale: float = 1.0):
26
+ super().__init__()
27
+ self.ae = ae
28
+ self.cfm = cfm
29
+ self.z_scale = z_scale
30
+ self._mode = None
31
+ self._eval_tau = 0.5
32
+
33
+ @property
34
+ def mode(self):
35
+ return self._mode
36
+
37
+ def set_mode_(self, mode):
38
+ mode = self.Mode(mode)
39
+ self._mode = mode
40
+
41
+ if mode == mode.AE:
42
+ freeze_(self.cfm)
43
+ logger.info("Freeze cfm")
44
+ elif mode == mode.CFM:
45
+ freeze_(self.ae)
46
+ logger.info("Freeze ae (encoder and decoder)")
47
+ else:
48
+ raise ValueError(f"Unknown training mode: {mode}")
49
+
50
+ def get_running_train_loop(self):
51
+ try:
52
+ # Lazy import
53
+ from ...utils.train_loop import TrainLoop
54
+
55
+ return TrainLoop.get_running_loop()
56
+ except ImportError:
57
+ return None
58
+
59
+ @property
60
+ def global_step(self):
61
+ loop = self.get_running_train_loop()
62
+ if loop is None:
63
+ return None
64
+ return loop.global_step
65
+
66
+ @torch.no_grad()
67
+ def _visualize(self, x, y, y_):
68
+ loop = self.get_running_train_loop()
69
+ if loop is None:
70
+ return
71
+
72
+ plt.subplot(221)
73
+ plt.imshow(y[0].detach().cpu().numpy(), aspect="auto", origin="lower", interpolation="none")
74
+ plt.title("GT")
75
+
76
+ plt.subplot(222)
77
+ y_ = y_[:, : y.shape[1]]
78
+ plt.imshow(y_[0].detach().cpu().numpy(), aspect="auto", origin="lower", interpolation="none")
79
+ plt.title("Posterior")
80
+
81
+ plt.subplot(223)
82
+ z_ = self.cfm(x)
83
+ y__ = self.ae.decode(z_)
84
+ y__ = y__[:, : y.shape[1]]
85
+ plt.imshow(y__[0].detach().cpu().numpy(), aspect="auto", origin="lower", interpolation="none")
86
+ plt.title("C-Prior")
87
+ del y__
88
+
89
+ plt.subplot(224)
90
+ z_ = torch.randn_like(z_)
91
+ y__ = self.ae.decode(z_)
92
+ y__ = y__[:, : y.shape[1]]
93
+ plt.imshow(y__[0].detach().cpu().numpy(), aspect="auto", origin="lower", interpolation="none")
94
+ plt.title("Prior")
95
+ del z_, y__
96
+
97
+ path = loop.make_current_step_viz_path("recon", ".png")
98
+ path.parent.mkdir(exist_ok=True, parents=True)
99
+ plt.tight_layout()
100
+ plt.savefig(path, dpi=500)
101
+ plt.close()
102
+
103
+ def _scale(self, z: Tensor):
104
+ return z * self.z_scale
105
+
106
+ def _unscale(self, z: Tensor):
107
+ return z / self.z_scale
108
+
109
+ def eval_tau_(self, tau):
110
+ self._eval_tau = tau
111
+
112
+ def forward(self, x, y: Tensor | None = None, ψ0: Tensor | None = None):
113
+ """
114
+ Args:
115
+ x: (b d t), condition mel
116
+ y: (b d t), target mel
117
+ ψ0: (b d t), starting mel
118
+ """
119
+ if self.mode == self.Mode.CFM:
120
+ self.ae.eval() # Always set to eval when training cfm
121
+
122
+ if ψ0 is not None:
123
+ ψ0 = self._scale(self.ae.encode(ψ0))
124
+ if self.training:
125
+ tau = torch.rand_like(ψ0[:, :1, :1])
126
+ else:
127
+ tau = self._eval_tau
128
+ ψ0 = tau * torch.randn_like(ψ0) + (1 - tau) * ψ0
129
+
130
+ if y is None:
131
+ if self.mode == self.Mode.AE:
132
+ with torch.no_grad():
133
+ training = self.ae.training
134
+ self.ae.eval()
135
+ z = self.ae.encode(x)
136
+ self.ae.train(training)
137
+ else:
138
+ z = self._unscale(self.cfm(x, ψ0=ψ0))
139
+
140
+ h = self.ae.decode(z)
141
+ else:
142
+ ae_output: IRMAEOutput = self.ae(y, skip_decoding=self.mode == self.Mode.CFM)
143
+
144
+ if self.mode == self.Mode.CFM:
145
+ _ = self.cfm(x, self._scale(ae_output.latent.detach()), ψ0=ψ0)
146
+
147
+ h = ae_output.decoded
148
+
149
+ if h is not None and self.global_step is not None and self.global_step % 100 == 0:
150
+ self._visualize(x[:1], y[:1], h[:1])
151
+
152
+ return h
modules/repos_static/resemble_enhance/enhancer/lcfm/wn.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+
10
+ @torch.jit.script
11
+ def _fused_tanh_sigmoid(h):
12
+ a, b = h.chunk(2, dim=1)
13
+ h = a.tanh() * b.sigmoid()
14
+ return h
15
+
16
+
17
+ class WNLayer(nn.Module):
18
+ """
19
+ A DiffWave-like WN
20
+ """
21
+
22
+ def __init__(self, hidden_dim, local_dim, global_dim, kernel_size, dilation):
23
+ super().__init__()
24
+
25
+ local_output_dim = hidden_dim * 2
26
+
27
+ if global_dim is not None:
28
+ self.gconv = nn.Conv1d(global_dim, hidden_dim, 1)
29
+
30
+ if local_dim is not None:
31
+ self.lconv = nn.Conv1d(local_dim, local_output_dim, 1)
32
+
33
+ self.dconv = nn.Conv1d(hidden_dim, local_output_dim, kernel_size, dilation=dilation, padding="same")
34
+
35
+ self.out = nn.Conv1d(hidden_dim, 2 * hidden_dim, kernel_size=1)
36
+
37
+ def forward(self, z, l, g):
38
+ identity = z
39
+
40
+ if g is not None:
41
+ if g.dim() == 2:
42
+ g = g.unsqueeze(-1)
43
+ z = z + self.gconv(g)
44
+
45
+ z = self.dconv(z)
46
+
47
+ if l is not None:
48
+ z = z + self.lconv(l)
49
+
50
+ z = _fused_tanh_sigmoid(z)
51
+
52
+ h = self.out(z)
53
+
54
+ z, s = h.chunk(2, dim=1)
55
+
56
+ o = (z + identity) / math.sqrt(2)
57
+
58
+ return o, s
59
+
60
+
61
+ class WN(nn.Module):
62
+ def __init__(
63
+ self,
64
+ input_dim,
65
+ output_dim,
66
+ local_dim=None,
67
+ global_dim=None,
68
+ n_layers=30,
69
+ kernel_size=3,
70
+ dilation_cycle=5,
71
+ hidden_dim=512,
72
+ ):
73
+ super().__init__()
74
+ assert kernel_size % 2 == 1
75
+ assert hidden_dim % 2 == 0
76
+
77
+ self.input_dim = input_dim
78
+ self.hidden_dim = hidden_dim
79
+ self.local_dim = local_dim
80
+ self.global_dim = global_dim
81
+
82
+ self.start = nn.Conv1d(input_dim, hidden_dim, 1)
83
+ if local_dim is not None:
84
+ self.local_norm = nn.InstanceNorm1d(local_dim)
85
+
86
+ self.layers = nn.ModuleList(
87
+ [
88
+ WNLayer(
89
+ hidden_dim=hidden_dim,
90
+ local_dim=local_dim,
91
+ global_dim=global_dim,
92
+ kernel_size=kernel_size,
93
+ dilation=2 ** (i % dilation_cycle),
94
+ )
95
+ for i in range(n_layers)
96
+ ]
97
+ )
98
+
99
+ self.end = nn.Conv1d(hidden_dim, output_dim, 1)
100
+
101
+ def forward(self, z, l=None, g=None):
102
+ """
103
+ Args:
104
+ z: input (b c t)
105
+ l: local condition (b c t)
106
+ g: global condition (b d)
107
+ """
108
+ z = self.start(z)
109
+
110
+ if l is not None:
111
+ l = self.local_norm(l)
112
+
113
+ # Skips
114
+ s_list = []
115
+
116
+ for layer in self.layers:
117
+ z, s = layer(z, l, g)
118
+ s_list.append(s)
119
+
120
+ s_list = torch.stack(s_list, dim=0).sum(dim=0)
121
+ s_list = s_list / math.sqrt(len(self.layers))
122
+
123
+ o = self.end(s_list)
124
+
125
+ return o
126
+
127
+ def summarize(self, length=100):
128
+ from ptflops import get_model_complexity_info
129
+
130
+ x = torch.randn(1, self.input_dim, length)
131
+
132
+ macs, params = get_model_complexity_info(
133
+ self,
134
+ (self.input_dim, length),
135
+ as_strings=True,
136
+ print_per_layer_stat=True,
137
+ verbose=True,
138
+ )
139
+
140
+ print(f"Input shape: {x.shape}")
141
+ print(f"Computational complexity: {macs}")
142
+ print(f"Number of parameters: {params}")
143
+
144
+
145
+ if __name__ == "__main__":
146
+ model = WN(input_dim=64, output_dim=64)
147
+ model.summarize()
modules/repos_static/resemble_enhance/enhancer/univnet/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .univnet import UnivNet
modules/repos_static/resemble_enhance/enhancer/univnet/alias_free_torch/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ from .filter import *
5
+ from .resample import *
modules/repos_static/resemble_enhance/enhancer/univnet/alias_free_torch/filter.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import math
8
+
9
+ if 'sinc' in dir(torch):
10
+ sinc = torch.sinc
11
+ else:
12
+ # This code is adopted from adefossez's julius.core.sinc under the MIT License
13
+ # https://adefossez.github.io/julius/julius/core.html
14
+ # LICENSE is in incl_licenses directory.
15
+ def sinc(x: torch.Tensor):
16
+ """
17
+ Implementation of sinc, i.e. sin(pi * x) / (pi * x)
18
+ __Warning__: Different to julius.sinc, the input is multiplied by `pi`!
19
+ """
20
+ return torch.where(x == 0,
21
+ torch.tensor(1., device=x.device, dtype=x.dtype),
22
+ torch.sin(math.pi * x) / math.pi / x)
23
+
24
+
25
+ # This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
26
+ # https://adefossez.github.io/julius/julius/lowpass.html
27
+ # LICENSE is in incl_licenses directory.
28
+ def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size]
29
+ even = (kernel_size % 2 == 0)
30
+ half_size = kernel_size // 2
31
+
32
+ #For kaiser window
33
+ delta_f = 4 * half_width
34
+ A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
35
+ if A > 50.:
36
+ beta = 0.1102 * (A - 8.7)
37
+ elif A >= 21.:
38
+ beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.)
39
+ else:
40
+ beta = 0.
41
+ window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
42
+
43
+ # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
44
+ if even:
45
+ time = (torch.arange(-half_size, half_size) + 0.5)
46
+ else:
47
+ time = torch.arange(kernel_size) - half_size
48
+ if cutoff == 0:
49
+ filter_ = torch.zeros_like(time)
50
+ else:
51
+ filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
52
+ # Normalize filter to have sum = 1, otherwise we will have a small leakage
53
+ # of the constant component in the input signal.
54
+ filter_ /= filter_.sum()
55
+ filter = filter_.view(1, 1, kernel_size)
56
+
57
+ return filter
58
+
59
+
60
+ class LowPassFilter1d(nn.Module):
61
+ def __init__(self,
62
+ cutoff=0.5,
63
+ half_width=0.6,
64
+ stride: int = 1,
65
+ padding: bool = True,
66
+ padding_mode: str = 'replicate',
67
+ kernel_size: int = 12):
68
+ # kernel_size should be even number for stylegan3 setup,
69
+ # in this implementation, odd number is also possible.
70
+ super().__init__()
71
+ if cutoff < -0.:
72
+ raise ValueError("Minimum cutoff must be larger than zero.")
73
+ if cutoff > 0.5:
74
+ raise ValueError("A cutoff above 0.5 does not make sense.")
75
+ self.kernel_size = kernel_size
76
+ self.even = (kernel_size % 2 == 0)
77
+ self.pad_left = kernel_size // 2 - int(self.even)
78
+ self.pad_right = kernel_size // 2
79
+ self.stride = stride
80
+ self.padding = padding
81
+ self.padding_mode = padding_mode
82
+ filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
83
+ self.register_buffer("filter", filter)
84
+
85
+ #input [B, C, T]
86
+ def forward(self, x):
87
+ _, C, _ = x.shape
88
+
89
+ if self.padding:
90
+ x = F.pad(x, (self.pad_left, self.pad_right),
91
+ mode=self.padding_mode)
92
+ out = F.conv1d(x, self.filter.expand(C, -1, -1),
93
+ stride=self.stride, groups=C)
94
+
95
+ return out
modules/repos_static/resemble_enhance/enhancer/univnet/alias_free_torch/resample.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import torch.nn as nn
5
+ from torch.nn import functional as F
6
+ from .filter import LowPassFilter1d
7
+ from .filter import kaiser_sinc_filter1d
8
+
9
+
10
+ class UpSample1d(nn.Module):
11
+ def __init__(self, ratio=2, kernel_size=None):
12
+ super().__init__()
13
+ self.ratio = ratio
14
+ self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
15
+ self.stride = ratio
16
+ self.pad = self.kernel_size // ratio - 1
17
+ self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
18
+ self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
19
+ filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio,
20
+ half_width=0.6 / ratio,
21
+ kernel_size=self.kernel_size)
22
+ self.register_buffer("filter", filter)
23
+
24
+ # x: [B, C, T]
25
+ def forward(self, x):
26
+ _, C, _ = x.shape
27
+
28
+ x = F.pad(x, (self.pad, self.pad), mode='replicate')
29
+ x = self.ratio * F.conv_transpose1d(
30
+ x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
31
+ x = x[..., self.pad_left:-self.pad_right]
32
+
33
+ return x
34
+
35
+
36
+ class DownSample1d(nn.Module):
37
+ def __init__(self, ratio=2, kernel_size=None):
38
+ super().__init__()
39
+ self.ratio = ratio
40
+ self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
41
+ self.lowpass = LowPassFilter1d(cutoff=0.5 / ratio,
42
+ half_width=0.6 / ratio,
43
+ stride=ratio,
44
+ kernel_size=self.kernel_size)
45
+
46
+ def forward(self, x):
47
+ xx = self.lowpass(x)
48
+
49
+ return xx
modules/repos_static/resemble_enhance/enhancer/univnet/amp.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Refer from https://github.com/NVIDIA/BigVGAN
2
+
3
+ import math
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch import nn
8
+ from torch.nn.utils.parametrizations import weight_norm
9
+
10
+ from .alias_free_torch import DownSample1d, UpSample1d
11
+
12
+
13
+ class SnakeBeta(nn.Module):
14
+ """
15
+ A modified Snake function which uses separate parameters for the magnitude of the periodic components
16
+ Shape:
17
+ - Input: (B, C, T)
18
+ - Output: (B, C, T), same shape as the input
19
+ Parameters:
20
+ - alpha - trainable parameter that controls frequency
21
+ - beta - trainable parameter that controls magnitude
22
+ References:
23
+ - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
24
+ https://arxiv.org/abs/2006.08195
25
+ Examples:
26
+ >>> a1 = snakebeta(256)
27
+ >>> x = torch.randn(256)
28
+ >>> x = a1(x)
29
+ """
30
+
31
+ def __init__(self, in_features, alpha=1.0, clamp=(1e-2, 50)):
32
+ """
33
+ Initialization.
34
+ INPUT:
35
+ - in_features: shape of the input
36
+ - alpha - trainable parameter that controls frequency
37
+ - beta - trainable parameter that controls magnitude
38
+ alpha is initialized to 1 by default, higher values = higher-frequency.
39
+ beta is initialized to 1 by default, higher values = higher-magnitude.
40
+ alpha will be trained along with the rest of your model.
41
+ """
42
+ super().__init__()
43
+ self.in_features = in_features
44
+ self.log_alpha = nn.Parameter(torch.zeros(in_features) + math.log(alpha))
45
+ self.log_beta = nn.Parameter(torch.zeros(in_features) + math.log(alpha))
46
+ self.clamp = clamp
47
+
48
+ def forward(self, x):
49
+ """
50
+ Forward pass of the function.
51
+ Applies the function to the input elementwise.
52
+ SnakeBeta ∶= x + 1/b * sin^2 (xa)
53
+ """
54
+ alpha = self.log_alpha.exp().clamp(*self.clamp)
55
+ alpha = alpha[None, :, None]
56
+
57
+ beta = self.log_beta.exp().clamp(*self.clamp)
58
+ beta = beta[None, :, None]
59
+
60
+ x = x + (1.0 / beta) * (x * alpha).sin().pow(2)
61
+
62
+ return x
63
+
64
+
65
+ class UpActDown(nn.Module):
66
+ def __init__(
67
+ self,
68
+ act,
69
+ up_ratio: int = 2,
70
+ down_ratio: int = 2,
71
+ up_kernel_size: int = 12,
72
+ down_kernel_size: int = 12,
73
+ ):
74
+ super().__init__()
75
+ self.up_ratio = up_ratio
76
+ self.down_ratio = down_ratio
77
+ self.act = act
78
+ self.upsample = UpSample1d(up_ratio, up_kernel_size)
79
+ self.downsample = DownSample1d(down_ratio, down_kernel_size)
80
+
81
+ def forward(self, x):
82
+ # x: [B,C,T]
83
+ x = self.upsample(x)
84
+ x = self.act(x)
85
+ x = self.downsample(x)
86
+ return x
87
+
88
+
89
+ class AMPBlock(nn.Sequential):
90
+ def __init__(self, channels, *, kernel_size=3, dilations=(1, 3, 5)):
91
+ super().__init__(*(self._make_layer(channels, kernel_size, d) for d in dilations))
92
+
93
+ def _make_layer(self, channels, kernel_size, dilation):
94
+ return nn.Sequential(
95
+ weight_norm(nn.Conv1d(channels, channels, kernel_size, dilation=dilation, padding="same")),
96
+ UpActDown(act=SnakeBeta(channels)),
97
+ weight_norm(nn.Conv1d(channels, channels, kernel_size, padding="same")),
98
+ )
99
+
100
+ def forward(self, x):
101
+ return x + super().forward(x)
modules/repos_static/resemble_enhance/enhancer/univnet/discriminator.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import Tensor, nn
6
+ from torch.nn.utils.parametrizations import weight_norm
7
+
8
+ from ..hparams import HParams
9
+ from .mrstft import get_stft_cfgs
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class PeriodNetwork(nn.Module):
15
+ def __init__(self, period):
16
+ super().__init__()
17
+ self.period = period
18
+ wn = weight_norm
19
+ self.convs = nn.ModuleList(
20
+ [
21
+ wn(nn.Conv2d(1, 64, (5, 1), (3, 1), padding=(2, 0))),
22
+ wn(nn.Conv2d(64, 128, (5, 1), (3, 1), padding=(2, 0))),
23
+ wn(nn.Conv2d(128, 256, (5, 1), (3, 1), padding=(2, 0))),
24
+ wn(nn.Conv2d(256, 512, (5, 1), (3, 1), padding=(2, 0))),
25
+ wn(nn.Conv2d(512, 1024, (5, 1), 1, padding=(2, 0))),
26
+ ]
27
+ )
28
+ self.conv_post = wn(nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
29
+
30
+ def forward(self, x):
31
+ """
32
+ Args:
33
+ x: [B, 1, T]
34
+ """
35
+ assert x.dim() == 3, f"(B, 1, T) is expected, but got {x.shape}."
36
+
37
+ # 1d to 2d
38
+ b, c, t = x.shape
39
+ if t % self.period != 0: # pad first
40
+ n_pad = self.period - (t % self.period)
41
+ x = F.pad(x, (0, n_pad), "reflect")
42
+ t = t + n_pad
43
+ x = x.view(b, c, t // self.period, self.period)
44
+
45
+ for l in self.convs:
46
+ x = l(x)
47
+ x = F.leaky_relu(x, 0.2)
48
+ x = self.conv_post(x)
49
+ x = torch.flatten(x, 1, -1)
50
+
51
+ return x
52
+
53
+
54
+ class SpecNetwork(nn.Module):
55
+ def __init__(self, stft_cfg: dict):
56
+ super().__init__()
57
+ wn = weight_norm
58
+ self.stft_cfg = stft_cfg
59
+ self.convs = nn.ModuleList(
60
+ [
61
+ wn(nn.Conv2d(1, 32, (3, 9), padding=(1, 4))),
62
+ wn(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
63
+ wn(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
64
+ wn(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
65
+ wn(nn.Conv2d(32, 32, (3, 3), padding=(1, 1))),
66
+ ]
67
+ )
68
+ self.conv_post = wn(nn.Conv2d(32, 1, (3, 3), padding=(1, 1)))
69
+
70
+ def forward(self, x):
71
+ """
72
+ Args:
73
+ x: [B, 1, T]
74
+ """
75
+ x = self.spectrogram(x)
76
+ x = x.unsqueeze(1)
77
+ for l in self.convs:
78
+ x = l(x)
79
+ x = F.leaky_relu(x, 0.2)
80
+ x = self.conv_post(x)
81
+ x = x.flatten(1, -1)
82
+ return x
83
+
84
+ def spectrogram(self, x):
85
+ """
86
+ Args:
87
+ x: [B, 1, T]
88
+ """
89
+ x = x.squeeze(1)
90
+ dtype = x.dtype
91
+ stft_cfg = dict(self.stft_cfg)
92
+ x = torch.stft(x.float(), center=False, return_complex=False, **stft_cfg)
93
+ mag = x.norm(p=2, dim=-1) # [B, F, TT]
94
+ mag = mag.to(dtype) # [B, F, TT]
95
+ return mag
96
+
97
+
98
+ class MD(nn.ModuleList):
99
+ def __init__(self, l: list):
100
+ super().__init__([self._create_network(x) for x in l])
101
+ self._loss_type = None
102
+
103
+ def loss_type_(self, loss_type):
104
+ self._loss_type = loss_type
105
+
106
+ def _create_network(self, _):
107
+ raise NotImplementedError
108
+
109
+ def _forward_each(self, d, x, y):
110
+ assert self._loss_type is not None, "loss_type is not set."
111
+ loss_type = self._loss_type
112
+
113
+ if loss_type == "hinge":
114
+ if y == 0:
115
+ # d(x) should be small -> -1
116
+ loss = F.relu(1 + d(x)).mean()
117
+ elif y == 1:
118
+ # d(x) should be large -> 1
119
+ loss = F.relu(1 - d(x)).mean()
120
+ else:
121
+ raise ValueError(f"Invalid y: {y}")
122
+ elif loss_type == "wgan":
123
+ if y == 0:
124
+ loss = d(x).mean()
125
+ elif y == 1:
126
+ loss = -d(x).mean()
127
+ else:
128
+ raise ValueError(f"Invalid y: {y}")
129
+ else:
130
+ raise ValueError(f"Invalid loss_type: {loss_type}")
131
+
132
+ return loss
133
+
134
+ def forward(self, x, y) -> Tensor:
135
+ losses = [self._forward_each(d, x, y) for d in self]
136
+ return torch.stack(losses).mean()
137
+
138
+
139
+ class MPD(MD):
140
+ def __init__(self):
141
+ super().__init__([2, 3, 7, 13, 17])
142
+
143
+ def _create_network(self, period):
144
+ return PeriodNetwork(period)
145
+
146
+
147
+ class MRD(MD):
148
+ def __init__(self, stft_cfgs):
149
+ super().__init__(stft_cfgs)
150
+
151
+ def _create_network(self, stft_cfg):
152
+ return SpecNetwork(stft_cfg)
153
+
154
+
155
+ class Discriminator(nn.Module):
156
+ @property
157
+ def wav_rate(self):
158
+ return self.hp.wav_rate
159
+
160
+ def __init__(self, hp: HParams):
161
+ super().__init__()
162
+ self.hp = hp
163
+ self.stft_cfgs = get_stft_cfgs(hp)
164
+ self.mpd = MPD()
165
+ self.mrd = MRD(self.stft_cfgs)
166
+ self.dummy_float: Tensor
167
+ self.register_buffer("dummy_float", torch.zeros(0), persistent=False)
168
+
169
+ def loss_type_(self, loss_type):
170
+ self.mpd.loss_type_(loss_type)
171
+ self.mrd.loss_type_(loss_type)
172
+
173
+ def forward(self, fake, real=None):
174
+ """
175
+ Args:
176
+ fake: [B T]
177
+ real: [B T]
178
+ """
179
+ fake = fake.to(self.dummy_float)
180
+
181
+ if real is None:
182
+ self.loss_type_("wgan")
183
+ else:
184
+ length_difference = (fake.shape[-1] - real.shape[-1]) / real.shape[-1]
185
+ assert length_difference < 0.05, f"length_difference should be smaller than 5%"
186
+
187
+ self.loss_type_("hinge")
188
+ real = real.to(self.dummy_float)
189
+
190
+ fake = fake[..., : real.shape[-1]]
191
+ real = real[..., : fake.shape[-1]]
192
+
193
+ losses = {}
194
+
195
+ assert fake.dim() == 2, f"(B, T) is expected, but got {fake.shape}."
196
+ assert real is None or real.dim() == 2, f"(B, T) is expected, but got {real.shape}."
197
+
198
+ fake = fake.unsqueeze(1)
199
+
200
+ if real is None:
201
+ losses["mpd"] = self.mpd(fake, 1)
202
+ losses["mrd"] = self.mrd(fake, 1)
203
+ else:
204
+ real = real.unsqueeze(1)
205
+ losses["mpd_fake"] = self.mpd(fake, 0)
206
+ losses["mpd_real"] = self.mpd(real, 1)
207
+ losses["mrd_fake"] = self.mrd(fake, 0)
208
+ losses["mrd_real"] = self.mrd(real, 1)
209
+
210
+ return losses
modules/repos_static/resemble_enhance/enhancer/univnet/lvcnet.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ refer from https://github.com/zceng/LVCNet """
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from torch import nn
7
+ from torch.nn.utils.parametrizations import weight_norm
8
+
9
+ from .amp import AMPBlock
10
+
11
+
12
+ class KernelPredictor(torch.nn.Module):
13
+ """Kernel predictor for the location-variable convolutions"""
14
+
15
+ def __init__(
16
+ self,
17
+ cond_channels,
18
+ conv_in_channels,
19
+ conv_out_channels,
20
+ conv_layers,
21
+ conv_kernel_size=3,
22
+ kpnet_hidden_channels=64,
23
+ kpnet_conv_size=3,
24
+ kpnet_dropout=0.0,
25
+ kpnet_nonlinear_activation="LeakyReLU",
26
+ kpnet_nonlinear_activation_params={"negative_slope": 0.1},
27
+ ):
28
+ """
29
+ Args:
30
+ cond_channels (int): number of channel for the conditioning sequence,
31
+ conv_in_channels (int): number of channel for the input sequence,
32
+ conv_out_channels (int): number of channel for the output sequence,
33
+ conv_layers (int): number of layers
34
+ """
35
+ super().__init__()
36
+
37
+ self.conv_in_channels = conv_in_channels
38
+ self.conv_out_channels = conv_out_channels
39
+ self.conv_kernel_size = conv_kernel_size
40
+ self.conv_layers = conv_layers
41
+
42
+ kpnet_kernel_channels = conv_in_channels * conv_out_channels * conv_kernel_size * conv_layers # l_w
43
+ kpnet_bias_channels = conv_out_channels * conv_layers # l_b
44
+
45
+ self.input_conv = nn.Sequential(
46
+ weight_norm(nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True)),
47
+ getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
48
+ )
49
+
50
+ self.residual_convs = nn.ModuleList()
51
+ padding = (kpnet_conv_size - 1) // 2
52
+ for _ in range(3):
53
+ self.residual_convs.append(
54
+ nn.Sequential(
55
+ nn.Dropout(kpnet_dropout),
56
+ weight_norm(
57
+ nn.Conv1d(
58
+ kpnet_hidden_channels,
59
+ kpnet_hidden_channels,
60
+ kpnet_conv_size,
61
+ padding=padding,
62
+ bias=True,
63
+ )
64
+ ),
65
+ getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
66
+ weight_norm(
67
+ nn.Conv1d(
68
+ kpnet_hidden_channels,
69
+ kpnet_hidden_channels,
70
+ kpnet_conv_size,
71
+ padding=padding,
72
+ bias=True,
73
+ )
74
+ ),
75
+ getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
76
+ )
77
+ )
78
+ self.kernel_conv = weight_norm(
79
+ nn.Conv1d(
80
+ kpnet_hidden_channels,
81
+ kpnet_kernel_channels,
82
+ kpnet_conv_size,
83
+ padding=padding,
84
+ bias=True,
85
+ )
86
+ )
87
+ self.bias_conv = weight_norm(
88
+ nn.Conv1d(
89
+ kpnet_hidden_channels,
90
+ kpnet_bias_channels,
91
+ kpnet_conv_size,
92
+ padding=padding,
93
+ bias=True,
94
+ )
95
+ )
96
+
97
+ def forward(self, c):
98
+ """
99
+ Args:
100
+ c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
101
+ """
102
+ batch, _, cond_length = c.shape
103
+ c = self.input_conv(c)
104
+ for residual_conv in self.residual_convs:
105
+ residual_conv.to(c.device)
106
+ c = c + residual_conv(c)
107
+ k = self.kernel_conv(c)
108
+ b = self.bias_conv(c)
109
+ kernels = k.contiguous().view(
110
+ batch,
111
+ self.conv_layers,
112
+ self.conv_in_channels,
113
+ self.conv_out_channels,
114
+ self.conv_kernel_size,
115
+ cond_length,
116
+ )
117
+ bias = b.contiguous().view(
118
+ batch,
119
+ self.conv_layers,
120
+ self.conv_out_channels,
121
+ cond_length,
122
+ )
123
+
124
+ return kernels, bias
125
+
126
+
127
+ class LVCBlock(torch.nn.Module):
128
+ """the location-variable convolutions"""
129
+
130
+ def __init__(
131
+ self,
132
+ in_channels,
133
+ cond_channels,
134
+ stride,
135
+ dilations=[1, 3, 9, 27],
136
+ lReLU_slope=0.2,
137
+ conv_kernel_size=3,
138
+ cond_hop_length=256,
139
+ kpnet_hidden_channels=64,
140
+ kpnet_conv_size=3,
141
+ kpnet_dropout=0.0,
142
+ add_extra_noise=False,
143
+ downsampling=False,
144
+ ):
145
+ super().__init__()
146
+
147
+ self.add_extra_noise = add_extra_noise
148
+
149
+ self.cond_hop_length = cond_hop_length
150
+ self.conv_layers = len(dilations)
151
+ self.conv_kernel_size = conv_kernel_size
152
+
153
+ self.kernel_predictor = KernelPredictor(
154
+ cond_channels=cond_channels,
155
+ conv_in_channels=in_channels,
156
+ conv_out_channels=2 * in_channels,
157
+ conv_layers=len(dilations),
158
+ conv_kernel_size=conv_kernel_size,
159
+ kpnet_hidden_channels=kpnet_hidden_channels,
160
+ kpnet_conv_size=kpnet_conv_size,
161
+ kpnet_dropout=kpnet_dropout,
162
+ kpnet_nonlinear_activation_params={"negative_slope": lReLU_slope},
163
+ )
164
+
165
+ if downsampling:
166
+ self.convt_pre = nn.Sequential(
167
+ nn.LeakyReLU(lReLU_slope),
168
+ weight_norm(nn.Conv1d(in_channels, in_channels, 2 * stride + 1, padding="same")),
169
+ nn.AvgPool1d(stride, stride),
170
+ )
171
+ else:
172
+ if stride == 1:
173
+ self.convt_pre = nn.Sequential(
174
+ nn.LeakyReLU(lReLU_slope),
175
+ weight_norm(nn.Conv1d(in_channels, in_channels, 1)),
176
+ )
177
+ else:
178
+ self.convt_pre = nn.Sequential(
179
+ nn.LeakyReLU(lReLU_slope),
180
+ weight_norm(
181
+ nn.ConvTranspose1d(
182
+ in_channels,
183
+ in_channels,
184
+ 2 * stride,
185
+ stride=stride,
186
+ padding=stride // 2 + stride % 2,
187
+ output_padding=stride % 2,
188
+ )
189
+ ),
190
+ )
191
+
192
+ self.amp_block = AMPBlock(in_channels)
193
+
194
+ self.conv_blocks = nn.ModuleList()
195
+ for d in dilations:
196
+ self.conv_blocks.append(
197
+ nn.Sequential(
198
+ nn.LeakyReLU(lReLU_slope),
199
+ weight_norm(nn.Conv1d(in_channels, in_channels, conv_kernel_size, dilation=d, padding="same")),
200
+ nn.LeakyReLU(lReLU_slope),
201
+ )
202
+ )
203
+
204
+ def forward(self, x, c):
205
+ """forward propagation of the location-variable convolutions.
206
+ Args:
207
+ x (Tensor): the input sequence (batch, in_channels, in_length)
208
+ c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
209
+
210
+ Returns:
211
+ Tensor: the output sequence (batch, in_channels, in_length)
212
+ """
213
+ _, in_channels, _ = x.shape # (B, c_g, L')
214
+
215
+ x = self.convt_pre(x) # (B, c_g, stride * L')
216
+
217
+ # Add one amp block just after the upsampling
218
+ x = self.amp_block(x) # (B, c_g, stride * L')
219
+
220
+ kernels, bias = self.kernel_predictor(c)
221
+
222
+ if self.add_extra_noise:
223
+ # Add extra noise to part of the feature
224
+ a, b = x.chunk(2, dim=1)
225
+ b = b + torch.randn_like(b) * 0.1
226
+ x = torch.cat([a, b], dim=1)
227
+
228
+ for i, conv in enumerate(self.conv_blocks):
229
+ output = conv(x) # (B, c_g, stride * L')
230
+
231
+ k = kernels[:, i, :, :, :, :] # (B, 2 * c_g, c_g, kernel_size, cond_length)
232
+ b = bias[:, i, :, :] # (B, 2 * c_g, cond_length)
233
+
234
+ output = self.location_variable_convolution(
235
+ output, k, b, hop_size=self.cond_hop_length
236
+ ) # (B, 2 * c_g, stride * L'): LVC
237
+ x = x + torch.sigmoid(output[:, :in_channels, :]) * torch.tanh(
238
+ output[:, in_channels:, :]
239
+ ) # (B, c_g, stride * L'): GAU
240
+
241
+ return x
242
+
243
+ def location_variable_convolution(self, x, kernel, bias, dilation=1, hop_size=256):
244
+ """perform location-variable convolution operation on the input sequence (x) using the local convolution kernl.
245
+ Time: 414 μs ± 309 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each), test on NVIDIA V100.
246
+ Args:
247
+ x (Tensor): the input sequence (batch, in_channels, in_length).
248
+ kernel (Tensor): the local convolution kernel (batch, in_channel, out_channels, kernel_size, kernel_length)
249
+ bias (Tensor): the bias for the local convolution (batch, out_channels, kernel_length)
250
+ dilation (int): the dilation of convolution.
251
+ hop_size (int): the hop_size of the conditioning sequence.
252
+ Returns:
253
+ (Tensor): the output sequence after performing local convolution. (batch, out_channels, in_length).
254
+ """
255
+ batch, _, in_length = x.shape
256
+ batch, _, out_channels, kernel_size, kernel_length = kernel.shape
257
+
258
+ assert in_length == (
259
+ kernel_length * hop_size
260
+ ), f"length of (x, kernel) is not matched, {in_length} != {kernel_length} * {hop_size}"
261
+
262
+ padding = dilation * int((kernel_size - 1) / 2)
263
+ x = F.pad(x, (padding, padding), "constant", 0) # (batch, in_channels, in_length + 2*padding)
264
+ x = x.unfold(2, hop_size + 2 * padding, hop_size) # (batch, in_channels, kernel_length, hop_size + 2*padding)
265
+
266
+ if hop_size < dilation:
267
+ x = F.pad(x, (0, dilation), "constant", 0)
268
+ x = x.unfold(
269
+ 3, dilation, dilation
270
+ ) # (batch, in_channels, kernel_length, (hop_size + 2*padding)/dilation, dilation)
271
+ x = x[:, :, :, :, :hop_size]
272
+ x = x.transpose(3, 4) # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation)
273
+ x = x.unfold(4, kernel_size, 1) # (batch, in_channels, kernel_length, dilation, _, kernel_size)
274
+
275
+ o = torch.einsum("bildsk,biokl->bolsd", x, kernel)
276
+ o = o.to(memory_format=torch.channels_last_3d)
277
+ bias = bias.unsqueeze(-1).unsqueeze(-1).to(memory_format=torch.channels_last_3d)
278
+ o = o + bias
279
+ o = o.contiguous().view(batch, out_channels, -1)
280
+
281
+ return o
modules/repos_static/resemble_enhance/enhancer/univnet/mrstft.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Copyright 2019 Tomoki Hayashi
4
+ # MIT License (https://opensource.org/licenses/MIT)
5
+
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from torch import nn
10
+
11
+ from ..hparams import HParams
12
+
13
+
14
+ def _make_stft_cfg(hop_length, win_length=None):
15
+ if win_length is None:
16
+ win_length = 4 * hop_length
17
+ n_fft = 2 ** (win_length - 1).bit_length()
18
+ return dict(n_fft=n_fft, hop_length=hop_length, win_length=win_length)
19
+
20
+
21
+ def get_stft_cfgs(hp: HParams):
22
+ assert hp.wav_rate == 44100, f"wav_rate must be 44100, got {hp.wav_rate}"
23
+ return [_make_stft_cfg(h) for h in (100, 256, 512)]
24
+
25
+
26
+ def stft(x, n_fft, hop_length, win_length, window):
27
+ dtype = x.dtype
28
+ x = torch.stft(x.float(), n_fft, hop_length, win_length, window, return_complex=True)
29
+ x = x.abs().to(dtype)
30
+ x = x.transpose(2, 1) # (b f t) -> (b t f)
31
+ return x
32
+
33
+
34
+ class SpectralConvergengeLoss(nn.Module):
35
+ def forward(self, x_mag, y_mag):
36
+ """Calculate forward propagation.
37
+ Args:
38
+ x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
39
+ y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
40
+ Returns:
41
+ Tensor: Spectral convergence loss value.
42
+ """
43
+ return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro")
44
+
45
+
46
+ class LogSTFTMagnitudeLoss(nn.Module):
47
+ def forward(self, x_mag, y_mag):
48
+ """Calculate forward propagation.
49
+ Args:
50
+ x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
51
+ y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
52
+ Returns:
53
+ Tensor: Log STFT magnitude loss value.
54
+ """
55
+ return F.l1_loss(torch.log1p(x_mag), torch.log1p(y_mag))
56
+
57
+
58
+ class STFTLoss(nn.Module):
59
+ def __init__(self, hp, stft_cfg: dict, window="hann_window"):
60
+ super().__init__()
61
+ self.hp = hp
62
+ self.stft_cfg = stft_cfg
63
+ self.spectral_convergenge_loss = SpectralConvergengeLoss()
64
+ self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss()
65
+ self.register_buffer("window", getattr(torch, window)(stft_cfg["win_length"]), persistent=False)
66
+
67
+ def forward(self, x, y):
68
+ """Calculate forward propagation.
69
+ Args:
70
+ x (Tensor): Predicted signal (B, T).
71
+ y (Tensor): Groundtruth signal (B, T).
72
+ Returns:
73
+ Tensor: Spectral convergence loss value.
74
+ Tensor: Log STFT magnitude loss value.
75
+ """
76
+ stft_cfg = dict(self.stft_cfg)
77
+ x_mag = stft(x, **stft_cfg, window=self.window) # (b t) -> (b t f)
78
+ y_mag = stft(y, **stft_cfg, window=self.window)
79
+ sc_loss = self.spectral_convergenge_loss(x_mag, y_mag)
80
+ mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
81
+ return dict(sc=sc_loss, mag=mag_loss)
82
+
83
+
84
+ class MRSTFTLoss(nn.Module):
85
+ def __init__(self, hp: HParams, window="hann_window"):
86
+ """Initialize Multi resolution STFT loss module.
87
+ Args:
88
+ resolutions (list): List of (FFT size, hop size, window length).
89
+ window (str): Window function type.
90
+ """
91
+ super().__init__()
92
+ stft_cfgs = get_stft_cfgs(hp)
93
+ self.stft_losses = nn.ModuleList()
94
+ self.hp = hp
95
+ for c in stft_cfgs:
96
+ self.stft_losses += [STFTLoss(hp, c, window=window)]
97
+
98
+ def forward(self, x, y):
99
+ """Calculate forward propagation.
100
+ Args:
101
+ x (Tensor): Predicted signal (b t).
102
+ y (Tensor): Groundtruth signal (b t).
103
+ Returns:
104
+ Tensor: Multi resolution spectral convergence loss value.
105
+ Tensor: Multi resolution log STFT magnitude loss value.
106
+ """
107
+ assert x.dim() == 2 and y.dim() == 2, f"(b t) is expected, but got {x.shape} and {y.shape}."
108
+
109
+ dtype = x.dtype
110
+
111
+ x = x.float()
112
+ y = y.float()
113
+
114
+ # Align length
115
+ x = x[..., : y.shape[-1]]
116
+ y = y[..., : x.shape[-1]]
117
+
118
+ losses = {}
119
+
120
+ for f in self.stft_losses:
121
+ d = f(x, y)
122
+ for k, v in d.items():
123
+ losses.setdefault(k, []).append(v)
124
+
125
+ for k, v in losses.items():
126
+ losses[k] = torch.stack(v, dim=0).mean().to(dtype)
127
+
128
+ return losses
modules/repos_static/resemble_enhance/enhancer/univnet/univnet.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from torch import Tensor, nn
5
+ from torch.nn.utils.parametrizations import weight_norm
6
+
7
+ from ..hparams import HParams
8
+ from .lvcnet import LVCBlock
9
+ from .mrstft import MRSTFTLoss
10
+
11
+
12
+ class UnivNet(nn.Module):
13
+ @property
14
+ def d_noise(self):
15
+ return 128
16
+
17
+ @property
18
+ def strides(self):
19
+ return [7, 5, 4, 3]
20
+
21
+ @property
22
+ def dilations(self):
23
+ return [1, 3, 9, 27]
24
+
25
+ @property
26
+ def nc(self):
27
+ return self.hp.univnet_nc
28
+
29
+ @property
30
+ def scale_factor(self) -> int:
31
+ return self.hp.hop_size
32
+
33
+ def __init__(self, hp: HParams, d_input):
34
+ super().__init__()
35
+ self.d_input = d_input
36
+
37
+ self.hp = hp
38
+
39
+ self.blocks = nn.ModuleList(
40
+ [
41
+ LVCBlock(
42
+ self.nc,
43
+ d_input,
44
+ stride=stride,
45
+ dilations=self.dilations,
46
+ cond_hop_length=hop_length,
47
+ kpnet_conv_size=3,
48
+ )
49
+ for stride, hop_length in zip(self.strides, np.cumprod(self.strides))
50
+ ]
51
+ )
52
+
53
+ self.conv_pre = weight_norm(nn.Conv1d(self.d_noise, self.nc, 7, padding=3, padding_mode="reflect"))
54
+
55
+ self.conv_post = nn.Sequential(
56
+ nn.LeakyReLU(0.2),
57
+ weight_norm(nn.Conv1d(self.nc, 1, 7, padding=3, padding_mode="reflect")),
58
+ nn.Tanh(),
59
+ )
60
+
61
+ self.mrstft = MRSTFTLoss(hp)
62
+
63
+ @property
64
+ def eps(self):
65
+ return 1e-5
66
+
67
+ def forward(self, x: Tensor, y: Tensor | None = None, npad=10):
68
+ """
69
+ Args:
70
+ x: (b c t), acoustic features
71
+ y: (b t), waveform
72
+ Returns:
73
+ z: (b t), waveform
74
+ """
75
+ assert x.ndim == 3, "x must be 3D tensor"
76
+ assert y is None or y.ndim == 2, "y must be 2D tensor"
77
+ assert x.shape[1] == self.d_input, f"x.shape[1] must be {self.d_input}, but got {x.shape}"
78
+ assert npad >= 0, "npad must be positive or zero"
79
+
80
+ x = F.pad(x, (0, npad), "constant", 0)
81
+ z = torch.randn(x.shape[0], self.d_noise, x.shape[2]).to(x)
82
+ z = self.conv_pre(z) # (b c t)
83
+
84
+ for block in self.blocks:
85
+ z = block(z, x) # (b c t)
86
+
87
+ z = self.conv_post(z) # (b 1 t)
88
+ z = z[..., : -self.scale_factor * npad]
89
+ z = z.squeeze(1) # (b t)
90
+
91
+ if y is not None:
92
+ self.losses = self.mrstft(z, y)
93
+
94
+ return z
modules/repos_static/resemble_enhance/hparams.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from dataclasses import asdict, dataclass
3
+ from pathlib import Path
4
+
5
+ from omegaconf import OmegaConf
6
+ from rich.console import Console
7
+ from rich.panel import Panel
8
+ from rich.table import Table
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ console = Console()
13
+
14
+
15
+ def _make_stft_cfg(hop_length, win_length=None):
16
+ if win_length is None:
17
+ win_length = 4 * hop_length
18
+ n_fft = 2 ** (win_length - 1).bit_length()
19
+ return dict(n_fft=n_fft, hop_length=hop_length, win_length=win_length)
20
+
21
+
22
+ def _build_rich_table(rows, columns, title=None):
23
+ table = Table(title=title, header_style=None)
24
+ for column in columns:
25
+ table.add_column(column.capitalize(), justify="left")
26
+ for row in rows:
27
+ table.add_row(*map(str, row))
28
+ return Panel(table, expand=False)
29
+
30
+
31
+ def _rich_print_dict(d, title="Config", key="Key", value="Value"):
32
+ console.print(_build_rich_table(d.items(), [key, value], title))
33
+
34
+
35
+ @dataclass(frozen=True)
36
+ class HParams:
37
+ # Dataset
38
+ fg_dir: Path = Path("data/fg")
39
+ bg_dir: Path = Path("data/bg")
40
+ rir_dir: Path = Path("data/rir")
41
+ load_fg_only: bool = False
42
+ praat_augment_prob: float = 0
43
+
44
+ # Audio settings
45
+ wav_rate: int = 44_100
46
+ n_fft: int = 2048
47
+ win_size: int = 2048
48
+ hop_size: int = 420 # 9.5ms
49
+ num_mels: int = 128
50
+ stft_magnitude_min: float = 1e-4
51
+ preemphasis: float = 0.97
52
+ mix_alpha_range: tuple[float, float] = (0.2, 0.8)
53
+
54
+ # Training
55
+ nj: int = 64
56
+ training_seconds: float = 1.0
57
+ batch_size_per_gpu: int = 16
58
+ min_lr: float = 1e-5
59
+ max_lr: float = 1e-4
60
+ warmup_steps: int = 1000
61
+ max_steps: int = 1_000_000
62
+ gradient_clipping: float = 1.0
63
+
64
+ @property
65
+ def deepspeed_config(self):
66
+ return {
67
+ "train_micro_batch_size_per_gpu": self.batch_size_per_gpu,
68
+ "optimizer": {
69
+ "type": "Adam",
70
+ "params": {"lr": float(self.min_lr)},
71
+ },
72
+ "scheduler": {
73
+ "type": "WarmupDecayLR",
74
+ "params": {
75
+ "warmup_min_lr": float(self.min_lr),
76
+ "warmup_max_lr": float(self.max_lr),
77
+ "warmup_num_steps": self.warmup_steps,
78
+ "total_num_steps": self.max_steps,
79
+ "warmup_type": "linear",
80
+ },
81
+ },
82
+ "gradient_clipping": self.gradient_clipping,
83
+ }
84
+
85
+ @property
86
+ def stft_cfgs(self):
87
+ assert self.wav_rate == 44_100, f"wav_rate must be 44_100, got {self.wav_rate}"
88
+ return [_make_stft_cfg(h) for h in (100, 256, 512)]
89
+
90
+ @classmethod
91
+ def from_yaml(cls, path: Path) -> "HParams":
92
+ logger.info(f"Reading hparams from {path}")
93
+ # First merge to fix types (e.g., str -> Path)
94
+ return cls(**dict(OmegaConf.merge(cls(), OmegaConf.load(path))))
95
+
96
+ def save_if_not_exists(self, run_dir: Path):
97
+ path = run_dir / "hparams.yaml"
98
+ if path.exists():
99
+ logger.info(f"{path} already exists, not saving")
100
+ return
101
+ path.parent.mkdir(parents=True, exist_ok=True)
102
+ OmegaConf.save(asdict(self), str(path))
103
+
104
+ @classmethod
105
+ def load(cls, run_dir, yaml: Path | None = None):
106
+ hps = []
107
+
108
+ if (run_dir / "hparams.yaml").exists():
109
+ hps.append(cls.from_yaml(run_dir / "hparams.yaml"))
110
+
111
+ if yaml is not None:
112
+ hps.append(cls.from_yaml(yaml))
113
+
114
+ if len(hps) == 0:
115
+ hps.append(cls())
116
+
117
+ for hp in hps[1:]:
118
+ if hp != hps[0]:
119
+ errors = {}
120
+ for k, v in asdict(hp).items():
121
+ if getattr(hps[0], k) != v:
122
+ errors[k] = f"{getattr(hps[0], k)} != {v}"
123
+ raise ValueError(f"Found inconsistent hparams: {errors}, consider deleting {run_dir}")
124
+
125
+ return hps[0]
126
+
127
+ def print(self):
128
+ _rich_print_dict(asdict(self), title="HParams")
modules/repos_static/resemble_enhance/inference.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import time
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch.nn.utils.parametrize import remove_parametrizations
7
+ from torchaudio.functional import resample
8
+ from torchaudio.transforms import MelSpectrogram
9
+ from tqdm import trange
10
+
11
+ from .hparams import HParams
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ @torch.inference_mode()
17
+ def inference_chunk(model, dwav, sr, device, npad=441):
18
+ assert model.hp.wav_rate == sr, f"Expected {model.hp.wav_rate} Hz, got {sr} Hz"
19
+ del sr
20
+
21
+ length = dwav.shape[-1]
22
+ abs_max = dwav.abs().max().clamp(min=1e-7)
23
+
24
+ assert dwav.dim() == 1, f"Expected 1D waveform, got {dwav.dim()}D"
25
+ dwav = dwav.to(device)
26
+ dwav = dwav / abs_max # Normalize
27
+ dwav = F.pad(dwav, (0, npad))
28
+ hwav = model(dwav[None])[0].cpu() # (T,)
29
+ hwav = hwav[:length] # Trim padding
30
+ hwav = hwav * abs_max # Unnormalize
31
+
32
+ return hwav
33
+
34
+
35
+ def compute_corr(x, y):
36
+ return torch.fft.ifft(torch.fft.fft(x) * torch.fft.fft(y).conj()).abs()
37
+
38
+
39
+ def compute_offset(chunk1, chunk2, sr=44100):
40
+ """
41
+ Args:
42
+ chunk1: (T,)
43
+ chunk2: (T,)
44
+ Returns:
45
+ offset: int, offset in samples such that chunk1 ~= chunk2.roll(-offset)
46
+ """
47
+ hop_length = sr // 200 # 5 ms resolution
48
+ win_length = hop_length * 4
49
+ n_fft = 2 ** (win_length - 1).bit_length()
50
+
51
+ mel_fn = MelSpectrogram(
52
+ sample_rate=sr,
53
+ n_fft=n_fft,
54
+ win_length=win_length,
55
+ hop_length=hop_length,
56
+ n_mels=80,
57
+ f_min=0.0,
58
+ f_max=sr // 2,
59
+ )
60
+
61
+ spec1 = mel_fn(chunk1).log1p()
62
+ spec2 = mel_fn(chunk2).log1p()
63
+
64
+ corr = compute_corr(spec1, spec2) # (F, T)
65
+ corr = corr.mean(dim=0) # (T,)
66
+
67
+ argmax = corr.argmax().item()
68
+
69
+ if argmax > len(corr) // 2:
70
+ argmax -= len(corr)
71
+
72
+ offset = -argmax * hop_length
73
+
74
+ return offset
75
+
76
+
77
+ def merge_chunks(chunks, chunk_length, hop_length, sr=44100, length=None):
78
+ signal_length = (len(chunks) - 1) * hop_length + chunk_length
79
+ overlap_length = chunk_length - hop_length
80
+ signal = torch.zeros(signal_length, device=chunks[0].device)
81
+
82
+ fadein = torch.linspace(0, 1, overlap_length, device=chunks[0].device)
83
+ fadein = torch.cat([fadein, torch.ones(hop_length, device=chunks[0].device)])
84
+ fadeout = torch.linspace(1, 0, overlap_length, device=chunks[0].device)
85
+ fadeout = torch.cat([torch.ones(hop_length, device=chunks[0].device), fadeout])
86
+
87
+ for i, chunk in enumerate(chunks):
88
+ start = i * hop_length
89
+ end = start + chunk_length
90
+
91
+ if len(chunk) < chunk_length:
92
+ chunk = F.pad(chunk, (0, chunk_length - len(chunk)))
93
+
94
+ if i > 0:
95
+ pre_region = chunks[i - 1][-overlap_length:]
96
+ cur_region = chunk[:overlap_length]
97
+ offset = compute_offset(pre_region, cur_region, sr=sr)
98
+ start -= offset
99
+ end -= offset
100
+
101
+ if i == 0:
102
+ chunk = chunk * fadeout
103
+ elif i == len(chunks) - 1:
104
+ chunk = chunk * fadein
105
+ else:
106
+ chunk = chunk * fadein * fadeout
107
+
108
+ signal[start:end] += chunk[: len(signal[start:end])]
109
+
110
+ signal = signal[:length]
111
+
112
+ return signal
113
+
114
+
115
+ def remove_weight_norm_recursively(module):
116
+ for _, module in module.named_modules():
117
+ try:
118
+ remove_parametrizations(module, "weight")
119
+ except Exception:
120
+ pass
121
+
122
+
123
+ def inference(model, dwav, sr, device, chunk_seconds: float = 30.0, overlap_seconds: float = 1.0):
124
+ remove_weight_norm_recursively(model)
125
+
126
+ hp: HParams = model.hp
127
+
128
+ dwav = resample(
129
+ dwav,
130
+ orig_freq=sr,
131
+ new_freq=hp.wav_rate,
132
+ lowpass_filter_width=64,
133
+ rolloff=0.9475937167399596,
134
+ resampling_method="sinc_interp_kaiser",
135
+ beta=14.769656459379492,
136
+ )
137
+
138
+ del sr # Everything is in hp.wav_rate now
139
+
140
+ sr = hp.wav_rate
141
+
142
+ if torch.cuda.is_available():
143
+ torch.cuda.synchronize()
144
+
145
+ start_time = time.perf_counter()
146
+
147
+ chunk_length = int(sr * chunk_seconds)
148
+ overlap_length = int(sr * overlap_seconds)
149
+ hop_length = chunk_length - overlap_length
150
+
151
+ chunks = []
152
+ for start in trange(0, dwav.shape[-1], hop_length):
153
+ chunks.append(inference_chunk(model, dwav[start : start + chunk_length], sr, device))
154
+
155
+ hwav = merge_chunks(chunks, chunk_length, hop_length, sr=sr, length=dwav.shape[-1])
156
+
157
+ if torch.cuda.is_available():
158
+ torch.cuda.synchronize()
159
+
160
+ elapsed_time = time.perf_counter() - start_time
161
+ logger.info(f"Elapsed time: {elapsed_time:.3f} s, {hwav.shape[-1] / elapsed_time / 1000:.3f} kHz")
162
+
163
+ return hwav, sr
modules/repos_static/resemble_enhance/melspec.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from torch import nn
4
+ from torchaudio.transforms import MelSpectrogram as TorchMelSpectrogram
5
+
6
+ from .hparams import HParams
7
+
8
+
9
+ class MelSpectrogram(nn.Module):
10
+ def __init__(self, hp: HParams):
11
+ """
12
+ Torch implementation of Resemble's mel extraction.
13
+ Note that the values are NOT identical to librosa's implementation
14
+ due to floating point precisions.
15
+ """
16
+ super().__init__()
17
+ self.hp = hp
18
+ self.melspec = TorchMelSpectrogram(
19
+ hp.wav_rate,
20
+ n_fft=hp.n_fft,
21
+ win_length=hp.win_size,
22
+ hop_length=hp.hop_size,
23
+ f_min=0,
24
+ f_max=hp.wav_rate // 2,
25
+ n_mels=hp.num_mels,
26
+ power=1,
27
+ normalized=False,
28
+ # NOTE: Folowing librosa's default.
29
+ pad_mode="constant",
30
+ norm="slaney",
31
+ mel_scale="slaney",
32
+ )
33
+ self.register_buffer("stft_magnitude_min", torch.FloatTensor([hp.stft_magnitude_min]))
34
+ self.min_level_db = 20 * np.log10(hp.stft_magnitude_min)
35
+ self.preemphasis = hp.preemphasis
36
+ self.hop_size = hp.hop_size
37
+
38
+ def forward(self, wav, pad=True):
39
+ """
40
+ Args:
41
+ wav: [B, T]
42
+ """
43
+ device = wav.device
44
+ if wav.is_mps:
45
+ wav = wav.cpu()
46
+ self.to(wav.device)
47
+ if self.preemphasis > 0:
48
+ wav = torch.nn.functional.pad(wav, [1, 0], value=0)
49
+ wav = wav[..., 1:] - self.preemphasis * wav[..., :-1]
50
+ mel = self.melspec(wav)
51
+ mel = self._amp_to_db(mel)
52
+ mel_normed = self._normalize(mel)
53
+ assert not pad or mel_normed.shape[-1] == 1 + wav.shape[-1] // self.hop_size # Sanity check
54
+ mel_normed = mel_normed.to(device)
55
+ return mel_normed # (M, T)
56
+
57
+ def _normalize(self, s, headroom_db=15):
58
+ return (s - self.min_level_db) / (-self.min_level_db + headroom_db)
59
+
60
+ def _amp_to_db(self, x):
61
+ return x.clamp_min(self.hp.stft_magnitude_min).log10() * 20
modules/repos_static/resemble_enhance/utils/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .logging import setup_logging
2
+ from .utils import save_mels, tree_map
modules/repos_static/resemble_enhance/utils/control.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import selectors
3
+ import sys
4
+ from functools import cache
5
+
6
+ from .distributed import global_leader_only
7
+
8
+ _logger = logging.getLogger(__name__)
9
+
10
+
11
+ @cache
12
+ def _get_stdin_selector():
13
+ selector = selectors.DefaultSelector()
14
+ selector.register(fileobj=sys.stdin, events=selectors.EVENT_READ)
15
+ return selector
16
+
17
+
18
+ @global_leader_only(boardcast_return=True)
19
+ def non_blocking_input():
20
+ s = ""
21
+ selector = _get_stdin_selector()
22
+ events = selector.select(timeout=0)
23
+ for key, _ in events:
24
+ s: str = key.fileobj.readline().strip()
25
+ _logger.info(f'Get stdin "{s}".')
26
+ return s
modules/repos_static/resemble_enhance/utils/logging.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from pathlib import Path
3
+
4
+ from rich.logging import RichHandler
5
+
6
+ from .distributed import global_leader_only
7
+
8
+
9
+ @global_leader_only
10
+ def setup_logging(run_dir):
11
+ handlers = []
12
+ stdout_handler = RichHandler()
13
+ stdout_handler.setLevel(logging.INFO)
14
+ handlers.append(stdout_handler)
15
+
16
+ if run_dir is not None:
17
+ filename = Path(run_dir) / f"log.txt"
18
+ filename.parent.mkdir(parents=True, exist_ok=True)
19
+ file_handler = logging.FileHandler(filename, mode="a")
20
+ file_handler.setLevel(logging.DEBUG)
21
+ handlers.append(file_handler)
22
+
23
+ # Update all existing loggers
24
+ for name in ["DeepSpeed"]:
25
+ logger = logging.getLogger(name)
26
+ if isinstance(logger, logging.Logger):
27
+ for handler in list(logger.handlers):
28
+ logger.removeHandler(handler)
29
+ for handler in handlers:
30
+ logger.addHandler(handler)
31
+
32
+ # Set the default logger
33
+ logging.basicConfig(
34
+ level=logging.getLevelName("INFO"),
35
+ format="%(message)s",
36
+ datefmt="[%X]",
37
+ handlers=handlers,
38
+ )
modules/repos_static/resemble_enhance/utils/utils.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, TypeVar, overload
2
+
3
+ import matplotlib.pyplot as plt
4
+ import numpy as np
5
+
6
+
7
+ def save_mels(path, *, targ_mel, pred_mel, cond_mel):
8
+ n = 3 if cond_mel is None else 4
9
+
10
+ plt.figure(figsize=(10, n * 4))
11
+
12
+ i = 1
13
+
14
+ plt.subplot(n, 1, i)
15
+ plt.imshow(pred_mel, origin="lower", interpolation="none")
16
+ plt.title(f"Pred mel {pred_mel.shape}")
17
+ i += 1
18
+
19
+ plt.subplot(n, 1, i)
20
+ plt.imshow(targ_mel, origin="lower", interpolation="none")
21
+ plt.title(f"GT mel {targ_mel.shape}")
22
+ i += 1
23
+
24
+ plt.subplot(n, 1, i)
25
+ pred_mel = pred_mel[:, : targ_mel.shape[1]]
26
+ targ_mel = targ_mel[:, : pred_mel.shape[1]]
27
+ plt.imshow(np.abs(pred_mel - targ_mel), origin="lower", interpolation="none")
28
+ plt.title(f"Diff mel {pred_mel.shape}, mse={np.mean((pred_mel - targ_mel)**2):.4f}")
29
+ i += 1
30
+
31
+ if cond_mel is not None:
32
+ plt.subplot(n, 1, i)
33
+ plt.imshow(cond_mel, origin="lower", interpolation="none")
34
+ plt.title(f"Cond mel {cond_mel.shape}")
35
+ i += 1
36
+
37
+ plt.savefig(path, dpi=480)
38
+ plt.close()
39
+
40
+
41
+ T = TypeVar("T")
42
+
43
+
44
+ @overload
45
+ def tree_map(fn: Callable, x: list[T]) -> list[T]:
46
+ ...
47
+
48
+
49
+ @overload
50
+ def tree_map(fn: Callable, x: tuple[T]) -> tuple[T]:
51
+ ...
52
+
53
+
54
+ @overload
55
+ def tree_map(fn: Callable, x: dict[str, T]) -> dict[str, T]:
56
+ ...
57
+
58
+
59
+ @overload
60
+ def tree_map(fn: Callable, x: T) -> T:
61
+ ...
62
+
63
+
64
+ def tree_map(fn: Callable, x):
65
+ if isinstance(x, list):
66
+ x = [tree_map(fn, xi) for xi in x]
67
+ elif isinstance(x, tuple):
68
+ x = (tree_map(fn, xi) for xi in x)
69
+ elif isinstance(x, dict):
70
+ x = {k: tree_map(fn, v) for k, v in x.items()}
71
+ else:
72
+ x = fn(x)
73
+ return x
modules/speaker.py CHANGED
@@ -99,6 +99,10 @@ class SpeakerManager:
99
  self.speakers[speaker_file] = Speaker.from_file(
100
  self.speaker_dir + speaker_file
101
  )
 
 
 
 
102
 
103
  def list_speakers(self):
104
  return list(self.speakers.values())
 
99
  self.speakers[speaker_file] = Speaker.from_file(
100
  self.speaker_dir + speaker_file
101
  )
102
+ # 检查是否有被删除的,同步到 speakers
103
+ for fname, spk in self.speakers.items():
104
+ if not os.path.exists(self.speaker_dir + fname):
105
+ del self.speakers[fname]
106
 
107
  def list_speakers(self):
108
  return list(self.speakers.values())
modules/webui/speaker/__init__.py ADDED
File without changes
modules/webui/speaker/speaker_creator.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from modules.speaker import Speaker
4
+ from modules.utils.SeedContext import SeedContext
5
+ from modules.hf import spaces
6
+ from modules.models import load_chat_tts
7
+ from modules.utils.rng import np_rng
8
+ from modules.webui.webui_utils import get_speakers, tts_generate
9
+
10
+ import tempfile
11
+
12
+ names_list = [
13
+ "Alice",
14
+ "Bob",
15
+ "Carol",
16
+ "Carlos",
17
+ "Charlie",
18
+ "Chuck",
19
+ "Chad",
20
+ "Craig",
21
+ "Dan",
22
+ "Dave",
23
+ "David",
24
+ "Erin",
25
+ "Eve",
26
+ "Yves",
27
+ "Faythe",
28
+ "Frank",
29
+ "Grace",
30
+ "Heidi",
31
+ "Ivan",
32
+ "Judy",
33
+ "Mallory",
34
+ "Mallet",
35
+ "Darth",
36
+ "Michael",
37
+ "Mike",
38
+ "Niaj",
39
+ "Olivia",
40
+ "Oscar",
41
+ "Peggy",
42
+ "Pat",
43
+ "Rupert",
44
+ "Sybil",
45
+ "Trent",
46
+ "Ted",
47
+ "Trudy",
48
+ "Victor",
49
+ "Vanna",
50
+ "Walter",
51
+ "Wendy",
52
+ ]
53
+
54
+
55
+ @torch.inference_mode()
56
+ @spaces.GPU
57
+ def create_spk_from_seed(
58
+ seed: int,
59
+ name: str,
60
+ gender: str,
61
+ desc: str,
62
+ ):
63
+ chat_tts = load_chat_tts()
64
+ with SeedContext(seed):
65
+ emb = chat_tts.sample_random_speaker()
66
+ spk = Speaker(seed=-2, name=name, gender=gender, describe=desc)
67
+ spk.emb = emb
68
+
69
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".pt") as tmp_file:
70
+ torch.save(spk, tmp_file)
71
+ tmp_file_path = tmp_file.name
72
+
73
+ return tmp_file_path
74
+
75
+
76
+ @torch.inference_mode()
77
+ @spaces.GPU
78
+ def test_spk_voice(seed: int, text: str):
79
+ return tts_generate(
80
+ spk=seed,
81
+ text=text,
82
+ )
83
+
84
+
85
+ def random_speaker():
86
+ seed = np_rng()
87
+ name = names_list[seed % len(names_list)]
88
+ return seed, name
89
+
90
+
91
+ creator_ui_desc = """
92
+ ## Speaker Creator
93
+ 使用本面板快捷抽卡生成 speaker.pt 文件。
94
+
95
+ 1. **生成说话人**:输入种子、名字、性别和描述。点击 "Generate speaker.pt" 按钮,生成的说话人配置会保存为.pt文件。
96
+ 2. **测试说话人声音**:输入测试文本。点击 "Test Voice" 按钮,生成的音频会在 "Output Audio" 中播放。
97
+ 3. **随机生成说话人**:点击 "Random Speaker" 按钮,随机生成一个种子和名字,可以进一步编辑其他信息并测试。
98
+ """
99
+
100
+
101
+ def speaker_creator_ui():
102
+ def on_generate(seed, name, gender, desc):
103
+ file_path = create_spk_from_seed(seed, name, gender, desc)
104
+ return file_path
105
+
106
+ def create_test_voice_card(seed_input):
107
+ with gr.Group():
108
+ gr.Markdown("🎤Test voice")
109
+ with gr.Row():
110
+ test_voice_btn = gr.Button("Test Voice", variant="secondary")
111
+
112
+ with gr.Column(scale=4):
113
+ test_text = gr.Textbox(
114
+ label="Test Text",
115
+ placeholder="Please input test text",
116
+ value="说话人测试 123456789 [uv_break] ok, test done [lbreak]",
117
+ )
118
+ with gr.Row():
119
+ current_seed = gr.Label(label="Current Seed", value=-1)
120
+ with gr.Column(scale=4):
121
+ output_audio = gr.Audio(label="Output Audio")
122
+
123
+ test_voice_btn.click(
124
+ fn=test_spk_voice,
125
+ inputs=[seed_input, test_text],
126
+ outputs=[output_audio],
127
+ )
128
+ test_voice_btn.click(
129
+ fn=lambda x: x,
130
+ inputs=[seed_input],
131
+ outputs=[current_seed],
132
+ )
133
+
134
+ gr.Markdown(creator_ui_desc)
135
+
136
+ with gr.Row():
137
+ with gr.Column(scale=2):
138
+ with gr.Group():
139
+ gr.Markdown("ℹ️Speaker info")
140
+ seed_input = gr.Number(label="Seed", value=2)
141
+ name_input = gr.Textbox(
142
+ label="Name", placeholder="Enter speaker name", value="Bob"
143
+ )
144
+ gender_input = gr.Textbox(
145
+ label="Gender", placeholder="Enter gender", value="*"
146
+ )
147
+ desc_input = gr.Textbox(
148
+ label="Description",
149
+ placeholder="Enter description",
150
+ )
151
+ random_button = gr.Button("Random Speaker")
152
+ with gr.Group():
153
+ gr.Markdown("🔊Generate speaker.pt")
154
+ generate_button = gr.Button("Save .pt file")
155
+ output_file = gr.File(label="Save to File")
156
+ with gr.Column(scale=5):
157
+ create_test_voice_card(seed_input=seed_input)
158
+ create_test_voice_card(seed_input=seed_input)
159
+ create_test_voice_card(seed_input=seed_input)
160
+ create_test_voice_card(seed_input=seed_input)
161
+
162
+ random_button.click(
163
+ random_speaker,
164
+ outputs=[seed_input, name_input],
165
+ )
166
+
167
+ generate_button.click(
168
+ fn=on_generate,
169
+ inputs=[seed_input, name_input, gender_input, desc_input],
170
+ outputs=[output_file],
171
+ )
modules/webui/speaker/speaker_merger.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import gradio as gr
3
+ import torch
4
+
5
+ from modules.hf import spaces
6
+ from modules.webui.webui_utils import get_speakers, tts_generate
7
+ from modules.speaker import speaker_mgr, Speaker
8
+
9
+ import tempfile
10
+
11
+
12
+ def spk_to_tensor(spk):
13
+ spk = spk.split(" : ")[1].strip() if " : " in spk else spk
14
+ if spk == "None" or spk == "":
15
+ return None
16
+ return speaker_mgr.get_speaker(spk).emb
17
+
18
+
19
+ def get_speaker_show_name(spk):
20
+ if spk.gender == "*" or spk.gender == "":
21
+ return spk.name
22
+ return f"{spk.gender} : {spk.name}"
23
+
24
+
25
+ def merge_spk(
26
+ spk_a,
27
+ spk_a_w,
28
+ spk_b,
29
+ spk_b_w,
30
+ spk_c,
31
+ spk_c_w,
32
+ spk_d,
33
+ spk_d_w,
34
+ ):
35
+ tensor_a = spk_to_tensor(spk_a)
36
+ tensor_b = spk_to_tensor(spk_b)
37
+ tensor_c = spk_to_tensor(spk_c)
38
+ tensor_d = spk_to_tensor(spk_d)
39
+
40
+ assert (
41
+ tensor_a is not None
42
+ or tensor_b is not None
43
+ or tensor_c is not None
44
+ or tensor_d is not None
45
+ ), "At least one speaker should be selected"
46
+
47
+ merge_tensor = torch.zeros_like(
48
+ tensor_a
49
+ if tensor_a is not None
50
+ else (
51
+ tensor_b
52
+ if tensor_b is not None
53
+ else tensor_c if tensor_c is not None else tensor_d
54
+ )
55
+ )
56
+
57
+ total_weight = 0
58
+ if tensor_a is not None:
59
+ merge_tensor += spk_a_w * tensor_a
60
+ total_weight += spk_a_w
61
+ if tensor_b is not None:
62
+ merge_tensor += spk_b_w * tensor_b
63
+ total_weight += spk_b_w
64
+ if tensor_c is not None:
65
+ merge_tensor += spk_c_w * tensor_c
66
+ total_weight += spk_c_w
67
+ if tensor_d is not None:
68
+ merge_tensor += spk_d_w * tensor_d
69
+ total_weight += spk_d_w
70
+
71
+ if total_weight > 0:
72
+ merge_tensor /= total_weight
73
+
74
+ merged_spk = Speaker.from_tensor(merge_tensor)
75
+ merged_spk.name = "<MIX>"
76
+
77
+ return merged_spk
78
+
79
+
80
+ @torch.inference_mode()
81
+ @spaces.GPU
82
+ def merge_and_test_spk_voice(
83
+ spk_a, spk_a_w, spk_b, spk_b_w, spk_c, spk_c_w, spk_d, spk_d_w, test_text
84
+ ):
85
+ merged_spk = merge_spk(
86
+ spk_a,
87
+ spk_a_w,
88
+ spk_b,
89
+ spk_b_w,
90
+ spk_c,
91
+ spk_c_w,
92
+ spk_d,
93
+ spk_d_w,
94
+ )
95
+ return tts_generate(
96
+ spk=merged_spk,
97
+ text=test_text,
98
+ )
99
+
100
+
101
+ @torch.inference_mode()
102
+ @spaces.GPU
103
+ def merge_spk_to_file(
104
+ spk_a,
105
+ spk_a_w,
106
+ spk_b,
107
+ spk_b_w,
108
+ spk_c,
109
+ spk_c_w,
110
+ spk_d,
111
+ spk_d_w,
112
+ speaker_name,
113
+ speaker_gender,
114
+ speaker_desc,
115
+ ):
116
+ merged_spk = merge_spk(
117
+ spk_a, spk_a_w, spk_b, spk_b_w, spk_c, spk_c_w, spk_d, spk_d_w
118
+ )
119
+ merged_spk.name = speaker_name
120
+ merged_spk.gender = speaker_gender
121
+ merged_spk.desc = speaker_desc
122
+
123
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".pt") as tmp_file:
124
+ torch.save(merged_spk, tmp_file)
125
+ tmp_file_path = tmp_file.name
126
+
127
+ return tmp_file_path
128
+
129
+
130
+ merge_desc = """
131
+ ## Speaker Merger
132
+
133
+ 在本面板中,您可以选择多个说话人并指定他们的权重,合成新的语音并进行测试。以下是各个功能的详细说明:
134
+
135
+ 1. 选择说话人: 您可以从下拉菜单中选择最多四个说话人(A、B、C、D),每个说话人都有一个对应的权重滑块,范围从0到10。权重决定了每个说话人在合成语音中的影响程度。
136
+ 2. 合成语音: 在选择好说话人和设置好权重后,您可以在“Test Text”框中输入要测试的文本,然后点击“测试语音”按钮来生成并播放合成的语音。
137
+ 3. 保存说话人: 您还可以在右侧的“说话人信息”部分填写新的说话人的名称、性别和描述,并点击“Save Speaker”按钮来保存合成的说话人。保存后的说话人文件将显示在“Merged Speaker”栏中,供下载使用。
138
+ """
139
+
140
+
141
+ def get_spk_choices():
142
+ speakers = get_speakers()
143
+
144
+ speaker_names = ["None"] + [get_speaker_show_name(speaker) for speaker in speakers]
145
+ return speaker_names
146
+
147
+
148
+ # 显示 a b c d 四个选择框,选择一个或多个,然后可以试音,并导出
149
+ def create_speaker_merger():
150
+ speaker_names = get_spk_choices()
151
+
152
+ gr.Markdown(merge_desc)
153
+
154
+ def spk_picker(label_tail: str):
155
+ with gr.Row():
156
+ spk_a = gr.Dropdown(
157
+ choices=speaker_names, value="None", label=f"Speaker {label_tail}"
158
+ )
159
+ refresh_a_btn = gr.Button("🔄", variant="secondary")
160
+
161
+ def refresh_a():
162
+ speaker_mgr.refresh_speakers()
163
+ speaker_names = get_spk_choices()
164
+ return gr.update(choices=speaker_names)
165
+
166
+ refresh_a_btn.click(refresh_a, outputs=[spk_a])
167
+ spk_a_w = gr.Slider(
168
+ value=1,
169
+ minimum=0,
170
+ maximum=10,
171
+ step=0.1,
172
+ label=f"Weight {label_tail}",
173
+ )
174
+ return spk_a, spk_a_w
175
+
176
+ with gr.Row():
177
+ with gr.Column(scale=5):
178
+ with gr.Row():
179
+ with gr.Group():
180
+ spk_a, spk_a_w = spk_picker("A")
181
+
182
+ with gr.Group():
183
+ spk_b, spk_b_w = spk_picker("B")
184
+
185
+ with gr.Group():
186
+ spk_c, spk_c_w = spk_picker("C")
187
+
188
+ with gr.Group():
189
+ spk_d, spk_d_w = spk_picker("D")
190
+
191
+ with gr.Row():
192
+ with gr.Column(scale=3):
193
+ with gr.Group():
194
+ gr.Markdown("🎤Test voice")
195
+ with gr.Row():
196
+ test_voice_btn = gr.Button(
197
+ "Test Voice", variant="secondary"
198
+ )
199
+
200
+ with gr.Column(scale=4):
201
+ test_text = gr.Textbox(
202
+ label="Test Text",
203
+ placeholder="Please input test text",
204
+ value="说话人合并测试 123456789 [uv_break] ok, test done [lbreak]",
205
+ )
206
+
207
+ output_audio = gr.Audio(label="Output Audio")
208
+
209
+ with gr.Column(scale=1):
210
+ with gr.Group():
211
+ gr.Markdown("🗃️Save to file")
212
+
213
+ speaker_name = gr.Textbox(label="Name", value="forge_speaker_merged")
214
+ speaker_gender = gr.Textbox(label="Gender", value="*")
215
+ speaker_desc = gr.Textbox(label="Description", value="merged speaker")
216
+
217
+ save_btn = gr.Button("Save Speaker", variant="primary")
218
+
219
+ merged_spker = gr.File(
220
+ label="Merged Speaker", interactive=False, type="binary"
221
+ )
222
+
223
+ test_voice_btn.click(
224
+ merge_and_test_spk_voice,
225
+ inputs=[
226
+ spk_a,
227
+ spk_a_w,
228
+ spk_b,
229
+ spk_b_w,
230
+ spk_c,
231
+ spk_c_w,
232
+ spk_d,
233
+ spk_d_w,
234
+ test_text,
235
+ ],
236
+ outputs=[output_audio],
237
+ )
238
+
239
+ save_btn.click(
240
+ merge_spk_to_file,
241
+ inputs=[
242
+ spk_a,
243
+ spk_a_w,
244
+ spk_b,
245
+ spk_b_w,
246
+ spk_c,
247
+ spk_c_w,
248
+ spk_d,
249
+ spk_d_w,
250
+ speaker_name,
251
+ speaker_gender,
252
+ speaker_desc,
253
+ ],
254
+ outputs=[merged_spker],
255
+ )