File size: 6,120 Bytes
627d3d7
d2b7e94
 
 
9d9fe0d
627d3d7
 
d2b7e94
 
f83b1b7
32b2aaa
 
 
da8d589
 
627d3d7
da8d589
627d3d7
 
da8d589
 
 
627d3d7
da8d589
627d3d7
da8d589
f83b1b7
 
da8d589
 
 
9d9fe0d
da8d589
 
9d9fe0d
da8d589
 
9d9fe0d
da8d589
 
 
 
 
627d3d7
da8d589
 
 
627d3d7
 
 
 
 
 
 
da8d589
 
 
 
 
 
 
f83b1b7
da8d589
 
 
 
 
 
 
 
 
 
 
 
 
 
627d3d7
 
 
 
 
 
 
 
 
 
 
bed01bd
627d3d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
da8d589
 
 
f83b1b7
d2b7e94
da8d589
 
 
f83b1b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
da8d589
f83b1b7
 
da8d589
f83b1b7
da8d589
f83b1b7
da8d589
f83b1b7
 
 
 
da8d589
f83b1b7
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
import gc
import logging
from pathlib import Path
from threading import Lock
from typing import Literal

import numpy as np
import torch

from modules.devices import devices
from modules.repos_static.resemble_enhance.enhancer.enhancer import Enhancer
from modules.repos_static.resemble_enhance.enhancer.hparams import HParams
from modules.repos_static.resemble_enhance.inference import inference
from modules.utils.constants import MODELS_DIR

logger = logging.getLogger(__name__)

resemble_enhance = None
lock = Lock()


class ResembleEnhance:
    def __init__(self, device: torch.device, dtype=torch.float32):
        self.device = device
        self.dtype = dtype

        self.enhancer: HParams = None
        self.hparams: Enhancer = None

    def load_model(self):
        hparams = HParams.load(Path(MODELS_DIR) / "resemble-enhance")
        enhancer = Enhancer(hparams)
        state_dict = torch.load(
            Path(MODELS_DIR) / "resemble-enhance" / "mp_rank_00_model_states.pt",
            map_location="cpu",
        )["module"]
        enhancer.load_state_dict(state_dict)
        enhancer.to(device=self.device, dtype=self.dtype).eval()

        self.hparams = hparams
        self.enhancer = enhancer

    @torch.inference_mode()
    def denoise(self, dwav, sr) -> tuple[torch.Tensor, int]:
        assert self.enhancer is not None, "Model not loaded"
        assert self.enhancer.denoiser is not None, "Denoiser not loaded"
        enhancer = self.enhancer
        return inference(
            model=enhancer.denoiser,
            dwav=dwav,
            sr=sr,
            device=self.devicem,
            dtype=self.dtype,
        )

    @torch.inference_mode()
    def enhance(
        self,
        dwav,
        sr,
        nfe=32,
        solver: Literal["midpoint", "rk4", "euler"] = "midpoint",
        lambd=0.5,
        tau=0.5,
    ) -> tuple[torch.Tensor, int]:
        assert 0 < nfe <= 128, f"nfe must be in (0, 128], got {nfe}"
        assert solver in (
            "midpoint",
            "rk4",
            "euler",
        ), f"solver must be in ('midpoint', 'rk4', 'euler'), got {solver}"
        assert 0 <= lambd <= 1, f"lambd must be in [0, 1], got {lambd}"
        assert 0 <= tau <= 1, f"tau must be in [0, 1], got {tau}"
        assert self.enhancer is not None, "Model not loaded"
        enhancer = self.enhancer
        enhancer.configurate_(nfe=nfe, solver=solver, lambd=lambd, tau=tau)
        return inference(
            model=enhancer, dwav=dwav, sr=sr, device=self.device, dtype=self.dtype
        )


def load_enhancer() -> ResembleEnhance:
    global resemble_enhance
    with lock:
        if resemble_enhance is None:
            logger.info("Loading ResembleEnhance model")
            resemble_enhance = ResembleEnhance(
                device=devices.get_device_for("enhancer"), dtype=devices.dtype
            )
            resemble_enhance.load_model()
            logger.info("ResembleEnhance model loaded")
    return resemble_enhance


def unload_enhancer():
    global resemble_enhance
    with lock:
        if resemble_enhance is not None:
            logger.info("Unloading ResembleEnhance model")
            del resemble_enhance
            resemble_enhance = None
            devices.torch_gc()
            gc.collect()
            logger.info("ResembleEnhance model unloaded")


def reload_enhancer():
    logger.info("Reloading ResembleEnhance model")
    unload_enhancer()
    load_enhancer()
    logger.info("ResembleEnhance model reloaded")


def apply_audio_enhance_full(
    audio_data: np.ndarray,
    sr: int,
    nfe=32,
    solver: Literal["midpoint", "rk4", "euler"] = "midpoint",
    lambd=0.5,
    tau=0.5,
):
    # FIXME: 这里可能改成 to(device) 会优化一点?
    tensor = torch.from_numpy(audio_data).float().squeeze().cpu()
    enhancer = load_enhancer()

    tensor, sr = enhancer.enhance(
        tensor, sr, tau=tau, nfe=nfe, solver=solver, lambd=lambd
    )

    audio_data = tensor.cpu().numpy()
    return audio_data, int(sr)


def apply_audio_enhance(
    audio_data: np.ndarray, sr: int, enable_denoise: bool, enable_enhance: bool
):
    if not enable_denoise and not enable_enhance:
        return audio_data, sr

    # FIXME: 这里可能改成 to(device) 会优化一点?
    tensor = torch.from_numpy(audio_data).float().squeeze().cpu()
    enhancer = load_enhancer()

    if enable_enhance or enable_denoise:
        lambd = 0.9 if enable_denoise else 0.1
        tensor, sr = enhancer.enhance(
            tensor, sr, tau=0.5, nfe=64, solver="rk4", lambd=lambd
        )

    audio_data = tensor.cpu().numpy()
    return audio_data, int(sr)


if __name__ == "__main__":
    import gradio as gr
    import torchaudio

    device = torch.device("cuda")

    # def enhance(file):
    #     print(file)
    #     ench = load_enhancer(device)
    #     dwav, sr = torchaudio.load(file)
    #     dwav = dwav.mean(dim=0).to(device)
    #     enhanced, e_sr = ench.enhance(dwav, sr)
    #     return e_sr, enhanced.cpu().numpy()

    # # 随便一个示例
    # gr.Interface(
    #     fn=enhance, inputs=[gr.Audio(type="filepath")], outputs=[gr.Audio()]
    # ).launch()

    # load_chat_tts()

    # ench = load_enhancer(device)

    # devices.torch_gc()

    # wav, sr = torchaudio.load("test.wav")

    # print(wav.shape, type(wav), sr, type(sr))
    # # exit()

    # wav = wav.squeeze(0).cuda()

    # print(wav.device)

    # denoised, d_sr = ench.denoise(wav, sr)
    # denoised = denoised.unsqueeze(0)
    # print(denoised.shape)
    # torchaudio.save("denoised.wav", denoised.cpu(), d_sr)

    # for solver in ("midpoint", "rk4", "euler"):
    #     for lambd in (0.1, 0.5, 0.9):
    #         for tau in (0.1, 0.5, 0.9):
    #             enhanced, e_sr = ench.enhance(
    #                 wav, sr, solver=solver, lambd=lambd, tau=tau, nfe=128
    #             )
    #             enhanced = enhanced.unsqueeze(0)
    #             print(enhanced.shape)
    #             torchaudio.save(
    #                 f"enhanced_{solver}_{lambd}_{tau}.wav", enhanced.cpu(), e_sr
    #             )