Spaces:
Sleeping
Sleeping
''' | |
Author: wuxulong19950206 1287173754@qq.com | |
Date: 2024-03-12 22:44:31 | |
LastEditors: wuxulong19950206 1287173754@qq.com | |
LastEditTime: 2024-03-12 23:05:02 | |
FilePath: \text_to_speech\mtts\models\vocoder\VocGAN\vocgan.py | |
Description: 这是默认设置,请设置`customMade`, 打开koroFileHeader查看配置 进行设置: https://github.com/OBKoro1/koro1FileHeader/wiki/%E9%85%8D%E7%BD%AE | |
''' | |
import argparse | |
import glob | |
import os | |
import numpy as np | |
import torch | |
import tqdm | |
from scipy.io.wavfile import write | |
from .denoiser import Denoiser | |
from .model.generator import ModifiedGenerator | |
from .utils.hparams import HParam, load_hparam_str | |
MAX_WAV_VALUE = 32768.0 | |
from .download_utils import download_url | |
url = 'https://zenodo.org/record/4743731/files/vctk_pretrained_model_3180.pt' | |
class VocGan: | |
def __init__(self, device='cuda:0',config=None, denoise=False): | |
# home = os.environ['HOME'] | |
checkpoint_path = config["checkpoint"] | |
denoise = config["denoise"] | |
device = config["device"] | |
# checkpoint_path = os.path.join(home,'./.cache/vocgan') | |
os.makedirs(checkpoint_path,exist_ok=True) | |
checkpoint_file = os.path.join(checkpoint_path,'vctk_pretrained_model_3180.pt') | |
if not os.path.exists(checkpoint_file): | |
download_url(url,checkpoint_path) | |
config = None | |
checkpoint = torch.load(checkpoint_file,map_location=device) | |
if config is not None: | |
hp = HParam(config) | |
else: | |
hp = load_hparam_str(checkpoint['hp_str']) | |
self.hp = hp | |
self.model = ModifiedGenerator(hp.audio.n_mel_channels, | |
hp.model.n_residual_layers, | |
ratios=hp.model.generator_ratio, | |
mult=hp.model.mult, | |
out_band=hp.model.out_channels).to(device) | |
self.model.load_state_dict(checkpoint['model_g']) | |
self.model.eval(inference=True) | |
self.model = self.model.to(device) | |
self.denoise = denoise | |
self.device = device | |
def synthesize(self, mel): | |
with torch.no_grad(): | |
if not isinstance(mel,torch.Tensor): | |
mel = torch.tensor(mel) | |
if len(mel.shape) == 2: | |
mel = mel.unsqueeze(0) | |
mel = mel.to(self.device) | |
audio = self.model.inference(mel) | |
audio = audio.squeeze(0) # collapse all dimension except time axis | |
if self.denoise: | |
denoiser = Denoiser(self.model,device=self.device) | |
#.to(self.device) | |
audio = denoiser(audio, 0.01) | |
audio = audio.squeeze() | |
audio = audio[:-(self.hp.audio.hop_length * 10)] | |
#audio = MAX_WAV_VALUE * audio | |
#audio = audio.clamp(min=-MAX_WAV_VALUE, max=MAX_WAV_VALUE - 1) | |
#audio = audio.short() | |
audio = audio.cpu().detach().numpy() | |
return audio | |
def __call__(self,mel): | |
return self.synthesize(mel) | |