diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..c7d9f3332a950355d5a77d85000f05e6f45435ea 100644 --- a/.gitattributes +++ b/.gitattributes @@ -25,7 +25,6 @@ *.safetensors filter=lfs diff=lfs merge=lfs -text saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.tar.* filter=lfs diff=lfs merge=lfs -text -*.tar filter=lfs diff=lfs merge=lfs -text *.tflite filter=lfs diff=lfs merge=lfs -text *.tgz filter=lfs diff=lfs merge=lfs -text *.wasm filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore index fd20fddf874731c364880a33eb9acd43c1512365..fb07aa3eb740e5d6b84172cbf73d2523768269ee 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,382 @@ +## Ignore Visual Studio temporary files, build results, and +## files generated by popular Visual Studio add-ons. +## +## Get latest from https://github.com/github/gitignore/blob/master/VisualStudio.gitignore -*.pyc +# User-specific files +*.rsuser +*.suo +*.user +*.userosscache +*.sln.docstates + +# User-specific files (MonoDevelop/Xamarin Studio) +*.userprefs + +# Mono auto generated files +mono_crash.* + +# Build results +[Dd]ebug/ +[Dd]ebugPublic/ +[Rr]elease/ +[Rr]eleases/ +x64/ +x86/ +[Ww][Ii][Nn]32/ +[Aa][Rr][Mm]/ +[Aa][Rr][Mm]64/ +bld/ +[Bb]in/ +[Oo]bj/ +[Oo]ut/ +[Ll]og/ +[Ll]ogs/ + +# Visual Studio 2015/2017 cache/options directory +.vs/ +# Uncomment if you have tasks that create the project's static files in wwwroot +#wwwroot/ + +# Visual Studio 2017 auto generated files +Generated\ Files/ + +# MSTest test Results +[Tt]est[Rr]esult*/ +[Bb]uild[Ll]og.* + +# NUnit +*.VisualState.xml +TestResult.xml +nunit-*.xml + +# Build Results of an ATL Project +[Dd]ebugPS/ +[Rr]eleasePS/ +dlldata.c + +# Benchmark Results +BenchmarkDotNet.Artifacts/ + +# .NET Core +project.lock.json +project.fragment.lock.json +artifacts/ + +# ASP.NET Scaffolding +ScaffoldingReadMe.txt + +# StyleCop +StyleCopReport.xml + +# Files built by Visual Studio +*_i.c +*_p.c +*_h.h +*.ilk +*.meta +*.obj +*.iobj +*.pch +*.pdb +*.ipdb +*.pgc +*.pgd +*.rsp +*.sbr +*.tlb +*.tli +*.tlh +*.tmp +*.tmp_proj +*_wpftmp.csproj +*.log +*.vspscc +*.vssscc +.builds +*.pidb +*.svclog +*.scc + +# Chutzpah Test files +_Chutzpah* + +# Visual C++ cache files +ipch/ +*.aps +*.ncb +*.opendb +*.opensdf +*.sdf +*.cachefile +*.VC.db +*.VC.VC.opendb + +# Visual Studio profiler +*.psess +*.vsp +*.vspx +*.sap + +# Visual Studio Trace Files +*.e2e + +# TFS 2012 Local Workspace +$tf/ + +# Guidance Automation Toolkit +*.gpState + +# ReSharper is a .NET coding add-in +_ReSharper*/ +*.[Rr]e[Ss]harper +*.DotSettings.user + +# TeamCity is a build add-in +_TeamCity* + +# DotCover is a Code Coverage Tool +*.dotCover + +# AxoCover is a Code Coverage Tool +.axoCover/* +!.axoCover/settings.json + +# Coverlet is a free, cross platform Code Coverage Tool +coverage*.json +coverage*.xml +coverage*.info + +# Visual Studio code coverage results +*.coverage +*.coveragexml + +# NCrunch +_NCrunch_* +.*crunch*.local.xml +nCrunchTemp_* + +# MightyMoose +*.mm.* +AutoTest.Net/ + +# Web workbench (sass) +.sass-cache/ + +# Installshield output folder +[Ee]xpress/ + +# DocProject is a documentation generator add-in +DocProject/buildhelp/ +DocProject/Help/*.HxT +DocProject/Help/*.HxC +DocProject/Help/*.hhc +DocProject/Help/*.hhk +DocProject/Help/*.hhp +DocProject/Help/Html2 +DocProject/Help/html + +# Click-Once directory +publish/ + +# Publish Web Output +*.[Pp]ublish.xml +*.azurePubxml +# Note: Comment the next line if you want to checkin your web deploy settings, +# but database connection strings (with potential passwords) will be unencrypted +*.pubxml +*.publishproj + +# Microsoft Azure Web App publish settings. Comment the next line if you want to +# checkin your Azure Web App publish settings, but sensitive information contained +# in these scripts will be unencrypted +PublishScripts/ + +# NuGet Packages +*.nupkg +# NuGet Symbol Packages +*.snupkg +# The packages folder can be ignored because of Package Restore +**/[Pp]ackages/* +# except build/, which is used as an MSBuild target. +!**/[Pp]ackages/build/ +# Uncomment if necessary however generally it will be regenerated when needed +#!**/[Pp]ackages/repositories.config +# NuGet v3's project.json files produces more ignorable files +*.nuget.props +*.nuget.targets + +# Microsoft Azure Build Output +csx/ +*.build.csdef + +# Microsoft Azure Emulator +ecf/ +rcf/ + +# Windows Store app package directories and files +AppPackages/ +BundleArtifacts/ +Package.StoreAssociation.xml +_pkginfo.txt +*.appx +*.appxbundle +*.appxupload + +# Visual Studio cache files +# files ending in .cache can be ignored +*.[Cc]ache +# but keep track of directories ending in .cache +!?*.[Cc]ache/ + +# Others +ClientBin/ +~$* +*~ +*.dbmdl +*.dbproj.schemaview +*.jfm +*.pfx +*.publishsettings +orleans.codegen.cs + +# Including strong name files can present a security risk +# (https://github.com/github/gitignore/pull/2483#issue-259490424) +#*.snk + +# Since there are multiple workflows, uncomment next line to ignore bower_components +# (https://github.com/github/gitignore/pull/1529#issuecomment-104372622) +#bower_components/ + +# RIA/Silverlight projects +Generated_Code/ + +# Backup & report files from converting an old project file +# to a newer Visual Studio version. Backup files are not needed, +# because we have git ;-) +_UpgradeReport_Files/ +Backup*/ +UpgradeLog*.XML +UpgradeLog*.htm +ServiceFabricBackup/ +*.rptproj.bak + +# SQL Server files +*.mdf +*.ldf +*.ndf + +# Business Intelligence projects +*.rdl.data +*.bim.layout +*.bim_*.settings +*.rptproj.rsuser +*- [Bb]ackup.rdl +*- [Bb]ackup ([0-9]).rdl +*- [Bb]ackup ([0-9][0-9]).rdl + +# Microsoft Fakes +FakesAssemblies/ + +# GhostDoc plugin setting file +*.GhostDoc.xml + +# Node.js Tools for Visual Studio +.ntvs_analysis.dat +node_modules/ + +# Visual Studio 6 build log +*.plg + +# Visual Studio 6 workspace options file +*.opt + +# Visual Studio 6 auto-generated workspace file (contains which files were open etc.) +*.vbw + +# Visual Studio LightSwitch build output +**/*.HTMLClient/GeneratedArtifacts +**/*.DesktopClient/GeneratedArtifacts +**/*.DesktopClient/ModelManifest.xml +**/*.Server/GeneratedArtifacts +**/*.Server/ModelManifest.xml +_Pvt_Extensions + +# Paket dependency manager +.paket/paket.exe +paket-files/ + +# FAKE - F# Make +.fake/ + +# CodeRush personal settings +.cr/personal + +# Python Tools for Visual Studio (PTVS) +__pycache__/ + + +# Cake - Uncomment if you are using it +# tools/** +# !tools/packages.config + +# Tabs Studio +*.tss + +# Telerik's JustMock configuration file +*.jmconfig + +# BizTalk build output +*.btp.cs +*.btm.cs +*.odx.cs +*.xsd.cs + +# OpenCover UI analysis results +OpenCover/ + +# Azure Stream Analytics local run output +ASALocalRun/ + +# MSBuild Binary and Structured Log +*.binlog + +# NVidia Nsight GPU debugger configuration file +*.nvuser + +# MFractors (Xamarin productivity tool) working folder +.mfractor/ + +# Local History for Visual Studio +.localhistory/ + +# BeatPulse healthcheck temp database +healthchecksdb + +# Backup folder for Package Reference Convert tool in Visual Studio 2017 +MigrationBackup/ + +# Ionide (cross platform F# VS Code tools) working folder +.ionide/ + +# Fody - auto-generated XML schema +FodyWeavers.xsd + +# build +build +monotonic_align/core.c +*.o +*.so +*.dll + +# data +/config.json +/*.pth +*.wav +/monotonic_align/monotonic_align +/resources +/MoeGoe.spec +/dist/MoeGoe +/dist + +.idea \ No newline at end of file diff --git a/app.py b/app.py index d444a565cc47f685371774965e681ab71d037238..58ace7595ff167bbdf7da7c16e2a96c0d0a9b2ff 100644 --- a/app.py +++ b/app.py @@ -2,17 +2,16 @@ import argparse import logging import os import re -import tempfile - -import edge_tts -import gradio as gr import gradio.processing_utils as gr_pu +import gradio as gr import librosa import numpy as np import soundfile from scipy.io import wavfile - +import tempfile +import edge_tts import utils + from inference.infer_tool import Svc logging.getLogger('numba').setLevel(logging.WARNING) @@ -29,7 +28,10 @@ tts_voice = { "英文女": "en-US-AnaNeural" } -hubert_model = utils.get_speech_encoder("vec768l12", device="cpu") +hubert_dict = { + "vec768l12": utils.get_speech_encoder("vec768l12", device="cpu"), + "vec256l9": utils.get_speech_encoder("vec256l9", device="cpu") +} def create_fn(model, spk): @@ -43,7 +45,8 @@ def create_fn(model, spk): audio = librosa.to_mono(audio.transpose(1, 0)) temp_path = "temp.wav" soundfile.write(temp_path, audio, sampling_rate, format="wav") - model.hubert_model = hubert_model + + model.hubert_model = hubert_dict[model.speech_encoder] out_audio = model.slice_inference(raw_audio_path=temp_path, spk=spk, slice_db=-40, @@ -63,9 +66,7 @@ def create_fn(model, spk): input_text = re.sub(r"[\n\,\(\) ]", "", input_text) voice = tts_voice[gender] ratestr = "+{:.0%}".format(tts_rate) if tts_rate >= 0 else "{:.0%}".format(tts_rate) - communicate = edge_tts.Communicate(text=input_text, - voice=voice, - rate=ratestr) + communicate = edge_tts.Communicate(text=input_text, voice=voice, rate=ratestr) with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file: temp_path = tmp_file.name await communicate.save(temp_path) @@ -107,8 +108,8 @@ if __name__ == '__main__': with gr.Column(): with gr.Row(): vc_transform = gr.Number(label="音高调整 (正负半音,12为1个八度)", value=0) - f0_predictor = gr.Radio(label="f0预测器 (harvest适合讲话,crepe适合唱歌)", - choices=['crepe', 'harvest', 'dio', 'pm'], value='crepe') + f0_predictor = gr.Radio(label="f0预测器 (推荐rmvpe)", + choices=['crepe', 'harvest', 'rmvpe'], value='rmvpe') auto_f0 = gr.Checkbox(label="自动音高预测 (文本转语音或讲话可选,会导致唱歌跑调)", value=False) with gr.Tabs(): diff --git a/cluster/__init__.py b/cluster/__init__.py deleted file mode 100644 index f1b9bde04e73e9218a5d534227caa4c25332f424..0000000000000000000000000000000000000000 --- a/cluster/__init__.py +++ /dev/null @@ -1,29 +0,0 @@ -import numpy as np -import torch -from sklearn.cluster import KMeans - -def get_cluster_model(ckpt_path): - checkpoint = torch.load(ckpt_path) - kmeans_dict = {} - for spk, ckpt in checkpoint.items(): - km = KMeans(ckpt["n_features_in_"]) - km.__dict__["n_features_in_"] = ckpt["n_features_in_"] - km.__dict__["_n_threads"] = ckpt["_n_threads"] - km.__dict__["cluster_centers_"] = ckpt["cluster_centers_"] - kmeans_dict[spk] = km - return kmeans_dict - -def get_cluster_result(model, x, speaker): - """ - x: np.array [t, 256] - return cluster class result - """ - return model[speaker].predict(x) - -def get_cluster_center_result(model, x,speaker): - """x: np.array [t, 256]""" - predict = model[speaker].predict(x) - return model[speaker].cluster_centers_[predict] - -def get_center(model, x,speaker): - return model[speaker].cluster_centers_[x] diff --git a/cluster/kmeans.py b/cluster/kmeans.py deleted file mode 100644 index 6111ea45e66a15d41b5b904be6f75affd3c4369f..0000000000000000000000000000000000000000 --- a/cluster/kmeans.py +++ /dev/null @@ -1,201 +0,0 @@ -import math,pdb -import torch,pynvml -from torch.nn.functional import normalize -from time import time -import numpy as np -# device=torch.device("cuda:0") -def _kpp(data: torch.Tensor, k: int, sample_size: int = -1): - """ Picks k points in the data based on the kmeans++ method. - - Parameters - ---------- - data : torch.Tensor - Expect a rank 1 or 2 array. Rank 1 is assumed to describe 1-D - data, rank 2 multidimensional data, in which case one - row is one observation. - k : int - Number of samples to generate. - sample_size : int - sample data to avoid memory overflow during calculation - - Returns - ------- - init : ndarray - A 'k' by 'N' containing the initial centroids. - - References - ---------- - .. [1] D. Arthur and S. Vassilvitskii, "k-means++: the advantages of - careful seeding", Proceedings of the Eighteenth Annual ACM-SIAM Symposium - on Discrete Algorithms, 2007. - .. [2] scipy/cluster/vq.py: _kpp - """ - batch_size=data.shape[0] - if batch_size>sample_size: - data = data[torch.randint(0, batch_size,[sample_size], device=data.device)] - dims = data.shape[1] if len(data.shape) > 1 else 1 - init = torch.zeros((k, dims)).to(data.device) - r = torch.distributions.uniform.Uniform(0, 1) - for i in range(k): - if i == 0: - init[i, :] = data[torch.randint(data.shape[0], [1])] - else: - D2 = torch.cdist(init[:i, :][None, :], data[None, :], p=2)[0].amin(dim=0) - probs = D2 / torch.sum(D2) - cumprobs = torch.cumsum(probs, dim=0) - init[i, :] = data[torch.searchsorted(cumprobs, r.sample([1]).to(data.device))] - return init -class KMeansGPU: - ''' - Kmeans clustering algorithm implemented with PyTorch - - Parameters: - n_clusters: int, - Number of clusters - - max_iter: int, default: 100 - Maximum number of iterations - - tol: float, default: 0.0001 - Tolerance - - verbose: int, default: 0 - Verbosity - - mode: {'euclidean', 'cosine'}, default: 'euclidean' - Type of distance measure - - init_method: {'random', 'point', '++'} - Type of initialization - - minibatch: {None, int}, default: None - Batch size of MinibatchKmeans algorithm - if None perform full KMeans algorithm - - Attributes: - centroids: torch.Tensor, shape: [n_clusters, n_features] - cluster centroids - ''' - def __init__(self, n_clusters, max_iter=200, tol=1e-4, verbose=0, mode="euclidean",device=torch.device("cuda:0")): - self.n_clusters = n_clusters - self.max_iter = max_iter - self.tol = tol - self.verbose = verbose - self.mode = mode - self.device=device - pynvml.nvmlInit() - gpu_handle = pynvml.nvmlDeviceGetHandleByIndex(device.index) - info = pynvml.nvmlDeviceGetMemoryInfo(gpu_handle) - self.minibatch=int(33e6/self.n_clusters*info.free/ 1024 / 1024 / 1024) - print("free_mem/GB:",info.free/ 1024 / 1024 / 1024,"minibatch:",self.minibatch) - - @staticmethod - def cos_sim(a, b): - """ - Compute cosine similarity of 2 sets of vectors - - Parameters: - a: torch.Tensor, shape: [m, n_features] - - b: torch.Tensor, shape: [n, n_features] - """ - return normalize(a, dim=-1) @ normalize(b, dim=-1).transpose(-2, -1) - - @staticmethod - def euc_sim(a, b): - """ - Compute euclidean similarity of 2 sets of vectors - Parameters: - a: torch.Tensor, shape: [m, n_features] - b: torch.Tensor, shape: [n, n_features] - """ - return 2 * a @ b.transpose(-2, -1) -(a**2).sum(dim=1)[..., :, None] - (b**2).sum(dim=1)[..., None, :] - - def max_sim(self, a, b): - """ - Compute maximum similarity (or minimum distance) of each vector - in a with all of the vectors in b - Parameters: - a: torch.Tensor, shape: [m, n_features] - b: torch.Tensor, shape: [n, n_features] - """ - if self.mode == 'cosine': - sim_func = self.cos_sim - elif self.mode == 'euclidean': - sim_func = self.euc_sim - sim = sim_func(a, b) - max_sim_v, max_sim_i = sim.max(dim=-1) - return max_sim_v, max_sim_i - - def fit_predict(self, X): - """ - Combination of fit() and predict() methods. - This is faster than calling fit() and predict() seperately. - Parameters: - X: torch.Tensor, shape: [n_samples, n_features] - centroids: {torch.Tensor, None}, default: None - if given, centroids will be initialized with given tensor - if None, centroids will be randomly chosen from X - Return: - labels: torch.Tensor, shape: [n_samples] - - mini_=33kk/k*remain - mini=min(mini_,fea_shape) - offset=log2(k/1000)*1.5 - kpp_all=min(mini_*10/offset,fea_shape) - kpp_sample=min(mini_/12/offset,fea_shape) - """ - assert isinstance(X, torch.Tensor), "input must be torch.Tensor" - assert X.dtype in [torch.half, torch.float, torch.double], "input must be floating point" - assert X.ndim == 2, "input must be a 2d tensor with shape: [n_samples, n_features] " - # print("verbose:%s"%self.verbose) - - offset = np.power(1.5,np.log(self.n_clusters / 1000))/np.log(2) - with torch.no_grad(): - batch_size= X.shape[0] - # print(self.minibatch, int(self.minibatch * 10 / offset), batch_size) - start_time = time() - if (self.minibatch*10//offset< batch_size): - x = X[torch.randint(0, batch_size,[int(self.minibatch*10/offset)])].to(self.device) - else: - x = X.to(self.device) - # print(x.device) - self.centroids = _kpp(x, self.n_clusters, min(int(self.minibatch/12/offset),batch_size)) - del x - torch.cuda.empty_cache() - # self.centroids = self.centroids.to(self.device) - num_points_in_clusters = torch.ones(self.n_clusters, device=self.device, dtype=X.dtype)#全1 - closest = None#[3098036]#int64 - if(self.minibatch>=batch_size//2 and self.minibatch=batch_size): - X=X.to(self.device) - for i in range(self.max_iter): - iter_time = time() - if self.minibatch= 2: - print('iter:', i, 'error:', error.item(), 'time spent:', round(time()-iter_time, 4)) - if error <= self.tol: - break - - if self.verbose >= 1: - print(f'used {i+1} iterations ({round(time()-start_time, 4)}s) to cluster {batch_size} items into {self.n_clusters} clusters') - return closest diff --git a/cluster/train_cluster.py b/cluster/train_cluster.py deleted file mode 100644 index 8644566388a4107c4442da14c0de090bcd4a91b8..0000000000000000000000000000000000000000 --- a/cluster/train_cluster.py +++ /dev/null @@ -1,84 +0,0 @@ -import time,pdb -import tqdm -from time import time as ttime -import os -from pathlib import Path -import logging -import argparse -from kmeans import KMeansGPU -import torch -import numpy as np -from sklearn.cluster import KMeans,MiniBatchKMeans - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) -from time import time as ttime -import pynvml,torch - -def train_cluster(in_dir, n_clusters, use_minibatch=True, verbose=False,use_gpu=False):#gpu_minibatch真拉,虽然库支持但是也不考虑 - logger.info(f"Loading features from {in_dir}") - features = [] - nums = 0 - for path in tqdm.tqdm(in_dir.glob("*.soft.pt")): - # for name in os.listdir(in_dir): - # path="%s/%s"%(in_dir,name) - features.append(torch.load(path,map_location="cpu").squeeze(0).numpy().T) - # print(features[-1].shape) - features = np.concatenate(features, axis=0) - print(nums, features.nbytes/ 1024**2, "MB , shape:",features.shape, features.dtype) - features = features.astype(np.float32) - logger.info(f"Clustering features of shape: {features.shape}") - t = time.time() - if(use_gpu==False): - if use_minibatch: - kmeans = MiniBatchKMeans(n_clusters=n_clusters,verbose=verbose, batch_size=4096, max_iter=80).fit(features) - else: - kmeans = KMeans(n_clusters=n_clusters,verbose=verbose).fit(features) - else: - kmeans = KMeansGPU(n_clusters=n_clusters, mode='euclidean', verbose=2 if verbose else 0,max_iter=500,tol=1e-2)# - features=torch.from_numpy(features)#.to(device) - labels = kmeans.fit_predict(features)# - - print(time.time()-t, "s") - - x = { - "n_features_in_": kmeans.n_features_in_ if use_gpu==False else features.shape[1], - "_n_threads": kmeans._n_threads if use_gpu==False else 4, - "cluster_centers_": kmeans.cluster_centers_ if use_gpu==False else kmeans.centroids.cpu().numpy(), - } - print("end") - - return x - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument('--dataset', type=Path, default="./dataset/44k", - help='path of training data directory') - parser.add_argument('--output', type=Path, default="logs/44k", - help='path of model output directory') - parser.add_argument('--gpu',action='store_true', default=False , - help='to use GPU') - - - args = parser.parse_args() - - checkpoint_dir = args.output - dataset = args.dataset - use_gpu = args.gpu - n_clusters = 10000 - - ckpt = {} - for spk in os.listdir(dataset): - if os.path.isdir(dataset/spk): - print(f"train kmeans for {spk}...") - in_dir = dataset/spk - x = train_cluster(in_dir, n_clusters,use_minibatch=False,verbose=False,use_gpu=use_gpu) - ckpt[spk] = x - - checkpoint_path = checkpoint_dir / f"kmeans_{n_clusters}.pt" - checkpoint_path.parent.mkdir(exist_ok=True, parents=True) - torch.save( - ckpt, - checkpoint_path, - ) - diff --git a/configs/config.json b/configs/config.json deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/diffusion/data_loaders.py b/diffusion/data_loaders.py deleted file mode 100644 index bf18572329019d7a8f1df01799eda207c16dd7ff..0000000000000000000000000000000000000000 --- a/diffusion/data_loaders.py +++ /dev/null @@ -1,284 +0,0 @@ -import os -import random -import re -import numpy as np -import librosa -import torch -import random -from utils import repeat_expand_2d -from tqdm import tqdm -from torch.utils.data import Dataset - -def traverse_dir( - root_dir, - extensions, - amount=None, - str_include=None, - str_exclude=None, - is_pure=False, - is_sort=False, - is_ext=True): - - file_list = [] - cnt = 0 - for root, _, files in os.walk(root_dir): - for file in files: - if any([file.endswith(f".{ext}") for ext in extensions]): - # path - mix_path = os.path.join(root, file) - pure_path = mix_path[len(root_dir)+1:] if is_pure else mix_path - - # amount - if (amount is not None) and (cnt == amount): - if is_sort: - file_list.sort() - return file_list - - # check string - if (str_include is not None) and (str_include not in pure_path): - continue - if (str_exclude is not None) and (str_exclude in pure_path): - continue - - if not is_ext: - ext = pure_path.split('.')[-1] - pure_path = pure_path[:-(len(ext)+1)] - file_list.append(pure_path) - cnt += 1 - if is_sort: - file_list.sort() - return file_list - - -def get_data_loaders(args, whole_audio=False): - data_train = AudioDataset( - filelists = args.data.training_files, - waveform_sec=args.data.duration, - hop_size=args.data.block_size, - sample_rate=args.data.sampling_rate, - load_all_data=args.train.cache_all_data, - whole_audio=whole_audio, - extensions=args.data.extensions, - n_spk=args.model.n_spk, - spk=args.spk, - device=args.train.cache_device, - fp16=args.train.cache_fp16, - use_aug=True) - loader_train = torch.utils.data.DataLoader( - data_train , - batch_size=args.train.batch_size if not whole_audio else 1, - shuffle=True, - num_workers=args.train.num_workers if args.train.cache_device=='cpu' else 0, - persistent_workers=(args.train.num_workers > 0) if args.train.cache_device=='cpu' else False, - pin_memory=True if args.train.cache_device=='cpu' else False - ) - data_valid = AudioDataset( - filelists = args.data.validation_files, - waveform_sec=args.data.duration, - hop_size=args.data.block_size, - sample_rate=args.data.sampling_rate, - load_all_data=args.train.cache_all_data, - whole_audio=True, - spk=args.spk, - extensions=args.data.extensions, - n_spk=args.model.n_spk) - loader_valid = torch.utils.data.DataLoader( - data_valid, - batch_size=1, - shuffle=False, - num_workers=0, - pin_memory=True - ) - return loader_train, loader_valid - - -class AudioDataset(Dataset): - def __init__( - self, - filelists, - waveform_sec, - hop_size, - sample_rate, - spk, - load_all_data=True, - whole_audio=False, - extensions=['wav'], - n_spk=1, - device='cpu', - fp16=False, - use_aug=False, - ): - super().__init__() - - self.waveform_sec = waveform_sec - self.sample_rate = sample_rate - self.hop_size = hop_size - self.filelists = filelists - self.whole_audio = whole_audio - self.use_aug = use_aug - self.data_buffer={} - self.pitch_aug_dict = {} - # np.load(os.path.join(self.path_root, 'pitch_aug_dict.npy'), allow_pickle=True).item() - if load_all_data: - print('Load all the data filelists:', filelists) - else: - print('Load the f0, volume data filelists:', filelists) - with open(filelists,"r") as f: - self.paths = f.read().splitlines() - for name_ext in tqdm(self.paths, total=len(self.paths)): - name = os.path.splitext(name_ext)[0] - path_audio = name_ext - duration = librosa.get_duration(filename = path_audio, sr = self.sample_rate) - - path_f0 = name_ext + ".f0.npy" - f0,_ = np.load(path_f0,allow_pickle=True) - f0 = torch.from_numpy(np.array(f0,dtype=float)).float().unsqueeze(-1).to(device) - - path_volume = name_ext + ".vol.npy" - volume = np.load(path_volume) - volume = torch.from_numpy(volume).float().unsqueeze(-1).to(device) - - path_augvol = name_ext + ".aug_vol.npy" - aug_vol = np.load(path_augvol) - aug_vol = torch.from_numpy(aug_vol).float().unsqueeze(-1).to(device) - - if n_spk is not None and n_spk > 1: - spk_name = name_ext.split("/")[-2] - spk_id = spk[spk_name] if spk_name in spk else 0 - if spk_id < 0 or spk_id >= n_spk: - raise ValueError(' [x] Muiti-speaker traing error : spk_id must be a positive integer from 0 to n_spk-1 ') - else: - spk_id = 0 - spk_id = torch.LongTensor(np.array([spk_id])).to(device) - - if load_all_data: - ''' - audio, sr = librosa.load(path_audio, sr=self.sample_rate) - if len(audio.shape) > 1: - audio = librosa.to_mono(audio) - audio = torch.from_numpy(audio).to(device) - ''' - path_mel = name_ext + ".mel.npy" - mel = np.load(path_mel) - mel = torch.from_numpy(mel).to(device) - - path_augmel = name_ext + ".aug_mel.npy" - aug_mel,keyshift = np.load(path_augmel, allow_pickle=True) - aug_mel = np.array(aug_mel,dtype=float) - aug_mel = torch.from_numpy(aug_mel).to(device) - self.pitch_aug_dict[name_ext] = keyshift - - path_units = name_ext + ".soft.pt" - units = torch.load(path_units).to(device) - units = units[0] - units = repeat_expand_2d(units,f0.size(0)).transpose(0,1) - - if fp16: - mel = mel.half() - aug_mel = aug_mel.half() - units = units.half() - - self.data_buffer[name_ext] = { - 'duration': duration, - 'mel': mel, - 'aug_mel': aug_mel, - 'units': units, - 'f0': f0, - 'volume': volume, - 'aug_vol': aug_vol, - 'spk_id': spk_id - } - else: - path_augmel = name_ext + ".aug_mel.npy" - aug_mel,keyshift = np.load(path_augmel, allow_pickle=True) - self.pitch_aug_dict[name_ext] = keyshift - self.data_buffer[name_ext] = { - 'duration': duration, - 'f0': f0, - 'volume': volume, - 'aug_vol': aug_vol, - 'spk_id': spk_id - } - - - def __getitem__(self, file_idx): - name_ext = self.paths[file_idx] - data_buffer = self.data_buffer[name_ext] - # check duration. if too short, then skip - if data_buffer['duration'] < (self.waveform_sec + 0.1): - return self.__getitem__( (file_idx + 1) % len(self.paths)) - - # get item - return self.get_data(name_ext, data_buffer) - - def get_data(self, name_ext, data_buffer): - name = os.path.splitext(name_ext)[0] - frame_resolution = self.hop_size / self.sample_rate - duration = data_buffer['duration'] - waveform_sec = duration if self.whole_audio else self.waveform_sec - - # load audio - idx_from = 0 if self.whole_audio else random.uniform(0, duration - waveform_sec - 0.1) - start_frame = int(idx_from / frame_resolution) - units_frame_len = int(waveform_sec / frame_resolution) - aug_flag = random.choice([True, False]) and self.use_aug - ''' - audio = data_buffer.get('audio') - if audio is None: - path_audio = os.path.join(self.path_root, 'audio', name) + '.wav' - audio, sr = librosa.load( - path_audio, - sr = self.sample_rate, - offset = start_frame * frame_resolution, - duration = waveform_sec) - if len(audio.shape) > 1: - audio = librosa.to_mono(audio) - # clip audio into N seconds - audio = audio[ : audio.shape[-1] // self.hop_size * self.hop_size] - audio = torch.from_numpy(audio).float() - else: - audio = audio[start_frame * self.hop_size : (start_frame + units_frame_len) * self.hop_size] - ''' - # load mel - mel_key = 'aug_mel' if aug_flag else 'mel' - mel = data_buffer.get(mel_key) - if mel is None: - mel = name_ext + ".mel.npy" - mel = np.load(mel) - mel = mel[start_frame : start_frame + units_frame_len] - mel = torch.from_numpy(mel).float() - else: - mel = mel[start_frame : start_frame + units_frame_len] - - # load f0 - f0 = data_buffer.get('f0') - aug_shift = 0 - if aug_flag: - aug_shift = self.pitch_aug_dict[name_ext] - f0_frames = 2 ** (aug_shift / 12) * f0[start_frame : start_frame + units_frame_len] - - # load units - units = data_buffer.get('units') - if units is None: - path_units = name_ext + ".soft.pt" - units = torch.load(path_units) - units = units[0] - units = repeat_expand_2d(units,f0.size(0)).transpose(0,1) - - units = units[start_frame : start_frame + units_frame_len] - - # load volume - vol_key = 'aug_vol' if aug_flag else 'volume' - volume = data_buffer.get(vol_key) - volume_frames = volume[start_frame : start_frame + units_frame_len] - - # load spk_id - spk_id = data_buffer.get('spk_id') - - # load shift - aug_shift = torch.from_numpy(np.array([[aug_shift]])).float() - - return dict(mel=mel, f0=f0_frames, volume=volume_frames, units=units, spk_id=spk_id, aug_shift=aug_shift, name=name, name_ext=name_ext) - - def __len__(self): - return len(self.paths) \ No newline at end of file diff --git a/diffusion/diffusion.py b/diffusion/diffusion.py deleted file mode 100644 index decc1d31503e93e6611b02ced7b9c6f00b95db58..0000000000000000000000000000000000000000 --- a/diffusion/diffusion.py +++ /dev/null @@ -1,317 +0,0 @@ -from collections import deque -from functools import partial -from inspect import isfunction -import torch.nn.functional as F -import librosa.sequence -import numpy as np -import torch -from torch import nn -from tqdm import tqdm - - -def exists(x): - return x is not None - - -def default(val, d): - if exists(val): - return val - return d() if isfunction(d) else d - - -def extract(a, t, x_shape): - b, *_ = t.shape - out = a.gather(-1, t) - return out.reshape(b, *((1,) * (len(x_shape) - 1))) - - -def noise_like(shape, device, repeat=False): - repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) - noise = lambda: torch.randn(shape, device=device) - return repeat_noise() if repeat else noise() - - -def linear_beta_schedule(timesteps, max_beta=0.02): - """ - linear schedule - """ - betas = np.linspace(1e-4, max_beta, timesteps) - return betas - - -def cosine_beta_schedule(timesteps, s=0.008): - """ - cosine schedule - as proposed in https://openreview.net/forum?id=-NEXDKk8gZ - """ - steps = timesteps + 1 - x = np.linspace(0, steps, steps) - alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2 - alphas_cumprod = alphas_cumprod / alphas_cumprod[0] - betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) - return np.clip(betas, a_min=0, a_max=0.999) - - -beta_schedule = { - "cosine": cosine_beta_schedule, - "linear": linear_beta_schedule, -} - - -class GaussianDiffusion(nn.Module): - def __init__(self, - denoise_fn, - out_dims=128, - timesteps=1000, - k_step=1000, - max_beta=0.02, - spec_min=-12, - spec_max=2): - super().__init__() - self.denoise_fn = denoise_fn - self.out_dims = out_dims - betas = beta_schedule['linear'](timesteps, max_beta=max_beta) - - alphas = 1. - betas - alphas_cumprod = np.cumprod(alphas, axis=0) - alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) - - timesteps, = betas.shape - self.num_timesteps = int(timesteps) - self.k_step = k_step - - self.noise_list = deque(maxlen=4) - - to_torch = partial(torch.tensor, dtype=torch.float32) - - self.register_buffer('betas', to_torch(betas)) - self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) - self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) - - # calculations for diffusion q(x_t | x_{t-1}) and others - self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) - self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) - self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) - self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) - self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) - - # calculations for posterior q(x_{t-1} | x_t, x_0) - posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) - # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) - self.register_buffer('posterior_variance', to_torch(posterior_variance)) - # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain - self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) - self.register_buffer('posterior_mean_coef1', to_torch( - betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))) - self.register_buffer('posterior_mean_coef2', to_torch( - (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))) - - self.register_buffer('spec_min', torch.FloatTensor([spec_min])[None, None, :out_dims]) - self.register_buffer('spec_max', torch.FloatTensor([spec_max])[None, None, :out_dims]) - - def q_mean_variance(self, x_start, t): - mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start - variance = extract(1. - self.alphas_cumprod, t, x_start.shape) - log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape) - return mean, variance, log_variance - - def predict_start_from_noise(self, x_t, t, noise): - return ( - extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - - extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise - ) - - def q_posterior(self, x_start, x_t, t): - posterior_mean = ( - extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + - extract(self.posterior_mean_coef2, t, x_t.shape) * x_t - ) - posterior_variance = extract(self.posterior_variance, t, x_t.shape) - posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) - return posterior_mean, posterior_variance, posterior_log_variance_clipped - - def p_mean_variance(self, x, t, cond): - noise_pred = self.denoise_fn(x, t, cond=cond) - x_recon = self.predict_start_from_noise(x, t=t, noise=noise_pred) - - x_recon.clamp_(-1., 1.) - - model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) - return model_mean, posterior_variance, posterior_log_variance - - @torch.no_grad() - def p_sample(self, x, t, cond, clip_denoised=True, repeat_noise=False): - b, *_, device = *x.shape, x.device - model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, cond=cond) - noise = noise_like(x.shape, device, repeat_noise) - # no noise when t == 0 - nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) - return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise - - @torch.no_grad() - def p_sample_plms(self, x, t, interval, cond, clip_denoised=True, repeat_noise=False): - """ - Use the PLMS method from - [Pseudo Numerical Methods for Diffusion Models on Manifolds](https://arxiv.org/abs/2202.09778). - """ - - def get_x_pred(x, noise_t, t): - a_t = extract(self.alphas_cumprod, t, x.shape) - a_prev = extract(self.alphas_cumprod, torch.max(t - interval, torch.zeros_like(t)), x.shape) - a_t_sq, a_prev_sq = a_t.sqrt(), a_prev.sqrt() - - x_delta = (a_prev - a_t) * ((1 / (a_t_sq * (a_t_sq + a_prev_sq))) * x - 1 / ( - a_t_sq * (((1 - a_prev) * a_t).sqrt() + ((1 - a_t) * a_prev).sqrt())) * noise_t) - x_pred = x + x_delta - - return x_pred - - noise_list = self.noise_list - noise_pred = self.denoise_fn(x, t, cond=cond) - - if len(noise_list) == 0: - x_pred = get_x_pred(x, noise_pred, t) - noise_pred_prev = self.denoise_fn(x_pred, max(t - interval, 0), cond=cond) - noise_pred_prime = (noise_pred + noise_pred_prev) / 2 - elif len(noise_list) == 1: - noise_pred_prime = (3 * noise_pred - noise_list[-1]) / 2 - elif len(noise_list) == 2: - noise_pred_prime = (23 * noise_pred - 16 * noise_list[-1] + 5 * noise_list[-2]) / 12 - else: - noise_pred_prime = (55 * noise_pred - 59 * noise_list[-1] + 37 * noise_list[-2] - 9 * noise_list[-3]) / 24 - - x_prev = get_x_pred(x, noise_pred_prime, t) - noise_list.append(noise_pred) - - return x_prev - - def q_sample(self, x_start, t, noise=None): - noise = default(noise, lambda: torch.randn_like(x_start)) - return ( - extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + - extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise - ) - - def p_losses(self, x_start, t, cond, noise=None, loss_type='l2'): - noise = default(noise, lambda: torch.randn_like(x_start)) - - x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) - x_recon = self.denoise_fn(x_noisy, t, cond) - - if loss_type == 'l1': - loss = (noise - x_recon).abs().mean() - elif loss_type == 'l2': - loss = F.mse_loss(noise, x_recon) - else: - raise NotImplementedError() - - return loss - - def forward(self, - condition, - gt_spec=None, - infer=True, - infer_speedup=10, - method='dpm-solver', - k_step=300, - use_tqdm=True): - """ - conditioning diffusion, use fastspeech2 encoder output as the condition - """ - cond = condition.transpose(1, 2) - b, device = condition.shape[0], condition.device - - if not infer: - spec = self.norm_spec(gt_spec) - t = torch.randint(0, self.k_step, (b,), device=device).long() - norm_spec = spec.transpose(1, 2)[:, None, :, :] # [B, 1, M, T] - return self.p_losses(norm_spec, t, cond=cond) - else: - shape = (cond.shape[0], 1, self.out_dims, cond.shape[2]) - - if gt_spec is None: - t = self.k_step - x = torch.randn(shape, device=device) - else: - t = k_step - norm_spec = self.norm_spec(gt_spec) - norm_spec = norm_spec.transpose(1, 2)[:, None, :, :] - x = self.q_sample(x_start=norm_spec, t=torch.tensor([t - 1], device=device).long()) - - if method is not None and infer_speedup > 1: - if method == 'dpm-solver': - from .dpm_solver_pytorch import NoiseScheduleVP, model_wrapper, DPM_Solver - # 1. Define the noise schedule. - noise_schedule = NoiseScheduleVP(schedule='discrete', betas=self.betas[:t]) - - # 2. Convert your discrete-time `model` to the continuous-time - # noise prediction model. Here is an example for a diffusion model - # `model` with the noise prediction type ("noise") . - def my_wrapper(fn): - def wrapped(x, t, **kwargs): - ret = fn(x, t, **kwargs) - if use_tqdm: - self.bar.update(1) - return ret - - return wrapped - - model_fn = model_wrapper( - my_wrapper(self.denoise_fn), - noise_schedule, - model_type="noise", # or "x_start" or "v" or "score" - model_kwargs={"cond": cond} - ) - - # 3. Define dpm-solver and sample by singlestep DPM-Solver. - # (We recommend singlestep DPM-Solver for unconditional sampling) - # You can adjust the `steps` to balance the computation - # costs and the sample quality. - dpm_solver = DPM_Solver(model_fn, noise_schedule) - - steps = t // infer_speedup - if use_tqdm: - self.bar = tqdm(desc="sample time step", total=steps) - x = dpm_solver.sample( - x, - steps=steps, - order=3, - skip_type="time_uniform", - method="singlestep", - ) - if use_tqdm: - self.bar.close() - elif method == 'pndm': - self.noise_list = deque(maxlen=4) - if use_tqdm: - for i in tqdm( - reversed(range(0, t, infer_speedup)), desc='sample time step', - total=t // infer_speedup, - ): - x = self.p_sample_plms( - x, torch.full((b,), i, device=device, dtype=torch.long), - infer_speedup, cond=cond - ) - else: - for i in reversed(range(0, t, infer_speedup)): - x = self.p_sample_plms( - x, torch.full((b,), i, device=device, dtype=torch.long), - infer_speedup, cond=cond - ) - else: - raise NotImplementedError(method) - else: - if use_tqdm: - for i in tqdm(reversed(range(0, t)), desc='sample time step', total=t): - x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), cond) - else: - for i in reversed(range(0, t)): - x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), cond) - x = x.squeeze(1).transpose(1, 2) # [B, T, M] - return self.denorm_spec(x) - - def norm_spec(self, x): - return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1 - - def denorm_spec(self, x): - return (x + 1) / 2 * (self.spec_max - self.spec_min) + self.spec_min diff --git a/diffusion/diffusion_onnx.py b/diffusion/diffusion_onnx.py deleted file mode 100644 index 1c1e80321de162b5233801efa3423739f7f92bdc..0000000000000000000000000000000000000000 --- a/diffusion/diffusion_onnx.py +++ /dev/null @@ -1,612 +0,0 @@ -from collections import deque -from functools import partial -from inspect import isfunction -import torch.nn.functional as F -import librosa.sequence -import numpy as np -from torch.nn import Conv1d -from torch.nn import Mish -import torch -from torch import nn -from tqdm import tqdm -import math - - -def exists(x): - return x is not None - - -def default(val, d): - if exists(val): - return val - return d() if isfunction(d) else d - - -def extract(a, t): - return a[t].reshape((1, 1, 1, 1)) - - -def noise_like(shape, device, repeat=False): - repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) - noise = lambda: torch.randn(shape, device=device) - return repeat_noise() if repeat else noise() - - -def linear_beta_schedule(timesteps, max_beta=0.02): - """ - linear schedule - """ - betas = np.linspace(1e-4, max_beta, timesteps) - return betas - - -def cosine_beta_schedule(timesteps, s=0.008): - """ - cosine schedule - as proposed in https://openreview.net/forum?id=-NEXDKk8gZ - """ - steps = timesteps + 1 - x = np.linspace(0, steps, steps) - alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2 - alphas_cumprod = alphas_cumprod / alphas_cumprod[0] - betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) - return np.clip(betas, a_min=0, a_max=0.999) - - -beta_schedule = { - "cosine": cosine_beta_schedule, - "linear": linear_beta_schedule, -} - - -def extract_1(a, t): - return a[t].reshape((1, 1, 1, 1)) - - -def predict_stage0(noise_pred, noise_pred_prev): - return (noise_pred + noise_pred_prev) / 2 - - -def predict_stage1(noise_pred, noise_list): - return (noise_pred * 3 - - noise_list[-1]) / 2 - - -def predict_stage2(noise_pred, noise_list): - return (noise_pred * 23 - - noise_list[-1] * 16 - + noise_list[-2] * 5) / 12 - - -def predict_stage3(noise_pred, noise_list): - return (noise_pred * 55 - - noise_list[-1] * 59 - + noise_list[-2] * 37 - - noise_list[-3] * 9) / 24 - - -class SinusoidalPosEmb(nn.Module): - def __init__(self, dim): - super().__init__() - self.dim = dim - self.half_dim = dim // 2 - self.emb = 9.21034037 / (self.half_dim - 1) - self.emb = torch.exp(torch.arange(self.half_dim) * torch.tensor(-self.emb)).unsqueeze(0) - self.emb = self.emb.cpu() - - def forward(self, x): - emb = self.emb * x - emb = torch.cat((emb.sin(), emb.cos()), dim=-1) - return emb - - -class ResidualBlock(nn.Module): - def __init__(self, encoder_hidden, residual_channels, dilation): - super().__init__() - self.residual_channels = residual_channels - self.dilated_conv = Conv1d(residual_channels, 2 * residual_channels, 3, padding=dilation, dilation=dilation) - self.diffusion_projection = nn.Linear(residual_channels, residual_channels) - self.conditioner_projection = Conv1d(encoder_hidden, 2 * residual_channels, 1) - self.output_projection = Conv1d(residual_channels, 2 * residual_channels, 1) - - def forward(self, x, conditioner, diffusion_step): - diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1) - conditioner = self.conditioner_projection(conditioner) - y = x + diffusion_step - y = self.dilated_conv(y) + conditioner - - gate, filter_1 = torch.split(y, [self.residual_channels, self.residual_channels], dim=1) - - y = torch.sigmoid(gate) * torch.tanh(filter_1) - y = self.output_projection(y) - - residual, skip = torch.split(y, [self.residual_channels, self.residual_channels], dim=1) - - return (x + residual) / 1.41421356, skip - - -class DiffNet(nn.Module): - def __init__(self, in_dims, n_layers, n_chans, n_hidden): - super().__init__() - self.encoder_hidden = n_hidden - self.residual_layers = n_layers - self.residual_channels = n_chans - self.input_projection = Conv1d(in_dims, self.residual_channels, 1) - self.diffusion_embedding = SinusoidalPosEmb(self.residual_channels) - dim = self.residual_channels - self.mlp = nn.Sequential( - nn.Linear(dim, dim * 4), - Mish(), - nn.Linear(dim * 4, dim) - ) - self.residual_layers = nn.ModuleList([ - ResidualBlock(self.encoder_hidden, self.residual_channels, 1) - for i in range(self.residual_layers) - ]) - self.skip_projection = Conv1d(self.residual_channels, self.residual_channels, 1) - self.output_projection = Conv1d(self.residual_channels, in_dims, 1) - nn.init.zeros_(self.output_projection.weight) - - def forward(self, spec, diffusion_step, cond): - x = spec.squeeze(0) - x = self.input_projection(x) # x [B, residual_channel, T] - x = F.relu(x) - # skip = torch.randn_like(x) - diffusion_step = diffusion_step.float() - diffusion_step = self.diffusion_embedding(diffusion_step) - diffusion_step = self.mlp(diffusion_step) - - x, skip = self.residual_layers[0](x, cond, diffusion_step) - # noinspection PyTypeChecker - for layer in self.residual_layers[1:]: - x, skip_connection = layer.forward(x, cond, diffusion_step) - skip = skip + skip_connection - x = skip / math.sqrt(len(self.residual_layers)) - x = self.skip_projection(x) - x = F.relu(x) - x = self.output_projection(x) # [B, 80, T] - return x.unsqueeze(1) - - -class AfterDiffusion(nn.Module): - def __init__(self, spec_max, spec_min, v_type='a'): - super().__init__() - self.spec_max = spec_max - self.spec_min = spec_min - self.type = v_type - - def forward(self, x): - x = x.squeeze(1).permute(0, 2, 1) - mel_out = (x + 1) / 2 * (self.spec_max - self.spec_min) + self.spec_min - if self.type == 'nsf-hifigan-log10': - mel_out = mel_out * 0.434294 - return mel_out.transpose(2, 1) - - -class Pred(nn.Module): - def __init__(self, alphas_cumprod): - super().__init__() - self.alphas_cumprod = alphas_cumprod - - def forward(self, x_1, noise_t, t_1, t_prev): - a_t = extract(self.alphas_cumprod, t_1).cpu() - a_prev = extract(self.alphas_cumprod, t_prev).cpu() - a_t_sq, a_prev_sq = a_t.sqrt().cpu(), a_prev.sqrt().cpu() - x_delta = (a_prev - a_t) * ((1 / (a_t_sq * (a_t_sq + a_prev_sq))) * x_1 - 1 / ( - a_t_sq * (((1 - a_prev) * a_t).sqrt() + ((1 - a_t) * a_prev).sqrt())) * noise_t) - x_pred = x_1 + x_delta.cpu() - - return x_pred - - -class GaussianDiffusion(nn.Module): - def __init__(self, - out_dims=128, - n_layers=20, - n_chans=384, - n_hidden=256, - timesteps=1000, - k_step=1000, - max_beta=0.02, - spec_min=-12, - spec_max=2): - super().__init__() - self.denoise_fn = DiffNet(out_dims, n_layers, n_chans, n_hidden) - self.out_dims = out_dims - self.mel_bins = out_dims - self.n_hidden = n_hidden - betas = beta_schedule['linear'](timesteps, max_beta=max_beta) - - alphas = 1. - betas - alphas_cumprod = np.cumprod(alphas, axis=0) - alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) - timesteps, = betas.shape - self.num_timesteps = int(timesteps) - self.k_step = k_step - - self.noise_list = deque(maxlen=4) - - to_torch = partial(torch.tensor, dtype=torch.float32) - - self.register_buffer('betas', to_torch(betas)) - self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) - self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) - - # calculations for diffusion q(x_t | x_{t-1}) and others - self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) - self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) - self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) - self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) - self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) - - # calculations for posterior q(x_{t-1} | x_t, x_0) - posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) - # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) - self.register_buffer('posterior_variance', to_torch(posterior_variance)) - # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain - self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) - self.register_buffer('posterior_mean_coef1', to_torch( - betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))) - self.register_buffer('posterior_mean_coef2', to_torch( - (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))) - - self.register_buffer('spec_min', torch.FloatTensor([spec_min])[None, None, :out_dims]) - self.register_buffer('spec_max', torch.FloatTensor([spec_max])[None, None, :out_dims]) - self.ad = AfterDiffusion(self.spec_max, self.spec_min) - self.xp = Pred(self.alphas_cumprod) - - def q_mean_variance(self, x_start, t): - mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start - variance = extract(1. - self.alphas_cumprod, t, x_start.shape) - log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape) - return mean, variance, log_variance - - def predict_start_from_noise(self, x_t, t, noise): - return ( - extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - - extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise - ) - - def q_posterior(self, x_start, x_t, t): - posterior_mean = ( - extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + - extract(self.posterior_mean_coef2, t, x_t.shape) * x_t - ) - posterior_variance = extract(self.posterior_variance, t, x_t.shape) - posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) - return posterior_mean, posterior_variance, posterior_log_variance_clipped - - def p_mean_variance(self, x, t, cond): - noise_pred = self.denoise_fn(x, t, cond=cond) - x_recon = self.predict_start_from_noise(x, t=t, noise=noise_pred) - - x_recon.clamp_(-1., 1.) - - model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) - return model_mean, posterior_variance, posterior_log_variance - - @torch.no_grad() - def p_sample(self, x, t, cond, clip_denoised=True, repeat_noise=False): - b, *_, device = *x.shape, x.device - model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, cond=cond) - noise = noise_like(x.shape, device, repeat_noise) - # no noise when t == 0 - nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) - return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise - - @torch.no_grad() - def p_sample_plms(self, x, t, interval, cond, clip_denoised=True, repeat_noise=False): - """ - Use the PLMS method from - [Pseudo Numerical Methods for Diffusion Models on Manifolds](https://arxiv.org/abs/2202.09778). - """ - - def get_x_pred(x, noise_t, t): - a_t = extract(self.alphas_cumprod, t) - a_prev = extract(self.alphas_cumprod, torch.max(t - interval, torch.zeros_like(t))) - a_t_sq, a_prev_sq = a_t.sqrt(), a_prev.sqrt() - - x_delta = (a_prev - a_t) * ((1 / (a_t_sq * (a_t_sq + a_prev_sq))) * x - 1 / ( - a_t_sq * (((1 - a_prev) * a_t).sqrt() + ((1 - a_t) * a_prev).sqrt())) * noise_t) - x_pred = x + x_delta - - return x_pred - - noise_list = self.noise_list - noise_pred = self.denoise_fn(x, t, cond=cond) - - if len(noise_list) == 0: - x_pred = get_x_pred(x, noise_pred, t) - noise_pred_prev = self.denoise_fn(x_pred, max(t - interval, 0), cond=cond) - noise_pred_prime = (noise_pred + noise_pred_prev) / 2 - elif len(noise_list) == 1: - noise_pred_prime = (3 * noise_pred - noise_list[-1]) / 2 - elif len(noise_list) == 2: - noise_pred_prime = (23 * noise_pred - 16 * noise_list[-1] + 5 * noise_list[-2]) / 12 - else: - noise_pred_prime = (55 * noise_pred - 59 * noise_list[-1] + 37 * noise_list[-2] - 9 * noise_list[-3]) / 24 - - x_prev = get_x_pred(x, noise_pred_prime, t) - noise_list.append(noise_pred) - - return x_prev - - def q_sample(self, x_start, t, noise=None): - noise = default(noise, lambda: torch.randn_like(x_start)) - return ( - extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + - extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise - ) - - def p_losses(self, x_start, t, cond, noise=None, loss_type='l2'): - noise = default(noise, lambda: torch.randn_like(x_start)) - - x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) - x_recon = self.denoise_fn(x_noisy, t, cond) - - if loss_type == 'l1': - loss = (noise - x_recon).abs().mean() - elif loss_type == 'l2': - loss = F.mse_loss(noise, x_recon) - else: - raise NotImplementedError() - - return loss - - def org_forward(self, - condition, - init_noise=None, - gt_spec=None, - infer=True, - infer_speedup=100, - method='pndm', - k_step=1000, - use_tqdm=True): - """ - conditioning diffusion, use fastspeech2 encoder output as the condition - """ - cond = condition - b, device = condition.shape[0], condition.device - if not infer: - spec = self.norm_spec(gt_spec) - t = torch.randint(0, self.k_step, (b,), device=device).long() - norm_spec = spec.transpose(1, 2)[:, None, :, :] # [B, 1, M, T] - return self.p_losses(norm_spec, t, cond=cond) - else: - shape = (cond.shape[0], 1, self.out_dims, cond.shape[2]) - - if gt_spec is None: - t = self.k_step - if init_noise is None: - x = torch.randn(shape, device=device) - else: - x = init_noise - else: - t = k_step - norm_spec = self.norm_spec(gt_spec) - norm_spec = norm_spec.transpose(1, 2)[:, None, :, :] - x = self.q_sample(x_start=norm_spec, t=torch.tensor([t - 1], device=device).long()) - - if method is not None and infer_speedup > 1: - if method == 'dpm-solver': - from .dpm_solver_pytorch import NoiseScheduleVP, model_wrapper, DPM_Solver - # 1. Define the noise schedule. - noise_schedule = NoiseScheduleVP(schedule='discrete', betas=self.betas[:t]) - - # 2. Convert your discrete-time `model` to the continuous-time - # noise prediction model. Here is an example for a diffusion model - # `model` with the noise prediction type ("noise") . - def my_wrapper(fn): - def wrapped(x, t, **kwargs): - ret = fn(x, t, **kwargs) - if use_tqdm: - self.bar.update(1) - return ret - - return wrapped - - model_fn = model_wrapper( - my_wrapper(self.denoise_fn), - noise_schedule, - model_type="noise", # or "x_start" or "v" or "score" - model_kwargs={"cond": cond} - ) - - # 3. Define dpm-solver and sample by singlestep DPM-Solver. - # (We recommend singlestep DPM-Solver for unconditional sampling) - # You can adjust the `steps` to balance the computation - # costs and the sample quality. - dpm_solver = DPM_Solver(model_fn, noise_schedule) - - steps = t // infer_speedup - if use_tqdm: - self.bar = tqdm(desc="sample time step", total=steps) - x = dpm_solver.sample( - x, - steps=steps, - order=3, - skip_type="time_uniform", - method="singlestep", - ) - if use_tqdm: - self.bar.close() - elif method == 'pndm': - self.noise_list = deque(maxlen=4) - if use_tqdm: - for i in tqdm( - reversed(range(0, t, infer_speedup)), desc='sample time step', - total=t // infer_speedup, - ): - x = self.p_sample_plms( - x, torch.full((b,), i, device=device, dtype=torch.long), - infer_speedup, cond=cond - ) - else: - for i in reversed(range(0, t, infer_speedup)): - x = self.p_sample_plms( - x, torch.full((b,), i, device=device, dtype=torch.long), - infer_speedup, cond=cond - ) - else: - raise NotImplementedError(method) - else: - if use_tqdm: - for i in tqdm(reversed(range(0, t)), desc='sample time step', total=t): - x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), cond) - else: - for i in reversed(range(0, t)): - x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), cond) - x = x.squeeze(1).transpose(1, 2) # [B, T, M] - return self.denorm_spec(x).transpose(2, 1) - - def norm_spec(self, x): - return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1 - - def denorm_spec(self, x): - return (x + 1) / 2 * (self.spec_max - self.spec_min) + self.spec_min - - def get_x_pred(self, x_1, noise_t, t_1, t_prev): - a_t = extract(self.alphas_cumprod, t_1) - a_prev = extract(self.alphas_cumprod, t_prev) - a_t_sq, a_prev_sq = a_t.sqrt(), a_prev.sqrt() - x_delta = (a_prev - a_t) * ((1 / (a_t_sq * (a_t_sq + a_prev_sq))) * x_1 - 1 / ( - a_t_sq * (((1 - a_prev) * a_t).sqrt() + ((1 - a_t) * a_prev).sqrt())) * noise_t) - x_pred = x_1 + x_delta - return x_pred - - def OnnxExport(self, project_name=None, init_noise=None, hidden_channels=256, export_denoise=True, export_pred=True, export_after=True): - cond = torch.randn([1, self.n_hidden, 10]).cpu() - if init_noise is None: - x = torch.randn((1, 1, self.mel_bins, cond.shape[2]), dtype=torch.float32).cpu() - else: - x = init_noise - pndms = 100 - - org_y_x = self.org_forward(cond, init_noise=x) - - device = cond.device - n_frames = cond.shape[2] - step_range = torch.arange(0, self.k_step, pndms, dtype=torch.long, device=device).flip(0) - plms_noise_stage = torch.tensor(0, dtype=torch.long, device=device) - noise_list = torch.zeros((0, 1, 1, self.mel_bins, n_frames), device=device) - - ot = step_range[0] - ot_1 = torch.full((1,), ot, device=device, dtype=torch.long) - if export_denoise: - torch.onnx.export( - self.denoise_fn, - (x.cpu(), ot_1.cpu(), cond.cpu()), - f"{project_name}_denoise.onnx", - input_names=["noise", "time", "condition"], - output_names=["noise_pred"], - dynamic_axes={ - "noise": [3], - "condition": [2] - }, - opset_version=16 - ) - - for t in step_range: - t_1 = torch.full((1,), t, device=device, dtype=torch.long) - noise_pred = self.denoise_fn(x, t_1, cond) - t_prev = t_1 - pndms - t_prev = t_prev * (t_prev > 0) - if plms_noise_stage == 0: - if export_pred: - torch.onnx.export( - self.xp, - (x.cpu(), noise_pred.cpu(), t_1.cpu(), t_prev.cpu()), - f"{project_name}_pred.onnx", - input_names=["noise", "noise_pred", "time", "time_prev"], - output_names=["noise_pred_o"], - dynamic_axes={ - "noise": [3], - "noise_pred": [3] - }, - opset_version=16 - ) - - x_pred = self.get_x_pred(x, noise_pred, t_1, t_prev) - noise_pred_prev = self.denoise_fn(x_pred, t_prev, cond=cond) - noise_pred_prime = predict_stage0(noise_pred, noise_pred_prev) - - elif plms_noise_stage == 1: - noise_pred_prime = predict_stage1(noise_pred, noise_list) - - elif plms_noise_stage == 2: - noise_pred_prime = predict_stage2(noise_pred, noise_list) - - else: - noise_pred_prime = predict_stage3(noise_pred, noise_list) - - noise_pred = noise_pred.unsqueeze(0) - - if plms_noise_stage < 3: - noise_list = torch.cat((noise_list, noise_pred), dim=0) - plms_noise_stage = plms_noise_stage + 1 - - else: - noise_list = torch.cat((noise_list[-2:], noise_pred), dim=0) - - x = self.get_x_pred(x, noise_pred_prime, t_1, t_prev) - if export_after: - torch.onnx.export( - self.ad, - x.cpu(), - f"{project_name}_after.onnx", - input_names=["x"], - output_names=["mel_out"], - dynamic_axes={ - "x": [3] - }, - opset_version=16 - ) - x = self.ad(x) - - print((x == org_y_x).all()) - return x - - def forward(self, condition=None, init_noise=None, pndms=None, k_step=None): - cond = condition - x = init_noise - - device = cond.device - n_frames = cond.shape[2] - step_range = torch.arange(0, k_step.item(), pndms.item(), dtype=torch.long, device=device).flip(0) - plms_noise_stage = torch.tensor(0, dtype=torch.long, device=device) - noise_list = torch.zeros((0, 1, 1, self.mel_bins, n_frames), device=device) - - ot = step_range[0] - ot_1 = torch.full((1,), ot, device=device, dtype=torch.long) - - for t in step_range: - t_1 = torch.full((1,), t, device=device, dtype=torch.long) - noise_pred = self.denoise_fn(x, t_1, cond) - t_prev = t_1 - pndms - t_prev = t_prev * (t_prev > 0) - if plms_noise_stage == 0: - x_pred = self.get_x_pred(x, noise_pred, t_1, t_prev) - noise_pred_prev = self.denoise_fn(x_pred, t_prev, cond=cond) - noise_pred_prime = predict_stage0(noise_pred, noise_pred_prev) - - elif plms_noise_stage == 1: - noise_pred_prime = predict_stage1(noise_pred, noise_list) - - elif plms_noise_stage == 2: - noise_pred_prime = predict_stage2(noise_pred, noise_list) - - else: - noise_pred_prime = predict_stage3(noise_pred, noise_list) - - noise_pred = noise_pred.unsqueeze(0) - - if plms_noise_stage < 3: - noise_list = torch.cat((noise_list, noise_pred), dim=0) - plms_noise_stage = plms_noise_stage + 1 - - else: - noise_list = torch.cat((noise_list[-2:], noise_pred), dim=0) - - x = self.get_x_pred(x, noise_pred_prime, t_1, t_prev) - x = self.ad(x) - return x diff --git a/diffusion/dpm_solver_pytorch.py b/diffusion/dpm_solver_pytorch.py deleted file mode 100644 index dee5e280661b61e0a99038ce0bd240db51344ead..0000000000000000000000000000000000000000 --- a/diffusion/dpm_solver_pytorch.py +++ /dev/null @@ -1,1201 +0,0 @@ -import math - -import torch - - -class NoiseScheduleVP: - def __init__( - self, - schedule='discrete', - betas=None, - alphas_cumprod=None, - continuous_beta_0=0.1, - continuous_beta_1=20., - ): - """Create a wrapper class for the forward SDE (VP type). - - *** - Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t. - We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images. - *** - - The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ). - We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper). - Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have: - - log_alpha_t = self.marginal_log_mean_coeff(t) - sigma_t = self.marginal_std(t) - lambda_t = self.marginal_lambda(t) - - Moreover, as lambda(t) is an invertible function, we also support its inverse function: - - t = self.inverse_lambda(lambda_t) - - =============================================================== - - We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]). - - 1. For discrete-time DPMs: - - For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by: - t_i = (i + 1) / N - e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1. - We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3. - - Args: - betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details) - alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details) - - Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`. - - **Important**: Please pay special attention for the args for `alphas_cumprod`: - The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that - q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ). - Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have - alpha_{t_n} = \sqrt{\hat{alpha_n}}, - and - log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}). - - - 2. For continuous-time DPMs: - - We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise - schedule are the default settings in DDPM and improved-DDPM: - - Args: - beta_min: A `float` number. The smallest beta for the linear schedule. - beta_max: A `float` number. The largest beta for the linear schedule. - cosine_s: A `float` number. The hyperparameter in the cosine schedule. - cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule. - T: A `float` number. The ending time of the forward process. - - =============================================================== - - Args: - schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs, - 'linear' or 'cosine' for continuous-time DPMs. - Returns: - A wrapper object of the forward SDE (VP type). - - =============================================================== - - Example: - - # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1): - >>> ns = NoiseScheduleVP('discrete', betas=betas) - - # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1): - >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod) - - # For continuous-time DPMs (VPSDE), linear schedule: - >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.) - - """ - - if schedule not in ['discrete', 'linear', 'cosine']: - raise ValueError( - "Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format( - schedule)) - - self.schedule = schedule - if schedule == 'discrete': - if betas is not None: - log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0) - else: - assert alphas_cumprod is not None - log_alphas = 0.5 * torch.log(alphas_cumprod) - self.total_N = len(log_alphas) - self.T = 1. - self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1)) - self.log_alpha_array = log_alphas.reshape((1, -1,)) - else: - self.total_N = 1000 - self.beta_0 = continuous_beta_0 - self.beta_1 = continuous_beta_1 - self.cosine_s = 0.008 - self.cosine_beta_max = 999. - self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * ( - 1. + self.cosine_s) / math.pi - self.cosine_s - self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.)) - self.schedule = schedule - if schedule == 'cosine': - # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T. - # Note that T = 0.9946 may be not the optimal setting. However, we find it works well. - self.T = 0.9946 - else: - self.T = 1. - - def marginal_log_mean_coeff(self, t): - """ - Compute log(alpha_t) of a given continuous-time label t in [0, T]. - """ - if self.schedule == 'discrete': - return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), - self.log_alpha_array.to(t.device)).reshape((-1)) - elif self.schedule == 'linear': - return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 - elif self.schedule == 'cosine': - log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.)) - log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0 - return log_alpha_t - - def marginal_alpha(self, t): - """ - Compute alpha_t of a given continuous-time label t in [0, T]. - """ - return torch.exp(self.marginal_log_mean_coeff(t)) - - def marginal_std(self, t): - """ - Compute sigma_t of a given continuous-time label t in [0, T]. - """ - return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t))) - - def marginal_lambda(self, t): - """ - Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T]. - """ - log_mean_coeff = self.marginal_log_mean_coeff(t) - log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff)) - return log_mean_coeff - log_std - - def inverse_lambda(self, lamb): - """ - Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t. - """ - if self.schedule == 'linear': - tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb)) - Delta = self.beta_0 ** 2 + tmp - return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0) - elif self.schedule == 'discrete': - log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb) - t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), - torch.flip(self.t_array.to(lamb.device), [1])) - return t.reshape((-1,)) - else: - log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb)) - t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * ( - 1. + self.cosine_s) / math.pi - self.cosine_s - t = t_fn(log_alpha) - return t - - -def model_wrapper( - model, - noise_schedule, - model_type="noise", - model_kwargs={}, - guidance_type="uncond", - condition=None, - unconditional_condition=None, - guidance_scale=1., - classifier_fn=None, - classifier_kwargs={}, -): - """Create a wrapper function for the noise prediction model. - - DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to - firstly wrap the model function to a noise prediction model that accepts the continuous time as the input. - - We support four types of the diffusion model by setting `model_type`: - - 1. "noise": noise prediction model. (Trained by predicting noise). - - 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0). - - 3. "v": velocity prediction model. (Trained by predicting the velocity). - The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2]. - - [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models." - arXiv preprint arXiv:2202.00512 (2022). - [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models." - arXiv preprint arXiv:2210.02303 (2022). - - 4. "score": marginal score function. (Trained by denoising score matching). - Note that the score function and the noise prediction model follows a simple relationship: - ``` - noise(x_t, t) = -sigma_t * score(x_t, t) - ``` - - We support three types of guided sampling by DPMs by setting `guidance_type`: - 1. "uncond": unconditional sampling by DPMs. - The input `model` has the following format: - `` - model(x, t_input, **model_kwargs) -> noise | x_start | v | score - `` - - 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier. - The input `model` has the following format: - `` - model(x, t_input, **model_kwargs) -> noise | x_start | v | score - `` - - The input `classifier_fn` has the following format: - `` - classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond) - `` - - [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis," - in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794. - - 3. "classifier-free": classifier-free guidance sampling by conditional DPMs. - The input `model` has the following format: - `` - model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score - `` - And if cond == `unconditional_condition`, the model output is the unconditional DPM output. - - [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance." - arXiv preprint arXiv:2207.12598 (2022). - - - The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999) - or continuous-time labels (i.e. epsilon to T). - - We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise: - `` - def model_fn(x, t_continuous) -> noise: - t_input = get_model_input_time(t_continuous) - return noise_pred(model, x, t_input, **model_kwargs) - `` - where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver. - - =============================================================== - - Args: - model: A diffusion model with the corresponding format described above. - noise_schedule: A noise schedule object, such as NoiseScheduleVP. - model_type: A `str`. The parameterization type of the diffusion model. - "noise" or "x_start" or "v" or "score". - model_kwargs: A `dict`. A dict for the other inputs of the model function. - guidance_type: A `str`. The type of the guidance for sampling. - "uncond" or "classifier" or "classifier-free". - condition: A pytorch tensor. The condition for the guided sampling. - Only used for "classifier" or "classifier-free" guidance type. - unconditional_condition: A pytorch tensor. The condition for the unconditional sampling. - Only used for "classifier-free" guidance type. - guidance_scale: A `float`. The scale for the guided sampling. - classifier_fn: A classifier function. Only used for the classifier guidance. - classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function. - Returns: - A noise prediction model that accepts the noised data and the continuous time as the inputs. - """ - - def get_model_input_time(t_continuous): - """ - Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time. - For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N]. - For continuous-time DPMs, we just use `t_continuous`. - """ - if noise_schedule.schedule == 'discrete': - return (t_continuous - 1. / noise_schedule.total_N) * noise_schedule.total_N - else: - return t_continuous - - def noise_pred_fn(x, t_continuous, cond=None): - if t_continuous.reshape((-1,)).shape[0] == 1: - t_continuous = t_continuous.expand((x.shape[0])) - t_input = get_model_input_time(t_continuous) - if cond is None: - output = model(x, t_input, **model_kwargs) - else: - output = model(x, t_input, cond, **model_kwargs) - if model_type == "noise": - return output - elif model_type == "x_start": - alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) - dims = x.dim() - return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims) - elif model_type == "v": - alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) - dims = x.dim() - return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x - elif model_type == "score": - sigma_t = noise_schedule.marginal_std(t_continuous) - dims = x.dim() - return -expand_dims(sigma_t, dims) * output - - def cond_grad_fn(x, t_input): - """ - Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t). - """ - with torch.enable_grad(): - x_in = x.detach().requires_grad_(True) - log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs) - return torch.autograd.grad(log_prob.sum(), x_in)[0] - - def model_fn(x, t_continuous): - """ - The noise predicition model function that is used for DPM-Solver. - """ - if t_continuous.reshape((-1,)).shape[0] == 1: - t_continuous = t_continuous.expand((x.shape[0])) - if guidance_type == "uncond": - return noise_pred_fn(x, t_continuous) - elif guidance_type == "classifier": - assert classifier_fn is not None - t_input = get_model_input_time(t_continuous) - cond_grad = cond_grad_fn(x, t_input) - sigma_t = noise_schedule.marginal_std(t_continuous) - noise = noise_pred_fn(x, t_continuous) - return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad - elif guidance_type == "classifier-free": - if guidance_scale == 1. or unconditional_condition is None: - return noise_pred_fn(x, t_continuous, cond=condition) - else: - x_in = torch.cat([x] * 2) - t_in = torch.cat([t_continuous] * 2) - c_in = torch.cat([unconditional_condition, condition]) - noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2) - return noise_uncond + guidance_scale * (noise - noise_uncond) - - assert model_type in ["noise", "x_start", "v"] - assert guidance_type in ["uncond", "classifier", "classifier-free"] - return model_fn - - -class DPM_Solver: - def __init__(self, model_fn, noise_schedule, predict_x0=False, thresholding=False, max_val=1.): - """Construct a DPM-Solver. - - We support both the noise prediction model ("predicting epsilon") and the data prediction model ("predicting x0"). - If `predict_x0` is False, we use the solver for the noise prediction model (DPM-Solver). - If `predict_x0` is True, we use the solver for the data prediction model (DPM-Solver++). - In such case, we further support the "dynamic thresholding" in [1] when `thresholding` is True. - The "dynamic thresholding" can greatly improve the sample quality for pixel-space DPMs with large guidance scales. - - Args: - model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]): - `` - def model_fn(x, t_continuous): - return noise - `` - noise_schedule: A noise schedule object, such as NoiseScheduleVP. - predict_x0: A `bool`. If true, use the data prediction model; else, use the noise prediction model. - thresholding: A `bool`. Valid when `predict_x0` is True. Whether to use the "dynamic thresholding" in [1]. - max_val: A `float`. Valid when both `predict_x0` and `thresholding` are True. The max value for thresholding. - - [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b. - """ - self.model = model_fn - self.noise_schedule = noise_schedule - self.predict_x0 = predict_x0 - self.thresholding = thresholding - self.max_val = max_val - - def noise_prediction_fn(self, x, t): - """ - Return the noise prediction model. - """ - return self.model(x, t) - - def data_prediction_fn(self, x, t): - """ - Return the data prediction model (with thresholding). - """ - noise = self.noise_prediction_fn(x, t) - dims = x.dim() - alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t) - x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims) - if self.thresholding: - p = 0.995 # A hyperparameter in the paper of "Imagen" [1]. - s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) - s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims) - x0 = torch.clamp(x0, -s, s) / s - return x0 - - def model_fn(self, x, t): - """ - Convert the model to the noise prediction model or the data prediction model. - """ - if self.predict_x0: - return self.data_prediction_fn(x, t) - else: - return self.noise_prediction_fn(x, t) - - def get_time_steps(self, skip_type, t_T, t_0, N, device): - """Compute the intermediate time steps for sampling. - - Args: - skip_type: A `str`. The type for the spacing of the time steps. We support three types: - - 'logSNR': uniform logSNR for the time steps. - - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.) - - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.) - t_T: A `float`. The starting time of the sampling (default is T). - t_0: A `float`. The ending time of the sampling (default is epsilon). - N: A `int`. The total number of the spacing of the time steps. - device: A torch device. - Returns: - A pytorch tensor of the time steps, with the shape (N + 1,). - """ - if skip_type == 'logSNR': - lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device)) - lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device)) - logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device) - return self.noise_schedule.inverse_lambda(logSNR_steps) - elif skip_type == 'time_uniform': - return torch.linspace(t_T, t_0, N + 1).to(device) - elif skip_type == 'time_quadratic': - t_order = 2 - t = torch.linspace(t_T ** (1. / t_order), t_0 ** (1. / t_order), N + 1).pow(t_order).to(device) - return t - else: - raise ValueError( - "Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type)) - - def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device): - """ - Get the order of each step for sampling by the singlestep DPM-Solver. - - We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast". - Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is: - - If order == 1: - We take `steps` of DPM-Solver-1 (i.e. DDIM). - - If order == 2: - - Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling. - - If steps % 2 == 0, we use K steps of DPM-Solver-2. - - If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1. - - If order == 3: - - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling. - - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1. - - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1. - - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2. - - ============================================ - Args: - order: A `int`. The max order for the solver (2 or 3). - steps: A `int`. The total number of function evaluations (NFE). - skip_type: A `str`. The type for the spacing of the time steps. We support three types: - - 'logSNR': uniform logSNR for the time steps. - - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.) - - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.) - t_T: A `float`. The starting time of the sampling (default is T). - t_0: A `float`. The ending time of the sampling (default is epsilon). - device: A torch device. - Returns: - orders: A list of the solver order of each step. - """ - if order == 3: - K = steps // 3 + 1 - if steps % 3 == 0: - orders = [3, ] * (K - 2) + [2, 1] - elif steps % 3 == 1: - orders = [3, ] * (K - 1) + [1] - else: - orders = [3, ] * (K - 1) + [2] - elif order == 2: - if steps % 2 == 0: - K = steps // 2 - orders = [2, ] * K - else: - K = steps // 2 + 1 - orders = [2, ] * (K - 1) + [1] - elif order == 1: - K = 1 - orders = [1, ] * steps - else: - raise ValueError("'order' must be '1' or '2' or '3'.") - if skip_type == 'logSNR': - # To reproduce the results in DPM-Solver paper - timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device) - else: - timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[ - torch.cumsum(torch.tensor([0, ] + orders), dim=0).to(device)] - return timesteps_outer, orders - - def denoise_fn(self, x, s): - """ - Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization. - """ - return self.data_prediction_fn(x, s) - - def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False): - """ - DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`. - - Args: - x: A pytorch tensor. The initial value at time `s`. - s: A pytorch tensor. The starting time, with the shape (x.shape[0],). - t: A pytorch tensor. The ending time, with the shape (x.shape[0],). - model_s: A pytorch tensor. The model function evaluated at time `s`. - If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. - return_intermediate: A `bool`. If true, also return the model value at time `s`. - Returns: - x_t: A pytorch tensor. The approximated solution at time `t`. - """ - ns = self.noise_schedule - dims = x.dim() - lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) - h = lambda_t - lambda_s - log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t) - sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t) - alpha_t = torch.exp(log_alpha_t) - - if self.predict_x0: - phi_1 = torch.expm1(-h) - if model_s is None: - model_s = self.model_fn(x, s) - x_t = ( - expand_dims(sigma_t / sigma_s, dims) * x - - expand_dims(alpha_t * phi_1, dims) * model_s - ) - if return_intermediate: - return x_t, {'model_s': model_s} - else: - return x_t - else: - phi_1 = torch.expm1(h) - if model_s is None: - model_s = self.model_fn(x, s) - x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x - - expand_dims(sigma_t * phi_1, dims) * model_s - ) - if return_intermediate: - return x_t, {'model_s': model_s} - else: - return x_t - - def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, return_intermediate=False, - solver_type='dpm_solver'): - """ - Singlestep solver DPM-Solver-2 from time `s` to time `t`. - - Args: - x: A pytorch tensor. The initial value at time `s`. - s: A pytorch tensor. The starting time, with the shape (x.shape[0],). - t: A pytorch tensor. The ending time, with the shape (x.shape[0],). - r1: A `float`. The hyperparameter of the second-order solver. - model_s: A pytorch tensor. The model function evaluated at time `s`. - If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. - return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time). - solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. - The type slightly impacts the performance. We recommend to use 'dpm_solver' type. - Returns: - x_t: A pytorch tensor. The approximated solution at time `t`. - """ - if solver_type not in ['dpm_solver', 'taylor']: - raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type)) - if r1 is None: - r1 = 0.5 - ns = self.noise_schedule - dims = x.dim() - lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) - h = lambda_t - lambda_s - lambda_s1 = lambda_s + r1 * h - s1 = ns.inverse_lambda(lambda_s1) - log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff( - s1), ns.marginal_log_mean_coeff(t) - sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t) - alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t) - - if self.predict_x0: - phi_11 = torch.expm1(-r1 * h) - phi_1 = torch.expm1(-h) - - if model_s is None: - model_s = self.model_fn(x, s) - x_s1 = ( - expand_dims(sigma_s1 / sigma_s, dims) * x - - expand_dims(alpha_s1 * phi_11, dims) * model_s - ) - model_s1 = self.model_fn(x_s1, s1) - if solver_type == 'dpm_solver': - x_t = ( - expand_dims(sigma_t / sigma_s, dims) * x - - expand_dims(alpha_t * phi_1, dims) * model_s - - (0.5 / r1) * expand_dims(alpha_t * phi_1, dims) * (model_s1 - model_s) - ) - elif solver_type == 'taylor': - x_t = ( - expand_dims(sigma_t / sigma_s, dims) * x - - expand_dims(alpha_t * phi_1, dims) * model_s - + (1. / r1) * expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * ( - model_s1 - model_s) - ) - else: - phi_11 = torch.expm1(r1 * h) - phi_1 = torch.expm1(h) - - if model_s is None: - model_s = self.model_fn(x, s) - x_s1 = ( - expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x - - expand_dims(sigma_s1 * phi_11, dims) * model_s - ) - model_s1 = self.model_fn(x_s1, s1) - if solver_type == 'dpm_solver': - x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x - - expand_dims(sigma_t * phi_1, dims) * model_s - - (0.5 / r1) * expand_dims(sigma_t * phi_1, dims) * (model_s1 - model_s) - ) - elif solver_type == 'taylor': - x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x - - expand_dims(sigma_t * phi_1, dims) * model_s - - (1. / r1) * expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * (model_s1 - model_s) - ) - if return_intermediate: - return x_t, {'model_s': model_s, 'model_s1': model_s1} - else: - return x_t - - def singlestep_dpm_solver_third_update(self, x, s, t, r1=1. / 3., r2=2. / 3., model_s=None, model_s1=None, - return_intermediate=False, solver_type='dpm_solver'): - """ - Singlestep solver DPM-Solver-3 from time `s` to time `t`. - - Args: - x: A pytorch tensor. The initial value at time `s`. - s: A pytorch tensor. The starting time, with the shape (x.shape[0],). - t: A pytorch tensor. The ending time, with the shape (x.shape[0],). - r1: A `float`. The hyperparameter of the third-order solver. - r2: A `float`. The hyperparameter of the third-order solver. - model_s: A pytorch tensor. The model function evaluated at time `s`. - If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. - model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`). - If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it. - return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times). - solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. - The type slightly impacts the performance. We recommend to use 'dpm_solver' type. - Returns: - x_t: A pytorch tensor. The approximated solution at time `t`. - """ - if solver_type not in ['dpm_solver', 'taylor']: - raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type)) - if r1 is None: - r1 = 1. / 3. - if r2 is None: - r2 = 2. / 3. - ns = self.noise_schedule - dims = x.dim() - lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) - h = lambda_t - lambda_s - lambda_s1 = lambda_s + r1 * h - lambda_s2 = lambda_s + r2 * h - s1 = ns.inverse_lambda(lambda_s1) - s2 = ns.inverse_lambda(lambda_s2) - log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff( - s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t) - sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std( - s2), ns.marginal_std(t) - alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t) - - if self.predict_x0: - phi_11 = torch.expm1(-r1 * h) - phi_12 = torch.expm1(-r2 * h) - phi_1 = torch.expm1(-h) - phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1. - phi_2 = phi_1 / h + 1. - phi_3 = phi_2 / h - 0.5 - - if model_s is None: - model_s = self.model_fn(x, s) - if model_s1 is None: - x_s1 = ( - expand_dims(sigma_s1 / sigma_s, dims) * x - - expand_dims(alpha_s1 * phi_11, dims) * model_s - ) - model_s1 = self.model_fn(x_s1, s1) - x_s2 = ( - expand_dims(sigma_s2 / sigma_s, dims) * x - - expand_dims(alpha_s2 * phi_12, dims) * model_s - + r2 / r1 * expand_dims(alpha_s2 * phi_22, dims) * (model_s1 - model_s) - ) - model_s2 = self.model_fn(x_s2, s2) - if solver_type == 'dpm_solver': - x_t = ( - expand_dims(sigma_t / sigma_s, dims) * x - - expand_dims(alpha_t * phi_1, dims) * model_s - + (1. / r2) * expand_dims(alpha_t * phi_2, dims) * (model_s2 - model_s) - ) - elif solver_type == 'taylor': - D1_0 = (1. / r1) * (model_s1 - model_s) - D1_1 = (1. / r2) * (model_s2 - model_s) - D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) - D2 = 2. * (D1_1 - D1_0) / (r2 - r1) - x_t = ( - expand_dims(sigma_t / sigma_s, dims) * x - - expand_dims(alpha_t * phi_1, dims) * model_s - + expand_dims(alpha_t * phi_2, dims) * D1 - - expand_dims(alpha_t * phi_3, dims) * D2 - ) - else: - phi_11 = torch.expm1(r1 * h) - phi_12 = torch.expm1(r2 * h) - phi_1 = torch.expm1(h) - phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1. - phi_2 = phi_1 / h - 1. - phi_3 = phi_2 / h - 0.5 - - if model_s is None: - model_s = self.model_fn(x, s) - if model_s1 is None: - x_s1 = ( - expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x - - expand_dims(sigma_s1 * phi_11, dims) * model_s - ) - model_s1 = self.model_fn(x_s1, s1) - x_s2 = ( - expand_dims(torch.exp(log_alpha_s2 - log_alpha_s), dims) * x - - expand_dims(sigma_s2 * phi_12, dims) * model_s - - r2 / r1 * expand_dims(sigma_s2 * phi_22, dims) * (model_s1 - model_s) - ) - model_s2 = self.model_fn(x_s2, s2) - if solver_type == 'dpm_solver': - x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x - - expand_dims(sigma_t * phi_1, dims) * model_s - - (1. / r2) * expand_dims(sigma_t * phi_2, dims) * (model_s2 - model_s) - ) - elif solver_type == 'taylor': - D1_0 = (1. / r1) * (model_s1 - model_s) - D1_1 = (1. / r2) * (model_s2 - model_s) - D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) - D2 = 2. * (D1_1 - D1_0) / (r2 - r1) - x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x - - expand_dims(sigma_t * phi_1, dims) * model_s - - expand_dims(sigma_t * phi_2, dims) * D1 - - expand_dims(sigma_t * phi_3, dims) * D2 - ) - - if return_intermediate: - return x_t, {'model_s': model_s, 'model_s1': model_s1, 'model_s2': model_s2} - else: - return x_t - - def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpm_solver"): - """ - Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`. - - Args: - x: A pytorch tensor. The initial value at time `s`. - model_prev_list: A list of pytorch tensor. The previous computed model values. - t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],) - t: A pytorch tensor. The ending time, with the shape (x.shape[0],). - solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. - The type slightly impacts the performance. We recommend to use 'dpm_solver' type. - Returns: - x_t: A pytorch tensor. The approximated solution at time `t`. - """ - if solver_type not in ['dpm_solver', 'taylor']: - raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type)) - ns = self.noise_schedule - dims = x.dim() - model_prev_1, model_prev_0 = model_prev_list - t_prev_1, t_prev_0 = t_prev_list - lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda( - t_prev_0), ns.marginal_lambda(t) - log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) - sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) - alpha_t = torch.exp(log_alpha_t) - - h_0 = lambda_prev_0 - lambda_prev_1 - h = lambda_t - lambda_prev_0 - r0 = h_0 / h - D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1) - if self.predict_x0: - if solver_type == 'dpm_solver': - x_t = ( - expand_dims(sigma_t / sigma_prev_0, dims) * x - - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0 - - 0.5 * expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * D1_0 - ) - elif solver_type == 'taylor': - x_t = ( - expand_dims(sigma_t / sigma_prev_0, dims) * x - - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0 - + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1_0 - ) - else: - if solver_type == 'dpm_solver': - x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x - - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0 - - 0.5 * expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * D1_0 - ) - elif solver_type == 'taylor': - x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x - - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0 - - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1_0 - ) - return x_t - - def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type='dpm_solver'): - """ - Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`. - - Args: - x: A pytorch tensor. The initial value at time `s`. - model_prev_list: A list of pytorch tensor. The previous computed model values. - t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],) - t: A pytorch tensor. The ending time, with the shape (x.shape[0],). - solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. - The type slightly impacts the performance. We recommend to use 'dpm_solver' type. - Returns: - x_t: A pytorch tensor. The approximated solution at time `t`. - """ - ns = self.noise_schedule - dims = x.dim() - model_prev_2, model_prev_1, model_prev_0 = model_prev_list - t_prev_2, t_prev_1, t_prev_0 = t_prev_list - lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda( - t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t) - log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) - sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) - alpha_t = torch.exp(log_alpha_t) - - h_1 = lambda_prev_1 - lambda_prev_2 - h_0 = lambda_prev_0 - lambda_prev_1 - h = lambda_t - lambda_prev_0 - r0, r1 = h_0 / h, h_1 / h - D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1) - D1_1 = expand_dims(1. / r1, dims) * (model_prev_1 - model_prev_2) - D1 = D1_0 + expand_dims(r0 / (r0 + r1), dims) * (D1_0 - D1_1) - D2 = expand_dims(1. / (r0 + r1), dims) * (D1_0 - D1_1) - if self.predict_x0: - x_t = ( - expand_dims(sigma_t / sigma_prev_0, dims) * x - - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0 - + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1 - - expand_dims(alpha_t * ((torch.exp(-h) - 1. + h) / h ** 2 - 0.5), dims) * D2 - ) - else: - x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x - - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0 - - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1 - - expand_dims(sigma_t * ((torch.exp(h) - 1. - h) / h ** 2 - 0.5), dims) * D2 - ) - return x_t - - def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False, solver_type='dpm_solver', r1=None, - r2=None): - """ - Singlestep DPM-Solver with the order `order` from time `s` to time `t`. - - Args: - x: A pytorch tensor. The initial value at time `s`. - s: A pytorch tensor. The starting time, with the shape (x.shape[0],). - t: A pytorch tensor. The ending time, with the shape (x.shape[0],). - order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3. - return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times). - solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. - The type slightly impacts the performance. We recommend to use 'dpm_solver' type. - r1: A `float`. The hyperparameter of the second-order or third-order solver. - r2: A `float`. The hyperparameter of the third-order solver. - Returns: - x_t: A pytorch tensor. The approximated solution at time `t`. - """ - if order == 1: - return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate) - elif order == 2: - return self.singlestep_dpm_solver_second_update(x, s, t, return_intermediate=return_intermediate, - solver_type=solver_type, r1=r1) - elif order == 3: - return self.singlestep_dpm_solver_third_update(x, s, t, return_intermediate=return_intermediate, - solver_type=solver_type, r1=r1, r2=r2) - else: - raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) - - def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type='dpm_solver'): - """ - Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`. - - Args: - x: A pytorch tensor. The initial value at time `s`. - model_prev_list: A list of pytorch tensor. The previous computed model values. - t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],) - t: A pytorch tensor. The ending time, with the shape (x.shape[0],). - order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3. - solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. - The type slightly impacts the performance. We recommend to use 'dpm_solver' type. - Returns: - x_t: A pytorch tensor. The approximated solution at time `t`. - """ - if order == 1: - return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1]) - elif order == 2: - return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type) - elif order == 3: - return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type) - else: - raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) - - def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5, - solver_type='dpm_solver'): - """ - The adaptive step size solver based on singlestep DPM-Solver. - - Args: - x: A pytorch tensor. The initial value at time `t_T`. - order: A `int`. The (higher) order of the solver. We only support order == 2 or 3. - t_T: A `float`. The starting time of the sampling (default is T). - t_0: A `float`. The ending time of the sampling (default is epsilon). - h_init: A `float`. The initial step size (for logSNR). - atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1]. - rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05. - theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1]. - t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the - current time and `t_0` is less than `t_err`. The default setting is 1e-5. - solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. - The type slightly impacts the performance. We recommend to use 'dpm_solver' type. - Returns: - x_0: A pytorch tensor. The approximated solution at time `t_0`. - - [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021. - """ - ns = self.noise_schedule - s = t_T * torch.ones((x.shape[0],)).to(x) - lambda_s = ns.marginal_lambda(s) - lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x)) - h = h_init * torch.ones_like(s).to(x) - x_prev = x - nfe = 0 - if order == 2: - r1 = 0.5 - lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True) - higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, - solver_type=solver_type, - **kwargs) - elif order == 3: - r1, r2 = 1. / 3., 2. / 3. - lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, - return_intermediate=True, - solver_type=solver_type) - higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2, - solver_type=solver_type, - **kwargs) - else: - raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order)) - while torch.abs((s - t_0)).mean() > t_err: - t = ns.inverse_lambda(lambda_s + h) - x_lower, lower_noise_kwargs = lower_update(x, s, t) - x_higher = higher_update(x, s, t, **lower_noise_kwargs) - delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev))) - norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True)) - E = norm_fn((x_higher - x_lower) / delta).max() - if torch.all(E <= 1.): - x = x_higher - s = t - x_prev = x_lower - lambda_s = ns.marginal_lambda(s) - h = torch.min(theta * h * torch.float_power(E, -1. / order).float(), lambda_0 - lambda_s) - nfe += order - print('adaptive solver nfe', nfe) - return x - - def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform', - method='singlestep', denoise=False, solver_type='dpm_solver', atol=0.0078, - rtol=0.05, - ): - """ - Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`. - - ===================================================== - - We support the following algorithms for both noise prediction model and data prediction model: - - 'singlestep': - Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver. - We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps). - The total number of function evaluations (NFE) == `steps`. - Given a fixed NFE == `steps`, the sampling procedure is: - - If `order` == 1: - - Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM). - - If `order` == 2: - - Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling. - - If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2. - - If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1. - - If `order` == 3: - - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling. - - If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1. - - If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1. - - If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2. - - 'multistep': - Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`. - We initialize the first `order` values by lower order multistep solvers. - Given a fixed NFE == `steps`, the sampling procedure is: - Denote K = steps. - - If `order` == 1: - - We use K steps of DPM-Solver-1 (i.e. DDIM). - - If `order` == 2: - - We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2. - - If `order` == 3: - - We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3. - - 'singlestep_fixed': - Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3). - We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE. - - 'adaptive': - Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper). - We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`. - You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs - (NFE) and the sample quality. - - If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2. - - If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3. - - ===================================================== - - Some advices for choosing the algorithm: - - For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs: - Use singlestep DPM-Solver ("DPM-Solver-fast" in the paper) with `order = 3`. - e.g. - >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=False) - >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3, - skip_type='time_uniform', method='singlestep') - - For **guided sampling with large guidance scale** by DPMs: - Use multistep DPM-Solver with `predict_x0 = True` and `order = 2`. - e.g. - >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True) - >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2, - skip_type='time_uniform', method='multistep') - - We support three types of `skip_type`: - - 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images** - - 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**. - - 'time_quadratic': quadratic time for the time steps. - - ===================================================== - Args: - x: A pytorch tensor. The initial value at time `t_start` - e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution. - steps: A `int`. The total number of function evaluations (NFE). - t_start: A `float`. The starting time of the sampling. - If `T` is None, we use self.noise_schedule.T (default is 1.0). - t_end: A `float`. The ending time of the sampling. - If `t_end` is None, we use 1. / self.noise_schedule.total_N. - e.g. if total_N == 1000, we have `t_end` == 1e-3. - For discrete-time DPMs: - - We recommend `t_end` == 1. / self.noise_schedule.total_N. - For continuous-time DPMs: - - We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15. - order: A `int`. The order of DPM-Solver. - skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'. - method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'. - denoise: A `bool`. Whether to denoise at the final step. Default is False. - If `denoise` is True, the total NFE is (`steps` + 1). - solver_type: A `str`. The taylor expansion type for the solver. `dpm_solver` or `taylor`. We recommend `dpm_solver`. - atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'. - rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'. - Returns: - x_end: A pytorch tensor. The approximated solution at time `t_end`. - - """ - t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end - t_T = self.noise_schedule.T if t_start is None else t_start - device = x.device - if method == 'adaptive': - with torch.no_grad(): - x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol, - solver_type=solver_type) - elif method == 'multistep': - assert steps >= order - timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device) - assert timesteps.shape[0] - 1 == steps - with torch.no_grad(): - vec_t = timesteps[0].expand((x.shape[0])) - model_prev_list = [self.model_fn(x, vec_t)] - t_prev_list = [vec_t] - # Init the first `order` values by lower order multistep DPM-Solver. - for init_order in range(1, order): - vec_t = timesteps[init_order].expand(x.shape[0]) - x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, init_order, - solver_type=solver_type) - model_prev_list.append(self.model_fn(x, vec_t)) - t_prev_list.append(vec_t) - # Compute the remaining values by `order`-th order multistep DPM-Solver. - for step in range(order, steps + 1): - vec_t = timesteps[step].expand(x.shape[0]) - x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, order, - solver_type=solver_type) - for i in range(order - 1): - t_prev_list[i] = t_prev_list[i + 1] - model_prev_list[i] = model_prev_list[i + 1] - t_prev_list[-1] = vec_t - # We do not need to evaluate the final model value. - if step < steps: - model_prev_list[-1] = self.model_fn(x, vec_t) - elif method in ['singlestep', 'singlestep_fixed']: - if method == 'singlestep': - timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps=steps, order=order, - skip_type=skip_type, - t_T=t_T, t_0=t_0, - device=device) - elif method == 'singlestep_fixed': - K = steps // order - orders = [order, ] * K - timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device) - for i, order in enumerate(orders): - t_T_inner, t_0_inner = timesteps_outer[i], timesteps_outer[i + 1] - timesteps_inner = self.get_time_steps(skip_type=skip_type, t_T=t_T_inner.item(), t_0=t_0_inner.item(), - N=order, device=device) - lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner) - vec_s, vec_t = t_T_inner.repeat(x.shape[0]), t_0_inner.repeat(x.shape[0]) - h = lambda_inner[-1] - lambda_inner[0] - r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h - r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h - x = self.singlestep_dpm_solver_update(x, vec_s, vec_t, order, solver_type=solver_type, r1=r1, r2=r2) - if denoise: - x = self.denoise_fn(x, torch.ones((x.shape[0],)).to(device) * t_0) - return x - - -############################################################# -# other utility functions -############################################################# - -def interpolate_fn(x, xp, yp): - """ - A piecewise linear function y = f(x), using xp and yp as keypoints. - We implement f(x) in a differentiable way (i.e. applicable for autograd). - The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.) - - Args: - x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver). - xp: PyTorch tensor with shape [C, K], where K is the number of keypoints. - yp: PyTorch tensor with shape [C, K]. - Returns: - The function values f(x), with shape [N, C]. - """ - N, K = x.shape[0], xp.shape[1] - all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2) - sorted_all_x, x_indices = torch.sort(all_x, dim=2) - x_idx = torch.argmin(x_indices, dim=2) - cand_start_idx = x_idx - 1 - start_idx = torch.where( - torch.eq(x_idx, 0), - torch.tensor(1, device=x.device), - torch.where( - torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, - ), - ) - end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1) - start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2) - end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2) - start_idx2 = torch.where( - torch.eq(x_idx, 0), - torch.tensor(0, device=x.device), - torch.where( - torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, - ), - ) - y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1) - start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2) - end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2) - cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x) - return cand - - -def expand_dims(v, dims): - """ - Expand the tensor `v` to the dim `dims`. - - Args: - `v`: a PyTorch tensor with shape [N]. - `dim`: a `int`. - Returns: - a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`. - """ - return v[(...,) + (None,) * (dims - 1)] diff --git a/diffusion/how to export onnx.md b/diffusion/how to export onnx.md deleted file mode 100644 index 6d22719fd1a8e9d034e6224cc95f4b50d44a0320..0000000000000000000000000000000000000000 --- a/diffusion/how to export onnx.md +++ /dev/null @@ -1,4 +0,0 @@ -- Open [onnx_export](onnx_export.py) -- project_name = "dddsp" change "project_name" to your project name -- model_path = f'{project_name}/model_500000.pt' change "model_path" to your model path -- Run \ No newline at end of file diff --git a/diffusion/infer_gt_mel.py b/diffusion/infer_gt_mel.py deleted file mode 100644 index 033b821a5d21a1232f1786bce5616b12e01488ad..0000000000000000000000000000000000000000 --- a/diffusion/infer_gt_mel.py +++ /dev/null @@ -1,74 +0,0 @@ -import numpy as np -import torch -import torch.nn.functional as F -from diffusion.unit2mel import load_model_vocoder - - -class DiffGtMel: - def __init__(self, project_path=None, device=None): - self.project_path = project_path - if device is not None: - self.device = device - else: - self.device = 'cuda' if torch.cuda.is_available() else 'cpu' - self.model = None - self.vocoder = None - self.args = None - - def flush_model(self, project_path, ddsp_config=None): - if (self.model is None) or (project_path != self.project_path): - model, vocoder, args = load_model_vocoder(project_path, device=self.device) - if self.check_args(ddsp_config, args): - self.model = model - self.vocoder = vocoder - self.args = args - - def check_args(self, args1, args2): - if args1.data.block_size != args2.data.block_size: - raise ValueError("DDSP与DIFF模型的block_size不一致") - if args1.data.sampling_rate != args2.data.sampling_rate: - raise ValueError("DDSP与DIFF模型的sampling_rate不一致") - if args1.data.encoder != args2.data.encoder: - raise ValueError("DDSP与DIFF模型的encoder不一致") - return True - - def __call__(self, audio, f0, hubert, volume, acc=1, spk_id=1, k_step=0, method='pndm', - spk_mix_dict=None, start_frame=0): - input_mel = self.vocoder.extract(audio, self.args.data.sampling_rate) - out_mel = self.model( - hubert, - f0, - volume, - spk_id=spk_id, - spk_mix_dict=spk_mix_dict, - gt_spec=input_mel, - infer=True, - infer_speedup=acc, - method=method, - k_step=k_step, - use_tqdm=False) - if start_frame > 0: - out_mel = out_mel[:, start_frame:, :] - f0 = f0[:, start_frame:, :] - output = self.vocoder.infer(out_mel, f0) - if start_frame > 0: - output = F.pad(output, (start_frame * self.vocoder.vocoder_hop_size, 0)) - return output - - def infer(self, audio, f0, hubert, volume, acc=1, spk_id=1, k_step=0, method='pndm', silence_front=0, - use_silence=False, spk_mix_dict=None): - start_frame = int(silence_front * self.vocoder.vocoder_sample_rate / self.vocoder.vocoder_hop_size) - if use_silence: - audio = audio[:, start_frame * self.vocoder.vocoder_hop_size:] - f0 = f0[:, start_frame:, :] - hubert = hubert[:, start_frame:, :] - volume = volume[:, start_frame:, :] - _start_frame = 0 - else: - _start_frame = start_frame - audio = self.__call__(audio, f0, hubert, volume, acc=acc, spk_id=spk_id, k_step=k_step, - method=method, spk_mix_dict=spk_mix_dict, start_frame=_start_frame) - if use_silence: - if start_frame > 0: - audio = F.pad(audio, (start_frame * self.vocoder.vocoder_hop_size, 0)) - return audio diff --git a/diffusion/logger/__init__.py b/diffusion/logger/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/diffusion/logger/saver.py b/diffusion/logger/saver.py deleted file mode 100644 index ef78b52b6bcd32106f962b731d3784d72d5f0cce..0000000000000000000000000000000000000000 --- a/diffusion/logger/saver.py +++ /dev/null @@ -1,150 +0,0 @@ -''' -author: wayn391@mastertones -''' - -import os -import json -import time -import yaml -import datetime -import torch -import matplotlib.pyplot as plt -from . import utils -from torch.utils.tensorboard import SummaryWriter - -class Saver(object): - def __init__( - self, - args, - initial_global_step=-1): - - self.expdir = args.env.expdir - self.sample_rate = args.data.sampling_rate - - # cold start - self.global_step = initial_global_step - self.init_time = time.time() - self.last_time = time.time() - - # makedirs - os.makedirs(self.expdir, exist_ok=True) - - # path - self.path_log_info = os.path.join(self.expdir, 'log_info.txt') - - # ckpt - os.makedirs(self.expdir, exist_ok=True) - - # writer - self.writer = SummaryWriter(os.path.join(self.expdir, 'logs')) - - # save config - path_config = os.path.join(self.expdir, 'config.yaml') - with open(path_config, "w") as out_config: - yaml.dump(dict(args), out_config) - - - def log_info(self, msg): - '''log method''' - if isinstance(msg, dict): - msg_list = [] - for k, v in msg.items(): - tmp_str = '' - if isinstance(v, int): - tmp_str = '{}: {:,}'.format(k, v) - else: - tmp_str = '{}: {}'.format(k, v) - - msg_list.append(tmp_str) - msg_str = '\n'.join(msg_list) - else: - msg_str = msg - - # dsplay - print(msg_str) - - # save - with open(self.path_log_info, 'a') as fp: - fp.write(msg_str+'\n') - - def log_value(self, dict): - for k, v in dict.items(): - self.writer.add_scalar(k, v, self.global_step) - - def log_spec(self, name, spec, spec_out, vmin=-14, vmax=3.5): - spec_cat = torch.cat([(spec_out - spec).abs() + vmin, spec, spec_out], -1) - spec = spec_cat[0] - if isinstance(spec, torch.Tensor): - spec = spec.cpu().numpy() - fig = plt.figure(figsize=(12, 9)) - plt.pcolor(spec.T, vmin=vmin, vmax=vmax) - plt.tight_layout() - self.writer.add_figure(name, fig, self.global_step) - - def log_audio(self, dict): - for k, v in dict.items(): - self.writer.add_audio(k, v, global_step=self.global_step, sample_rate=self.sample_rate) - - def get_interval_time(self, update=True): - cur_time = time.time() - time_interval = cur_time - self.last_time - if update: - self.last_time = cur_time - return time_interval - - def get_total_time(self, to_str=True): - total_time = time.time() - self.init_time - if to_str: - total_time = str(datetime.timedelta( - seconds=total_time))[:-5] - return total_time - - def save_model( - self, - model, - optimizer, - name='model', - postfix='', - to_json=False): - # path - if postfix: - postfix = '_' + postfix - path_pt = os.path.join( - self.expdir , name+postfix+'.pt') - - # check - print(' [*] model checkpoint saved: {}'.format(path_pt)) - - # save - if optimizer is not None: - torch.save({ - 'global_step': self.global_step, - 'model': model.state_dict(), - 'optimizer': optimizer.state_dict()}, path_pt) - else: - torch.save({ - 'global_step': self.global_step, - 'model': model.state_dict()}, path_pt) - - # to json - if to_json: - path_json = os.path.join( - self.expdir , name+'.json') - utils.to_json(path_params, path_json) - - def delete_model(self, name='model', postfix=''): - # path - if postfix: - postfix = '_' + postfix - path_pt = os.path.join( - self.expdir , name+postfix+'.pt') - - # delete - if os.path.exists(path_pt): - os.remove(path_pt) - print(' [*] model checkpoint deleted: {}'.format(path_pt)) - - def global_step_increment(self): - self.global_step += 1 - - diff --git a/diffusion/logger/utils.py b/diffusion/logger/utils.py deleted file mode 100644 index 485681ced897980dc0bf5b149308245bbd708de9..0000000000000000000000000000000000000000 --- a/diffusion/logger/utils.py +++ /dev/null @@ -1,126 +0,0 @@ -import os -import yaml -import json -import pickle -import torch - -def traverse_dir( - root_dir, - extensions, - amount=None, - str_include=None, - str_exclude=None, - is_pure=False, - is_sort=False, - is_ext=True): - - file_list = [] - cnt = 0 - for root, _, files in os.walk(root_dir): - for file in files: - if any([file.endswith(f".{ext}") for ext in extensions]): - # path - mix_path = os.path.join(root, file) - pure_path = mix_path[len(root_dir)+1:] if is_pure else mix_path - - # amount - if (amount is not None) and (cnt == amount): - if is_sort: - file_list.sort() - return file_list - - # check string - if (str_include is not None) and (str_include not in pure_path): - continue - if (str_exclude is not None) and (str_exclude in pure_path): - continue - - if not is_ext: - ext = pure_path.split('.')[-1] - pure_path = pure_path[:-(len(ext)+1)] - file_list.append(pure_path) - cnt += 1 - if is_sort: - file_list.sort() - return file_list - - - -class DotDict(dict): - def __getattr__(*args): - val = dict.get(*args) - return DotDict(val) if type(val) is dict else val - - __setattr__ = dict.__setitem__ - __delattr__ = dict.__delitem__ - - -def get_network_paras_amount(model_dict): - info = dict() - for model_name, model in model_dict.items(): - # all_params = sum(p.numel() for p in model.parameters()) - trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) - - info[model_name] = trainable_params - return info - - -def load_config(path_config): - with open(path_config, "r") as config: - args = yaml.safe_load(config) - args = DotDict(args) - # print(args) - return args - -def save_config(path_config,config): - config = dict(config) - with open(path_config, "w") as f: - yaml.dump(config, f) - -def to_json(path_params, path_json): - params = torch.load(path_params, map_location=torch.device('cpu')) - raw_state_dict = {} - for k, v in params.items(): - val = v.flatten().numpy().tolist() - raw_state_dict[k] = val - - with open(path_json, 'w') as outfile: - json.dump(raw_state_dict, outfile,indent= "\t") - - -def convert_tensor_to_numpy(tensor, is_squeeze=True): - if is_squeeze: - tensor = tensor.squeeze() - if tensor.requires_grad: - tensor = tensor.detach() - if tensor.is_cuda: - tensor = tensor.cpu() - return tensor.numpy() - - -def load_model( - expdir, - model, - optimizer, - name='model', - postfix='', - device='cpu'): - if postfix == '': - postfix = '_' + postfix - path = os.path.join(expdir, name+postfix) - path_pt = traverse_dir(expdir, ['pt'], is_ext=False) - global_step = 0 - if len(path_pt) > 0: - steps = [s[len(path):] for s in path_pt] - maxstep = max([int(s) if s.isdigit() else 0 for s in steps]) - if maxstep >= 0: - path_pt = path+str(maxstep)+'.pt' - else: - path_pt = path+'best.pt' - print(' [*] restoring model from', path_pt) - ckpt = torch.load(path_pt, map_location=torch.device(device)) - global_step = ckpt['global_step'] - model.load_state_dict(ckpt['model'], strict=False) - if ckpt.get('optimizer') != None: - optimizer.load_state_dict(ckpt['optimizer']) - return global_step, model, optimizer diff --git a/diffusion/onnx_export.py b/diffusion/onnx_export.py deleted file mode 100644 index 5deda785cf22b341f7d2e6399ef5fcdad6fe129e..0000000000000000000000000000000000000000 --- a/diffusion/onnx_export.py +++ /dev/null @@ -1,226 +0,0 @@ -from diffusion_onnx import GaussianDiffusion -import os -import yaml -import torch -import torch.nn as nn -import numpy as np -from wavenet import WaveNet -import torch.nn.functional as F -import diffusion - -class DotDict(dict): - def __getattr__(*args): - val = dict.get(*args) - return DotDict(val) if type(val) is dict else val - - __setattr__ = dict.__setitem__ - __delattr__ = dict.__delitem__ - - -def load_model_vocoder( - model_path, - device='cpu'): - config_file = os.path.join(os.path.split(model_path)[0], 'config.yaml') - with open(config_file, "r") as config: - args = yaml.safe_load(config) - args = DotDict(args) - - # load model - model = Unit2Mel( - args.data.encoder_out_channels, - args.model.n_spk, - args.model.use_pitch_aug, - 128, - args.model.n_layers, - args.model.n_chans, - args.model.n_hidden) - - print(' [Loading] ' + model_path) - ckpt = torch.load(model_path, map_location=torch.device(device)) - model.to(device) - model.load_state_dict(ckpt['model']) - model.eval() - return model, args - - -class Unit2Mel(nn.Module): - def __init__( - self, - input_channel, - n_spk, - use_pitch_aug=False, - out_dims=128, - n_layers=20, - n_chans=384, - n_hidden=256): - super().__init__() - self.unit_embed = nn.Linear(input_channel, n_hidden) - self.f0_embed = nn.Linear(1, n_hidden) - self.volume_embed = nn.Linear(1, n_hidden) - if use_pitch_aug: - self.aug_shift_embed = nn.Linear(1, n_hidden, bias=False) - else: - self.aug_shift_embed = None - self.n_spk = n_spk - if n_spk is not None and n_spk > 1: - self.spk_embed = nn.Embedding(n_spk, n_hidden) - - # diffusion - self.decoder = GaussianDiffusion(out_dims, n_layers, n_chans, n_hidden) - self.hidden_size = n_hidden - self.speaker_map = torch.zeros((self.n_spk,1,1,n_hidden)) - - - - def forward(self, units, mel2ph, f0, volume, g = None): - - ''' - input: - B x n_frames x n_unit - return: - dict of B x n_frames x feat - ''' - - decoder_inp = F.pad(units, [0, 0, 1, 0]) - mel2ph_ = mel2ph.unsqueeze(2).repeat([1, 1, units.shape[-1]]) - units = torch.gather(decoder_inp, 1, mel2ph_) # [B, T, H] - - x = self.unit_embed(units) + self.f0_embed((1 + f0.unsqueeze(-1) / 700).log()) + self.volume_embed(volume.unsqueeze(-1)) - - if self.n_spk is not None and self.n_spk > 1: # [N, S] * [S, B, 1, H] - g = g.reshape((g.shape[0], g.shape[1], 1, 1, 1)) # [N, S, B, 1, 1] - g = g * self.speaker_map # [N, S, B, 1, H] - g = torch.sum(g, dim=1) # [N, 1, B, 1, H] - g = g.transpose(0, -1).transpose(0, -2).squeeze(0) # [B, H, N] - x = x.transpose(1, 2) + g - return x - else: - return x.transpose(1, 2) - - - def init_spkembed(self, units, f0, volume, spk_id = None, spk_mix_dict = None, aug_shift = None, - gt_spec=None, infer=True, infer_speedup=10, method='dpm-solver', k_step=300, use_tqdm=True): - - ''' - input: - B x n_frames x n_unit - return: - dict of B x n_frames x feat - ''' - x = self.unit_embed(units) + self.f0_embed((1+ f0 / 700).log()) + self.volume_embed(volume) - if self.n_spk is not None and self.n_spk > 1: - if spk_mix_dict is not None: - spk_embed_mix = torch.zeros((1,1,self.hidden_size)) - for k, v in spk_mix_dict.items(): - spk_id_torch = torch.LongTensor(np.array([[k]])).to(units.device) - spk_embeddd = self.spk_embed(spk_id_torch) - self.speaker_map[k] = spk_embeddd - spk_embed_mix = spk_embed_mix + v * spk_embeddd - x = x + spk_embed_mix - else: - x = x + self.spk_embed(spk_id - 1) - self.speaker_map = self.speaker_map.unsqueeze(0) - self.speaker_map = self.speaker_map.detach() - return x.transpose(1, 2) - - def OnnxExport(self, project_name=None, init_noise=None, export_encoder=True, export_denoise=True, export_pred=True, export_after=True): - hubert_hidden_size = 768 - n_frames = 100 - hubert = torch.randn((1, n_frames, hubert_hidden_size)) - mel2ph = torch.arange(end=n_frames).unsqueeze(0).long() - f0 = torch.randn((1, n_frames)) - volume = torch.randn((1, n_frames)) - spk_mix = [] - spks = {} - if self.n_spk is not None and self.n_spk > 1: - for i in range(self.n_spk): - spk_mix.append(1.0/float(self.n_spk)) - spks.update({i:1.0/float(self.n_spk)}) - spk_mix = torch.tensor(spk_mix) - spk_mix = spk_mix.repeat(n_frames, 1) - orgouttt = self.init_spkembed(hubert, f0.unsqueeze(-1), volume.unsqueeze(-1), spk_mix_dict=spks) - outtt = self.forward(hubert, mel2ph, f0, volume, spk_mix) - if export_encoder: - torch.onnx.export( - self, - (hubert, mel2ph, f0, volume, spk_mix), - f"{project_name}_encoder.onnx", - input_names=["hubert", "mel2ph", "f0", "volume", "spk_mix"], - output_names=["mel_pred"], - dynamic_axes={ - "hubert": [1], - "f0": [1], - "volume": [1], - "mel2ph": [1], - "spk_mix": [0], - }, - opset_version=16 - ) - - self.decoder.OnnxExport(project_name, init_noise=init_noise, export_denoise=export_denoise, export_pred=export_pred, export_after=export_after) - - def ExportOnnx(self, project_name=None): - hubert_hidden_size = 768 - n_frames = 100 - hubert = torch.randn((1, n_frames, hubert_hidden_size)) - mel2ph = torch.arange(end=n_frames).unsqueeze(0).long() - f0 = torch.randn((1, n_frames)) - volume = torch.randn((1, n_frames)) - spk_mix = [] - spks = {} - if self.n_spk is not None and self.n_spk > 1: - for i in range(self.n_spk): - spk_mix.append(1.0/float(self.n_spk)) - spks.update({i:1.0/float(self.n_spk)}) - spk_mix = torch.tensor(spk_mix) - orgouttt = self.orgforward(hubert, f0.unsqueeze(-1), volume.unsqueeze(-1), spk_mix_dict=spks) - outtt = self.forward(hubert, mel2ph, f0, volume, spk_mix) - - torch.onnx.export( - self, - (hubert, mel2ph, f0, volume, spk_mix), - f"{project_name}_encoder.onnx", - input_names=["hubert", "mel2ph", "f0", "volume", "spk_mix"], - output_names=["mel_pred"], - dynamic_axes={ - "hubert": [1], - "f0": [1], - "volume": [1], - "mel2ph": [1] - }, - opset_version=16 - ) - - condition = torch.randn(1,self.decoder.n_hidden,n_frames) - noise = torch.randn((1, 1, self.decoder.mel_bins, condition.shape[2]), dtype=torch.float32) - pndm_speedup = torch.LongTensor([100]) - K_steps = torch.LongTensor([1000]) - self.decoder = torch.jit.script(self.decoder) - self.decoder(condition, noise, pndm_speedup, K_steps) - - torch.onnx.export( - self.decoder, - (condition, noise, pndm_speedup, K_steps), - f"{project_name}_diffusion.onnx", - input_names=["condition", "noise", "pndm_speedup", "K_steps"], - output_names=["mel"], - dynamic_axes={ - "condition": [2], - "noise": [3], - }, - opset_version=16 - ) - - -if __name__ == "__main__": - project_name = "dddsp" - model_path = f'{project_name}/model_500000.pt' - - model, _ = load_model_vocoder(model_path) - - # 分开Diffusion导出(需要使用MoeSS/MoeVoiceStudio或者自己编写Pndm/Dpm采样) - model.OnnxExport(project_name, export_encoder=True, export_denoise=True, export_pred=True, export_after=True) - - # 合并Diffusion导出(Encoder和Diffusion分开,直接将Encoder的结果和初始噪声输入Diffusion即可) - # model.ExportOnnx(project_name) - diff --git a/diffusion/solver.py b/diffusion/solver.py deleted file mode 100644 index aaf0b21591b42fa903424f8d44fef88d7d791e57..0000000000000000000000000000000000000000 --- a/diffusion/solver.py +++ /dev/null @@ -1,195 +0,0 @@ -import os -import time -import numpy as np -import torch -import librosa -from diffusion.logger.saver import Saver -from diffusion.logger import utils -from torch import autocast -from torch.cuda.amp import GradScaler - -def test(args, model, vocoder, loader_test, saver): - print(' [*] testing...') - model.eval() - - # losses - test_loss = 0. - - # intialization - num_batches = len(loader_test) - rtf_all = [] - - # run - with torch.no_grad(): - for bidx, data in enumerate(loader_test): - fn = data['name'][0].split("/")[-1] - speaker = data['name'][0].split("/")[-2] - print('--------') - print('{}/{} - {}'.format(bidx, num_batches, fn)) - - # unpack data - for k in data.keys(): - if not k.startswith('name'): - data[k] = data[k].to(args.device) - print('>>', data['name'][0]) - - # forward - st_time = time.time() - mel = model( - data['units'], - data['f0'], - data['volume'], - data['spk_id'], - gt_spec=None, - infer=True, - infer_speedup=args.infer.speedup, - method=args.infer.method) - signal = vocoder.infer(mel, data['f0']) - ed_time = time.time() - - # RTF - run_time = ed_time - st_time - song_time = signal.shape[-1] / args.data.sampling_rate - rtf = run_time / song_time - print('RTF: {} | {} / {}'.format(rtf, run_time, song_time)) - rtf_all.append(rtf) - - # loss - for i in range(args.train.batch_size): - loss = model( - data['units'], - data['f0'], - data['volume'], - data['spk_id'], - gt_spec=data['mel'], - infer=False) - test_loss += loss.item() - - # log mel - saver.log_spec(f"{speaker}_{fn}.wav", data['mel'], mel) - - # log audi - path_audio = data['name_ext'][0] - audio, sr = librosa.load(path_audio, sr=args.data.sampling_rate) - if len(audio.shape) > 1: - audio = librosa.to_mono(audio) - audio = torch.from_numpy(audio).unsqueeze(0).to(signal) - saver.log_audio({f"{speaker}_{fn}_gt.wav": audio,f"{speaker}_{fn}_pred.wav": signal}) - # report - test_loss /= args.train.batch_size - test_loss /= num_batches - - # check - print(' [test_loss] test_loss:', test_loss) - print(' Real Time Factor', np.mean(rtf_all)) - return test_loss - - -def train(args, initial_global_step, model, optimizer, scheduler, vocoder, loader_train, loader_test): - # saver - saver = Saver(args, initial_global_step=initial_global_step) - - # model size - params_count = utils.get_network_paras_amount({'model': model}) - saver.log_info('--- model size ---') - saver.log_info(params_count) - - # run - num_batches = len(loader_train) - model.train() - saver.log_info('======= start training =======') - scaler = GradScaler() - if args.train.amp_dtype == 'fp32': - dtype = torch.float32 - elif args.train.amp_dtype == 'fp16': - dtype = torch.float16 - elif args.train.amp_dtype == 'bf16': - dtype = torch.bfloat16 - else: - raise ValueError(' [x] Unknown amp_dtype: ' + args.train.amp_dtype) - saver.log_info("epoch|batch_idx/num_batches|output_dir|batch/s|lr|time|step") - for epoch in range(args.train.epochs): - for batch_idx, data in enumerate(loader_train): - saver.global_step_increment() - optimizer.zero_grad() - - # unpack data - for k in data.keys(): - if not k.startswith('name'): - data[k] = data[k].to(args.device) - - # forward - if dtype == torch.float32: - loss = model(data['units'].float(), data['f0'], data['volume'], data['spk_id'], - aug_shift = data['aug_shift'], gt_spec=data['mel'].float(), infer=False) - else: - with autocast(device_type=args.device, dtype=dtype): - loss = model(data['units'], data['f0'], data['volume'], data['spk_id'], - aug_shift = data['aug_shift'], gt_spec=data['mel'], infer=False) - - # handle nan loss - if torch.isnan(loss): - raise ValueError(' [x] nan loss ') - else: - # backpropagate - if dtype == torch.float32: - loss.backward() - optimizer.step() - else: - scaler.scale(loss).backward() - scaler.step(optimizer) - scaler.update() - scheduler.step() - - # log loss - if saver.global_step % args.train.interval_log == 0: - current_lr = optimizer.param_groups[0]['lr'] - saver.log_info( - 'epoch: {} | {:3d}/{:3d} | {} | batch/s: {:.2f} | lr: {:.6} | loss: {:.3f} | time: {} | step: {}'.format( - epoch, - batch_idx, - num_batches, - args.env.expdir, - args.train.interval_log/saver.get_interval_time(), - current_lr, - loss.item(), - saver.get_total_time(), - saver.global_step - ) - ) - - saver.log_value({ - 'train/loss': loss.item() - }) - - saver.log_value({ - 'train/lr': current_lr - }) - - # validation - if saver.global_step % args.train.interval_val == 0: - optimizer_save = optimizer if args.train.save_opt else None - - # save latest - saver.save_model(model, optimizer_save, postfix=f'{saver.global_step}') - last_val_step = saver.global_step - args.train.interval_val - if last_val_step % args.train.interval_force_save != 0: - saver.delete_model(postfix=f'{last_val_step}') - - # run testing set - test_loss = test(args, model, vocoder, loader_test, saver) - - # log loss - saver.log_info( - ' --- --- \nloss: {:.3f}. '.format( - test_loss, - ) - ) - - saver.log_value({ - 'validation/loss': test_loss - }) - - model.train() - - diff --git a/diffusion/unit2mel.py b/diffusion/unit2mel.py deleted file mode 100644 index 52293b13da8e1afeef6fa5586aeaf01cbcc27fb7..0000000000000000000000000000000000000000 --- a/diffusion/unit2mel.py +++ /dev/null @@ -1,147 +0,0 @@ -import os -import yaml -import torch -import torch.nn as nn -import numpy as np -from .diffusion import GaussianDiffusion -from .wavenet import WaveNet -from .vocoder import Vocoder - -class DotDict(dict): - def __getattr__(*args): - val = dict.get(*args) - return DotDict(val) if type(val) is dict else val - - __setattr__ = dict.__setitem__ - __delattr__ = dict.__delitem__ - - -def load_model_vocoder( - model_path, - device='cpu', - config_path = None - ): - if config_path is None: config_file = os.path.join(os.path.split(model_path)[0], 'config.yaml') - else: config_file = config_path - - with open(config_file, "r") as config: - args = yaml.safe_load(config) - args = DotDict(args) - - # load vocoder - vocoder = Vocoder(args.vocoder.type, args.vocoder.ckpt, device=device) - - # load model - model = Unit2Mel( - args.data.encoder_out_channels, - args.model.n_spk, - args.model.use_pitch_aug, - vocoder.dimension, - args.model.n_layers, - args.model.n_chans, - args.model.n_hidden) - - print(' [Loading] ' + model_path) - ckpt = torch.load(model_path, map_location=torch.device(device)) - model.to(device) - model.load_state_dict(ckpt['model']) - model.eval() - return model, vocoder, args - - -class Unit2Mel(nn.Module): - def __init__( - self, - input_channel, - n_spk, - use_pitch_aug=False, - out_dims=128, - n_layers=20, - n_chans=384, - n_hidden=256): - super().__init__() - self.unit_embed = nn.Linear(input_channel, n_hidden) - self.f0_embed = nn.Linear(1, n_hidden) - self.volume_embed = nn.Linear(1, n_hidden) - if use_pitch_aug: - self.aug_shift_embed = nn.Linear(1, n_hidden, bias=False) - else: - self.aug_shift_embed = None - self.n_spk = n_spk - if n_spk is not None and n_spk > 1: - self.spk_embed = nn.Embedding(n_spk, n_hidden) - - self.n_hidden = n_hidden - # diffusion - self.decoder = GaussianDiffusion(WaveNet(out_dims, n_layers, n_chans, n_hidden), out_dims=out_dims) - self.input_channel = input_channel - - def init_spkembed(self, units, f0, volume, spk_id = None, spk_mix_dict = None, aug_shift = None, - gt_spec=None, infer=True, infer_speedup=10, method='dpm-solver', k_step=300, use_tqdm=True): - - ''' - input: - B x n_frames x n_unit - return: - dict of B x n_frames x feat - ''' - x = self.unit_embed(units) + self.f0_embed((1+ f0 / 700).log()) + self.volume_embed(volume) - if self.n_spk is not None and self.n_spk > 1: - if spk_mix_dict is not None: - spk_embed_mix = torch.zeros((1,1,self.hidden_size)) - for k, v in spk_mix_dict.items(): - spk_id_torch = torch.LongTensor(np.array([[k]])).to(units.device) - spk_embeddd = self.spk_embed(spk_id_torch) - self.speaker_map[k] = spk_embeddd - spk_embed_mix = spk_embed_mix + v * spk_embeddd - x = x + spk_embed_mix - else: - x = x + self.spk_embed(spk_id - 1) - self.speaker_map = self.speaker_map.unsqueeze(0) - self.speaker_map = self.speaker_map.detach() - return x.transpose(1, 2) - - def init_spkmix(self, n_spk): - self.speaker_map = torch.zeros((n_spk,1,1,self.n_hidden)) - hubert_hidden_size = self.input_channel - n_frames = 10 - hubert = torch.randn((1, n_frames, hubert_hidden_size)) - mel2ph = torch.arange(end=n_frames).unsqueeze(0).long() - f0 = torch.randn((1, n_frames)) - volume = torch.randn((1, n_frames)) - spks = {} - for i in range(n_spk): - spks.update({i:1.0/float(self.n_spk)}) - orgouttt = self.init_spkembed(hubert, f0.unsqueeze(-1), volume.unsqueeze(-1), spk_mix_dict=spks) - - def forward(self, units, f0, volume, spk_id = None, spk_mix_dict = None, aug_shift = None, - gt_spec=None, infer=True, infer_speedup=10, method='dpm-solver', k_step=300, use_tqdm=True): - - ''' - input: - B x n_frames x n_unit - return: - dict of B x n_frames x feat - ''' - - x = self.unit_embed(units) + self.f0_embed((1+ f0 / 700).log()) + self.volume_embed(volume) - if self.n_spk is not None and self.n_spk > 1: - if spk_mix_dict is not None: - for k, v in spk_mix_dict.items(): - spk_id_torch = torch.LongTensor(np.array([[k]])).to(units.device) - x = x + v * self.spk_embed(spk_id_torch) - else: - if spk_id.shape[1] > 1: - g = spk_id.reshape((spk_id.shape[0], spk_id.shape[1], 1, 1, 1)) # [N, S, B, 1, 1] - g = g * self.speaker_map # [N, S, B, 1, H] - g = torch.sum(g, dim=1) # [N, 1, B, 1, H] - g = g.transpose(0, -1).transpose(0, -2).squeeze(0) # [B, H, N] - x = x + g - else: - x = x + self.spk_embed(spk_id) - if self.aug_shift_embed is not None and aug_shift is not None: - x = x + self.aug_shift_embed(aug_shift / 5) - x = self.decoder(x, gt_spec=gt_spec, infer=infer, infer_speedup=infer_speedup, method=method, k_step=k_step, use_tqdm=use_tqdm) - - return x - diff --git a/diffusion/vocoder.py b/diffusion/vocoder.py deleted file mode 100644 index bbaa47f64fd5a3191a24dfaa054c423fa86e5bae..0000000000000000000000000000000000000000 --- a/diffusion/vocoder.py +++ /dev/null @@ -1,94 +0,0 @@ -import torch -from vdecoder.nsf_hifigan.nvSTFT import STFT -from vdecoder.nsf_hifigan.models import load_model,load_config -from torchaudio.transforms import Resample - - -class Vocoder: - def __init__(self, vocoder_type, vocoder_ckpt, device = None): - if device is None: - device = 'cuda' if torch.cuda.is_available() else 'cpu' - self.device = device - - if vocoder_type == 'nsf-hifigan': - self.vocoder = NsfHifiGAN(vocoder_ckpt, device = device) - elif vocoder_type == 'nsf-hifigan-log10': - self.vocoder = NsfHifiGANLog10(vocoder_ckpt, device = device) - else: - raise ValueError(f" [x] Unknown vocoder: {vocoder_type}") - - self.resample_kernel = {} - self.vocoder_sample_rate = self.vocoder.sample_rate() - self.vocoder_hop_size = self.vocoder.hop_size() - self.dimension = self.vocoder.dimension() - - def extract(self, audio, sample_rate, keyshift=0): - - # resample - if sample_rate == self.vocoder_sample_rate: - audio_res = audio - else: - key_str = str(sample_rate) - if key_str not in self.resample_kernel: - self.resample_kernel[key_str] = Resample(sample_rate, self.vocoder_sample_rate, lowpass_filter_width = 128).to(self.device) - audio_res = self.resample_kernel[key_str](audio) - - # extract - mel = self.vocoder.extract(audio_res, keyshift=keyshift) # B, n_frames, bins - return mel - - def infer(self, mel, f0): - f0 = f0[:,:mel.size(1),0] # B, n_frames - audio = self.vocoder(mel, f0) - return audio - - -class NsfHifiGAN(torch.nn.Module): - def __init__(self, model_path, device=None): - super().__init__() - if device is None: - device = 'cuda' if torch.cuda.is_available() else 'cpu' - self.device = device - self.model_path = model_path - self.model = None - self.h = load_config(model_path) - self.stft = STFT( - self.h.sampling_rate, - self.h.num_mels, - self.h.n_fft, - self.h.win_size, - self.h.hop_size, - self.h.fmin, - self.h.fmax) - - def sample_rate(self): - return self.h.sampling_rate - - def hop_size(self): - return self.h.hop_size - - def dimension(self): - return self.h.num_mels - - def extract(self, audio, keyshift=0): - mel = self.stft.get_mel(audio, keyshift=keyshift).transpose(1, 2) # B, n_frames, bins - return mel - - def forward(self, mel, f0): - if self.model is None: - print('| Load HifiGAN: ', self.model_path) - self.model, self.h = load_model(self.model_path, device=self.device) - with torch.no_grad(): - c = mel.transpose(1, 2) - audio = self.model(c, f0) - return audio - -class NsfHifiGANLog10(NsfHifiGAN): - def forward(self, mel, f0): - if self.model is None: - print('| Load HifiGAN: ', self.model_path) - self.model, self.h = load_model(self.model_path, device=self.device) - with torch.no_grad(): - c = 0.434294 * mel.transpose(1, 2) - audio = self.model(c, f0) - return audio \ No newline at end of file diff --git a/diffusion/wavenet.py b/diffusion/wavenet.py deleted file mode 100644 index 3d48c7eaaa0e8191b27a5d1890eb657cbcc0d143..0000000000000000000000000000000000000000 --- a/diffusion/wavenet.py +++ /dev/null @@ -1,108 +0,0 @@ -import math -from math import sqrt - -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.nn import Mish - - -class Conv1d(torch.nn.Conv1d): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - nn.init.kaiming_normal_(self.weight) - - -class SinusoidalPosEmb(nn.Module): - def __init__(self, dim): - super().__init__() - self.dim = dim - - def forward(self, x): - device = x.device - half_dim = self.dim // 2 - emb = math.log(10000) / (half_dim - 1) - emb = torch.exp(torch.arange(half_dim, device=device) * -emb) - emb = x[:, None] * emb[None, :] - emb = torch.cat((emb.sin(), emb.cos()), dim=-1) - return emb - - -class ResidualBlock(nn.Module): - def __init__(self, encoder_hidden, residual_channels, dilation): - super().__init__() - self.residual_channels = residual_channels - self.dilated_conv = nn.Conv1d( - residual_channels, - 2 * residual_channels, - kernel_size=3, - padding=dilation, - dilation=dilation - ) - self.diffusion_projection = nn.Linear(residual_channels, residual_channels) - self.conditioner_projection = nn.Conv1d(encoder_hidden, 2 * residual_channels, 1) - self.output_projection = nn.Conv1d(residual_channels, 2 * residual_channels, 1) - - def forward(self, x, conditioner, diffusion_step): - diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1) - conditioner = self.conditioner_projection(conditioner) - y = x + diffusion_step - - y = self.dilated_conv(y) + conditioner - - # Using torch.split instead of torch.chunk to avoid using onnx::Slice - gate, filter = torch.split(y, [self.residual_channels, self.residual_channels], dim=1) - y = torch.sigmoid(gate) * torch.tanh(filter) - - y = self.output_projection(y) - - # Using torch.split instead of torch.chunk to avoid using onnx::Slice - residual, skip = torch.split(y, [self.residual_channels, self.residual_channels], dim=1) - return (x + residual) / math.sqrt(2.0), skip - - -class WaveNet(nn.Module): - def __init__(self, in_dims=128, n_layers=20, n_chans=384, n_hidden=256): - super().__init__() - self.input_projection = Conv1d(in_dims, n_chans, 1) - self.diffusion_embedding = SinusoidalPosEmb(n_chans) - self.mlp = nn.Sequential( - nn.Linear(n_chans, n_chans * 4), - Mish(), - nn.Linear(n_chans * 4, n_chans) - ) - self.residual_layers = nn.ModuleList([ - ResidualBlock( - encoder_hidden=n_hidden, - residual_channels=n_chans, - dilation=1 - ) - for i in range(n_layers) - ]) - self.skip_projection = Conv1d(n_chans, n_chans, 1) - self.output_projection = Conv1d(n_chans, in_dims, 1) - nn.init.zeros_(self.output_projection.weight) - - def forward(self, spec, diffusion_step, cond): - """ - :param spec: [B, 1, M, T] - :param diffusion_step: [B, 1] - :param cond: [B, M, T] - :return: - """ - x = spec.squeeze(1) - x = self.input_projection(x) # [B, residual_channel, T] - - x = F.relu(x) - diffusion_step = self.diffusion_embedding(diffusion_step) - diffusion_step = self.mlp(diffusion_step) - skip = [] - for layer in self.residual_layers: - x, skip_connection = layer(x, cond, diffusion_step) - skip.append(skip_connection) - - x = torch.sum(torch.stack(skip), dim=0) / sqrt(len(self.residual_layers)) - x = self.skip_projection(x) - x = F.relu(x) - x = self.output_projection(x) # [B, mel_bins, T] - return x[:, None, :, :] diff --git a/inference/infer_tool.py b/inference/infer_tool.py index efa4949e7e86b659b7efd3d6cc952f3db96bdfd5..8e47d5efc5e2f33e99602b4fd1abcc8cbcab5128 100644 --- a/inference/infer_tool.py +++ b/inference/infer_tool.py @@ -1,27 +1,24 @@ +import gc import hashlib import io import json import logging import os +import pickle import time from pathlib import Path -from inference import slicer -import gc import librosa import numpy as np + # import onnxruntime import soundfile import torch import torchaudio -import cluster import utils +from inference import slicer from models import SynthesizerTrn -import pickle - -from diffusion.unit2mel import load_model_vocoder -import yaml logging.getLogger('matplotlib').setLevel(logging.WARNING) @@ -131,6 +128,7 @@ class Svc(object): spk_mix_enable=False, feature_retrieval=False ): + self.hubert_model = None self.net_g_path = net_g_path self.only_diffusion = only_diffusion self.shallow_diffusion = shallow_diffusion @@ -141,53 +139,26 @@ class Svc(object): self.dev = torch.device(device) self.net_g_ms = None if not self.only_diffusion: - self.hps_ms = utils.get_hparams_from_file(config_path) + self.hps_ms = utils.get_hparams_from_file(config_path, True) self.target_sample = self.hps_ms.data.sampling_rate self.hop_size = self.hps_ms.data.hop_length self.spk2id = self.hps_ms.spk - try: - self.vol_embedding = self.hps_ms.model.vol_embedding - except Exception as e: - self.vol_embedding = False - try: - self.speech_encoder = self.hps_ms.model.speech_encoder - except Exception as e: - self.speech_encoder = 'vec768l12' + self.unit_interpolate_mode = self.hps_ms.data.unit_interpolate_mode if self.hps_ms.data.unit_interpolate_mode is not None else 'left' + self.vol_embedding = self.hps_ms.model.vol_embedding if self.hps_ms.model.vol_embedding is not None else False + self.speech_encoder = self.hps_ms.model.speech_encoder if self.hps_ms.model.speech_encoder is not None else 'vec768l12' self.nsf_hifigan_enhance = nsf_hifigan_enhance - if self.shallow_diffusion or self.only_diffusion: - if os.path.exists(diffusion_model_path) and os.path.exists(diffusion_model_path): - self.diffusion_model, self.vocoder, self.diffusion_args = load_model_vocoder(diffusion_model_path, - self.dev, - config_path=diffusion_config_path) - if self.only_diffusion: - self.target_sample = self.diffusion_args.data.sampling_rate - self.hop_size = self.diffusion_args.data.block_size - self.spk2id = self.diffusion_args.spk - self.speech_encoder = self.diffusion_args.data.encoder - if spk_mix_enable: - self.diffusion_model.init_spkmix(len(self.spk2id)) - else: - print("No diffusion model or config found. Shallow diffusion mode will False") - self.shallow_diffusion = self.only_diffusion = False # load hubert and model self.load_model(spk_mix_enable) # self.hubert_model = utils.get_speech_encoder(self.speech_encoder, device=self.dev) self.volume_extractor = utils.Volume_Extractor(self.hop_size) - if os.path.exists(cluster_model_path): - if self.feature_retrieval: - with open(cluster_model_path, "rb") as f: - self.cluster_model = pickle.load(f) - self.big_npy = None - self.now_spk_id = -1 - else: - self.cluster_model = cluster.get_cluster_model(cluster_model_path) - else: - self.feature_retrieval = False - if self.shallow_diffusion: self.nsf_hifigan_enhance = False + self.feature_retrieval = False + + if self.shallow_diffusion: + self.nsf_hifigan_enhance = False if self.nsf_hifigan_enhance: from modules.enhancer import Enhancer self.enhancer = Enhancer('nsf-hifigan', 'pretrain/nsf_hifigan/model', device=self.dev) @@ -199,6 +170,7 @@ class Svc(object): self.hps_ms.train.segment_size // self.hps_ms.data.hop_length, **self.hps_ms.model) _ = utils.load_checkpoint(self.net_g_path, self.net_g_ms, None) + self.dtype = list(self.net_g_ms.parameters())[0].dtype if "half" in self.net_g_path and torch.cuda.is_available(): _ = self.net_g_ms.half().eval().to(self.dev) else: @@ -208,11 +180,13 @@ class Svc(object): def get_unit_f0(self, wav, tran, cluster_infer_ratio, speaker, f0_filter, f0_predictor, cr_threshold=0.05): - f0_predictor_object = utils.get_f0_predictor(f0_predictor, hop_length=self.hop_size, - sampling_rate=self.target_sample, device=self.dev, - threshold=cr_threshold) + if not hasattr(self, + "f0_predictor_object") or self.f0_predictor_object is None or f0_predictor != self.f0_predictor_object.name: + self.f0_predictor_object = utils.get_f0_predictor(f0_predictor, hop_length=self.hop_size, + sampling_rate=self.target_sample, device=self.dev, + threshold=cr_threshold) + f0, uv = self.f0_predictor_object.compute_f0_uv(wav) - f0, uv = f0_predictor_object.compute_f0_uv(wav) if f0_filter and sum(f0) == 0: raise F0FilterException("No voice detected") f0 = torch.FloatTensor(f0).to(self.dev) @@ -222,36 +196,13 @@ class Svc(object): f0 = f0.unsqueeze(0) uv = uv.unsqueeze(0) - wav16k = librosa.resample(wav, orig_sr=self.target_sample, target_sr=16000) - wav16k = torch.from_numpy(wav16k).to(self.dev) + wav = torch.from_numpy(wav).to(self.dev) + if not hasattr(self, "audio16k_resample_transform"): + self.audio16k_resample_transform = torchaudio.transforms.Resample(self.target_sample, 16000).to(self.dev) + wav16k = self.audio16k_resample_transform(wav[None, :])[0] + c = self.hubert_model.encoder(wav16k) - c = utils.repeat_expand_2d(c.squeeze(0), f0.shape[1]) - - if cluster_infer_ratio != 0: - if self.feature_retrieval: - speaker_id = self.spk2id.get(speaker) - if speaker_id is None: - raise RuntimeError("The name you entered is not in the speaker list!") - if not speaker_id and type(speaker) is int: - if len(self.spk2id.__dict__) >= speaker: - speaker_id = speaker - feature_index = self.cluster_model[speaker_id] - feat_np = c.transpose(0, 1).cpu().numpy() - if self.big_npy is None or self.now_spk_id != speaker_id: - self.big_npy = feature_index.reconstruct_n(0, feature_index.ntotal) - self.now_spk_id = speaker_id - print("starting feature retrieval...") - score, ix = feature_index.search(feat_np, k=8) - weight = np.square(1 / score) - weight /= weight.sum(axis=1, keepdims=True) - npy = np.sum(self.big_npy[ix] * np.expand_dims(weight, axis=2), axis=1) - c = cluster_infer_ratio * npy + (1 - cluster_infer_ratio) * feat_np - c = torch.FloatTensor(c).to(self.dev).transpose(0, 1) - print("end feature retrieval...") - else: - cluster_c = cluster.get_cluster_center_result(self.cluster_model, c.cpu().numpy().T, speaker).T - cluster_c = torch.FloatTensor(cluster_c).to(self.dev) - c = cluster_infer_ratio * cluster_c + (1 - cluster_infer_ratio) * c + c = utils.repeat_expand_2d(c.squeeze(0), f0.shape[1], self.unit_interpolate_mode) c = c.unsqueeze(0) return c, f0, uv @@ -270,7 +221,11 @@ class Svc(object): second_encoding=False, loudness_envelope_adjustment=1 ): - wav, sr = librosa.load(raw_path, sr=self.target_sample) + torchaudio.set_audio_backend("soundfile") + wav, sr = torchaudio.load(raw_path) + if not hasattr(self, "audio_resample_transform") or self.audio16k_resample_transform.orig_freq != sr: + self.audio_resample_transform = torchaudio.transforms.Resample(sr, self.target_sample) + wav = self.audio_resample_transform(wav).numpy()[0] if spk_mix: c, f0, uv = self.get_unit_f0(wav, tran, 0, None, f0_filter, f0_predictor, cr_threshold=cr_threshold) n_frames = f0.size(1) @@ -286,8 +241,9 @@ class Svc(object): c, f0, uv = self.get_unit_f0(wav, tran, cluster_infer_ratio, speaker, f0_filter, f0_predictor, cr_threshold=cr_threshold) n_frames = f0.size(1) - if "half" in self.net_g_path and torch.cuda.is_available(): - c = c.half() + c = c.to(self.dtype) + f0 = f0.to(self.dtype) + uv = uv.to(self.dtype) with torch.no_grad(): start = time.time() vol = None @@ -301,17 +257,22 @@ class Svc(object): else: audio = torch.FloatTensor(wav).to(self.dev) audio_mel = None + if self.dtype != torch.float32: + c = c.to(torch.float32) + f0 = f0.to(torch.float32) + uv = uv.to(torch.float32) if self.only_diffusion or self.shallow_diffusion: - vol = self.volume_extractor.extract(audio[None, :])[None, :, None].to(self.dev) if vol == None else vol[ + vol = self.volume_extractor.extract(audio[None, :])[None, :, None].to(self.dev) if vol is None else vol[ :, :, None] if self.shallow_diffusion and second_encoding: - audio16k = librosa.resample(audio.detach().cpu().numpy(), orig_sr=self.target_sample, - target_sr=16000) - audio16k = torch.from_numpy(audio16k).to(self.dev) + if not hasattr(self, "audio16k_resample_transform"): + self.audio16k_resample_transform = torchaudio.transforms.Resample(self.target_sample, 16000).to( + self.dev) + audio16k = self.audio16k_resample_transform(audio[None, :])[0] c = self.hubert_model.encoder(audio16k) - c = utils.repeat_expand_2d(c.squeeze(0), f0.shape[1]) + c = utils.repeat_expand_2d(c.squeeze(0), f0.shape[1], self.unit_interpolate_mode) f0 = f0[:, :, None] c = c.transpose(-1, -2) audio_mel = self.diffusion_model( @@ -460,7 +421,8 @@ class Svc(object): datas = [data] for k, dat in enumerate(datas): per_length = int(np.ceil(len(dat) / audio_sr * self.target_sample)) if clip_seconds != 0 else length - if clip_seconds != 0: print(f'###=====segment clip start, {round(len(dat) / audio_sr, 3)}s======') + if clip_seconds != 0: + print(f'###=====segment clip start, {round(len(dat) / audio_sr, 3)}s======') # padd pad_len = int(audio_sr * pad_seconds) dat = np.concatenate([np.zeros([pad_len]), dat, np.zeros([pad_len])]) @@ -496,51 +458,3 @@ class Svc(object): return np.array(audio) -class RealTimeVC: - def __init__(self): - self.last_chunk = None - self.last_o = None - self.chunk_len = 16000 # chunk length - self.pre_len = 3840 # cross fade length, multiples of 640 - - # Input and output are 1-dimensional numpy waveform arrays - - def process(self, svc_model, speaker_id, f_pitch_change, input_wav_path, - cluster_infer_ratio=0, - auto_predict_f0=False, - noice_scale=0.4, - f0_filter=False): - - import maad - audio, sr = torchaudio.load(input_wav_path) - audio = audio.cpu().numpy()[0] - temp_wav = io.BytesIO() - if self.last_chunk is None: - input_wav_path.seek(0) - - audio, sr = svc_model.infer(speaker_id, f_pitch_change, input_wav_path, - cluster_infer_ratio=cluster_infer_ratio, - auto_predict_f0=auto_predict_f0, - noice_scale=noice_scale, - f0_filter=f0_filter) - - audio = audio.cpu().numpy() - self.last_chunk = audio[-self.pre_len:] - self.last_o = audio - return audio[-self.chunk_len:] - else: - audio = np.concatenate([self.last_chunk, audio]) - soundfile.write(temp_wav, audio, sr, format="wav") - temp_wav.seek(0) - - audio, sr = svc_model.infer(speaker_id, f_pitch_change, temp_wav, - cluster_infer_ratio=cluster_infer_ratio, - auto_predict_f0=auto_predict_f0, - noice_scale=noice_scale, - f0_filter=f0_filter) - - audio = audio.cpu().numpy() - ret = maad.util.crossfade(self.last_o, audio, self.pre_len) - self.last_chunk = audio[-self.pre_len:] - self.last_o = audio - return ret[self.chunk_len:2 * self.chunk_len] diff --git a/inference/infer_tool_grad.py b/inference/infer_tool_grad.py index 561c22c55e4f0527d038bbce3cef317393ded542..136e9048ec73e0d60f32fde80acc349b97eff366 100644 --- a/inference/infer_tool_grad.py +++ b/inference/infer_tool_grad.py @@ -1,22 +1,18 @@ -import hashlib -import json +import io import logging import os -import time -from pathlib import Path -import io + import librosa -import maad import numpy as np -from inference import slicer import parselmouth import soundfile import torch import torchaudio -from hubert import hubert_model import utils +from inference import slicer from models import SynthesizerTrn + logging.getLogger('numba').setLevel(logging.WARNING) logging.getLogger('matplotlib').setLevel(logging.WARNING) @@ -93,7 +89,7 @@ class VitsSvc(object): def set_device(self, device): self.device = torch.device(device) self.hubert_soft.to(self.device) - if self.SVCVITS != None: + if self.SVCVITS is not None: self.SVCVITS.to(self.device) def loadCheckpoint(self, path): diff --git a/inference/slicer.py b/inference/slicer.py index 05b3df0842d56ad700bfed931e90a988b2149a34..b05840bcf6bdced0b6e2adbecb1a1dd5b3dee462 100644 --- a/inference/slicer.py +++ b/inference/slicer.py @@ -117,8 +117,8 @@ class Slicer: return chunk_dict -def cut(input_audio, db_thresh=-30, min_len=5000): - audio, sr = librosa.load(input_audio, sr=None) +def cut(audio_path, db_thresh=-30, min_len=5000): + audio, sr = librosa.load(audio_path, sr=None) slicer = Slicer( sr=sr, threshold=db_thresh, diff --git a/inference_main.py b/inference_main.py deleted file mode 100644 index d3311757984e262658cc406f8d28febbe9620844..0000000000000000000000000000000000000000 --- a/inference_main.py +++ /dev/null @@ -1,181 +0,0 @@ -import io -import logging -import time -from pathlib import Path -from spkmix import spk_mix_map -import librosa -import matplotlib.pyplot as plt -import numpy as np -import soundfile -from inference import infer_tool -from inference import slicer -from inference.infer_tool import Svc - -logging.getLogger('numba').setLevel(logging.WARNING) -chunks_dict = infer_tool.read_temp("inference/chunks_temp.json") - - -def main(): - import argparse - - parser = argparse.ArgumentParser(description='sovits4 inference') - - # 一定要设置的部分 - parser.add_argument('-m', '--model_path', type=str, default="logs/44k/", help='模型路径') - parser.add_argument('-c', '--config_path', type=str, default="configs/", help='配置文件路径') - parser.add_argument('-cl', '--clip', type=float, default=0, help='音频强制切片,默认0为自动切片,单位为秒/s') - parser.add_argument('-n', '--clean_names', type=str, nargs='+', default=["test.wav"], - help='wav文件名列表,放在raw文件夹下') - parser.add_argument('-t', '--trans', type=int, nargs='+', default=[0], help='音高调整,支持正负(半音)') - parser.add_argument('-s', '--spk_list', type=str, nargs='+', default=['buyizi'], help='合成目标说话人名称') - - # 可选项部分 - parser.add_argument('-a', '--auto_predict_f0', action='store_true', default=False, - help='语音转换自动预测音高,转换歌声时不要打开这个会严重跑调') - parser.add_argument('-cm', '--cluster_model_path', type=str, default="logs/44k/kmeans_10000.pt", - help='聚类模型或特征检索索引路径,如果没有训练聚类或特征检索则随便填') - parser.add_argument('-cr', '--cluster_infer_ratio', type=float, default=0, - help='聚类方案或特征检索占比,范围0-1,若没有训练聚类模型或特征检索则默认0即可') - parser.add_argument('-lg', '--linear_gradient', type=float, default=0, - help='两段音频切片的交叉淡入长度,如果强制切片后出现人声不连贯可调整该数值,如果连贯建议采用默认值0,单位为秒') - parser.add_argument('-f0p', '--f0_predictor', type=str, default="harvest", - help='选择F0预测器,可选择crepe,pm,dio,harvest,默认为pm(注意:crepe为原F0使用均值滤波器)') - parser.add_argument('-eh', '--enhance', action='store_true', default=False, - help='是否使用NSF_HIFIGAN增强器,该选项对部分训练集少的模型有一定的音质增强效果,但是对训练好的模型有反面效果,默认关闭') - parser.add_argument('-shd', '--shallow_diffusion', action='store_true', default=False, - help='是否使用浅层扩散,使用后可解决一部分电音问题,默认关闭,该选项打开时,NSF_HIFIGAN增强器将会被禁止') - parser.add_argument('-usm', '--use_spk_mix', action='store_true', default=False, help='是否使用角色融合') - parser.add_argument('-lea', '--loudness_envelope_adjustment', type=float, default=1, - help='输入源响度包络替换输出响度包络融合比例,越靠近1越使用输出响度包络') - parser.add_argument('-fr', '--feature_retrieval', action='store_true', default=False, - help='是否使用特征检索,如果使用聚类模型将被禁用,且cm与cr参数将会变成特征检索的索引路径与混合比例') - - # 浅扩散设置 - parser.add_argument('-dm', '--diffusion_model_path', type=str, default="logs/44k/diffusion/model_0.pt", - help='扩散模型路径') - parser.add_argument('-dc', '--diffusion_config_path', type=str, default="logs/44k/diffusion/config.yaml", - help='扩散模型配置文件路径') - parser.add_argument('-ks', '--k_step', type=int, default=100, help='扩散步数,越大越接近扩散模型的结果,默认100') - parser.add_argument('-se', '--second_encoding', action='store_true', default=False, - help='二次编码,浅扩散前会对原始音频进行二次编码,玄学选项,有时候效果好,有时候效果差') - parser.add_argument('-od', '--only_diffusion', action='store_true', default=False, - help='纯扩散模式,该模式不会加载sovits模型,以扩散模型推理') - - # 不用动的部分 - parser.add_argument('-sd', '--slice_db', type=int, default=-40, - help='默认-40,嘈杂的音频可以-30,干声保留呼吸可以-50') - parser.add_argument('-d', '--device', type=str, default=None, help='推理设备,None则为自动选择cpu和gpu') - parser.add_argument('-ns', '--noice_scale', type=float, default=0.4, help='噪音级别,会影响咬字和音质,较为玄学') - parser.add_argument('-p', '--pad_seconds', type=float, default=0.5, - help='推理音频pad秒数,由于未知原因开头结尾会有异响,pad一小段静音段后就不会出现') - parser.add_argument('-wf', '--wav_format', type=str, default='flac', help='音频输出格式') - parser.add_argument('-lgr', '--linear_gradient_retain', type=float, default=0.75, - help='自动音频切片后,需要舍弃每段切片的头尾。该参数设置交叉长度保留的比例,范围0-1,左开右闭') - parser.add_argument('-eak', '--enhancer_adaptive_key', type=int, default=0, - help='使增强器适应更高的音域(单位为半音数)|默认为0') - parser.add_argument('-ft', '--f0_filter_threshold', type=float, default=0.05, - help='F0过滤阈值,只有使用crepe时有效. 数值范围从0-1. 降低该值可减少跑调概率,但会增加哑音') - - def preprocess_args(args1): - spk1 = args1.spk_list[0] - args1.model_path += f"{spk1}.pth" - args1.config_path += f"config_{spk1}.json" - args1.clip = 30 - - if spk1 == 'tomori': - args1.feature_retrieval = True - args1.cluster_model_path = "logs/44k/tomori_index.pkl" - args1.cluster_infer_ratio = 0.5 - args1.f0_predictor = 'crepe' - - return args1 - - args = parser.parse_args() - args = preprocess_args(args) - - clean_names = args.clean_names - trans = args.trans - spk_list = args.spk_list - slice_db = args.slice_db - wav_format = args.wav_format - auto_predict_f0 = args.auto_predict_f0 - cluster_infer_ratio = args.cluster_infer_ratio - noice_scale = args.noice_scale - pad_seconds = args.pad_seconds - clip = args.clip - lg = args.linear_gradient - lgr = args.linear_gradient_retain - f0p = args.f0_predictor - enhance = args.enhance - enhancer_adaptive_key = args.enhancer_adaptive_key - cr_threshold = args.f0_filter_threshold - diffusion_model_path = args.diffusion_model_path - diffusion_config_path = args.diffusion_config_path - k_step = args.k_step - only_diffusion = args.only_diffusion - shallow_diffusion = args.shallow_diffusion - use_spk_mix = args.use_spk_mix - second_encoding = args.second_encoding - loudness_envelope_adjustment = args.loudness_envelope_adjustment - - svc_model = Svc(args.model_path, - args.config_path, - args.device, - args.cluster_model_path, - enhance, - diffusion_model_path, - diffusion_config_path, - shallow_diffusion, - only_diffusion, - use_spk_mix, - args.feature_retrieval) - - infer_tool.mkdir(["raw", "results"]) - - if len(spk_mix_map) <= 1: - use_spk_mix = False - if use_spk_mix: - spk_list = [spk_mix_map] - - infer_tool.fill_a_to_b(trans, clean_names) - for clean_name, tran in zip(clean_names, trans): - raw_audio_path = f"raw/{clean_name}" - if "." not in raw_audio_path: - raw_audio_path += ".wav" - infer_tool.format_wav(raw_audio_path) - for spk in spk_list: - kwarg = { - "raw_audio_path": raw_audio_path, - "spk": spk, - "tran": tran, - "slice_db": slice_db, - "cluster_infer_ratio": cluster_infer_ratio, - "auto_predict_f0": auto_predict_f0, - "noice_scale": noice_scale, - "pad_seconds": pad_seconds, - "clip_seconds": clip, - "lg_num": lg, - "lgr_num": lgr, - "f0_predictor": f0p, - "enhancer_adaptive_key": enhancer_adaptive_key, - "cr_threshold": cr_threshold, - "k_step": k_step, - "use_spk_mix": use_spk_mix, - "second_encoding": second_encoding, - "loudness_envelope_adjustment": loudness_envelope_adjustment - } - audio = svc_model.slice_inference(**kwarg) - key = "auto" if auto_predict_f0 else f"{tran}key" - cluster_name = "" if cluster_infer_ratio == 0 else f"_{cluster_infer_ratio}" - isdiffusion = "sovits" - if shallow_diffusion: isdiffusion = "sovdiff" - if only_diffusion: isdiffusion = "diff" - if use_spk_mix: - spk = "spk_mix" - res_path = f'results/{clean_name}_{key}_{spk}{cluster_name}_{isdiffusion}.{wav_format}' - soundfile.write(res_path, audio, svc_model.target_sample, format=wav_format) - svc_model.clear_empty() - - -if __name__ == '__main__': - main() diff --git a/models.py b/models.py index ac40c3cda6b5ef351049b0348711f90e2985ce1e..24338fa2c1f6c15e60f5f341c7e3df2301f74eb8 100644 --- a/models.py +++ b/models.py @@ -1,20 +1,17 @@ -import copy -import math import torch from torch import nn +from torch.nn import Conv1d, Conv2d from torch.nn import functional as F +from torch.nn.utils import spectral_norm, weight_norm import modules.attentions as attentions import modules.commons as commons import modules.modules as modules - -from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d -from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm - import utils -from modules.commons import init_weights, get_padding +from modules.commons import get_padding from utils import f0_to_coarse + class ResidualCouplingBlock(nn.Module): def __init__(self, channels, @@ -23,7 +20,9 @@ class ResidualCouplingBlock(nn.Module): dilation_rate, n_layers, n_flows=4, - gin_channels=0): + gin_channels=0, + share_parameter=False + ): super().__init__() self.channels = channels self.hidden_channels = hidden_channels @@ -34,10 +33,53 @@ class ResidualCouplingBlock(nn.Module): self.gin_channels = gin_channels self.flows = nn.ModuleList() + + self.wn = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=0, gin_channels=gin_channels) if share_parameter else None + for i in range(n_flows): self.flows.append( modules.ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, - gin_channels=gin_channels, mean_only=True)) + gin_channels=gin_channels, mean_only=True, wn_sharing_parameter=self.wn)) + self.flows.append(modules.Flip()) + + def forward(self, x, x_mask, g=None, reverse=False): + if not reverse: + for flow in self.flows: + x, _ = flow(x, x_mask, g=g, reverse=reverse) + else: + for flow in reversed(self.flows): + x = flow(x, x_mask, g=g, reverse=reverse) + return x + +class TransformerCouplingBlock(nn.Module): + def __init__(self, + channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + n_flows=4, + gin_channels=0, + share_parameter=False + ): + + super().__init__() + self.channels = channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.n_flows = n_flows + self.gin_channels = gin_channels + + self.flows = nn.ModuleList() + + self.wn = attentions.FFT(hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, isflow = True, gin_channels = self.gin_channels) if share_parameter else None + + for i in range(n_flows): + self.flows.append( + modules.TransformerCouplingLayer(channels, hidden_channels, kernel_size, n_layers, n_heads, p_dropout, filter_channels, mean_only=True, wn_sharing_parameter=self.wn, gin_channels = self.gin_channels)) self.flows.append(modules.Flip()) def forward(self, x, x_mask, g=None, reverse=False): @@ -125,7 +167,7 @@ class DiscriminatorP(torch.nn.Module): super(DiscriminatorP, self).__init__() self.period = period self.use_spectral_norm = use_spectral_norm - norm_f = weight_norm if use_spectral_norm == False else spectral_norm + norm_f = weight_norm if use_spectral_norm is False else spectral_norm self.convs = nn.ModuleList([ norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), @@ -160,7 +202,7 @@ class DiscriminatorP(torch.nn.Module): class DiscriminatorS(torch.nn.Module): def __init__(self, use_spectral_norm=False): super(DiscriminatorS, self).__init__() - norm_f = weight_norm if use_spectral_norm == False else spectral_norm + norm_f = weight_norm if use_spectral_norm is False else spectral_norm self.convs = nn.ModuleList([ norm_f(Conv1d(1, 16, 15, 1, padding=7)), norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)), @@ -321,6 +363,12 @@ class SynthesizerTrn(nn.Module): sampling_rate=44100, vol_embedding=False, vocoder_name = "nsf-hifigan", + use_depthwise_conv = False, + use_automatic_f0_prediction = True, + flow_share_parameter = False, + n_flow_layer = 4, + n_layers_trans_flow = 3, + use_transformer_flow = False, **kwargs): super().__init__() @@ -343,6 +391,9 @@ class SynthesizerTrn(nn.Module): self.ssl_dim = ssl_dim self.vol_embedding = vol_embedding self.emb_g = nn.Embedding(n_speakers, gin_channels) + self.use_depthwise_conv = use_depthwise_conv + self.use_automatic_f0_prediction = use_automatic_f0_prediction + self.n_layers_trans_flow = n_layers_trans_flow if vol_embedding: self.emb_vol = nn.Linear(1, hidden_channels) @@ -367,9 +418,11 @@ class SynthesizerTrn(nn.Module): "upsample_initial_channel": upsample_initial_channel, "upsample_kernel_sizes": upsample_kernel_sizes, "gin_channels": gin_channels, + "use_depthwise_conv":use_depthwise_conv } - + modules.set_Conv1dModel(self.use_depthwise_conv) + if vocoder_name == "nsf-hifigan": from vdecoder.hifigan.models import Generator self.dec = Generator(h=hps) @@ -382,17 +435,21 @@ class SynthesizerTrn(nn.Module): self.dec = Generator(h=hps) self.enc_q = Encoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels) - self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels) - self.f0_decoder = F0Decoder( - 1, - hidden_channels, - filter_channels, - n_heads, - n_layers, - kernel_size, - p_dropout, - spk_channels=gin_channels - ) + if use_transformer_flow: + self.flow = TransformerCouplingBlock(inter_channels, hidden_channels, filter_channels, n_heads, n_layers_trans_flow, 5, p_dropout, n_flow_layer, gin_channels=gin_channels, share_parameter= flow_share_parameter) + else: + self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, n_flow_layer, gin_channels=gin_channels, share_parameter= flow_share_parameter) + if self.use_automatic_f0_prediction: + self.f0_decoder = F0Decoder( + 1, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + spk_channels=gin_channels + ) self.emb_uv = nn.Embedding(2, hidden_channels) self.character_mix = False @@ -407,17 +464,21 @@ class SynthesizerTrn(nn.Module): g = self.emb_g(g).transpose(1,2) # vol proj - vol = self.emb_vol(vol[:,:,None]).transpose(1,2) if vol!=None and self.vol_embedding else 0 + vol = self.emb_vol(vol[:,:,None]).transpose(1,2) if vol is not None and self.vol_embedding else 0 # ssl prenet x_mask = torch.unsqueeze(commons.sequence_mask(c_lengths, c.size(2)), 1).to(c.dtype) x = self.pre(c) * x_mask + self.emb_uv(uv.long()).transpose(1,2) + vol - + # f0 predict - lf0 = 2595. * torch.log10(1. + f0.unsqueeze(1) / 700.) / 500 - norm_lf0 = utils.normalize_f0(lf0, x_mask, uv) - pred_lf0 = self.f0_decoder(x, norm_lf0, x_mask, spk_emb=g) - + if self.use_automatic_f0_prediction: + lf0 = 2595. * torch.log10(1. + f0.unsqueeze(1) / 700.) / 500 + norm_lf0 = utils.normalize_f0(lf0, x_mask, uv) + pred_lf0 = self.f0_decoder(x, norm_lf0, x_mask, spk_emb=g) + else: + lf0 = 0 + norm_lf0 = 0 + pred_lf0 = 0 # encoder z_ptemp, m_p, logs_p, _ = self.enc_p(x, x_mask, f0=f0_to_coarse(f0)) z, m_q, logs_q, spec_mask = self.enc_q(spec, spec_lengths, g=g) @@ -431,6 +492,7 @@ class SynthesizerTrn(nn.Module): return o, ids_slice, spec_mask, (z, z_p, m_p, logs_p, m_q, logs_q), pred_lf0, norm_lf0, lf0 + @torch.no_grad() def infer(self, c, f0, uv, g=None, noice_scale=0.35, seed=52468, predict_f0=False, vol = None): if c.device == torch.device("cuda"): @@ -452,11 +514,13 @@ class SynthesizerTrn(nn.Module): x_mask = torch.unsqueeze(commons.sequence_mask(c_lengths, c.size(2)), 1).to(c.dtype) # vol proj - vol = self.emb_vol(vol[:,:,None]).transpose(1,2) if vol!=None and self.vol_embedding else 0 - - x = self.pre(c) * x_mask + self.emb_uv(uv.long()).transpose(1,2) + vol - if predict_f0: + vol = self.emb_vol(vol[:,:,None]).transpose(1,2) if vol is not None and self.vol_embedding else 0 + + x = self.pre(c) * x_mask + self.emb_uv(uv.long()).transpose(1, 2) + vol + + + if self.use_automatic_f0_prediction and predict_f0: lf0 = 2595. * torch.log10(1. + f0.unsqueeze(1) / 700.) / 500 norm_lf0 = utils.normalize_f0(lf0, x_mask, uv, random_scale=False) pred_lf0 = self.f0_decoder(x, norm_lf0, x_mask, spk_emb=g) diff --git a/modules/DSConv.py b/modules/DSConv.py new file mode 100644 index 0000000000000000000000000000000000000000..44c2bf60e9cd2b837ca95fb6436768782057014a --- /dev/null +++ b/modules/DSConv.py @@ -0,0 +1,76 @@ +import torch.nn as nn +from torch.nn.utils import remove_weight_norm, weight_norm + + +class Depthwise_Separable_Conv1D(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride = 1, + padding = 0, + dilation = 1, + bias = True, + padding_mode = 'zeros', # TODO: refine this type + device=None, + dtype=None + ): + super().__init__() + self.depth_conv = nn.Conv1d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size, groups=in_channels,stride = stride,padding=padding,dilation=dilation,bias=bias,padding_mode=padding_mode,device=device,dtype=dtype) + self.point_conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias, device=device,dtype=dtype) + + def forward(self, input): + return self.point_conv(self.depth_conv(input)) + + def weight_norm(self): + self.depth_conv = weight_norm(self.depth_conv, name = 'weight') + self.point_conv = weight_norm(self.point_conv, name = 'weight') + + def remove_weight_norm(self): + self.depth_conv = remove_weight_norm(self.depth_conv, name = 'weight') + self.point_conv = remove_weight_norm(self.point_conv, name = 'weight') + +class Depthwise_Separable_TransposeConv1D(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride = 1, + padding = 0, + output_padding = 0, + bias = True, + dilation = 1, + padding_mode = 'zeros', # TODO: refine this type + device=None, + dtype=None + ): + super().__init__() + self.depth_conv = nn.ConvTranspose1d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size, groups=in_channels,stride = stride,output_padding=output_padding,padding=padding,dilation=dilation,bias=bias,padding_mode=padding_mode,device=device,dtype=dtype) + self.point_conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias, device=device,dtype=dtype) + + def forward(self, input): + return self.point_conv(self.depth_conv(input)) + + def weight_norm(self): + self.depth_conv = weight_norm(self.depth_conv, name = 'weight') + self.point_conv = weight_norm(self.point_conv, name = 'weight') + + def remove_weight_norm(self): + remove_weight_norm(self.depth_conv, name = 'weight') + remove_weight_norm(self.point_conv, name = 'weight') + + +def weight_norm_modules(module, name = 'weight', dim = 0): + if isinstance(module,Depthwise_Separable_Conv1D) or isinstance(module,Depthwise_Separable_TransposeConv1D): + module.weight_norm() + return module + else: + return weight_norm(module,name,dim) + +def remove_weight_norm_modules(module, name = 'weight'): + if isinstance(module,Depthwise_Separable_Conv1D) or isinstance(module,Depthwise_Separable_TransposeConv1D): + module.remove_weight_norm() + else: + remove_weight_norm(module,name) \ No newline at end of file diff --git a/modules/F0Predictor/CrepeF0Predictor.py b/modules/F0Predictor/CrepeF0Predictor.py index e0052881b9b7b3aa373ebf69eb553815a564f610..c0854b64ed3bff96ed3381a7ef666c784aefd995 100644 --- a/modules/F0Predictor/CrepeF0Predictor.py +++ b/modules/F0Predictor/CrepeF0Predictor.py @@ -1,7 +1,9 @@ -from modules.F0Predictor.F0Predictor import F0Predictor -from modules.F0Predictor.crepe import CrepePitchExtractor import torch +from modules.F0Predictor.crepe import CrepePitchExtractor +from modules.F0Predictor.F0Predictor import F0Predictor + + class CrepeF0Predictor(F0Predictor): def __init__(self,hop_length=512,f0_min=50,f0_max=1100,device=None,sampling_rate=44100,threshold=0.05,model="full"): self.F0Creper = CrepePitchExtractor(hop_length=hop_length,f0_min=f0_min,f0_max=f0_max,device=device,threshold=threshold,model=model) @@ -11,6 +13,7 @@ class CrepeF0Predictor(F0Predictor): self.device = device self.threshold = threshold self.sampling_rate = sampling_rate + self.name = "crepe" def compute_f0(self,wav,p_len=None): x = torch.FloatTensor(wav).to(self.device) diff --git a/modules/F0Predictor/DioF0Predictor.py b/modules/F0Predictor/DioF0Predictor.py index 4ab27de23cae4dbc282e30f84501afebd1a37518..178dd2e8a02b79e5af113300f00d6a4dc2fb2a07 100644 --- a/modules/F0Predictor/DioF0Predictor.py +++ b/modules/F0Predictor/DioF0Predictor.py @@ -1,6 +1,8 @@ -from modules.F0Predictor.F0Predictor import F0Predictor -import pyworld import numpy as np +import pyworld + +from modules.F0Predictor.F0Predictor import F0Predictor + class DioF0Predictor(F0Predictor): def __init__(self,hop_length=512,f0_min=50,f0_max=1100,sampling_rate=44100): @@ -8,44 +10,31 @@ class DioF0Predictor(F0Predictor): self.f0_min = f0_min self.f0_max = f0_max self.sampling_rate = sampling_rate + self.name = "dio" def interpolate_f0(self,f0): ''' 对F0进行插值处理 ''' + vuv_vector = np.zeros_like(f0, dtype=np.float32) + vuv_vector[f0 > 0.0] = 1.0 + vuv_vector[f0 <= 0.0] = 0.0 - data = np.reshape(f0, (f0.size, 1)) - - vuv_vector = np.zeros((data.size, 1), dtype=np.float32) - vuv_vector[data > 0.0] = 1.0 - vuv_vector[data <= 0.0] = 0.0 - - ip_data = data - - frame_number = data.size - last_value = 0.0 - for i in range(frame_number): - if data[i] <= 0.0: - j = i + 1 - for j in range(i + 1, frame_number): - if data[j] > 0.0: - break - if j < frame_number - 1: - if last_value > 0.0: - step = (data[j] - data[i - 1]) / float(j - i) - for k in range(i, j): - ip_data[k] = data[i - 1] + step * (k - i + 1) - else: - for k in range(i, j): - ip_data[k] = data[j] - else: - for k in range(i, frame_number): - ip_data[k] = last_value - else: - ip_data[i] = data[i] #这里可能存在一个没有必要的拷贝 - last_value = data[i] - - return ip_data[:,0], vuv_vector[:,0] + nzindex = np.nonzero(f0)[0] + data = f0[nzindex] + nzindex = nzindex.astype(np.float32) + time_org = self.hop_length / self.sampling_rate * nzindex + time_frame = np.arange(f0.shape[0]) * self.hop_length / self.sampling_rate + + if data.shape[0] <= 0: + return np.zeros(f0.shape[0], dtype=np.float32),vuv_vector + + if data.shape[0] == 1: + return np.ones(f0.shape[0], dtype=np.float32) * f0[0],vuv_vector + + f0 = np.interp(time_frame, time_org, data, left=data[0], right=data[-1]) + + return f0,vuv_vector def resize_f0(self,x, target_len): source = np.array(x) diff --git a/modules/F0Predictor/FCPEF0Predictor.py b/modules/F0Predictor/FCPEF0Predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..1096e040110d5f526e388d989c08b25937eac8f5 --- /dev/null +++ b/modules/F0Predictor/FCPEF0Predictor.py @@ -0,0 +1,109 @@ +from typing import Union + +import numpy as np +import torch +import torch.nn.functional as F + +from modules.F0Predictor.F0Predictor import F0Predictor + +from .fcpe.model import FCPEInfer + + +class FCPEF0Predictor(F0Predictor): + def __init__(self, hop_length=512, f0_min=50, f0_max=1100, dtype=torch.float32, device=None, sampling_rate=44100, + threshold=0.05): + self.fcpe = FCPEInfer(model_path="pretrain/fcpe.pt", device=device, dtype=dtype) + self.hop_length = hop_length + self.f0_min = f0_min + self.f0_max = f0_max + if device is None: + self.device = 'cuda' if torch.cuda.is_available() else 'cpu' + else: + self.device = device + self.threshold = threshold + self.sampling_rate = sampling_rate + self.dtype = dtype + self.name = "fcpe" + + def repeat_expand( + self, content: Union[torch.Tensor, np.ndarray], target_len: int, mode: str = "nearest" + ): + ndim = content.ndim + + if content.ndim == 1: + content = content[None, None] + elif content.ndim == 2: + content = content[None] + + assert content.ndim == 3 + + is_np = isinstance(content, np.ndarray) + if is_np: + content = torch.from_numpy(content) + + results = torch.nn.functional.interpolate(content, size=target_len, mode=mode) + + if is_np: + results = results.numpy() + + if ndim == 1: + return results[0, 0] + elif ndim == 2: + return results[0] + + def post_process(self, x, sampling_rate, f0, pad_to): + if isinstance(f0, np.ndarray): + f0 = torch.from_numpy(f0).float().to(x.device) + + if pad_to is None: + return f0 + + f0 = self.repeat_expand(f0, pad_to) + + vuv_vector = torch.zeros_like(f0) + vuv_vector[f0 > 0.0] = 1.0 + vuv_vector[f0 <= 0.0] = 0.0 + + # 去掉0频率, 并线性插值 + nzindex = torch.nonzero(f0).squeeze() + f0 = torch.index_select(f0, dim=0, index=nzindex).cpu().numpy() + time_org = self.hop_length / sampling_rate * nzindex.cpu().numpy() + time_frame = np.arange(pad_to) * self.hop_length / sampling_rate + + vuv_vector = F.interpolate(vuv_vector[None, None, :], size=pad_to)[0][0] + + if f0.shape[0] <= 0: + return torch.zeros(pad_to, dtype=torch.float, device=x.device).cpu().numpy(), vuv_vector.cpu().numpy() + if f0.shape[0] == 1: + return (torch.ones(pad_to, dtype=torch.float, device=x.device) * f0[ + 0]).cpu().numpy(), vuv_vector.cpu().numpy() + + # 大概可以用 torch 重写? + f0 = np.interp(time_frame, time_org, f0, left=f0[0], right=f0[-1]) + # vuv_vector = np.ceil(scipy.ndimage.zoom(vuv_vector,pad_to/len(vuv_vector),order = 0)) + + return f0, vuv_vector.cpu().numpy() + + def compute_f0(self, wav, p_len=None): + x = torch.FloatTensor(wav).to(self.dtype).to(self.device) + if p_len is None: + p_len = x.shape[0] // self.hop_length + else: + assert abs(p_len - x.shape[0] // self.hop_length) < 4, "pad length error" + f0 = self.fcpe(x, sr=self.sampling_rate, threshold=self.threshold)[0,:,0] + if torch.all(f0 == 0): + rtn = f0.cpu().numpy() if p_len is None else np.zeros(p_len) + return rtn, rtn + return self.post_process(x, self.sampling_rate, f0, p_len)[0] + + def compute_f0_uv(self, wav, p_len=None): + x = torch.FloatTensor(wav).to(self.dtype).to(self.device) + if p_len is None: + p_len = x.shape[0] // self.hop_length + else: + assert abs(p_len - x.shape[0] // self.hop_length) < 4, "pad length error" + f0 = self.fcpe(x, sr=self.sampling_rate, threshold=self.threshold)[0,:,0] + if torch.all(f0 == 0): + rtn = f0.cpu().numpy() if p_len is None else np.zeros(p_len) + return rtn, rtn + return self.post_process(x, self.sampling_rate, f0, p_len) \ No newline at end of file diff --git a/modules/F0Predictor/HarvestF0Predictor.py b/modules/F0Predictor/HarvestF0Predictor.py index 122bdbb4c736feb4a8d974eca03df71aede76f69..f36b332f7b42802918ce3e232a6609413394acf9 100644 --- a/modules/F0Predictor/HarvestF0Predictor.py +++ b/modules/F0Predictor/HarvestF0Predictor.py @@ -1,6 +1,8 @@ -from modules.F0Predictor.F0Predictor import F0Predictor -import pyworld import numpy as np +import pyworld + +from modules.F0Predictor.F0Predictor import F0Predictor + class HarvestF0Predictor(F0Predictor): def __init__(self,hop_length=512,f0_min=50,f0_max=1100,sampling_rate=44100): @@ -8,45 +10,31 @@ class HarvestF0Predictor(F0Predictor): self.f0_min = f0_min self.f0_max = f0_max self.sampling_rate = sampling_rate + self.name = "harvest" def interpolate_f0(self,f0): ''' 对F0进行插值处理 ''' + vuv_vector = np.zeros_like(f0, dtype=np.float32) + vuv_vector[f0 > 0.0] = 1.0 + vuv_vector[f0 <= 0.0] = 0.0 - data = np.reshape(f0, (f0.size, 1)) - - vuv_vector = np.zeros((data.size, 1), dtype=np.float32) - vuv_vector[data > 0.0] = 1.0 - vuv_vector[data <= 0.0] = 0.0 - - ip_data = data - - frame_number = data.size - last_value = 0.0 - for i in range(frame_number): - if data[i] <= 0.0: - j = i + 1 - for j in range(i + 1, frame_number): - if data[j] > 0.0: - break - if j < frame_number - 1: - if last_value > 0.0: - step = (data[j] - data[i - 1]) / float(j - i) - for k in range(i, j): - ip_data[k] = data[i - 1] + step * (k - i + 1) - else: - for k in range(i, j): - ip_data[k] = data[j] - else: - for k in range(i, frame_number): - ip_data[k] = last_value - else: - ip_data[i] = data[i] #这里可能存在一个没有必要的拷贝 - last_value = data[i] - - return ip_data[:,0], vuv_vector[:,0] + nzindex = np.nonzero(f0)[0] + data = f0[nzindex] + nzindex = nzindex.astype(np.float32) + time_org = self.hop_length / self.sampling_rate * nzindex + time_frame = np.arange(f0.shape[0]) * self.hop_length / self.sampling_rate + if data.shape[0] <= 0: + return np.zeros(f0.shape[0], dtype=np.float32),vuv_vector + + if data.shape[0] == 1: + return np.ones(f0.shape[0], dtype=np.float32) * f0[0],vuv_vector + + f0 = np.interp(time_frame, time_org, data, left=data[0], right=data[-1]) + + return f0,vuv_vector def resize_f0(self,x, target_len): source = np.array(x) source[source<0.001] = np.nan diff --git a/modules/F0Predictor/PMF0Predictor.py b/modules/F0Predictor/PMF0Predictor.py index ccf4128436c5b7e5a3e720d4597bad0c622d0920..2af3f6e7ee7c5c4e10899f9988e1d9b92aa52157 100644 --- a/modules/F0Predictor/PMF0Predictor.py +++ b/modules/F0Predictor/PMF0Predictor.py @@ -1,6 +1,8 @@ -from modules.F0Predictor.F0Predictor import F0Predictor -import parselmouth import numpy as np +import parselmouth + +from modules.F0Predictor.F0Predictor import F0Predictor + class PMF0Predictor(F0Predictor): def __init__(self,hop_length=512,f0_min=50,f0_max=1100,sampling_rate=44100): @@ -8,45 +10,32 @@ class PMF0Predictor(F0Predictor): self.f0_min = f0_min self.f0_max = f0_max self.sampling_rate = sampling_rate - + self.name = "pm" def interpolate_f0(self,f0): ''' 对F0进行插值处理 ''' + vuv_vector = np.zeros_like(f0, dtype=np.float32) + vuv_vector[f0 > 0.0] = 1.0 + vuv_vector[f0 <= 0.0] = 0.0 - data = np.reshape(f0, (f0.size, 1)) - - vuv_vector = np.zeros((data.size, 1), dtype=np.float32) - vuv_vector[data > 0.0] = 1.0 - vuv_vector[data <= 0.0] = 0.0 - - ip_data = data - - frame_number = data.size - last_value = 0.0 - for i in range(frame_number): - if data[i] <= 0.0: - j = i + 1 - for j in range(i + 1, frame_number): - if data[j] > 0.0: - break - if j < frame_number - 1: - if last_value > 0.0: - step = (data[j] - data[i - 1]) / float(j - i) - for k in range(i, j): - ip_data[k] = data[i - 1] + step * (k - i + 1) - else: - for k in range(i, j): - ip_data[k] = data[j] - else: - for k in range(i, frame_number): - ip_data[k] = last_value - else: - ip_data[i] = data[i] #这里可能存在一个没有必要的拷贝 - last_value = data[i] + nzindex = np.nonzero(f0)[0] + data = f0[nzindex] + nzindex = nzindex.astype(np.float32) + time_org = self.hop_length / self.sampling_rate * nzindex + time_frame = np.arange(f0.shape[0]) * self.hop_length / self.sampling_rate + + if data.shape[0] <= 0: + return np.zeros(f0.shape[0], dtype=np.float32),vuv_vector + + if data.shape[0] == 1: + return np.ones(f0.shape[0], dtype=np.float32) * f0[0],vuv_vector + + f0 = np.interp(time_frame, time_org, data, left=data[0], right=data[-1]) + + return f0,vuv_vector - return ip_data[:,0], vuv_vector[:,0] def compute_f0(self,wav,p_len=None): x = wav diff --git a/modules/F0Predictor/RMVPEF0Predictor.py b/modules/F0Predictor/RMVPEF0Predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..9313887be084e99059e6c76fffba323de1f3c835 --- /dev/null +++ b/modules/F0Predictor/RMVPEF0Predictor.py @@ -0,0 +1,107 @@ +from typing import Union + +import numpy as np +import torch +import torch.nn.functional as F + +from modules.F0Predictor.F0Predictor import F0Predictor + +from .rmvpe import RMVPE + + +class RMVPEF0Predictor(F0Predictor): + def __init__(self,hop_length=512,f0_min=50,f0_max=1100, dtype=torch.float32, device=None,sampling_rate=44100,threshold=0.05): + self.rmvpe = RMVPE(model_path="pretrain/rmvpe.pt",dtype=dtype,device=device) + self.hop_length = hop_length + self.f0_min = f0_min + self.f0_max = f0_max + if device is None: + self.device = 'cuda' if torch.cuda.is_available() else 'cpu' + else: + self.device = device + self.threshold = threshold + self.sampling_rate = sampling_rate + self.dtype = dtype + self.name = "rmvpe" + + def repeat_expand( + self, content: Union[torch.Tensor, np.ndarray], target_len: int, mode: str = "nearest" + ): + ndim = content.ndim + + if content.ndim == 1: + content = content[None, None] + elif content.ndim == 2: + content = content[None] + + assert content.ndim == 3 + + is_np = isinstance(content, np.ndarray) + if is_np: + content = torch.from_numpy(content) + + results = torch.nn.functional.interpolate(content, size=target_len, mode=mode) + + if is_np: + results = results.numpy() + + if ndim == 1: + return results[0, 0] + elif ndim == 2: + return results[0] + + def post_process(self, x, sampling_rate, f0, pad_to): + if isinstance(f0, np.ndarray): + f0 = torch.from_numpy(f0).float().to(x.device) + + if pad_to is None: + return f0 + + f0 = self.repeat_expand(f0, pad_to) + + vuv_vector = torch.zeros_like(f0) + vuv_vector[f0 > 0.0] = 1.0 + vuv_vector[f0 <= 0.0] = 0.0 + + # 去掉0频率, 并线性插值 + nzindex = torch.nonzero(f0).squeeze() + f0 = torch.index_select(f0, dim=0, index=nzindex).cpu().numpy() + time_org = self.hop_length / sampling_rate * nzindex.cpu().numpy() + time_frame = np.arange(pad_to) * self.hop_length / sampling_rate + + vuv_vector = F.interpolate(vuv_vector[None,None,:],size=pad_to)[0][0] + + if f0.shape[0] <= 0: + return torch.zeros(pad_to, dtype=torch.float, device=x.device).cpu().numpy(),vuv_vector.cpu().numpy() + if f0.shape[0] == 1: + return (torch.ones(pad_to, dtype=torch.float, device=x.device) * f0[0]).cpu().numpy() ,vuv_vector.cpu().numpy() + + # 大概可以用 torch 重写? + f0 = np.interp(time_frame, time_org, f0, left=f0[0], right=f0[-1]) + #vuv_vector = np.ceil(scipy.ndimage.zoom(vuv_vector,pad_to/len(vuv_vector),order = 0)) + + return f0,vuv_vector.cpu().numpy() + + def compute_f0(self,wav,p_len=None): + x = torch.FloatTensor(wav).to(self.dtype).to(self.device) + if p_len is None: + p_len = x.shape[0]//self.hop_length + else: + assert abs(p_len-x.shape[0]//self.hop_length) < 4, "pad length error" + f0 = self.rmvpe.infer_from_audio(x,self.sampling_rate,self.threshold) + if torch.all(f0 == 0): + rtn = f0.cpu().numpy() if p_len is None else np.zeros(p_len) + return rtn,rtn + return self.post_process(x,self.sampling_rate,f0,p_len)[0] + + def compute_f0_uv(self,wav,p_len=None): + x = torch.FloatTensor(wav).to(self.dtype).to(self.device) + if p_len is None: + p_len = x.shape[0]//self.hop_length + else: + assert abs(p_len-x.shape[0]//self.hop_length) < 4, "pad length error" + f0 = self.rmvpe.infer_from_audio(x,self.sampling_rate,self.threshold) + if torch.all(f0 == 0): + rtn = f0.cpu().numpy() if p_len is None else np.zeros(p_len) + return rtn,rtn + return self.post_process(x,self.sampling_rate,f0,p_len) \ No newline at end of file diff --git a/modules/F0Predictor/crepe.py b/modules/F0Predictor/crepe.py index c6fb45c79bcd306202a2c0282b3d73a8074ced5d..e68f19cb39eb79931926ffd312fb61e30bf39d72 100644 --- a/modules/F0Predictor/crepe.py +++ b/modules/F0Predictor/crepe.py @@ -1,14 +1,14 @@ -from typing import Optional,Union +from typing import Optional, Union + try: from typing import Literal -except Exception as e: +except Exception: from typing_extensions import Literal import numpy as np import torch import torchcrepe from torch import nn from torch.nn import functional as F -import scipy #from:https://github.com/fishaudio/fish-diffusion @@ -97,19 +97,19 @@ class BasePitchExtractor: f0 = torch.index_select(f0, dim=0, index=nzindex).cpu().numpy() time_org = self.hop_length / sampling_rate * nzindex.cpu().numpy() time_frame = np.arange(pad_to) * self.hop_length / sampling_rate + + vuv_vector = F.interpolate(vuv_vector[None,None,:],size=pad_to)[0][0] if f0.shape[0] <= 0: - return torch.zeros(pad_to, dtype=torch.float, device=x.device),torch.zeros(pad_to, dtype=torch.float, device=x.device) - + return torch.zeros(pad_to, dtype=torch.float, device=x.device),vuv_vector.cpu().numpy() if f0.shape[0] == 1: - return torch.ones(pad_to, dtype=torch.float, device=x.device) * f0[0],torch.ones(pad_to, dtype=torch.float, device=x.device) + return torch.ones(pad_to, dtype=torch.float, device=x.device) * f0[0],vuv_vector.cpu().numpy() # 大概可以用 torch 重写? f0 = np.interp(time_frame, time_org, f0, left=f0[0], right=f0[-1]) - vuv_vector = vuv_vector.cpu().numpy() - vuv_vector = np.ceil(scipy.ndimage.zoom(vuv_vector,pad_to/len(vuv_vector),order = 0)) + #vuv_vector = np.ceil(scipy.ndimage.zoom(vuv_vector,pad_to/len(vuv_vector),order = 0)) - return f0,vuv_vector + return f0,vuv_vector.cpu().numpy() class MaskedAvgPool1d(nn.Module): @@ -323,7 +323,7 @@ class CrepePitchExtractor(BasePitchExtractor): else: pd = torchcrepe.filter.median(pd, 3) - pd = torchcrepe.threshold.Silence(-60.0)(pd, x, sampling_rate, 512) + pd = torchcrepe.threshold.Silence(-60.0)(pd, x, sampling_rate, self.hop_length) f0 = torchcrepe.threshold.At(self.threshold)(f0, pd) if self.use_fast_filters: @@ -334,7 +334,7 @@ class CrepePitchExtractor(BasePitchExtractor): f0 = torch.where(torch.isnan(f0), torch.full_like(f0, 0), f0)[0] if torch.all(f0 == 0): - rtn = f0.cpu().numpy() if pad_to==None else np.zeros(pad_to) + rtn = f0.cpu().numpy() if pad_to is None else np.zeros(pad_to) return rtn,rtn return self.post_process(x, sampling_rate, f0, pad_to) diff --git a/modules/F0Predictor/fcpe/__init__.py b/modules/F0Predictor/fcpe/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a33fdace676fb8e13eaf02b693442d707812b06b --- /dev/null +++ b/modules/F0Predictor/fcpe/__init__.py @@ -0,0 +1,3 @@ +from .model import FCPEInfer # noqa: F401 +from .nvSTFT import STFT # noqa: F401 +from .pcmer import PCmer # noqa: F401 diff --git a/modules/F0Predictor/fcpe/model.py b/modules/F0Predictor/fcpe/model.py new file mode 100644 index 0000000000000000000000000000000000000000..91ad6baadf3fa122bd373c52815e0eb60ed025b3 --- /dev/null +++ b/modules/F0Predictor/fcpe/model.py @@ -0,0 +1,262 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.utils import weight_norm +from torchaudio.transforms import Resample + +from .nvSTFT import STFT +from .pcmer import PCmer + + +def l2_regularization(model, l2_alpha): + l2_loss = [] + for module in model.modules(): + if type(module) is nn.Conv2d: + l2_loss.append((module.weight ** 2).sum() / 2.0) + return l2_alpha * sum(l2_loss) + + +class FCPE(nn.Module): + def __init__( + self, + input_channel=128, + out_dims=360, + n_layers=12, + n_chans=512, + use_siren=False, + use_full=False, + loss_mse_scale=10, + loss_l2_regularization=False, + loss_l2_regularization_scale=1, + loss_grad1_mse=False, + loss_grad1_mse_scale=1, + f0_max=1975.5, + f0_min=32.70, + confidence=False, + threshold=0.05, + use_input_conv=True + ): + super().__init__() + if use_siren is True: + raise ValueError("Siren is not supported yet.") + if use_full is True: + raise ValueError("Full model is not supported yet.") + + self.loss_mse_scale = loss_mse_scale if (loss_mse_scale is not None) else 10 + self.loss_l2_regularization = loss_l2_regularization if (loss_l2_regularization is not None) else False + self.loss_l2_regularization_scale = loss_l2_regularization_scale if (loss_l2_regularization_scale + is not None) else 1 + self.loss_grad1_mse = loss_grad1_mse if (loss_grad1_mse is not None) else False + self.loss_grad1_mse_scale = loss_grad1_mse_scale if (loss_grad1_mse_scale is not None) else 1 + self.f0_max = f0_max if (f0_max is not None) else 1975.5 + self.f0_min = f0_min if (f0_min is not None) else 32.70 + self.confidence = confidence if (confidence is not None) else False + self.threshold = threshold if (threshold is not None) else 0.05 + self.use_input_conv = use_input_conv if (use_input_conv is not None) else True + + self.cent_table_b = torch.Tensor( + np.linspace(self.f0_to_cent(torch.Tensor([f0_min]))[0], self.f0_to_cent(torch.Tensor([f0_max]))[0], + out_dims)) + self.register_buffer("cent_table", self.cent_table_b) + + # conv in stack + _leaky = nn.LeakyReLU() + self.stack = nn.Sequential( + nn.Conv1d(input_channel, n_chans, 3, 1, 1), + nn.GroupNorm(4, n_chans), + _leaky, + nn.Conv1d(n_chans, n_chans, 3, 1, 1)) + + # transformer + self.decoder = PCmer( + num_layers=n_layers, + num_heads=8, + dim_model=n_chans, + dim_keys=n_chans, + dim_values=n_chans, + residual_dropout=0.1, + attention_dropout=0.1) + self.norm = nn.LayerNorm(n_chans) + + # out + self.n_out = out_dims + self.dense_out = weight_norm( + nn.Linear(n_chans, self.n_out)) + + def forward(self, mel, infer=True, gt_f0=None, return_hz_f0=False, cdecoder = "local_argmax"): + """ + input: + B x n_frames x n_unit + return: + dict of B x n_frames x feat + """ + if cdecoder == "argmax": + self.cdecoder = self.cents_decoder + elif cdecoder == "local_argmax": + self.cdecoder = self.cents_local_decoder + if self.use_input_conv: + x = self.stack(mel.transpose(1, 2)).transpose(1, 2) + else: + x = mel + x = self.decoder(x) + x = self.norm(x) + x = self.dense_out(x) # [B,N,D] + x = torch.sigmoid(x) + if not infer: + gt_cent_f0 = self.f0_to_cent(gt_f0) # mel f0 #[B,N,1] + gt_cent_f0 = self.gaussian_blurred_cent(gt_cent_f0) # #[B,N,out_dim] + loss_all = self.loss_mse_scale * F.binary_cross_entropy(x, gt_cent_f0) # bce loss + # l2 regularization + if self.loss_l2_regularization: + loss_all = loss_all + l2_regularization(model=self, l2_alpha=self.loss_l2_regularization_scale) + x = loss_all + if infer: + x = self.cdecoder(x) + x = self.cent_to_f0(x) + if not return_hz_f0: + x = (1 + x / 700).log() + return x + + def cents_decoder(self, y, mask=True): + B, N, _ = y.size() + ci = self.cent_table[None, None, :].expand(B, N, -1) + rtn = torch.sum(ci * y, dim=-1, keepdim=True) / torch.sum(y, dim=-1, keepdim=True) # cents: [B,N,1] + if mask: + confident = torch.max(y, dim=-1, keepdim=True)[0] + confident_mask = torch.ones_like(confident) + confident_mask[confident <= self.threshold] = float("-INF") + rtn = rtn * confident_mask + if self.confidence: + return rtn, confident + else: + return rtn + + def cents_local_decoder(self, y, mask=True): + B, N, _ = y.size() + ci = self.cent_table[None, None, :].expand(B, N, -1) + confident, max_index = torch.max(y, dim=-1, keepdim=True) + local_argmax_index = torch.arange(0,9).to(max_index.device) + (max_index - 4) + local_argmax_index[local_argmax_index<0] = 0 + local_argmax_index[local_argmax_index>=self.n_out] = self.n_out - 1 + ci_l = torch.gather(ci,-1,local_argmax_index) + y_l = torch.gather(y,-1,local_argmax_index) + rtn = torch.sum(ci_l * y_l, dim=-1, keepdim=True) / torch.sum(y_l, dim=-1, keepdim=True) # cents: [B,N,1] + if mask: + confident_mask = torch.ones_like(confident) + confident_mask[confident <= self.threshold] = float("-INF") + rtn = rtn * confident_mask + if self.confidence: + return rtn, confident + else: + return rtn + + def cent_to_f0(self, cent): + return 10. * 2 ** (cent / 1200.) + + def f0_to_cent(self, f0): + return 1200. * torch.log2(f0 / 10.) + + def gaussian_blurred_cent(self, cents): # cents: [B,N,1] + mask = (cents > 0.1) & (cents < (1200. * np.log2(self.f0_max / 10.))) + B, N, _ = cents.size() + ci = self.cent_table[None, None, :].expand(B, N, -1) + return torch.exp(-torch.square(ci - cents) / 1250) * mask.float() + + +class FCPEInfer: + def __init__(self, model_path, device=None, dtype=torch.float32): + if device is None: + device = 'cuda' if torch.cuda.is_available() else 'cpu' + self.device = device + ckpt = torch.load(model_path, map_location=torch.device(self.device)) + self.args = DotDict(ckpt["config"]) + self.dtype = dtype + model = FCPE( + input_channel=self.args.model.input_channel, + out_dims=self.args.model.out_dims, + n_layers=self.args.model.n_layers, + n_chans=self.args.model.n_chans, + use_siren=self.args.model.use_siren, + use_full=self.args.model.use_full, + loss_mse_scale=self.args.loss.loss_mse_scale, + loss_l2_regularization=self.args.loss.loss_l2_regularization, + loss_l2_regularization_scale=self.args.loss.loss_l2_regularization_scale, + loss_grad1_mse=self.args.loss.loss_grad1_mse, + loss_grad1_mse_scale=self.args.loss.loss_grad1_mse_scale, + f0_max=self.args.model.f0_max, + f0_min=self.args.model.f0_min, + confidence=self.args.model.confidence, + ) + model.to(self.device).to(self.dtype) + model.load_state_dict(ckpt['model']) + model.eval() + self.model = model + self.wav2mel = Wav2Mel(self.args, dtype=self.dtype, device=self.device) + + @torch.no_grad() + def __call__(self, audio, sr, threshold=0.05): + self.model.threshold = threshold + audio = audio[None,:] + mel = self.wav2mel(audio=audio, sample_rate=sr).to(self.dtype) + f0 = self.model(mel=mel, infer=True, return_hz_f0=True) + return f0 + + +class Wav2Mel: + + def __init__(self, args, device=None, dtype=torch.float32): + # self.args = args + self.sampling_rate = args.mel.sampling_rate + self.hop_size = args.mel.hop_size + if device is None: + device = 'cuda' if torch.cuda.is_available() else 'cpu' + self.device = device + self.dtype = dtype + self.stft = STFT( + args.mel.sampling_rate, + args.mel.num_mels, + args.mel.n_fft, + args.mel.win_size, + args.mel.hop_size, + args.mel.fmin, + args.mel.fmax + ) + self.resample_kernel = {} + + def extract_nvstft(self, audio, keyshift=0, train=False): + mel = self.stft.get_mel(audio, keyshift=keyshift, train=train).transpose(1, 2) # B, n_frames, bins + return mel + + def extract_mel(self, audio, sample_rate, keyshift=0, train=False): + audio = audio.to(self.dtype).to(self.device) + # resample + if sample_rate == self.sampling_rate: + audio_res = audio + else: + key_str = str(sample_rate) + if key_str not in self.resample_kernel: + self.resample_kernel[key_str] = Resample(sample_rate, self.sampling_rate, lowpass_filter_width=128) + self.resample_kernel[key_str] = self.resample_kernel[key_str].to(self.dtype).to(self.device) + audio_res = self.resample_kernel[key_str](audio) + + # extract + mel = self.extract_nvstft(audio_res, keyshift=keyshift, train=train) # B, n_frames, bins + n_frames = int(audio.shape[1] // self.hop_size) + 1 + if n_frames > int(mel.shape[1]): + mel = torch.cat((mel, mel[:, -1:, :]), 1) + if n_frames < int(mel.shape[1]): + mel = mel[:, :n_frames, :] + return mel + + def __call__(self, audio, sample_rate, keyshift=0, train=False): + return self.extract_mel(audio, sample_rate, keyshift=keyshift, train=train) + + +class DotDict(dict): + def __getattr__(*args): + val = dict.get(*args) + return DotDict(val) if type(val) is dict else val + + __setattr__ = dict.__setitem__ + __delattr__ = dict.__delitem__ diff --git a/modules/F0Predictor/fcpe/nvSTFT.py b/modules/F0Predictor/fcpe/nvSTFT.py new file mode 100644 index 0000000000000000000000000000000000000000..b97435f8977d659f594b41fa3f8993ee85f02ee9 --- /dev/null +++ b/modules/F0Predictor/fcpe/nvSTFT.py @@ -0,0 +1,133 @@ +import os + +import librosa +import numpy as np +import soundfile as sf +import torch +import torch.nn.functional as F +import torch.utils.data +from librosa.filters import mel as librosa_mel_fn + +os.environ["LRU_CACHE_CAPACITY"] = "3" + +def load_wav_to_torch(full_path, target_sr=None, return_empty_on_exception=False): + sampling_rate = None + try: + data, sampling_rate = sf.read(full_path, always_2d=True)# than soundfile. + except Exception as ex: + print(f"'{full_path}' failed to load.\nException:") + print(ex) + if return_empty_on_exception: + return [], sampling_rate or target_sr or 48000 + else: + raise Exception(ex) + + if len(data.shape) > 1: + data = data[:, 0] + assert len(data) > 2# check duration of audio file is > 2 samples (because otherwise the slice operation was on the wrong dimension) + + if np.issubdtype(data.dtype, np.integer): # if audio data is type int + max_mag = -np.iinfo(data.dtype).min # maximum magnitude = min possible value of intXX + else: # if audio data is type fp32 + max_mag = max(np.amax(data), -np.amin(data)) + max_mag = (2**31)+1 if max_mag > (2**15) else ((2**15)+1 if max_mag > 1.01 else 1.0) # data should be either 16-bit INT, 32-bit INT or [-1 to 1] float32 + + data = torch.FloatTensor(data.astype(np.float32))/max_mag + + if (torch.isinf(data) | torch.isnan(data)).any() and return_empty_on_exception:# resample will crash with inf/NaN inputs. return_empty_on_exception will return empty arr instead of except + return [], sampling_rate or target_sr or 48000 + if target_sr is not None and sampling_rate != target_sr: + data = torch.from_numpy(librosa.core.resample(data.numpy(), orig_sr=sampling_rate, target_sr=target_sr)) + sampling_rate = target_sr + + return data, sampling_rate + +def dynamic_range_compression(x, C=1, clip_val=1e-5): + return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) + +def dynamic_range_decompression(x, C=1): + return np.exp(x) / C + +def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): + return torch.log(torch.clamp(x, min=clip_val) * C) + +def dynamic_range_decompression_torch(x, C=1): + return torch.exp(x) / C + +class STFT(): + def __init__(self, sr=22050, n_mels=80, n_fft=1024, win_size=1024, hop_length=256, fmin=20, fmax=11025, clip_val=1e-5): + self.target_sr = sr + + self.n_mels = n_mels + self.n_fft = n_fft + self.win_size = win_size + self.hop_length = hop_length + self.fmin = fmin + self.fmax = fmax + self.clip_val = clip_val + self.mel_basis = {} + self.hann_window = {} + + def get_mel(self, y, keyshift=0, speed=1, center=False, train=False): + sampling_rate = self.target_sr + n_mels = self.n_mels + n_fft = self.n_fft + win_size = self.win_size + hop_length = self.hop_length + fmin = self.fmin + fmax = self.fmax + clip_val = self.clip_val + + factor = 2 ** (keyshift / 12) + n_fft_new = int(np.round(n_fft * factor)) + win_size_new = int(np.round(win_size * factor)) + hop_length_new = int(np.round(hop_length * speed)) + if not train: + mel_basis = self.mel_basis + hann_window = self.hann_window + else: + mel_basis = {} + hann_window = {} + + if torch.min(y) < -1.: + print('min value is ', torch.min(y)) + if torch.max(y) > 1.: + print('max value is ', torch.max(y)) + + mel_basis_key = str(fmax)+'_'+str(y.device) + if mel_basis_key not in mel_basis: + mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax) + mel_basis[mel_basis_key] = torch.from_numpy(mel).float().to(y.device) + + keyshift_key = str(keyshift)+'_'+str(y.device) + if keyshift_key not in hann_window: + hann_window[keyshift_key] = torch.hann_window(win_size_new).to(y.device) + + pad_left = (win_size_new - hop_length_new) //2 + pad_right = max((win_size_new- hop_length_new + 1) //2, win_size_new - y.size(-1) - pad_left) + if pad_right < y.size(-1): + mode = 'reflect' + else: + mode = 'constant' + y = torch.nn.functional.pad(y.unsqueeze(1), (pad_left, pad_right), mode = mode) + y = y.squeeze(1) + + spec = torch.stft(y, n_fft_new, hop_length=hop_length_new, win_length=win_size_new, window=hann_window[keyshift_key], + center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True) + spec = torch.sqrt(spec.real.pow(2) + spec.imag.pow(2) + (1e-9)) + if keyshift != 0: + size = n_fft // 2 + 1 + resize = spec.size(1) + if resize < size: + spec = F.pad(spec, (0, 0, 0, size-resize)) + spec = spec[:, :size, :] * win_size / win_size_new + spec = torch.matmul(mel_basis[mel_basis_key], spec) + spec = dynamic_range_compression_torch(spec, clip_val=clip_val) + return spec + + def __call__(self, audiopath): + audio, sr = load_wav_to_torch(audiopath, target_sr=self.target_sr) + spect = self.get_mel(audio.unsqueeze(0)).squeeze(0) + return spect + +stft = STFT() diff --git a/modules/F0Predictor/fcpe/pcmer.py b/modules/F0Predictor/fcpe/pcmer.py new file mode 100644 index 0000000000000000000000000000000000000000..5c12678007ad62e1d370533fe37307226dc48492 --- /dev/null +++ b/modules/F0Predictor/fcpe/pcmer.py @@ -0,0 +1,369 @@ +import math +from functools import partial + +import torch +import torch.nn.functional as F +from einops import rearrange, repeat +from local_attention import LocalAttention +from torch import nn + +#import fast_transformers.causal_product.causal_product_cuda + +def softmax_kernel(data, *, projection_matrix, is_query, normalize_data=True, eps=1e-4, device = None): + b, h, *_ = data.shape + # (batch size, head, length, model_dim) + + # normalize model dim + data_normalizer = (data.shape[-1] ** -0.25) if normalize_data else 1. + + # what is ration?, projection_matrix.shape[0] --> 266 + + ratio = (projection_matrix.shape[0] ** -0.5) + + projection = repeat(projection_matrix, 'j d -> b h j d', b = b, h = h) + projection = projection.type_as(data) + + #data_dash = w^T x + data_dash = torch.einsum('...id,...jd->...ij', (data_normalizer * data), projection) + + + # diag_data = D**2 + diag_data = data ** 2 + diag_data = torch.sum(diag_data, dim=-1) + diag_data = (diag_data / 2.0) * (data_normalizer ** 2) + diag_data = diag_data.unsqueeze(dim=-1) + + #print () + if is_query: + data_dash = ratio * ( + torch.exp(data_dash - diag_data - + torch.max(data_dash, dim=-1, keepdim=True).values) + eps) + else: + data_dash = ratio * ( + torch.exp(data_dash - diag_data + eps))#- torch.max(data_dash)) + eps) + + return data_dash.type_as(data) + +def orthogonal_matrix_chunk(cols, qr_uniform_q = False, device = None): + unstructured_block = torch.randn((cols, cols), device = device) + q, r = torch.linalg.qr(unstructured_block.cpu(), mode='reduced') + q, r = map(lambda t: t.to(device), (q, r)) + + # proposed by @Parskatt + # to make sure Q is uniform https://arxiv.org/pdf/math-ph/0609050.pdf + if qr_uniform_q: + d = torch.diag(r, 0) + q *= d.sign() + return q.t() +def exists(val): + return val is not None + +def empty(tensor): + return tensor.numel() == 0 + +def default(val, d): + return val if exists(val) else d + +def cast_tuple(val): + return (val,) if not isinstance(val, tuple) else val + +class PCmer(nn.Module): + """The encoder that is used in the Transformer model.""" + + def __init__(self, + num_layers, + num_heads, + dim_model, + dim_keys, + dim_values, + residual_dropout, + attention_dropout): + super().__init__() + self.num_layers = num_layers + self.num_heads = num_heads + self.dim_model = dim_model + self.dim_values = dim_values + self.dim_keys = dim_keys + self.residual_dropout = residual_dropout + self.attention_dropout = attention_dropout + + self._layers = nn.ModuleList([_EncoderLayer(self) for _ in range(num_layers)]) + + # METHODS ######################################################################################################## + + def forward(self, phone, mask=None): + + # apply all layers to the input + for (i, layer) in enumerate(self._layers): + phone = layer(phone, mask) + # provide the final sequence + return phone + + +# ==================================================================================================================== # +# CLASS _ E N C O D E R L A Y E R # +# ==================================================================================================================== # + + +class _EncoderLayer(nn.Module): + """One layer of the encoder. + + Attributes: + attn: (:class:`mha.MultiHeadAttention`): The attention mechanism that is used to read the input sequence. + feed_forward (:class:`ffl.FeedForwardLayer`): The feed-forward layer on top of the attention mechanism. + """ + + def __init__(self, parent: PCmer): + """Creates a new instance of ``_EncoderLayer``. + + Args: + parent (Encoder): The encoder that the layers is created for. + """ + super().__init__() + + + self.conformer = ConformerConvModule(parent.dim_model) + self.norm = nn.LayerNorm(parent.dim_model) + self.dropout = nn.Dropout(parent.residual_dropout) + + # selfatt -> fastatt: performer! + self.attn = SelfAttention(dim = parent.dim_model, + heads = parent.num_heads, + causal = False) + + # METHODS ######################################################################################################## + + def forward(self, phone, mask=None): + + # compute attention sub-layer + phone = phone + (self.attn(self.norm(phone), mask=mask)) + + phone = phone + (self.conformer(phone)) + + return phone + +def calc_same_padding(kernel_size): + pad = kernel_size // 2 + return (pad, pad - (kernel_size + 1) % 2) + +# helper classes + +class Swish(nn.Module): + def forward(self, x): + return x * x.sigmoid() + +class Transpose(nn.Module): + def __init__(self, dims): + super().__init__() + assert len(dims) == 2, 'dims must be a tuple of two dimensions' + self.dims = dims + + def forward(self, x): + return x.transpose(*self.dims) + +class GLU(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + out, gate = x.chunk(2, dim=self.dim) + return out * gate.sigmoid() + +class DepthWiseConv1d(nn.Module): + def __init__(self, chan_in, chan_out, kernel_size, padding): + super().__init__() + self.padding = padding + self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, groups = chan_in) + + def forward(self, x): + x = F.pad(x, self.padding) + return self.conv(x) + +class ConformerConvModule(nn.Module): + def __init__( + self, + dim, + causal = False, + expansion_factor = 2, + kernel_size = 31, + dropout = 0.): + super().__init__() + + inner_dim = dim * expansion_factor + padding = calc_same_padding(kernel_size) if not causal else (kernel_size - 1, 0) + + self.net = nn.Sequential( + nn.LayerNorm(dim), + Transpose((1, 2)), + nn.Conv1d(dim, inner_dim * 2, 1), + GLU(dim=1), + DepthWiseConv1d(inner_dim, inner_dim, kernel_size = kernel_size, padding = padding), + #nn.BatchNorm1d(inner_dim) if not causal else nn.Identity(), + Swish(), + nn.Conv1d(inner_dim, dim, 1), + Transpose((1, 2)), + nn.Dropout(dropout) + ) + + def forward(self, x): + return self.net(x) + +def linear_attention(q, k, v): + if v is None: + #print (k.size(), q.size()) + out = torch.einsum('...ed,...nd->...ne', k, q) + return out + + else: + k_cumsum = k.sum(dim = -2) + #k_cumsum = k.sum(dim = -2) + D_inv = 1. / (torch.einsum('...nd,...d->...n', q, k_cumsum.type_as(q)) + 1e-8) + + context = torch.einsum('...nd,...ne->...de', k, v) + #print ("TRUEEE: ", context.size(), q.size(), D_inv.size()) + out = torch.einsum('...de,...nd,...n->...ne', context, q, D_inv) + return out + +def gaussian_orthogonal_random_matrix(nb_rows, nb_columns, scaling = 0, qr_uniform_q = False, device = None): + nb_full_blocks = int(nb_rows / nb_columns) + #print (nb_full_blocks) + block_list = [] + + for _ in range(nb_full_blocks): + q = orthogonal_matrix_chunk(nb_columns, qr_uniform_q = qr_uniform_q, device = device) + block_list.append(q) + # block_list[n] is a orthogonal matrix ... (model_dim * model_dim) + #print (block_list[0].size(), torch.einsum('...nd,...nd->...n', block_list[0], torch.roll(block_list[0],1,1))) + #print (nb_rows, nb_full_blocks, nb_columns) + remaining_rows = nb_rows - nb_full_blocks * nb_columns + #print (remaining_rows) + if remaining_rows > 0: + q = orthogonal_matrix_chunk(nb_columns, qr_uniform_q = qr_uniform_q, device = device) + #print (q[:remaining_rows].size()) + block_list.append(q[:remaining_rows]) + + final_matrix = torch.cat(block_list) + + if scaling == 0: + multiplier = torch.randn((nb_rows, nb_columns), device = device).norm(dim = 1) + elif scaling == 1: + multiplier = math.sqrt((float(nb_columns))) * torch.ones((nb_rows,), device = device) + else: + raise ValueError(f'Invalid scaling {scaling}') + + return torch.diag(multiplier) @ final_matrix + +class FastAttention(nn.Module): + def __init__(self, dim_heads, nb_features = None, ortho_scaling = 0, causal = False, generalized_attention = False, kernel_fn = nn.ReLU(), qr_uniform_q = False, no_projection = False): + super().__init__() + nb_features = default(nb_features, int(dim_heads * math.log(dim_heads))) + + self.dim_heads = dim_heads + self.nb_features = nb_features + self.ortho_scaling = ortho_scaling + + self.create_projection = partial(gaussian_orthogonal_random_matrix, nb_rows = self.nb_features, nb_columns = dim_heads, scaling = ortho_scaling, qr_uniform_q = qr_uniform_q) + projection_matrix = self.create_projection() + self.register_buffer('projection_matrix', projection_matrix) + + self.generalized_attention = generalized_attention + self.kernel_fn = kernel_fn + + # if this is turned on, no projection will be used + # queries and keys will be softmax-ed as in the original efficient attention paper + self.no_projection = no_projection + + self.causal = causal + + @torch.no_grad() + def redraw_projection_matrix(self): + projections = self.create_projection() + self.projection_matrix.copy_(projections) + del projections + + def forward(self, q, k, v): + device = q.device + + if self.no_projection: + q = q.softmax(dim = -1) + k = torch.exp(k) if self.causal else k.softmax(dim = -2) + else: + create_kernel = partial(softmax_kernel, projection_matrix = self.projection_matrix, device = device) + + q = create_kernel(q, is_query = True) + k = create_kernel(k, is_query = False) + + attn_fn = linear_attention if not self.causal else self.causal_linear_fn + if v is None: + out = attn_fn(q, k, None) + return out + else: + out = attn_fn(q, k, v) + return out +class SelfAttention(nn.Module): + def __init__(self, dim, causal = False, heads = 8, dim_head = 64, local_heads = 0, local_window_size = 256, nb_features = None, feature_redraw_interval = 1000, generalized_attention = False, kernel_fn = nn.ReLU(), qr_uniform_q = False, dropout = 0., no_projection = False): + super().__init__() + assert dim % heads == 0, 'dimension must be divisible by number of heads' + dim_head = default(dim_head, dim // heads) + inner_dim = dim_head * heads + self.fast_attention = FastAttention(dim_head, nb_features, causal = causal, generalized_attention = generalized_attention, kernel_fn = kernel_fn, qr_uniform_q = qr_uniform_q, no_projection = no_projection) + + self.heads = heads + self.global_heads = heads - local_heads + self.local_attn = LocalAttention(window_size = local_window_size, causal = causal, autopad = True, dropout = dropout, look_forward = int(not causal), rel_pos_emb_config = (dim_head, local_heads)) if local_heads > 0 else None + + #print (heads, nb_features, dim_head) + #name_embedding = torch.zeros(110, heads, dim_head, dim_head) + #self.name_embedding = nn.Parameter(name_embedding, requires_grad=True) + + + self.to_q = nn.Linear(dim, inner_dim) + self.to_k = nn.Linear(dim, inner_dim) + self.to_v = nn.Linear(dim, inner_dim) + self.to_out = nn.Linear(inner_dim, dim) + self.dropout = nn.Dropout(dropout) + + @torch.no_grad() + def redraw_projection_matrix(self): + self.fast_attention.redraw_projection_matrix() + #torch.nn.init.zeros_(self.name_embedding) + #print (torch.sum(self.name_embedding)) + def forward(self, x, context = None, mask = None, context_mask = None, name=None, inference=False, **kwargs): + _, _, _, h, gh = *x.shape, self.heads, self.global_heads + + cross_attend = exists(context) + + context = default(context, x) + context_mask = default(context_mask, mask) if not cross_attend else context_mask + #print (torch.sum(self.name_embedding)) + q, k, v = self.to_q(x), self.to_k(context), self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) + (q, lq), (k, lk), (v, lv) = map(lambda t: (t[:, :gh], t[:, gh:]), (q, k, v)) + + attn_outs = [] + #print (name) + #print (self.name_embedding[name].size()) + if not empty(q): + if exists(context_mask): + global_mask = context_mask[:, None, :, None] + v.masked_fill_(~global_mask, 0.) + if cross_attend: + pass + #print (torch.sum(self.name_embedding)) + #out = self.fast_attention(q,self.name_embedding[name],None) + #print (torch.sum(self.name_embedding[...,-1:])) + else: + out = self.fast_attention(q, k, v) + attn_outs.append(out) + + if not empty(lq): + assert not cross_attend, 'local attention is not compatible with cross attention' + out = self.local_attn(lq, lk, lv, input_mask = mask) + attn_outs.append(out) + + out = torch.cat(attn_outs, dim = 1) + out = rearrange(out, 'b h n d -> b n (h d)') + out = self.to_out(out) + return self.dropout(out) \ No newline at end of file diff --git a/modules/F0Predictor/rmvpe/__init__.py b/modules/F0Predictor/rmvpe/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e2dcf9e971ac4fcea29fe2e312d591fd0447f95d --- /dev/null +++ b/modules/F0Predictor/rmvpe/__init__.py @@ -0,0 +1,10 @@ +from .constants import * # noqa: F403 +from .inference import RMVPE # noqa: F401 +from .model import E2E, E2E0 # noqa: F401 +from .spec import MelSpectrogram # noqa: F401 +from .utils import ( # noqa: F401 + cycle, + summary, + to_local_average_cents, + to_viterbi_cents, +) diff --git a/modules/F0Predictor/rmvpe/constants.py b/modules/F0Predictor/rmvpe/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..c5f52efc9b40f49bb746dae6807a817bffce4375 --- /dev/null +++ b/modules/F0Predictor/rmvpe/constants.py @@ -0,0 +1,9 @@ +SAMPLE_RATE = 16000 + +N_CLASS = 360 + +N_MELS = 128 +MEL_FMIN = 30 +MEL_FMAX = SAMPLE_RATE // 2 +WINDOW_LENGTH = 1024 +CONST = 1997.3794084376191 diff --git a/modules/F0Predictor/rmvpe/deepunet.py b/modules/F0Predictor/rmvpe/deepunet.py new file mode 100644 index 0000000000000000000000000000000000000000..b0171d562ac58526c7693a15124e181c78ad0a18 --- /dev/null +++ b/modules/F0Predictor/rmvpe/deepunet.py @@ -0,0 +1,190 @@ +import torch +import torch.nn as nn + +from .constants import N_MELS + + +class ConvBlockRes(nn.Module): + def __init__(self, in_channels, out_channels, momentum=0.01): + super(ConvBlockRes, self).__init__() + self.conv = nn.Sequential( + nn.Conv2d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=(3, 3), + stride=(1, 1), + padding=(1, 1), + bias=False), + nn.BatchNorm2d(out_channels, momentum=momentum), + nn.ReLU(), + + nn.Conv2d(in_channels=out_channels, + out_channels=out_channels, + kernel_size=(3, 3), + stride=(1, 1), + padding=(1, 1), + bias=False), + nn.BatchNorm2d(out_channels, momentum=momentum), + nn.ReLU(), + ) + if in_channels != out_channels: + self.shortcut = nn.Conv2d(in_channels, out_channels, (1, 1)) + self.is_shortcut = True + else: + self.is_shortcut = False + + def forward(self, x): + if self.is_shortcut: + return self.conv(x) + self.shortcut(x) + else: + return self.conv(x) + x + + +class ResEncoderBlock(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, n_blocks=1, momentum=0.01): + super(ResEncoderBlock, self).__init__() + self.n_blocks = n_blocks + self.conv = nn.ModuleList() + self.conv.append(ConvBlockRes(in_channels, out_channels, momentum)) + for i in range(n_blocks - 1): + self.conv.append(ConvBlockRes(out_channels, out_channels, momentum)) + self.kernel_size = kernel_size + if self.kernel_size is not None: + self.pool = nn.AvgPool2d(kernel_size=kernel_size) + + def forward(self, x): + for i in range(self.n_blocks): + x = self.conv[i](x) + if self.kernel_size is not None: + return x, self.pool(x) + else: + return x + + +class ResDecoderBlock(nn.Module): + def __init__(self, in_channels, out_channels, stride, n_blocks=1, momentum=0.01): + super(ResDecoderBlock, self).__init__() + out_padding = (0, 1) if stride == (1, 2) else (1, 1) + self.n_blocks = n_blocks + self.conv1 = nn.Sequential( + nn.ConvTranspose2d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=(3, 3), + stride=stride, + padding=(1, 1), + output_padding=out_padding, + bias=False), + nn.BatchNorm2d(out_channels, momentum=momentum), + nn.ReLU(), + ) + self.conv2 = nn.ModuleList() + self.conv2.append(ConvBlockRes(out_channels * 2, out_channels, momentum)) + for i in range(n_blocks-1): + self.conv2.append(ConvBlockRes(out_channels, out_channels, momentum)) + + def forward(self, x, concat_tensor): + x = self.conv1(x) + x = torch.cat((x, concat_tensor), dim=1) + for i in range(self.n_blocks): + x = self.conv2[i](x) + return x + + +class Encoder(nn.Module): + def __init__(self, in_channels, in_size, n_encoders, kernel_size, n_blocks, out_channels=16, momentum=0.01): + super(Encoder, self).__init__() + self.n_encoders = n_encoders + self.bn = nn.BatchNorm2d(in_channels, momentum=momentum) + self.layers = nn.ModuleList() + self.latent_channels = [] + for i in range(self.n_encoders): + self.layers.append(ResEncoderBlock(in_channels, out_channels, kernel_size, n_blocks, momentum=momentum)) + self.latent_channels.append([out_channels, in_size]) + in_channels = out_channels + out_channels *= 2 + in_size //= 2 + self.out_size = in_size + self.out_channel = out_channels + + def forward(self, x): + concat_tensors = [] + x = self.bn(x) + for i in range(self.n_encoders): + _, x = self.layers[i](x) + concat_tensors.append(_) + return x, concat_tensors + + +class Intermediate(nn.Module): + def __init__(self, in_channels, out_channels, n_inters, n_blocks, momentum=0.01): + super(Intermediate, self).__init__() + self.n_inters = n_inters + self.layers = nn.ModuleList() + self.layers.append(ResEncoderBlock(in_channels, out_channels, None, n_blocks, momentum)) + for i in range(self.n_inters-1): + self.layers.append(ResEncoderBlock(out_channels, out_channels, None, n_blocks, momentum)) + + def forward(self, x): + for i in range(self.n_inters): + x = self.layers[i](x) + return x + + +class Decoder(nn.Module): + def __init__(self, in_channels, n_decoders, stride, n_blocks, momentum=0.01): + super(Decoder, self).__init__() + self.layers = nn.ModuleList() + self.n_decoders = n_decoders + for i in range(self.n_decoders): + out_channels = in_channels // 2 + self.layers.append(ResDecoderBlock(in_channels, out_channels, stride, n_blocks, momentum)) + in_channels = out_channels + + def forward(self, x, concat_tensors): + for i in range(self.n_decoders): + x = self.layers[i](x, concat_tensors[-1-i]) + return x + + +class TimbreFilter(nn.Module): + def __init__(self, latent_rep_channels): + super(TimbreFilter, self).__init__() + self.layers = nn.ModuleList() + for latent_rep in latent_rep_channels: + self.layers.append(ConvBlockRes(latent_rep[0], latent_rep[0])) + + def forward(self, x_tensors): + out_tensors = [] + for i, layer in enumerate(self.layers): + out_tensors.append(layer(x_tensors[i])) + return out_tensors + + +class DeepUnet(nn.Module): + def __init__(self, kernel_size, n_blocks, en_de_layers=5, inter_layers=4, in_channels=1, en_out_channels=16): + super(DeepUnet, self).__init__() + self.encoder = Encoder(in_channels, N_MELS, en_de_layers, kernel_size, n_blocks, en_out_channels) + self.intermediate = Intermediate(self.encoder.out_channel // 2, self.encoder.out_channel, inter_layers, n_blocks) + self.tf = TimbreFilter(self.encoder.latent_channels) + self.decoder = Decoder(self.encoder.out_channel, en_de_layers, kernel_size, n_blocks) + + def forward(self, x): + x, concat_tensors = self.encoder(x) + x = self.intermediate(x) + concat_tensors = self.tf(concat_tensors) + x = self.decoder(x, concat_tensors) + return x + + +class DeepUnet0(nn.Module): + def __init__(self, kernel_size, n_blocks, en_de_layers=5, inter_layers=4, in_channels=1, en_out_channels=16): + super(DeepUnet0, self).__init__() + self.encoder = Encoder(in_channels, N_MELS, en_de_layers, kernel_size, n_blocks, en_out_channels) + self.intermediate = Intermediate(self.encoder.out_channel // 2, self.encoder.out_channel, inter_layers, n_blocks) + self.tf = TimbreFilter(self.encoder.latent_channels) + self.decoder = Decoder(self.encoder.out_channel, en_de_layers, kernel_size, n_blocks) + + def forward(self, x): + x, concat_tensors = self.encoder(x) + x = self.intermediate(x) + x = self.decoder(x, concat_tensors) + return x diff --git a/modules/F0Predictor/rmvpe/inference.py b/modules/F0Predictor/rmvpe/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..02d21881e5ccbf969759f4ef8030abce3083ce8c --- /dev/null +++ b/modules/F0Predictor/rmvpe/inference.py @@ -0,0 +1,57 @@ +import torch +import torch.nn.functional as F +from torchaudio.transforms import Resample + +from .constants import * # noqa: F403 +from .model import E2E0 +from .spec import MelSpectrogram +from .utils import to_local_average_cents, to_viterbi_cents + + +class RMVPE: + def __init__(self, model_path, device=None, dtype = torch.float32, hop_length=160): + self.resample_kernel = {} + if device is None: + self.device = 'cuda' if torch.cuda.is_available() else 'cpu' + else: + self.device = device + model = E2E0(4, 1, (2, 2)) + ckpt = torch.load(model_path, map_location=torch.device(self.device)) + model.load_state_dict(ckpt['model']) + model = model.to(dtype).to(self.device) + model.eval() + self.model = model + self.dtype = dtype + self.mel_extractor = MelSpectrogram(N_MELS, SAMPLE_RATE, WINDOW_LENGTH, hop_length, None, MEL_FMIN, MEL_FMAX) # noqa: F405 + self.resample_kernel = {} + + def mel2hidden(self, mel): + with torch.no_grad(): + n_frames = mel.shape[-1] + mel = F.pad(mel, (0, 32 * ((n_frames - 1) // 32 + 1) - n_frames), mode='constant') + hidden = self.model(mel) + return hidden[:, :n_frames] + + def decode(self, hidden, thred=0.03, use_viterbi=False): + if use_viterbi: + cents_pred = to_viterbi_cents(hidden, thred=thred) + else: + cents_pred = to_local_average_cents(hidden, thred=thred) + f0 = torch.Tensor([10 * (2 ** (cent_pred / 1200)) if cent_pred else 0 for cent_pred in cents_pred]).to(self.device) + return f0 + + def infer_from_audio(self, audio, sample_rate=16000, thred=0.05, use_viterbi=False): + audio = audio.unsqueeze(0).to(self.dtype).to(self.device) + if sample_rate == 16000: + audio_res = audio + else: + key_str = str(sample_rate) + if key_str not in self.resample_kernel: + self.resample_kernel[key_str] = Resample(sample_rate, 16000, lowpass_filter_width=128) + self.resample_kernel[key_str] = self.resample_kernel[key_str].to(self.dtype).to(self.device) + audio_res = self.resample_kernel[key_str](audio) + mel_extractor = self.mel_extractor.to(self.device) + mel = mel_extractor(audio_res, center=True).to(self.dtype) + hidden = self.mel2hidden(mel) + f0 = self.decode(hidden.squeeze(0), thred=thred, use_viterbi=use_viterbi) + return f0 diff --git a/modules/F0Predictor/rmvpe/model.py b/modules/F0Predictor/rmvpe/model.py new file mode 100644 index 0000000000000000000000000000000000000000..f1b6b643b113a0eee9a9142016c15444273002c5 --- /dev/null +++ b/modules/F0Predictor/rmvpe/model.py @@ -0,0 +1,67 @@ +from torch import nn + +from .constants import * # noqa: F403 +from .deepunet import DeepUnet, DeepUnet0 +from .seq import BiGRU +from .spec import MelSpectrogram + + +class E2E(nn.Module): + def __init__(self, hop_length, n_blocks, n_gru, kernel_size, en_de_layers=5, inter_layers=4, in_channels=1, + en_out_channels=16): + super(E2E, self).__init__() + self.mel = MelSpectrogram(N_MELS, SAMPLE_RATE, WINDOW_LENGTH, hop_length, None, MEL_FMIN, MEL_FMAX) # noqa: F405 + self.unet = DeepUnet(kernel_size, n_blocks, en_de_layers, inter_layers, in_channels, en_out_channels) + self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1)) + if n_gru: + self.fc = nn.Sequential( + BiGRU(3 * N_MELS, 256, n_gru), # noqa: F405 + nn.Linear(512, N_CLASS), # noqa: F405 + nn.Dropout(0.25), + nn.Sigmoid() + ) + else: + self.fc = nn.Sequential( + nn.Linear(3 * N_MELS, N_CLASS), # noqa: F405 + nn.Dropout(0.25), + nn.Sigmoid() + ) + + def forward(self, x): + mel = self.mel(x.reshape(-1, x.shape[-1])).transpose(-1, -2).unsqueeze(1) + x = self.cnn(self.unet(mel)).transpose(1, 2).flatten(-2) + # x = self.fc(x) + hidden_vec = 0 + if len(self.fc) == 4: + for i in range(len(self.fc)): + x = self.fc[i](x) + if i == 0: + hidden_vec = x + return hidden_vec, x + + +class E2E0(nn.Module): + def __init__(self, n_blocks, n_gru, kernel_size, en_de_layers=5, inter_layers=4, in_channels=1, + en_out_channels=16): + super(E2E0, self).__init__() + self.unet = DeepUnet0(kernel_size, n_blocks, en_de_layers, inter_layers, in_channels, en_out_channels) + self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1)) + if n_gru: + self.fc = nn.Sequential( + BiGRU(3 * N_MELS, 256, n_gru), # noqa: F405 + nn.Linear(512, N_CLASS), # noqa: F405 + nn.Dropout(0.25), + nn.Sigmoid() + ) + else: + self.fc = nn.Sequential( + nn.Linear(3 * N_MELS, N_CLASS), # noqa: F405 + nn.Dropout(0.25), + nn.Sigmoid() + ) + + def forward(self, mel): + mel = mel.transpose(-1, -2).unsqueeze(1) + x = self.cnn(self.unet(mel)).transpose(1, 2).flatten(-2) + x = self.fc(x) + return x diff --git a/modules/F0Predictor/rmvpe/seq.py b/modules/F0Predictor/rmvpe/seq.py new file mode 100644 index 0000000000000000000000000000000000000000..0d48e49d72e14d34f048ca0b5824ea1f335e9a0d --- /dev/null +++ b/modules/F0Predictor/rmvpe/seq.py @@ -0,0 +1,20 @@ +import torch.nn as nn + + +class BiGRU(nn.Module): + def __init__(self, input_features, hidden_features, num_layers): + super(BiGRU, self).__init__() + self.gru = nn.GRU(input_features, hidden_features, num_layers=num_layers, batch_first=True, bidirectional=True) + + def forward(self, x): + return self.gru(x)[0] + + +class BiLSTM(nn.Module): + def __init__(self, input_features, hidden_features, num_layers): + super(BiLSTM, self).__init__() + self.lstm = nn.LSTM(input_features, hidden_features, num_layers=num_layers, batch_first=True, bidirectional=True) + + def forward(self, x): + return self.lstm(x)[0] + diff --git a/modules/F0Predictor/rmvpe/spec.py b/modules/F0Predictor/rmvpe/spec.py new file mode 100644 index 0000000000000000000000000000000000000000..349d05e4541ccad31cbbb24372a89cda7c0aacc0 --- /dev/null +++ b/modules/F0Predictor/rmvpe/spec.py @@ -0,0 +1,67 @@ +import numpy as np +import torch +import torch.nn.functional as F +from librosa.filters import mel + + +class MelSpectrogram(torch.nn.Module): + def __init__( + self, + n_mel_channels, + sampling_rate, + win_length, + hop_length, + n_fft=None, + mel_fmin=0, + mel_fmax=None, + clamp = 1e-5 + ): + super().__init__() + n_fft = win_length if n_fft is None else n_fft + self.hann_window = {} + mel_basis = mel( + sr=sampling_rate, + n_fft=n_fft, + n_mels=n_mel_channels, + fmin=mel_fmin, + fmax=mel_fmax, + htk=True) + mel_basis = torch.from_numpy(mel_basis).float() + self.register_buffer("mel_basis", mel_basis) + self.n_fft = win_length if n_fft is None else n_fft + self.hop_length = hop_length + self.win_length = win_length + self.sampling_rate = sampling_rate + self.n_mel_channels = n_mel_channels + self.clamp = clamp + + def forward(self, audio, keyshift=0, speed=1, center=True): + factor = 2 ** (keyshift / 12) + n_fft_new = int(np.round(self.n_fft * factor)) + win_length_new = int(np.round(self.win_length * factor)) + hop_length_new = int(np.round(self.hop_length * speed)) + + keyshift_key = str(keyshift)+'_'+str(audio.device) + if keyshift_key not in self.hann_window: + self.hann_window[keyshift_key] = torch.hann_window(win_length_new).to(audio.device) + + fft = torch.stft( + audio, + n_fft=n_fft_new, + hop_length=hop_length_new, + win_length=win_length_new, + window=self.hann_window[keyshift_key], + center=center, + return_complex=True) + magnitude = torch.sqrt(fft.real.pow(2) + fft.imag.pow(2)) + + if keyshift != 0: + size = self.n_fft // 2 + 1 + resize = magnitude.size(1) + if resize < size: + magnitude = F.pad(magnitude, (0, 0, 0, size-resize)) + magnitude = magnitude[:, :size, :] * self.win_length / win_length_new + + mel_output = torch.matmul(self.mel_basis, magnitude) + log_mel_spec = torch.log(torch.clamp(mel_output, min=self.clamp)) + return log_mel_spec \ No newline at end of file diff --git a/modules/F0Predictor/rmvpe/utils.py b/modules/F0Predictor/rmvpe/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4395255f8608da2bce0b1f15d6bd2b2bd02c1fe7 --- /dev/null +++ b/modules/F0Predictor/rmvpe/utils.py @@ -0,0 +1,107 @@ +import sys +from functools import reduce + +import librosa +import numpy as np +import torch +from torch.nn.modules.module import _addindent + +from .constants import * # noqa: F403 + + +def cycle(iterable): + while True: + for item in iterable: + yield item + + +def summary(model, file=sys.stdout): + def repr(model): + # We treat the extra repr like the sub-module, one item per line + extra_lines = [] + extra_repr = model.extra_repr() + # empty string will be split into list [''] + if extra_repr: + extra_lines = extra_repr.split('\n') + child_lines = [] + total_params = 0 + for key, module in model._modules.items(): + mod_str, num_params = repr(module) + mod_str = _addindent(mod_str, 2) + child_lines.append('(' + key + '): ' + mod_str) + total_params += num_params + lines = extra_lines + child_lines + + for name, p in model._parameters.items(): + if hasattr(p, 'shape'): + total_params += reduce(lambda x, y: x * y, p.shape) + + main_str = model._get_name() + '(' + if lines: + # simple one-liner info, which most builtin Modules will use + if len(extra_lines) == 1 and not child_lines: + main_str += extra_lines[0] + else: + main_str += '\n ' + '\n '.join(lines) + '\n' + + main_str += ')' + if file is sys.stdout: + main_str += ', \033[92m{:,}\033[0m params'.format(total_params) + else: + main_str += ', {:,} params'.format(total_params) + return main_str, total_params + + string, count = repr(model) + if file is not None: + if isinstance(file, str): + file = open(file, 'w') + print(string, file=file) + file.flush() + + return count + + +def to_local_average_cents(salience, center=None, thred=0.05): + """ + find the weighted average cents near the argmax bin + """ + + if not hasattr(to_local_average_cents, 'cents_mapping'): + # the bin number-to-cents mapping + to_local_average_cents.cents_mapping = ( + 20 * torch.arange(N_CLASS) + CONST).to(salience.device) # noqa: F405 + + if salience.ndim == 1: + if center is None: + center = int(torch.argmax(salience)) + start = max(0, center - 4) + end = min(len(salience), center + 5) + salience = salience[start:end] + product_sum = torch.sum( + salience * to_local_average_cents.cents_mapping[start:end]) + weight_sum = torch.sum(salience) + return product_sum / weight_sum if torch.max(salience) > thred else 0 + if salience.ndim == 2: + return torch.Tensor([to_local_average_cents(salience[i, :], None, thred) for i in + range(salience.shape[0])]).to(salience.device) + + raise Exception("label should be either 1d or 2d ndarray") + +def to_viterbi_cents(salience, thred=0.05): + # Create viterbi transition matrix + if not hasattr(to_viterbi_cents, 'transition'): + xx, yy = torch.meshgrid(range(N_CLASS), range(N_CLASS)) # noqa: F405 + transition = torch.maximum(30 - abs(xx - yy), 0) + transition = transition / transition.sum(axis=1, keepdims=True) + to_viterbi_cents.transition = transition + + # Convert to probability + prob = salience.T + prob = prob / prob.sum(axis=0) + + # Perform viterbi decoding + path = librosa.sequence.viterbi(prob.detach().cpu().numpy(), to_viterbi_cents.transition).astype(np.int64) + + return torch.Tensor([to_local_average_cents(salience[i, :], path[i], thred) for i in + range(len(path))]).to(salience.device) + \ No newline at end of file diff --git a/modules/attentions.py b/modules/attentions.py index f9c11ca4a3acb86bf1abc04d9dcfa82a4ed4061f..f9d75bc65e45f8e27460c18e0d267605a752f013 100644 --- a/modules/attentions.py +++ b/modules/attentions.py @@ -1,18 +1,17 @@ -import copy import math -import numpy as np + import torch from torch import nn from torch.nn import functional as F import modules.commons as commons -import modules.modules as modules +from modules.DSConv import weight_norm_modules from modules.modules import LayerNorm class FFT(nn.Module): def __init__(self, hidden_channels, filter_channels, n_heads, n_layers=1, kernel_size=1, p_dropout=0., - proximal_bias=False, proximal_init=True, **kwargs): + proximal_bias=False, proximal_init=True, isflow = False, **kwargs): super().__init__() self.hidden_channels = hidden_channels self.filter_channels = filter_channels @@ -22,7 +21,11 @@ class FFT(nn.Module): self.p_dropout = p_dropout self.proximal_bias = proximal_bias self.proximal_init = proximal_init - + if isflow: + cond_layer = torch.nn.Conv1d(kwargs["gin_channels"], 2*hidden_channels*n_layers, 1) + self.cond_pre = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, 1) + self.cond_layer = weight_norm_modules(cond_layer, name='weight') + self.gin_channels = kwargs["gin_channels"] self.drop = nn.Dropout(p_dropout) self.self_attn_layers = nn.ModuleList() self.norm_layers_0 = nn.ModuleList() @@ -37,14 +40,25 @@ class FFT(nn.Module): FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout, causal=True)) self.norm_layers_1.append(LayerNorm(hidden_channels)) - def forward(self, x, x_mask): + def forward(self, x, x_mask, g = None): """ x: decoder input h: encoder output """ + if g is not None: + g = self.cond_layer(g) + self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype) x = x * x_mask for i in range(self.n_layers): + if g is not None: + x = self.cond_pre(x) + cond_offset = i * 2 * self.hidden_channels + g_l = g[:,cond_offset:cond_offset+2*self.hidden_channels,:] + x = commons.fused_add_tanh_sigmoid_multiply( + x, + g_l, + torch.IntTensor([self.hidden_channels])) y = self.self_attn_layers[i](x, x, self_attn_mask) y = self.drop(y) x = self.norm_layers_0[i](x + y) @@ -243,7 +257,7 @@ class MultiHeadAttention(nn.Module): return ret def _get_relative_embeddings(self, relative_embeddings, length): - max_relative_position = 2 * self.window_size + 1 + 2 * self.window_size + 1 # Pad first before slice to avoid using cond ops. pad_length = max(length - (self.window_size + 1), 0) slice_start_position = max((self.window_size + 1) - length, 0) diff --git a/modules/commons.py b/modules/commons.py index 074888006392e956ce204d8368362dbb2cd4e304..761379da55793b7f2eca1c9ba511ec767ac1d90e 100644 --- a/modules/commons.py +++ b/modules/commons.py @@ -1,9 +1,9 @@ import math -import numpy as np + import torch -from torch import nn from torch.nn import functional as F + def slice_pitch_segments(x, ids_str, segment_size=4): ret = torch.zeros_like(x[:, :segment_size]) for i in range(x.size(0)): @@ -24,10 +24,12 @@ def rand_slice_segments_with_pitch(x, pitch, x_lengths=None, segment_size=4): def init_weights(m, mean=0.0, std=0.01): classname = m.__class__.__name__ - if classname.find("Conv") != -1: + if "Depthwise_Separable" in classname: + m.depth_conv.weight.data.normal_(mean, std) + m.point_conv.weight.data.normal_(mean, std) + elif classname.find("Conv") != -1: m.weight.data.normal_(mean, std) - def get_padding(kernel_size, dilation=1): return int((kernel_size*dilation - dilation)/2) @@ -134,12 +136,6 @@ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): return acts -def convert_pad_shape(pad_shape): - l = pad_shape[::-1] - pad_shape = [item for sublist in l for item in sublist] - return pad_shape - - def shift_1d(x): x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] return x @@ -157,7 +153,6 @@ def generate_path(duration, mask): duration: [b, 1, t_x] mask: [b, 1, t_y, t_x] """ - device = duration.device b, _, t_y, t_x = mask.shape cum_duration = torch.cumsum(duration, -1) diff --git a/modules/enhancer.py b/modules/enhancer.py index 37676311f7d8dc4ddc2a5244dedc27b2437e04f5..a3f0dd0460ff6d6153f9277dfa90763bc03861db 100644 --- a/modules/enhancer.py +++ b/modules/enhancer.py @@ -1,10 +1,12 @@ import numpy as np import torch import torch.nn.functional as F -from vdecoder.nsf_hifigan.nvSTFT import STFT -from vdecoder.nsf_hifigan.models import load_model from torchaudio.transforms import Resample +from vdecoder.nsf_hifigan.models import load_model +from vdecoder.nsf_hifigan.nvSTFT import STFT + + class Enhancer: def __init__(self, enhancer_type, enhancer_ckpt, device=None): if device is None: diff --git a/modules/losses.py b/modules/losses.py index cd21799eccde350c3aac0bdd661baf96ed220147..494e979a60ba069114cac609bf6454a99c1019e3 100644 --- a/modules/losses.py +++ b/modules/losses.py @@ -1,7 +1,4 @@ -import torch -from torch.nn import functional as F - -import modules.commons as commons +import torch def feature_loss(fmap_r, fmap_g): diff --git a/modules/mel_processing.py b/modules/mel_processing.py index 99c5b35beb83f3b288af0fac5b49ebf2c69f062c..c21e4bffb6d9f5fd7b45a84176b3e6206f7d29db 100644 --- a/modules/mel_processing.py +++ b/modules/mel_processing.py @@ -1,16 +1,5 @@ -import math -import os -import random import torch -from torch import nn -import torch.nn.functional as F import torch.utils.data -import numpy as np -import librosa -import librosa.util as librosa_util -from librosa.util import normalize, pad_center, tiny -from scipy.signal import get_window -from scipy.io.wavfile import read from librosa.filters import mel as librosa_mel_fn MAX_WAV_VALUE = 32768.0 @@ -62,9 +51,14 @@ def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False) y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') y = y.squeeze(1) + + y_dtype = y.dtype + if y.dtype == torch.bfloat16: + y = y.to(torch.float32) spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device], - center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False) + center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True) + spec = torch.view_as_real(spec).to(y_dtype) spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) return spec @@ -83,30 +77,7 @@ def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax): def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): - if torch.min(y) < -1.: - print('min value is ', torch.min(y)) - if torch.max(y) > 1.: - print('max value is ', torch.max(y)) - - global mel_basis, hann_window - dtype_device = str(y.dtype) + '_' + str(y.device) - fmax_dtype_device = str(fmax) + '_' + dtype_device - wnsize_dtype_device = str(win_size) + '_' + dtype_device - if fmax_dtype_device not in mel_basis: - mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) - mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device) - if wnsize_dtype_device not in hann_window: - hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) - - y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') - y = y.squeeze(1) - - spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device], - center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False) - - spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) - - spec = torch.matmul(mel_basis[fmax_dtype_device], spec) - spec = spectral_normalize_torch(spec) - + spec = spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center) + spec = spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax) + return spec diff --git a/modules/modules.py b/modules/modules.py index 54290fd207b25e93831bd21005990ea137e6b50e..a622d4f264a8d89a62a1b549efa71f4c37eb7ca1 100644 --- a/modules/modules.py +++ b/modules/modules.py @@ -1,20 +1,24 @@ -import copy -import math -import numpy as np -import scipy import torch from torch import nn from torch.nn import functional as F -from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d -from torch.nn.utils import weight_norm, remove_weight_norm - +import modules.attentions as attentions import modules.commons as commons -from modules.commons import init_weights, get_padding - +from modules.commons import get_padding, init_weights +from modules.DSConv import ( + Depthwise_Separable_Conv1D, + remove_weight_norm_modules, + weight_norm_modules, +) LRELU_SLOPE = 0.1 +Conv1dModel = nn.Conv1d + +def set_Conv1dModel(use_depthwise_conv): + global Conv1dModel + Conv1dModel = Depthwise_Separable_Conv1D if use_depthwise_conv else nn.Conv1d + class LayerNorm(nn.Module): def __init__(self, channels, eps=1e-5): @@ -44,13 +48,13 @@ class ConvReluNorm(nn.Module): self.conv_layers = nn.ModuleList() self.norm_layers = nn.ModuleList() - self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size//2)) + self.conv_layers.append(Conv1dModel(in_channels, hidden_channels, kernel_size, padding=kernel_size//2)) self.norm_layers.append(LayerNorm(hidden_channels)) self.relu_drop = nn.Sequential( nn.ReLU(), nn.Dropout(p_dropout)) for _ in range(n_layers-1): - self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size//2)) + self.conv_layers.append(Conv1dModel(hidden_channels, hidden_channels, kernel_size, padding=kernel_size//2)) self.norm_layers.append(LayerNorm(hidden_channels)) self.proj = nn.Conv1d(hidden_channels, out_channels, 1) self.proj.weight.data.zero_() @@ -66,47 +70,6 @@ class ConvReluNorm(nn.Module): return x * x_mask -class DDSConv(nn.Module): - """ - Dialted and Depth-Separable Convolution - """ - def __init__(self, channels, kernel_size, n_layers, p_dropout=0.): - super().__init__() - self.channels = channels - self.kernel_size = kernel_size - self.n_layers = n_layers - self.p_dropout = p_dropout - - self.drop = nn.Dropout(p_dropout) - self.convs_sep = nn.ModuleList() - self.convs_1x1 = nn.ModuleList() - self.norms_1 = nn.ModuleList() - self.norms_2 = nn.ModuleList() - for i in range(n_layers): - dilation = kernel_size ** i - padding = (kernel_size * dilation - dilation) // 2 - self.convs_sep.append(nn.Conv1d(channels, channels, kernel_size, - groups=channels, dilation=dilation, padding=padding - )) - self.convs_1x1.append(nn.Conv1d(channels, channels, 1)) - self.norms_1.append(LayerNorm(channels)) - self.norms_2.append(LayerNorm(channels)) - - def forward(self, x, x_mask, g=None): - if g is not None: - x = x + g - for i in range(self.n_layers): - y = self.convs_sep[i](x * x_mask) - y = self.norms_1[i](y) - y = F.gelu(y) - y = self.convs_1x1[i](y) - y = self.norms_2[i](y) - y = F.gelu(y) - y = self.drop(y) - x = x + y - return x * x_mask - - class WN(torch.nn.Module): def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0): super(WN, self).__init__() @@ -124,14 +87,14 @@ class WN(torch.nn.Module): if gin_channels != 0: cond_layer = torch.nn.Conv1d(gin_channels, 2*hidden_channels*n_layers, 1) - self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight') + self.cond_layer = weight_norm_modules(cond_layer, name='weight') for i in range(n_layers): dilation = dilation_rate ** i padding = int((kernel_size * dilation - dilation) / 2) - in_layer = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, kernel_size, + in_layer = Conv1dModel(hidden_channels, 2*hidden_channels, kernel_size, dilation=dilation, padding=padding) - in_layer = torch.nn.utils.weight_norm(in_layer, name='weight') + in_layer = weight_norm_modules(in_layer, name='weight') self.in_layers.append(in_layer) # last one is not necessary @@ -141,7 +104,7 @@ class WN(torch.nn.Module): res_skip_channels = hidden_channels res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1) - res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight') + res_skip_layer = weight_norm_modules(res_skip_layer, name='weight') self.res_skip_layers.append(res_skip_layer) def forward(self, x, x_mask, g=None, **kwargs): @@ -176,32 +139,32 @@ class WN(torch.nn.Module): def remove_weight_norm(self): if self.gin_channels != 0: - torch.nn.utils.remove_weight_norm(self.cond_layer) + remove_weight_norm_modules(self.cond_layer) for l in self.in_layers: - torch.nn.utils.remove_weight_norm(l) + remove_weight_norm_modules(l) for l in self.res_skip_layers: - torch.nn.utils.remove_weight_norm(l) + remove_weight_norm_modules(l) class ResBlock1(torch.nn.Module): def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): super(ResBlock1, self).__init__() self.convs1 = nn.ModuleList([ - weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], + weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=dilation[0], padding=get_padding(kernel_size, dilation[0]))), - weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], + weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=dilation[1], padding=get_padding(kernel_size, dilation[1]))), - weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], + weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=dilation[2], padding=get_padding(kernel_size, dilation[2]))) ]) self.convs1.apply(init_weights) self.convs2 = nn.ModuleList([ - weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))), - weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))), - weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))) ]) self.convs2.apply(init_weights) @@ -223,18 +186,18 @@ class ResBlock1(torch.nn.Module): def remove_weight_norm(self): for l in self.convs1: - remove_weight_norm(l) + remove_weight_norm_modules(l) for l in self.convs2: - remove_weight_norm(l) + remove_weight_norm_modules(l) class ResBlock2(torch.nn.Module): def __init__(self, channels, kernel_size=3, dilation=(1, 3)): super(ResBlock2, self).__init__() self.convs = nn.ModuleList([ - weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], + weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=dilation[0], padding=get_padding(kernel_size, dilation[0]))), - weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], + weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=dilation[1], padding=get_padding(kernel_size, dilation[1]))) ]) self.convs.apply(init_weights) @@ -252,7 +215,7 @@ class ResBlock2(torch.nn.Module): def remove_weight_norm(self): for l in self.convs: - remove_weight_norm(l) + remove_weight_norm_modules(l) class Log(nn.Module): @@ -303,7 +266,9 @@ class ResidualCouplingLayer(nn.Module): n_layers, p_dropout=0, gin_channels=0, - mean_only=False): + mean_only=False, + wn_sharing_parameter=None + ): assert channels % 2 == 0, "channels should be divisible by 2" super().__init__() self.channels = channels @@ -315,7 +280,56 @@ class ResidualCouplingLayer(nn.Module): self.mean_only = mean_only self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) - self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout, gin_channels=gin_channels) + self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout, gin_channels=gin_channels) if wn_sharing_parameter is None else wn_sharing_parameter + self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1) + self.post.weight.data.zero_() + self.post.bias.data.zero_() + + def forward(self, x, x_mask, g=None, reverse=False): + x0, x1 = torch.split(x, [self.half_channels]*2, 1) + h = self.pre(x0) * x_mask + h = self.enc(h, x_mask, g=g) + stats = self.post(h) * x_mask + if not self.mean_only: + m, logs = torch.split(stats, [self.half_channels]*2, 1) + else: + m = stats + logs = torch.zeros_like(m) + + if not reverse: + x1 = m + x1 * torch.exp(logs) * x_mask + x = torch.cat([x0, x1], 1) + logdet = torch.sum(logs, [1,2]) + return x, logdet + else: + x1 = (x1 - m) * torch.exp(-logs) * x_mask + x = torch.cat([x0, x1], 1) + return x + +class TransformerCouplingLayer(nn.Module): + def __init__(self, + channels, + hidden_channels, + kernel_size, + n_layers, + n_heads, + p_dropout=0, + filter_channels=0, + mean_only=False, + wn_sharing_parameter=None, + gin_channels = 0 + ): + assert channels % 2 == 0, "channels should be divisible by 2" + super().__init__() + self.channels = channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.half_channels = channels // 2 + self.mean_only = mean_only + + self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) + self.enc = attentions.FFT(hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, isflow = True, gin_channels = gin_channels) if wn_sharing_parameter is None else wn_sharing_parameter self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1) self.post.weight.data.zero_() self.post.bias.data.zero_() diff --git a/onnxexport/model_onnx.py b/onnxexport/model_onnx.py deleted file mode 100644 index e28bae95ec1e53aa05d06fc784ff86d55f228d60..0000000000000000000000000000000000000000 --- a/onnxexport/model_onnx.py +++ /dev/null @@ -1,335 +0,0 @@ -import torch -from torch import nn -from torch.nn import functional as F - -import modules.attentions as attentions -import modules.commons as commons -import modules.modules as modules - -from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d -from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm - -import utils -from modules.commons import init_weights, get_padding -from vdecoder.hifigan.models import Generator -from utils import f0_to_coarse - - -class ResidualCouplingBlock(nn.Module): - def __init__(self, - channels, - hidden_channels, - kernel_size, - dilation_rate, - n_layers, - n_flows=4, - gin_channels=0): - super().__init__() - self.channels = channels - self.hidden_channels = hidden_channels - self.kernel_size = kernel_size - self.dilation_rate = dilation_rate - self.n_layers = n_layers - self.n_flows = n_flows - self.gin_channels = gin_channels - - self.flows = nn.ModuleList() - for i in range(n_flows): - self.flows.append( - modules.ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, - gin_channels=gin_channels, mean_only=True)) - self.flows.append(modules.Flip()) - - def forward(self, x, x_mask, g=None, reverse=False): - if not reverse: - for flow in self.flows: - x, _ = flow(x, x_mask, g=g, reverse=reverse) - else: - for flow in reversed(self.flows): - x = flow(x, x_mask, g=g, reverse=reverse) - return x - - -class Encoder(nn.Module): - def __init__(self, - in_channels, - out_channels, - hidden_channels, - kernel_size, - dilation_rate, - n_layers, - gin_channels=0): - super().__init__() - self.in_channels = in_channels - self.out_channels = out_channels - self.hidden_channels = hidden_channels - self.kernel_size = kernel_size - self.dilation_rate = dilation_rate - self.n_layers = n_layers - self.gin_channels = gin_channels - - self.pre = nn.Conv1d(in_channels, hidden_channels, 1) - self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels) - self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) - - def forward(self, x, x_lengths, g=None): - # print(x.shape,x_lengths.shape) - x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) - x = self.pre(x) * x_mask - x = self.enc(x, x_mask, g=g) - stats = self.proj(x) * x_mask - m, logs = torch.split(stats, self.out_channels, dim=1) - z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask - return z, m, logs, x_mask - - -class TextEncoder(nn.Module): - def __init__(self, - out_channels, - hidden_channels, - kernel_size, - n_layers, - gin_channels=0, - filter_channels=None, - n_heads=None, - p_dropout=None): - super().__init__() - self.out_channels = out_channels - self.hidden_channels = hidden_channels - self.kernel_size = kernel_size - self.n_layers = n_layers - self.gin_channels = gin_channels - self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) - self.f0_emb = nn.Embedding(256, hidden_channels) - - self.enc_ = attentions.Encoder( - hidden_channels, - filter_channels, - n_heads, - n_layers, - kernel_size, - p_dropout) - - def forward(self, x, x_mask, f0=None, z=None): - x = x + self.f0_emb(f0).transpose(1, 2) - x = self.enc_(x * x_mask, x_mask) - stats = self.proj(x) * x_mask - m, logs = torch.split(stats, self.out_channels, dim=1) - z = (m + z * torch.exp(logs)) * x_mask - return z, m, logs, x_mask - - -class DiscriminatorP(torch.nn.Module): - def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): - super(DiscriminatorP, self).__init__() - self.period = period - self.use_spectral_norm = use_spectral_norm - norm_f = weight_norm if use_spectral_norm == False else spectral_norm - self.convs = nn.ModuleList([ - norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), - norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), - norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), - norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), - norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))), - ]) - self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) - - def forward(self, x): - fmap = [] - - # 1d to 2d - b, c, t = x.shape - if t % self.period != 0: # pad first - n_pad = self.period - (t % self.period) - x = F.pad(x, (0, n_pad), "reflect") - t = t + n_pad - x = x.view(b, c, t // self.period, self.period) - - for l in self.convs: - x = l(x) - x = F.leaky_relu(x, modules.LRELU_SLOPE) - fmap.append(x) - x = self.conv_post(x) - fmap.append(x) - x = torch.flatten(x, 1, -1) - - return x, fmap - - -class DiscriminatorS(torch.nn.Module): - def __init__(self, use_spectral_norm=False): - super(DiscriminatorS, self).__init__() - norm_f = weight_norm if use_spectral_norm == False else spectral_norm - self.convs = nn.ModuleList([ - norm_f(Conv1d(1, 16, 15, 1, padding=7)), - norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)), - norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)), - norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)), - norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), - norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), - ]) - self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) - - def forward(self, x): - fmap = [] - - for l in self.convs: - x = l(x) - x = F.leaky_relu(x, modules.LRELU_SLOPE) - fmap.append(x) - x = self.conv_post(x) - fmap.append(x) - x = torch.flatten(x, 1, -1) - - return x, fmap - - -class F0Decoder(nn.Module): - def __init__(self, - out_channels, - hidden_channels, - filter_channels, - n_heads, - n_layers, - kernel_size, - p_dropout, - spk_channels=0): - super().__init__() - self.out_channels = out_channels - self.hidden_channels = hidden_channels - self.filter_channels = filter_channels - self.n_heads = n_heads - self.n_layers = n_layers - self.kernel_size = kernel_size - self.p_dropout = p_dropout - self.spk_channels = spk_channels - - self.prenet = nn.Conv1d(hidden_channels, hidden_channels, 3, padding=1) - self.decoder = attentions.FFT( - hidden_channels, - filter_channels, - n_heads, - n_layers, - kernel_size, - p_dropout) - self.proj = nn.Conv1d(hidden_channels, out_channels, 1) - self.f0_prenet = nn.Conv1d(1, hidden_channels, 3, padding=1) - self.cond = nn.Conv1d(spk_channels, hidden_channels, 1) - - def forward(self, x, norm_f0, x_mask, spk_emb=None): - x = torch.detach(x) - if spk_emb is not None: - x = x + self.cond(spk_emb) - x += self.f0_prenet(norm_f0) - x = self.prenet(x) * x_mask - x = self.decoder(x * x_mask, x_mask) - x = self.proj(x) * x_mask - return x - - -class SynthesizerTrn(nn.Module): - """ - Synthesizer for Training - """ - - def __init__(self, - spec_channels, - segment_size, - inter_channels, - hidden_channels, - filter_channels, - n_heads, - n_layers, - kernel_size, - p_dropout, - resblock, - resblock_kernel_sizes, - resblock_dilation_sizes, - upsample_rates, - upsample_initial_channel, - upsample_kernel_sizes, - gin_channels, - ssl_dim, - n_speakers, - sampling_rate=44100, - **kwargs): - super().__init__() - self.spec_channels = spec_channels - self.inter_channels = inter_channels - self.hidden_channels = hidden_channels - self.filter_channels = filter_channels - self.n_heads = n_heads - self.n_layers = n_layers - self.kernel_size = kernel_size - self.p_dropout = p_dropout - self.resblock = resblock - self.resblock_kernel_sizes = resblock_kernel_sizes - self.resblock_dilation_sizes = resblock_dilation_sizes - self.upsample_rates = upsample_rates - self.upsample_initial_channel = upsample_initial_channel - self.upsample_kernel_sizes = upsample_kernel_sizes - self.segment_size = segment_size - self.gin_channels = gin_channels - self.ssl_dim = ssl_dim - self.emb_g = nn.Embedding(n_speakers, gin_channels) - - self.pre = nn.Conv1d(ssl_dim, hidden_channels, kernel_size=5, padding=2) - - self.enc_p = TextEncoder( - inter_channels, - hidden_channels, - filter_channels=filter_channels, - n_heads=n_heads, - n_layers=n_layers, - kernel_size=kernel_size, - p_dropout=p_dropout - ) - hps = { - "sampling_rate": sampling_rate, - "inter_channels": inter_channels, - "resblock": resblock, - "resblock_kernel_sizes": resblock_kernel_sizes, - "resblock_dilation_sizes": resblock_dilation_sizes, - "upsample_rates": upsample_rates, - "upsample_initial_channel": upsample_initial_channel, - "upsample_kernel_sizes": upsample_kernel_sizes, - "gin_channels": gin_channels, - } - self.dec = Generator(h=hps) - self.enc_q = Encoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels) - self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels) - self.f0_decoder = F0Decoder( - 1, - hidden_channels, - filter_channels, - n_heads, - n_layers, - kernel_size, - p_dropout, - spk_channels=gin_channels - ) - self.emb_uv = nn.Embedding(2, hidden_channels) - self.predict_f0 = False - - def forward(self, c, f0, mel2ph, uv, noise=None, g=None): - - decoder_inp = F.pad(c, [0, 0, 1, 0]) - mel2ph_ = mel2ph.unsqueeze(2).repeat([1, 1, c.shape[-1]]) - c = torch.gather(decoder_inp, 1, mel2ph_).transpose(1, 2) # [B, T, H] - - c_lengths = (torch.ones(c.size(0)) * c.size(-1)).to(c.device) - g = g.unsqueeze(0) - g = self.emb_g(g).transpose(1, 2) - x_mask = torch.unsqueeze(commons.sequence_mask(c_lengths, c.size(2)), 1).to(c.dtype) - x = self.pre(c) * x_mask + self.emb_uv(uv.long()).transpose(1, 2) - - if self.predict_f0: - lf0 = 2595. * torch.log10(1. + f0.unsqueeze(1) / 700.) / 500 - norm_lf0 = utils.normalize_f0(lf0, x_mask, uv, random_scale=False) - pred_lf0 = self.f0_decoder(x, norm_lf0, x_mask, spk_emb=g) - f0 = (700 * (torch.pow(10, pred_lf0 * 500 / 2595) - 1)).squeeze(1) - - z_p, m_p, logs_p, c_mask = self.enc_p(x, x_mask, f0=f0_to_coarse(f0), z=noise) - z = self.flow(z_p, c_mask, g=g, reverse=True) - o = self.dec(z * c_mask, g=g, f0=f0) - return o diff --git a/diffusion/__init__.py b/pretrain/__init__.py similarity index 100% rename from diffusion/__init__.py rename to pretrain/__init__.py diff --git a/pretrain/checkpoint_best_legacy_500.pt b/pretrain/checkpoint_best_legacy_500.pt index 9a2f13fb9c7047dff746e2d5d88c0d0a5aecf643..72f47ab58564f01d5cc8b05c63bdf96d944551ff 100644 --- a/pretrain/checkpoint_best_legacy_500.pt +++ b/pretrain/checkpoint_best_legacy_500.pt @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:60d936ec5a566776fc392e69ad8b630d14eb588111233fe313436e200a7b187b -size 1330114945 +oid sha256:f54b40fd2802423a5643779c4861af1e9ee9c1564dc9d32f54f20b5ffba7db96 +size 189507909 diff --git a/pretrain/meta.py b/pretrain/meta.py index cc35dd3c0dfe8436e7d635f2db507cedca75ed49..c591573a6aca9fefbb15561e76e7fcbac8e90961 100644 --- a/pretrain/meta.py +++ b/pretrain/meta.py @@ -12,9 +12,17 @@ def download_dict(): "url": "https://github.com/bshall/hubert/releases/download/v0.1/hubert-soft-0d54a1f4.pt", "output": "./pretrain/hubert-soft-0d54a1f4.pt" }, + "whisper-ppg-small": { + "url": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt", + "output": "./pretrain/small.pt" + }, "whisper-ppg": { "url": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt", "output": "./pretrain/medium.pt" + }, + "whisper-ppg-large": { + "url": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt", + "output": "./pretrain/large-v2.pt" } } diff --git a/pretrain/nsf_hifigan/.gitattributes b/pretrain/nsf_hifigan/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..382c42cde8152ff21085245e6da2cff66b783f3e --- /dev/null +++ b/pretrain/nsf_hifigan/.gitattributes @@ -0,0 +1,2 @@ +model filter=lfs diff=lfs merge=lfs -text +pretrain/*.pt filter=lfs diff=lfs merge=lfs -text diff --git a/pretrain/nsf_hifigan/model b/pretrain/nsf_hifigan/model new file mode 100644 index 0000000000000000000000000000000000000000..6ff8d81f7fe19ab507232cdd35667f3ccba9893c --- /dev/null +++ b/pretrain/nsf_hifigan/model @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2c576b63b7ed952161b70fad34e0562ace502ce689195520d8a2a6c051de29d6 +size 56825430 diff --git a/pretrain/rmvpe.pt b/pretrain/rmvpe.pt new file mode 100644 index 0000000000000000000000000000000000000000..c70b9e1c71c721763a91e200607fb1d17494ae6c --- /dev/null +++ b/pretrain/rmvpe.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:19dc1809cf4cdb0a18db93441816bc327e14e5644b72eeaae5220560c6736fe2 +size 368492925 diff --git a/requirements.txt b/requirements.txt index a441bdd77752cfe6b0af41a1a2d7a9c2ff5a279b..aa206de1199696c6846e4b74b3d378a7ff01c1aa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,32 +1,22 @@ ffmpeg-python -Flask -Flask_Cors -gradio==3.18.0 numpy -playsound pydub requests scipy sounddevice SoundFile -starlette -torch -torchaudio +torch==2.2.1 +torchaudio==2.2.1 tqdm scikit-maad praat-parselmouth -onnx -onnxsim -onnxoptimizer fairseq librosa edge-tts pyworld -tensorboard -tensorboardX transformers pyyaml pynvml ffmpeg torchcrepe - +faiss-cpu diff --git a/utils.py b/utils.py index c35561562ebe704e15804830101345620a9934d5..95b6d8882867a81bc638237957dd3141b7bc1210 100644 --- a/utils.py +++ b/utils.py @@ -1,3 +1,4 @@ +import argparse import glob import json import logging @@ -5,11 +6,15 @@ import os import re import subprocess import sys +import traceback +from multiprocessing import cpu_count +import faiss import librosa import numpy as np import torch from scipy.io.wavfile import read +from sklearn.cluster import MiniBatchKMeans from torch.nn import functional as F MATPLOTLIB_FLAG = False @@ -38,7 +43,6 @@ def normalize_f0(f0, x_mask, uv, random_scale=True): if torch.isnan(f0_norm).any(): exit(0) return f0_norm * x_mask - def plot_data_to_numpy(x, y): global MATPLOTLIB_FLAG if not MATPLOTLIB_FLAG: @@ -61,125 +65,20 @@ def plot_data_to_numpy(x, y): plt.close() return data -def interpolate_f0(f0): - ''' - 对F0进行插值处理 - ''' - - data = np.reshape(f0, (f0.size, 1)) - - vuv_vector = np.zeros((data.size, 1), dtype=np.float32) - vuv_vector[data > 0.0] = 1.0 - vuv_vector[data <= 0.0] = 0.0 - - ip_data = data - - frame_number = data.size - last_value = 0.0 - for i in range(frame_number): - if data[i] <= 0.0: - j = i + 1 - for j in range(i + 1, frame_number): - if data[j] > 0.0: - break - if j < frame_number - 1: - if last_value > 0.0: - step = (data[j] - data[i - 1]) / float(j - i) - for k in range(i, j): - ip_data[k] = data[i - 1] + step * (k - i + 1) - else: - for k in range(i, j): - ip_data[k] = data[j] - else: - for k in range(i, frame_number): - ip_data[k] = last_value - else: - ip_data[i] = data[i] - last_value = data[i] - - return ip_data[:,0], vuv_vector[:,0] - -def compute_f0_parselmouth(wav_numpy, p_len=None, sampling_rate=44100, hop_length=512): - import parselmouth - x = wav_numpy - if p_len is None: - p_len = x.shape[0]//hop_length - else: - assert abs(p_len-x.shape[0]//hop_length) < 4, "pad length error" - time_step = hop_length / sampling_rate * 1000 - f0_min = 50 - f0_max = 1100 - f0 = parselmouth.Sound(x, sampling_rate).to_pitch_ac( - time_step=time_step / 1000, voicing_threshold=0.6, - pitch_floor=f0_min, pitch_ceiling=f0_max).selected_array['frequency'] - - pad_size=(p_len - len(f0) + 1) // 2 - if(pad_size>0 or p_len - len(f0) - pad_size>0): - f0 = np.pad(f0,[[pad_size,p_len - len(f0) - pad_size]], mode='constant') - return f0 - -def resize_f0(x, target_len): - source = np.array(x) - source[source<0.001] = np.nan - target = np.interp(np.arange(0, len(source)*target_len, len(source))/ target_len, np.arange(0, len(source)), source) - res = np.nan_to_num(target) - return res - -def compute_f0_dio(wav_numpy, p_len=None, sampling_rate=44100, hop_length=512): - import pyworld - if p_len is None: - p_len = wav_numpy.shape[0]//hop_length - f0, t = pyworld.dio( - wav_numpy.astype(np.double), - fs=sampling_rate, - f0_ceil=800, - frame_period=1000 * hop_length / sampling_rate, - ) - f0 = pyworld.stonemask(wav_numpy.astype(np.double), f0, t, sampling_rate) - for index, pitch in enumerate(f0): - f0[index] = round(pitch, 1) - return resize_f0(f0, p_len) def f0_to_coarse(f0): - is_torch = isinstance(f0, torch.Tensor) - f0_mel = 1127 * (1 + f0 / 700).log() if is_torch else 1127 * np.log(1 + f0 / 700) - f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * (f0_bin - 2) / (f0_mel_max - f0_mel_min) + 1 - - f0_mel[f0_mel <= 1] = 1 - f0_mel[f0_mel > f0_bin - 1] = f0_bin - 1 - f0_coarse = (f0_mel + 0.5).int() if is_torch else np.rint(f0_mel).astype(np.int) - assert f0_coarse.max() <= 255 and f0_coarse.min() >= 1, (f0_coarse.max(), f0_coarse.min()) + f0_mel = 1127 * (1 + f0 / 700).log() + a = (f0_bin - 2) / (f0_mel_max - f0_mel_min) + b = f0_mel_min * a - 1. + f0_mel = torch.where(f0_mel > 0, f0_mel * a - b, f0_mel) + # torch.clip_(f0_mel, min=1., max=float(f0_bin - 1)) + f0_coarse = torch.round(f0_mel).long() + f0_coarse = f0_coarse * (f0_coarse > 0) + f0_coarse = f0_coarse + ((f0_coarse < 1) * 1) + f0_coarse = f0_coarse * (f0_coarse < f0_bin) + f0_coarse = f0_coarse + ((f0_coarse >= f0_bin) * (f0_bin - 1)) return f0_coarse -def get_hubert_model(): - vec_path = "hubert/checkpoint_best_legacy_500.pt" - print("load model(s) from {}".format(vec_path)) - from fairseq import checkpoint_utils - models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task( - [vec_path], - suffix="", - ) - model = models[0] - model.eval() - return model - -def get_hubert_content(hmodel, wav_16k_tensor): - feats = wav_16k_tensor - if feats.dim() == 2: # double channels - feats = feats.mean(-1) - assert feats.dim() == 1, feats.dim() - feats = feats.view(1, -1) - padding_mask = torch.BoolTensor(feats.shape).fill_(False) - inputs = { - "source": feats.to(wav_16k_tensor.device), - "padding_mask": padding_mask.to(wav_16k_tensor.device), - "output_layer": 9, # layer 9 - } - with torch.no_grad(): - logits = hmodel.extract_features(**inputs) - feats = hmodel.final_proj(logits[0]) - return feats.transpose(1, 2) - def get_content(cmodel, y): with torch.no_grad(): c = cmodel.extract_features(y.squeeze(1))[0] @@ -198,7 +97,13 @@ def get_f0_predictor(f0_predictor,hop_length,sampling_rate,**kargs): f0_predictor_object = HarvestF0Predictor(hop_length=hop_length,sampling_rate=sampling_rate) elif f0_predictor == "dio": from modules.F0Predictor.DioF0Predictor import DioF0Predictor - f0_predictor_object = DioF0Predictor(hop_length=hop_length,sampling_rate=sampling_rate) + f0_predictor_object = DioF0Predictor(hop_length=hop_length,sampling_rate=sampling_rate) + elif f0_predictor == "rmvpe": + from modules.F0Predictor.RMVPEF0Predictor import RMVPEF0Predictor + f0_predictor_object = RMVPEF0Predictor(hop_length=hop_length,sampling_rate=sampling_rate,dtype=torch.float32 ,device=kargs["device"],threshold=kargs["threshold"]) + elif f0_predictor == "fcpe": + from modules.F0Predictor.FCPEF0Predictor import FCPEF0Predictor + f0_predictor_object = FCPEF0Predictor(hop_length=hop_length,sampling_rate=sampling_rate,dtype=torch.float32 ,device=kargs["device"],threshold=kargs["threshold"]) else: raise Exception("Unknown f0 predictor") return f0_predictor_object @@ -240,6 +145,9 @@ def get_speech_encoder(speech_encoder,device=None,**kargs): elif speech_encoder == "whisper-ppg-large": from vencoder.WhisperPPGLarge import WhisperPPGLarge speech_encoder_object = WhisperPPGLarge(device = device) + elif speech_encoder == "wavlmbase+": + from vencoder.WavLMBasePlus import WavLMBasePlus + speech_encoder_object = WavLMBasePlus(device = device) else: raise Exception("Unknown speech encoder") return speech_encoder_object @@ -252,6 +160,7 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False if optimizer is not None and not skip_optimizer and checkpoint_dict['optimizer'] is not None: optimizer.load_state_dict(checkpoint_dict['optimizer']) saved_state_dict = checkpoint_dict['model'] + model = model.to(list(saved_state_dict.values())[0].dtype) if hasattr(model, 'module'): state_dict = model.module.state_dict() else: @@ -263,10 +172,11 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False # print("load", k) new_state_dict[k] = saved_state_dict[k] assert saved_state_dict[k].shape == v.shape, (saved_state_dict[k].shape, v.shape) - except: - print("error, %s is not in the checkpoint" % k) - logger.info("%s is not in the checkpoint" % k) - new_state_dict[k] = v + except Exception: + if "enc_q" not in k or "emb_g" not in k: + print("%s is not in the checkpoint,please check your checkpoint.If you're using pretrain model,just ignore this warning." % k) + logger.info("%s is not in the checkpoint" % k) + new_state_dict[k] = v if hasattr(model, 'module'): model.module.load_state_dict(new_state_dict) else: @@ -276,6 +186,7 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False checkpoint_path, iteration)) return model, optimizer, learning_rate, iteration + def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path): logger.info("Saving model and optimizer state at iteration {} to {}".format( iteration, checkpoint_path)) @@ -298,15 +209,20 @@ def clean_checkpoints(path_to_models='logs/44k/', n_ckpts_to_keep=2, sort_by_tim False -> lexicographically delete ckpts """ ckpts_files = [f for f in os.listdir(path_to_models) if os.path.isfile(os.path.join(path_to_models, f))] - name_key = (lambda _f: int(re.compile('._(\d+)\.pth').match(_f).group(1))) - time_key = (lambda _f: os.path.getmtime(os.path.join(path_to_models, _f))) + def name_key(_f): + return int(re.compile("._(\\d+)\\.pth").match(_f).group(1)) + def time_key(_f): + return os.path.getmtime(os.path.join(path_to_models, _f)) sort_key = time_key if sort_by_time else name_key - x_sorted = lambda _x: sorted([f for f in ckpts_files if f.startswith(_x) and not f.endswith('_0.pth')], key=sort_key) + def x_sorted(_x): + return sorted([f for f in ckpts_files if f.startswith(_x) and not f.endswith("_0.pth")], key=sort_key) to_del = [os.path.join(path_to_models, fn) for fn in (x_sorted('G')[:-n_ckpts_to_keep] + x_sorted('D')[:-n_ckpts_to_keep])] - del_info = lambda fn: logger.info(f".. Free up space by deleting ckpt {fn}") - del_routine = lambda x: [os.remove(x), del_info(x)] - rs = [del_routine(fn) for fn in to_del] + def del_info(fn): + return logger.info(f".. Free up space by deleting ckpt {fn}") + def del_routine(x): + return [os.remove(x), del_info(x)] + [del_routine(fn) for fn in to_del] def summarize(writer, global_step, scalars={}, histograms={}, images={}, audios={}, audio_sampling_rate=22050): for k, v in scalars.items(): @@ -393,11 +309,52 @@ def load_filepaths_and_text(filename, split="|"): return filepaths_and_text -def get_hparams_from_file(config_path): - with open(config_path, "r") as f: +def get_hparams(init=True): + parser = argparse.ArgumentParser() + parser.add_argument('-c', '--config', type=str, default="./configs/config.json", + help='JSON file for configuration') + parser.add_argument('-m', '--model', type=str, required=True, + help='Model name') + + args = parser.parse_args() + model_dir = os.path.join("./logs", args.model) + + if not os.path.exists(model_dir): + os.makedirs(model_dir) + + config_path = args.config + config_save_path = os.path.join(model_dir, "config.json") + if init: + with open(config_path, "r") as f: + data = f.read() + with open(config_save_path, "w") as f: + f.write(data) + else: + with open(config_save_path, "r") as f: + data = f.read() + config = json.loads(data) + + hparams = HParams(**config) + hparams.model_dir = model_dir + return hparams + + +def get_hparams_from_dir(model_dir): + config_save_path = os.path.join(model_dir, "config.json") + with open(config_save_path, "r") as f: data = f.read() config = json.loads(data) + hparams =HParams(**config) + hparams.model_dir = model_dir + return hparams + + +def get_hparams_from_file(config_path, infer_mode = False): + with open(config_path, "r") as f: + data = f.read() + config = json.loads(data) + hparams =HParams(**config) if not infer_mode else InferHParams(**config) return hparams @@ -436,7 +393,13 @@ def get_logger(model_dir, filename="train.log"): return logger -def repeat_expand_2d(content, target_len): +def repeat_expand_2d(content, target_len, mode = 'left'): + # content : [h, t] + return repeat_expand_2d_left(content, target_len) if mode == 'left' else repeat_expand_2d_other(content, target_len, mode) + + + +def repeat_expand_2d_left(content, target_len): # content : [h, t] src_len = content.shape[-1] @@ -453,7 +416,27 @@ def repeat_expand_2d(content, target_len): return target +# mode : 'nearest'| 'linear'| 'bilinear'| 'bicubic'| 'trilinear'| 'area' +def repeat_expand_2d_other(content, target_len, mode = 'nearest'): + # content : [h, t] + content = content[None,:,:] + target = F.interpolate(content,size=target_len,mode=mode)[0] + return target + +def mix_model(model_paths,mix_rate,mode): + mix_rate = torch.FloatTensor(mix_rate)/100 + model_tem = torch.load(model_paths[0]) + models = [torch.load(path)["model"] for path in model_paths] + if mode == 0: + mix_rate = F.softmax(mix_rate,dim=0) + for k in model_tem["model"].keys(): + model_tem["model"][k] = torch.zeros_like(model_tem["model"][k]) + for i,model in enumerate(models): + model_tem["model"][k] += model[k]*mix_rate[i] + torch.save(model_tem,os.path.join(os.path.curdir,"output.pth")) + return os.path.join(os.path.curdir,"output.pth") + def change_rms(data1, sr1, data2, sr2, rate): # 1是输入音频,2是输出音频,rate是2的占比 from RVC # print(data1.max(),data2.max()) rms1 = librosa.feature.rms( @@ -475,6 +458,58 @@ def change_rms(data1, sr1, data2, sr2, rate): # 1是输入音频,2是输出 ) return data2 +def train_index(spk_name,root_dir = "dataset/44k/"): #from: RVC https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI + n_cpu = cpu_count() + print("The feature index is constructing.") + exp_dir = os.path.join(root_dir,spk_name) + listdir_res = [] + for file in os.listdir(exp_dir): + if ".wav.soft.pt" in file: + listdir_res.append(os.path.join(exp_dir,file)) + if len(listdir_res) == 0: + raise Exception("You need to run preprocess_hubert_f0.py!") + npys = [] + for name in sorted(listdir_res): + phone = torch.load(name)[0].transpose(-1,-2).numpy() + npys.append(phone) + big_npy = np.concatenate(npys, 0) + big_npy_idx = np.arange(big_npy.shape[0]) + np.random.shuffle(big_npy_idx) + big_npy = big_npy[big_npy_idx] + if big_npy.shape[0] > 2e5: + # if(1): + info = "Trying doing kmeans %s shape to 10k centers." % big_npy.shape[0] + print(info) + try: + big_npy = ( + MiniBatchKMeans( + n_clusters=10000, + verbose=True, + batch_size=256 * n_cpu, + compute_labels=False, + init="random", + ) + .fit(big_npy) + .cluster_centers_ + ) + except Exception: + info = traceback.format_exc() + print(info) + n_ivf = min(int(16 * np.sqrt(big_npy.shape[0])), big_npy.shape[0] // 39) + index = faiss.index_factory(big_npy.shape[1] , "IVF%s,Flat" % n_ivf) + index_ivf = faiss.extract_index_ivf(index) # + index_ivf.nprobe = 1 + index.train(big_npy) + batch_size_add = 8192 + for i in range(0, big_npy.shape[0], batch_size_add): + index.add(big_npy[i : i + batch_size_add]) + # faiss.write_index( + # index, + # f"added_{spk_name}.index" + # ) + print("Successfully build index") + return index + class HParams(): def __init__(self, **kwargs): @@ -510,6 +545,18 @@ class HParams(): def get(self,index): return self.__dict__.get(index) + +class InferHParams(HParams): + def __init__(self, **kwargs): + for k, v in kwargs.items(): + if type(v) == dict: + v = InferHParams(**v) + self[k] = v + + def __getattr__(self,index): + return self.get(index) + + class Volume_Extractor: def __init__(self, hop_size = 512): self.hop_size = hop_size @@ -520,6 +567,6 @@ class Volume_Extractor: n_frames = int(audio.size(-1) // self.hop_size) audio2 = audio ** 2 audio2 = torch.nn.functional.pad(audio2, (int(self.hop_size // 2), int((self.hop_size + 1) // 2)), mode = 'reflect') - volume = torch.FloatTensor([torch.mean(audio2[:,int(n * self.hop_size) : int((n + 1) * self.hop_size)]) for n in range(n_frames)]) + volume = torch.nn.functional.unfold(audio2[:,None,None,:],(1,self.hop_size),stride=self.hop_size)[:,:,:n_frames].mean(dim=1)[0] volume = torch.sqrt(volume) return volume diff --git a/vdecoder/hifigan/models.py b/vdecoder/hifigan/models.py index 9747301f350bb269e62601017fe4633ce271b27e..107553368ff1798f72df21c6d5a965260f5a60fd 100644 --- a/vdecoder/hifigan/models.py +++ b/vdecoder/hifigan/models.py @@ -1,13 +1,15 @@ -import os import json -from .env import AttrDict +import os + import numpy as np import torch -import torch.nn.functional as F import torch.nn as nn -from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d -from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm -from .utils import init_weights, get_padding +import torch.nn.functional as F +from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d +from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm + +from .env import AttrDict +from .utils import get_padding, init_weights LRELU_SLOPE = 0.1 @@ -126,6 +128,7 @@ class SineGen(torch.nn.Module): self.sampling_rate = samp_rate self.voiced_threshold = voiced_threshold self.flag_for_pulse = flag_for_pulse + self.onnx = False def _f02uv(self, f0): # generate uv signal @@ -191,37 +194,81 @@ class SineGen(torch.nn.Module): sines = torch.cos(i_phase * 2 * np.pi) return sines - def forward(self, f0): + def forward(self, f0, upp=None): """ sine_tensor, uv = forward(f0) input F0: tensor(batchsize=1, length, dim=1) f0 for unvoiced steps should be 0 output sine_tensor: tensor(batchsize=1, length, dim) output uv: tensor(batchsize=1, length, 1) """ - with torch.no_grad(): - f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, - device=f0.device) - # fundamental component - fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device)) + if self.onnx: + with torch.no_grad(): + f0 = f0[:, None].transpose(1, 2) + f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device) + # fundamental component + f0_buf[:, :, 0] = f0[:, :, 0] + for idx in np.arange(self.harmonic_num): + f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * ( + idx + 2 + ) # idx + 2: the (idx+1)-th overtone, (idx+2)-th harmonic + rad_values = (f0_buf / self.sampling_rate) % 1 ###%1意味着n_har的乘积无法后处理优化 + rand_ini = torch.rand( + f0_buf.shape[0], f0_buf.shape[2], device=f0_buf.device + ) + rand_ini[:, 0] = 0 + rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini + tmp_over_one = torch.cumsum(rad_values, 1) # % 1 #####%1意味着后面的cumsum无法再优化 + tmp_over_one *= upp + tmp_over_one = F.interpolate( + tmp_over_one.transpose(2, 1), + scale_factor=upp, + mode="linear", + align_corners=True, + ).transpose(2, 1) + rad_values = F.interpolate( + rad_values.transpose(2, 1), scale_factor=upp, mode="nearest" + ).transpose( + 2, 1 + ) ####### + tmp_over_one %= 1 + tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0 + cumsum_shift = torch.zeros_like(rad_values) + cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0 + sine_waves = torch.sin( + torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi + ) + sine_waves = sine_waves * self.sine_amp + uv = self._f02uv(f0) + uv = F.interpolate( + uv.transpose(2, 1), scale_factor=upp, mode="nearest" + ).transpose(2, 1) + noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3 + noise = noise_amp * torch.randn_like(sine_waves) + sine_waves = sine_waves * uv + noise + return sine_waves, uv, noise + else: + with torch.no_grad(): + # fundamental component + fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device)) - # generate sine waveforms - sine_waves = self._f02sine(fn) * self.sine_amp + # generate sine waveforms + sine_waves = self._f02sine(fn) * self.sine_amp - # generate uv signal - # uv = torch.ones(f0.shape) - # uv = uv * (f0 > self.voiced_threshold) - uv = self._f02uv(f0) + # generate uv signal + # uv = torch.ones(f0.shape) + # uv = uv * (f0 > self.voiced_threshold) + uv = self._f02uv(f0) - # noise: for unvoiced should be similar to sine_amp - # std = self.sine_amp/3 -> max value ~ self.sine_amp - # . for voiced regions is self.noise_std - noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3 - noise = noise_amp * torch.randn_like(sine_waves) + # noise: for unvoiced should be similar to sine_amp + # std = self.sine_amp/3 -> max value ~ self.sine_amp + # . for voiced regions is self.noise_std + noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3 + noise = noise_amp * torch.randn_like(sine_waves) - # first: set the unvoiced part to 0 by uv - # then: additive noise - sine_waves = sine_waves * uv + noise - return sine_waves, uv, noise + # first: set the unvoiced part to 0 by uv + # then: additive noise + sine_waves = sine_waves * uv + noise + return sine_waves, uv, noise class SourceModuleHnNSF(torch.nn.Module): @@ -257,7 +304,7 @@ class SourceModuleHnNSF(torch.nn.Module): self.l_linear = torch.nn.Linear(harmonic_num + 1, 1) self.l_tanh = torch.nn.Tanh() - def forward(self, x): + def forward(self, x, upp=None): """ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled) F0_sampled (batchsize, length, 1) @@ -265,8 +312,8 @@ class SourceModuleHnNSF(torch.nn.Module): noise_source (batchsize, length 1) """ # source for harmonic branch - sine_wavs, uv, _ = self.l_sin_gen(x) - sine_merge = self.l_tanh(self.l_linear(sine_wavs)) + sine_wavs, uv, _ = self.l_sin_gen(x, upp) + sine_merge = self.l_tanh(self.l_linear(sine_wavs.to(self.l_linear.weight.dtype))) # source for noise branch, in the same shape as uv noise = torch.randn_like(uv) * self.sine_amp / 3 @@ -292,11 +339,11 @@ class Generator(torch.nn.Module): c_cur = h["upsample_initial_channel"] // (2 ** (i + 1)) self.ups.append(weight_norm( ConvTranspose1d(h["upsample_initial_channel"] // (2 ** i), h["upsample_initial_channel"] // (2 ** (i + 1)), - k, u, padding=(k - u) // 2))) + k, u, padding=(k - u +1 ) // 2))) if i + 1 < len(h["upsample_rates"]): # stride_f0 = np.prod(h["upsample_rates"][i + 1:]) self.noise_convs.append(Conv1d( - 1, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=stride_f0 // 2)) + 1, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=(stride_f0+1) // 2)) else: self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1)) self.resblocks = nn.ModuleList() @@ -309,12 +356,19 @@ class Generator(torch.nn.Module): self.ups.apply(init_weights) self.conv_post.apply(init_weights) self.cond = nn.Conv1d(h['gin_channels'], h['upsample_initial_channel'], 1) + self.upp = np.prod(h["upsample_rates"]) + self.onnx = False + + def OnnxExport(self): + self.onnx = True + self.m_source.l_sin_gen.onnx = True def forward(self, x, f0, g=None): # print(1,x.shape,f0.shape,f0[:, None].shape) - f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t + if not self.onnx: + f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t # print(2,f0.shape) - har_source, noi_source, uv = self.m_source(f0) + har_source, noi_source, uv = self.m_source(f0, self.upp) har_source = har_source.transpose(1, 2) x = self.conv_pre(x) x = x + self.cond(g) @@ -353,7 +407,7 @@ class DiscriminatorP(torch.nn.Module): def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): super(DiscriminatorP, self).__init__() self.period = period - norm_f = weight_norm if use_spectral_norm == False else spectral_norm + norm_f = weight_norm if use_spectral_norm is False else spectral_norm self.convs = nn.ModuleList([ norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), @@ -412,7 +466,7 @@ class MultiPeriodDiscriminator(torch.nn.Module): class DiscriminatorS(torch.nn.Module): def __init__(self, use_spectral_norm=False): super(DiscriminatorS, self).__init__() - norm_f = weight_norm if use_spectral_norm == False else spectral_norm + norm_f = weight_norm if use_spectral_norm is False else spectral_norm self.convs = nn.ModuleList([ norm_f(Conv1d(1, 128, 15, 1, padding=7)), norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)), diff --git a/vdecoder/hifigan/nvSTFT.py b/vdecoder/hifigan/nvSTFT.py index 88597d62a505715091f9ba62d38bf0a85a31b95a..b3321b2ee3da28f43c2650ea011e14d5e1cdcc94 100644 --- a/vdecoder/hifigan/nvSTFT.py +++ b/vdecoder/hifigan/nvSTFT.py @@ -1,15 +1,13 @@ -import math import os -os.environ["LRU_CACHE_CAPACITY"] = "3" -import random + +import librosa +import numpy as np +import soundfile as sf import torch import torch.utils.data -import numpy as np -import librosa -from librosa.util import normalize from librosa.filters import mel as librosa_mel_fn -from scipy.io.wavfile import read -import soundfile as sf + +os.environ["LRU_CACHE_CAPACITY"] = "3" def load_wav_to_torch(full_path, target_sr=None, return_empty_on_exception=False): sampling_rate = None diff --git a/vdecoder/hifigan/utils.py b/vdecoder/hifigan/utils.py index 9c93c996d3cc73c30d71c1fc47056e4230f35c0f..fba57ba37a4739ef234ae25ee6dcc1feebf2fd22 100644 --- a/vdecoder/hifigan/utils.py +++ b/vdecoder/hifigan/utils.py @@ -1,22 +1,8 @@ import glob import os -import matplotlib + import torch from torch.nn.utils import weight_norm -# matplotlib.use("Agg") -import matplotlib.pylab as plt - - -def plot_spectrogram(spectrogram): - fig, ax = plt.subplots(figsize=(10, 2)) - im = ax.imshow(spectrogram, aspect="auto", origin="lower", - interpolation='none') - plt.colorbar(im, ax=ax) - - fig.canvas.draw() - plt.close() - - return fig def init_weights(m, mean=0.0, std=0.01): diff --git a/vdecoder/hifiganwithsnake/alias/__init__.py b/vdecoder/hifiganwithsnake/alias/__init__.py index a2318b63198250856809c0cb46210a4147b829bc..be97a33248ae6378c6736586774abda11cfbdeba 100644 --- a/vdecoder/hifiganwithsnake/alias/__init__.py +++ b/vdecoder/hifiganwithsnake/alias/__init__.py @@ -1,6 +1,6 @@ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 # LICENSE is in incl_licenses directory. -from .filter import * -from .resample import * -from .act import * \ No newline at end of file +from .act import * # noqa: F403 +from .filter import * # noqa: F403 +from .resample import * # noqa: F403 diff --git a/vdecoder/hifiganwithsnake/alias/act.py b/vdecoder/hifiganwithsnake/alias/act.py index 308344fb6ccbc39317c584a3ee1fb2f29084678e..e46b3467b73b90df51c1d19032b90d26595aca6e 100644 --- a/vdecoder/hifiganwithsnake/alias/act.py +++ b/vdecoder/hifiganwithsnake/alias/act.py @@ -4,10 +4,10 @@ import torch import torch.nn as nn import torch.nn.functional as F - -from torch import sin, pow +from torch import pow, sin from torch.nn import Parameter -from .resample import UpSample1d, DownSample1d + +from .resample import DownSample1d, UpSample1d class Activation1d(nn.Module): @@ -112,17 +112,18 @@ class SnakeAlias(nn.Module): up_ratio: int = 2, down_ratio: int = 2, up_kernel_size: int = 12, - down_kernel_size: int = 12): + down_kernel_size: int = 12, + C = None): super().__init__() self.up_ratio = up_ratio self.down_ratio = down_ratio self.act = SnakeBeta(channels, alpha_logscale=True) - self.upsample = UpSample1d(up_ratio, up_kernel_size) - self.downsample = DownSample1d(down_ratio, down_kernel_size) + self.upsample = UpSample1d(up_ratio, up_kernel_size, C) + self.downsample = DownSample1d(down_ratio, down_kernel_size, C) # x: [B,C,T] - def forward(self, x): - x = self.upsample(x) + def forward(self, x, C=None): + x = self.upsample(x, C) x = self.act(x) x = self.downsample(x) diff --git a/vdecoder/hifiganwithsnake/alias/filter.py b/vdecoder/hifiganwithsnake/alias/filter.py index 7ad6ea87c1f10ddd94c544037791d7a4634d5ae1..3942eb3ae547a2f500d5c47defdd70cd29ea4655 100644 --- a/vdecoder/hifiganwithsnake/alias/filter.py +++ b/vdecoder/hifiganwithsnake/alias/filter.py @@ -1,10 +1,11 @@ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 # LICENSE is in incl_licenses directory. +import math + import torch import torch.nn as nn import torch.nn.functional as F -import math if 'sinc' in dir(torch): sinc = torch.sinc @@ -64,7 +65,8 @@ class LowPassFilter1d(nn.Module): stride: int = 1, padding: bool = True, padding_mode: str = 'replicate', - kernel_size: int = 12): + kernel_size: int = 12, + C=None): # kernel_size should be even number for stylegan3 setup, # in this implementation, odd number is also possible. super().__init__() @@ -81,15 +83,28 @@ class LowPassFilter1d(nn.Module): self.padding_mode = padding_mode filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size) self.register_buffer("filter", filter) + self.conv1d_block = None + if C is not None: + self.conv1d_block = [nn.Conv1d(C,C,kernel_size,stride=self.stride, groups=C, bias=False),] + self.conv1d_block[0].weight = nn.Parameter(self.filter.expand(C, -1, -1)) + self.conv1d_block[0].requires_grad_(False) #input [B, C, T] def forward(self, x): - _, C, _ = x.shape + if self.conv1d_block[0].weight.device != x.device: + self.conv1d_block[0] = self.conv1d_block[0].to(x.device) + if self.conv1d_block is None: + _, C, _ = x.shape - if self.padding: - x = F.pad(x, (self.pad_left, self.pad_right), - mode=self.padding_mode) - out = F.conv1d(x, self.filter.expand(C, -1, -1), - stride=self.stride, groups=C) + if self.padding: + x = F.pad(x, (self.pad_left, self.pad_right), + mode=self.padding_mode) + out = F.conv1d(x, self.filter.expand(C, -1, -1), + stride=self.stride, groups=C) + else: + if self.padding: + x = F.pad(x, (self.pad_left, self.pad_right), + mode=self.padding_mode) + out = self.conv1d_block[0](x) return out \ No newline at end of file diff --git a/vdecoder/hifiganwithsnake/alias/resample.py b/vdecoder/hifiganwithsnake/alias/resample.py index 750e6c3402cc5ac939c4b9d075246562e0e1d1a7..a364403f0977bc8bcffbb4764081e4bd3619467a 100644 --- a/vdecoder/hifiganwithsnake/alias/resample.py +++ b/vdecoder/hifiganwithsnake/alias/resample.py @@ -3,12 +3,12 @@ import torch.nn as nn from torch.nn import functional as F -from .filter import LowPassFilter1d -from .filter import kaiser_sinc_filter1d + +from .filter import LowPassFilter1d, kaiser_sinc_filter1d class UpSample1d(nn.Module): - def __init__(self, ratio=2, kernel_size=None): + def __init__(self, ratio=2, kernel_size=None, C=None): super().__init__() self.ratio = ratio self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size @@ -20,28 +20,51 @@ class UpSample1d(nn.Module): half_width=0.6 / ratio, kernel_size=self.kernel_size) self.register_buffer("filter", filter) + self.conv_transpose1d_block = None + if C is not None: + self.conv_transpose1d_block = [nn.ConvTranspose1d(C, + C, + kernel_size=self.kernel_size, + stride=self.stride, + groups=C, + bias=False + ),] + self.conv_transpose1d_block[0].weight = nn.Parameter(self.filter.expand(C, -1, -1).clone()) + self.conv_transpose1d_block[0].requires_grad_(False) + + # x: [B, C, T] - def forward(self, x): - _, C, _ = x.shape - - x = F.pad(x, (self.pad, self.pad), mode='replicate') - x = self.ratio * F.conv_transpose1d( - x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C) - x = x[..., self.pad_left:-self.pad_right] - + def forward(self, x, C=None): + if self.conv_transpose1d_block[0].weight.device != x.device: + self.conv_transpose1d_block[0] = self.conv_transpose1d_block[0].to(x.device) + if self.conv_transpose1d_block is None: + if C is None: + _, C, _ = x.shape + # print("snake.conv_t.in:",x.shape) + x = F.pad(x, (self.pad, self.pad), mode='replicate') + x = self.ratio * F.conv_transpose1d( + x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C) + # print("snake.conv_t.out:",x.shape) + x = x[..., self.pad_left:-self.pad_right] + else: + x = F.pad(x, (self.pad, self.pad), mode='replicate') + x = self.ratio * self.conv_transpose1d_block[0](x) + x = x[..., self.pad_left:-self.pad_right] return x class DownSample1d(nn.Module): - def __init__(self, ratio=2, kernel_size=None): + def __init__(self, ratio=2, kernel_size=None, C=None): super().__init__() self.ratio = ratio self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size self.lowpass = LowPassFilter1d(cutoff=0.5 / ratio, half_width=0.6 / ratio, stride=ratio, - kernel_size=self.kernel_size) + kernel_size=self.kernel_size, + C=C) + def forward(self, x): xx = self.lowpass(x) diff --git a/vdecoder/hifiganwithsnake/models.py b/vdecoder/hifiganwithsnake/models.py index 664547b0c3dd24f95da0356a056658a2f28dc928..08bbda9b77b095d81ca8d8a9e5e8ebe20fa9bcfa 100644 --- a/vdecoder/hifiganwithsnake/models.py +++ b/vdecoder/hifiganwithsnake/models.py @@ -1,15 +1,18 @@ -import os import json -from .env import AttrDict +import os + import numpy as np import torch -import torch.nn.functional as F import torch.nn as nn -from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d -from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm -from .utils import init_weights, get_padding +import torch.nn.functional as F +from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d +from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm + from vdecoder.hifiganwithsnake.alias.act import SnakeAlias +from .env import AttrDict +from .utils import get_padding, init_weights + LRELU_SLOPE = 0.1 @@ -33,7 +36,7 @@ def load_model(model_path, device='cuda'): class ResBlock1(torch.nn.Module): - def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5), C=None): super(ResBlock1, self).__init__() self.h = h self.convs1 = nn.ModuleList([ @@ -58,15 +61,15 @@ class ResBlock1(torch.nn.Module): self.num_layers = len(self.convs1) + len(self.convs2) self.activations = nn.ModuleList([ - SnakeAlias(channels) for _ in range(self.num_layers) + SnakeAlias(channels, C=C) for _ in range(self.num_layers) ]) - def forward(self, x): + def forward(self, x, DIM=None): acts1, acts2 = self.activations[::2], self.activations[1::2] for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2): - xt = a1(x) + xt = a1(x, DIM) xt = c1(xt) - xt = a2(xt) + xt = a2(xt, DIM) xt = c2(xt) x = xt + x return x @@ -79,7 +82,7 @@ class ResBlock1(torch.nn.Module): class ResBlock2(torch.nn.Module): - def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)): + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3), C=None): super(ResBlock2, self).__init__() self.h = h self.convs = nn.ModuleList([ @@ -92,12 +95,12 @@ class ResBlock2(torch.nn.Module): self.num_layers = len(self.convs) self.activations = nn.ModuleList([ - SnakeAlias(channels) for _ in range(self.num_layers) + SnakeAlias(channels, C=C) for _ in range(self.num_layers) ]) - def forward(self, x): + def forward(self, x, DIM=None): for c,a in zip(self.convs, self.activations): - xt = a(x) + xt = a(x, DIM) xt = c(xt) x = xt + x return x @@ -138,6 +141,7 @@ class SineGen(torch.nn.Module): self.sampling_rate = samp_rate self.voiced_threshold = voiced_threshold self.flag_for_pulse = flag_for_pulse + self.onnx = False def _f02uv(self, f0): # generate uv signal @@ -203,37 +207,82 @@ class SineGen(torch.nn.Module): sines = torch.cos(i_phase * 2 * np.pi) return sines - def forward(self, f0): + def forward(self, f0, upp=None): """ sine_tensor, uv = forward(f0) input F0: tensor(batchsize=1, length, dim=1) f0 for unvoiced steps should be 0 output sine_tensor: tensor(batchsize=1, length, dim) output uv: tensor(batchsize=1, length, 1) """ - with torch.no_grad(): - f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, - device=f0.device) - # fundamental component - fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device)) + + if self.onnx: + with torch.no_grad(): + f0 = f0[:, None].transpose(1, 2) + f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device) + # fundamental component + f0_buf[:, :, 0] = f0[:, :, 0] + for idx in np.arange(self.harmonic_num): + f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * ( + idx + 2 + ) # idx + 2: the (idx+1)-th overtone, (idx+2)-th harmonic + rad_values = (f0_buf / self.sampling_rate) % 1 ###%1意味着n_har的乘积无法后处理优化 + rand_ini = torch.rand( + f0_buf.shape[0], f0_buf.shape[2], device=f0_buf.device + ) + rand_ini[:, 0] = 0 + rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini + tmp_over_one = torch.cumsum(rad_values, 1) # % 1 #####%1意味着后面的cumsum无法再优化 + tmp_over_one *= upp + tmp_over_one = F.interpolate( + tmp_over_one.transpose(2, 1), + scale_factor=upp, + mode="linear", + align_corners=True, + ).transpose(2, 1) + rad_values = F.interpolate( + rad_values.transpose(2, 1), scale_factor=upp, mode="nearest" + ).transpose( + 2, 1 + ) ####### + tmp_over_one %= 1 + tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0 + cumsum_shift = torch.zeros_like(rad_values) + cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0 + sine_waves = torch.sin( + torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi + ) + sine_waves = sine_waves * self.sine_amp + uv = self._f02uv(f0) + uv = F.interpolate( + uv.transpose(2, 1), scale_factor=upp, mode="nearest" + ).transpose(2, 1) + noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3 + noise = noise_amp * torch.randn_like(sine_waves) + sine_waves = sine_waves * uv + noise + return sine_waves, uv, noise + else: + with torch.no_grad(): + # fundamental component + fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device)) - # generate sine waveforms - sine_waves = self._f02sine(fn) * self.sine_amp + # generate sine waveforms + sine_waves = self._f02sine(fn) * self.sine_amp - # generate uv signal - # uv = torch.ones(f0.shape) - # uv = uv * (f0 > self.voiced_threshold) - uv = self._f02uv(f0) + # generate uv signal + # uv = torch.ones(f0.shape) + # uv = uv * (f0 > self.voiced_threshold) + uv = self._f02uv(f0) - # noise: for unvoiced should be similar to sine_amp - # std = self.sine_amp/3 -> max value ~ self.sine_amp - # . for voiced regions is self.noise_std - noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3 - noise = noise_amp * torch.randn_like(sine_waves) + # noise: for unvoiced should be similar to sine_amp + # std = self.sine_amp/3 -> max value ~ self.sine_amp + # . for voiced regions is self.noise_std + noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3 + noise = noise_amp * torch.randn_like(sine_waves) - # first: set the unvoiced part to 0 by uv - # then: additive noise - sine_waves = sine_waves * uv + noise - return sine_waves, uv, noise + # first: set the unvoiced part to 0 by uv + # then: additive noise + sine_waves = sine_waves * uv + noise + return sine_waves, uv, noise class SourceModuleHnNSF(torch.nn.Module): @@ -269,7 +318,7 @@ class SourceModuleHnNSF(torch.nn.Module): self.l_linear = torch.nn.Linear(harmonic_num + 1, 1) self.l_tanh = torch.nn.Tanh() - def forward(self, x): + def forward(self, x, upp=None): """ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled) F0_sampled (batchsize, length, 1) @@ -277,8 +326,8 @@ class SourceModuleHnNSF(torch.nn.Module): noise_source (batchsize, length 1) """ # source for harmonic branch - sine_wavs, uv, _ = self.l_sin_gen(x) - sine_merge = self.l_tanh(self.l_linear(sine_wavs)) + sine_wavs, uv, _ = self.l_sin_gen(x, upp) + sine_merge = self.l_tanh(self.l_linear(sine_wavs.to(self.l_linear.weight.dtype))) # source for noise branch, in the same shape as uv noise = torch.randn_like(uv) * self.sine_amp / 3 @@ -304,39 +353,47 @@ class Generator(torch.nn.Module): c_cur = h["upsample_initial_channel"] // (2 ** (i + 1)) self.ups.append(weight_norm( ConvTranspose1d(h["upsample_initial_channel"] // (2 ** i), h["upsample_initial_channel"] // (2 ** (i + 1)), - k, u, padding=(k - u) // 2))) + k, u, padding=(k - u + 1) // 2))) if i + 1 < len(h["upsample_rates"]): # stride_f0 = np.prod(h["upsample_rates"][i + 1:]) self.noise_convs.append(Conv1d( - 1, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=stride_f0 // 2)) + 1, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=(stride_f0+ 1) // 2)) else: self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1)) self.resblocks = nn.ModuleList() self.snakes = nn.ModuleList() for i in range(len(self.ups)): ch = h["upsample_initial_channel"] // (2 ** (i + 1)) - self.snakes.append(SnakeAlias(h["upsample_initial_channel"] // (2 ** (i)))) + self.snakes.append(SnakeAlias(h["upsample_initial_channel"] // (2 ** (i)), C = h["upsample_initial_channel"] >> i)) for j, (k, d) in enumerate(zip(h["resblock_kernel_sizes"], h["resblock_dilation_sizes"])): - self.resblocks.append(resblock(h, ch, k, d)) + self.resblocks.append(resblock(h, ch, k, d, C = h["upsample_initial_channel"] >> (i + 1))) self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) self.ups.apply(init_weights) self.conv_post.apply(init_weights) - self.snake_post = SnakeAlias(ch) + self.snake_post = SnakeAlias(ch, C = h["upsample_initial_channel"] >> len(self.ups)) self.cond = nn.Conv1d(h['gin_channels'], h['upsample_initial_channel'], 1) + self.upp = np.prod(h["upsample_rates"]) + self.onnx = False + + def OnnxExport(self): + self.onnx = True + self.m_source.l_sin_gen.onnx = True def forward(self, x, f0, g=None): # print(1,x.shape,f0.shape,f0[:, None].shape) - f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t + if not self.onnx: + f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t # print(2,f0.shape) - har_source, noi_source, uv = self.m_source(f0) + har_source, noi_source, uv = self.m_source(f0, self.upp) har_source = har_source.transpose(1, 2) x = self.conv_pre(x) x = x + self.cond(g) # print(124,x.shape,har_source.shape) for i in range(self.num_upsamples): + # print(f"self.snakes.{i}.pre:", x.shape) x = self.snakes[i](x) - # print(3,x.shape) + # print(f"self.snakes.{i}.after:", x.shape) x = self.ups[i](x) x_source = self.noise_convs[i](har_source) # print(4,x_source.shape,har_source.shape,x.shape) @@ -347,6 +404,7 @@ class Generator(torch.nn.Module): xs = self.resblocks[i * self.num_kernels + j](x) else: xs += self.resblocks[i * self.num_kernels + j](x) + # print(f"self.resblocks.{i}.after:", xs.shape) x = xs / self.num_kernels x = self.snake_post(x) x = self.conv_post(x) @@ -368,7 +426,7 @@ class DiscriminatorP(torch.nn.Module): def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): super(DiscriminatorP, self).__init__() self.period = period - norm_f = weight_norm if use_spectral_norm == False else spectral_norm + norm_f = weight_norm if use_spectral_norm is False else spectral_norm self.convs = nn.ModuleList([ norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), @@ -427,7 +485,7 @@ class MultiPeriodDiscriminator(torch.nn.Module): class DiscriminatorS(torch.nn.Module): def __init__(self, use_spectral_norm=False): super(DiscriminatorS, self).__init__() - norm_f = weight_norm if use_spectral_norm == False else spectral_norm + norm_f = weight_norm if use_spectral_norm is False else spectral_norm self.convs = nn.ModuleList([ norm_f(Conv1d(1, 128, 15, 1, padding=7)), norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)), diff --git a/vdecoder/hifiganwithsnake/nvSTFT.py b/vdecoder/hifiganwithsnake/nvSTFT.py index 88597d62a505715091f9ba62d38bf0a85a31b95a..b3321b2ee3da28f43c2650ea011e14d5e1cdcc94 100644 --- a/vdecoder/hifiganwithsnake/nvSTFT.py +++ b/vdecoder/hifiganwithsnake/nvSTFT.py @@ -1,15 +1,13 @@ -import math import os -os.environ["LRU_CACHE_CAPACITY"] = "3" -import random + +import librosa +import numpy as np +import soundfile as sf import torch import torch.utils.data -import numpy as np -import librosa -from librosa.util import normalize from librosa.filters import mel as librosa_mel_fn -from scipy.io.wavfile import read -import soundfile as sf + +os.environ["LRU_CACHE_CAPACITY"] = "3" def load_wav_to_torch(full_path, target_sr=None, return_empty_on_exception=False): sampling_rate = None diff --git a/vdecoder/hifiganwithsnake/utils.py b/vdecoder/hifiganwithsnake/utils.py index 9c93c996d3cc73c30d71c1fc47056e4230f35c0f..e519e2b7ed8fe5f93266d21d727a30173699f88b 100644 --- a/vdecoder/hifiganwithsnake/utils.py +++ b/vdecoder/hifiganwithsnake/utils.py @@ -1,10 +1,10 @@ import glob import os -import matplotlib -import torch -from torch.nn.utils import weight_norm + # matplotlib.use("Agg") import matplotlib.pylab as plt +import torch +from torch.nn.utils import weight_norm def plot_spectrogram(spectrogram): diff --git a/vdecoder/nsf_hifigan/models.py b/vdecoder/nsf_hifigan/models.py index c2c889ec2fbd215702298ba2b7c411c6f5630d80..8a35b134d814008c3990d019d1de502ff10dd86f 100644 --- a/vdecoder/nsf_hifigan/models.py +++ b/vdecoder/nsf_hifigan/models.py @@ -1,13 +1,15 @@ -import os import json -from .env import AttrDict +import os + import numpy as np import torch -import torch.nn.functional as F import torch.nn as nn -from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d -from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm -from .utils import init_weights, get_padding +import torch.nn.functional as F +from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d +from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm + +from .env import AttrDict +from .utils import get_padding, init_weights LRELU_SLOPE = 0.1 @@ -289,7 +291,7 @@ class DiscriminatorP(torch.nn.Module): def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): super(DiscriminatorP, self).__init__() self.period = period - norm_f = weight_norm if use_spectral_norm == False else spectral_norm + norm_f = weight_norm if use_spectral_norm is False else spectral_norm self.convs = nn.ModuleList([ norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), @@ -348,7 +350,7 @@ class MultiPeriodDiscriminator(torch.nn.Module): class DiscriminatorS(torch.nn.Module): def __init__(self, use_spectral_norm=False): super(DiscriminatorS, self).__init__() - norm_f = weight_norm if use_spectral_norm == False else spectral_norm + norm_f = weight_norm if use_spectral_norm is False else spectral_norm self.convs = nn.ModuleList([ norm_f(Conv1d(1, 128, 15, 1, padding=7)), norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)), diff --git a/vdecoder/nsf_hifigan/nvSTFT.py b/vdecoder/nsf_hifigan/nvSTFT.py index 62bd5a008f81929054f036c81955d5d73377f772..e756cca561a45bde435f36447e6681bfa17e34aa 100644 --- a/vdecoder/nsf_hifigan/nvSTFT.py +++ b/vdecoder/nsf_hifigan/nvSTFT.py @@ -1,16 +1,14 @@ -import math import os -os.environ["LRU_CACHE_CAPACITY"] = "3" -import random -import torch -import torch.utils.data -import numpy as np + import librosa -from librosa.util import normalize -from librosa.filters import mel as librosa_mel_fn -from scipy.io.wavfile import read +import numpy as np import soundfile as sf +import torch import torch.nn.functional as F +import torch.utils.data +from librosa.filters import mel as librosa_mel_fn + +os.environ["LRU_CACHE_CAPACITY"] = "3" def load_wav_to_torch(full_path, target_sr=None, return_empty_on_exception=False): sampling_rate = None diff --git a/vdecoder/nsf_hifigan/utils.py b/vdecoder/nsf_hifigan/utils.py index 84bff024f4d2e2de194b2a88ee7bbe5f0d33f67c..58d0e701d377e318fe0302743c27bdb4d6e089ec 100644 --- a/vdecoder/nsf_hifigan/utils.py +++ b/vdecoder/nsf_hifigan/utils.py @@ -1,10 +1,12 @@ import glob import os + import matplotlib +import matplotlib.pylab as plt import torch from torch.nn.utils import weight_norm + matplotlib.use("Agg") -import matplotlib.pylab as plt def plot_spectrogram(spectrogram): diff --git a/vencoder/CNHubertLarge.py b/vencoder/CNHubertLarge.py index 9db93781c36884c4096fa6fa5a12a95d385e80b8..f43694762f92c5d839d358825f157f5d1a4ff6f6 100644 --- a/vencoder/CNHubertLarge.py +++ b/vencoder/CNHubertLarge.py @@ -1,9 +1,12 @@ -from vencoder.encoder import SpeechEncoder import torch from fairseq import checkpoint_utils +from vencoder.encoder import SpeechEncoder + + class CNHubertLarge(SpeechEncoder): - def __init__(self,vec_path = "pretrain/chinese-hubert-large-fairseq-ckpt.pt",device=None): + def __init__(self, vec_path="pretrain/chinese-hubert-large-fairseq-ckpt.pt", device=None): + super().__init__() print("load model(s) from {}".format(vec_path)) self.hidden_dim = 1024 models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task( @@ -20,7 +23,7 @@ class CNHubertLarge(SpeechEncoder): def encoder(self, wav): feats = wav if feats.dim() == 2: # double channels - feats = feats.mean(-1) + feats = feats.mean(-1) assert feats.dim() == 1, feats.dim() feats = feats.view(1, -1) padding_mask = torch.BoolTensor(feats.shape).fill_(False) @@ -29,5 +32,5 @@ class CNHubertLarge(SpeechEncoder): "padding_mask": padding_mask.to(wav.device) } with torch.no_grad(): - logits = self.model.extract_features(**inputs) + logits = self.model.extract_features(**inputs) return logits[0].transpose(1, 2) \ No newline at end of file diff --git a/vencoder/ContentVec256L12_Onnx.py b/vencoder/ContentVec256L12_Onnx.py index 9ad5085e02654fd1fcfbdad7d476bfa9b763d2c6..466e6c128b88acdfb94392662086e6752d503a27 100644 --- a/vencoder/ContentVec256L12_Onnx.py +++ b/vencoder/ContentVec256L12_Onnx.py @@ -1,25 +1,30 @@ -from vencoder.encoder import SpeechEncoder import onnxruntime import torch +from vencoder.encoder import SpeechEncoder + + class ContentVec256L12_Onnx(SpeechEncoder): - def __init__(self,vec_path = "pretrain/vec-256-layer-12.onnx",device=None): + def __init__(self, vec_path="pretrain/vec-256-layer-12.onnx", device=None): + super().__init__() print("load model(s) from {}".format(vec_path)) self.hidden_dim = 256 if device is None: self.dev = torch.device("cpu") else: self.dev = torch.device(device) - if device == 'cpu' or device == torch.device("cpu") or device is None: - providers = ['CPUExecutionProvider'] - elif device == 'cuda' or device == torch.device("cuda"): + + if device == 'cuda' or device == torch.device("cuda"): providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] + else: + providers = ['CPUExecutionProvider'] + self.model = onnxruntime.InferenceSession(vec_path, providers=providers) def encoder(self, wav): feats = wav if feats.dim() == 2: # double channels - feats = feats.mean(-1) + feats = feats.mean(-1) assert feats.dim() == 1, feats.dim() feats = feats.view(1, -1) feats = feats.unsqueeze(0).cpu().detach().numpy() diff --git a/vencoder/ContentVec256L9.py b/vencoder/ContentVec256L9.py index b0089c789cd87cfd3b1badb2fc45cb1b88041eab..c973090dd4cdaa3d8ca07d9007c26633883c36a7 100644 --- a/vencoder/ContentVec256L9.py +++ b/vencoder/ContentVec256L9.py @@ -1,9 +1,12 @@ -from vencoder.encoder import SpeechEncoder import torch from fairseq import checkpoint_utils +from vencoder.encoder import SpeechEncoder + + class ContentVec256L9(SpeechEncoder): - def __init__(self,vec_path = "pretrain/checkpoint_best_legacy_500.pt",device=None): + def __init__(self, vec_path="pretrain/checkpoint_best_legacy_500.pt", device=None): + super().__init__() print("load model(s) from {}".format(vec_path)) models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task( [vec_path], @@ -20,7 +23,7 @@ class ContentVec256L9(SpeechEncoder): def encoder(self, wav): feats = wav if feats.dim() == 2: # double channels - feats = feats.mean(-1) + feats = feats.mean(-1) assert feats.dim() == 1, feats.dim() feats = feats.view(1, -1) padding_mask = torch.BoolTensor(feats.shape).fill_(False) @@ -30,6 +33,6 @@ class ContentVec256L9(SpeechEncoder): "output_layer": 9, # layer 9 } with torch.no_grad(): - logits = self.model.extract_features(**inputs) - feats = self.model.final_proj(logits[0]) + logits = self.model.extract_features(**inputs) + feats = self.model.final_proj(logits[0]) return feats.transpose(1, 2) diff --git a/vencoder/ContentVec256L9_Onnx.py b/vencoder/ContentVec256L9_Onnx.py index fae2b928252801795b038f51451b234e007f6f03..a27e1f76655d9dc9fcc41d05d11b4a1ac5d85b90 100644 --- a/vencoder/ContentVec256L9_Onnx.py +++ b/vencoder/ContentVec256L9_Onnx.py @@ -1,9 +1,12 @@ -from vencoder.encoder import SpeechEncoder import onnxruntime import torch +from vencoder.encoder import SpeechEncoder + + class ContentVec256L9_Onnx(SpeechEncoder): - def __init__(self,vec_path = "pretrain/vec-256-layer-9.onnx",device=None): + def __init__(self, vec_path="pretrain/vec-256-layer-9.onnx", device=None): + super().__init__() print("load model(s) from {}".format(vec_path)) self.hidden_dim = 256 if device is None: @@ -19,10 +22,11 @@ class ContentVec256L9_Onnx(SpeechEncoder): def encoder(self, wav): feats = wav if feats.dim() == 2: # double channels - feats = feats.mean(-1) + feats = feats.mean(-1) assert feats.dim() == 1, feats.dim() feats = feats.view(1, -1) feats = feats.unsqueeze(0).cpu().detach().numpy() onnx_input = {self.model.get_inputs()[0].name: feats} logits = self.model.run(None, onnx_input) - return torch.tensor(logits[0]).transpose(1, 2).to(self.dev) \ No newline at end of file + return torch.tensor(logits[0]).transpose(1, 2).to(self.dev) + \ No newline at end of file diff --git a/vencoder/ContentVec768L12.py b/vencoder/ContentVec768L12.py index 0d1591c8843b920d5685e822354e8e6adc9a9e19..066b824b68447b5c860730c9f11b7be415068b46 100644 --- a/vencoder/ContentVec768L12.py +++ b/vencoder/ContentVec768L12.py @@ -1,9 +1,12 @@ -from vencoder.encoder import SpeechEncoder import torch from fairseq import checkpoint_utils +from vencoder.encoder import SpeechEncoder + + class ContentVec768L12(SpeechEncoder): - def __init__(self,vec_path = "pretrain/checkpoint_best_legacy_500.pt",device=None): + def __init__(self, vec_path="pretrain/checkpoint_best_legacy_500.pt", device=None): + super().__init__() print("load model(s) from {}".format(vec_path)) self.hidden_dim = 768 models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task( @@ -20,7 +23,7 @@ class ContentVec768L12(SpeechEncoder): def encoder(self, wav): feats = wav if feats.dim() == 2: # double channels - feats = feats.mean(-1) + feats = feats.mean(-1) assert feats.dim() == 1, feats.dim() feats = feats.view(1, -1) padding_mask = torch.BoolTensor(feats.shape).fill_(False) @@ -30,5 +33,5 @@ class ContentVec768L12(SpeechEncoder): "output_layer": 12, # layer 12 } with torch.no_grad(): - logits = self.model.extract_features(**inputs) - return logits[0].transpose(1, 2) \ No newline at end of file + logits = self.model.extract_features(**inputs) + return logits[0].transpose(1, 2) diff --git a/vencoder/ContentVec768L12_Onnx.py b/vencoder/ContentVec768L12_Onnx.py index 8dde0f173ed60169282128cc51eb1c200c5d82c5..e737594526fd09f19353b85c11d4c357a325af48 100644 --- a/vencoder/ContentVec768L12_Onnx.py +++ b/vencoder/ContentVec768L12_Onnx.py @@ -1,28 +1,33 @@ -from vencoder.encoder import SpeechEncoder import onnxruntime import torch +from vencoder.encoder import SpeechEncoder + + class ContentVec768L12_Onnx(SpeechEncoder): - def __init__(self,vec_path = "pretrain/vec-768-layer-12.onnx",device=None): + def __init__(self, vec_path="pretrain/vec-768-layer-12.onnx", device=None): + super().__init__() print("load model(s) from {}".format(vec_path)) self.hidden_dim = 768 if device is None: self.dev = torch.device("cpu") else: self.dev = torch.device(device) - if device == 'cpu' or device == torch.device("cpu") or device is None: - providers = ['CPUExecutionProvider'] - elif device == 'cuda' or device == torch.device("cuda"): + + if device == 'cuda' or device == torch.device("cuda"): providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] + else: + providers = ['CPUExecutionProvider'] + self.model = onnxruntime.InferenceSession(vec_path, providers=providers) def encoder(self, wav): feats = wav if feats.dim() == 2: # double channels - feats = feats.mean(-1) + feats = feats.mean(-1) assert feats.dim() == 1, feats.dim() feats = feats.view(1, -1) feats = feats.unsqueeze(0).cpu().detach().numpy() onnx_input = {self.model.get_inputs()[0].name: feats} logits = self.model.run(None, onnx_input) - return torch.tensor(logits[0]).transpose(1, 2).to(self.dev) \ No newline at end of file + return torch.tensor(logits[0]).transpose(1, 2).to(self.dev) diff --git a/vencoder/ContentVec768L9_Onnx.py b/vencoder/ContentVec768L9_Onnx.py index 7cdac4cd93478d3ddddb4b76dd9d9ccc5d1af2d4..3bd0f337bbf5fa261ea43adfab2377fced7c9e7c 100644 --- a/vencoder/ContentVec768L9_Onnx.py +++ b/vencoder/ContentVec768L9_Onnx.py @@ -1,28 +1,33 @@ -from vencoder.encoder import SpeechEncoder import onnxruntime import torch +from vencoder.encoder import SpeechEncoder + + class ContentVec768L9_Onnx(SpeechEncoder): def __init__(self,vec_path = "pretrain/vec-768-layer-9.onnx",device=None): + super().__init__() print("load model(s) from {}".format(vec_path)) self.hidden_dim = 768 if device is None: self.dev = torch.device("cpu") else: self.dev = torch.device(device) - if device == 'cpu' or device == torch.device("cpu") or device is None: - providers = ['CPUExecutionProvider'] - elif device == 'cuda' or device == torch.device("cuda"): + + if device == 'cuda' or device == torch.device("cuda"): providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] + else: + providers = ['CPUExecutionProvider'] + self.model = onnxruntime.InferenceSession(vec_path, providers=providers) def encoder(self, wav): feats = wav if feats.dim() == 2: # double channels - feats = feats.mean(-1) + feats = feats.mean(-1) assert feats.dim() == 1, feats.dim() feats = feats.view(1, -1) feats = feats.unsqueeze(0).cpu().detach().numpy() onnx_input = {self.model.get_inputs()[0].name: feats} logits = self.model.run(None, onnx_input) - return torch.tensor(logits[0]).transpose(1, 2).to(self.dev) \ No newline at end of file + return torch.tensor(logits[0]).transpose(1, 2).to(self.dev) diff --git a/vencoder/DPHubert.py b/vencoder/DPHubert.py index 95b98b8b2e08e76139ce652bbbdb60dc42248a19..130064ff3ea5c24017be2f0faa204fc4c7dbd078 100644 --- a/vencoder/DPHubert.py +++ b/vencoder/DPHubert.py @@ -1,9 +1,12 @@ -from vencoder.encoder import SpeechEncoder import torch + from vencoder.dphubert.model import wav2vec2_model +from vencoder.encoder import SpeechEncoder + class DPHubert(SpeechEncoder): - def __init__(self,vec_path = "pretrain/DPHuBERT-sp0.75.pth",device=None): + def __init__(self, vec_path="pretrain/DPHuBERT-sp0.75.pth", device=None): + super().__init__() print("load model(s) from {}".format(vec_path)) if device is None: self.dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -17,10 +20,10 @@ class DPHubert(SpeechEncoder): def encoder(self, wav): feats = wav if feats.dim() == 2: # double channels - feats = feats.mean(-1) + feats = feats.mean(-1) assert feats.dim() == 1, feats.dim() - feats = feats[None,:] + feats = feats[None, :] with torch.no_grad(): with torch.inference_mode(): - units = self.model(feats)[0] - return units.transpose(1,2) + units = self.model(feats)[0] + return units.transpose(1,2) diff --git a/vencoder/HubertSoft.py b/vencoder/HubertSoft.py index c7155e9edd8b3d898643f59111cd0c7a83067749..423c159c44f0e5cb820a911a47b71ae1478d725d 100644 --- a/vencoder/HubertSoft.py +++ b/vencoder/HubertSoft.py @@ -1,8 +1,12 @@ -from vencoder.encoder import SpeechEncoder import torch + +from vencoder.encoder import SpeechEncoder from vencoder.hubert import hubert_model + + class HubertSoft(SpeechEncoder): - def __init__(self,vec_path = "pretrain/hubert-soft-0d54a1f4.pt",device=None): + def __init__(self, vec_path="pretrain/hubert-soft-0d54a1f4.pt", device=None): + super().__init__() print("load model(s) from {}".format(vec_path)) hubert_soft = hubert_model.hubert_soft(vec_path) if device is None: @@ -15,10 +19,10 @@ class HubertSoft(SpeechEncoder): def encoder(self, wav): feats = wav if feats.dim() == 2: # double channels - feats = feats.mean(-1) + feats = feats.mean(-1) assert feats.dim() == 1, feats.dim() feats = feats[None,None,:] with torch.no_grad(): with torch.inference_mode(): - units = self.model.units(feats) - return units.transpose(1,2) + units = self.model.units(feats) + return units.transpose(1,2) diff --git a/vencoder/HubertSoft_Onnx.py b/vencoder/HubertSoft_Onnx.py index 06f10a4ca79c429ed59ab9743578128e8db506cc..038d78e8ffa0804cb63b146f8122b3f2bba2f637 100644 --- a/vencoder/HubertSoft_Onnx.py +++ b/vencoder/HubertSoft_Onnx.py @@ -1,28 +1,33 @@ -from vencoder.encoder import SpeechEncoder import onnxruntime import torch +from vencoder.encoder import SpeechEncoder + + class HubertSoft_Onnx(SpeechEncoder): - def __init__(self,vec_path = "pretrain/hubert-soft.onnx",device=None): + def __init__(self, vec_path="pretrain/hubert-soft.onnx", device=None): + super().__init__() print("load model(s) from {}".format(vec_path)) self.hidden_dim = 256 if device is None: self.dev = torch.device("cpu") else: self.dev = torch.device(device) - if device == 'cpu' or device == torch.device("cpu") or device is None: - providers = ['CPUExecutionProvider'] - elif device == 'cuda' or device == torch.device("cuda"): + + if device == 'cuda' or device == torch.device("cuda"): providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] + else: + providers = ['CPUExecutionProvider'] + self.model = onnxruntime.InferenceSession(vec_path, providers=providers) def encoder(self, wav): feats = wav if feats.dim() == 2: # double channels - feats = feats.mean(-1) + feats = feats.mean(-1) assert feats.dim() == 1, feats.dim() feats = feats.view(1, -1) feats = feats.unsqueeze(0).cpu().detach().numpy() onnx_input = {self.model.get_inputs()[0].name: feats} logits = self.model.run(None, onnx_input) - return torch.tensor(logits[0]).transpose(1, 2).to(self.dev) \ No newline at end of file + return torch.tensor(logits[0]).transpose(1, 2).to(self.dev) diff --git a/vencoder/WavLMBasePlus.py b/vencoder/WavLMBasePlus.py new file mode 100644 index 0000000000000000000000000000000000000000..99df15be73c0c4774cea83a376f79fb68405bfa1 --- /dev/null +++ b/vencoder/WavLMBasePlus.py @@ -0,0 +1,32 @@ +import torch + +from vencoder.encoder import SpeechEncoder +from vencoder.wavlm.WavLM import WavLM, WavLMConfig + + +class WavLMBasePlus(SpeechEncoder): + def __init__(self, vec_path="pretrain/WavLM-Base+.pt", device=None): + super().__init__() + print("load model(s) from {}".format(vec_path)) + checkpoint = torch.load(vec_path) + self.cfg = WavLMConfig(checkpoint['cfg']) + if device is None: + self.dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") + else: + self.dev = torch.device(device) + self.hidden_dim = self.cfg.encoder_embed_dim + self.model = WavLM(self.cfg) + self.model.load_state_dict(checkpoint['model']) + self.model.to(self.dev).eval() + + def encoder(self, wav): + feats = wav + if feats.dim() == 2: # double channels + feats = feats.mean(-1) + assert feats.dim() == 1, feats.dim() + if self.cfg.normalize: + feats = torch.nn.functional.layer_norm(feats, feats.shape) + with torch.no_grad(): + with torch.inference_mode(): + units = self.model.extract_features(feats[None, :])[0] + return units.transpose(1, 2) diff --git a/vencoder/WhisperPPG.py b/vencoder/WhisperPPG.py index aa988b0a6d05696ea519d1652e5801302ba8a6c6..86af53e69b5f60f143a4acce0949c24812e327d1 100644 --- a/vencoder/WhisperPPG.py +++ b/vencoder/WhisperPPG.py @@ -1,12 +1,13 @@ -from vencoder.encoder import SpeechEncoder import torch -from vencoder.whisper.model import Whisper, ModelDimensions -from vencoder.whisper.audio import pad_or_trim, log_mel_spectrogram +from vencoder.encoder import SpeechEncoder +from vencoder.whisper.audio import log_mel_spectrogram, pad_or_trim +from vencoder.whisper.model import ModelDimensions, Whisper class WhisperPPG(SpeechEncoder): - def __init__(self,vec_path = "pretrain/medium.pt",device=None): + def __init__(self, vec_path="pretrain/medium.pt", device=None): + super().__init__() if device is None: self.dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") else: @@ -26,5 +27,5 @@ class WhisperPPG(SpeechEncoder): mel = log_mel_spectrogram(audio).to(self.dev) with torch.no_grad(): ppg = self.model.encoder(mel.unsqueeze(0)).squeeze().data.cpu().float().numpy() - ppg = torch.FloatTensor(ppg[:ppgln,]).to(self.dev) - return ppg[None,:,:].transpose(1, 2) + ppg = torch.FloatTensor(ppg[:ppgln, ]).to(self.dev) + return ppg[None, :, :].transpose(1, 2) diff --git a/vencoder/WhisperPPGLarge.py b/vencoder/WhisperPPGLarge.py index cab1ca646a1559c2a05b24ec38474408f27b3f08..e1d3ea212bff50c11c2711077c67800b06318e3a 100644 --- a/vencoder/WhisperPPGLarge.py +++ b/vencoder/WhisperPPGLarge.py @@ -1,12 +1,13 @@ -from vencoder.encoder import SpeechEncoder import torch -from vencoder.whisper.model import Whisper, ModelDimensions -from vencoder.whisper.audio import pad_or_trim, log_mel_spectrogram +from vencoder.encoder import SpeechEncoder +from vencoder.whisper.audio import log_mel_spectrogram, pad_or_trim +from vencoder.whisper.model import ModelDimensions, Whisper class WhisperPPGLarge(SpeechEncoder): - def __init__(self,vec_path = "pretrain/large-v2.pt",device=None): + def __init__(self, vec_path="pretrain/large-v2.pt", device=None): + super().__init__() if device is None: self.dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") else: @@ -26,5 +27,5 @@ class WhisperPPGLarge(SpeechEncoder): mel = log_mel_spectrogram(audio).to(self.dev) with torch.no_grad(): ppg = self.model.encoder(mel.unsqueeze(0)).squeeze().data.cpu().float().numpy() - ppg = torch.FloatTensor(ppg[:ppgln,]).to(self.dev) - return ppg[None,:,:].transpose(1, 2) + ppg = torch.FloatTensor(ppg[:ppgln, ]).to(self.dev) + return ppg[None, :, :].transpose(1, 2) diff --git a/vencoder/dphubert/components.py b/vencoder/dphubert/components.py index 0cc82a35581db1289a7ced76f1793e907ffbe05f..be5cc8ce28f11f4f1339578a9d2658740f103283 100644 --- a/vencoder/dphubert/components.py +++ b/vencoder/dphubert/components.py @@ -5,19 +5,19 @@ https://github.com/pytorch/audio/blob/main/torchaudio/models/wav2vec2/components """ +import math from collections import defaultdict from typing import List, Optional, Tuple -import math import torch -from torch import nn, Tensor -from torch.nn import Module, Parameter +from torch import Tensor, nn +from torch.nn import Module from .hardconcrete import HardConcrete from .pruning_utils import ( - prune_linear_layer, prune_conv1d_layer, prune_layer_norm, + prune_linear_layer, ) diff --git a/vencoder/dphubert/utils/import_huggingface_wavlm.py b/vencoder/dphubert/utils/import_huggingface_wavlm.py index 1a2ea31c14df5450298ddc5e1f56c98769144828..24a3f38ae9cc08e19010b2876b19dc9082873377 100644 --- a/vencoder/dphubert/utils/import_huggingface_wavlm.py +++ b/vencoder/dphubert/utils/import_huggingface_wavlm.py @@ -10,7 +10,7 @@ from typing import Any, Dict from torch.nn import Module -from ..model import wav2vec2_model, Wav2Vec2Model, wavlm_model +from ..model import Wav2Vec2Model, wav2vec2_model, wavlm_model _LG = logging.getLogger(__name__) diff --git a/vencoder/encoder.py b/vencoder/encoder.py index 2cf5678533cf16f2e81248535d35e4c3c1c5799a..9ad120da34893d64b47b8ebeeaaed1f822a2e0be 100644 --- a/vencoder/encoder.py +++ b/vencoder/encoder.py @@ -1,12 +1,13 @@ class SpeechEncoder(object): - def __init__(self,vec_path = "pretrain/checkpoint_best_legacy_500.pt",device=None): - self.model = None #This is Model + def __init__(self, vec_path="pretrain/checkpoint_best_legacy_500.pt", device=None): + self.model = None # This is Model self.hidden_dim = 768 pass - def encoder(self,wav): - ''' - input: wav:[batchsize,signal_length] + + def encoder(self, wav): + """ + input: wav:[signal_length] output: embedding:[batchsize,hidden_dim,wav_frame] - ''' - pass \ No newline at end of file + """ + pass diff --git a/vencoder/wavlm/WavLM.py b/vencoder/wavlm/WavLM.py new file mode 100644 index 0000000000000000000000000000000000000000..5a3986fdcc00033a9e8f1bfcd25df3799f40ed90 --- /dev/null +++ b/vencoder/wavlm/WavLM.py @@ -0,0 +1,741 @@ +# -------------------------------------------------------- +# WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf) +# Github source: https://github.com/microsoft/unilm/tree/master/wavlm +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Based on fairseq code bases +# https://github.com/pytorch/fairseq +# -------------------------------------------------------- + +import logging +import math +from typing import List, Optional, Tuple + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import LayerNorm + +from vencoder.wavlm.modules import ( + Fp32GroupNorm, + Fp32LayerNorm, + GLU_Linear, + GradMultiply, + MultiheadAttention, + SamePad, + TransposeLast, + get_activation_fn, + init_bert_params, +) + +logger = logging.getLogger(__name__) + + +def compute_mask_indices( + shape: Tuple[int, int], + padding_mask: Optional[torch.Tensor], + mask_prob: float, + mask_length: int, + mask_type: str = "static", + mask_other: float = 0.0, + min_masks: int = 0, + no_overlap: bool = False, + min_space: int = 0, +) -> np.ndarray: + """ + Computes random mask spans for a given shape + + Args: + shape: the the shape for which to compute masks. + should be of size 2 where first element is batch size and 2nd is timesteps + padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements + mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by + number of timesteps divided by length of mask span to mask approximately this percentage of all elements. + however due to overlaps, the actual number will be smaller (unless no_overlap is True) + mask_type: how to compute mask lengths + static = fixed size + uniform = sample from uniform distribution [mask_other, mask_length*2] + normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element + poisson = sample from possion distribution with lambda = mask length + min_masks: minimum number of masked spans + no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping + min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans + """ + + bsz, all_sz = shape + mask = np.full((bsz, all_sz), False) + + all_num_mask = int( + # add a random number for probabilistic rounding + mask_prob * all_sz / float(mask_length) + + np.random.rand() + ) + + all_num_mask = max(min_masks, all_num_mask) + + mask_idcs = [] + for i in range(bsz): + if padding_mask is not None: + sz = all_sz - padding_mask[i].long().sum().item() + num_mask = int( + # add a random number for probabilistic rounding + mask_prob * sz / float(mask_length) + + np.random.rand() + ) + num_mask = max(min_masks, num_mask) + else: + sz = all_sz + num_mask = all_num_mask + + if mask_type == "static": + lengths = np.full(num_mask, mask_length) + elif mask_type == "uniform": + lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask) + elif mask_type == "normal": + lengths = np.random.normal(mask_length, mask_other, size=num_mask) + lengths = [max(1, int(round(x))) for x in lengths] + elif mask_type == "poisson": + lengths = np.random.poisson(mask_length, size=num_mask) + lengths = [int(round(x)) for x in lengths] + else: + raise Exception("unknown mask selection " + mask_type) + + if sum(lengths) == 0: + lengths[0] = min(mask_length, sz - 1) + + if no_overlap: + mask_idc = [] + + def arrange(s, e, length, keep_length): + span_start = np.random.randint(s, e - length) + mask_idc.extend(span_start + i for i in range(length)) + + new_parts = [] + if span_start - s - min_space >= keep_length: + new_parts.append((s, span_start - min_space + 1)) + if e - span_start - keep_length - min_space > keep_length: + new_parts.append((span_start + length + min_space, e)) + return new_parts + + parts = [(0, sz)] + min_length = min(lengths) + for length in sorted(lengths, reverse=True): + lens = np.fromiter( + (e - s if e - s >= length + min_space else 0 for s, e in parts), + np.int, + ) + l_sum = np.sum(lens) + if l_sum == 0: + break + probs = lens / np.sum(lens) + c = np.random.choice(len(parts), p=probs) + s, e = parts.pop(c) + parts.extend(arrange(s, e, length, min_length)) + mask_idc = np.asarray(mask_idc) + else: + min_len = min(lengths) + if sz - min_len <= num_mask: + min_len = sz - num_mask - 1 + + mask_idc = np.random.choice(sz - min_len, num_mask, replace=False) + + mask_idc = np.asarray( + [ + mask_idc[j] + offset + for j in range(len(mask_idc)) + for offset in range(lengths[j]) + ] + ) + + mask_idcs.append(np.unique(mask_idc[mask_idc < sz])) + + min_len = min([len(m) for m in mask_idcs]) + for i, mask_idc in enumerate(mask_idcs): + if len(mask_idc) > min_len: + mask_idc = np.random.choice(mask_idc, min_len, replace=False) + mask[i, mask_idc] = True + + return mask + + +class WavLMConfig: + def __init__(self, cfg=None): + self.extractor_mode: str = "default" # mode for feature extractor. default has a single group norm with d groups in the first conv block, whereas layer_norm has layer norms in every block (meant to use with normalize=True) + self.encoder_layers: int = 12 # num encoder layers in the transformer + + self.encoder_embed_dim: int = 768 # encoder embedding dimension + self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN + self.encoder_attention_heads: int = 12 # num encoder attention heads + self.activation_fn: str = "gelu" # activation function to use + + self.layer_norm_first: bool = False # apply layernorm first in the transformer + self.conv_feature_layers: str = "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2" # string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...] + self.conv_bias: bool = False # include bias in conv encoder + self.feature_grad_mult: float = 1.0 # multiply feature extractor var grads by this + + self.normalize: bool = False # normalize input to have 0 mean and unit variance during training + + # dropouts + self.dropout: float = 0.1 # dropout probability for the transformer + self.attention_dropout: float = 0.1 # dropout probability for attention weights + self.activation_dropout: float = 0.0 # dropout probability after activation in FFN + self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer + self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr) + self.dropout_features: float = 0.0 # dropout to apply to the features (after feat extr) + + # masking + self.mask_length: int = 10 # mask length + self.mask_prob: float = 0.65 # probability of replacing a token with mask + self.mask_selection: str = "static" # how to choose mask length + self.mask_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indicesh + self.no_mask_overlap: bool = False # whether to allow masks to overlap + self.mask_min_space: int = 1 # min space between spans (if no overlap is enabled) + + # channel masking + self.mask_channel_length: int = 10 # length of the mask for features (channels) + self.mask_channel_prob: float = 0.0 # probability of replacing a feature with 0 + self.mask_channel_selection: str = "static" # how to choose mask length for channel masking + self.mask_channel_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indices + self.no_mask_channel_overlap: bool = False # whether to allow channel masks to overlap + self.mask_channel_min_space: int = 1 # min space between spans (if no overlap is enabled) + + # positional embeddings + self.conv_pos: int = 128 # number of filters for convolutional positional embeddings + self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding + + # relative position embedding + self.relative_position_embedding: bool = False # apply relative position embedding + self.num_buckets: int = 320 # number of buckets for relative position embedding + self.max_distance: int = 1280 # maximum distance for relative position embedding + self.gru_rel_pos: bool = False # apply gated relative position embedding + + if cfg is not None: + self.update(cfg) + + def update(self, cfg: dict): + self.__dict__.update(cfg) + + +class WavLM(nn.Module): + def __init__( + self, + cfg: WavLMConfig, + ) -> None: + super().__init__() + logger.info(f"WavLM Config: {cfg.__dict__}") + + self.cfg = cfg + feature_enc_layers = eval(cfg.conv_feature_layers) + self.embed = feature_enc_layers[-1][0] + + self.feature_extractor = ConvFeatureExtractionModel( + conv_layers=feature_enc_layers, + dropout=0.0, + mode=cfg.extractor_mode, + conv_bias=cfg.conv_bias, + ) + + self.post_extract_proj = ( + nn.Linear(self.embed, cfg.encoder_embed_dim) + if self.embed != cfg.encoder_embed_dim + else None + ) + + self.mask_prob = cfg.mask_prob + self.mask_selection = cfg.mask_selection + self.mask_other = cfg.mask_other + self.mask_length = cfg.mask_length + self.no_mask_overlap = cfg.no_mask_overlap + self.mask_min_space = cfg.mask_min_space + + self.mask_channel_prob = cfg.mask_channel_prob + self.mask_channel_selection = cfg.mask_channel_selection + self.mask_channel_other = cfg.mask_channel_other + self.mask_channel_length = cfg.mask_channel_length + self.no_mask_channel_overlap = cfg.no_mask_channel_overlap + self.mask_channel_min_space = cfg.mask_channel_min_space + + self.dropout_input = nn.Dropout(cfg.dropout_input) + self.dropout_features = nn.Dropout(cfg.dropout_features) + + self.feature_grad_mult = cfg.feature_grad_mult + + self.mask_emb = nn.Parameter( + torch.FloatTensor(cfg.encoder_embed_dim).uniform_() + ) + + self.encoder = TransformerEncoder(cfg) + self.layer_norm = LayerNorm(self.embed) + + def apply_mask(self, x, padding_mask): + B, T, C = x.shape + if self.mask_prob > 0: + mask_indices = compute_mask_indices( + (B, T), + padding_mask, + self.mask_prob, + self.mask_length, + self.mask_selection, + self.mask_other, + min_masks=2, + no_overlap=self.no_mask_overlap, + min_space=self.mask_min_space, + ) + mask_indices = torch.from_numpy(mask_indices).to(x.device) + x[mask_indices] = self.mask_emb + else: + mask_indices = None + + if self.mask_channel_prob > 0: + mask_channel_indices = compute_mask_indices( + (B, C), + None, + self.mask_channel_prob, + self.mask_channel_length, + self.mask_channel_selection, + self.mask_channel_other, + no_overlap=self.no_mask_channel_overlap, + min_space=self.mask_channel_min_space, + ) + mask_channel_indices = ( + torch.from_numpy(mask_channel_indices) + .to(x.device) + .unsqueeze(1) + .expand(-1, T, -1) + ) + x[mask_channel_indices] = 0 + + return x, mask_indices + + def forward_padding_mask( + self, features: torch.Tensor, padding_mask: torch.Tensor, + ) -> torch.Tensor: + extra = padding_mask.size(1) % features.size(1) + if extra > 0: + padding_mask = padding_mask[:, :-extra] + padding_mask = padding_mask.view( + padding_mask.size(0), features.size(1), -1 + ) + padding_mask = padding_mask.all(-1) + return padding_mask + + def extract_features( + self, + source: torch.Tensor, + padding_mask: Optional[torch.Tensor] = None, + mask: bool = False, + ret_conv: bool = False, + output_layer: Optional[int] = None, + ret_layer_results: bool = False, + ): + + if self.feature_grad_mult > 0: + features = self.feature_extractor(source) + if self.feature_grad_mult != 1.0: + features = GradMultiply.apply(features, self.feature_grad_mult) + else: + with torch.no_grad(): + features = self.feature_extractor(source) + + features = features.transpose(1, 2) + features = self.layer_norm(features) + + if padding_mask is not None: + padding_mask = self.forward_padding_mask(features, padding_mask) + + if self.post_extract_proj is not None: + features = self.post_extract_proj(features) + + features = self.dropout_input(features) + + if mask: + x, mask_indices = self.apply_mask( + features, padding_mask + ) + else: + x = features + + # feature: (B, T, D), float + # target: (B, T), long + # x: (B, T, D), float + # padding_mask: (B, T), bool + # mask_indices: (B, T), bool + x, layer_results = self.encoder( + x, + padding_mask=padding_mask, + layer=None if output_layer is None else output_layer - 1 + ) + + res = {"x": x, "padding_mask": padding_mask, "features": features, "layer_results": layer_results} + + feature = res["features"] if ret_conv else res["x"] + if ret_layer_results: + feature = (feature, res["layer_results"]) + return feature, res["padding_mask"] + + +class ConvFeatureExtractionModel(nn.Module): + def __init__( + self, + conv_layers: List[Tuple[int, int, int]], + dropout: float = 0.0, + mode: str = "default", + conv_bias: bool = False, + conv_type: str = "default" + ): + super().__init__() + + assert mode in {"default", "layer_norm"} + + def block( + n_in, + n_out, + k, + stride, + is_layer_norm=False, + is_group_norm=False, + conv_bias=False, + ): + def make_conv(): + conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias) + nn.init.kaiming_normal_(conv.weight) + return conv + + assert (is_layer_norm and is_group_norm) is False, "layer norm and group norm are exclusive" + + if is_layer_norm: + return nn.Sequential( + make_conv(), + nn.Dropout(p=dropout), + nn.Sequential( + TransposeLast(), + Fp32LayerNorm(dim, elementwise_affine=True), + TransposeLast(), + ), + nn.GELU(), + ) + elif is_group_norm: + return nn.Sequential( + make_conv(), + nn.Dropout(p=dropout), + Fp32GroupNorm(dim, dim, affine=True), + nn.GELU(), + ) + else: + return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU()) + + self.conv_type = conv_type + if self.conv_type == "default": + in_d = 1 + self.conv_layers = nn.ModuleList() + for i, cl in enumerate(conv_layers): + assert len(cl) == 3, "invalid conv definition: " + str(cl) + (dim, k, stride) = cl + + self.conv_layers.append( + block( + in_d, + dim, + k, + stride, + is_layer_norm=mode == "layer_norm", + is_group_norm=mode == "default" and i == 0, + conv_bias=conv_bias, + ) + ) + in_d = dim + elif self.conv_type == "conv2d": + in_d = 1 + self.conv_layers = nn.ModuleList() + for i, cl in enumerate(conv_layers): + assert len(cl) == 3 + (dim, k, stride) = cl + + self.conv_layers.append( + torch.nn.Conv2d(in_d, dim, k, stride) + ) + self.conv_layers.append(torch.nn.ReLU()) + in_d = dim + elif self.conv_type == "custom": + in_d = 1 + idim = 80 + self.conv_layers = nn.ModuleList() + for i, cl in enumerate(conv_layers): + assert len(cl) == 3 + (dim, k, stride) = cl + self.conv_layers.append( + torch.nn.Conv2d(in_d, dim, k, stride, padding=1) + ) + self.conv_layers.append( + torch.nn.LayerNorm([dim, idim]) + ) + self.conv_layers.append(torch.nn.ReLU()) + in_d = dim + if (i + 1) % 2 == 0: + self.conv_layers.append( + torch.nn.MaxPool2d(2, stride=2, ceil_mode=True) + ) + idim = int(math.ceil(idim / 2)) + else: + pass + + def forward(self, x, mask=None): + + # BxT -> BxCxT + x = x.unsqueeze(1) + if self.conv_type == "custom": + for conv in self.conv_layers: + if isinstance(conv, nn.LayerNorm): + x = x.transpose(1, 2) + x = conv(x).transpose(1, 2) + else: + x = conv(x) + x = x.transpose(2, 3).contiguous() + x = x.view(x.size(0), -1, x.size(-1)) + else: + for conv in self.conv_layers: + x = conv(x) + if self.conv_type == "conv2d": + b, c, t, f = x.size() + x = x.transpose(2, 3).contiguous().view(b, c * f, t) + return x + + +class TransformerEncoder(nn.Module): + def __init__(self, args): + super().__init__() + + self.dropout = args.dropout + self.embedding_dim = args.encoder_embed_dim + + self.pos_conv = nn.Conv1d( + self.embedding_dim, + self.embedding_dim, + kernel_size=args.conv_pos, + padding=args.conv_pos // 2, + groups=args.conv_pos_groups, + ) + dropout = 0 + std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim)) + nn.init.normal_(self.pos_conv.weight, mean=0, std=std) + nn.init.constant_(self.pos_conv.bias, 0) + + self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2) + self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU()) + + if hasattr(args, "relative_position_embedding"): + self.relative_position_embedding = args.relative_position_embedding + self.num_buckets = args.num_buckets + self.max_distance = args.max_distance + else: + self.relative_position_embedding = False + self.num_buckets = 0 + self.max_distance = 0 + + self.layers = nn.ModuleList( + [ + TransformerSentenceEncoderLayer( + embedding_dim=self.embedding_dim, + ffn_embedding_dim=args.encoder_ffn_embed_dim, + num_attention_heads=args.encoder_attention_heads, + dropout=self.dropout, + attention_dropout=args.attention_dropout, + activation_dropout=args.activation_dropout, + activation_fn=args.activation_fn, + layer_norm_first=args.layer_norm_first, + has_relative_attention_bias=(self.relative_position_embedding and i == 0), + num_buckets=self.num_buckets, + max_distance=self.max_distance, + gru_rel_pos=args.gru_rel_pos, + ) + for i in range(args.encoder_layers) + ] + ) + + self.layer_norm_first = args.layer_norm_first + self.layer_norm = LayerNorm(self.embedding_dim) + self.layerdrop = args.encoder_layerdrop + + self.apply(init_bert_params) + + def forward(self, x, padding_mask=None, streaming_mask=None, layer=None): + x, layer_results = self.extract_features(x, padding_mask, streaming_mask, layer) + + if self.layer_norm_first and layer is None: + x = self.layer_norm(x) + + return x, layer_results + + def extract_features(self, x, padding_mask=None, streaming_mask=None, tgt_layer=None): + + if padding_mask is not None: + x[padding_mask] = 0 + + x_conv = self.pos_conv(x.transpose(1, 2)) + x_conv = x_conv.transpose(1, 2) + x = x + x_conv + + if not self.layer_norm_first: + x = self.layer_norm(x) + + x = F.dropout(x, p=self.dropout, training=self.training) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + layer_results = [] + z = None + if tgt_layer is not None: + layer_results.append((x, z)) + r = None + pos_bias = None + for i, layer in enumerate(self.layers): + dropout_probability = np.random.random() + if not self.training or (dropout_probability > self.layerdrop): + x, z, pos_bias = layer(x, self_attn_padding_mask=padding_mask, need_weights=False, + self_attn_mask=streaming_mask, pos_bias=pos_bias) + if tgt_layer is not None: + layer_results.append((x, z)) + if i == tgt_layer: + r = x + break + + if r is not None: + x = r + + # T x B x C -> B x T x C + x = x.transpose(0, 1) + + return x, layer_results + + +class TransformerSentenceEncoderLayer(nn.Module): + """ + Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained + models. + """ + + def __init__( + self, + embedding_dim: float = 768, + ffn_embedding_dim: float = 3072, + num_attention_heads: float = 8, + dropout: float = 0.1, + attention_dropout: float = 0.1, + activation_dropout: float = 0.1, + activation_fn: str = "relu", + layer_norm_first: bool = False, + has_relative_attention_bias: bool = False, + num_buckets: int = 0, + max_distance: int = 0, + rescale_init: bool = False, + gru_rel_pos: bool = False, + ) -> None: + + super().__init__() + # Initialize parameters + self.embedding_dim = embedding_dim + self.dropout = dropout + self.activation_dropout = activation_dropout + + # Initialize blocks + self.activation_name = activation_fn + self.activation_fn = get_activation_fn(activation_fn) + self.self_attn = MultiheadAttention( + self.embedding_dim, + num_attention_heads, + dropout=attention_dropout, + self_attention=True, + has_relative_attention_bias=has_relative_attention_bias, + num_buckets=num_buckets, + max_distance=max_distance, + rescale_init=rescale_init, + gru_rel_pos=gru_rel_pos, + ) + + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(self.activation_dropout) + self.dropout3 = nn.Dropout(dropout) + + self.layer_norm_first = layer_norm_first + + # layer norm associated with the self attention layer + self.self_attn_layer_norm = LayerNorm(self.embedding_dim) + + if self.activation_name == "glu": + self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish") + else: + self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim) + self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim) + + # layer norm associated with the position wise feed-forward NN + self.final_layer_norm = LayerNorm(self.embedding_dim) + + def forward( + self, + x: torch.Tensor, + self_attn_mask: torch.Tensor = None, + self_attn_padding_mask: torch.Tensor = None, + need_weights: bool = False, + pos_bias=None + ): + """ + LayerNorm is applied either before or after the self-attention/ffn + modules similar to the original Transformer imlementation. + """ + residual = x + + if self.layer_norm_first: + x = self.self_attn_layer_norm(x) + x, attn, pos_bias = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=self_attn_padding_mask, + need_weights=False, + attn_mask=self_attn_mask, + position_bias=pos_bias + ) + x = self.dropout1(x) + x = residual + x + + residual = x + x = self.final_layer_norm(x) + if self.activation_name == "glu": + x = self.fc1(x) + else: + x = self.activation_fn(self.fc1(x)) + x = self.dropout2(x) + x = self.fc2(x) + x = self.dropout3(x) + x = residual + x + else: + x, attn, pos_bias = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=self_attn_padding_mask, + need_weights=need_weights, + attn_mask=self_attn_mask, + position_bias=pos_bias + ) + + x = self.dropout1(x) + x = residual + x + + x = self.self_attn_layer_norm(x) + + residual = x + if self.activation_name == "glu": + x = self.fc1(x) + else: + x = self.activation_fn(self.fc1(x)) + x = self.dropout2(x) + x = self.fc2(x) + x = self.dropout3(x) + x = residual + x + x = self.final_layer_norm(x) + + return x, attn, pos_bias + diff --git a/vencoder/wavlm/modules.py b/vencoder/wavlm/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..add4a1aa0042cbcbf5c3b28d4d72f017b507717d --- /dev/null +++ b/vencoder/wavlm/modules.py @@ -0,0 +1,828 @@ +# -------------------------------------------------------- +# WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf) +# Github source: https://github.com/microsoft/unilm/tree/master/wavlm +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Based on fairseq code bases +# https://github.com/pytorch/fairseq +# -------------------------------------------------------- + +import math +import warnings +from typing import Dict, Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn import Parameter + + +class TransposeLast(nn.Module): + def __init__(self, deconstruct_idx=None): + super().__init__() + self.deconstruct_idx = deconstruct_idx + + def forward(self, x): + if self.deconstruct_idx is not None: + x = x[self.deconstruct_idx] + return x.transpose(-2, -1) + + +class Fp32LayerNorm(nn.LayerNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, input): + output = F.layer_norm( + input.float(), + self.normalized_shape, + self.weight.float() if self.weight is not None else None, + self.bias.float() if self.bias is not None else None, + self.eps, + ) + return output.type_as(input) + + +class Fp32GroupNorm(nn.GroupNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, input): + output = F.group_norm( + input.float(), + self.num_groups, + self.weight.float() if self.weight is not None else None, + self.bias.float() if self.bias is not None else None, + self.eps, + ) + return output.type_as(input) + + +class GradMultiply(torch.autograd.Function): + @staticmethod + def forward(ctx, x, scale): + ctx.scale = scale + res = x.new(x) + return res + + @staticmethod + def backward(ctx, grad): + return grad * ctx.scale, None + + +class SamePad(nn.Module): + def __init__(self, kernel_size, causal=False): + super().__init__() + if causal: + self.remove = kernel_size - 1 + else: + self.remove = 1 if kernel_size % 2 == 0 else 0 + + def forward(self, x): + if self.remove > 0: + x = x[:, :, : -self.remove] + return x + + +class Swish(nn.Module): + """Swish function + """ + + def __init__(self): + """Construct an MultiHeadedAttention object.""" + super(Swish, self).__init__() + self.act = torch.nn.Sigmoid() + + def forward(self, x): + return x * self.act(x) + + +class GLU_Linear(nn.Module): + def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True): + super(GLU_Linear, self).__init__() + + self.glu_type = glu_type + self.output_dim = output_dim + + if glu_type == "sigmoid": + self.glu_act = torch.nn.Sigmoid() + elif glu_type == "swish": + self.glu_act = Swish() + elif glu_type == "relu": + self.glu_act = torch.nn.ReLU() + elif glu_type == "gelu": + self.glu_act = torch.nn.GELU() + + if bias_in_glu: + self.linear = nn.Linear(input_dim, output_dim * 2, True) + else: + self.linear = nn.Linear(input_dim, output_dim * 2, False) + + def forward(self, x): + # to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case + x = self.linear(x) + + if self.glu_type == "bilinear": + x = (x[:, :, 0:self.output_dim] * x[:, :, self.output_dim:self.output_dim * 2]) + else: + x = (x[:, :, 0:self.output_dim] * self.glu_act(x[:, :, self.output_dim:self.output_dim * 2])) + + return x + + +def gelu_accurate(x): + if not hasattr(gelu_accurate, "_a"): + gelu_accurate._a = math.sqrt(2 / math.pi) + return ( + 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3)))) + ) + + +def gelu(x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.gelu(x.float()).type_as(x) + + +def get_activation_fn(activation: str): + """Returns the activation function corresponding to `activation`""" + + if activation == "relu": + return F.relu + elif activation == "gelu": + return gelu + elif activation == "gelu_fast": + warnings.warn( + "--activation-fn=gelu_fast has been renamed to gelu_accurate" + ) + return gelu_accurate + elif activation == "gelu_accurate": + return gelu_accurate + elif activation == "tanh": + return torch.tanh + elif activation == "linear": + return lambda x: x + elif activation == "glu": + return lambda x: x + else: + raise RuntimeError("--activation-fn {} not supported".format(activation)) + + +def init_bert_params(module): + """ + Initialize the weights specific to the BERT Model. + This overrides the default initializations depending on the specified arguments. + 1. If normal_init_linear_weights is set then weights of linear + layer will be initialized using the normal distribution and + bais will be set to the specified value. + 2. If normal_init_embed_weights is set then weights of embedding + layer will be initialized using the normal distribution. + 3. If normal_init_proj_weights is set then weights of + in_project_weight for MultiHeadAttention initialized using + the normal distribution (to be validated). + """ + + def normal_(data): + # with FSDP, module params will be on CUDA, so we cast them back to CPU + # so that the RNG is consistent with and without FSDP + data.copy_( + data.cpu().normal_(mean=0.0, std=0.02).to(data.device) + ) + + if isinstance(module, nn.Linear): + normal_(module.weight.data) + if module.bias is not None: + module.bias.data.zero_() + if isinstance(module, nn.Embedding): + normal_(module.weight.data) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + if isinstance(module, MultiheadAttention): + normal_(module.q_proj.weight.data) + normal_(module.k_proj.weight.data) + normal_(module.v_proj.weight.data) + + +def quant_noise(module, p, block_size): + """ + Wraps modules and applies quantization noise to the weights for + subsequent quantization with Iterative Product Quantization as + described in "Training with Quantization Noise for Extreme Model Compression" + + Args: + - module: nn.Module + - p: amount of Quantization Noise + - block_size: size of the blocks for subsequent quantization with iPQ + + Remarks: + - Module weights must have the right sizes wrt the block size + - Only Linear, Embedding and Conv2d modules are supported for the moment + - For more detail on how to quantize by blocks with convolutional weights, + see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks" + - We implement the simplest form of noise here as stated in the paper + which consists in randomly dropping blocks + """ + + # if no quantization noise, don't register hook + if p <= 0: + return module + + # supported modules + assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d)) + + # test whether module.weight has the right sizes wrt block_size + is_conv = module.weight.ndim == 4 + + # 2D matrix + if not is_conv: + assert ( + module.weight.size(1) % block_size == 0 + ), "Input features must be a multiple of block sizes" + + # 4D matrix + else: + # 1x1 convolutions + if module.kernel_size == (1, 1): + assert ( + module.in_channels % block_size == 0 + ), "Input channels must be a multiple of block sizes" + # regular convolutions + else: + k = module.kernel_size[0] * module.kernel_size[1] + assert k % block_size == 0, "Kernel size must be a multiple of block size" + + def _forward_pre_hook(mod, input): + # no noise for evaluation + if mod.training: + if not is_conv: + # gather weight and sizes + weight = mod.weight + in_features = weight.size(1) + out_features = weight.size(0) + + # split weight matrix into blocks and randomly drop selected blocks + mask = torch.zeros( + in_features // block_size * out_features, device=weight.device + ) + mask.bernoulli_(p) + mask = mask.repeat_interleave(block_size, -1).view(-1, in_features) + + else: + # gather weight and sizes + weight = mod.weight + in_channels = mod.in_channels + out_channels = mod.out_channels + + # split weight matrix into blocks and randomly drop selected blocks + if mod.kernel_size == (1, 1): + mask = torch.zeros( + int(in_channels // block_size * out_channels), + device=weight.device, + ) + mask.bernoulli_(p) + mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels) + else: + mask = torch.zeros( + weight.size(0), weight.size(1), device=weight.device + ) + mask.bernoulli_(p) + mask = ( + mask.unsqueeze(2) + .unsqueeze(3) + .repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1]) + ) + + # scale weights and apply mask + mask = mask.to( + torch.bool + ) # x.bool() is not currently supported in TorchScript + s = 1 / (1 - p) + mod.weight.data = s * weight.masked_fill(mask, 0) + + module.register_forward_pre_hook(_forward_pre_hook) + return module + + +class MultiheadAttention(nn.Module): + """Multi-headed attention. + + See "Attention Is All You Need" for more details. + """ + + def __init__( + self, + embed_dim, + num_heads, + kdim=None, + vdim=None, + dropout=0.0, + bias=True, + add_bias_kv=False, + add_zero_attn=False, + self_attention=False, + encoder_decoder_attention=False, + q_noise=0.0, + qn_block_size=8, + has_relative_attention_bias=False, + num_buckets=32, + max_distance=128, + gru_rel_pos=False, + rescale_init=False, + ): + super().__init__() + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim + + self.num_heads = num_heads + self.dropout_module = nn.Dropout(dropout) + + self.has_relative_attention_bias = has_relative_attention_bias + self.num_buckets = num_buckets + self.max_distance = max_distance + if self.has_relative_attention_bias: + self.relative_attention_bias = nn.Embedding(num_buckets, num_heads) + + self.head_dim = embed_dim // num_heads + self.q_head_dim = self.head_dim + self.k_head_dim = self.head_dim + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + self.scaling = self.head_dim ** -0.5 + + self.self_attention = self_attention + self.encoder_decoder_attention = encoder_decoder_attention + + assert not self.self_attention or self.qkv_same_dim, ( + "Self-attention requires query, key and " "value to be of the same size" + ) + + k_bias = True + if rescale_init: + k_bias = False + + k_embed_dim = embed_dim + q_embed_dim = embed_dim + + self.k_proj = quant_noise( + nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise, qn_block_size + ) + self.v_proj = quant_noise( + nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size + ) + self.q_proj = quant_noise( + nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise, qn_block_size + ) + + self.out_proj = quant_noise( + nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size + ) + + if add_bias_kv: + self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim)) + self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim)) + else: + self.bias_k = self.bias_v = None + + self.add_zero_attn = add_zero_attn + + self.gru_rel_pos = gru_rel_pos + if self.gru_rel_pos: + self.grep_linear = nn.Linear(self.q_head_dim, 8) + self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1)) + + self.reset_parameters() + + def reset_parameters(self): + if self.qkv_same_dim: + # Empirically observed the convergence to be much better with + # the scaled initialization + nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2)) + else: + nn.init.xavier_uniform_(self.k_proj.weight) + nn.init.xavier_uniform_(self.v_proj.weight) + nn.init.xavier_uniform_(self.q_proj.weight) + + nn.init.xavier_uniform_(self.out_proj.weight) + if self.out_proj.bias is not None: + nn.init.constant_(self.out_proj.bias, 0.0) + if self.bias_k is not None: + nn.init.xavier_normal_(self.bias_k) + if self.bias_v is not None: + nn.init.xavier_normal_(self.bias_v) + if self.has_relative_attention_bias: + nn.init.xavier_normal_(self.relative_attention_bias.weight) + + def _relative_positions_bucket(self, relative_positions, bidirectional=True): + num_buckets = self.num_buckets + max_distance = self.max_distance + relative_buckets = 0 + + if bidirectional: + num_buckets = num_buckets // 2 + relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets + relative_positions = torch.abs(relative_positions) + else: + relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions)) + + max_exact = num_buckets // 2 + is_small = relative_positions < max_exact + + relative_postion_if_large = max_exact + ( + torch.log(relative_positions.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_postion_if_large = torch.min( + relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1) + ) + + relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large) + return relative_buckets + + def compute_bias(self, query_length, key_length): + context_position = torch.arange(query_length, dtype=torch.long)[:, None] + memory_position = torch.arange(key_length, dtype=torch.long)[None, :] + relative_position = memory_position - context_position + relative_position_bucket = self._relative_positions_bucket( + relative_position, + bidirectional=True + ) + relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device) + values = self.relative_attention_bias(relative_position_bucket) + values = values.permute([2, 0, 1]) + return values + + def forward( + self, + query, + key: Optional[Tensor], + value: Optional[Tensor], + key_padding_mask: Optional[Tensor] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + need_weights: bool = True, + static_kv: bool = False, + attn_mask: Optional[Tensor] = None, + before_softmax: bool = False, + need_head_weights: bool = False, + position_bias: Optional[Tensor] = None + ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: + """Input shape: Time x Batch x Channel + + Args: + key_padding_mask (ByteTensor, optional): mask to exclude + keys that are pads, of shape `(batch, src_len)`, where + padding elements are indicated by 1s. + need_weights (bool, optional): return the attention weights, + averaged over heads (default: False). + attn_mask (ByteTensor, optional): typically used to + implement causal attention, where the mask prevents the + attention from looking forward in time (default: None). + before_softmax (bool, optional): return the raw attention + weights and values before the attention softmax. + need_head_weights (bool, optional): return the attention + weights for each head. Implies *need_weights*. Default: + return the average attention weights over all heads. + """ + if need_head_weights: + need_weights = True + + is_tpu = query.device.type == "xla" + + tgt_len, bsz, embed_dim = query.size() + src_len = tgt_len + assert embed_dim == self.embed_dim + assert list(query.size()) == [tgt_len, bsz, embed_dim] + if key is not None: + src_len, key_bsz, _ = key.size() + if not torch.jit.is_scripting(): + assert key_bsz == bsz + assert value is not None + assert src_len, bsz == value.shape[:2] + + if self.has_relative_attention_bias and position_bias is None: + position_bias = self.compute_bias(tgt_len, src_len) + position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, src_len) + + if ( + not is_tpu # don't use PyTorch version on TPUs + and incremental_state is None + and not static_kv + # A workaround for quantization to work. Otherwise JIT compilation + # treats bias in linear module as method. + and not torch.jit.is_scripting() + and self.q_head_dim == self.head_dim + ): + assert key is not None and value is not None + assert attn_mask is None + + attn_mask_rel_pos = None + if position_bias is not None: + attn_mask_rel_pos = position_bias + if self.gru_rel_pos: + query_layer = query.transpose(0, 1) + new_x_shape = query_layer.size()[:-1] + (self.num_heads, -1) + query_layer = query_layer.view(*new_x_shape) + query_layer = query_layer.permute(0, 2, 1, 3) + _B, _H, _L, __ = query_layer.size() + + gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view( + _B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1) + gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0 + attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias + + attn_mask_rel_pos = attn_mask_rel_pos.view((-1, tgt_len, tgt_len)) + k_proj_bias = self.k_proj.bias + if k_proj_bias is None: + k_proj_bias = torch.zeros_like(self.q_proj.bias) + + x, attn = F.multi_head_attention_forward( + query, + key, + value, + self.embed_dim, + self.num_heads, + torch.empty([0]), + torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)), + self.bias_k, + self.bias_v, + self.add_zero_attn, + self.dropout_module.p, + self.out_proj.weight, + self.out_proj.bias, + self.training, + # self.training or self.dropout_module.apply_during_inference, + key_padding_mask, + need_weights, + attn_mask_rel_pos, + use_separate_proj_weight=True, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + ) + return x, attn, position_bias + + if incremental_state is not None: + saved_state = self._get_input_buffer(incremental_state) + if saved_state is not None and "prev_key" in saved_state: + # previous time steps are cached - no need to recompute + # key and value if they are static + if static_kv: + assert self.encoder_decoder_attention and not self.self_attention + key = value = None + else: + saved_state = None + + if self.self_attention: + q = self.q_proj(query) + k = self.k_proj(query) + v = self.v_proj(query) + elif self.encoder_decoder_attention: + # encoder-decoder attention + q = self.q_proj(query) + if key is None: + assert value is None + k = v = None + else: + k = self.k_proj(key) + v = self.v_proj(key) + + else: + assert key is not None and value is not None + q = self.q_proj(query) + k = self.k_proj(key) + v = self.v_proj(value) + q *= self.scaling + + if self.bias_k is not None: + assert self.bias_v is not None + k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) + v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) + if attn_mask is not None: + attn_mask = torch.cat( + [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 + ) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [ + key_padding_mask, + key_padding_mask.new_zeros(key_padding_mask.size(0), 1), + ], + dim=1, + ) + + q = ( + q.contiguous() + .view(tgt_len, bsz * self.num_heads, self.q_head_dim) + .transpose(0, 1) + ) + if k is not None: + k = ( + k.contiguous() + .view(-1, bsz * self.num_heads, self.k_head_dim) + .transpose(0, 1) + ) + if v is not None: + v = ( + v.contiguous() + .view(-1, bsz * self.num_heads, self.head_dim) + .transpose(0, 1) + ) + + if saved_state is not None: + # saved states are stored with shape (bsz, num_heads, seq_len, head_dim) + if "prev_key" in saved_state: + _prev_key = saved_state["prev_key"] + assert _prev_key is not None + prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim) + if static_kv: + k = prev_key + else: + assert k is not None + k = torch.cat([prev_key, k], dim=1) + src_len = k.size(1) + if "prev_value" in saved_state: + _prev_value = saved_state["prev_value"] + assert _prev_value is not None + prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim) + if static_kv: + v = prev_value + else: + assert v is not None + v = torch.cat([prev_value, v], dim=1) + prev_key_padding_mask: Optional[Tensor] = None + if "prev_key_padding_mask" in saved_state: + prev_key_padding_mask = saved_state["prev_key_padding_mask"] + assert k is not None and v is not None + key_padding_mask = MultiheadAttention._append_prev_key_padding_mask( + key_padding_mask=key_padding_mask, + prev_key_padding_mask=prev_key_padding_mask, + batch_size=bsz, + src_len=k.size(1), + static_kv=static_kv, + ) + + saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim) + saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim) + saved_state["prev_key_padding_mask"] = key_padding_mask + # In this branch incremental_state is never None + assert incremental_state is not None + incremental_state = self._set_input_buffer(incremental_state, saved_state) + assert k is not None + assert k.size(1) == src_len + + # This is part of a workaround to get around fork/join parallelism + # not supporting Optional types. + if key_padding_mask is not None and key_padding_mask.dim() == 0: + key_padding_mask = None + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz + assert key_padding_mask.size(1) == src_len + + if self.add_zero_attn: + assert v is not None + src_len += 1 + k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1) + v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1) + if attn_mask is not None: + attn_mask = torch.cat( + [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 + ) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [ + key_padding_mask, + torch.zeros(key_padding_mask.size(0), 1).type_as( + key_padding_mask + ), + ], + dim=1, + ) + + attn_weights = torch.bmm(q, k.transpose(1, 2)) + attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz) + + assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] + + if attn_mask is not None: + attn_mask = attn_mask.unsqueeze(0) + attn_weights += attn_mask + + if key_padding_mask is not None: + # don't attend to padding symbols + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + if not is_tpu: + attn_weights = attn_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), + float("-inf"), + ) + else: + attn_weights = attn_weights.transpose(0, 2) + attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf")) + attn_weights = attn_weights.transpose(0, 2) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if before_softmax: + return attn_weights, v, position_bias + + if position_bias is not None: + if self.gru_rel_pos == 1: + query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim) + _B, _H, _L, __ = query_layer.size() + gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view( + _B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1) + gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0 + position_bias = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias + + position_bias = position_bias.view(attn_weights.size()) + + attn_weights = attn_weights + position_bias + + attn_weights_float = F.softmax( + attn_weights, dim=-1 + ) + attn_weights = attn_weights_float.type_as(attn_weights) + attn_probs = self.dropout_module(attn_weights) + + assert v is not None + attn = torch.bmm(attn_probs, v) + assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] + attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn = self.out_proj(attn) + attn_weights: Optional[Tensor] = None + if need_weights: + attn_weights = attn_weights_float.view( + bsz, self.num_heads, tgt_len, src_len + ).transpose(1, 0) + if not need_head_weights: + # average attention weights over heads + attn_weights = attn_weights.mean(dim=0) + + return attn, attn_weights, position_bias + + @staticmethod + def _append_prev_key_padding_mask( + key_padding_mask: Optional[Tensor], + prev_key_padding_mask: Optional[Tensor], + batch_size: int, + src_len: int, + static_kv: bool, + ) -> Optional[Tensor]: + # saved key padding masks have shape (bsz, seq_len) + if prev_key_padding_mask is not None and static_kv: + new_key_padding_mask = prev_key_padding_mask + elif prev_key_padding_mask is not None and key_padding_mask is not None: + new_key_padding_mask = torch.cat( + [prev_key_padding_mask.float(), key_padding_mask.float()], dim=1 + ) + # During incremental decoding, as the padding token enters and + # leaves the frame, there will be a time when prev or current + # is None + elif prev_key_padding_mask is not None: + if src_len > prev_key_padding_mask.size(1): + filler = torch.zeros( + (batch_size, src_len - prev_key_padding_mask.size(1)), + device=prev_key_padding_mask.device, + ) + new_key_padding_mask = torch.cat( + [prev_key_padding_mask.float(), filler.float()], dim=1 + ) + else: + new_key_padding_mask = prev_key_padding_mask.float() + elif key_padding_mask is not None: + if src_len > key_padding_mask.size(1): + filler = torch.zeros( + (batch_size, src_len - key_padding_mask.size(1)), + device=key_padding_mask.device, + ) + new_key_padding_mask = torch.cat( + [filler.float(), key_padding_mask.float()], dim=1 + ) + else: + new_key_padding_mask = key_padding_mask.float() + else: + new_key_padding_mask = prev_key_padding_mask + return new_key_padding_mask + + def _get_input_buffer( + self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] + ) -> Dict[str, Optional[Tensor]]: + result = self.get_incremental_state(incremental_state, "attn_state") + if result is not None: + return result + else: + empty_result: Dict[str, Optional[Tensor]] = {} + return empty_result + + def _set_input_buffer( + self, + incremental_state: Dict[str, Dict[str, Optional[Tensor]]], + buffer: Dict[str, Optional[Tensor]], + ): + return self.set_incremental_state(incremental_state, "attn_state", buffer) + + def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int): + return attn_weights diff --git a/vencoder/whisper/audio.py b/vencoder/whisper/audio.py index 3bdb70ba9357e95ff05853dcc06437c3401ef3be..05890dc195a376181c21072eb0a8af24cf29928a 100644 --- a/vencoder/whisper/audio.py +++ b/vencoder/whisper/audio.py @@ -1,4 +1,3 @@ -import os from functools import lru_cache from typing import Union @@ -6,11 +5,10 @@ import ffmpeg import numpy as np import torch import torch.nn.functional as F +from librosa.filters import mel as librosa_mel_fn from .utils import exact_div -from librosa.filters import mel as librosa_mel_fn - # hard-coded audio hyperparameters SAMPLE_RATE = 16000 N_FFT = 400 diff --git a/vencoder/whisper/decoding.py b/vencoder/whisper/decoding.py index 603546d4c9ff67514d2567576935b974fe373bef..45e50b1c33c2c8f9ca6572e6175b8d6051ae02ee 100644 --- a/vencoder/whisper/decoding.py +++ b/vencoder/whisper/decoding.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import Dict, List, Tuple, Iterable, Optional, Sequence, Union, TYPE_CHECKING +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union import numpy as np import torch @@ -32,7 +32,7 @@ def detect_language(model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None) if tokenizer is None: tokenizer = get_tokenizer(model.is_multilingual) if tokenizer.language is None or tokenizer.language_token not in tokenizer.sot_sequence: - raise ValueError(f"This model doesn't have language tokens so it can't perform lang id") + raise ValueError("This model doesn't have language tokens so it can't perform lang id") single = mel.ndim == 2 if single: diff --git a/vencoder/whisper/model.py b/vencoder/whisper/model.py index cb3781c17a1e78a33bf62246e5134e8512206d0d..f3de4d32cb9646964074401aad176dbef9ef2125 100644 --- a/vencoder/whisper/model.py +++ b/vencoder/whisper/model.py @@ -1,14 +1,13 @@ from dataclasses import dataclass -from typing import Dict -from typing import Iterable, Optional +from typing import Dict, Iterable, Optional import numpy as np import torch import torch.nn.functional as F -from torch import Tensor -from torch import nn +from torch import Tensor, nn -from .decoding import detect_language as detect_language_function, decode as decode_function +from .decoding import decode as decode_function +from .decoding import detect_language as detect_language_function @dataclass diff --git a/vencoder/whisper/tokenizer.py b/vencoder/whisper/tokenizer.py index a27cb359ee891590d3f793624f9f8ec768a26cc3..b15645dc7e15ca9f601413076299b362293eae6d 100644 --- a/vencoder/whisper/tokenizer.py +++ b/vencoder/whisper/tokenizer.py @@ -196,7 +196,7 @@ class Tokenizer: def language_token(self) -> int: """Returns the token id corresponding to the value of the `language` field""" if self.language is None: - raise ValueError(f"This tokenizer does not have language token configured") + raise ValueError("This tokenizer does not have language token configured") additional_tokens = dict( zip(