Spaces:
Running
on
Zero
Running
on
Zero
zhzluke96
commited on
Commit
·
32b2aaa
1
Parent(s):
2ca1c87
update
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- modules/Enhancer/ResembleEnhance.py +3 -8
- modules/repos_static/__init__.py +0 -0
- modules/repos_static/readme.md +5 -0
- modules/repos_static/resemble_enhance/__init__.py +0 -0
- modules/repos_static/resemble_enhance/common.py +55 -0
- modules/repos_static/resemble_enhance/data/__init__.py +48 -0
- modules/repos_static/resemble_enhance/data/dataset.py +171 -0
- modules/repos_static/resemble_enhance/data/distorter/__init__.py +1 -0
- modules/repos_static/resemble_enhance/data/distorter/base.py +104 -0
- modules/repos_static/resemble_enhance/data/distorter/custom.py +85 -0
- modules/repos_static/resemble_enhance/data/distorter/distorter.py +32 -0
- modules/repos_static/resemble_enhance/data/distorter/sox.py +176 -0
- modules/repos_static/resemble_enhance/data/utils.py +43 -0
- modules/repos_static/resemble_enhance/denoiser/__init__.py +0 -0
- modules/repos_static/resemble_enhance/denoiser/__main__.py +30 -0
- modules/repos_static/resemble_enhance/denoiser/denoiser.py +181 -0
- modules/repos_static/resemble_enhance/denoiser/hparams.py +9 -0
- modules/repos_static/resemble_enhance/denoiser/inference.py +31 -0
- modules/repos_static/resemble_enhance/denoiser/unet.py +144 -0
- modules/repos_static/resemble_enhance/enhancer/__init__.py +0 -0
- modules/repos_static/resemble_enhance/enhancer/__main__.py +129 -0
- modules/repos_static/resemble_enhance/enhancer/download.py +30 -0
- modules/repos_static/resemble_enhance/enhancer/enhancer.py +185 -0
- modules/repos_static/resemble_enhance/enhancer/hparams.py +23 -0
- modules/repos_static/resemble_enhance/enhancer/inference.py +48 -0
- modules/repos_static/resemble_enhance/enhancer/lcfm/__init__.py +2 -0
- modules/repos_static/resemble_enhance/enhancer/lcfm/cfm.py +372 -0
- modules/repos_static/resemble_enhance/enhancer/lcfm/irmae.py +123 -0
- modules/repos_static/resemble_enhance/enhancer/lcfm/lcfm.py +152 -0
- modules/repos_static/resemble_enhance/enhancer/lcfm/wn.py +147 -0
- modules/repos_static/resemble_enhance/enhancer/univnet/__init__.py +1 -0
- modules/repos_static/resemble_enhance/enhancer/univnet/alias_free_torch/__init__.py +5 -0
- modules/repos_static/resemble_enhance/enhancer/univnet/alias_free_torch/filter.py +95 -0
- modules/repos_static/resemble_enhance/enhancer/univnet/alias_free_torch/resample.py +49 -0
- modules/repos_static/resemble_enhance/enhancer/univnet/amp.py +101 -0
- modules/repos_static/resemble_enhance/enhancer/univnet/discriminator.py +210 -0
- modules/repos_static/resemble_enhance/enhancer/univnet/lvcnet.py +281 -0
- modules/repos_static/resemble_enhance/enhancer/univnet/mrstft.py +128 -0
- modules/repos_static/resemble_enhance/enhancer/univnet/univnet.py +94 -0
- modules/repos_static/resemble_enhance/hparams.py +128 -0
- modules/repos_static/resemble_enhance/inference.py +163 -0
- modules/repos_static/resemble_enhance/melspec.py +61 -0
- modules/repos_static/resemble_enhance/utils/__init__.py +2 -0
- modules/repos_static/resemble_enhance/utils/control.py +26 -0
- modules/repos_static/resemble_enhance/utils/logging.py +38 -0
- modules/repos_static/resemble_enhance/utils/utils.py +73 -0
- modules/speaker.py +4 -0
- modules/webui/speaker/__init__.py +0 -0
- modules/webui/speaker/speaker_creator.py +171 -0
- modules/webui/speaker/speaker_merger.py +255 -0
modules/Enhancer/ResembleEnhance.py
CHANGED
|
@@ -1,13 +1,8 @@
|
|
| 1 |
import os
|
| 2 |
from typing import List
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
from resemble_enhance.enhancer.hparams import HParams
|
| 7 |
-
from resemble_enhance.inference import inference
|
| 8 |
-
except:
|
| 9 |
-
HParams = dict
|
| 10 |
-
Enhancer = dict
|
| 11 |
|
| 12 |
import torch
|
| 13 |
|
|
|
|
| 1 |
import os
|
| 2 |
from typing import List
|
| 3 |
+
from modules.repos_static.resemble_enhance.enhancer.enhancer import Enhancer
|
| 4 |
+
from modules.repos_static.resemble_enhance.enhancer.hparams import HParams
|
| 5 |
+
from modules.repos_static.resemble_enhance.inference import inference
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
import torch
|
| 8 |
|
modules/repos_static/__init__.py
ADDED
|
File without changes
|
modules/repos_static/readme.md
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# repos static
|
| 2 |
+
|
| 3 |
+
## resemble_enhance
|
| 4 |
+
|
| 5 |
+
https://github.com/resemble-ai/resemble-enhance/tree/main
|
modules/repos_static/resemble_enhance/__init__.py
ADDED
|
File without changes
|
modules/repos_static/resemble_enhance/common.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import Tensor, nn
|
| 5 |
+
|
| 6 |
+
logger = logging.getLogger(__name__)
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class Normalizer(nn.Module):
|
| 10 |
+
def __init__(self, momentum=0.01, eps=1e-9):
|
| 11 |
+
super().__init__()
|
| 12 |
+
self.momentum = momentum
|
| 13 |
+
self.eps = eps
|
| 14 |
+
self.running_mean_unsafe: Tensor
|
| 15 |
+
self.running_var_unsafe: Tensor
|
| 16 |
+
self.register_buffer("running_mean_unsafe", torch.full([], torch.nan))
|
| 17 |
+
self.register_buffer("running_var_unsafe", torch.full([], torch.nan))
|
| 18 |
+
|
| 19 |
+
@property
|
| 20 |
+
def started(self):
|
| 21 |
+
return not torch.isnan(self.running_mean_unsafe)
|
| 22 |
+
|
| 23 |
+
@property
|
| 24 |
+
def running_mean(self):
|
| 25 |
+
if not self.started:
|
| 26 |
+
return torch.zeros_like(self.running_mean_unsafe)
|
| 27 |
+
return self.running_mean_unsafe
|
| 28 |
+
|
| 29 |
+
@property
|
| 30 |
+
def running_std(self):
|
| 31 |
+
if not self.started:
|
| 32 |
+
return torch.ones_like(self.running_var_unsafe)
|
| 33 |
+
return (self.running_var_unsafe + self.eps).sqrt()
|
| 34 |
+
|
| 35 |
+
@torch.no_grad()
|
| 36 |
+
def _ema(self, a: Tensor, x: Tensor):
|
| 37 |
+
return (1 - self.momentum) * a + self.momentum * x
|
| 38 |
+
|
| 39 |
+
def update_(self, x):
|
| 40 |
+
if not self.started:
|
| 41 |
+
self.running_mean_unsafe = x.mean()
|
| 42 |
+
self.running_var_unsafe = x.var()
|
| 43 |
+
else:
|
| 44 |
+
self.running_mean_unsafe = self._ema(self.running_mean_unsafe, x.mean())
|
| 45 |
+
self.running_var_unsafe = self._ema(self.running_var_unsafe, (x - self.running_mean).pow(2).mean())
|
| 46 |
+
|
| 47 |
+
def forward(self, x: Tensor, update=True):
|
| 48 |
+
if self.training and update:
|
| 49 |
+
self.update_(x)
|
| 50 |
+
self.stats = dict(mean=self.running_mean.item(), std=self.running_std.item())
|
| 51 |
+
x = (x - self.running_mean) / self.running_std
|
| 52 |
+
return x
|
| 53 |
+
|
| 54 |
+
def inverse(self, x: Tensor):
|
| 55 |
+
return x * self.running_std + self.running_mean
|
modules/repos_static/resemble_enhance/data/__init__.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import random
|
| 3 |
+
|
| 4 |
+
from torch.utils.data import DataLoader
|
| 5 |
+
|
| 6 |
+
from ..hparams import HParams
|
| 7 |
+
from .dataset import Dataset
|
| 8 |
+
from .utils import mix_fg_bg, rglob_audio_files
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def _create_datasets(hp: HParams, mode, val_size=10, seed=123):
|
| 14 |
+
paths = rglob_audio_files(hp.fg_dir)
|
| 15 |
+
logger.info(f"Found {len(paths)} audio files in {hp.fg_dir}")
|
| 16 |
+
|
| 17 |
+
random.Random(seed).shuffle(paths)
|
| 18 |
+
train_paths = paths[:-val_size]
|
| 19 |
+
val_paths = paths[-val_size:]
|
| 20 |
+
|
| 21 |
+
train_ds = Dataset(train_paths, hp, training=True, mode=mode)
|
| 22 |
+
val_ds = Dataset(val_paths, hp, training=False, mode=mode)
|
| 23 |
+
|
| 24 |
+
logger.info(f"Train set: {len(train_ds)} samples - Val set: {len(val_ds)} samples")
|
| 25 |
+
|
| 26 |
+
return train_ds, val_ds
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def create_dataloaders(hp: HParams, mode):
|
| 30 |
+
train_ds, val_ds = _create_datasets(hp=hp, mode=mode)
|
| 31 |
+
|
| 32 |
+
train_dl = DataLoader(
|
| 33 |
+
train_ds,
|
| 34 |
+
batch_size=hp.batch_size_per_gpu,
|
| 35 |
+
shuffle=True,
|
| 36 |
+
num_workers=hp.nj,
|
| 37 |
+
drop_last=True,
|
| 38 |
+
collate_fn=train_ds.collate_fn,
|
| 39 |
+
)
|
| 40 |
+
val_dl = DataLoader(
|
| 41 |
+
val_ds,
|
| 42 |
+
batch_size=1,
|
| 43 |
+
shuffle=False,
|
| 44 |
+
num_workers=hp.nj,
|
| 45 |
+
drop_last=False,
|
| 46 |
+
collate_fn=val_ds.collate_fn,
|
| 47 |
+
)
|
| 48 |
+
return train_dl, val_dl
|
modules/repos_static/resemble_enhance/data/dataset.py
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import random
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
import torchaudio
|
| 8 |
+
import torchaudio.functional as AF
|
| 9 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 10 |
+
from torch.utils.data import Dataset as DatasetBase
|
| 11 |
+
|
| 12 |
+
from ..hparams import HParams
|
| 13 |
+
from .distorter import Distorter
|
| 14 |
+
from .utils import rglob_audio_files
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def _normalize(x):
|
| 20 |
+
return x / (np.abs(x).max() + 1e-7)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def _collate(batch, key, tensor=True, pad=True):
|
| 24 |
+
l = [d[key] for d in batch]
|
| 25 |
+
if l[0] is None:
|
| 26 |
+
return None
|
| 27 |
+
if tensor:
|
| 28 |
+
l = [torch.from_numpy(x) for x in l]
|
| 29 |
+
if pad:
|
| 30 |
+
assert tensor, "Can't pad non-tensor"
|
| 31 |
+
l = pad_sequence(l, batch_first=True)
|
| 32 |
+
return l
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def praat_augment(wav, sr):
|
| 36 |
+
try:
|
| 37 |
+
import parselmouth
|
| 38 |
+
except ImportError:
|
| 39 |
+
raise ImportError("Please install parselmouth>=0.5.0 to use Praat augmentation")
|
| 40 |
+
# "praat-parselmouth @ git+https://github.com/YannickJadoul/Parselmouth@0bbcca69705ed73322f3712b19d71bb3694b2540",
|
| 41 |
+
# https://github.com/YannickJadoul/Parselmouth/issues/68
|
| 42 |
+
# note that this function may hang if the praat version is 0.4.3
|
| 43 |
+
assert wav.ndim == 1, f"wav.ndim must be 1 but got {wav.ndim}"
|
| 44 |
+
sound = parselmouth.Sound(wav, sr)
|
| 45 |
+
formant_shift_ratio = random.uniform(1.1, 1.5)
|
| 46 |
+
pitch_range_factor = random.uniform(0.5, 2.0)
|
| 47 |
+
sound = parselmouth.praat.call(sound, "Change gender", 75, 600, formant_shift_ratio, 0, pitch_range_factor, 1.0)
|
| 48 |
+
wav = np.array(sound.values)[0].astype(np.float32)
|
| 49 |
+
return wav
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class Dataset(DatasetBase):
|
| 53 |
+
def __init__(
|
| 54 |
+
self,
|
| 55 |
+
fg_paths: list[Path],
|
| 56 |
+
hp: HParams,
|
| 57 |
+
training=True,
|
| 58 |
+
max_retries=100,
|
| 59 |
+
silent_fg_prob=0.01,
|
| 60 |
+
mode=False,
|
| 61 |
+
):
|
| 62 |
+
super().__init__()
|
| 63 |
+
|
| 64 |
+
assert mode in ("enhancer", "denoiser"), f"Invalid mode: {mode}"
|
| 65 |
+
|
| 66 |
+
self.hp = hp
|
| 67 |
+
self.fg_paths = fg_paths
|
| 68 |
+
self.bg_paths = rglob_audio_files(hp.bg_dir)
|
| 69 |
+
|
| 70 |
+
if len(self.fg_paths) == 0:
|
| 71 |
+
raise ValueError(f"No foreground audio files found in {hp.fg_dir}")
|
| 72 |
+
|
| 73 |
+
if len(self.bg_paths) == 0:
|
| 74 |
+
raise ValueError(f"No background audio files found in {hp.bg_dir}")
|
| 75 |
+
|
| 76 |
+
logger.info(f"Found {len(self.fg_paths)} foreground files and {len(self.bg_paths)} background files")
|
| 77 |
+
|
| 78 |
+
self.training = training
|
| 79 |
+
self.max_retries = max_retries
|
| 80 |
+
self.silent_fg_prob = silent_fg_prob
|
| 81 |
+
|
| 82 |
+
self.mode = mode
|
| 83 |
+
self.distorter = Distorter(hp, training=training, mode=mode)
|
| 84 |
+
|
| 85 |
+
def _load_wav(self, path, length=None, random_crop=True):
|
| 86 |
+
wav, sr = torchaudio.load(path)
|
| 87 |
+
|
| 88 |
+
wav = AF.resample(
|
| 89 |
+
waveform=wav,
|
| 90 |
+
orig_freq=sr,
|
| 91 |
+
new_freq=self.hp.wav_rate,
|
| 92 |
+
lowpass_filter_width=64,
|
| 93 |
+
rolloff=0.9475937167399596,
|
| 94 |
+
resampling_method="sinc_interp_kaiser",
|
| 95 |
+
beta=14.769656459379492,
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
wav = wav.float().numpy()
|
| 99 |
+
|
| 100 |
+
if wav.ndim == 2:
|
| 101 |
+
wav = np.mean(wav, axis=0)
|
| 102 |
+
|
| 103 |
+
if length is None and self.training:
|
| 104 |
+
length = int(self.hp.training_seconds * self.hp.wav_rate)
|
| 105 |
+
|
| 106 |
+
if length is not None:
|
| 107 |
+
if random_crop:
|
| 108 |
+
start = random.randint(0, max(0, len(wav) - length))
|
| 109 |
+
wav = wav[start : start + length]
|
| 110 |
+
else:
|
| 111 |
+
wav = wav[:length]
|
| 112 |
+
|
| 113 |
+
if length is not None and len(wav) < length:
|
| 114 |
+
wav = np.pad(wav, (0, length - len(wav)))
|
| 115 |
+
|
| 116 |
+
wav = _normalize(wav)
|
| 117 |
+
|
| 118 |
+
return wav
|
| 119 |
+
|
| 120 |
+
def _getitem_unsafe(self, index: int):
|
| 121 |
+
fg_path = self.fg_paths[index]
|
| 122 |
+
|
| 123 |
+
if self.training and random.random() < self.silent_fg_prob:
|
| 124 |
+
fg_wav = np.zeros(int(self.hp.training_seconds * self.hp.wav_rate), dtype=np.float32)
|
| 125 |
+
else:
|
| 126 |
+
fg_wav = self._load_wav(fg_path)
|
| 127 |
+
if random.random() < self.hp.praat_augment_prob and self.training:
|
| 128 |
+
fg_wav = praat_augment(fg_wav, self.hp.wav_rate)
|
| 129 |
+
|
| 130 |
+
if self.hp.load_fg_only:
|
| 131 |
+
bg_wav = None
|
| 132 |
+
fg_dwav = None
|
| 133 |
+
bg_dwav = None
|
| 134 |
+
else:
|
| 135 |
+
fg_dwav = _normalize(self.distorter(fg_wav, self.hp.wav_rate)).astype(np.float32)
|
| 136 |
+
if self.training:
|
| 137 |
+
bg_path = random.choice(self.bg_paths)
|
| 138 |
+
else:
|
| 139 |
+
# Deterministic for validation
|
| 140 |
+
bg_path = self.bg_paths[index % len(self.bg_paths)]
|
| 141 |
+
bg_wav = self._load_wav(bg_path, length=len(fg_wav), random_crop=self.training)
|
| 142 |
+
bg_dwav = _normalize(self.distorter(bg_wav, self.hp.wav_rate)).astype(np.float32)
|
| 143 |
+
|
| 144 |
+
return dict(
|
| 145 |
+
fg_wav=fg_wav,
|
| 146 |
+
bg_wav=bg_wav,
|
| 147 |
+
fg_dwav=fg_dwav,
|
| 148 |
+
bg_dwav=bg_dwav,
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
def __getitem__(self, index: int):
|
| 152 |
+
for i in range(self.max_retries):
|
| 153 |
+
try:
|
| 154 |
+
return self._getitem_unsafe(index)
|
| 155 |
+
except Exception as e:
|
| 156 |
+
if i == self.max_retries - 1:
|
| 157 |
+
raise RuntimeError(f"Failed to load {self.fg_paths[index]} after {self.max_retries} retries") from e
|
| 158 |
+
logger.debug(f"Error loading {self.fg_paths[index]}: {e}, skipping")
|
| 159 |
+
index = np.random.randint(0, len(self))
|
| 160 |
+
|
| 161 |
+
def __len__(self):
|
| 162 |
+
return len(self.fg_paths)
|
| 163 |
+
|
| 164 |
+
@staticmethod
|
| 165 |
+
def collate_fn(batch):
|
| 166 |
+
return dict(
|
| 167 |
+
fg_wavs=_collate(batch, "fg_wav"),
|
| 168 |
+
bg_wavs=_collate(batch, "bg_wav"),
|
| 169 |
+
fg_dwavs=_collate(batch, "fg_dwav"),
|
| 170 |
+
bg_dwavs=_collate(batch, "bg_dwav"),
|
| 171 |
+
)
|
modules/repos_static/resemble_enhance/data/distorter/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .distorter import Distorter
|
modules/repos_static/resemble_enhance/data/distorter/base.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import itertools
|
| 2 |
+
import os
|
| 3 |
+
import random
|
| 4 |
+
import time
|
| 5 |
+
import warnings
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
_DEBUG = bool(os.environ.get("DEBUG", False))
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class Effect:
|
| 13 |
+
def apply(self, wav: np.ndarray, sr: int):
|
| 14 |
+
"""
|
| 15 |
+
Args:
|
| 16 |
+
wav: (T)
|
| 17 |
+
sr: sample rate
|
| 18 |
+
Returns:
|
| 19 |
+
wav: (T) with the same sample rate of `sr`
|
| 20 |
+
"""
|
| 21 |
+
raise NotImplementedError
|
| 22 |
+
|
| 23 |
+
def __call__(self, wav: np.ndarray, sr: int):
|
| 24 |
+
"""
|
| 25 |
+
Args:
|
| 26 |
+
wav: (T)
|
| 27 |
+
sr: sample rate
|
| 28 |
+
Returns:
|
| 29 |
+
wav: (T) with the same sample rate of `sr`
|
| 30 |
+
"""
|
| 31 |
+
assert len(wav.shape) == 1, wav.shape
|
| 32 |
+
|
| 33 |
+
if _DEBUG:
|
| 34 |
+
start = time.time()
|
| 35 |
+
else:
|
| 36 |
+
start = None
|
| 37 |
+
|
| 38 |
+
shape = wav.shape
|
| 39 |
+
assert wav.ndim == 1, f"{self}: Expected wav.ndim == 1, got {wav.ndim}."
|
| 40 |
+
wav = self.apply(wav, sr)
|
| 41 |
+
assert shape == wav.shape, f"{self}: {shape} != {wav.shape}."
|
| 42 |
+
|
| 43 |
+
if start is not None:
|
| 44 |
+
end = time.time()
|
| 45 |
+
print(f"{self.__class__.__name__}: {end - start:.3f} sec")
|
| 46 |
+
|
| 47 |
+
return wav
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class Chain(Effect):
|
| 51 |
+
def __init__(self, *effects):
|
| 52 |
+
super().__init__()
|
| 53 |
+
|
| 54 |
+
self.effects = effects
|
| 55 |
+
|
| 56 |
+
def apply(self, wav, sr):
|
| 57 |
+
for effect in self.effects:
|
| 58 |
+
wav = effect(wav, sr)
|
| 59 |
+
return wav
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class Maybe(Effect):
|
| 63 |
+
def __init__(self, prob, effect):
|
| 64 |
+
super().__init__()
|
| 65 |
+
|
| 66 |
+
self.prob = prob
|
| 67 |
+
self.effect = effect
|
| 68 |
+
|
| 69 |
+
if _DEBUG:
|
| 70 |
+
warnings.warn("DEBUG mode is on. Maybe -> Must.")
|
| 71 |
+
self.prob = 1
|
| 72 |
+
|
| 73 |
+
def apply(self, wav, sr):
|
| 74 |
+
if random.random() > self.prob:
|
| 75 |
+
return wav
|
| 76 |
+
return self.effect(wav, sr)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class Choice(Effect):
|
| 80 |
+
def __init__(self, *effects, **kwargs):
|
| 81 |
+
super().__init__()
|
| 82 |
+
self.effects = effects
|
| 83 |
+
self.kwargs = kwargs
|
| 84 |
+
|
| 85 |
+
def apply(self, wav, sr):
|
| 86 |
+
return np.random.choice(self.effects, **self.kwargs)(wav, sr)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class Permutation(Effect):
|
| 90 |
+
def __init__(self, *effects, n: int | None = None):
|
| 91 |
+
super().__init__()
|
| 92 |
+
self.effects = effects
|
| 93 |
+
self.n = n
|
| 94 |
+
|
| 95 |
+
def apply(self, wav, sr):
|
| 96 |
+
if self.n is None:
|
| 97 |
+
n = np.random.binomial(len(self.effects), 0.5)
|
| 98 |
+
else:
|
| 99 |
+
n = self.n
|
| 100 |
+
if n == 0:
|
| 101 |
+
return wav
|
| 102 |
+
perms = itertools.permutations(self.effects, n)
|
| 103 |
+
effects = random.choice(list(perms))
|
| 104 |
+
return Chain(*effects)(wav, sr)
|
modules/repos_static/resemble_enhance/data/distorter/custom.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import random
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from functools import cached_property
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
import librosa
|
| 8 |
+
import numpy as np
|
| 9 |
+
from scipy import signal
|
| 10 |
+
|
| 11 |
+
from ..utils import walk_paths
|
| 12 |
+
from .base import Effect
|
| 13 |
+
|
| 14 |
+
_logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@dataclass
|
| 18 |
+
class RandomRIR(Effect):
|
| 19 |
+
rir_dir: Path | None
|
| 20 |
+
rir_rate: int = 44_000
|
| 21 |
+
rir_suffix: str = ".npy"
|
| 22 |
+
deterministic: bool = False
|
| 23 |
+
|
| 24 |
+
@cached_property
|
| 25 |
+
def rir_paths(self):
|
| 26 |
+
if self.rir_dir is None:
|
| 27 |
+
return []
|
| 28 |
+
return list(walk_paths(self.rir_dir, self.rir_suffix))
|
| 29 |
+
|
| 30 |
+
def _sample_rir(self):
|
| 31 |
+
if len(self.rir_paths) == 0:
|
| 32 |
+
return None
|
| 33 |
+
|
| 34 |
+
if self.deterministic:
|
| 35 |
+
rir_path = self.rir_paths[0]
|
| 36 |
+
else:
|
| 37 |
+
rir_path = random.choice(self.rir_paths)
|
| 38 |
+
|
| 39 |
+
rir = np.squeeze(np.load(rir_path))
|
| 40 |
+
assert isinstance(rir, np.ndarray)
|
| 41 |
+
|
| 42 |
+
return rir
|
| 43 |
+
|
| 44 |
+
def apply(self, wav, sr):
|
| 45 |
+
# ref: https://github.com/haoheliu/voicefixer_main/blob/b06e07c945ac1d309b8a57ddcd599ca376b98cd9/dataloaders/augmentation/magical_effects.py#L158
|
| 46 |
+
|
| 47 |
+
if len(self.rir_paths) == 0:
|
| 48 |
+
return wav
|
| 49 |
+
|
| 50 |
+
length = len(wav)
|
| 51 |
+
|
| 52 |
+
wav = librosa.resample(wav, orig_sr=sr, target_sr=self.rir_rate, res_type="kaiser_fast")
|
| 53 |
+
rir = self._sample_rir()
|
| 54 |
+
|
| 55 |
+
wav = signal.convolve(wav, rir, mode="same")
|
| 56 |
+
|
| 57 |
+
actlev = np.max(np.abs(wav))
|
| 58 |
+
if actlev > 0.99:
|
| 59 |
+
wav = (wav / actlev) * 0.98
|
| 60 |
+
|
| 61 |
+
wav = librosa.resample(wav, orig_sr=self.rir_rate, target_sr=sr, res_type="kaiser_fast")
|
| 62 |
+
|
| 63 |
+
if abs(length - len(wav)) > 10:
|
| 64 |
+
_logger.warning(f"length mismatch: {length} vs {len(wav)}")
|
| 65 |
+
|
| 66 |
+
if length > len(wav):
|
| 67 |
+
wav = np.pad(wav, (0, length - len(wav)))
|
| 68 |
+
elif length < len(wav):
|
| 69 |
+
wav = wav[:length]
|
| 70 |
+
|
| 71 |
+
return wav
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class RandomGaussianNoise(Effect):
|
| 75 |
+
def __init__(self, alpha_range=(0.8, 1)):
|
| 76 |
+
super().__init__()
|
| 77 |
+
self.alpha_range = alpha_range
|
| 78 |
+
|
| 79 |
+
def apply(self, wav, sr):
|
| 80 |
+
noise = np.random.randn(*wav.shape)
|
| 81 |
+
noise_energy = np.sum(noise**2)
|
| 82 |
+
wav_energy = np.sum(wav**2)
|
| 83 |
+
noise = noise * np.sqrt(wav_energy / noise_energy)
|
| 84 |
+
alpha = random.uniform(*self.alpha_range)
|
| 85 |
+
return wav * alpha + noise * (1 - alpha)
|
modules/repos_static/resemble_enhance/data/distorter/distorter.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ...hparams import HParams
|
| 2 |
+
from .base import Chain, Choice, Permutation
|
| 3 |
+
from .custom import RandomGaussianNoise, RandomRIR
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class Distorter(Chain):
|
| 7 |
+
def __init__(self, hp: HParams, training: bool = False, mode: str = "enhancer"):
|
| 8 |
+
# Lazy import
|
| 9 |
+
from .sox import RandomBandpassDistorter, RandomEqualizer, RandomLowpassDistorter, RandomOverdrive, RandomReverb
|
| 10 |
+
|
| 11 |
+
if training:
|
| 12 |
+
permutation = Permutation(
|
| 13 |
+
RandomRIR(hp.rir_dir),
|
| 14 |
+
RandomReverb(),
|
| 15 |
+
RandomGaussianNoise(),
|
| 16 |
+
RandomOverdrive(),
|
| 17 |
+
RandomEqualizer(),
|
| 18 |
+
Choice(
|
| 19 |
+
RandomLowpassDistorter(),
|
| 20 |
+
RandomBandpassDistorter(),
|
| 21 |
+
),
|
| 22 |
+
)
|
| 23 |
+
if mode == "denoiser":
|
| 24 |
+
super().__init__(permutation)
|
| 25 |
+
else:
|
| 26 |
+
# 80%: distortion, 20%: clean
|
| 27 |
+
super().__init__(Choice(permutation, Chain(), p=[0.8, 0.2]))
|
| 28 |
+
else:
|
| 29 |
+
super().__init__(
|
| 30 |
+
RandomRIR(hp.rir_dir, deterministic=True),
|
| 31 |
+
RandomReverb(deterministic=True),
|
| 32 |
+
)
|
modules/repos_static/resemble_enhance/data/distorter/sox.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
import random
|
| 4 |
+
import warnings
|
| 5 |
+
from functools import partial
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
try:
|
| 11 |
+
import augment
|
| 12 |
+
except ImportError:
|
| 13 |
+
raise ImportError(
|
| 14 |
+
"augment is not installed, please install it first using:"
|
| 15 |
+
"\npip install git+https://github.com/facebookresearch/WavAugment@54afcdb00ccc852c2f030f239f8532c9562b550e"
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
from .base import Effect
|
| 19 |
+
|
| 20 |
+
_logger = logging.getLogger(__name__)
|
| 21 |
+
_DEBUG = bool(os.environ.get("DEBUG", False))
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class AttachableEffect(Effect):
|
| 25 |
+
def attach(self, chain: augment.EffectChain) -> augment.EffectChain:
|
| 26 |
+
raise NotImplementedError
|
| 27 |
+
|
| 28 |
+
def apply(self, wav: np.ndarray, sr: int):
|
| 29 |
+
chain = augment.EffectChain()
|
| 30 |
+
chain = self.attach(chain)
|
| 31 |
+
tensor = torch.from_numpy(wav)[None].float() # (1, T)
|
| 32 |
+
tensor = chain.apply(tensor, src_info={"rate": sr}, target_info={"channels": 1, "rate": sr})
|
| 33 |
+
wav = tensor.numpy()[0] # (T,)
|
| 34 |
+
return wav
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class SoxEffect(AttachableEffect):
|
| 38 |
+
def __init__(self, effect_name: str, *args, **kwargs):
|
| 39 |
+
self.effect_name = effect_name
|
| 40 |
+
self.args = args
|
| 41 |
+
self.kwargs = kwargs
|
| 42 |
+
|
| 43 |
+
def attach(self, chain: augment.EffectChain) -> augment.EffectChain:
|
| 44 |
+
_logger.debug(f"Attaching {self.effect_name} with {self.args} and {self.kwargs}")
|
| 45 |
+
if not hasattr(chain, self.effect_name):
|
| 46 |
+
raise ValueError(f"EffectChain has no attribute {self.effect_name}")
|
| 47 |
+
return getattr(chain, self.effect_name)(*self.args, **self.kwargs)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class Maybe(AttachableEffect):
|
| 51 |
+
"""
|
| 52 |
+
Attach an effect with a probability.
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
def __init__(self, prob: float, effect: AttachableEffect):
|
| 56 |
+
self.prob = prob
|
| 57 |
+
self.effect = effect
|
| 58 |
+
if _DEBUG:
|
| 59 |
+
warnings.warn("DEBUG mode is on. Maybe -> Must.")
|
| 60 |
+
self.prob = 1
|
| 61 |
+
|
| 62 |
+
def attach(self, chain: augment.EffectChain) -> augment.EffectChain:
|
| 63 |
+
if random.random() > self.prob:
|
| 64 |
+
return chain
|
| 65 |
+
return self.effect.attach(chain)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class Chain(AttachableEffect):
|
| 69 |
+
"""
|
| 70 |
+
Attach a chain of effects.
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
def __init__(self, *effects: AttachableEffect):
|
| 74 |
+
self.effects = effects
|
| 75 |
+
|
| 76 |
+
def attach(self, chain: augment.EffectChain) -> augment.EffectChain:
|
| 77 |
+
for effect in self.effects:
|
| 78 |
+
chain = effect.attach(chain)
|
| 79 |
+
return chain
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class Choice(AttachableEffect):
|
| 83 |
+
"""
|
| 84 |
+
Attach one of the effects randomly.
|
| 85 |
+
"""
|
| 86 |
+
|
| 87 |
+
def __init__(self, *effects: AttachableEffect):
|
| 88 |
+
self.effects = effects
|
| 89 |
+
|
| 90 |
+
def attach(self, chain: augment.EffectChain) -> augment.EffectChain:
|
| 91 |
+
return random.choice(self.effects).attach(chain)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class Generator:
|
| 95 |
+
def __call__(self) -> str:
|
| 96 |
+
raise NotImplementedError
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class Uniform(Generator):
|
| 100 |
+
def __init__(self, low, high):
|
| 101 |
+
self.low = low
|
| 102 |
+
self.high = high
|
| 103 |
+
|
| 104 |
+
def __call__(self) -> str:
|
| 105 |
+
return str(random.uniform(self.low, self.high))
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class Randint(Generator):
|
| 109 |
+
def __init__(self, low, high):
|
| 110 |
+
self.low = low
|
| 111 |
+
self.high = high
|
| 112 |
+
|
| 113 |
+
def __call__(self) -> str:
|
| 114 |
+
return str(random.randint(self.low, self.high))
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class Concat(Generator):
|
| 118 |
+
def __init__(self, *parts: Generator | str):
|
| 119 |
+
self.parts = parts
|
| 120 |
+
|
| 121 |
+
def __call__(self):
|
| 122 |
+
return "".join([part if isinstance(part, str) else part() for part in self.parts])
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
class RandomLowpassDistorter(SoxEffect):
|
| 126 |
+
def __init__(self, low=2000, high=16000):
|
| 127 |
+
super().__init__("sinc", "-n", Randint(50, 200), Concat("-", Uniform(low, high)))
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
class RandomBandpassDistorter(SoxEffect):
|
| 131 |
+
def __init__(self, low=100, high=1000, min_width=2000, max_width=4000):
|
| 132 |
+
super().__init__("sinc", "-n", Randint(50, 200), partial(self._fn, low, high, min_width, max_width))
|
| 133 |
+
|
| 134 |
+
@staticmethod
|
| 135 |
+
def _fn(low, high, min_width, max_width):
|
| 136 |
+
start = random.randint(low, high)
|
| 137 |
+
stop = start + random.randint(min_width, max_width)
|
| 138 |
+
return f"{start}-{stop}"
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
class RandomEqualizer(SoxEffect):
|
| 142 |
+
def __init__(self, low=100, high=4000, q_low=1, q_high=5, db_low: int = -30, db_high: int = 30):
|
| 143 |
+
super().__init__(
|
| 144 |
+
"equalizer",
|
| 145 |
+
Uniform(low, high),
|
| 146 |
+
lambda: f"{random.randint(q_low, q_high)}q",
|
| 147 |
+
lambda: random.randint(db_low, db_high),
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
class RandomOverdrive(SoxEffect):
|
| 152 |
+
def __init__(self, gain_low=5, gain_high=40, colour_low=20, colour_high=80):
|
| 153 |
+
super().__init__("overdrive", Uniform(gain_low, gain_high), Uniform(colour_low, colour_high))
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
class RandomReverb(Chain):
|
| 157 |
+
def __init__(self, deterministic=False):
|
| 158 |
+
super().__init__(
|
| 159 |
+
SoxEffect(
|
| 160 |
+
"reverb",
|
| 161 |
+
Uniform(50, 50) if deterministic else Uniform(0, 100),
|
| 162 |
+
Uniform(50, 50) if deterministic else Uniform(0, 100),
|
| 163 |
+
Uniform(50, 50) if deterministic else Uniform(0, 100),
|
| 164 |
+
),
|
| 165 |
+
SoxEffect("channels", 1),
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class Flanger(SoxEffect):
|
| 170 |
+
def __init__(self):
|
| 171 |
+
super().__init__("flanger")
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
class Phaser(SoxEffect):
|
| 175 |
+
def __init__(self):
|
| 176 |
+
super().__init__("phaser")
|
modules/repos_static/resemble_enhance/data/utils.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
from typing import Callable
|
| 3 |
+
|
| 4 |
+
from torch import Tensor
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def walk_paths(root, suffix):
|
| 8 |
+
for path in Path(root).iterdir():
|
| 9 |
+
if path.is_dir():
|
| 10 |
+
yield from walk_paths(path, suffix)
|
| 11 |
+
elif path.suffix == suffix:
|
| 12 |
+
yield path
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def rglob_audio_files(path: Path):
|
| 16 |
+
return list(walk_paths(path, ".wav")) + list(walk_paths(path, ".flac"))
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def mix_fg_bg(fg: Tensor, bg: Tensor, alpha: float | Callable[..., float] = 0.5, eps=1e-7):
|
| 20 |
+
"""
|
| 21 |
+
Args:
|
| 22 |
+
fg: (b, t)
|
| 23 |
+
bg: (b, t)
|
| 24 |
+
"""
|
| 25 |
+
assert bg.shape == fg.shape, f"bg.shape != fg.shape: {bg.shape} != {fg.shape}"
|
| 26 |
+
fg = fg / (fg.abs().max(dim=-1, keepdim=True).values + eps)
|
| 27 |
+
bg = bg / (bg.abs().max(dim=-1, keepdim=True).values + eps)
|
| 28 |
+
|
| 29 |
+
fg_energy = fg.pow(2).sum(dim=-1, keepdim=True)
|
| 30 |
+
bg_energy = bg.pow(2).sum(dim=-1, keepdim=True)
|
| 31 |
+
|
| 32 |
+
fg = fg / (fg_energy + eps).sqrt()
|
| 33 |
+
bg = bg / (bg_energy + eps).sqrt()
|
| 34 |
+
|
| 35 |
+
if callable(alpha):
|
| 36 |
+
alpha = alpha()
|
| 37 |
+
|
| 38 |
+
assert 0 <= alpha <= 1, f"alpha must be between 0 and 1: {alpha}"
|
| 39 |
+
|
| 40 |
+
mx = alpha * fg + (1 - alpha) * bg
|
| 41 |
+
mx = mx / (mx.abs().max(dim=-1, keepdim=True).values + eps)
|
| 42 |
+
|
| 43 |
+
return mx
|
modules/repos_static/resemble_enhance/denoiser/__init__.py
ADDED
|
File without changes
|
modules/repos_static/resemble_enhance/denoiser/__main__.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torchaudio
|
| 6 |
+
|
| 7 |
+
from .inference import denoise
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@torch.inference_mode()
|
| 11 |
+
def main():
|
| 12 |
+
parser = argparse.ArgumentParser()
|
| 13 |
+
parser.add_argument("in_dir", type=Path, help="Path to input audio folder")
|
| 14 |
+
parser.add_argument("out_dir", type=Path, help="Output folder")
|
| 15 |
+
parser.add_argument("--run_dir", type=Path, default="runs/denoiser", help="Path to run folder")
|
| 16 |
+
parser.add_argument("--suffix", type=str, default=".wav", help="File suffix")
|
| 17 |
+
parser.add_argument("--device", type=str, default="cuda", help="Device")
|
| 18 |
+
args = parser.parse_args()
|
| 19 |
+
|
| 20 |
+
for path in args.in_dir.glob(f"**/*{args.suffix}"):
|
| 21 |
+
print(f"Processing {path} ..")
|
| 22 |
+
dwav, sr = torchaudio.load(path)
|
| 23 |
+
hwav, sr = denoise(dwav[0], sr, args.run_dir, args.device)
|
| 24 |
+
out_path = args.out_dir / path.relative_to(args.in_dir)
|
| 25 |
+
out_path.parent.mkdir(parents=True, exist_ok=True)
|
| 26 |
+
torchaudio.save(out_path, hwav[None], sr)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
if __name__ == "__main__":
|
| 30 |
+
main()
|
modules/repos_static/resemble_enhance/denoiser/denoiser.py
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from torch import Tensor, nn
|
| 6 |
+
|
| 7 |
+
from ..melspec import MelSpectrogram
|
| 8 |
+
from .hparams import HParams
|
| 9 |
+
from .unet import UNet
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def _normalize(x: Tensor) -> Tensor:
|
| 15 |
+
return x / (x.abs().max(dim=-1, keepdim=True).values + 1e-7)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class Denoiser(nn.Module):
|
| 19 |
+
@property
|
| 20 |
+
def stft_cfg(self) -> dict:
|
| 21 |
+
hop_size = self.hp.hop_size
|
| 22 |
+
return dict(hop_length=hop_size, n_fft=hop_size * 4, win_length=hop_size * 4)
|
| 23 |
+
|
| 24 |
+
@property
|
| 25 |
+
def n_fft(self):
|
| 26 |
+
return self.stft_cfg["n_fft"]
|
| 27 |
+
|
| 28 |
+
@property
|
| 29 |
+
def eps(self):
|
| 30 |
+
return 1e-7
|
| 31 |
+
|
| 32 |
+
def __init__(self, hp: HParams):
|
| 33 |
+
super().__init__()
|
| 34 |
+
self.hp = hp
|
| 35 |
+
self.net = UNet(input_dim=3, output_dim=3)
|
| 36 |
+
self.mel_fn = MelSpectrogram(hp)
|
| 37 |
+
|
| 38 |
+
self.dummy: Tensor
|
| 39 |
+
self.register_buffer("dummy", torch.zeros(1), persistent=False)
|
| 40 |
+
|
| 41 |
+
def to_mel(self, x: Tensor, drop_last=True):
|
| 42 |
+
"""
|
| 43 |
+
Args:
|
| 44 |
+
x: (b t), wavs
|
| 45 |
+
Returns:
|
| 46 |
+
o: (b c t), mels
|
| 47 |
+
"""
|
| 48 |
+
if drop_last:
|
| 49 |
+
return self.mel_fn(x)[..., :-1] # (b d t)
|
| 50 |
+
return self.mel_fn(x)
|
| 51 |
+
|
| 52 |
+
def _stft(self, x):
|
| 53 |
+
"""
|
| 54 |
+
Args:
|
| 55 |
+
x: (b t)
|
| 56 |
+
Returns:
|
| 57 |
+
mag: (b f t) in [0, inf)
|
| 58 |
+
cos: (b f t) in [-1, 1]
|
| 59 |
+
sin: (b f t) in [-1, 1]
|
| 60 |
+
"""
|
| 61 |
+
dtype = x.dtype
|
| 62 |
+
device = x.device
|
| 63 |
+
|
| 64 |
+
if x.is_mps:
|
| 65 |
+
x = x.cpu()
|
| 66 |
+
|
| 67 |
+
window = torch.hann_window(self.stft_cfg["win_length"], device=x.device)
|
| 68 |
+
s = torch.stft(x.float(), **self.stft_cfg, window=window, return_complex=True) # (b f t+1)
|
| 69 |
+
|
| 70 |
+
s = s[..., :-1] # (b f t)
|
| 71 |
+
|
| 72 |
+
mag = s.abs() # (b f t)
|
| 73 |
+
|
| 74 |
+
phi = s.angle() # (b f t)
|
| 75 |
+
cos = phi.cos() # (b f t)
|
| 76 |
+
sin = phi.sin() # (b f t)
|
| 77 |
+
|
| 78 |
+
mag = mag.to(dtype=dtype, device=device)
|
| 79 |
+
cos = cos.to(dtype=dtype, device=device)
|
| 80 |
+
sin = sin.to(dtype=dtype, device=device)
|
| 81 |
+
|
| 82 |
+
return mag, cos, sin
|
| 83 |
+
|
| 84 |
+
def _istft(self, mag: Tensor, cos: Tensor, sin: Tensor):
|
| 85 |
+
"""
|
| 86 |
+
Args:
|
| 87 |
+
mag: (b f t) in [0, inf)
|
| 88 |
+
cos: (b f t) in [-1, 1]
|
| 89 |
+
sin: (b f t) in [-1, 1]
|
| 90 |
+
Returns:
|
| 91 |
+
x: (b t)
|
| 92 |
+
"""
|
| 93 |
+
device = mag.device
|
| 94 |
+
dtype = mag.dtype
|
| 95 |
+
|
| 96 |
+
if mag.is_mps:
|
| 97 |
+
mag = mag.cpu()
|
| 98 |
+
cos = cos.cpu()
|
| 99 |
+
sin = sin.cpu()
|
| 100 |
+
|
| 101 |
+
real = mag * cos # (b f t)
|
| 102 |
+
imag = mag * sin # (b f t)
|
| 103 |
+
|
| 104 |
+
s = torch.complex(real, imag) # (b f t)
|
| 105 |
+
|
| 106 |
+
if s.isnan().any():
|
| 107 |
+
logger.warning("NaN detected in ISTFT input.")
|
| 108 |
+
|
| 109 |
+
s = F.pad(s, (0, 1), "replicate") # (b f t+1)
|
| 110 |
+
|
| 111 |
+
window = torch.hann_window(self.stft_cfg["win_length"], device=s.device)
|
| 112 |
+
x = torch.istft(s, **self.stft_cfg, window=window, return_complex=False)
|
| 113 |
+
|
| 114 |
+
if x.isnan().any():
|
| 115 |
+
logger.warning("NaN detected in ISTFT output, set to zero.")
|
| 116 |
+
x = torch.where(x.isnan(), torch.zeros_like(x), x)
|
| 117 |
+
|
| 118 |
+
x = x.to(dtype=dtype, device=device)
|
| 119 |
+
|
| 120 |
+
return x
|
| 121 |
+
|
| 122 |
+
def _magphase(self, real, imag):
|
| 123 |
+
mag = (real.pow(2) + imag.pow(2) + self.eps).sqrt()
|
| 124 |
+
cos = real / mag
|
| 125 |
+
sin = imag / mag
|
| 126 |
+
return mag, cos, sin
|
| 127 |
+
|
| 128 |
+
def _predict(self, mag: Tensor, cos: Tensor, sin: Tensor):
|
| 129 |
+
"""
|
| 130 |
+
Args:
|
| 131 |
+
mag: (b f t)
|
| 132 |
+
cos: (b f t)
|
| 133 |
+
sin: (b f t)
|
| 134 |
+
Returns:
|
| 135 |
+
mag_mask: (b f t) in [0, 1], magnitude mask
|
| 136 |
+
cos_res: (b f t) in [-1, 1], phase residual
|
| 137 |
+
sin_res: (b f t) in [-1, 1], phase residual
|
| 138 |
+
"""
|
| 139 |
+
x = torch.stack([mag, cos, sin], dim=1) # (b 3 f t)
|
| 140 |
+
mag_mask, real, imag = self.net(x).unbind(1) # (b 3 f t)
|
| 141 |
+
mag_mask = mag_mask.sigmoid() # (b f t)
|
| 142 |
+
real = real.tanh() # (b f t)
|
| 143 |
+
imag = imag.tanh() # (b f t)
|
| 144 |
+
_, cos_res, sin_res = self._magphase(real, imag) # (b f t)
|
| 145 |
+
return mag_mask, sin_res, cos_res
|
| 146 |
+
|
| 147 |
+
def _separate(self, mag, cos, sin, mag_mask, cos_res, sin_res):
|
| 148 |
+
"""Ref: https://audio-agi.github.io/Separate-Anything-You-Describe/AudioSep_arXiv.pdf"""
|
| 149 |
+
sep_mag = F.relu(mag * mag_mask)
|
| 150 |
+
sep_cos = cos * cos_res - sin * sin_res
|
| 151 |
+
sep_sin = sin * cos_res + cos * sin_res
|
| 152 |
+
return sep_mag, sep_cos, sep_sin
|
| 153 |
+
|
| 154 |
+
def forward(self, x: Tensor, y: Tensor | None = None):
|
| 155 |
+
"""
|
| 156 |
+
Args:
|
| 157 |
+
x: (b t), a mixed audio
|
| 158 |
+
y: (b t), a fg audio
|
| 159 |
+
"""
|
| 160 |
+
assert x.dim() == 2, f"Expected (b t), got {x.size()}"
|
| 161 |
+
x = x.to(self.dummy)
|
| 162 |
+
x = _normalize(x)
|
| 163 |
+
|
| 164 |
+
if y is not None:
|
| 165 |
+
assert y.dim() == 2, f"Expected (b t), got {y.size()}"
|
| 166 |
+
y = y.to(self.dummy)
|
| 167 |
+
y = _normalize(y)
|
| 168 |
+
|
| 169 |
+
mag, cos, sin = self._stft(x) # (b 2f t)
|
| 170 |
+
mag_mask, sin_res, cos_res = self._predict(mag, cos, sin)
|
| 171 |
+
sep_mag, sep_cos, sep_sin = self._separate(mag, cos, sin, mag_mask, cos_res, sin_res)
|
| 172 |
+
|
| 173 |
+
o = self._istft(sep_mag, sep_cos, sep_sin)
|
| 174 |
+
|
| 175 |
+
npad = x.shape[-1] - o.shape[-1]
|
| 176 |
+
o = F.pad(o, (0, npad))
|
| 177 |
+
|
| 178 |
+
if y is not None:
|
| 179 |
+
self.losses = dict(l1=F.l1_loss(o, y))
|
| 180 |
+
|
| 181 |
+
return o
|
modules/repos_static/resemble_enhance/denoiser/hparams.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
|
| 3 |
+
from ..hparams import HParams as HParamsBase
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
@dataclass(frozen=True)
|
| 7 |
+
class HParams(HParamsBase):
|
| 8 |
+
batch_size_per_gpu: int = 128
|
| 9 |
+
distort_prob: float = 0.5
|
modules/repos_static/resemble_enhance/denoiser/inference.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from functools import cache
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from ..denoiser.denoiser import Denoiser
|
| 7 |
+
|
| 8 |
+
from ..inference import inference
|
| 9 |
+
from .hparams import HParams
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@cache
|
| 15 |
+
def load_denoiser(run_dir, device):
|
| 16 |
+
if run_dir is None:
|
| 17 |
+
return Denoiser(HParams())
|
| 18 |
+
hp = HParams.load(run_dir)
|
| 19 |
+
denoiser = Denoiser(hp)
|
| 20 |
+
path = run_dir / "ds" / "G" / "default" / "mp_rank_00_model_states.pt"
|
| 21 |
+
state_dict = torch.load(path, map_location="cpu")["module"]
|
| 22 |
+
denoiser.load_state_dict(state_dict)
|
| 23 |
+
denoiser.eval()
|
| 24 |
+
denoiser.to(device)
|
| 25 |
+
return denoiser
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@torch.inference_mode()
|
| 29 |
+
def denoise(dwav, sr, run_dir, device):
|
| 30 |
+
denoiser = load_denoiser(run_dir, device)
|
| 31 |
+
return inference(model=denoiser, dwav=dwav, sr=sr, device=device)
|
modules/repos_static/resemble_enhance/denoiser/unet.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn.functional as F
|
| 2 |
+
from torch import nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class PreactResBlock(nn.Sequential):
|
| 6 |
+
def __init__(self, dim):
|
| 7 |
+
super().__init__(
|
| 8 |
+
nn.GroupNorm(dim // 16, dim),
|
| 9 |
+
nn.GELU(),
|
| 10 |
+
nn.Conv2d(dim, dim, 3, padding=1),
|
| 11 |
+
nn.GroupNorm(dim // 16, dim),
|
| 12 |
+
nn.GELU(),
|
| 13 |
+
nn.Conv2d(dim, dim, 3, padding=1),
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
def forward(self, x):
|
| 17 |
+
return x + super().forward(x)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class UNetBlock(nn.Module):
|
| 21 |
+
def __init__(self, input_dim, output_dim=None, scale_factor=1.0):
|
| 22 |
+
super().__init__()
|
| 23 |
+
if output_dim is None:
|
| 24 |
+
output_dim = input_dim
|
| 25 |
+
self.pre_conv = nn.Conv2d(input_dim, output_dim, 3, padding=1)
|
| 26 |
+
self.res_block1 = PreactResBlock(output_dim)
|
| 27 |
+
self.res_block2 = PreactResBlock(output_dim)
|
| 28 |
+
self.downsample = self.upsample = nn.Identity()
|
| 29 |
+
if scale_factor > 1:
|
| 30 |
+
self.upsample = nn.Upsample(scale_factor=scale_factor)
|
| 31 |
+
elif scale_factor < 1:
|
| 32 |
+
self.downsample = nn.Upsample(scale_factor=scale_factor)
|
| 33 |
+
|
| 34 |
+
def forward(self, x, h=None):
|
| 35 |
+
"""
|
| 36 |
+
Args:
|
| 37 |
+
x: (b c h w), last output
|
| 38 |
+
h: (b c h w), skip output
|
| 39 |
+
Returns:
|
| 40 |
+
o: (b c h w), output
|
| 41 |
+
s: (b c h w), skip output
|
| 42 |
+
"""
|
| 43 |
+
x = self.upsample(x)
|
| 44 |
+
if h is not None:
|
| 45 |
+
assert x.shape == h.shape, f"{x.shape} != {h.shape}"
|
| 46 |
+
x = x + h
|
| 47 |
+
x = self.pre_conv(x)
|
| 48 |
+
x = self.res_block1(x)
|
| 49 |
+
x = self.res_block2(x)
|
| 50 |
+
return self.downsample(x), x
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class UNet(nn.Module):
|
| 54 |
+
def __init__(self, input_dim, output_dim, hidden_dim=16, num_blocks=4, num_middle_blocks=2):
|
| 55 |
+
super().__init__()
|
| 56 |
+
self.input_dim = input_dim
|
| 57 |
+
self.output_dim = output_dim
|
| 58 |
+
self.input_proj = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
|
| 59 |
+
self.encoder_blocks = nn.ModuleList(
|
| 60 |
+
[
|
| 61 |
+
UNetBlock(input_dim=hidden_dim * 2**i, output_dim=hidden_dim * 2 ** (i + 1), scale_factor=0.5)
|
| 62 |
+
for i in range(num_blocks)
|
| 63 |
+
]
|
| 64 |
+
)
|
| 65 |
+
self.middle_blocks = nn.ModuleList(
|
| 66 |
+
[UNetBlock(input_dim=hidden_dim * 2**num_blocks) for _ in range(num_middle_blocks)]
|
| 67 |
+
)
|
| 68 |
+
self.decoder_blocks = nn.ModuleList(
|
| 69 |
+
[
|
| 70 |
+
UNetBlock(input_dim=hidden_dim * 2 ** (i + 1), output_dim=hidden_dim * 2**i, scale_factor=2)
|
| 71 |
+
for i in reversed(range(num_blocks))
|
| 72 |
+
]
|
| 73 |
+
)
|
| 74 |
+
self.head = nn.Sequential(
|
| 75 |
+
nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
|
| 76 |
+
nn.GELU(),
|
| 77 |
+
nn.Conv2d(hidden_dim, output_dim, 1),
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
@property
|
| 81 |
+
def scale_factor(self):
|
| 82 |
+
return 2 ** len(self.encoder_blocks)
|
| 83 |
+
|
| 84 |
+
def pad_to_fit(self, x):
|
| 85 |
+
"""
|
| 86 |
+
Args:
|
| 87 |
+
x: (b c h w), input
|
| 88 |
+
Returns:
|
| 89 |
+
x: (b c h' w'), padded input
|
| 90 |
+
"""
|
| 91 |
+
hpad = (self.scale_factor - x.shape[2] % self.scale_factor) % self.scale_factor
|
| 92 |
+
wpad = (self.scale_factor - x.shape[3] % self.scale_factor) % self.scale_factor
|
| 93 |
+
return F.pad(x, (0, wpad, 0, hpad))
|
| 94 |
+
|
| 95 |
+
def forward(self, x):
|
| 96 |
+
"""
|
| 97 |
+
Args:
|
| 98 |
+
x: (b c h w), input
|
| 99 |
+
Returns:
|
| 100 |
+
o: (b c h w), output
|
| 101 |
+
"""
|
| 102 |
+
shape = x.shape
|
| 103 |
+
|
| 104 |
+
x = self.pad_to_fit(x)
|
| 105 |
+
x = self.input_proj(x)
|
| 106 |
+
|
| 107 |
+
s_list = []
|
| 108 |
+
for block in self.encoder_blocks:
|
| 109 |
+
x, s = block(x)
|
| 110 |
+
s_list.append(s)
|
| 111 |
+
|
| 112 |
+
for block in self.middle_blocks:
|
| 113 |
+
x, _ = block(x)
|
| 114 |
+
|
| 115 |
+
for block, s in zip(self.decoder_blocks, reversed(s_list)):
|
| 116 |
+
x, _ = block(x, s)
|
| 117 |
+
|
| 118 |
+
x = self.head(x)
|
| 119 |
+
x = x[..., : shape[2], : shape[3]]
|
| 120 |
+
|
| 121 |
+
return x
|
| 122 |
+
|
| 123 |
+
def test(self, shape=(3, 512, 256)):
|
| 124 |
+
import ptflops
|
| 125 |
+
|
| 126 |
+
macs, params = ptflops.get_model_complexity_info(
|
| 127 |
+
self,
|
| 128 |
+
shape,
|
| 129 |
+
as_strings=True,
|
| 130 |
+
print_per_layer_stat=True,
|
| 131 |
+
verbose=True,
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
print(f"macs: {macs}")
|
| 135 |
+
print(f"params: {params}")
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def main():
|
| 139 |
+
model = UNet(3, 3)
|
| 140 |
+
model.test()
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
if __name__ == "__main__":
|
| 144 |
+
main()
|
modules/repos_static/resemble_enhance/enhancer/__init__.py
ADDED
|
File without changes
|
modules/repos_static/resemble_enhance/enhancer/__main__.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import random
|
| 3 |
+
import time
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torchaudio
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
|
| 10 |
+
from .inference import denoise, enhance
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@torch.inference_mode()
|
| 14 |
+
def main():
|
| 15 |
+
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
| 16 |
+
parser.add_argument("in_dir", type=Path, help="Path to input audio folder")
|
| 17 |
+
parser.add_argument("out_dir", type=Path, help="Output folder")
|
| 18 |
+
parser.add_argument(
|
| 19 |
+
"--run_dir",
|
| 20 |
+
type=Path,
|
| 21 |
+
default=None,
|
| 22 |
+
help="Path to the enhancer run folder, if None, use the default model",
|
| 23 |
+
)
|
| 24 |
+
parser.add_argument(
|
| 25 |
+
"--suffix",
|
| 26 |
+
type=str,
|
| 27 |
+
default=".wav",
|
| 28 |
+
help="Audio file suffix",
|
| 29 |
+
)
|
| 30 |
+
parser.add_argument(
|
| 31 |
+
"--device",
|
| 32 |
+
type=str,
|
| 33 |
+
default="cuda",
|
| 34 |
+
help="Device to use for computation, recommended to use CUDA",
|
| 35 |
+
)
|
| 36 |
+
parser.add_argument(
|
| 37 |
+
"--denoise_only",
|
| 38 |
+
action="store_true",
|
| 39 |
+
help="Only apply denoising without enhancement",
|
| 40 |
+
)
|
| 41 |
+
parser.add_argument(
|
| 42 |
+
"--lambd",
|
| 43 |
+
type=float,
|
| 44 |
+
default=1.0,
|
| 45 |
+
help="Denoise strength for enhancement (0.0 to 1.0)",
|
| 46 |
+
)
|
| 47 |
+
parser.add_argument(
|
| 48 |
+
"--tau",
|
| 49 |
+
type=float,
|
| 50 |
+
default=0.5,
|
| 51 |
+
help="CFM prior temperature (0.0 to 1.0)",
|
| 52 |
+
)
|
| 53 |
+
parser.add_argument(
|
| 54 |
+
"--solver",
|
| 55 |
+
type=str,
|
| 56 |
+
default="midpoint",
|
| 57 |
+
choices=["midpoint", "rk4", "euler"],
|
| 58 |
+
help="Numerical solver to use",
|
| 59 |
+
)
|
| 60 |
+
parser.add_argument(
|
| 61 |
+
"--nfe",
|
| 62 |
+
type=int,
|
| 63 |
+
default=64,
|
| 64 |
+
help="Number of function evaluations",
|
| 65 |
+
)
|
| 66 |
+
parser.add_argument(
|
| 67 |
+
"--parallel_mode",
|
| 68 |
+
action="store_true",
|
| 69 |
+
help="Shuffle the audio paths and skip the existing ones, enabling multiple jobs to run in parallel",
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
args = parser.parse_args()
|
| 73 |
+
|
| 74 |
+
device = args.device
|
| 75 |
+
|
| 76 |
+
if device == "cuda" and not torch.cuda.is_available():
|
| 77 |
+
print("CUDA is not available but --device is set to cuda, using CPU instead")
|
| 78 |
+
device = "cpu"
|
| 79 |
+
|
| 80 |
+
start_time = time.perf_counter()
|
| 81 |
+
|
| 82 |
+
run_dir = args.run_dir
|
| 83 |
+
|
| 84 |
+
paths = sorted(args.in_dir.glob(f"**/*{args.suffix}"))
|
| 85 |
+
|
| 86 |
+
if args.parallel_mode:
|
| 87 |
+
random.shuffle(paths)
|
| 88 |
+
|
| 89 |
+
if len(paths) == 0:
|
| 90 |
+
print(f"No {args.suffix} files found in the following path: {args.in_dir}")
|
| 91 |
+
return
|
| 92 |
+
|
| 93 |
+
pbar = tqdm(paths)
|
| 94 |
+
|
| 95 |
+
for path in pbar:
|
| 96 |
+
out_path = args.out_dir / path.relative_to(args.in_dir)
|
| 97 |
+
if args.parallel_mode and out_path.exists():
|
| 98 |
+
continue
|
| 99 |
+
pbar.set_description(f"Processing {out_path}")
|
| 100 |
+
dwav, sr = torchaudio.load(path)
|
| 101 |
+
dwav = dwav.mean(0)
|
| 102 |
+
if args.denoise_only:
|
| 103 |
+
hwav, sr = denoise(
|
| 104 |
+
dwav=dwav,
|
| 105 |
+
sr=sr,
|
| 106 |
+
device=device,
|
| 107 |
+
run_dir=args.run_dir,
|
| 108 |
+
)
|
| 109 |
+
else:
|
| 110 |
+
hwav, sr = enhance(
|
| 111 |
+
dwav=dwav,
|
| 112 |
+
sr=sr,
|
| 113 |
+
device=device,
|
| 114 |
+
nfe=args.nfe,
|
| 115 |
+
solver=args.solver,
|
| 116 |
+
lambd=args.lambd,
|
| 117 |
+
tau=args.tau,
|
| 118 |
+
run_dir=run_dir,
|
| 119 |
+
)
|
| 120 |
+
out_path.parent.mkdir(parents=True, exist_ok=True)
|
| 121 |
+
torchaudio.save(out_path, hwav[None], sr)
|
| 122 |
+
|
| 123 |
+
# Cool emoji effect saying the job is done
|
| 124 |
+
elapsed_time = time.perf_counter() - start_time
|
| 125 |
+
print(f"🌟 Enhancement done! {len(paths)} files processed in {elapsed_time:.2f}s")
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
if __name__ == "__main__":
|
| 129 |
+
main()
|
modules/repos_static/resemble_enhance/enhancer/download.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
RUN_NAME = "enhancer_stage2"
|
| 7 |
+
|
| 8 |
+
logger = logging.getLogger(__name__)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def get_source_url(relpath):
|
| 12 |
+
return f"https://huggingface.co/ResembleAI/resemble-enhance/resolve/main/{RUN_NAME}/{relpath}?download=true"
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def get_target_path(relpath: str | Path, run_dir: str | Path | None = None):
|
| 16 |
+
if run_dir is None:
|
| 17 |
+
run_dir = Path(__file__).parent.parent / "model_repo" / RUN_NAME
|
| 18 |
+
return Path(run_dir) / relpath
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def download(run_dir: str | Path | None = None):
|
| 22 |
+
relpaths = ["hparams.yaml", "ds/G/latest", "ds/G/default/mp_rank_00_model_states.pt"]
|
| 23 |
+
for relpath in relpaths:
|
| 24 |
+
path = get_target_path(relpath, run_dir=run_dir)
|
| 25 |
+
if path.exists():
|
| 26 |
+
continue
|
| 27 |
+
url = get_source_url(relpath)
|
| 28 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 29 |
+
torch.hub.download_url_to_file(url, str(path))
|
| 30 |
+
return get_target_path("", run_dir=run_dir)
|
modules/repos_static/resemble_enhance/enhancer/enhancer.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import torch
|
| 6 |
+
from torch import Tensor, nn
|
| 7 |
+
from torch.distributions import Beta
|
| 8 |
+
|
| 9 |
+
from ..common import Normalizer
|
| 10 |
+
from ..denoiser.inference import load_denoiser
|
| 11 |
+
from ..melspec import MelSpectrogram
|
| 12 |
+
from .hparams import HParams
|
| 13 |
+
from .lcfm import CFM, IRMAE, LCFM
|
| 14 |
+
from .univnet import UnivNet
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def _maybe(fn):
|
| 20 |
+
def _fn(*args):
|
| 21 |
+
if args[0] is None:
|
| 22 |
+
return None
|
| 23 |
+
return fn(*args)
|
| 24 |
+
|
| 25 |
+
return _fn
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def _normalize_wav(x: Tensor):
|
| 29 |
+
return x / (x.abs().max(dim=-1, keepdim=True).values + 1e-7)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class Enhancer(nn.Module):
|
| 33 |
+
def __init__(self, hp: HParams):
|
| 34 |
+
super().__init__()
|
| 35 |
+
self.hp = hp
|
| 36 |
+
|
| 37 |
+
n_mels = self.hp.num_mels
|
| 38 |
+
vocoder_input_dim = n_mels + self.hp.vocoder_extra_dim
|
| 39 |
+
latent_dim = self.hp.lcfm_latent_dim
|
| 40 |
+
|
| 41 |
+
self.lcfm = LCFM(
|
| 42 |
+
IRMAE(
|
| 43 |
+
input_dim=n_mels,
|
| 44 |
+
output_dim=vocoder_input_dim,
|
| 45 |
+
latent_dim=latent_dim,
|
| 46 |
+
),
|
| 47 |
+
CFM(
|
| 48 |
+
cond_dim=n_mels,
|
| 49 |
+
output_dim=self.hp.lcfm_latent_dim,
|
| 50 |
+
solver_nfe=self.hp.cfm_solver_nfe,
|
| 51 |
+
solver_method=self.hp.cfm_solver_method,
|
| 52 |
+
time_mapping_divisor=self.hp.cfm_time_mapping_divisor,
|
| 53 |
+
),
|
| 54 |
+
z_scale=self.hp.lcfm_z_scale,
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
self.lcfm.set_mode_(self.hp.lcfm_training_mode)
|
| 58 |
+
|
| 59 |
+
self.mel_fn = MelSpectrogram(hp)
|
| 60 |
+
self.vocoder = UnivNet(self.hp, vocoder_input_dim)
|
| 61 |
+
self.denoiser = load_denoiser(self.hp.denoiser_run_dir, "cpu")
|
| 62 |
+
self.normalizer = Normalizer()
|
| 63 |
+
|
| 64 |
+
self._eval_lambd = 0.0
|
| 65 |
+
|
| 66 |
+
self.dummy: Tensor
|
| 67 |
+
self.register_buffer("dummy", torch.zeros(1))
|
| 68 |
+
|
| 69 |
+
if self.hp.enhancer_stage1_run_dir is not None:
|
| 70 |
+
pretrained_path = (
|
| 71 |
+
self.hp.enhancer_stage1_run_dir
|
| 72 |
+
/ "ds/G/default/mp_rank_00_model_states.pt"
|
| 73 |
+
)
|
| 74 |
+
self._load_pretrained(pretrained_path)
|
| 75 |
+
|
| 76 |
+
logger.info(f"{self.__class__.__name__} summary")
|
| 77 |
+
logger.info(f"{self.summarize()}")
|
| 78 |
+
|
| 79 |
+
def _load_pretrained(self, path):
|
| 80 |
+
# Clone is necessary as otherwise it holds a reference to the original model
|
| 81 |
+
cfm_state_dict = {k: v.clone() for k, v in self.lcfm.cfm.state_dict().items()}
|
| 82 |
+
denoiser_state_dict = {
|
| 83 |
+
k: v.clone() for k, v in self.denoiser.state_dict().items()
|
| 84 |
+
}
|
| 85 |
+
state_dict = torch.load(path, map_location="cpu")["module"]
|
| 86 |
+
self.load_state_dict(state_dict, strict=False)
|
| 87 |
+
self.lcfm.cfm.load_state_dict(cfm_state_dict) # Reset cfm
|
| 88 |
+
self.denoiser.load_state_dict(denoiser_state_dict) # Reset denoiser
|
| 89 |
+
logger.info(f"Loaded pretrained model from {path}")
|
| 90 |
+
|
| 91 |
+
def summarize(self):
|
| 92 |
+
npa_train = lambda m: sum(p.numel() for p in m.parameters() if p.requires_grad)
|
| 93 |
+
npa = lambda m: sum(p.numel() for p in m.parameters())
|
| 94 |
+
rows = []
|
| 95 |
+
for name, module in self.named_children():
|
| 96 |
+
rows.append(dict(name=name, trainable=npa_train(module), total=npa(module)))
|
| 97 |
+
rows.append(dict(name="total", trainable=npa_train(self), total=npa(self)))
|
| 98 |
+
df = pd.DataFrame(rows)
|
| 99 |
+
return df.to_markdown(index=False)
|
| 100 |
+
|
| 101 |
+
def to_mel(self, x: Tensor, drop_last=True):
|
| 102 |
+
"""
|
| 103 |
+
Args:
|
| 104 |
+
x: (b t), wavs
|
| 105 |
+
Returns:
|
| 106 |
+
o: (b c t), mels
|
| 107 |
+
"""
|
| 108 |
+
if drop_last:
|
| 109 |
+
return self.mel_fn(x)[..., :-1] # (b d t)
|
| 110 |
+
return self.mel_fn(x)
|
| 111 |
+
|
| 112 |
+
def _may_denoise(self, x: Tensor, y: Tensor | None = None):
|
| 113 |
+
if self.hp.lcfm_training_mode == "cfm":
|
| 114 |
+
return self.denoiser(x, y)
|
| 115 |
+
return x
|
| 116 |
+
|
| 117 |
+
def configurate_(self, nfe, solver, lambd, tau):
|
| 118 |
+
"""
|
| 119 |
+
Args:
|
| 120 |
+
nfe: number of function evaluations
|
| 121 |
+
solver: solver method
|
| 122 |
+
lambd: denoiser strength [0, 1]
|
| 123 |
+
tau: prior temperature [0, 1]
|
| 124 |
+
"""
|
| 125 |
+
self.lcfm.cfm.solver.configurate_(nfe, solver)
|
| 126 |
+
self.lcfm.eval_tau_(tau)
|
| 127 |
+
self._eval_lambd = lambd
|
| 128 |
+
|
| 129 |
+
def forward(self, x: Tensor, y: Tensor | None = None, z: Tensor | None = None):
|
| 130 |
+
"""
|
| 131 |
+
Args:
|
| 132 |
+
x: (b t), mix wavs (fg + bg)
|
| 133 |
+
y: (b t), fg clean wavs
|
| 134 |
+
z: (b t), fg distorted wavs
|
| 135 |
+
Returns:
|
| 136 |
+
o: (b t), reconstructed wavs
|
| 137 |
+
"""
|
| 138 |
+
assert x.dim() == 2, f"Expected (b t), got {x.size()}"
|
| 139 |
+
assert y is None or y.dim() == 2, f"Expected (b t), got {y.size()}"
|
| 140 |
+
|
| 141 |
+
if self.hp.lcfm_training_mode == "cfm":
|
| 142 |
+
self.normalizer.eval()
|
| 143 |
+
|
| 144 |
+
x = _normalize_wav(x)
|
| 145 |
+
y = _maybe(_normalize_wav)(y)
|
| 146 |
+
z = _maybe(_normalize_wav)(z)
|
| 147 |
+
|
| 148 |
+
x_mel_original = self.normalizer(self.to_mel(x), update=False) # (b d t)
|
| 149 |
+
|
| 150 |
+
if self.hp.lcfm_training_mode == "cfm":
|
| 151 |
+
if self.training:
|
| 152 |
+
lambd = Beta(0.2, 0.2).sample(x.shape[:1]).to(x.device)
|
| 153 |
+
lambd = lambd[:, None, None]
|
| 154 |
+
x_mel_denoised = self.normalizer(
|
| 155 |
+
self.to_mel(self._may_denoise(x, z)), update=False
|
| 156 |
+
)
|
| 157 |
+
x_mel_denoised = x_mel_denoised.detach()
|
| 158 |
+
x_mel_denoised = lambd * x_mel_denoised + (1 - lambd) * x_mel_original
|
| 159 |
+
self._visualize(x_mel_original, x_mel_denoised)
|
| 160 |
+
else:
|
| 161 |
+
lambd = self._eval_lambd
|
| 162 |
+
if lambd == 0:
|
| 163 |
+
x_mel_denoised = x_mel_original
|
| 164 |
+
else:
|
| 165 |
+
x_mel_denoised = self.normalizer(
|
| 166 |
+
self.to_mel(self._may_denoise(x, z)), update=False
|
| 167 |
+
)
|
| 168 |
+
x_mel_denoised = x_mel_denoised.detach()
|
| 169 |
+
x_mel_denoised = (
|
| 170 |
+
lambd * x_mel_denoised + (1 - lambd) * x_mel_original
|
| 171 |
+
)
|
| 172 |
+
else:
|
| 173 |
+
x_mel_denoised = x_mel_original
|
| 174 |
+
|
| 175 |
+
y_mel = _maybe(self.to_mel)(y) # (b d t)
|
| 176 |
+
y_mel = _maybe(self.normalizer)(y_mel)
|
| 177 |
+
|
| 178 |
+
lcfm_decoded = self.lcfm(x_mel_denoised, y_mel, ψ0=x_mel_original) # (b d t)
|
| 179 |
+
|
| 180 |
+
if lcfm_decoded is None:
|
| 181 |
+
o = None
|
| 182 |
+
else:
|
| 183 |
+
o = self.vocoder(lcfm_decoded, y)
|
| 184 |
+
|
| 185 |
+
return o
|
modules/repos_static/resemble_enhance/enhancer/hparams.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
from ..hparams import HParams as HParamsBase
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@dataclass(frozen=True)
|
| 8 |
+
class HParams(HParamsBase):
|
| 9 |
+
cfm_solver_method: str = "midpoint"
|
| 10 |
+
cfm_solver_nfe: int = 64
|
| 11 |
+
cfm_time_mapping_divisor: int = 4
|
| 12 |
+
univnet_nc: int = 96
|
| 13 |
+
|
| 14 |
+
lcfm_latent_dim: int = 64
|
| 15 |
+
lcfm_training_mode: str = "ae"
|
| 16 |
+
lcfm_z_scale: float = 5
|
| 17 |
+
|
| 18 |
+
vocoder_extra_dim: int = 32
|
| 19 |
+
|
| 20 |
+
gan_training_start_step: int | None = 5_000
|
| 21 |
+
enhancer_stage1_run_dir: Path | None = None
|
| 22 |
+
|
| 23 |
+
denoiser_run_dir: Path | None = None
|
modules/repos_static/resemble_enhance/enhancer/inference.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from functools import cache
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
from ..inference import inference
|
| 8 |
+
from .download import download
|
| 9 |
+
from .hparams import HParams
|
| 10 |
+
from .enhancer import Enhancer
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@cache
|
| 16 |
+
def load_enhancer(run_dir: str | Path | None, device):
|
| 17 |
+
run_dir = download(run_dir)
|
| 18 |
+
hp = HParams.load(run_dir)
|
| 19 |
+
enhancer = Enhancer(hp)
|
| 20 |
+
path = run_dir / "ds" / "G" / "default" / "mp_rank_00_model_states.pt"
|
| 21 |
+
state_dict = torch.load(path, map_location="cpu")["module"]
|
| 22 |
+
enhancer.load_state_dict(state_dict)
|
| 23 |
+
enhancer.eval()
|
| 24 |
+
enhancer.to(device)
|
| 25 |
+
return enhancer
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@torch.inference_mode()
|
| 29 |
+
def denoise(dwav, sr, device, run_dir=None):
|
| 30 |
+
enhancer = load_enhancer(run_dir, device)
|
| 31 |
+
return inference(model=enhancer.denoiser, dwav=dwav, sr=sr, device=device)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@torch.inference_mode()
|
| 35 |
+
def enhance(
|
| 36 |
+
dwav, sr, device, nfe=32, solver="midpoint", lambd=0.5, tau=0.5, run_dir=None
|
| 37 |
+
):
|
| 38 |
+
assert 0 < nfe <= 128, f"nfe must be in (0, 128], got {nfe}"
|
| 39 |
+
assert solver in (
|
| 40 |
+
"midpoint",
|
| 41 |
+
"rk4",
|
| 42 |
+
"euler",
|
| 43 |
+
), f"solver must be in ('midpoint', 'rk4', 'euler'), got {solver}"
|
| 44 |
+
assert 0 <= lambd <= 1, f"lambd must be in [0, 1], got {lambd}"
|
| 45 |
+
assert 0 <= tau <= 1, f"tau must be in [0, 1], got {tau}"
|
| 46 |
+
enhancer = load_enhancer(run_dir, device)
|
| 47 |
+
enhancer.configurate_(nfe=nfe, solver=solver, lambd=lambd, tau=tau)
|
| 48 |
+
return inference(model=enhancer, dwav=dwav, sr=sr, device=device)
|
modules/repos_static/resemble_enhance/enhancer/lcfm/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .irmae import IRMAE
|
| 2 |
+
from .lcfm import CFM, LCFM
|
modules/repos_static/resemble_enhance/enhancer/lcfm/cfm.py
ADDED
|
@@ -0,0 +1,372 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from functools import partial
|
| 4 |
+
from typing import Protocol
|
| 5 |
+
|
| 6 |
+
import matplotlib.pyplot as plt
|
| 7 |
+
import numpy as np
|
| 8 |
+
import scipy
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from torch import Tensor, nn
|
| 12 |
+
from tqdm import trange
|
| 13 |
+
|
| 14 |
+
from .wn import WN
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class VelocityField(Protocol):
|
| 20 |
+
def __call__(self, *, t: Tensor, ψt: Tensor, dt: Tensor) -> Tensor:
|
| 21 |
+
...
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class Solver:
|
| 25 |
+
def __init__(
|
| 26 |
+
self,
|
| 27 |
+
method="midpoint",
|
| 28 |
+
nfe=32,
|
| 29 |
+
viz_name="solver",
|
| 30 |
+
viz_every=100,
|
| 31 |
+
mel_fn=None,
|
| 32 |
+
time_mapping_divisor=4,
|
| 33 |
+
verbose=False,
|
| 34 |
+
):
|
| 35 |
+
self.configurate_(nfe=nfe, method=method)
|
| 36 |
+
|
| 37 |
+
self.verbose = verbose
|
| 38 |
+
self.viz_every = viz_every
|
| 39 |
+
self.viz_name = viz_name
|
| 40 |
+
|
| 41 |
+
self._camera = None
|
| 42 |
+
self._mel_fn = mel_fn
|
| 43 |
+
self._time_mapping = partial(self.exponential_decay_mapping, n=time_mapping_divisor)
|
| 44 |
+
|
| 45 |
+
def configurate_(self, nfe=None, method=None):
|
| 46 |
+
if nfe is None:
|
| 47 |
+
nfe = self.nfe
|
| 48 |
+
|
| 49 |
+
if method is None:
|
| 50 |
+
method = self.method
|
| 51 |
+
|
| 52 |
+
if nfe == 1 and method in ("midpoint", "rk4"):
|
| 53 |
+
logger.warning(f"1 NFE is not supported for {method}, using euler method instead.")
|
| 54 |
+
method = "euler"
|
| 55 |
+
|
| 56 |
+
self.nfe = nfe
|
| 57 |
+
self.method = method
|
| 58 |
+
|
| 59 |
+
@property
|
| 60 |
+
def time_mapping(self):
|
| 61 |
+
return self._time_mapping
|
| 62 |
+
|
| 63 |
+
@staticmethod
|
| 64 |
+
def exponential_decay_mapping(t, n=4):
|
| 65 |
+
"""
|
| 66 |
+
Args:
|
| 67 |
+
n: target step
|
| 68 |
+
"""
|
| 69 |
+
|
| 70 |
+
def h(t, a):
|
| 71 |
+
return (a**t - 1) / (a - 1)
|
| 72 |
+
|
| 73 |
+
# Solve h(1/n) = 0.5
|
| 74 |
+
a = float(scipy.optimize.fsolve(lambda a: h(1 / n, a) - 0.5, x0=0))
|
| 75 |
+
|
| 76 |
+
t = h(t, a=a)
|
| 77 |
+
|
| 78 |
+
return t
|
| 79 |
+
|
| 80 |
+
@torch.no_grad()
|
| 81 |
+
def _maybe_camera_snap(self, *, ψt, t):
|
| 82 |
+
camera = self._camera
|
| 83 |
+
if camera is not None:
|
| 84 |
+
if ψt.shape[1] == 1:
|
| 85 |
+
# Waveform, b 1 t, plot every 100 samples
|
| 86 |
+
plt.subplot(211)
|
| 87 |
+
plt.plot(ψt.detach().cpu().numpy()[0, 0, ::100], color="blue")
|
| 88 |
+
if self._mel_fn is not None:
|
| 89 |
+
plt.subplot(212)
|
| 90 |
+
mel = self._mel_fn(ψt.detach().cpu().numpy()[0, 0])
|
| 91 |
+
plt.imshow(mel, origin="lower", interpolation="none")
|
| 92 |
+
elif ψt.shape[1] == 2:
|
| 93 |
+
# Complex
|
| 94 |
+
plt.subplot(121)
|
| 95 |
+
plt.imshow(
|
| 96 |
+
ψt.detach().cpu().numpy()[0, 0],
|
| 97 |
+
origin="lower",
|
| 98 |
+
interpolation="none",
|
| 99 |
+
)
|
| 100 |
+
plt.subplot(122)
|
| 101 |
+
plt.imshow(
|
| 102 |
+
ψt.detach().cpu().numpy()[0, 1],
|
| 103 |
+
origin="lower",
|
| 104 |
+
interpolation="none",
|
| 105 |
+
)
|
| 106 |
+
else:
|
| 107 |
+
# Spectrogram, b c t
|
| 108 |
+
plt.imshow(ψt.detach().cpu().numpy()[0], origin="lower", interpolation="none")
|
| 109 |
+
ax = plt.gca()
|
| 110 |
+
ax.text(0.5, 1.01, f"t={t:.2f}", transform=ax.transAxes, ha="center")
|
| 111 |
+
camera.snap()
|
| 112 |
+
|
| 113 |
+
@staticmethod
|
| 114 |
+
def _euler_step(t, ψt, dt, f: VelocityField):
|
| 115 |
+
return ψt + dt * f(t=t, ψt=ψt, dt=dt)
|
| 116 |
+
|
| 117 |
+
@staticmethod
|
| 118 |
+
def _midpoint_step(t, ψt, dt, f: VelocityField):
|
| 119 |
+
return ψt + dt * f(t=t + dt / 2, ψt=ψt + dt * f(t=t, ψt=ψt, dt=dt) / 2, dt=dt)
|
| 120 |
+
|
| 121 |
+
@staticmethod
|
| 122 |
+
def _rk4_step(t, ψt, dt, f: VelocityField):
|
| 123 |
+
k1 = f(t=t, ψt=ψt, dt=dt)
|
| 124 |
+
k2 = f(t=t + dt / 2, ψt=ψt + dt * k1 / 2, dt=dt)
|
| 125 |
+
k3 = f(t=t + dt / 2, ψt=ψt + dt * k2 / 2, dt=dt)
|
| 126 |
+
k4 = f(t=t + dt, ψt=ψt + dt * k3, dt=dt)
|
| 127 |
+
return ψt + dt * (k1 + 2 * k2 + 2 * k3 + k4) / 6
|
| 128 |
+
|
| 129 |
+
@property
|
| 130 |
+
def _step(self):
|
| 131 |
+
if self.method == "euler":
|
| 132 |
+
return self._euler_step
|
| 133 |
+
elif self.method == "midpoint":
|
| 134 |
+
return self._midpoint_step
|
| 135 |
+
elif self.method == "rk4":
|
| 136 |
+
return self._rk4_step
|
| 137 |
+
else:
|
| 138 |
+
raise ValueError(f"Unknown method: {self.method}")
|
| 139 |
+
|
| 140 |
+
def get_running_train_loop(self):
|
| 141 |
+
try:
|
| 142 |
+
# Lazy import
|
| 143 |
+
from ...utils.train_loop import TrainLoop
|
| 144 |
+
|
| 145 |
+
return TrainLoop.get_running_loop()
|
| 146 |
+
except ImportError:
|
| 147 |
+
return None
|
| 148 |
+
|
| 149 |
+
@property
|
| 150 |
+
def visualizing(self):
|
| 151 |
+
loop = self.get_running_train_loop()
|
| 152 |
+
if loop is None:
|
| 153 |
+
return
|
| 154 |
+
out_path = loop.make_current_step_viz_path(self.viz_name, ".gif")
|
| 155 |
+
return loop.global_step % self.viz_every == 0 and not out_path.exists()
|
| 156 |
+
|
| 157 |
+
def _reset_camera(self):
|
| 158 |
+
try:
|
| 159 |
+
from celluloid import Camera
|
| 160 |
+
|
| 161 |
+
self._camera = Camera(plt.figure())
|
| 162 |
+
except:
|
| 163 |
+
pass
|
| 164 |
+
|
| 165 |
+
def _maybe_dump_camera(self):
|
| 166 |
+
camera = self._camera
|
| 167 |
+
loop = self.get_running_train_loop()
|
| 168 |
+
if camera is not None and loop is not None:
|
| 169 |
+
animation = camera.animate()
|
| 170 |
+
out_path = loop.make_current_step_viz_path(self.viz_name, ".gif")
|
| 171 |
+
out_path.parent.mkdir(exist_ok=True, parents=True)
|
| 172 |
+
animation.save(out_path, writer="pillow", fps=4)
|
| 173 |
+
plt.close()
|
| 174 |
+
self._camera = None
|
| 175 |
+
|
| 176 |
+
@property
|
| 177 |
+
def n_steps(self):
|
| 178 |
+
n = self.nfe
|
| 179 |
+
if self.method == "euler":
|
| 180 |
+
pass
|
| 181 |
+
elif self.method == "midpoint":
|
| 182 |
+
n //= 2
|
| 183 |
+
elif self.method == "rk4":
|
| 184 |
+
n //= 4
|
| 185 |
+
else:
|
| 186 |
+
raise ValueError(f"Unknown method: {self.method}")
|
| 187 |
+
return n
|
| 188 |
+
|
| 189 |
+
def solve(self, f: VelocityField, ψ0: Tensor, t0=0.0, t1=1.0):
|
| 190 |
+
ts = self._time_mapping(np.linspace(t0, t1, self.n_steps + 1))
|
| 191 |
+
|
| 192 |
+
if self.visualizing:
|
| 193 |
+
self._reset_camera()
|
| 194 |
+
|
| 195 |
+
if self.verbose:
|
| 196 |
+
steps = trange(self.n_steps, desc="CFM inference")
|
| 197 |
+
else:
|
| 198 |
+
steps = range(self.n_steps)
|
| 199 |
+
|
| 200 |
+
ψt = ψ0
|
| 201 |
+
|
| 202 |
+
for i in steps:
|
| 203 |
+
dt = ts[i + 1] - ts[i]
|
| 204 |
+
t = ts[i]
|
| 205 |
+
self._maybe_camera_snap(ψt=ψt, t=t)
|
| 206 |
+
ψt = self._step(t=t, ψt=ψt, dt=dt, f=f)
|
| 207 |
+
|
| 208 |
+
self._maybe_camera_snap(ψt=ψt, t=ts[-1])
|
| 209 |
+
|
| 210 |
+
ψ1 = ψt
|
| 211 |
+
del ψt
|
| 212 |
+
|
| 213 |
+
self._maybe_dump_camera()
|
| 214 |
+
|
| 215 |
+
return ψ1
|
| 216 |
+
|
| 217 |
+
def __call__(self, f: VelocityField, ψ0: Tensor, t0=0.0, t1=1.0):
|
| 218 |
+
return self.solve(f=f, ψ0=ψ0, t0=t0, t1=t1)
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
class SinusodialTimeEmbedding(nn.Module):
|
| 222 |
+
def __init__(self, d_embed):
|
| 223 |
+
super().__init__()
|
| 224 |
+
self.d_embed = d_embed
|
| 225 |
+
assert d_embed % 2 == 0
|
| 226 |
+
|
| 227 |
+
def forward(self, t):
|
| 228 |
+
t = t.unsqueeze(-1) # ... 1
|
| 229 |
+
p = torch.linspace(0, 4, self.d_embed // 2).to(t)
|
| 230 |
+
while p.dim() < t.dim():
|
| 231 |
+
p = p.unsqueeze(0) # ... d/2
|
| 232 |
+
sin = torch.sin(t * 10**p)
|
| 233 |
+
cos = torch.cos(t * 10**p)
|
| 234 |
+
return torch.cat([sin, cos], dim=-1)
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
@dataclass(eq=False)
|
| 238 |
+
class CFM(nn.Module):
|
| 239 |
+
"""
|
| 240 |
+
This mixin is for general diffusion models.
|
| 241 |
+
|
| 242 |
+
ψ0 stands for the gaussian noise, and ψ1 is the data point.
|
| 243 |
+
|
| 244 |
+
Here we follow the CFM style:
|
| 245 |
+
The generation process (reverse process) is from t=0 to t=1.
|
| 246 |
+
The forward process is from t=1 to t=0.
|
| 247 |
+
"""
|
| 248 |
+
|
| 249 |
+
cond_dim: int
|
| 250 |
+
output_dim: int
|
| 251 |
+
time_emb_dim: int = 128
|
| 252 |
+
viz_name: str = "cfm"
|
| 253 |
+
solver_nfe: int = 32
|
| 254 |
+
solver_method: str = "midpoint"
|
| 255 |
+
time_mapping_divisor: int = 4
|
| 256 |
+
|
| 257 |
+
def __post_init__(self):
|
| 258 |
+
super().__init__()
|
| 259 |
+
self.solver = Solver(
|
| 260 |
+
viz_name=self.viz_name,
|
| 261 |
+
viz_every=1,
|
| 262 |
+
nfe=self.solver_nfe,
|
| 263 |
+
method=self.solver_method,
|
| 264 |
+
time_mapping_divisor=self.time_mapping_divisor,
|
| 265 |
+
)
|
| 266 |
+
self.emb = SinusodialTimeEmbedding(self.time_emb_dim)
|
| 267 |
+
self.net = WN(
|
| 268 |
+
input_dim=self.output_dim,
|
| 269 |
+
output_dim=self.output_dim,
|
| 270 |
+
local_dim=self.cond_dim,
|
| 271 |
+
global_dim=self.time_emb_dim,
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
def _perturb(self, ψ1: Tensor, t: Tensor | None = None):
|
| 275 |
+
"""
|
| 276 |
+
Perturb ψ1 to ψt.
|
| 277 |
+
"""
|
| 278 |
+
raise NotImplementedError
|
| 279 |
+
|
| 280 |
+
def _sample_ψ0(self, x: Tensor):
|
| 281 |
+
"""
|
| 282 |
+
Args:
|
| 283 |
+
x: (b c t), which implies the shape of ψ0
|
| 284 |
+
"""
|
| 285 |
+
shape = list(x.shape)
|
| 286 |
+
shape[1] = self.output_dim
|
| 287 |
+
if self.training:
|
| 288 |
+
g = None
|
| 289 |
+
else:
|
| 290 |
+
g = torch.Generator(device=x.device)
|
| 291 |
+
g.manual_seed(0) # deterministic sampling during eval
|
| 292 |
+
ψ0 = torch.randn(shape, device=x.device, dtype=x.dtype, generator=g)
|
| 293 |
+
return ψ0
|
| 294 |
+
|
| 295 |
+
@property
|
| 296 |
+
def sigma(self):
|
| 297 |
+
return 1e-4
|
| 298 |
+
|
| 299 |
+
def _to_ψt(self, *, ψ1: Tensor, ψ0: Tensor, t: Tensor):
|
| 300 |
+
"""
|
| 301 |
+
Eq (22)
|
| 302 |
+
"""
|
| 303 |
+
while t.dim() < ψ1.dim():
|
| 304 |
+
t = t.unsqueeze(-1)
|
| 305 |
+
μ = t * ψ1 + (1 - t) * ψ0
|
| 306 |
+
return μ + torch.randn_like(μ) * self.sigma
|
| 307 |
+
|
| 308 |
+
def _to_u(self, *, ψ1, ψ0: Tensor):
|
| 309 |
+
"""
|
| 310 |
+
Eq (21)
|
| 311 |
+
"""
|
| 312 |
+
return ψ1 - ψ0
|
| 313 |
+
|
| 314 |
+
def _to_v(self, *, ψt, x, t: float | Tensor):
|
| 315 |
+
"""
|
| 316 |
+
Args:
|
| 317 |
+
ψt: (b c t)
|
| 318 |
+
x: (b c t)
|
| 319 |
+
t: (b)
|
| 320 |
+
Returns:
|
| 321 |
+
v: (b c t)
|
| 322 |
+
"""
|
| 323 |
+
if isinstance(t, (float, int)):
|
| 324 |
+
t = torch.full(ψt.shape[:1], t).to(ψt)
|
| 325 |
+
t = t.clamp(0, 1) # [0, 1)
|
| 326 |
+
g = self.emb(t) # (b d)
|
| 327 |
+
v = self.net(ψt, l=x, g=g)
|
| 328 |
+
return v
|
| 329 |
+
|
| 330 |
+
def compute_losses(self, x, y, ψ0) -> dict:
|
| 331 |
+
"""
|
| 332 |
+
Args:
|
| 333 |
+
x: (b c t)
|
| 334 |
+
y: (b c t)
|
| 335 |
+
Returns:
|
| 336 |
+
losses: dict
|
| 337 |
+
"""
|
| 338 |
+
t = torch.rand(len(x), device=x.device, dtype=x.dtype)
|
| 339 |
+
t = self.solver.time_mapping(t)
|
| 340 |
+
|
| 341 |
+
if ψ0 is None:
|
| 342 |
+
ψ0 = self._sample_ψ0(x)
|
| 343 |
+
|
| 344 |
+
ψt = self._to_ψt(ψ1=y, t=t, ψ0=ψ0)
|
| 345 |
+
|
| 346 |
+
v = self._to_v(ψt=ψt, t=t, x=x)
|
| 347 |
+
u = self._to_u(ψ1=y, ψ0=ψ0)
|
| 348 |
+
|
| 349 |
+
losses = dict(l1=F.l1_loss(v, u))
|
| 350 |
+
|
| 351 |
+
return losses
|
| 352 |
+
|
| 353 |
+
@torch.inference_mode()
|
| 354 |
+
def sample(self, x, ψ0=None, t0=0.0):
|
| 355 |
+
"""
|
| 356 |
+
Args:
|
| 357 |
+
x: (b c t)
|
| 358 |
+
Returns:
|
| 359 |
+
y: (b ... t)
|
| 360 |
+
"""
|
| 361 |
+
if ψ0 is None:
|
| 362 |
+
ψ0 = self._sample_ψ0(x)
|
| 363 |
+
f = lambda t, ψt, dt: self._to_v(ψt=ψt, t=t, x=x)
|
| 364 |
+
ψ1 = self.solver(f=f, ψ0=ψ0, t0=t0)
|
| 365 |
+
return ψ1
|
| 366 |
+
|
| 367 |
+
def forward(self, x: Tensor, y: Tensor | None = None, ψ0: Tensor | None = None, t0=0.0):
|
| 368 |
+
if y is None:
|
| 369 |
+
y = self.sample(x, ψ0=ψ0, t0=t0)
|
| 370 |
+
else:
|
| 371 |
+
self.losses = self.compute_losses(x, y, ψ0=ψ0)
|
| 372 |
+
return y
|
modules/repos_static/resemble_enhance/enhancer/lcfm/irmae.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from torch import Tensor, nn
|
| 7 |
+
from torch.nn.utils.parametrizations import weight_norm
|
| 8 |
+
|
| 9 |
+
from ...common import Normalizer
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@dataclass
|
| 15 |
+
class IRMAEOutput:
|
| 16 |
+
latent: Tensor # latent vector
|
| 17 |
+
decoded: Tensor | None # decoder output, include extra dim
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class ResBlock(nn.Sequential):
|
| 21 |
+
def __init__(self, channels, dilations=[1, 2, 4, 8]):
|
| 22 |
+
wn = weight_norm
|
| 23 |
+
super().__init__(
|
| 24 |
+
nn.GroupNorm(32, channels),
|
| 25 |
+
nn.GELU(),
|
| 26 |
+
wn(nn.Conv1d(channels, channels, 3, padding="same", dilation=dilations[0])),
|
| 27 |
+
nn.GroupNorm(32, channels),
|
| 28 |
+
nn.GELU(),
|
| 29 |
+
wn(nn.Conv1d(channels, channels, 3, padding="same", dilation=dilations[1])),
|
| 30 |
+
nn.GroupNorm(32, channels),
|
| 31 |
+
nn.GELU(),
|
| 32 |
+
wn(nn.Conv1d(channels, channels, 3, padding="same", dilation=dilations[2])),
|
| 33 |
+
nn.GroupNorm(32, channels),
|
| 34 |
+
nn.GELU(),
|
| 35 |
+
wn(nn.Conv1d(channels, channels, 3, padding="same", dilation=dilations[3])),
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
def forward(self, x: Tensor):
|
| 39 |
+
return x + super().forward(x)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class IRMAE(nn.Module):
|
| 43 |
+
def __init__(
|
| 44 |
+
self,
|
| 45 |
+
input_dim,
|
| 46 |
+
output_dim,
|
| 47 |
+
latent_dim,
|
| 48 |
+
hidden_dim=1024,
|
| 49 |
+
num_irms=4,
|
| 50 |
+
):
|
| 51 |
+
"""
|
| 52 |
+
Args:
|
| 53 |
+
input_dim: input dimension
|
| 54 |
+
output_dim: output dimension
|
| 55 |
+
latent_dim: latent dimension
|
| 56 |
+
hidden_dim: hidden layer dimension
|
| 57 |
+
num_irm_matrics: number of implicit rank minimization matrices
|
| 58 |
+
norm: normalization layer
|
| 59 |
+
"""
|
| 60 |
+
self.input_dim = input_dim
|
| 61 |
+
super().__init__()
|
| 62 |
+
|
| 63 |
+
self.encoder = nn.Sequential(
|
| 64 |
+
nn.Conv1d(input_dim, hidden_dim, 3, padding="same"),
|
| 65 |
+
*[ResBlock(hidden_dim) for _ in range(4)],
|
| 66 |
+
# Try to obtain compact representation (https://proceedings.neurips.cc/paper/2020/file/a9078e8653368c9c291ae2f8b74012e7-Paper.pdf)
|
| 67 |
+
*[nn.Conv1d(hidden_dim if i == 0 else latent_dim, latent_dim, 1, bias=False) for i in range(num_irms)],
|
| 68 |
+
nn.Tanh(),
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
self.decoder = nn.Sequential(
|
| 72 |
+
nn.Conv1d(latent_dim, hidden_dim, 3, padding="same"),
|
| 73 |
+
*[ResBlock(hidden_dim) for _ in range(4)],
|
| 74 |
+
nn.Conv1d(hidden_dim, output_dim, 1),
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
self.head = nn.Sequential(
|
| 78 |
+
nn.Conv1d(output_dim, hidden_dim, 3, padding="same"),
|
| 79 |
+
nn.GELU(),
|
| 80 |
+
nn.Conv1d(hidden_dim, input_dim, 1),
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
self.estimator = Normalizer()
|
| 84 |
+
|
| 85 |
+
def encode(self, x):
|
| 86 |
+
"""
|
| 87 |
+
Args:
|
| 88 |
+
x: (b c t) tensor
|
| 89 |
+
"""
|
| 90 |
+
z = self.encoder(x) # (b c t)
|
| 91 |
+
_ = self.estimator(z) # Estimate the glboal mean and std of z
|
| 92 |
+
self.stats = {}
|
| 93 |
+
self.stats["z_mean"] = z.mean().item()
|
| 94 |
+
self.stats["z_std"] = z.std().item()
|
| 95 |
+
self.stats["z_abs_68"] = z.abs().quantile(0.6827).item()
|
| 96 |
+
self.stats["z_abs_95"] = z.abs().quantile(0.9545).item()
|
| 97 |
+
self.stats["z_abs_99"] = z.abs().quantile(0.9973).item()
|
| 98 |
+
return z
|
| 99 |
+
|
| 100 |
+
def decode(self, z):
|
| 101 |
+
"""
|
| 102 |
+
Args:
|
| 103 |
+
z: (b c t) tensor
|
| 104 |
+
"""
|
| 105 |
+
return self.decoder(z)
|
| 106 |
+
|
| 107 |
+
def forward(self, x, skip_decoding=False):
|
| 108 |
+
"""
|
| 109 |
+
Args:
|
| 110 |
+
x: (b c t) tensor
|
| 111 |
+
skip_decoding: if True, skip the decoding step
|
| 112 |
+
"""
|
| 113 |
+
z = self.encode(x) # q(z|x)
|
| 114 |
+
|
| 115 |
+
if skip_decoding:
|
| 116 |
+
# This speeds up the training in cfm only mode
|
| 117 |
+
decoded = None
|
| 118 |
+
else:
|
| 119 |
+
decoded = self.decode(z) # p(x|z)
|
| 120 |
+
predicted = self.head(decoded)
|
| 121 |
+
self.losses = dict(mse=F.mse_loss(predicted, x))
|
| 122 |
+
|
| 123 |
+
return IRMAEOutput(latent=z, decoded=decoded)
|
modules/repos_static/resemble_enhance/enhancer/lcfm/lcfm.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from enum import Enum
|
| 3 |
+
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from torch import Tensor, nn
|
| 8 |
+
|
| 9 |
+
from .cfm import CFM
|
| 10 |
+
from .irmae import IRMAE, IRMAEOutput
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def freeze_(module):
|
| 16 |
+
for p in module.parameters():
|
| 17 |
+
p.requires_grad_(False)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class LCFM(nn.Module):
|
| 21 |
+
class Mode(Enum):
|
| 22 |
+
AE = "ae"
|
| 23 |
+
CFM = "cfm"
|
| 24 |
+
|
| 25 |
+
def __init__(self, ae: IRMAE, cfm: CFM, z_scale: float = 1.0):
|
| 26 |
+
super().__init__()
|
| 27 |
+
self.ae = ae
|
| 28 |
+
self.cfm = cfm
|
| 29 |
+
self.z_scale = z_scale
|
| 30 |
+
self._mode = None
|
| 31 |
+
self._eval_tau = 0.5
|
| 32 |
+
|
| 33 |
+
@property
|
| 34 |
+
def mode(self):
|
| 35 |
+
return self._mode
|
| 36 |
+
|
| 37 |
+
def set_mode_(self, mode):
|
| 38 |
+
mode = self.Mode(mode)
|
| 39 |
+
self._mode = mode
|
| 40 |
+
|
| 41 |
+
if mode == mode.AE:
|
| 42 |
+
freeze_(self.cfm)
|
| 43 |
+
logger.info("Freeze cfm")
|
| 44 |
+
elif mode == mode.CFM:
|
| 45 |
+
freeze_(self.ae)
|
| 46 |
+
logger.info("Freeze ae (encoder and decoder)")
|
| 47 |
+
else:
|
| 48 |
+
raise ValueError(f"Unknown training mode: {mode}")
|
| 49 |
+
|
| 50 |
+
def get_running_train_loop(self):
|
| 51 |
+
try:
|
| 52 |
+
# Lazy import
|
| 53 |
+
from ...utils.train_loop import TrainLoop
|
| 54 |
+
|
| 55 |
+
return TrainLoop.get_running_loop()
|
| 56 |
+
except ImportError:
|
| 57 |
+
return None
|
| 58 |
+
|
| 59 |
+
@property
|
| 60 |
+
def global_step(self):
|
| 61 |
+
loop = self.get_running_train_loop()
|
| 62 |
+
if loop is None:
|
| 63 |
+
return None
|
| 64 |
+
return loop.global_step
|
| 65 |
+
|
| 66 |
+
@torch.no_grad()
|
| 67 |
+
def _visualize(self, x, y, y_):
|
| 68 |
+
loop = self.get_running_train_loop()
|
| 69 |
+
if loop is None:
|
| 70 |
+
return
|
| 71 |
+
|
| 72 |
+
plt.subplot(221)
|
| 73 |
+
plt.imshow(y[0].detach().cpu().numpy(), aspect="auto", origin="lower", interpolation="none")
|
| 74 |
+
plt.title("GT")
|
| 75 |
+
|
| 76 |
+
plt.subplot(222)
|
| 77 |
+
y_ = y_[:, : y.shape[1]]
|
| 78 |
+
plt.imshow(y_[0].detach().cpu().numpy(), aspect="auto", origin="lower", interpolation="none")
|
| 79 |
+
plt.title("Posterior")
|
| 80 |
+
|
| 81 |
+
plt.subplot(223)
|
| 82 |
+
z_ = self.cfm(x)
|
| 83 |
+
y__ = self.ae.decode(z_)
|
| 84 |
+
y__ = y__[:, : y.shape[1]]
|
| 85 |
+
plt.imshow(y__[0].detach().cpu().numpy(), aspect="auto", origin="lower", interpolation="none")
|
| 86 |
+
plt.title("C-Prior")
|
| 87 |
+
del y__
|
| 88 |
+
|
| 89 |
+
plt.subplot(224)
|
| 90 |
+
z_ = torch.randn_like(z_)
|
| 91 |
+
y__ = self.ae.decode(z_)
|
| 92 |
+
y__ = y__[:, : y.shape[1]]
|
| 93 |
+
plt.imshow(y__[0].detach().cpu().numpy(), aspect="auto", origin="lower", interpolation="none")
|
| 94 |
+
plt.title("Prior")
|
| 95 |
+
del z_, y__
|
| 96 |
+
|
| 97 |
+
path = loop.make_current_step_viz_path("recon", ".png")
|
| 98 |
+
path.parent.mkdir(exist_ok=True, parents=True)
|
| 99 |
+
plt.tight_layout()
|
| 100 |
+
plt.savefig(path, dpi=500)
|
| 101 |
+
plt.close()
|
| 102 |
+
|
| 103 |
+
def _scale(self, z: Tensor):
|
| 104 |
+
return z * self.z_scale
|
| 105 |
+
|
| 106 |
+
def _unscale(self, z: Tensor):
|
| 107 |
+
return z / self.z_scale
|
| 108 |
+
|
| 109 |
+
def eval_tau_(self, tau):
|
| 110 |
+
self._eval_tau = tau
|
| 111 |
+
|
| 112 |
+
def forward(self, x, y: Tensor | None = None, ψ0: Tensor | None = None):
|
| 113 |
+
"""
|
| 114 |
+
Args:
|
| 115 |
+
x: (b d t), condition mel
|
| 116 |
+
y: (b d t), target mel
|
| 117 |
+
ψ0: (b d t), starting mel
|
| 118 |
+
"""
|
| 119 |
+
if self.mode == self.Mode.CFM:
|
| 120 |
+
self.ae.eval() # Always set to eval when training cfm
|
| 121 |
+
|
| 122 |
+
if ψ0 is not None:
|
| 123 |
+
ψ0 = self._scale(self.ae.encode(ψ0))
|
| 124 |
+
if self.training:
|
| 125 |
+
tau = torch.rand_like(ψ0[:, :1, :1])
|
| 126 |
+
else:
|
| 127 |
+
tau = self._eval_tau
|
| 128 |
+
ψ0 = tau * torch.randn_like(ψ0) + (1 - tau) * ψ0
|
| 129 |
+
|
| 130 |
+
if y is None:
|
| 131 |
+
if self.mode == self.Mode.AE:
|
| 132 |
+
with torch.no_grad():
|
| 133 |
+
training = self.ae.training
|
| 134 |
+
self.ae.eval()
|
| 135 |
+
z = self.ae.encode(x)
|
| 136 |
+
self.ae.train(training)
|
| 137 |
+
else:
|
| 138 |
+
z = self._unscale(self.cfm(x, ψ0=ψ0))
|
| 139 |
+
|
| 140 |
+
h = self.ae.decode(z)
|
| 141 |
+
else:
|
| 142 |
+
ae_output: IRMAEOutput = self.ae(y, skip_decoding=self.mode == self.Mode.CFM)
|
| 143 |
+
|
| 144 |
+
if self.mode == self.Mode.CFM:
|
| 145 |
+
_ = self.cfm(x, self._scale(ae_output.latent.detach()), ψ0=ψ0)
|
| 146 |
+
|
| 147 |
+
h = ae_output.decoded
|
| 148 |
+
|
| 149 |
+
if h is not None and self.global_step is not None and self.global_step % 100 == 0:
|
| 150 |
+
self._visualize(x[:1], y[:1], h[:1])
|
| 151 |
+
|
| 152 |
+
return h
|
modules/repos_static/resemble_enhance/enhancer/lcfm/wn.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import math
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
|
| 7 |
+
logger = logging.getLogger(__name__)
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@torch.jit.script
|
| 11 |
+
def _fused_tanh_sigmoid(h):
|
| 12 |
+
a, b = h.chunk(2, dim=1)
|
| 13 |
+
h = a.tanh() * b.sigmoid()
|
| 14 |
+
return h
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class WNLayer(nn.Module):
|
| 18 |
+
"""
|
| 19 |
+
A DiffWave-like WN
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(self, hidden_dim, local_dim, global_dim, kernel_size, dilation):
|
| 23 |
+
super().__init__()
|
| 24 |
+
|
| 25 |
+
local_output_dim = hidden_dim * 2
|
| 26 |
+
|
| 27 |
+
if global_dim is not None:
|
| 28 |
+
self.gconv = nn.Conv1d(global_dim, hidden_dim, 1)
|
| 29 |
+
|
| 30 |
+
if local_dim is not None:
|
| 31 |
+
self.lconv = nn.Conv1d(local_dim, local_output_dim, 1)
|
| 32 |
+
|
| 33 |
+
self.dconv = nn.Conv1d(hidden_dim, local_output_dim, kernel_size, dilation=dilation, padding="same")
|
| 34 |
+
|
| 35 |
+
self.out = nn.Conv1d(hidden_dim, 2 * hidden_dim, kernel_size=1)
|
| 36 |
+
|
| 37 |
+
def forward(self, z, l, g):
|
| 38 |
+
identity = z
|
| 39 |
+
|
| 40 |
+
if g is not None:
|
| 41 |
+
if g.dim() == 2:
|
| 42 |
+
g = g.unsqueeze(-1)
|
| 43 |
+
z = z + self.gconv(g)
|
| 44 |
+
|
| 45 |
+
z = self.dconv(z)
|
| 46 |
+
|
| 47 |
+
if l is not None:
|
| 48 |
+
z = z + self.lconv(l)
|
| 49 |
+
|
| 50 |
+
z = _fused_tanh_sigmoid(z)
|
| 51 |
+
|
| 52 |
+
h = self.out(z)
|
| 53 |
+
|
| 54 |
+
z, s = h.chunk(2, dim=1)
|
| 55 |
+
|
| 56 |
+
o = (z + identity) / math.sqrt(2)
|
| 57 |
+
|
| 58 |
+
return o, s
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class WN(nn.Module):
|
| 62 |
+
def __init__(
|
| 63 |
+
self,
|
| 64 |
+
input_dim,
|
| 65 |
+
output_dim,
|
| 66 |
+
local_dim=None,
|
| 67 |
+
global_dim=None,
|
| 68 |
+
n_layers=30,
|
| 69 |
+
kernel_size=3,
|
| 70 |
+
dilation_cycle=5,
|
| 71 |
+
hidden_dim=512,
|
| 72 |
+
):
|
| 73 |
+
super().__init__()
|
| 74 |
+
assert kernel_size % 2 == 1
|
| 75 |
+
assert hidden_dim % 2 == 0
|
| 76 |
+
|
| 77 |
+
self.input_dim = input_dim
|
| 78 |
+
self.hidden_dim = hidden_dim
|
| 79 |
+
self.local_dim = local_dim
|
| 80 |
+
self.global_dim = global_dim
|
| 81 |
+
|
| 82 |
+
self.start = nn.Conv1d(input_dim, hidden_dim, 1)
|
| 83 |
+
if local_dim is not None:
|
| 84 |
+
self.local_norm = nn.InstanceNorm1d(local_dim)
|
| 85 |
+
|
| 86 |
+
self.layers = nn.ModuleList(
|
| 87 |
+
[
|
| 88 |
+
WNLayer(
|
| 89 |
+
hidden_dim=hidden_dim,
|
| 90 |
+
local_dim=local_dim,
|
| 91 |
+
global_dim=global_dim,
|
| 92 |
+
kernel_size=kernel_size,
|
| 93 |
+
dilation=2 ** (i % dilation_cycle),
|
| 94 |
+
)
|
| 95 |
+
for i in range(n_layers)
|
| 96 |
+
]
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
self.end = nn.Conv1d(hidden_dim, output_dim, 1)
|
| 100 |
+
|
| 101 |
+
def forward(self, z, l=None, g=None):
|
| 102 |
+
"""
|
| 103 |
+
Args:
|
| 104 |
+
z: input (b c t)
|
| 105 |
+
l: local condition (b c t)
|
| 106 |
+
g: global condition (b d)
|
| 107 |
+
"""
|
| 108 |
+
z = self.start(z)
|
| 109 |
+
|
| 110 |
+
if l is not None:
|
| 111 |
+
l = self.local_norm(l)
|
| 112 |
+
|
| 113 |
+
# Skips
|
| 114 |
+
s_list = []
|
| 115 |
+
|
| 116 |
+
for layer in self.layers:
|
| 117 |
+
z, s = layer(z, l, g)
|
| 118 |
+
s_list.append(s)
|
| 119 |
+
|
| 120 |
+
s_list = torch.stack(s_list, dim=0).sum(dim=0)
|
| 121 |
+
s_list = s_list / math.sqrt(len(self.layers))
|
| 122 |
+
|
| 123 |
+
o = self.end(s_list)
|
| 124 |
+
|
| 125 |
+
return o
|
| 126 |
+
|
| 127 |
+
def summarize(self, length=100):
|
| 128 |
+
from ptflops import get_model_complexity_info
|
| 129 |
+
|
| 130 |
+
x = torch.randn(1, self.input_dim, length)
|
| 131 |
+
|
| 132 |
+
macs, params = get_model_complexity_info(
|
| 133 |
+
self,
|
| 134 |
+
(self.input_dim, length),
|
| 135 |
+
as_strings=True,
|
| 136 |
+
print_per_layer_stat=True,
|
| 137 |
+
verbose=True,
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
print(f"Input shape: {x.shape}")
|
| 141 |
+
print(f"Computational complexity: {macs}")
|
| 142 |
+
print(f"Number of parameters: {params}")
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
if __name__ == "__main__":
|
| 146 |
+
model = WN(input_dim=64, output_dim=64)
|
| 147 |
+
model.summarize()
|
modules/repos_static/resemble_enhance/enhancer/univnet/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .univnet import UnivNet
|
modules/repos_static/resemble_enhance/enhancer/univnet/alias_free_torch/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
| 2 |
+
# LICENSE is in incl_licenses directory.
|
| 3 |
+
|
| 4 |
+
from .filter import *
|
| 5 |
+
from .resample import *
|
modules/repos_static/resemble_enhance/enhancer/univnet/alias_free_torch/filter.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
| 2 |
+
# LICENSE is in incl_licenses directory.
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
import math
|
| 8 |
+
|
| 9 |
+
if 'sinc' in dir(torch):
|
| 10 |
+
sinc = torch.sinc
|
| 11 |
+
else:
|
| 12 |
+
# This code is adopted from adefossez's julius.core.sinc under the MIT License
|
| 13 |
+
# https://adefossez.github.io/julius/julius/core.html
|
| 14 |
+
# LICENSE is in incl_licenses directory.
|
| 15 |
+
def sinc(x: torch.Tensor):
|
| 16 |
+
"""
|
| 17 |
+
Implementation of sinc, i.e. sin(pi * x) / (pi * x)
|
| 18 |
+
__Warning__: Different to julius.sinc, the input is multiplied by `pi`!
|
| 19 |
+
"""
|
| 20 |
+
return torch.where(x == 0,
|
| 21 |
+
torch.tensor(1., device=x.device, dtype=x.dtype),
|
| 22 |
+
torch.sin(math.pi * x) / math.pi / x)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
|
| 26 |
+
# https://adefossez.github.io/julius/julius/lowpass.html
|
| 27 |
+
# LICENSE is in incl_licenses directory.
|
| 28 |
+
def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size]
|
| 29 |
+
even = (kernel_size % 2 == 0)
|
| 30 |
+
half_size = kernel_size // 2
|
| 31 |
+
|
| 32 |
+
#For kaiser window
|
| 33 |
+
delta_f = 4 * half_width
|
| 34 |
+
A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
|
| 35 |
+
if A > 50.:
|
| 36 |
+
beta = 0.1102 * (A - 8.7)
|
| 37 |
+
elif A >= 21.:
|
| 38 |
+
beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.)
|
| 39 |
+
else:
|
| 40 |
+
beta = 0.
|
| 41 |
+
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
|
| 42 |
+
|
| 43 |
+
# ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
|
| 44 |
+
if even:
|
| 45 |
+
time = (torch.arange(-half_size, half_size) + 0.5)
|
| 46 |
+
else:
|
| 47 |
+
time = torch.arange(kernel_size) - half_size
|
| 48 |
+
if cutoff == 0:
|
| 49 |
+
filter_ = torch.zeros_like(time)
|
| 50 |
+
else:
|
| 51 |
+
filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
|
| 52 |
+
# Normalize filter to have sum = 1, otherwise we will have a small leakage
|
| 53 |
+
# of the constant component in the input signal.
|
| 54 |
+
filter_ /= filter_.sum()
|
| 55 |
+
filter = filter_.view(1, 1, kernel_size)
|
| 56 |
+
|
| 57 |
+
return filter
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class LowPassFilter1d(nn.Module):
|
| 61 |
+
def __init__(self,
|
| 62 |
+
cutoff=0.5,
|
| 63 |
+
half_width=0.6,
|
| 64 |
+
stride: int = 1,
|
| 65 |
+
padding: bool = True,
|
| 66 |
+
padding_mode: str = 'replicate',
|
| 67 |
+
kernel_size: int = 12):
|
| 68 |
+
# kernel_size should be even number for stylegan3 setup,
|
| 69 |
+
# in this implementation, odd number is also possible.
|
| 70 |
+
super().__init__()
|
| 71 |
+
if cutoff < -0.:
|
| 72 |
+
raise ValueError("Minimum cutoff must be larger than zero.")
|
| 73 |
+
if cutoff > 0.5:
|
| 74 |
+
raise ValueError("A cutoff above 0.5 does not make sense.")
|
| 75 |
+
self.kernel_size = kernel_size
|
| 76 |
+
self.even = (kernel_size % 2 == 0)
|
| 77 |
+
self.pad_left = kernel_size // 2 - int(self.even)
|
| 78 |
+
self.pad_right = kernel_size // 2
|
| 79 |
+
self.stride = stride
|
| 80 |
+
self.padding = padding
|
| 81 |
+
self.padding_mode = padding_mode
|
| 82 |
+
filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
|
| 83 |
+
self.register_buffer("filter", filter)
|
| 84 |
+
|
| 85 |
+
#input [B, C, T]
|
| 86 |
+
def forward(self, x):
|
| 87 |
+
_, C, _ = x.shape
|
| 88 |
+
|
| 89 |
+
if self.padding:
|
| 90 |
+
x = F.pad(x, (self.pad_left, self.pad_right),
|
| 91 |
+
mode=self.padding_mode)
|
| 92 |
+
out = F.conv1d(x, self.filter.expand(C, -1, -1),
|
| 93 |
+
stride=self.stride, groups=C)
|
| 94 |
+
|
| 95 |
+
return out
|
modules/repos_static/resemble_enhance/enhancer/univnet/alias_free_torch/resample.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
| 2 |
+
# LICENSE is in incl_licenses directory.
|
| 3 |
+
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from torch.nn import functional as F
|
| 6 |
+
from .filter import LowPassFilter1d
|
| 7 |
+
from .filter import kaiser_sinc_filter1d
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class UpSample1d(nn.Module):
|
| 11 |
+
def __init__(self, ratio=2, kernel_size=None):
|
| 12 |
+
super().__init__()
|
| 13 |
+
self.ratio = ratio
|
| 14 |
+
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
| 15 |
+
self.stride = ratio
|
| 16 |
+
self.pad = self.kernel_size // ratio - 1
|
| 17 |
+
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
|
| 18 |
+
self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
|
| 19 |
+
filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio,
|
| 20 |
+
half_width=0.6 / ratio,
|
| 21 |
+
kernel_size=self.kernel_size)
|
| 22 |
+
self.register_buffer("filter", filter)
|
| 23 |
+
|
| 24 |
+
# x: [B, C, T]
|
| 25 |
+
def forward(self, x):
|
| 26 |
+
_, C, _ = x.shape
|
| 27 |
+
|
| 28 |
+
x = F.pad(x, (self.pad, self.pad), mode='replicate')
|
| 29 |
+
x = self.ratio * F.conv_transpose1d(
|
| 30 |
+
x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
|
| 31 |
+
x = x[..., self.pad_left:-self.pad_right]
|
| 32 |
+
|
| 33 |
+
return x
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class DownSample1d(nn.Module):
|
| 37 |
+
def __init__(self, ratio=2, kernel_size=None):
|
| 38 |
+
super().__init__()
|
| 39 |
+
self.ratio = ratio
|
| 40 |
+
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
| 41 |
+
self.lowpass = LowPassFilter1d(cutoff=0.5 / ratio,
|
| 42 |
+
half_width=0.6 / ratio,
|
| 43 |
+
stride=ratio,
|
| 44 |
+
kernel_size=self.kernel_size)
|
| 45 |
+
|
| 46 |
+
def forward(self, x):
|
| 47 |
+
xx = self.lowpass(x)
|
| 48 |
+
|
| 49 |
+
return xx
|
modules/repos_static/resemble_enhance/enhancer/univnet/amp.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Refer from https://github.com/NVIDIA/BigVGAN
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from torch import nn
|
| 8 |
+
from torch.nn.utils.parametrizations import weight_norm
|
| 9 |
+
|
| 10 |
+
from .alias_free_torch import DownSample1d, UpSample1d
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class SnakeBeta(nn.Module):
|
| 14 |
+
"""
|
| 15 |
+
A modified Snake function which uses separate parameters for the magnitude of the periodic components
|
| 16 |
+
Shape:
|
| 17 |
+
- Input: (B, C, T)
|
| 18 |
+
- Output: (B, C, T), same shape as the input
|
| 19 |
+
Parameters:
|
| 20 |
+
- alpha - trainable parameter that controls frequency
|
| 21 |
+
- beta - trainable parameter that controls magnitude
|
| 22 |
+
References:
|
| 23 |
+
- This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
| 24 |
+
https://arxiv.org/abs/2006.08195
|
| 25 |
+
Examples:
|
| 26 |
+
>>> a1 = snakebeta(256)
|
| 27 |
+
>>> x = torch.randn(256)
|
| 28 |
+
>>> x = a1(x)
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
def __init__(self, in_features, alpha=1.0, clamp=(1e-2, 50)):
|
| 32 |
+
"""
|
| 33 |
+
Initialization.
|
| 34 |
+
INPUT:
|
| 35 |
+
- in_features: shape of the input
|
| 36 |
+
- alpha - trainable parameter that controls frequency
|
| 37 |
+
- beta - trainable parameter that controls magnitude
|
| 38 |
+
alpha is initialized to 1 by default, higher values = higher-frequency.
|
| 39 |
+
beta is initialized to 1 by default, higher values = higher-magnitude.
|
| 40 |
+
alpha will be trained along with the rest of your model.
|
| 41 |
+
"""
|
| 42 |
+
super().__init__()
|
| 43 |
+
self.in_features = in_features
|
| 44 |
+
self.log_alpha = nn.Parameter(torch.zeros(in_features) + math.log(alpha))
|
| 45 |
+
self.log_beta = nn.Parameter(torch.zeros(in_features) + math.log(alpha))
|
| 46 |
+
self.clamp = clamp
|
| 47 |
+
|
| 48 |
+
def forward(self, x):
|
| 49 |
+
"""
|
| 50 |
+
Forward pass of the function.
|
| 51 |
+
Applies the function to the input elementwise.
|
| 52 |
+
SnakeBeta ∶= x + 1/b * sin^2 (xa)
|
| 53 |
+
"""
|
| 54 |
+
alpha = self.log_alpha.exp().clamp(*self.clamp)
|
| 55 |
+
alpha = alpha[None, :, None]
|
| 56 |
+
|
| 57 |
+
beta = self.log_beta.exp().clamp(*self.clamp)
|
| 58 |
+
beta = beta[None, :, None]
|
| 59 |
+
|
| 60 |
+
x = x + (1.0 / beta) * (x * alpha).sin().pow(2)
|
| 61 |
+
|
| 62 |
+
return x
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class UpActDown(nn.Module):
|
| 66 |
+
def __init__(
|
| 67 |
+
self,
|
| 68 |
+
act,
|
| 69 |
+
up_ratio: int = 2,
|
| 70 |
+
down_ratio: int = 2,
|
| 71 |
+
up_kernel_size: int = 12,
|
| 72 |
+
down_kernel_size: int = 12,
|
| 73 |
+
):
|
| 74 |
+
super().__init__()
|
| 75 |
+
self.up_ratio = up_ratio
|
| 76 |
+
self.down_ratio = down_ratio
|
| 77 |
+
self.act = act
|
| 78 |
+
self.upsample = UpSample1d(up_ratio, up_kernel_size)
|
| 79 |
+
self.downsample = DownSample1d(down_ratio, down_kernel_size)
|
| 80 |
+
|
| 81 |
+
def forward(self, x):
|
| 82 |
+
# x: [B,C,T]
|
| 83 |
+
x = self.upsample(x)
|
| 84 |
+
x = self.act(x)
|
| 85 |
+
x = self.downsample(x)
|
| 86 |
+
return x
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class AMPBlock(nn.Sequential):
|
| 90 |
+
def __init__(self, channels, *, kernel_size=3, dilations=(1, 3, 5)):
|
| 91 |
+
super().__init__(*(self._make_layer(channels, kernel_size, d) for d in dilations))
|
| 92 |
+
|
| 93 |
+
def _make_layer(self, channels, kernel_size, dilation):
|
| 94 |
+
return nn.Sequential(
|
| 95 |
+
weight_norm(nn.Conv1d(channels, channels, kernel_size, dilation=dilation, padding="same")),
|
| 96 |
+
UpActDown(act=SnakeBeta(channels)),
|
| 97 |
+
weight_norm(nn.Conv1d(channels, channels, kernel_size, padding="same")),
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
def forward(self, x):
|
| 101 |
+
return x + super().forward(x)
|
modules/repos_static/resemble_enhance/enhancer/univnet/discriminator.py
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from torch import Tensor, nn
|
| 6 |
+
from torch.nn.utils.parametrizations import weight_norm
|
| 7 |
+
|
| 8 |
+
from ..hparams import HParams
|
| 9 |
+
from .mrstft import get_stft_cfgs
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class PeriodNetwork(nn.Module):
|
| 15 |
+
def __init__(self, period):
|
| 16 |
+
super().__init__()
|
| 17 |
+
self.period = period
|
| 18 |
+
wn = weight_norm
|
| 19 |
+
self.convs = nn.ModuleList(
|
| 20 |
+
[
|
| 21 |
+
wn(nn.Conv2d(1, 64, (5, 1), (3, 1), padding=(2, 0))),
|
| 22 |
+
wn(nn.Conv2d(64, 128, (5, 1), (3, 1), padding=(2, 0))),
|
| 23 |
+
wn(nn.Conv2d(128, 256, (5, 1), (3, 1), padding=(2, 0))),
|
| 24 |
+
wn(nn.Conv2d(256, 512, (5, 1), (3, 1), padding=(2, 0))),
|
| 25 |
+
wn(nn.Conv2d(512, 1024, (5, 1), 1, padding=(2, 0))),
|
| 26 |
+
]
|
| 27 |
+
)
|
| 28 |
+
self.conv_post = wn(nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
| 29 |
+
|
| 30 |
+
def forward(self, x):
|
| 31 |
+
"""
|
| 32 |
+
Args:
|
| 33 |
+
x: [B, 1, T]
|
| 34 |
+
"""
|
| 35 |
+
assert x.dim() == 3, f"(B, 1, T) is expected, but got {x.shape}."
|
| 36 |
+
|
| 37 |
+
# 1d to 2d
|
| 38 |
+
b, c, t = x.shape
|
| 39 |
+
if t % self.period != 0: # pad first
|
| 40 |
+
n_pad = self.period - (t % self.period)
|
| 41 |
+
x = F.pad(x, (0, n_pad), "reflect")
|
| 42 |
+
t = t + n_pad
|
| 43 |
+
x = x.view(b, c, t // self.period, self.period)
|
| 44 |
+
|
| 45 |
+
for l in self.convs:
|
| 46 |
+
x = l(x)
|
| 47 |
+
x = F.leaky_relu(x, 0.2)
|
| 48 |
+
x = self.conv_post(x)
|
| 49 |
+
x = torch.flatten(x, 1, -1)
|
| 50 |
+
|
| 51 |
+
return x
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class SpecNetwork(nn.Module):
|
| 55 |
+
def __init__(self, stft_cfg: dict):
|
| 56 |
+
super().__init__()
|
| 57 |
+
wn = weight_norm
|
| 58 |
+
self.stft_cfg = stft_cfg
|
| 59 |
+
self.convs = nn.ModuleList(
|
| 60 |
+
[
|
| 61 |
+
wn(nn.Conv2d(1, 32, (3, 9), padding=(1, 4))),
|
| 62 |
+
wn(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
|
| 63 |
+
wn(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
|
| 64 |
+
wn(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
|
| 65 |
+
wn(nn.Conv2d(32, 32, (3, 3), padding=(1, 1))),
|
| 66 |
+
]
|
| 67 |
+
)
|
| 68 |
+
self.conv_post = wn(nn.Conv2d(32, 1, (3, 3), padding=(1, 1)))
|
| 69 |
+
|
| 70 |
+
def forward(self, x):
|
| 71 |
+
"""
|
| 72 |
+
Args:
|
| 73 |
+
x: [B, 1, T]
|
| 74 |
+
"""
|
| 75 |
+
x = self.spectrogram(x)
|
| 76 |
+
x = x.unsqueeze(1)
|
| 77 |
+
for l in self.convs:
|
| 78 |
+
x = l(x)
|
| 79 |
+
x = F.leaky_relu(x, 0.2)
|
| 80 |
+
x = self.conv_post(x)
|
| 81 |
+
x = x.flatten(1, -1)
|
| 82 |
+
return x
|
| 83 |
+
|
| 84 |
+
def spectrogram(self, x):
|
| 85 |
+
"""
|
| 86 |
+
Args:
|
| 87 |
+
x: [B, 1, T]
|
| 88 |
+
"""
|
| 89 |
+
x = x.squeeze(1)
|
| 90 |
+
dtype = x.dtype
|
| 91 |
+
stft_cfg = dict(self.stft_cfg)
|
| 92 |
+
x = torch.stft(x.float(), center=False, return_complex=False, **stft_cfg)
|
| 93 |
+
mag = x.norm(p=2, dim=-1) # [B, F, TT]
|
| 94 |
+
mag = mag.to(dtype) # [B, F, TT]
|
| 95 |
+
return mag
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class MD(nn.ModuleList):
|
| 99 |
+
def __init__(self, l: list):
|
| 100 |
+
super().__init__([self._create_network(x) for x in l])
|
| 101 |
+
self._loss_type = None
|
| 102 |
+
|
| 103 |
+
def loss_type_(self, loss_type):
|
| 104 |
+
self._loss_type = loss_type
|
| 105 |
+
|
| 106 |
+
def _create_network(self, _):
|
| 107 |
+
raise NotImplementedError
|
| 108 |
+
|
| 109 |
+
def _forward_each(self, d, x, y):
|
| 110 |
+
assert self._loss_type is not None, "loss_type is not set."
|
| 111 |
+
loss_type = self._loss_type
|
| 112 |
+
|
| 113 |
+
if loss_type == "hinge":
|
| 114 |
+
if y == 0:
|
| 115 |
+
# d(x) should be small -> -1
|
| 116 |
+
loss = F.relu(1 + d(x)).mean()
|
| 117 |
+
elif y == 1:
|
| 118 |
+
# d(x) should be large -> 1
|
| 119 |
+
loss = F.relu(1 - d(x)).mean()
|
| 120 |
+
else:
|
| 121 |
+
raise ValueError(f"Invalid y: {y}")
|
| 122 |
+
elif loss_type == "wgan":
|
| 123 |
+
if y == 0:
|
| 124 |
+
loss = d(x).mean()
|
| 125 |
+
elif y == 1:
|
| 126 |
+
loss = -d(x).mean()
|
| 127 |
+
else:
|
| 128 |
+
raise ValueError(f"Invalid y: {y}")
|
| 129 |
+
else:
|
| 130 |
+
raise ValueError(f"Invalid loss_type: {loss_type}")
|
| 131 |
+
|
| 132 |
+
return loss
|
| 133 |
+
|
| 134 |
+
def forward(self, x, y) -> Tensor:
|
| 135 |
+
losses = [self._forward_each(d, x, y) for d in self]
|
| 136 |
+
return torch.stack(losses).mean()
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
class MPD(MD):
|
| 140 |
+
def __init__(self):
|
| 141 |
+
super().__init__([2, 3, 7, 13, 17])
|
| 142 |
+
|
| 143 |
+
def _create_network(self, period):
|
| 144 |
+
return PeriodNetwork(period)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class MRD(MD):
|
| 148 |
+
def __init__(self, stft_cfgs):
|
| 149 |
+
super().__init__(stft_cfgs)
|
| 150 |
+
|
| 151 |
+
def _create_network(self, stft_cfg):
|
| 152 |
+
return SpecNetwork(stft_cfg)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
class Discriminator(nn.Module):
|
| 156 |
+
@property
|
| 157 |
+
def wav_rate(self):
|
| 158 |
+
return self.hp.wav_rate
|
| 159 |
+
|
| 160 |
+
def __init__(self, hp: HParams):
|
| 161 |
+
super().__init__()
|
| 162 |
+
self.hp = hp
|
| 163 |
+
self.stft_cfgs = get_stft_cfgs(hp)
|
| 164 |
+
self.mpd = MPD()
|
| 165 |
+
self.mrd = MRD(self.stft_cfgs)
|
| 166 |
+
self.dummy_float: Tensor
|
| 167 |
+
self.register_buffer("dummy_float", torch.zeros(0), persistent=False)
|
| 168 |
+
|
| 169 |
+
def loss_type_(self, loss_type):
|
| 170 |
+
self.mpd.loss_type_(loss_type)
|
| 171 |
+
self.mrd.loss_type_(loss_type)
|
| 172 |
+
|
| 173 |
+
def forward(self, fake, real=None):
|
| 174 |
+
"""
|
| 175 |
+
Args:
|
| 176 |
+
fake: [B T]
|
| 177 |
+
real: [B T]
|
| 178 |
+
"""
|
| 179 |
+
fake = fake.to(self.dummy_float)
|
| 180 |
+
|
| 181 |
+
if real is None:
|
| 182 |
+
self.loss_type_("wgan")
|
| 183 |
+
else:
|
| 184 |
+
length_difference = (fake.shape[-1] - real.shape[-1]) / real.shape[-1]
|
| 185 |
+
assert length_difference < 0.05, f"length_difference should be smaller than 5%"
|
| 186 |
+
|
| 187 |
+
self.loss_type_("hinge")
|
| 188 |
+
real = real.to(self.dummy_float)
|
| 189 |
+
|
| 190 |
+
fake = fake[..., : real.shape[-1]]
|
| 191 |
+
real = real[..., : fake.shape[-1]]
|
| 192 |
+
|
| 193 |
+
losses = {}
|
| 194 |
+
|
| 195 |
+
assert fake.dim() == 2, f"(B, T) is expected, but got {fake.shape}."
|
| 196 |
+
assert real is None or real.dim() == 2, f"(B, T) is expected, but got {real.shape}."
|
| 197 |
+
|
| 198 |
+
fake = fake.unsqueeze(1)
|
| 199 |
+
|
| 200 |
+
if real is None:
|
| 201 |
+
losses["mpd"] = self.mpd(fake, 1)
|
| 202 |
+
losses["mrd"] = self.mrd(fake, 1)
|
| 203 |
+
else:
|
| 204 |
+
real = real.unsqueeze(1)
|
| 205 |
+
losses["mpd_fake"] = self.mpd(fake, 0)
|
| 206 |
+
losses["mpd_real"] = self.mpd(real, 1)
|
| 207 |
+
losses["mrd_fake"] = self.mrd(fake, 0)
|
| 208 |
+
losses["mrd_real"] = self.mrd(real, 1)
|
| 209 |
+
|
| 210 |
+
return losses
|
modules/repos_static/resemble_enhance/enhancer/univnet/lvcnet.py
ADDED
|
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" refer from https://github.com/zceng/LVCNet """
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from torch import nn
|
| 7 |
+
from torch.nn.utils.parametrizations import weight_norm
|
| 8 |
+
|
| 9 |
+
from .amp import AMPBlock
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class KernelPredictor(torch.nn.Module):
|
| 13 |
+
"""Kernel predictor for the location-variable convolutions"""
|
| 14 |
+
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
cond_channels,
|
| 18 |
+
conv_in_channels,
|
| 19 |
+
conv_out_channels,
|
| 20 |
+
conv_layers,
|
| 21 |
+
conv_kernel_size=3,
|
| 22 |
+
kpnet_hidden_channels=64,
|
| 23 |
+
kpnet_conv_size=3,
|
| 24 |
+
kpnet_dropout=0.0,
|
| 25 |
+
kpnet_nonlinear_activation="LeakyReLU",
|
| 26 |
+
kpnet_nonlinear_activation_params={"negative_slope": 0.1},
|
| 27 |
+
):
|
| 28 |
+
"""
|
| 29 |
+
Args:
|
| 30 |
+
cond_channels (int): number of channel for the conditioning sequence,
|
| 31 |
+
conv_in_channels (int): number of channel for the input sequence,
|
| 32 |
+
conv_out_channels (int): number of channel for the output sequence,
|
| 33 |
+
conv_layers (int): number of layers
|
| 34 |
+
"""
|
| 35 |
+
super().__init__()
|
| 36 |
+
|
| 37 |
+
self.conv_in_channels = conv_in_channels
|
| 38 |
+
self.conv_out_channels = conv_out_channels
|
| 39 |
+
self.conv_kernel_size = conv_kernel_size
|
| 40 |
+
self.conv_layers = conv_layers
|
| 41 |
+
|
| 42 |
+
kpnet_kernel_channels = conv_in_channels * conv_out_channels * conv_kernel_size * conv_layers # l_w
|
| 43 |
+
kpnet_bias_channels = conv_out_channels * conv_layers # l_b
|
| 44 |
+
|
| 45 |
+
self.input_conv = nn.Sequential(
|
| 46 |
+
weight_norm(nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True)),
|
| 47 |
+
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
self.residual_convs = nn.ModuleList()
|
| 51 |
+
padding = (kpnet_conv_size - 1) // 2
|
| 52 |
+
for _ in range(3):
|
| 53 |
+
self.residual_convs.append(
|
| 54 |
+
nn.Sequential(
|
| 55 |
+
nn.Dropout(kpnet_dropout),
|
| 56 |
+
weight_norm(
|
| 57 |
+
nn.Conv1d(
|
| 58 |
+
kpnet_hidden_channels,
|
| 59 |
+
kpnet_hidden_channels,
|
| 60 |
+
kpnet_conv_size,
|
| 61 |
+
padding=padding,
|
| 62 |
+
bias=True,
|
| 63 |
+
)
|
| 64 |
+
),
|
| 65 |
+
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
| 66 |
+
weight_norm(
|
| 67 |
+
nn.Conv1d(
|
| 68 |
+
kpnet_hidden_channels,
|
| 69 |
+
kpnet_hidden_channels,
|
| 70 |
+
kpnet_conv_size,
|
| 71 |
+
padding=padding,
|
| 72 |
+
bias=True,
|
| 73 |
+
)
|
| 74 |
+
),
|
| 75 |
+
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
| 76 |
+
)
|
| 77 |
+
)
|
| 78 |
+
self.kernel_conv = weight_norm(
|
| 79 |
+
nn.Conv1d(
|
| 80 |
+
kpnet_hidden_channels,
|
| 81 |
+
kpnet_kernel_channels,
|
| 82 |
+
kpnet_conv_size,
|
| 83 |
+
padding=padding,
|
| 84 |
+
bias=True,
|
| 85 |
+
)
|
| 86 |
+
)
|
| 87 |
+
self.bias_conv = weight_norm(
|
| 88 |
+
nn.Conv1d(
|
| 89 |
+
kpnet_hidden_channels,
|
| 90 |
+
kpnet_bias_channels,
|
| 91 |
+
kpnet_conv_size,
|
| 92 |
+
padding=padding,
|
| 93 |
+
bias=True,
|
| 94 |
+
)
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
def forward(self, c):
|
| 98 |
+
"""
|
| 99 |
+
Args:
|
| 100 |
+
c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
|
| 101 |
+
"""
|
| 102 |
+
batch, _, cond_length = c.shape
|
| 103 |
+
c = self.input_conv(c)
|
| 104 |
+
for residual_conv in self.residual_convs:
|
| 105 |
+
residual_conv.to(c.device)
|
| 106 |
+
c = c + residual_conv(c)
|
| 107 |
+
k = self.kernel_conv(c)
|
| 108 |
+
b = self.bias_conv(c)
|
| 109 |
+
kernels = k.contiguous().view(
|
| 110 |
+
batch,
|
| 111 |
+
self.conv_layers,
|
| 112 |
+
self.conv_in_channels,
|
| 113 |
+
self.conv_out_channels,
|
| 114 |
+
self.conv_kernel_size,
|
| 115 |
+
cond_length,
|
| 116 |
+
)
|
| 117 |
+
bias = b.contiguous().view(
|
| 118 |
+
batch,
|
| 119 |
+
self.conv_layers,
|
| 120 |
+
self.conv_out_channels,
|
| 121 |
+
cond_length,
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
return kernels, bias
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
class LVCBlock(torch.nn.Module):
|
| 128 |
+
"""the location-variable convolutions"""
|
| 129 |
+
|
| 130 |
+
def __init__(
|
| 131 |
+
self,
|
| 132 |
+
in_channels,
|
| 133 |
+
cond_channels,
|
| 134 |
+
stride,
|
| 135 |
+
dilations=[1, 3, 9, 27],
|
| 136 |
+
lReLU_slope=0.2,
|
| 137 |
+
conv_kernel_size=3,
|
| 138 |
+
cond_hop_length=256,
|
| 139 |
+
kpnet_hidden_channels=64,
|
| 140 |
+
kpnet_conv_size=3,
|
| 141 |
+
kpnet_dropout=0.0,
|
| 142 |
+
add_extra_noise=False,
|
| 143 |
+
downsampling=False,
|
| 144 |
+
):
|
| 145 |
+
super().__init__()
|
| 146 |
+
|
| 147 |
+
self.add_extra_noise = add_extra_noise
|
| 148 |
+
|
| 149 |
+
self.cond_hop_length = cond_hop_length
|
| 150 |
+
self.conv_layers = len(dilations)
|
| 151 |
+
self.conv_kernel_size = conv_kernel_size
|
| 152 |
+
|
| 153 |
+
self.kernel_predictor = KernelPredictor(
|
| 154 |
+
cond_channels=cond_channels,
|
| 155 |
+
conv_in_channels=in_channels,
|
| 156 |
+
conv_out_channels=2 * in_channels,
|
| 157 |
+
conv_layers=len(dilations),
|
| 158 |
+
conv_kernel_size=conv_kernel_size,
|
| 159 |
+
kpnet_hidden_channels=kpnet_hidden_channels,
|
| 160 |
+
kpnet_conv_size=kpnet_conv_size,
|
| 161 |
+
kpnet_dropout=kpnet_dropout,
|
| 162 |
+
kpnet_nonlinear_activation_params={"negative_slope": lReLU_slope},
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
if downsampling:
|
| 166 |
+
self.convt_pre = nn.Sequential(
|
| 167 |
+
nn.LeakyReLU(lReLU_slope),
|
| 168 |
+
weight_norm(nn.Conv1d(in_channels, in_channels, 2 * stride + 1, padding="same")),
|
| 169 |
+
nn.AvgPool1d(stride, stride),
|
| 170 |
+
)
|
| 171 |
+
else:
|
| 172 |
+
if stride == 1:
|
| 173 |
+
self.convt_pre = nn.Sequential(
|
| 174 |
+
nn.LeakyReLU(lReLU_slope),
|
| 175 |
+
weight_norm(nn.Conv1d(in_channels, in_channels, 1)),
|
| 176 |
+
)
|
| 177 |
+
else:
|
| 178 |
+
self.convt_pre = nn.Sequential(
|
| 179 |
+
nn.LeakyReLU(lReLU_slope),
|
| 180 |
+
weight_norm(
|
| 181 |
+
nn.ConvTranspose1d(
|
| 182 |
+
in_channels,
|
| 183 |
+
in_channels,
|
| 184 |
+
2 * stride,
|
| 185 |
+
stride=stride,
|
| 186 |
+
padding=stride // 2 + stride % 2,
|
| 187 |
+
output_padding=stride % 2,
|
| 188 |
+
)
|
| 189 |
+
),
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
self.amp_block = AMPBlock(in_channels)
|
| 193 |
+
|
| 194 |
+
self.conv_blocks = nn.ModuleList()
|
| 195 |
+
for d in dilations:
|
| 196 |
+
self.conv_blocks.append(
|
| 197 |
+
nn.Sequential(
|
| 198 |
+
nn.LeakyReLU(lReLU_slope),
|
| 199 |
+
weight_norm(nn.Conv1d(in_channels, in_channels, conv_kernel_size, dilation=d, padding="same")),
|
| 200 |
+
nn.LeakyReLU(lReLU_slope),
|
| 201 |
+
)
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
def forward(self, x, c):
|
| 205 |
+
"""forward propagation of the location-variable convolutions.
|
| 206 |
+
Args:
|
| 207 |
+
x (Tensor): the input sequence (batch, in_channels, in_length)
|
| 208 |
+
c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
|
| 209 |
+
|
| 210 |
+
Returns:
|
| 211 |
+
Tensor: the output sequence (batch, in_channels, in_length)
|
| 212 |
+
"""
|
| 213 |
+
_, in_channels, _ = x.shape # (B, c_g, L')
|
| 214 |
+
|
| 215 |
+
x = self.convt_pre(x) # (B, c_g, stride * L')
|
| 216 |
+
|
| 217 |
+
# Add one amp block just after the upsampling
|
| 218 |
+
x = self.amp_block(x) # (B, c_g, stride * L')
|
| 219 |
+
|
| 220 |
+
kernels, bias = self.kernel_predictor(c)
|
| 221 |
+
|
| 222 |
+
if self.add_extra_noise:
|
| 223 |
+
# Add extra noise to part of the feature
|
| 224 |
+
a, b = x.chunk(2, dim=1)
|
| 225 |
+
b = b + torch.randn_like(b) * 0.1
|
| 226 |
+
x = torch.cat([a, b], dim=1)
|
| 227 |
+
|
| 228 |
+
for i, conv in enumerate(self.conv_blocks):
|
| 229 |
+
output = conv(x) # (B, c_g, stride * L')
|
| 230 |
+
|
| 231 |
+
k = kernels[:, i, :, :, :, :] # (B, 2 * c_g, c_g, kernel_size, cond_length)
|
| 232 |
+
b = bias[:, i, :, :] # (B, 2 * c_g, cond_length)
|
| 233 |
+
|
| 234 |
+
output = self.location_variable_convolution(
|
| 235 |
+
output, k, b, hop_size=self.cond_hop_length
|
| 236 |
+
) # (B, 2 * c_g, stride * L'): LVC
|
| 237 |
+
x = x + torch.sigmoid(output[:, :in_channels, :]) * torch.tanh(
|
| 238 |
+
output[:, in_channels:, :]
|
| 239 |
+
) # (B, c_g, stride * L'): GAU
|
| 240 |
+
|
| 241 |
+
return x
|
| 242 |
+
|
| 243 |
+
def location_variable_convolution(self, x, kernel, bias, dilation=1, hop_size=256):
|
| 244 |
+
"""perform location-variable convolution operation on the input sequence (x) using the local convolution kernl.
|
| 245 |
+
Time: 414 μs ± 309 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each), test on NVIDIA V100.
|
| 246 |
+
Args:
|
| 247 |
+
x (Tensor): the input sequence (batch, in_channels, in_length).
|
| 248 |
+
kernel (Tensor): the local convolution kernel (batch, in_channel, out_channels, kernel_size, kernel_length)
|
| 249 |
+
bias (Tensor): the bias for the local convolution (batch, out_channels, kernel_length)
|
| 250 |
+
dilation (int): the dilation of convolution.
|
| 251 |
+
hop_size (int): the hop_size of the conditioning sequence.
|
| 252 |
+
Returns:
|
| 253 |
+
(Tensor): the output sequence after performing local convolution. (batch, out_channels, in_length).
|
| 254 |
+
"""
|
| 255 |
+
batch, _, in_length = x.shape
|
| 256 |
+
batch, _, out_channels, kernel_size, kernel_length = kernel.shape
|
| 257 |
+
|
| 258 |
+
assert in_length == (
|
| 259 |
+
kernel_length * hop_size
|
| 260 |
+
), f"length of (x, kernel) is not matched, {in_length} != {kernel_length} * {hop_size}"
|
| 261 |
+
|
| 262 |
+
padding = dilation * int((kernel_size - 1) / 2)
|
| 263 |
+
x = F.pad(x, (padding, padding), "constant", 0) # (batch, in_channels, in_length + 2*padding)
|
| 264 |
+
x = x.unfold(2, hop_size + 2 * padding, hop_size) # (batch, in_channels, kernel_length, hop_size + 2*padding)
|
| 265 |
+
|
| 266 |
+
if hop_size < dilation:
|
| 267 |
+
x = F.pad(x, (0, dilation), "constant", 0)
|
| 268 |
+
x = x.unfold(
|
| 269 |
+
3, dilation, dilation
|
| 270 |
+
) # (batch, in_channels, kernel_length, (hop_size + 2*padding)/dilation, dilation)
|
| 271 |
+
x = x[:, :, :, :, :hop_size]
|
| 272 |
+
x = x.transpose(3, 4) # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation)
|
| 273 |
+
x = x.unfold(4, kernel_size, 1) # (batch, in_channels, kernel_length, dilation, _, kernel_size)
|
| 274 |
+
|
| 275 |
+
o = torch.einsum("bildsk,biokl->bolsd", x, kernel)
|
| 276 |
+
o = o.to(memory_format=torch.channels_last_3d)
|
| 277 |
+
bias = bias.unsqueeze(-1).unsqueeze(-1).to(memory_format=torch.channels_last_3d)
|
| 278 |
+
o = o + bias
|
| 279 |
+
o = o.contiguous().view(batch, out_channels, -1)
|
| 280 |
+
|
| 281 |
+
return o
|
modules/repos_static/resemble_enhance/enhancer/univnet/mrstft.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
|
| 3 |
+
# Copyright 2019 Tomoki Hayashi
|
| 4 |
+
# MIT License (https://opensource.org/licenses/MIT)
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from torch import nn
|
| 10 |
+
|
| 11 |
+
from ..hparams import HParams
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def _make_stft_cfg(hop_length, win_length=None):
|
| 15 |
+
if win_length is None:
|
| 16 |
+
win_length = 4 * hop_length
|
| 17 |
+
n_fft = 2 ** (win_length - 1).bit_length()
|
| 18 |
+
return dict(n_fft=n_fft, hop_length=hop_length, win_length=win_length)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def get_stft_cfgs(hp: HParams):
|
| 22 |
+
assert hp.wav_rate == 44100, f"wav_rate must be 44100, got {hp.wav_rate}"
|
| 23 |
+
return [_make_stft_cfg(h) for h in (100, 256, 512)]
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def stft(x, n_fft, hop_length, win_length, window):
|
| 27 |
+
dtype = x.dtype
|
| 28 |
+
x = torch.stft(x.float(), n_fft, hop_length, win_length, window, return_complex=True)
|
| 29 |
+
x = x.abs().to(dtype)
|
| 30 |
+
x = x.transpose(2, 1) # (b f t) -> (b t f)
|
| 31 |
+
return x
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class SpectralConvergengeLoss(nn.Module):
|
| 35 |
+
def forward(self, x_mag, y_mag):
|
| 36 |
+
"""Calculate forward propagation.
|
| 37 |
+
Args:
|
| 38 |
+
x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
|
| 39 |
+
y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
|
| 40 |
+
Returns:
|
| 41 |
+
Tensor: Spectral convergence loss value.
|
| 42 |
+
"""
|
| 43 |
+
return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro")
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class LogSTFTMagnitudeLoss(nn.Module):
|
| 47 |
+
def forward(self, x_mag, y_mag):
|
| 48 |
+
"""Calculate forward propagation.
|
| 49 |
+
Args:
|
| 50 |
+
x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
|
| 51 |
+
y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
|
| 52 |
+
Returns:
|
| 53 |
+
Tensor: Log STFT magnitude loss value.
|
| 54 |
+
"""
|
| 55 |
+
return F.l1_loss(torch.log1p(x_mag), torch.log1p(y_mag))
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class STFTLoss(nn.Module):
|
| 59 |
+
def __init__(self, hp, stft_cfg: dict, window="hann_window"):
|
| 60 |
+
super().__init__()
|
| 61 |
+
self.hp = hp
|
| 62 |
+
self.stft_cfg = stft_cfg
|
| 63 |
+
self.spectral_convergenge_loss = SpectralConvergengeLoss()
|
| 64 |
+
self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss()
|
| 65 |
+
self.register_buffer("window", getattr(torch, window)(stft_cfg["win_length"]), persistent=False)
|
| 66 |
+
|
| 67 |
+
def forward(self, x, y):
|
| 68 |
+
"""Calculate forward propagation.
|
| 69 |
+
Args:
|
| 70 |
+
x (Tensor): Predicted signal (B, T).
|
| 71 |
+
y (Tensor): Groundtruth signal (B, T).
|
| 72 |
+
Returns:
|
| 73 |
+
Tensor: Spectral convergence loss value.
|
| 74 |
+
Tensor: Log STFT magnitude loss value.
|
| 75 |
+
"""
|
| 76 |
+
stft_cfg = dict(self.stft_cfg)
|
| 77 |
+
x_mag = stft(x, **stft_cfg, window=self.window) # (b t) -> (b t f)
|
| 78 |
+
y_mag = stft(y, **stft_cfg, window=self.window)
|
| 79 |
+
sc_loss = self.spectral_convergenge_loss(x_mag, y_mag)
|
| 80 |
+
mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
|
| 81 |
+
return dict(sc=sc_loss, mag=mag_loss)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class MRSTFTLoss(nn.Module):
|
| 85 |
+
def __init__(self, hp: HParams, window="hann_window"):
|
| 86 |
+
"""Initialize Multi resolution STFT loss module.
|
| 87 |
+
Args:
|
| 88 |
+
resolutions (list): List of (FFT size, hop size, window length).
|
| 89 |
+
window (str): Window function type.
|
| 90 |
+
"""
|
| 91 |
+
super().__init__()
|
| 92 |
+
stft_cfgs = get_stft_cfgs(hp)
|
| 93 |
+
self.stft_losses = nn.ModuleList()
|
| 94 |
+
self.hp = hp
|
| 95 |
+
for c in stft_cfgs:
|
| 96 |
+
self.stft_losses += [STFTLoss(hp, c, window=window)]
|
| 97 |
+
|
| 98 |
+
def forward(self, x, y):
|
| 99 |
+
"""Calculate forward propagation.
|
| 100 |
+
Args:
|
| 101 |
+
x (Tensor): Predicted signal (b t).
|
| 102 |
+
y (Tensor): Groundtruth signal (b t).
|
| 103 |
+
Returns:
|
| 104 |
+
Tensor: Multi resolution spectral convergence loss value.
|
| 105 |
+
Tensor: Multi resolution log STFT magnitude loss value.
|
| 106 |
+
"""
|
| 107 |
+
assert x.dim() == 2 and y.dim() == 2, f"(b t) is expected, but got {x.shape} and {y.shape}."
|
| 108 |
+
|
| 109 |
+
dtype = x.dtype
|
| 110 |
+
|
| 111 |
+
x = x.float()
|
| 112 |
+
y = y.float()
|
| 113 |
+
|
| 114 |
+
# Align length
|
| 115 |
+
x = x[..., : y.shape[-1]]
|
| 116 |
+
y = y[..., : x.shape[-1]]
|
| 117 |
+
|
| 118 |
+
losses = {}
|
| 119 |
+
|
| 120 |
+
for f in self.stft_losses:
|
| 121 |
+
d = f(x, y)
|
| 122 |
+
for k, v in d.items():
|
| 123 |
+
losses.setdefault(k, []).append(v)
|
| 124 |
+
|
| 125 |
+
for k, v in losses.items():
|
| 126 |
+
losses[k] = torch.stack(v, dim=0).mean().to(dtype)
|
| 127 |
+
|
| 128 |
+
return losses
|
modules/repos_static/resemble_enhance/enhancer/univnet/univnet.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from torch import Tensor, nn
|
| 5 |
+
from torch.nn.utils.parametrizations import weight_norm
|
| 6 |
+
|
| 7 |
+
from ..hparams import HParams
|
| 8 |
+
from .lvcnet import LVCBlock
|
| 9 |
+
from .mrstft import MRSTFTLoss
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class UnivNet(nn.Module):
|
| 13 |
+
@property
|
| 14 |
+
def d_noise(self):
|
| 15 |
+
return 128
|
| 16 |
+
|
| 17 |
+
@property
|
| 18 |
+
def strides(self):
|
| 19 |
+
return [7, 5, 4, 3]
|
| 20 |
+
|
| 21 |
+
@property
|
| 22 |
+
def dilations(self):
|
| 23 |
+
return [1, 3, 9, 27]
|
| 24 |
+
|
| 25 |
+
@property
|
| 26 |
+
def nc(self):
|
| 27 |
+
return self.hp.univnet_nc
|
| 28 |
+
|
| 29 |
+
@property
|
| 30 |
+
def scale_factor(self) -> int:
|
| 31 |
+
return self.hp.hop_size
|
| 32 |
+
|
| 33 |
+
def __init__(self, hp: HParams, d_input):
|
| 34 |
+
super().__init__()
|
| 35 |
+
self.d_input = d_input
|
| 36 |
+
|
| 37 |
+
self.hp = hp
|
| 38 |
+
|
| 39 |
+
self.blocks = nn.ModuleList(
|
| 40 |
+
[
|
| 41 |
+
LVCBlock(
|
| 42 |
+
self.nc,
|
| 43 |
+
d_input,
|
| 44 |
+
stride=stride,
|
| 45 |
+
dilations=self.dilations,
|
| 46 |
+
cond_hop_length=hop_length,
|
| 47 |
+
kpnet_conv_size=3,
|
| 48 |
+
)
|
| 49 |
+
for stride, hop_length in zip(self.strides, np.cumprod(self.strides))
|
| 50 |
+
]
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
self.conv_pre = weight_norm(nn.Conv1d(self.d_noise, self.nc, 7, padding=3, padding_mode="reflect"))
|
| 54 |
+
|
| 55 |
+
self.conv_post = nn.Sequential(
|
| 56 |
+
nn.LeakyReLU(0.2),
|
| 57 |
+
weight_norm(nn.Conv1d(self.nc, 1, 7, padding=3, padding_mode="reflect")),
|
| 58 |
+
nn.Tanh(),
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
self.mrstft = MRSTFTLoss(hp)
|
| 62 |
+
|
| 63 |
+
@property
|
| 64 |
+
def eps(self):
|
| 65 |
+
return 1e-5
|
| 66 |
+
|
| 67 |
+
def forward(self, x: Tensor, y: Tensor | None = None, npad=10):
|
| 68 |
+
"""
|
| 69 |
+
Args:
|
| 70 |
+
x: (b c t), acoustic features
|
| 71 |
+
y: (b t), waveform
|
| 72 |
+
Returns:
|
| 73 |
+
z: (b t), waveform
|
| 74 |
+
"""
|
| 75 |
+
assert x.ndim == 3, "x must be 3D tensor"
|
| 76 |
+
assert y is None or y.ndim == 2, "y must be 2D tensor"
|
| 77 |
+
assert x.shape[1] == self.d_input, f"x.shape[1] must be {self.d_input}, but got {x.shape}"
|
| 78 |
+
assert npad >= 0, "npad must be positive or zero"
|
| 79 |
+
|
| 80 |
+
x = F.pad(x, (0, npad), "constant", 0)
|
| 81 |
+
z = torch.randn(x.shape[0], self.d_noise, x.shape[2]).to(x)
|
| 82 |
+
z = self.conv_pre(z) # (b c t)
|
| 83 |
+
|
| 84 |
+
for block in self.blocks:
|
| 85 |
+
z = block(z, x) # (b c t)
|
| 86 |
+
|
| 87 |
+
z = self.conv_post(z) # (b 1 t)
|
| 88 |
+
z = z[..., : -self.scale_factor * npad]
|
| 89 |
+
z = z.squeeze(1) # (b t)
|
| 90 |
+
|
| 91 |
+
if y is not None:
|
| 92 |
+
self.losses = self.mrstft(z, y)
|
| 93 |
+
|
| 94 |
+
return z
|
modules/repos_static/resemble_enhance/hparams.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from dataclasses import asdict, dataclass
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
from omegaconf import OmegaConf
|
| 6 |
+
from rich.console import Console
|
| 7 |
+
from rich.panel import Panel
|
| 8 |
+
from rich.table import Table
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
console = Console()
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def _make_stft_cfg(hop_length, win_length=None):
|
| 16 |
+
if win_length is None:
|
| 17 |
+
win_length = 4 * hop_length
|
| 18 |
+
n_fft = 2 ** (win_length - 1).bit_length()
|
| 19 |
+
return dict(n_fft=n_fft, hop_length=hop_length, win_length=win_length)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def _build_rich_table(rows, columns, title=None):
|
| 23 |
+
table = Table(title=title, header_style=None)
|
| 24 |
+
for column in columns:
|
| 25 |
+
table.add_column(column.capitalize(), justify="left")
|
| 26 |
+
for row in rows:
|
| 27 |
+
table.add_row(*map(str, row))
|
| 28 |
+
return Panel(table, expand=False)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _rich_print_dict(d, title="Config", key="Key", value="Value"):
|
| 32 |
+
console.print(_build_rich_table(d.items(), [key, value], title))
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@dataclass(frozen=True)
|
| 36 |
+
class HParams:
|
| 37 |
+
# Dataset
|
| 38 |
+
fg_dir: Path = Path("data/fg")
|
| 39 |
+
bg_dir: Path = Path("data/bg")
|
| 40 |
+
rir_dir: Path = Path("data/rir")
|
| 41 |
+
load_fg_only: bool = False
|
| 42 |
+
praat_augment_prob: float = 0
|
| 43 |
+
|
| 44 |
+
# Audio settings
|
| 45 |
+
wav_rate: int = 44_100
|
| 46 |
+
n_fft: int = 2048
|
| 47 |
+
win_size: int = 2048
|
| 48 |
+
hop_size: int = 420 # 9.5ms
|
| 49 |
+
num_mels: int = 128
|
| 50 |
+
stft_magnitude_min: float = 1e-4
|
| 51 |
+
preemphasis: float = 0.97
|
| 52 |
+
mix_alpha_range: tuple[float, float] = (0.2, 0.8)
|
| 53 |
+
|
| 54 |
+
# Training
|
| 55 |
+
nj: int = 64
|
| 56 |
+
training_seconds: float = 1.0
|
| 57 |
+
batch_size_per_gpu: int = 16
|
| 58 |
+
min_lr: float = 1e-5
|
| 59 |
+
max_lr: float = 1e-4
|
| 60 |
+
warmup_steps: int = 1000
|
| 61 |
+
max_steps: int = 1_000_000
|
| 62 |
+
gradient_clipping: float = 1.0
|
| 63 |
+
|
| 64 |
+
@property
|
| 65 |
+
def deepspeed_config(self):
|
| 66 |
+
return {
|
| 67 |
+
"train_micro_batch_size_per_gpu": self.batch_size_per_gpu,
|
| 68 |
+
"optimizer": {
|
| 69 |
+
"type": "Adam",
|
| 70 |
+
"params": {"lr": float(self.min_lr)},
|
| 71 |
+
},
|
| 72 |
+
"scheduler": {
|
| 73 |
+
"type": "WarmupDecayLR",
|
| 74 |
+
"params": {
|
| 75 |
+
"warmup_min_lr": float(self.min_lr),
|
| 76 |
+
"warmup_max_lr": float(self.max_lr),
|
| 77 |
+
"warmup_num_steps": self.warmup_steps,
|
| 78 |
+
"total_num_steps": self.max_steps,
|
| 79 |
+
"warmup_type": "linear",
|
| 80 |
+
},
|
| 81 |
+
},
|
| 82 |
+
"gradient_clipping": self.gradient_clipping,
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
@property
|
| 86 |
+
def stft_cfgs(self):
|
| 87 |
+
assert self.wav_rate == 44_100, f"wav_rate must be 44_100, got {self.wav_rate}"
|
| 88 |
+
return [_make_stft_cfg(h) for h in (100, 256, 512)]
|
| 89 |
+
|
| 90 |
+
@classmethod
|
| 91 |
+
def from_yaml(cls, path: Path) -> "HParams":
|
| 92 |
+
logger.info(f"Reading hparams from {path}")
|
| 93 |
+
# First merge to fix types (e.g., str -> Path)
|
| 94 |
+
return cls(**dict(OmegaConf.merge(cls(), OmegaConf.load(path))))
|
| 95 |
+
|
| 96 |
+
def save_if_not_exists(self, run_dir: Path):
|
| 97 |
+
path = run_dir / "hparams.yaml"
|
| 98 |
+
if path.exists():
|
| 99 |
+
logger.info(f"{path} already exists, not saving")
|
| 100 |
+
return
|
| 101 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 102 |
+
OmegaConf.save(asdict(self), str(path))
|
| 103 |
+
|
| 104 |
+
@classmethod
|
| 105 |
+
def load(cls, run_dir, yaml: Path | None = None):
|
| 106 |
+
hps = []
|
| 107 |
+
|
| 108 |
+
if (run_dir / "hparams.yaml").exists():
|
| 109 |
+
hps.append(cls.from_yaml(run_dir / "hparams.yaml"))
|
| 110 |
+
|
| 111 |
+
if yaml is not None:
|
| 112 |
+
hps.append(cls.from_yaml(yaml))
|
| 113 |
+
|
| 114 |
+
if len(hps) == 0:
|
| 115 |
+
hps.append(cls())
|
| 116 |
+
|
| 117 |
+
for hp in hps[1:]:
|
| 118 |
+
if hp != hps[0]:
|
| 119 |
+
errors = {}
|
| 120 |
+
for k, v in asdict(hp).items():
|
| 121 |
+
if getattr(hps[0], k) != v:
|
| 122 |
+
errors[k] = f"{getattr(hps[0], k)} != {v}"
|
| 123 |
+
raise ValueError(f"Found inconsistent hparams: {errors}, consider deleting {run_dir}")
|
| 124 |
+
|
| 125 |
+
return hps[0]
|
| 126 |
+
|
| 127 |
+
def print(self):
|
| 128 |
+
_rich_print_dict(asdict(self), title="HParams")
|
modules/repos_static/resemble_enhance/inference.py
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import time
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from torch.nn.utils.parametrize import remove_parametrizations
|
| 7 |
+
from torchaudio.functional import resample
|
| 8 |
+
from torchaudio.transforms import MelSpectrogram
|
| 9 |
+
from tqdm import trange
|
| 10 |
+
|
| 11 |
+
from .hparams import HParams
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@torch.inference_mode()
|
| 17 |
+
def inference_chunk(model, dwav, sr, device, npad=441):
|
| 18 |
+
assert model.hp.wav_rate == sr, f"Expected {model.hp.wav_rate} Hz, got {sr} Hz"
|
| 19 |
+
del sr
|
| 20 |
+
|
| 21 |
+
length = dwav.shape[-1]
|
| 22 |
+
abs_max = dwav.abs().max().clamp(min=1e-7)
|
| 23 |
+
|
| 24 |
+
assert dwav.dim() == 1, f"Expected 1D waveform, got {dwav.dim()}D"
|
| 25 |
+
dwav = dwav.to(device)
|
| 26 |
+
dwav = dwav / abs_max # Normalize
|
| 27 |
+
dwav = F.pad(dwav, (0, npad))
|
| 28 |
+
hwav = model(dwav[None])[0].cpu() # (T,)
|
| 29 |
+
hwav = hwav[:length] # Trim padding
|
| 30 |
+
hwav = hwav * abs_max # Unnormalize
|
| 31 |
+
|
| 32 |
+
return hwav
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def compute_corr(x, y):
|
| 36 |
+
return torch.fft.ifft(torch.fft.fft(x) * torch.fft.fft(y).conj()).abs()
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def compute_offset(chunk1, chunk2, sr=44100):
|
| 40 |
+
"""
|
| 41 |
+
Args:
|
| 42 |
+
chunk1: (T,)
|
| 43 |
+
chunk2: (T,)
|
| 44 |
+
Returns:
|
| 45 |
+
offset: int, offset in samples such that chunk1 ~= chunk2.roll(-offset)
|
| 46 |
+
"""
|
| 47 |
+
hop_length = sr // 200 # 5 ms resolution
|
| 48 |
+
win_length = hop_length * 4
|
| 49 |
+
n_fft = 2 ** (win_length - 1).bit_length()
|
| 50 |
+
|
| 51 |
+
mel_fn = MelSpectrogram(
|
| 52 |
+
sample_rate=sr,
|
| 53 |
+
n_fft=n_fft,
|
| 54 |
+
win_length=win_length,
|
| 55 |
+
hop_length=hop_length,
|
| 56 |
+
n_mels=80,
|
| 57 |
+
f_min=0.0,
|
| 58 |
+
f_max=sr // 2,
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
spec1 = mel_fn(chunk1).log1p()
|
| 62 |
+
spec2 = mel_fn(chunk2).log1p()
|
| 63 |
+
|
| 64 |
+
corr = compute_corr(spec1, spec2) # (F, T)
|
| 65 |
+
corr = corr.mean(dim=0) # (T,)
|
| 66 |
+
|
| 67 |
+
argmax = corr.argmax().item()
|
| 68 |
+
|
| 69 |
+
if argmax > len(corr) // 2:
|
| 70 |
+
argmax -= len(corr)
|
| 71 |
+
|
| 72 |
+
offset = -argmax * hop_length
|
| 73 |
+
|
| 74 |
+
return offset
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def merge_chunks(chunks, chunk_length, hop_length, sr=44100, length=None):
|
| 78 |
+
signal_length = (len(chunks) - 1) * hop_length + chunk_length
|
| 79 |
+
overlap_length = chunk_length - hop_length
|
| 80 |
+
signal = torch.zeros(signal_length, device=chunks[0].device)
|
| 81 |
+
|
| 82 |
+
fadein = torch.linspace(0, 1, overlap_length, device=chunks[0].device)
|
| 83 |
+
fadein = torch.cat([fadein, torch.ones(hop_length, device=chunks[0].device)])
|
| 84 |
+
fadeout = torch.linspace(1, 0, overlap_length, device=chunks[0].device)
|
| 85 |
+
fadeout = torch.cat([torch.ones(hop_length, device=chunks[0].device), fadeout])
|
| 86 |
+
|
| 87 |
+
for i, chunk in enumerate(chunks):
|
| 88 |
+
start = i * hop_length
|
| 89 |
+
end = start + chunk_length
|
| 90 |
+
|
| 91 |
+
if len(chunk) < chunk_length:
|
| 92 |
+
chunk = F.pad(chunk, (0, chunk_length - len(chunk)))
|
| 93 |
+
|
| 94 |
+
if i > 0:
|
| 95 |
+
pre_region = chunks[i - 1][-overlap_length:]
|
| 96 |
+
cur_region = chunk[:overlap_length]
|
| 97 |
+
offset = compute_offset(pre_region, cur_region, sr=sr)
|
| 98 |
+
start -= offset
|
| 99 |
+
end -= offset
|
| 100 |
+
|
| 101 |
+
if i == 0:
|
| 102 |
+
chunk = chunk * fadeout
|
| 103 |
+
elif i == len(chunks) - 1:
|
| 104 |
+
chunk = chunk * fadein
|
| 105 |
+
else:
|
| 106 |
+
chunk = chunk * fadein * fadeout
|
| 107 |
+
|
| 108 |
+
signal[start:end] += chunk[: len(signal[start:end])]
|
| 109 |
+
|
| 110 |
+
signal = signal[:length]
|
| 111 |
+
|
| 112 |
+
return signal
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def remove_weight_norm_recursively(module):
|
| 116 |
+
for _, module in module.named_modules():
|
| 117 |
+
try:
|
| 118 |
+
remove_parametrizations(module, "weight")
|
| 119 |
+
except Exception:
|
| 120 |
+
pass
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def inference(model, dwav, sr, device, chunk_seconds: float = 30.0, overlap_seconds: float = 1.0):
|
| 124 |
+
remove_weight_norm_recursively(model)
|
| 125 |
+
|
| 126 |
+
hp: HParams = model.hp
|
| 127 |
+
|
| 128 |
+
dwav = resample(
|
| 129 |
+
dwav,
|
| 130 |
+
orig_freq=sr,
|
| 131 |
+
new_freq=hp.wav_rate,
|
| 132 |
+
lowpass_filter_width=64,
|
| 133 |
+
rolloff=0.9475937167399596,
|
| 134 |
+
resampling_method="sinc_interp_kaiser",
|
| 135 |
+
beta=14.769656459379492,
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
del sr # Everything is in hp.wav_rate now
|
| 139 |
+
|
| 140 |
+
sr = hp.wav_rate
|
| 141 |
+
|
| 142 |
+
if torch.cuda.is_available():
|
| 143 |
+
torch.cuda.synchronize()
|
| 144 |
+
|
| 145 |
+
start_time = time.perf_counter()
|
| 146 |
+
|
| 147 |
+
chunk_length = int(sr * chunk_seconds)
|
| 148 |
+
overlap_length = int(sr * overlap_seconds)
|
| 149 |
+
hop_length = chunk_length - overlap_length
|
| 150 |
+
|
| 151 |
+
chunks = []
|
| 152 |
+
for start in trange(0, dwav.shape[-1], hop_length):
|
| 153 |
+
chunks.append(inference_chunk(model, dwav[start : start + chunk_length], sr, device))
|
| 154 |
+
|
| 155 |
+
hwav = merge_chunks(chunks, chunk_length, hop_length, sr=sr, length=dwav.shape[-1])
|
| 156 |
+
|
| 157 |
+
if torch.cuda.is_available():
|
| 158 |
+
torch.cuda.synchronize()
|
| 159 |
+
|
| 160 |
+
elapsed_time = time.perf_counter() - start_time
|
| 161 |
+
logger.info(f"Elapsed time: {elapsed_time:.3f} s, {hwav.shape[-1] / elapsed_time / 1000:.3f} kHz")
|
| 162 |
+
|
| 163 |
+
return hwav, sr
|
modules/repos_static/resemble_enhance/melspec.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
from torch import nn
|
| 4 |
+
from torchaudio.transforms import MelSpectrogram as TorchMelSpectrogram
|
| 5 |
+
|
| 6 |
+
from .hparams import HParams
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class MelSpectrogram(nn.Module):
|
| 10 |
+
def __init__(self, hp: HParams):
|
| 11 |
+
"""
|
| 12 |
+
Torch implementation of Resemble's mel extraction.
|
| 13 |
+
Note that the values are NOT identical to librosa's implementation
|
| 14 |
+
due to floating point precisions.
|
| 15 |
+
"""
|
| 16 |
+
super().__init__()
|
| 17 |
+
self.hp = hp
|
| 18 |
+
self.melspec = TorchMelSpectrogram(
|
| 19 |
+
hp.wav_rate,
|
| 20 |
+
n_fft=hp.n_fft,
|
| 21 |
+
win_length=hp.win_size,
|
| 22 |
+
hop_length=hp.hop_size,
|
| 23 |
+
f_min=0,
|
| 24 |
+
f_max=hp.wav_rate // 2,
|
| 25 |
+
n_mels=hp.num_mels,
|
| 26 |
+
power=1,
|
| 27 |
+
normalized=False,
|
| 28 |
+
# NOTE: Folowing librosa's default.
|
| 29 |
+
pad_mode="constant",
|
| 30 |
+
norm="slaney",
|
| 31 |
+
mel_scale="slaney",
|
| 32 |
+
)
|
| 33 |
+
self.register_buffer("stft_magnitude_min", torch.FloatTensor([hp.stft_magnitude_min]))
|
| 34 |
+
self.min_level_db = 20 * np.log10(hp.stft_magnitude_min)
|
| 35 |
+
self.preemphasis = hp.preemphasis
|
| 36 |
+
self.hop_size = hp.hop_size
|
| 37 |
+
|
| 38 |
+
def forward(self, wav, pad=True):
|
| 39 |
+
"""
|
| 40 |
+
Args:
|
| 41 |
+
wav: [B, T]
|
| 42 |
+
"""
|
| 43 |
+
device = wav.device
|
| 44 |
+
if wav.is_mps:
|
| 45 |
+
wav = wav.cpu()
|
| 46 |
+
self.to(wav.device)
|
| 47 |
+
if self.preemphasis > 0:
|
| 48 |
+
wav = torch.nn.functional.pad(wav, [1, 0], value=0)
|
| 49 |
+
wav = wav[..., 1:] - self.preemphasis * wav[..., :-1]
|
| 50 |
+
mel = self.melspec(wav)
|
| 51 |
+
mel = self._amp_to_db(mel)
|
| 52 |
+
mel_normed = self._normalize(mel)
|
| 53 |
+
assert not pad or mel_normed.shape[-1] == 1 + wav.shape[-1] // self.hop_size # Sanity check
|
| 54 |
+
mel_normed = mel_normed.to(device)
|
| 55 |
+
return mel_normed # (M, T)
|
| 56 |
+
|
| 57 |
+
def _normalize(self, s, headroom_db=15):
|
| 58 |
+
return (s - self.min_level_db) / (-self.min_level_db + headroom_db)
|
| 59 |
+
|
| 60 |
+
def _amp_to_db(self, x):
|
| 61 |
+
return x.clamp_min(self.hp.stft_magnitude_min).log10() * 20
|
modules/repos_static/resemble_enhance/utils/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .logging import setup_logging
|
| 2 |
+
from .utils import save_mels, tree_map
|
modules/repos_static/resemble_enhance/utils/control.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import selectors
|
| 3 |
+
import sys
|
| 4 |
+
from functools import cache
|
| 5 |
+
|
| 6 |
+
from .distributed import global_leader_only
|
| 7 |
+
|
| 8 |
+
_logger = logging.getLogger(__name__)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@cache
|
| 12 |
+
def _get_stdin_selector():
|
| 13 |
+
selector = selectors.DefaultSelector()
|
| 14 |
+
selector.register(fileobj=sys.stdin, events=selectors.EVENT_READ)
|
| 15 |
+
return selector
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@global_leader_only(boardcast_return=True)
|
| 19 |
+
def non_blocking_input():
|
| 20 |
+
s = ""
|
| 21 |
+
selector = _get_stdin_selector()
|
| 22 |
+
events = selector.select(timeout=0)
|
| 23 |
+
for key, _ in events:
|
| 24 |
+
s: str = key.fileobj.readline().strip()
|
| 25 |
+
_logger.info(f'Get stdin "{s}".')
|
| 26 |
+
return s
|
modules/repos_static/resemble_enhance/utils/logging.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
from rich.logging import RichHandler
|
| 5 |
+
|
| 6 |
+
from .distributed import global_leader_only
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@global_leader_only
|
| 10 |
+
def setup_logging(run_dir):
|
| 11 |
+
handlers = []
|
| 12 |
+
stdout_handler = RichHandler()
|
| 13 |
+
stdout_handler.setLevel(logging.INFO)
|
| 14 |
+
handlers.append(stdout_handler)
|
| 15 |
+
|
| 16 |
+
if run_dir is not None:
|
| 17 |
+
filename = Path(run_dir) / f"log.txt"
|
| 18 |
+
filename.parent.mkdir(parents=True, exist_ok=True)
|
| 19 |
+
file_handler = logging.FileHandler(filename, mode="a")
|
| 20 |
+
file_handler.setLevel(logging.DEBUG)
|
| 21 |
+
handlers.append(file_handler)
|
| 22 |
+
|
| 23 |
+
# Update all existing loggers
|
| 24 |
+
for name in ["DeepSpeed"]:
|
| 25 |
+
logger = logging.getLogger(name)
|
| 26 |
+
if isinstance(logger, logging.Logger):
|
| 27 |
+
for handler in list(logger.handlers):
|
| 28 |
+
logger.removeHandler(handler)
|
| 29 |
+
for handler in handlers:
|
| 30 |
+
logger.addHandler(handler)
|
| 31 |
+
|
| 32 |
+
# Set the default logger
|
| 33 |
+
logging.basicConfig(
|
| 34 |
+
level=logging.getLevelName("INFO"),
|
| 35 |
+
format="%(message)s",
|
| 36 |
+
datefmt="[%X]",
|
| 37 |
+
handlers=handlers,
|
| 38 |
+
)
|
modules/repos_static/resemble_enhance/utils/utils.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Callable, TypeVar, overload
|
| 2 |
+
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def save_mels(path, *, targ_mel, pred_mel, cond_mel):
|
| 8 |
+
n = 3 if cond_mel is None else 4
|
| 9 |
+
|
| 10 |
+
plt.figure(figsize=(10, n * 4))
|
| 11 |
+
|
| 12 |
+
i = 1
|
| 13 |
+
|
| 14 |
+
plt.subplot(n, 1, i)
|
| 15 |
+
plt.imshow(pred_mel, origin="lower", interpolation="none")
|
| 16 |
+
plt.title(f"Pred mel {pred_mel.shape}")
|
| 17 |
+
i += 1
|
| 18 |
+
|
| 19 |
+
plt.subplot(n, 1, i)
|
| 20 |
+
plt.imshow(targ_mel, origin="lower", interpolation="none")
|
| 21 |
+
plt.title(f"GT mel {targ_mel.shape}")
|
| 22 |
+
i += 1
|
| 23 |
+
|
| 24 |
+
plt.subplot(n, 1, i)
|
| 25 |
+
pred_mel = pred_mel[:, : targ_mel.shape[1]]
|
| 26 |
+
targ_mel = targ_mel[:, : pred_mel.shape[1]]
|
| 27 |
+
plt.imshow(np.abs(pred_mel - targ_mel), origin="lower", interpolation="none")
|
| 28 |
+
plt.title(f"Diff mel {pred_mel.shape}, mse={np.mean((pred_mel - targ_mel)**2):.4f}")
|
| 29 |
+
i += 1
|
| 30 |
+
|
| 31 |
+
if cond_mel is not None:
|
| 32 |
+
plt.subplot(n, 1, i)
|
| 33 |
+
plt.imshow(cond_mel, origin="lower", interpolation="none")
|
| 34 |
+
plt.title(f"Cond mel {cond_mel.shape}")
|
| 35 |
+
i += 1
|
| 36 |
+
|
| 37 |
+
plt.savefig(path, dpi=480)
|
| 38 |
+
plt.close()
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
T = TypeVar("T")
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@overload
|
| 45 |
+
def tree_map(fn: Callable, x: list[T]) -> list[T]:
|
| 46 |
+
...
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
@overload
|
| 50 |
+
def tree_map(fn: Callable, x: tuple[T]) -> tuple[T]:
|
| 51 |
+
...
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
@overload
|
| 55 |
+
def tree_map(fn: Callable, x: dict[str, T]) -> dict[str, T]:
|
| 56 |
+
...
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
@overload
|
| 60 |
+
def tree_map(fn: Callable, x: T) -> T:
|
| 61 |
+
...
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def tree_map(fn: Callable, x):
|
| 65 |
+
if isinstance(x, list):
|
| 66 |
+
x = [tree_map(fn, xi) for xi in x]
|
| 67 |
+
elif isinstance(x, tuple):
|
| 68 |
+
x = (tree_map(fn, xi) for xi in x)
|
| 69 |
+
elif isinstance(x, dict):
|
| 70 |
+
x = {k: tree_map(fn, v) for k, v in x.items()}
|
| 71 |
+
else:
|
| 72 |
+
x = fn(x)
|
| 73 |
+
return x
|
modules/speaker.py
CHANGED
|
@@ -99,6 +99,10 @@ class SpeakerManager:
|
|
| 99 |
self.speakers[speaker_file] = Speaker.from_file(
|
| 100 |
self.speaker_dir + speaker_file
|
| 101 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
|
| 103 |
def list_speakers(self):
|
| 104 |
return list(self.speakers.values())
|
|
|
|
| 99 |
self.speakers[speaker_file] = Speaker.from_file(
|
| 100 |
self.speaker_dir + speaker_file
|
| 101 |
)
|
| 102 |
+
# 检查是否有被删除的,同步到 speakers
|
| 103 |
+
for fname, spk in self.speakers.items():
|
| 104 |
+
if not os.path.exists(self.speaker_dir + fname):
|
| 105 |
+
del self.speakers[fname]
|
| 106 |
|
| 107 |
def list_speakers(self):
|
| 108 |
return list(self.speakers.values())
|
modules/webui/speaker/__init__.py
ADDED
|
File without changes
|
modules/webui/speaker/speaker_creator.py
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import torch
|
| 3 |
+
from modules.speaker import Speaker
|
| 4 |
+
from modules.utils.SeedContext import SeedContext
|
| 5 |
+
from modules.hf import spaces
|
| 6 |
+
from modules.models import load_chat_tts
|
| 7 |
+
from modules.utils.rng import np_rng
|
| 8 |
+
from modules.webui.webui_utils import get_speakers, tts_generate
|
| 9 |
+
|
| 10 |
+
import tempfile
|
| 11 |
+
|
| 12 |
+
names_list = [
|
| 13 |
+
"Alice",
|
| 14 |
+
"Bob",
|
| 15 |
+
"Carol",
|
| 16 |
+
"Carlos",
|
| 17 |
+
"Charlie",
|
| 18 |
+
"Chuck",
|
| 19 |
+
"Chad",
|
| 20 |
+
"Craig",
|
| 21 |
+
"Dan",
|
| 22 |
+
"Dave",
|
| 23 |
+
"David",
|
| 24 |
+
"Erin",
|
| 25 |
+
"Eve",
|
| 26 |
+
"Yves",
|
| 27 |
+
"Faythe",
|
| 28 |
+
"Frank",
|
| 29 |
+
"Grace",
|
| 30 |
+
"Heidi",
|
| 31 |
+
"Ivan",
|
| 32 |
+
"Judy",
|
| 33 |
+
"Mallory",
|
| 34 |
+
"Mallet",
|
| 35 |
+
"Darth",
|
| 36 |
+
"Michael",
|
| 37 |
+
"Mike",
|
| 38 |
+
"Niaj",
|
| 39 |
+
"Olivia",
|
| 40 |
+
"Oscar",
|
| 41 |
+
"Peggy",
|
| 42 |
+
"Pat",
|
| 43 |
+
"Rupert",
|
| 44 |
+
"Sybil",
|
| 45 |
+
"Trent",
|
| 46 |
+
"Ted",
|
| 47 |
+
"Trudy",
|
| 48 |
+
"Victor",
|
| 49 |
+
"Vanna",
|
| 50 |
+
"Walter",
|
| 51 |
+
"Wendy",
|
| 52 |
+
]
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
@torch.inference_mode()
|
| 56 |
+
@spaces.GPU
|
| 57 |
+
def create_spk_from_seed(
|
| 58 |
+
seed: int,
|
| 59 |
+
name: str,
|
| 60 |
+
gender: str,
|
| 61 |
+
desc: str,
|
| 62 |
+
):
|
| 63 |
+
chat_tts = load_chat_tts()
|
| 64 |
+
with SeedContext(seed):
|
| 65 |
+
emb = chat_tts.sample_random_speaker()
|
| 66 |
+
spk = Speaker(seed=-2, name=name, gender=gender, describe=desc)
|
| 67 |
+
spk.emb = emb
|
| 68 |
+
|
| 69 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".pt") as tmp_file:
|
| 70 |
+
torch.save(spk, tmp_file)
|
| 71 |
+
tmp_file_path = tmp_file.name
|
| 72 |
+
|
| 73 |
+
return tmp_file_path
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
@torch.inference_mode()
|
| 77 |
+
@spaces.GPU
|
| 78 |
+
def test_spk_voice(seed: int, text: str):
|
| 79 |
+
return tts_generate(
|
| 80 |
+
spk=seed,
|
| 81 |
+
text=text,
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def random_speaker():
|
| 86 |
+
seed = np_rng()
|
| 87 |
+
name = names_list[seed % len(names_list)]
|
| 88 |
+
return seed, name
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
creator_ui_desc = """
|
| 92 |
+
## Speaker Creator
|
| 93 |
+
使用本面板快捷抽卡生成 speaker.pt 文件。
|
| 94 |
+
|
| 95 |
+
1. **生成说话人**:输入种子、名字、性别和描述。点击 "Generate speaker.pt" 按钮,生成的说话人配置会保存为.pt文件。
|
| 96 |
+
2. **测试说话人声音**:输入测试文本。点击 "Test Voice" 按钮,生成的音频会在 "Output Audio" 中播放。
|
| 97 |
+
3. **随机生成说话人**:点击 "Random Speaker" 按钮,随机生成一个种子和名字,可以进一步编辑其他信息并测试。
|
| 98 |
+
"""
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def speaker_creator_ui():
|
| 102 |
+
def on_generate(seed, name, gender, desc):
|
| 103 |
+
file_path = create_spk_from_seed(seed, name, gender, desc)
|
| 104 |
+
return file_path
|
| 105 |
+
|
| 106 |
+
def create_test_voice_card(seed_input):
|
| 107 |
+
with gr.Group():
|
| 108 |
+
gr.Markdown("🎤Test voice")
|
| 109 |
+
with gr.Row():
|
| 110 |
+
test_voice_btn = gr.Button("Test Voice", variant="secondary")
|
| 111 |
+
|
| 112 |
+
with gr.Column(scale=4):
|
| 113 |
+
test_text = gr.Textbox(
|
| 114 |
+
label="Test Text",
|
| 115 |
+
placeholder="Please input test text",
|
| 116 |
+
value="说话人测试 123456789 [uv_break] ok, test done [lbreak]",
|
| 117 |
+
)
|
| 118 |
+
with gr.Row():
|
| 119 |
+
current_seed = gr.Label(label="Current Seed", value=-1)
|
| 120 |
+
with gr.Column(scale=4):
|
| 121 |
+
output_audio = gr.Audio(label="Output Audio")
|
| 122 |
+
|
| 123 |
+
test_voice_btn.click(
|
| 124 |
+
fn=test_spk_voice,
|
| 125 |
+
inputs=[seed_input, test_text],
|
| 126 |
+
outputs=[output_audio],
|
| 127 |
+
)
|
| 128 |
+
test_voice_btn.click(
|
| 129 |
+
fn=lambda x: x,
|
| 130 |
+
inputs=[seed_input],
|
| 131 |
+
outputs=[current_seed],
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
gr.Markdown(creator_ui_desc)
|
| 135 |
+
|
| 136 |
+
with gr.Row():
|
| 137 |
+
with gr.Column(scale=2):
|
| 138 |
+
with gr.Group():
|
| 139 |
+
gr.Markdown("ℹ️Speaker info")
|
| 140 |
+
seed_input = gr.Number(label="Seed", value=2)
|
| 141 |
+
name_input = gr.Textbox(
|
| 142 |
+
label="Name", placeholder="Enter speaker name", value="Bob"
|
| 143 |
+
)
|
| 144 |
+
gender_input = gr.Textbox(
|
| 145 |
+
label="Gender", placeholder="Enter gender", value="*"
|
| 146 |
+
)
|
| 147 |
+
desc_input = gr.Textbox(
|
| 148 |
+
label="Description",
|
| 149 |
+
placeholder="Enter description",
|
| 150 |
+
)
|
| 151 |
+
random_button = gr.Button("Random Speaker")
|
| 152 |
+
with gr.Group():
|
| 153 |
+
gr.Markdown("🔊Generate speaker.pt")
|
| 154 |
+
generate_button = gr.Button("Save .pt file")
|
| 155 |
+
output_file = gr.File(label="Save to File")
|
| 156 |
+
with gr.Column(scale=5):
|
| 157 |
+
create_test_voice_card(seed_input=seed_input)
|
| 158 |
+
create_test_voice_card(seed_input=seed_input)
|
| 159 |
+
create_test_voice_card(seed_input=seed_input)
|
| 160 |
+
create_test_voice_card(seed_input=seed_input)
|
| 161 |
+
|
| 162 |
+
random_button.click(
|
| 163 |
+
random_speaker,
|
| 164 |
+
outputs=[seed_input, name_input],
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
generate_button.click(
|
| 168 |
+
fn=on_generate,
|
| 169 |
+
inputs=[seed_input, name_input, gender_input, desc_input],
|
| 170 |
+
outputs=[output_file],
|
| 171 |
+
)
|
modules/webui/speaker/speaker_merger.py
ADDED
|
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import io
|
| 2 |
+
import gradio as gr
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from modules.hf import spaces
|
| 6 |
+
from modules.webui.webui_utils import get_speakers, tts_generate
|
| 7 |
+
from modules.speaker import speaker_mgr, Speaker
|
| 8 |
+
|
| 9 |
+
import tempfile
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def spk_to_tensor(spk):
|
| 13 |
+
spk = spk.split(" : ")[1].strip() if " : " in spk else spk
|
| 14 |
+
if spk == "None" or spk == "":
|
| 15 |
+
return None
|
| 16 |
+
return speaker_mgr.get_speaker(spk).emb
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def get_speaker_show_name(spk):
|
| 20 |
+
if spk.gender == "*" or spk.gender == "":
|
| 21 |
+
return spk.name
|
| 22 |
+
return f"{spk.gender} : {spk.name}"
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def merge_spk(
|
| 26 |
+
spk_a,
|
| 27 |
+
spk_a_w,
|
| 28 |
+
spk_b,
|
| 29 |
+
spk_b_w,
|
| 30 |
+
spk_c,
|
| 31 |
+
spk_c_w,
|
| 32 |
+
spk_d,
|
| 33 |
+
spk_d_w,
|
| 34 |
+
):
|
| 35 |
+
tensor_a = spk_to_tensor(spk_a)
|
| 36 |
+
tensor_b = spk_to_tensor(spk_b)
|
| 37 |
+
tensor_c = spk_to_tensor(spk_c)
|
| 38 |
+
tensor_d = spk_to_tensor(spk_d)
|
| 39 |
+
|
| 40 |
+
assert (
|
| 41 |
+
tensor_a is not None
|
| 42 |
+
or tensor_b is not None
|
| 43 |
+
or tensor_c is not None
|
| 44 |
+
or tensor_d is not None
|
| 45 |
+
), "At least one speaker should be selected"
|
| 46 |
+
|
| 47 |
+
merge_tensor = torch.zeros_like(
|
| 48 |
+
tensor_a
|
| 49 |
+
if tensor_a is not None
|
| 50 |
+
else (
|
| 51 |
+
tensor_b
|
| 52 |
+
if tensor_b is not None
|
| 53 |
+
else tensor_c if tensor_c is not None else tensor_d
|
| 54 |
+
)
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
total_weight = 0
|
| 58 |
+
if tensor_a is not None:
|
| 59 |
+
merge_tensor += spk_a_w * tensor_a
|
| 60 |
+
total_weight += spk_a_w
|
| 61 |
+
if tensor_b is not None:
|
| 62 |
+
merge_tensor += spk_b_w * tensor_b
|
| 63 |
+
total_weight += spk_b_w
|
| 64 |
+
if tensor_c is not None:
|
| 65 |
+
merge_tensor += spk_c_w * tensor_c
|
| 66 |
+
total_weight += spk_c_w
|
| 67 |
+
if tensor_d is not None:
|
| 68 |
+
merge_tensor += spk_d_w * tensor_d
|
| 69 |
+
total_weight += spk_d_w
|
| 70 |
+
|
| 71 |
+
if total_weight > 0:
|
| 72 |
+
merge_tensor /= total_weight
|
| 73 |
+
|
| 74 |
+
merged_spk = Speaker.from_tensor(merge_tensor)
|
| 75 |
+
merged_spk.name = "<MIX>"
|
| 76 |
+
|
| 77 |
+
return merged_spk
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
@torch.inference_mode()
|
| 81 |
+
@spaces.GPU
|
| 82 |
+
def merge_and_test_spk_voice(
|
| 83 |
+
spk_a, spk_a_w, spk_b, spk_b_w, spk_c, spk_c_w, spk_d, spk_d_w, test_text
|
| 84 |
+
):
|
| 85 |
+
merged_spk = merge_spk(
|
| 86 |
+
spk_a,
|
| 87 |
+
spk_a_w,
|
| 88 |
+
spk_b,
|
| 89 |
+
spk_b_w,
|
| 90 |
+
spk_c,
|
| 91 |
+
spk_c_w,
|
| 92 |
+
spk_d,
|
| 93 |
+
spk_d_w,
|
| 94 |
+
)
|
| 95 |
+
return tts_generate(
|
| 96 |
+
spk=merged_spk,
|
| 97 |
+
text=test_text,
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
@torch.inference_mode()
|
| 102 |
+
@spaces.GPU
|
| 103 |
+
def merge_spk_to_file(
|
| 104 |
+
spk_a,
|
| 105 |
+
spk_a_w,
|
| 106 |
+
spk_b,
|
| 107 |
+
spk_b_w,
|
| 108 |
+
spk_c,
|
| 109 |
+
spk_c_w,
|
| 110 |
+
spk_d,
|
| 111 |
+
spk_d_w,
|
| 112 |
+
speaker_name,
|
| 113 |
+
speaker_gender,
|
| 114 |
+
speaker_desc,
|
| 115 |
+
):
|
| 116 |
+
merged_spk = merge_spk(
|
| 117 |
+
spk_a, spk_a_w, spk_b, spk_b_w, spk_c, spk_c_w, spk_d, spk_d_w
|
| 118 |
+
)
|
| 119 |
+
merged_spk.name = speaker_name
|
| 120 |
+
merged_spk.gender = speaker_gender
|
| 121 |
+
merged_spk.desc = speaker_desc
|
| 122 |
+
|
| 123 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".pt") as tmp_file:
|
| 124 |
+
torch.save(merged_spk, tmp_file)
|
| 125 |
+
tmp_file_path = tmp_file.name
|
| 126 |
+
|
| 127 |
+
return tmp_file_path
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
merge_desc = """
|
| 131 |
+
## Speaker Merger
|
| 132 |
+
|
| 133 |
+
在本面板中,您可以选择多个说话人并指定他们的权重,合成新的语音并进行测试。以下是各个功能的详细说明:
|
| 134 |
+
|
| 135 |
+
1. 选择说话人: 您可以从下拉菜单中选择最多四个说话人(A、B、C、D),每个说话人都有一个对应的权重滑块,范围从0到10。权重决定了每个说话人在合成语音中的影响程度。
|
| 136 |
+
2. 合成语音: 在选择好说话人和设置好权重后,您可以在“Test Text”框中输入要测试的文本,然后点击“测试语音”按钮来生成并播放合成的语音。
|
| 137 |
+
3. 保存说话人: 您还可以在右侧的“说话人信息”部分填写新的说话人的名称、性别和描述,并点击“Save Speaker”按钮来保存合成的说话人。保存后的说话人文件将显示在“Merged Speaker”栏中,供下载使用。
|
| 138 |
+
"""
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def get_spk_choices():
|
| 142 |
+
speakers = get_speakers()
|
| 143 |
+
|
| 144 |
+
speaker_names = ["None"] + [get_speaker_show_name(speaker) for speaker in speakers]
|
| 145 |
+
return speaker_names
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
# 显示 a b c d 四个选择框,选择一个或多个,然后可以试音,并导出
|
| 149 |
+
def create_speaker_merger():
|
| 150 |
+
speaker_names = get_spk_choices()
|
| 151 |
+
|
| 152 |
+
gr.Markdown(merge_desc)
|
| 153 |
+
|
| 154 |
+
def spk_picker(label_tail: str):
|
| 155 |
+
with gr.Row():
|
| 156 |
+
spk_a = gr.Dropdown(
|
| 157 |
+
choices=speaker_names, value="None", label=f"Speaker {label_tail}"
|
| 158 |
+
)
|
| 159 |
+
refresh_a_btn = gr.Button("🔄", variant="secondary")
|
| 160 |
+
|
| 161 |
+
def refresh_a():
|
| 162 |
+
speaker_mgr.refresh_speakers()
|
| 163 |
+
speaker_names = get_spk_choices()
|
| 164 |
+
return gr.update(choices=speaker_names)
|
| 165 |
+
|
| 166 |
+
refresh_a_btn.click(refresh_a, outputs=[spk_a])
|
| 167 |
+
spk_a_w = gr.Slider(
|
| 168 |
+
value=1,
|
| 169 |
+
minimum=0,
|
| 170 |
+
maximum=10,
|
| 171 |
+
step=0.1,
|
| 172 |
+
label=f"Weight {label_tail}",
|
| 173 |
+
)
|
| 174 |
+
return spk_a, spk_a_w
|
| 175 |
+
|
| 176 |
+
with gr.Row():
|
| 177 |
+
with gr.Column(scale=5):
|
| 178 |
+
with gr.Row():
|
| 179 |
+
with gr.Group():
|
| 180 |
+
spk_a, spk_a_w = spk_picker("A")
|
| 181 |
+
|
| 182 |
+
with gr.Group():
|
| 183 |
+
spk_b, spk_b_w = spk_picker("B")
|
| 184 |
+
|
| 185 |
+
with gr.Group():
|
| 186 |
+
spk_c, spk_c_w = spk_picker("C")
|
| 187 |
+
|
| 188 |
+
with gr.Group():
|
| 189 |
+
spk_d, spk_d_w = spk_picker("D")
|
| 190 |
+
|
| 191 |
+
with gr.Row():
|
| 192 |
+
with gr.Column(scale=3):
|
| 193 |
+
with gr.Group():
|
| 194 |
+
gr.Markdown("🎤Test voice")
|
| 195 |
+
with gr.Row():
|
| 196 |
+
test_voice_btn = gr.Button(
|
| 197 |
+
"Test Voice", variant="secondary"
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
with gr.Column(scale=4):
|
| 201 |
+
test_text = gr.Textbox(
|
| 202 |
+
label="Test Text",
|
| 203 |
+
placeholder="Please input test text",
|
| 204 |
+
value="说话人合并测试 123456789 [uv_break] ok, test done [lbreak]",
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
output_audio = gr.Audio(label="Output Audio")
|
| 208 |
+
|
| 209 |
+
with gr.Column(scale=1):
|
| 210 |
+
with gr.Group():
|
| 211 |
+
gr.Markdown("🗃️Save to file")
|
| 212 |
+
|
| 213 |
+
speaker_name = gr.Textbox(label="Name", value="forge_speaker_merged")
|
| 214 |
+
speaker_gender = gr.Textbox(label="Gender", value="*")
|
| 215 |
+
speaker_desc = gr.Textbox(label="Description", value="merged speaker")
|
| 216 |
+
|
| 217 |
+
save_btn = gr.Button("Save Speaker", variant="primary")
|
| 218 |
+
|
| 219 |
+
merged_spker = gr.File(
|
| 220 |
+
label="Merged Speaker", interactive=False, type="binary"
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
test_voice_btn.click(
|
| 224 |
+
merge_and_test_spk_voice,
|
| 225 |
+
inputs=[
|
| 226 |
+
spk_a,
|
| 227 |
+
spk_a_w,
|
| 228 |
+
spk_b,
|
| 229 |
+
spk_b_w,
|
| 230 |
+
spk_c,
|
| 231 |
+
spk_c_w,
|
| 232 |
+
spk_d,
|
| 233 |
+
spk_d_w,
|
| 234 |
+
test_text,
|
| 235 |
+
],
|
| 236 |
+
outputs=[output_audio],
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
save_btn.click(
|
| 240 |
+
merge_spk_to_file,
|
| 241 |
+
inputs=[
|
| 242 |
+
spk_a,
|
| 243 |
+
spk_a_w,
|
| 244 |
+
spk_b,
|
| 245 |
+
spk_b_w,
|
| 246 |
+
spk_c,
|
| 247 |
+
spk_c_w,
|
| 248 |
+
spk_d,
|
| 249 |
+
spk_d_w,
|
| 250 |
+
speaker_name,
|
| 251 |
+
speaker_gender,
|
| 252 |
+
speaker_desc,
|
| 253 |
+
],
|
| 254 |
+
outputs=[merged_spker],
|
| 255 |
+
)
|