wuxulong19950206
add files
9270314
'''
Author: wuxulong19950206 1287173754@qq.com
Date: 2024-03-12 22:44:31
LastEditors: wuxulong19950206 1287173754@qq.com
LastEditTime: 2024-03-12 23:05:02
FilePath: \text_to_speech\mtts\models\vocoder\VocGAN\vocgan.py
Description: 这是默认设置,请设置`customMade`, 打开koroFileHeader查看配置 进行设置: https://github.com/OBKoro1/koro1FileHeader/wiki/%E9%85%8D%E7%BD%AE
'''
import argparse
import glob
import os
import numpy as np
import torch
import tqdm
from scipy.io.wavfile import write
from .denoiser import Denoiser
from .model.generator import ModifiedGenerator
from .utils.hparams import HParam, load_hparam_str
MAX_WAV_VALUE = 32768.0
from .download_utils import download_url
url = 'https://zenodo.org/record/4743731/files/vctk_pretrained_model_3180.pt'
class VocGan:
def __init__(self, device='cuda:0',config=None, denoise=False):
# home = os.environ['HOME']
checkpoint_path = config["checkpoint"]
denoise = config["denoise"]
device = config["device"]
# checkpoint_path = os.path.join(home,'./.cache/vocgan')
os.makedirs(checkpoint_path,exist_ok=True)
checkpoint_file = os.path.join(checkpoint_path,'vctk_pretrained_model_3180.pt')
if not os.path.exists(checkpoint_file):
download_url(url,checkpoint_path)
config = None
checkpoint = torch.load(checkpoint_file,map_location=device)
if config is not None:
hp = HParam(config)
else:
hp = load_hparam_str(checkpoint['hp_str'])
self.hp = hp
self.model = ModifiedGenerator(hp.audio.n_mel_channels,
hp.model.n_residual_layers,
ratios=hp.model.generator_ratio,
mult=hp.model.mult,
out_band=hp.model.out_channels).to(device)
self.model.load_state_dict(checkpoint['model_g'])
self.model.eval(inference=True)
self.model = self.model.to(device)
self.denoise = denoise
self.device = device
def synthesize(self, mel):
with torch.no_grad():
if not isinstance(mel,torch.Tensor):
mel = torch.tensor(mel)
if len(mel.shape) == 2:
mel = mel.unsqueeze(0)
mel = mel.to(self.device)
audio = self.model.inference(mel)
audio = audio.squeeze(0) # collapse all dimension except time axis
if self.denoise:
denoiser = Denoiser(self.model,device=self.device)
#.to(self.device)
audio = denoiser(audio, 0.01)
audio = audio.squeeze()
audio = audio[:-(self.hp.audio.hop_length * 10)]
#audio = MAX_WAV_VALUE * audio
#audio = audio.clamp(min=-MAX_WAV_VALUE, max=MAX_WAV_VALUE - 1)
#audio = audio.short()
audio = audio.cpu().detach().numpy()
return audio
def __call__(self,mel):
return self.synthesize(mel)