Spaces:
Running
on
Zero
Running
on
Zero
| import gc | |
| import os | |
| from typing import List, Literal | |
| import numpy as np | |
| from modules.devices import devices | |
| from modules.repos_static.resemble_enhance.enhancer.enhancer import Enhancer | |
| from modules.repos_static.resemble_enhance.enhancer.hparams import HParams | |
| from modules.repos_static.resemble_enhance.inference import inference | |
| import torch | |
| from modules.utils.constants import MODELS_DIR | |
| from pathlib import Path | |
| from threading import Lock | |
| from modules import config | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| resemble_enhance = None | |
| lock = Lock() | |
| class ResembleEnhance: | |
| def __init__(self, device: torch.device, dtype=torch.float32): | |
| self.device = device | |
| self.dtype = dtype | |
| self.enhancer: HParams = None | |
| self.hparams: Enhancer = None | |
| def load_model(self): | |
| hparams = HParams.load(Path(MODELS_DIR) / "resemble-enhance") | |
| enhancer = Enhancer(hparams).to(device=self.device, dtype=self.dtype).eval() | |
| state_dict = torch.load( | |
| Path(MODELS_DIR) / "resemble-enhance" / "mp_rank_00_model_states.pt", | |
| map_location=self.device, | |
| )["module"] | |
| enhancer.load_state_dict(state_dict) | |
| self.hparams = hparams | |
| self.enhancer = enhancer | |
| def denoise(self, dwav, sr) -> tuple[torch.Tensor, int]: | |
| assert self.enhancer is not None, "Model not loaded" | |
| assert self.enhancer.denoiser is not None, "Denoiser not loaded" | |
| enhancer = self.enhancer | |
| return inference( | |
| model=enhancer.denoiser, | |
| dwav=dwav, | |
| sr=sr, | |
| device=self.devicem, | |
| dtype=self.dtype, | |
| ) | |
| def enhance( | |
| self, | |
| dwav, | |
| sr, | |
| nfe=32, | |
| solver: Literal["midpoint", "rk4", "euler"] = "midpoint", | |
| lambd=0.5, | |
| tau=0.5, | |
| ) -> tuple[torch.Tensor, int]: | |
| assert 0 < nfe <= 128, f"nfe must be in (0, 128], got {nfe}" | |
| assert solver in ( | |
| "midpoint", | |
| "rk4", | |
| "euler", | |
| ), f"solver must be in ('midpoint', 'rk4', 'euler'), got {solver}" | |
| assert 0 <= lambd <= 1, f"lambd must be in [0, 1], got {lambd}" | |
| assert 0 <= tau <= 1, f"tau must be in [0, 1], got {tau}" | |
| assert self.enhancer is not None, "Model not loaded" | |
| enhancer = self.enhancer | |
| enhancer.configurate_(nfe=nfe, solver=solver, lambd=lambd, tau=tau) | |
| return inference( | |
| model=enhancer, dwav=dwav, sr=sr, device=self.device, dtype=self.dtype | |
| ) | |
| def load_enhancer() -> ResembleEnhance: | |
| global resemble_enhance | |
| with lock: | |
| if resemble_enhance is None: | |
| logger.info("Loading ResembleEnhance model") | |
| resemble_enhance = ResembleEnhance( | |
| device=devices.device, dtype=devices.dtype | |
| ) | |
| resemble_enhance.load_model() | |
| logger.info("ResembleEnhance model loaded") | |
| return resemble_enhance | |
| def unload_enhancer(): | |
| global resemble_enhance | |
| with lock: | |
| if resemble_enhance is not None: | |
| logger.info("Unloading ResembleEnhance model") | |
| del resemble_enhance | |
| resemble_enhance = None | |
| devices.torch_gc() | |
| gc.collect() | |
| logger.info("ResembleEnhance model unloaded") | |
| def reload_enhancer(): | |
| logger.info("Reloading ResembleEnhance model") | |
| unload_enhancer() | |
| load_enhancer() | |
| logger.info("ResembleEnhance model reloaded") | |
| def apply_audio_enhance_full( | |
| audio_data: np.ndarray, | |
| sr: int, | |
| nfe=32, | |
| solver: Literal["midpoint", "rk4", "euler"] = "midpoint", | |
| lambd=0.5, | |
| tau=0.5, | |
| ): | |
| # FIXME: 这里可能改成 to(device) 会优化一点? | |
| tensor = torch.from_numpy(audio_data).float().squeeze().cpu() | |
| enhancer = load_enhancer() | |
| tensor, sr = enhancer.enhance( | |
| tensor, sr, tau=tau, nfe=nfe, solver=solver, lambd=lambd | |
| ) | |
| audio_data = tensor.cpu().numpy() | |
| return audio_data, int(sr) | |
| def apply_audio_enhance( | |
| audio_data: np.ndarray, sr: int, enable_denoise: bool, enable_enhance: bool | |
| ): | |
| if not enable_denoise and not enable_enhance: | |
| return audio_data, sr | |
| # FIXME: 这里可能改成 to(device) 会优化一点? | |
| tensor = torch.from_numpy(audio_data).float().squeeze().cpu() | |
| enhancer = load_enhancer() | |
| if enable_enhance or enable_denoise: | |
| lambd = 0.9 if enable_denoise else 0.1 | |
| tensor, sr = enhancer.enhance( | |
| tensor, sr, tau=0.5, nfe=64, solver="rk4", lambd=lambd | |
| ) | |
| audio_data = tensor.cpu().numpy() | |
| return audio_data, int(sr) | |
| if __name__ == "__main__": | |
| import torchaudio | |
| import gradio as gr | |
| device = torch.device("cuda") | |
| # def enhance(file): | |
| # print(file) | |
| # ench = load_enhancer(device) | |
| # dwav, sr = torchaudio.load(file) | |
| # dwav = dwav.mean(dim=0).to(device) | |
| # enhanced, e_sr = ench.enhance(dwav, sr) | |
| # return e_sr, enhanced.cpu().numpy() | |
| # # 随便一个示例 | |
| # gr.Interface( | |
| # fn=enhance, inputs=[gr.Audio(type="filepath")], outputs=[gr.Audio()] | |
| # ).launch() | |
| # load_chat_tts() | |
| # ench = load_enhancer(device) | |
| # devices.torch_gc() | |
| # wav, sr = torchaudio.load("test.wav") | |
| # print(wav.shape, type(wav), sr, type(sr)) | |
| # # exit() | |
| # wav = wav.squeeze(0).cuda() | |
| # print(wav.device) | |
| # denoised, d_sr = ench.denoise(wav, sr) | |
| # denoised = denoised.unsqueeze(0) | |
| # print(denoised.shape) | |
| # torchaudio.save("denoised.wav", denoised.cpu(), d_sr) | |
| # for solver in ("midpoint", "rk4", "euler"): | |
| # for lambd in (0.1, 0.5, 0.9): | |
| # for tau in (0.1, 0.5, 0.9): | |
| # enhanced, e_sr = ench.enhance( | |
| # wav, sr, solver=solver, lambd=lambd, tau=tau, nfe=128 | |
| # ) | |
| # enhanced = enhanced.unsqueeze(0) | |
| # print(enhanced.shape) | |
| # torchaudio.save( | |
| # f"enhanced_{solver}_{lambd}_{tau}.wav", enhanced.cpu(), e_sr | |
| # ) | |