darksakura commited on
Commit
537486f
1 Parent(s): 6f5bbf2

Upload 165 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. cluster/__init__.py +1 -1
  2. cluster/__pycache__/__init__.cpython-38.pyc +0 -0
  3. cluster/__pycache__/kmeans.cpython-38.pyc +0 -0
  4. cluster/km_train.py +80 -0
  5. cluster/kmeans.py +204 -0
  6. cluster/train_cluster.py +33 -37
  7. diffusion/__pycache__/__init__.cpython-38.pyc +0 -0
  8. diffusion/__pycache__/data_loaders.cpython-38.pyc +0 -0
  9. diffusion/__pycache__/diffusion.cpython-38.pyc +0 -0
  10. diffusion/__pycache__/dpm_solver_pytorch.cpython-38.pyc +0 -0
  11. diffusion/__pycache__/solver.cpython-38.pyc +0 -0
  12. diffusion/__pycache__/unit2mel.cpython-38.pyc +0 -0
  13. diffusion/__pycache__/vocoder.cpython-38.pyc +0 -0
  14. diffusion/__pycache__/wavenet.cpython-38.pyc +0 -0
  15. diffusion/data_loaders.py +12 -8
  16. diffusion/diffusion.py +90 -11
  17. diffusion/diffusion_onnx.py +13 -11
  18. diffusion/dpm_solver_pytorch.py +425 -319
  19. diffusion/infer_gt_mel.py +1 -1
  20. diffusion/logger/__pycache__/__init__.cpython-38.pyc +0 -0
  21. diffusion/logger/__pycache__/saver.cpython-38.pyc +0 -0
  22. diffusion/logger/__pycache__/utils.cpython-38.pyc +0 -0
  23. diffusion/logger/saver.py +6 -11
  24. diffusion/logger/utils.py +5 -4
  25. diffusion/onnx_export.py +22 -13
  26. diffusion/solver.py +14 -9
  27. diffusion/uni_pc.py +733 -0
  28. diffusion/unit2mel.py +31 -11
  29. diffusion/vocoder.py +4 -3
  30. modules/DSConv.py +76 -0
  31. modules/F0Predictor/CrepeF0Predictor.py +4 -2
  32. modules/F0Predictor/DioF0Predictor.py +22 -34
  33. modules/F0Predictor/HarvestF0Predictor.py +21 -34
  34. modules/F0Predictor/PMF0Predictor.py +22 -33
  35. modules/F0Predictor/RMVPEF0Predictor.py +106 -0
  36. modules/F0Predictor/__pycache__/CrepeF0Predictor.cpython-38.pyc +0 -0
  37. modules/F0Predictor/__pycache__/F0Predictor.cpython-38.pyc +0 -0
  38. modules/F0Predictor/__pycache__/HarvestF0Predictor.cpython-38.pyc +0 -0
  39. modules/F0Predictor/__pycache__/PMF0Predictor.cpython-38.pyc +0 -0
  40. modules/F0Predictor/__pycache__/RMVPEF0Predictor.cpython-38.pyc +0 -0
  41. modules/F0Predictor/__pycache__/__init__.cpython-38.pyc +0 -0
  42. modules/F0Predictor/__pycache__/crepe.cpython-38.pyc +0 -0
  43. modules/F0Predictor/crepe.py +11 -11
  44. modules/F0Predictor/rmvpe/__init__.py +10 -0
  45. modules/F0Predictor/rmvpe/__pycache__/__init__.cpython-38.pyc +0 -0
  46. modules/F0Predictor/rmvpe/__pycache__/constants.cpython-38.pyc +0 -0
  47. modules/F0Predictor/rmvpe/__pycache__/deepunet.cpython-38.pyc +0 -0
  48. modules/F0Predictor/rmvpe/__pycache__/inference.cpython-38.pyc +0 -0
  49. modules/F0Predictor/rmvpe/__pycache__/model.cpython-38.pyc +0 -0
  50. modules/F0Predictor/rmvpe/__pycache__/seq.cpython-38.pyc +0 -0
cluster/__init__.py CHANGED
@@ -1,7 +1,7 @@
1
- import numpy as np
2
  import torch
3
  from sklearn.cluster import KMeans
4
 
 
5
  def get_cluster_model(ckpt_path):
6
  checkpoint = torch.load(ckpt_path)
7
  kmeans_dict = {}
 
 
1
  import torch
2
  from sklearn.cluster import KMeans
3
 
4
+
5
  def get_cluster_model(ckpt_path):
6
  checkpoint = torch.load(ckpt_path)
7
  kmeans_dict = {}
cluster/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (1.06 kB). View file
 
cluster/__pycache__/kmeans.cpython-38.pyc ADDED
Binary file (6.93 kB). View file
 
cluster/km_train.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time,pdb
2
+ import tqdm
3
+ from time import time as ttime
4
+ import os
5
+ from pathlib import Path
6
+ import logging
7
+ import argparse
8
+ from cluster.kmeans import KMeansGPU
9
+ import torch
10
+ import numpy as np
11
+ from sklearn.cluster import KMeans,MiniBatchKMeans
12
+
13
+ logging.basicConfig(level=logging.INFO)
14
+ logger = logging.getLogger(__name__)
15
+ from time import time as ttime
16
+ import pynvml,torch
17
+
18
+ def train_cluster(in_dir, n_clusters, use_minibatch=True, verbose=False,use_gpu=False):#gpu_minibatch真拉,虽然库支持但是也不考虑
19
+ logger.info(f"Loading features from {in_dir}")
20
+ features = []
21
+ nums = 0
22
+ for path in tqdm.tqdm(in_dir.glob("*.soft.pt")):
23
+ # for name in os.listdir(in_dir):
24
+ # path="%s/%s"%(in_dir,name)
25
+ features.append(torch.load(path,map_location="cpu").squeeze(0).numpy().T)
26
+ # print(features[-1].shape)
27
+ features = np.concatenate(features, axis=0)
28
+ print(nums, features.nbytes/ 1024**2, "MB , shape:",features.shape, features.dtype)
29
+ features = features.astype(np.float32)
30
+ logger.info(f"Clustering features of shape: {features.shape}")
31
+ t = time.time()
32
+ if(use_gpu==False):
33
+ if use_minibatch:
34
+ kmeans = MiniBatchKMeans(n_clusters=n_clusters,verbose=verbose, batch_size=4096, max_iter=80).fit(features)
35
+ else:
36
+ kmeans = KMeans(n_clusters=n_clusters,verbose=verbose).fit(features)
37
+ else:
38
+ kmeans = KMeansGPU(n_clusters=n_clusters, mode='euclidean', verbose=2 if verbose else 0,max_iter=500,tol=1e-2)#
39
+ features=torch.from_numpy(features)#.to(device)
40
+ labels = kmeans.fit_predict(features)#
41
+
42
+ print(time.time()-t, "s")
43
+
44
+ x = {
45
+ "n_features_in_": kmeans.n_features_in_ if use_gpu==False else features.shape[0],
46
+ "_n_threads": kmeans._n_threads if use_gpu==False else 4,
47
+ "cluster_centers_": kmeans.cluster_centers_ if use_gpu==False else kmeans.centroids.cpu().numpy(),
48
+ }
49
+ print("end")
50
+
51
+ return x
52
+
53
+ if __name__ == "__main__":
54
+ parser = argparse.ArgumentParser()
55
+ parser.add_argument('--dataset', type=Path, default="./dataset/44k",
56
+ help='path of training data directory')
57
+ parser.add_argument('--output', type=Path, default="logs/44k",
58
+ help='path of model output directory')
59
+
60
+ args = parser.parse_args()
61
+
62
+ checkpoint_dir = args.output
63
+ dataset = args.dataset
64
+ n_clusters = 1000
65
+
66
+ ckpt = {}
67
+ for spk in os.listdir(dataset):
68
+ if os.path.isdir(dataset/spk):
69
+ print(f"train kmeans for {spk}...")
70
+ in_dir = dataset/spk
71
+ x = train_cluster(in_dir, n_clusters,use_minibatch=False,verbose=False,use_gpu=True)
72
+ ckpt[spk] = x
73
+
74
+ checkpoint_path = checkpoint_dir / f"kmeans_{n_clusters}.pt"
75
+ checkpoint_path.parent.mkdir(exist_ok=True, parents=True)
76
+ torch.save(
77
+ ckpt,
78
+ checkpoint_path,
79
+ )
80
+
cluster/kmeans.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from time import time
2
+
3
+ import numpy as np
4
+ import pynvml
5
+ import torch
6
+ from torch.nn.functional import normalize
7
+
8
+
9
+ # device=torch.device("cuda:0")
10
+ def _kpp(data: torch.Tensor, k: int, sample_size: int = -1):
11
+ """ Picks k points in the data based on the kmeans++ method.
12
+
13
+ Parameters
14
+ ----------
15
+ data : torch.Tensor
16
+ Expect a rank 1 or 2 array. Rank 1 is assumed to describe 1-D
17
+ data, rank 2 multidimensional data, in which case one
18
+ row is one observation.
19
+ k : int
20
+ Number of samples to generate.
21
+ sample_size : int
22
+ sample data to avoid memory overflow during calculation
23
+
24
+ Returns
25
+ -------
26
+ init : ndarray
27
+ A 'k' by 'N' containing the initial centroids.
28
+
29
+ References
30
+ ----------
31
+ .. [1] D. Arthur and S. Vassilvitskii, "k-means++: the advantages of
32
+ careful seeding", Proceedings of the Eighteenth Annual ACM-SIAM Symposium
33
+ on Discrete Algorithms, 2007.
34
+ .. [2] scipy/cluster/vq.py: _kpp
35
+ """
36
+ batch_size=data.shape[0]
37
+ if batch_size>sample_size:
38
+ data = data[torch.randint(0, batch_size,[sample_size], device=data.device)]
39
+ dims = data.shape[1] if len(data.shape) > 1 else 1
40
+ init = torch.zeros((k, dims)).to(data.device)
41
+ r = torch.distributions.uniform.Uniform(0, 1)
42
+ for i in range(k):
43
+ if i == 0:
44
+ init[i, :] = data[torch.randint(data.shape[0], [1])]
45
+ else:
46
+ D2 = torch.cdist(init[:i, :][None, :], data[None, :], p=2)[0].amin(dim=0)
47
+ probs = D2 / torch.sum(D2)
48
+ cumprobs = torch.cumsum(probs, dim=0)
49
+ init[i, :] = data[torch.searchsorted(cumprobs, r.sample([1]).to(data.device))]
50
+ return init
51
+ class KMeansGPU:
52
+ '''
53
+ Kmeans clustering algorithm implemented with PyTorch
54
+
55
+ Parameters:
56
+ n_clusters: int,
57
+ Number of clusters
58
+
59
+ max_iter: int, default: 100
60
+ Maximum number of iterations
61
+
62
+ tol: float, default: 0.0001
63
+ Tolerance
64
+
65
+ verbose: int, default: 0
66
+ Verbosity
67
+
68
+ mode: {'euclidean', 'cosine'}, default: 'euclidean'
69
+ Type of distance measure
70
+
71
+ init_method: {'random', 'point', '++'}
72
+ Type of initialization
73
+
74
+ minibatch: {None, int}, default: None
75
+ Batch size of MinibatchKmeans algorithm
76
+ if None perform full KMeans algorithm
77
+
78
+ Attributes:
79
+ centroids: torch.Tensor, shape: [n_clusters, n_features]
80
+ cluster centroids
81
+ '''
82
+ def __init__(self, n_clusters, max_iter=200, tol=1e-4, verbose=0, mode="euclidean",device=torch.device("cuda:0")):
83
+ self.n_clusters = n_clusters
84
+ self.max_iter = max_iter
85
+ self.tol = tol
86
+ self.verbose = verbose
87
+ self.mode = mode
88
+ self.device=device
89
+ pynvml.nvmlInit()
90
+ gpu_handle = pynvml.nvmlDeviceGetHandleByIndex(device.index)
91
+ info = pynvml.nvmlDeviceGetMemoryInfo(gpu_handle)
92
+ self.minibatch=int(33e6/self.n_clusters*info.free/ 1024 / 1024 / 1024)
93
+ print("free_mem/GB:",info.free/ 1024 / 1024 / 1024,"minibatch:",self.minibatch)
94
+
95
+ @staticmethod
96
+ def cos_sim(a, b):
97
+ """
98
+ Compute cosine similarity of 2 sets of vectors
99
+
100
+ Parameters:
101
+ a: torch.Tensor, shape: [m, n_features]
102
+
103
+ b: torch.Tensor, shape: [n, n_features]
104
+ """
105
+ return normalize(a, dim=-1) @ normalize(b, dim=-1).transpose(-2, -1)
106
+
107
+ @staticmethod
108
+ def euc_sim(a, b):
109
+ """
110
+ Compute euclidean similarity of 2 sets of vectors
111
+ Parameters:
112
+ a: torch.Tensor, shape: [m, n_features]
113
+ b: torch.Tensor, shape: [n, n_features]
114
+ """
115
+ return 2 * a @ b.transpose(-2, -1) -(a**2).sum(dim=1)[..., :, None] - (b**2).sum(dim=1)[..., None, :]
116
+
117
+ def max_sim(self, a, b):
118
+ """
119
+ Compute maximum similarity (or minimum distance) of each vector
120
+ in a with all of the vectors in b
121
+ Parameters:
122
+ a: torch.Tensor, shape: [m, n_features]
123
+ b: torch.Tensor, shape: [n, n_features]
124
+ """
125
+ if self.mode == 'cosine':
126
+ sim_func = self.cos_sim
127
+ elif self.mode == 'euclidean':
128
+ sim_func = self.euc_sim
129
+ sim = sim_func(a, b)
130
+ max_sim_v, max_sim_i = sim.max(dim=-1)
131
+ return max_sim_v, max_sim_i
132
+
133
+ def fit_predict(self, X):
134
+ """
135
+ Combination of fit() and predict() methods.
136
+ This is faster than calling fit() and predict() seperately.
137
+ Parameters:
138
+ X: torch.Tensor, shape: [n_samples, n_features]
139
+ centroids: {torch.Tensor, None}, default: None
140
+ if given, centroids will be initialized with given tensor
141
+ if None, centroids will be randomly chosen from X
142
+ Return:
143
+ labels: torch.Tensor, shape: [n_samples]
144
+
145
+ mini_=33kk/k*remain
146
+ mini=min(mini_,fea_shape)
147
+ offset=log2(k/1000)*1.5
148
+ kpp_all=min(mini_*10/offset,fea_shape)
149
+ kpp_sample=min(mini_/12/offset,fea_shape)
150
+ """
151
+ assert isinstance(X, torch.Tensor), "input must be torch.Tensor"
152
+ assert X.dtype in [torch.half, torch.float, torch.double], "input must be floating point"
153
+ assert X.ndim == 2, "input must be a 2d tensor with shape: [n_samples, n_features] "
154
+ # print("verbose:%s"%self.verbose)
155
+
156
+ offset = np.power(1.5,np.log(self.n_clusters / 1000))/np.log(2)
157
+ with torch.no_grad():
158
+ batch_size= X.shape[0]
159
+ # print(self.minibatch, int(self.minibatch * 10 / offset), batch_size)
160
+ start_time = time()
161
+ if (self.minibatch*10//offset< batch_size):
162
+ x = X[torch.randint(0, batch_size,[int(self.minibatch*10/offset)])].to(self.device)
163
+ else:
164
+ x = X.to(self.device)
165
+ # print(x.device)
166
+ self.centroids = _kpp(x, self.n_clusters, min(int(self.minibatch/12/offset),batch_size))
167
+ del x
168
+ torch.cuda.empty_cache()
169
+ # self.centroids = self.centroids.to(self.device)
170
+ num_points_in_clusters = torch.ones(self.n_clusters, device=self.device, dtype=X.dtype)#全1
171
+ closest = None#[3098036]#int64
172
+ if(self.minibatch>=batch_size//2 and self.minibatch<batch_size):
173
+ X = X[torch.randint(0, batch_size,[self.minibatch])].to(self.device)
174
+ elif(self.minibatch>=batch_size):
175
+ X=X.to(self.device)
176
+ for i in range(self.max_iter):
177
+ iter_time = time()
178
+ if self.minibatch<batch_size//2:#可用minibatch数太小,每次都得从内存倒腾到显存
179
+ x = X[torch.randint(0, batch_size, [self.minibatch])].to(self.device)
180
+ else:#否则直接全部缓存
181
+ x = X
182
+
183
+ closest = self.max_sim(a=x, b=self.centroids)[1].to(torch.int16)#[3098036]#int64#0~999
184
+ matched_clusters, counts = closest.unique(return_counts=True)#int64#1k
185
+ expanded_closest = closest[None].expand(self.n_clusters, -1)#[1000, 3098036]#int16#0~999
186
+ mask = (expanded_closest==torch.arange(self.n_clusters, device=self.device)[:, None]).to(X.dtype)#==后者是int64*1000
187
+ c_grad = mask @ x / mask.sum(-1)[..., :, None]
188
+ c_grad[c_grad!=c_grad] = 0 # remove NaNs
189
+ error = (c_grad - self.centroids).pow(2).sum()
190
+ if self.minibatch is not None:
191
+ lr = 1/num_points_in_clusters[:,None] * 0.9 + 0.1
192
+ else:
193
+ lr = 1
194
+ matched_clusters=matched_clusters.long()
195
+ num_points_in_clusters[matched_clusters] += counts#IndexError: tensors used as indices must be long, byte or bool tensors
196
+ self.centroids = self.centroids * (1-lr) + c_grad * lr
197
+ if self.verbose >= 2:
198
+ print('iter:', i, 'error:', error.item(), 'time spent:', round(time()-iter_time, 4))
199
+ if error <= self.tol:
200
+ break
201
+
202
+ if self.verbose >= 1:
203
+ print(f'used {i+1} iterations ({round(time()-start_time, 4)}s) to cluster {batch_size} items into {self.n_clusters} clusters')
204
+ return closest
cluster/train_cluster.py CHANGED
@@ -1,67 +1,79 @@
 
 
1
  import os
2
- from glob import glob
3
  from pathlib import Path
4
- import torch
5
- import logging
6
- import argparse
7
- import torch
8
  import numpy as np
9
- from sklearn.cluster import KMeans, MiniBatchKMeans
10
  import tqdm
 
 
 
11
  logging.basicConfig(level=logging.INFO)
12
  logger = logging.getLogger(__name__)
13
- import time
14
- import random
15
 
16
- def train_cluster(in_dir, n_clusters, use_minibatch=True, verbose=False):
 
 
17
 
18
  logger.info(f"Loading features from {in_dir}")
19
  features = []
20
  nums = 0
21
  for path in tqdm.tqdm(in_dir.glob("*.soft.pt")):
22
- features.append(torch.load(path).squeeze(0).numpy().T)
 
 
23
  # print(features[-1].shape)
24
  features = np.concatenate(features, axis=0)
25
  print(nums, features.nbytes/ 1024**2, "MB , shape:",features.shape, features.dtype)
26
  features = features.astype(np.float32)
27
  logger.info(f"Clustering features of shape: {features.shape}")
28
  t = time.time()
29
- if use_minibatch:
30
- kmeans = MiniBatchKMeans(n_clusters=n_clusters,verbose=verbose, batch_size=4096, max_iter=80).fit(features)
 
 
 
31
  else:
32
- kmeans = KMeans(n_clusters=n_clusters,verbose=verbose).fit(features)
 
 
 
33
  print(time.time()-t, "s")
34
 
35
  x = {
36
- "n_features_in_": kmeans.n_features_in_,
37
- "_n_threads": kmeans._n_threads,
38
- "cluster_centers_": kmeans.cluster_centers_,
39
  }
40
  print("end")
41
 
42
  return x
43
 
44
-
45
  if __name__ == "__main__":
46
-
47
  parser = argparse.ArgumentParser()
48
  parser.add_argument('--dataset', type=Path, default="./dataset/44k",
49
  help='path of training data directory')
50
  parser.add_argument('--output', type=Path, default="logs/44k",
51
  help='path of model output directory')
 
 
 
52
 
53
  args = parser.parse_args()
54
 
55
  checkpoint_dir = args.output
56
  dataset = args.dataset
 
57
  n_clusters = 10000
58
-
59
  ckpt = {}
60
  for spk in os.listdir(dataset):
61
  if os.path.isdir(dataset/spk):
62
  print(f"train kmeans for {spk}...")
63
  in_dir = dataset/spk
64
- x = train_cluster(in_dir, n_clusters, verbose=False)
65
  ckpt[spk] = x
66
 
67
  checkpoint_path = checkpoint_dir / f"kmeans_{n_clusters}.pt"
@@ -70,20 +82,4 @@ if __name__ == "__main__":
70
  ckpt,
71
  checkpoint_path,
72
  )
73
-
74
-
75
- # import cluster
76
- # for spk in tqdm.tqdm(os.listdir("dataset")):
77
- # if os.path.isdir(f"dataset/{spk}"):
78
- # print(f"start kmeans inference for {spk}...")
79
- # for feature_path in tqdm.tqdm(glob(f"dataset/{spk}/*.discrete.npy", recursive=True)):
80
- # mel_path = feature_path.replace(".discrete.npy",".mel.npy")
81
- # mel_spectrogram = np.load(mel_path)
82
- # feature_len = mel_spectrogram.shape[-1]
83
- # c = np.load(feature_path)
84
- # c = utils.tools.repeat_expand_2d(torch.FloatTensor(c), feature_len).numpy()
85
- # feature = c.T
86
- # feature_class = cluster.get_cluster_result(feature, spk)
87
- # np.save(feature_path.replace(".discrete.npy", ".discrete_class.npy"), feature_class)
88
-
89
-
 
1
+ import argparse
2
+ import logging
3
  import os
4
+ import time
5
  from pathlib import Path
6
+
 
 
 
7
  import numpy as np
8
+ import torch
9
  import tqdm
10
+ from kmeans import KMeansGPU
11
+ from sklearn.cluster import KMeans, MiniBatchKMeans
12
+
13
  logging.basicConfig(level=logging.INFO)
14
  logger = logging.getLogger(__name__)
 
 
15
 
16
+ def train_cluster(in_dir, n_clusters, use_minibatch=True, verbose=False,use_gpu=False):#gpu_minibatch真拉,虽然库支持但是也不考虑
17
+ if str(in_dir).endswith(".ipynb_checkpoints"):
18
+ logger.info(f"Ignore {in_dir}")
19
 
20
  logger.info(f"Loading features from {in_dir}")
21
  features = []
22
  nums = 0
23
  for path in tqdm.tqdm(in_dir.glob("*.soft.pt")):
24
+ # for name in os.listdir(in_dir):
25
+ # path="%s/%s"%(in_dir,name)
26
+ features.append(torch.load(path,map_location="cpu").squeeze(0).numpy().T)
27
  # print(features[-1].shape)
28
  features = np.concatenate(features, axis=0)
29
  print(nums, features.nbytes/ 1024**2, "MB , shape:",features.shape, features.dtype)
30
  features = features.astype(np.float32)
31
  logger.info(f"Clustering features of shape: {features.shape}")
32
  t = time.time()
33
+ if(use_gpu is False):
34
+ if use_minibatch:
35
+ kmeans = MiniBatchKMeans(n_clusters=n_clusters,verbose=verbose, batch_size=4096, max_iter=80).fit(features)
36
+ else:
37
+ kmeans = KMeans(n_clusters=n_clusters,verbose=verbose).fit(features)
38
  else:
39
+ kmeans = KMeansGPU(n_clusters=n_clusters, mode='euclidean', verbose=2 if verbose else 0,max_iter=500,tol=1e-2)#
40
+ features=torch.from_numpy(features)#.to(device)
41
+ kmeans.fit_predict(features)#
42
+
43
  print(time.time()-t, "s")
44
 
45
  x = {
46
+ "n_features_in_": kmeans.n_features_in_ if use_gpu is False else features.shape[1],
47
+ "_n_threads": kmeans._n_threads if use_gpu is False else 4,
48
+ "cluster_centers_": kmeans.cluster_centers_ if use_gpu is False else kmeans.centroids.cpu().numpy(),
49
  }
50
  print("end")
51
 
52
  return x
53
 
 
54
  if __name__ == "__main__":
 
55
  parser = argparse.ArgumentParser()
56
  parser.add_argument('--dataset', type=Path, default="./dataset/44k",
57
  help='path of training data directory')
58
  parser.add_argument('--output', type=Path, default="logs/44k",
59
  help='path of model output directory')
60
+ parser.add_argument('--gpu',action='store_true', default=False ,
61
+ help='to use GPU')
62
+
63
 
64
  args = parser.parse_args()
65
 
66
  checkpoint_dir = args.output
67
  dataset = args.dataset
68
+ use_gpu = args.gpu
69
  n_clusters = 10000
70
+
71
  ckpt = {}
72
  for spk in os.listdir(dataset):
73
  if os.path.isdir(dataset/spk):
74
  print(f"train kmeans for {spk}...")
75
  in_dir = dataset/spk
76
+ x = train_cluster(in_dir, n_clusters,use_minibatch=False,verbose=False,use_gpu=use_gpu)
77
  ckpt[spk] = x
78
 
79
  checkpoint_path = checkpoint_dir / f"kmeans_{n_clusters}.pt"
 
82
  ckpt,
83
  checkpoint_path,
84
  )
85
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusion/__pycache__/__init__.cpython-38.pyc CHANGED
Binary files a/diffusion/__pycache__/__init__.cpython-38.pyc and b/diffusion/__pycache__/__init__.cpython-38.pyc differ
 
diffusion/__pycache__/data_loaders.cpython-38.pyc CHANGED
Binary files a/diffusion/__pycache__/data_loaders.cpython-38.pyc and b/diffusion/__pycache__/data_loaders.cpython-38.pyc differ
 
diffusion/__pycache__/diffusion.cpython-38.pyc CHANGED
Binary files a/diffusion/__pycache__/diffusion.cpython-38.pyc and b/diffusion/__pycache__/diffusion.cpython-38.pyc differ
 
diffusion/__pycache__/dpm_solver_pytorch.cpython-38.pyc CHANGED
Binary files a/diffusion/__pycache__/dpm_solver_pytorch.cpython-38.pyc and b/diffusion/__pycache__/dpm_solver_pytorch.cpython-38.pyc differ
 
diffusion/__pycache__/solver.cpython-38.pyc CHANGED
Binary files a/diffusion/__pycache__/solver.cpython-38.pyc and b/diffusion/__pycache__/solver.cpython-38.pyc differ
 
diffusion/__pycache__/unit2mel.cpython-38.pyc CHANGED
Binary files a/diffusion/__pycache__/unit2mel.cpython-38.pyc and b/diffusion/__pycache__/unit2mel.cpython-38.pyc differ
 
diffusion/__pycache__/vocoder.cpython-38.pyc CHANGED
Binary files a/diffusion/__pycache__/vocoder.cpython-38.pyc and b/diffusion/__pycache__/vocoder.cpython-38.pyc differ
 
diffusion/__pycache__/wavenet.cpython-38.pyc CHANGED
Binary files a/diffusion/__pycache__/wavenet.cpython-38.pyc and b/diffusion/__pycache__/wavenet.cpython-38.pyc differ
 
diffusion/data_loaders.py CHANGED
@@ -1,13 +1,14 @@
1
  import os
2
  import random
3
- import re
4
- import numpy as np
5
  import librosa
 
6
  import torch
7
- import random
8
- from utils import repeat_expand_2d
9
- from tqdm import tqdm
10
  from torch.utils.data import Dataset
 
 
 
 
11
 
12
  def traverse_dir(
13
  root_dir,
@@ -63,6 +64,7 @@ def get_data_loaders(args, whole_audio=False):
63
  spk=args.spk,
64
  device=args.train.cache_device,
65
  fp16=args.train.cache_fp16,
 
66
  use_aug=True)
67
  loader_train = torch.utils.data.DataLoader(
68
  data_train ,
@@ -81,6 +83,7 @@ def get_data_loaders(args, whole_audio=False):
81
  whole_audio=True,
82
  spk=args.spk,
83
  extensions=args.data.extensions,
 
84
  n_spk=args.model.n_spk)
85
  loader_valid = torch.utils.data.DataLoader(
86
  data_valid,
@@ -107,6 +110,7 @@ class AudioDataset(Dataset):
107
  device='cpu',
108
  fp16=False,
109
  use_aug=False,
 
110
  ):
111
  super().__init__()
112
 
@@ -118,6 +122,7 @@ class AudioDataset(Dataset):
118
  self.use_aug = use_aug
119
  self.data_buffer={}
120
  self.pitch_aug_dict = {}
 
121
  # np.load(os.path.join(self.path_root, 'pitch_aug_dict.npy'), allow_pickle=True).item()
122
  if load_all_data:
123
  print('Load all the data filelists:', filelists)
@@ -126,7 +131,6 @@ class AudioDataset(Dataset):
126
  with open(filelists,"r") as f:
127
  self.paths = f.read().splitlines()
128
  for name_ext in tqdm(self.paths, total=len(self.paths)):
129
- name = os.path.splitext(name_ext)[0]
130
  path_audio = name_ext
131
  duration = librosa.get_duration(filename = path_audio, sr = self.sample_rate)
132
 
@@ -171,7 +175,7 @@ class AudioDataset(Dataset):
171
  path_units = name_ext + ".soft.pt"
172
  units = torch.load(path_units).to(device)
173
  units = units[0]
174
- units = repeat_expand_2d(units,f0.size(0)).transpose(0,1)
175
 
176
  if fp16:
177
  mel = mel.half()
@@ -263,7 +267,7 @@ class AudioDataset(Dataset):
263
  path_units = name_ext + ".soft.pt"
264
  units = torch.load(path_units)
265
  units = units[0]
266
- units = repeat_expand_2d(units,f0.size(0)).transpose(0,1)
267
 
268
  units = units[start_frame : start_frame + units_frame_len]
269
 
 
1
  import os
2
  import random
3
+
 
4
  import librosa
5
+ import numpy as np
6
  import torch
 
 
 
7
  from torch.utils.data import Dataset
8
+ from tqdm import tqdm
9
+
10
+ from utils import repeat_expand_2d
11
+
12
 
13
  def traverse_dir(
14
  root_dir,
 
64
  spk=args.spk,
65
  device=args.train.cache_device,
66
  fp16=args.train.cache_fp16,
67
+ unit_interpolate_mode = args.data.unit_interpolate_mode,
68
  use_aug=True)
69
  loader_train = torch.utils.data.DataLoader(
70
  data_train ,
 
83
  whole_audio=True,
84
  spk=args.spk,
85
  extensions=args.data.extensions,
86
+ unit_interpolate_mode = args.data.unit_interpolate_mode,
87
  n_spk=args.model.n_spk)
88
  loader_valid = torch.utils.data.DataLoader(
89
  data_valid,
 
110
  device='cpu',
111
  fp16=False,
112
  use_aug=False,
113
+ unit_interpolate_mode = 'left'
114
  ):
115
  super().__init__()
116
 
 
122
  self.use_aug = use_aug
123
  self.data_buffer={}
124
  self.pitch_aug_dict = {}
125
+ self.unit_interpolate_mode = unit_interpolate_mode
126
  # np.load(os.path.join(self.path_root, 'pitch_aug_dict.npy'), allow_pickle=True).item()
127
  if load_all_data:
128
  print('Load all the data filelists:', filelists)
 
131
  with open(filelists,"r") as f:
132
  self.paths = f.read().splitlines()
133
  for name_ext in tqdm(self.paths, total=len(self.paths)):
 
134
  path_audio = name_ext
135
  duration = librosa.get_duration(filename = path_audio, sr = self.sample_rate)
136
 
 
175
  path_units = name_ext + ".soft.pt"
176
  units = torch.load(path_units).to(device)
177
  units = units[0]
178
+ units = repeat_expand_2d(units,f0.size(0),unit_interpolate_mode).transpose(0,1)
179
 
180
  if fp16:
181
  mel = mel.half()
 
267
  path_units = name_ext + ".soft.pt"
268
  units = torch.load(path_units)
269
  units = units[0]
270
+ units = repeat_expand_2d(units,f0.size(0),self.unit_interpolate_mode).transpose(0,1)
271
 
272
  units = units[start_frame : start_frame + units_frame_len]
273
 
diffusion/diffusion.py CHANGED
@@ -1,10 +1,10 @@
1
  from collections import deque
2
  from functools import partial
3
  from inspect import isfunction
4
- import torch.nn.functional as F
5
- import librosa.sequence
6
  import numpy as np
7
  import torch
 
8
  from torch import nn
9
  from tqdm import tqdm
10
 
@@ -26,8 +26,10 @@ def extract(a, t, x_shape):
26
 
27
 
28
  def noise_like(shape, device, repeat=False):
29
- repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
30
- noise = lambda: torch.randn(shape, device=device)
 
 
31
  return repeat_noise() if repeat else noise()
32
 
33
 
@@ -67,6 +69,7 @@ class GaussianDiffusion(nn.Module):
67
  max_beta=0.02,
68
  spec_min=-12,
69
  spec_max=2):
 
70
  super().__init__()
71
  self.denoise_fn = denoise_fn
72
  self.out_dims = out_dims
@@ -78,7 +81,7 @@ class GaussianDiffusion(nn.Module):
78
 
79
  timesteps, = betas.shape
80
  self.num_timesteps = int(timesteps)
81
- self.k_step = k_step
82
 
83
  self.noise_list = deque(maxlen=4)
84
 
@@ -139,6 +142,18 @@ class GaussianDiffusion(nn.Module):
139
  model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
140
  return model_mean, posterior_variance, posterior_log_variance
141
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  @torch.no_grad()
143
  def p_sample(self, x, t, cond, clip_denoised=True, repeat_noise=False):
144
  b, *_, device = *x.shape, x.device
@@ -239,8 +254,12 @@ class GaussianDiffusion(nn.Module):
239
  x = self.q_sample(x_start=norm_spec, t=torch.tensor([t - 1], device=device).long())
240
 
241
  if method is not None and infer_speedup > 1:
242
- if method == 'dpm-solver':
243
- from .dpm_solver_pytorch import NoiseScheduleVP, model_wrapper, DPM_Solver
 
 
 
 
244
  # 1. Define the noise schedule.
245
  noise_schedule = NoiseScheduleVP(schedule='discrete', betas=self.betas[:t])
246
 
@@ -267,17 +286,20 @@ class GaussianDiffusion(nn.Module):
267
  # (We recommend singlestep DPM-Solver for unconditional sampling)
268
  # You can adjust the `steps` to balance the computation
269
  # costs and the sample quality.
270
- dpm_solver = DPM_Solver(model_fn, noise_schedule)
271
-
 
 
 
272
  steps = t // infer_speedup
273
  if use_tqdm:
274
  self.bar = tqdm(desc="sample time step", total=steps)
275
  x = dpm_solver.sample(
276
  x,
277
  steps=steps,
278
- order=3,
279
  skip_type="time_uniform",
280
- method="singlestep",
281
  )
282
  if use_tqdm:
283
  self.bar.close()
@@ -298,6 +320,63 @@ class GaussianDiffusion(nn.Module):
298
  x, torch.full((b,), i, device=device, dtype=torch.long),
299
  infer_speedup, cond=cond
300
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
301
  else:
302
  raise NotImplementedError(method)
303
  else:
 
1
  from collections import deque
2
  from functools import partial
3
  from inspect import isfunction
4
+
 
5
  import numpy as np
6
  import torch
7
+ import torch.nn.functional as F
8
  from torch import nn
9
  from tqdm import tqdm
10
 
 
26
 
27
 
28
  def noise_like(shape, device, repeat=False):
29
+ def repeat_noise():
30
+ return torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
31
+ def noise():
32
+ return torch.randn(shape, device=device)
33
  return repeat_noise() if repeat else noise()
34
 
35
 
 
69
  max_beta=0.02,
70
  spec_min=-12,
71
  spec_max=2):
72
+
73
  super().__init__()
74
  self.denoise_fn = denoise_fn
75
  self.out_dims = out_dims
 
81
 
82
  timesteps, = betas.shape
83
  self.num_timesteps = int(timesteps)
84
+ self.k_step = k_step if k_step>0 and k_step<timesteps else timesteps
85
 
86
  self.noise_list = deque(maxlen=4)
87
 
 
142
  model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
143
  return model_mean, posterior_variance, posterior_log_variance
144
 
145
+ @torch.no_grad()
146
+ def p_sample_ddim(self, x, t, interval, cond):
147
+ """
148
+ Use the DDIM method from
149
+ """
150
+ a_t = extract(self.alphas_cumprod, t, x.shape)
151
+ a_prev = extract(self.alphas_cumprod, torch.max(t - interval, torch.zeros_like(t)), x.shape)
152
+
153
+ noise_pred = self.denoise_fn(x, t, cond=cond)
154
+ x_prev = a_prev.sqrt() * (x / a_t.sqrt() + (((1 - a_prev) / a_prev).sqrt()-((1 - a_t) / a_t).sqrt()) * noise_pred)
155
+ return x_prev
156
+
157
  @torch.no_grad()
158
  def p_sample(self, x, t, cond, clip_denoised=True, repeat_noise=False):
159
  b, *_, device = *x.shape, x.device
 
254
  x = self.q_sample(x_start=norm_spec, t=torch.tensor([t - 1], device=device).long())
255
 
256
  if method is not None and infer_speedup > 1:
257
+ if method == 'dpm-solver' or method == 'dpm-solver++':
258
+ from .dpm_solver_pytorch import (
259
+ DPM_Solver,
260
+ NoiseScheduleVP,
261
+ model_wrapper,
262
+ )
263
  # 1. Define the noise schedule.
264
  noise_schedule = NoiseScheduleVP(schedule='discrete', betas=self.betas[:t])
265
 
 
286
  # (We recommend singlestep DPM-Solver for unconditional sampling)
287
  # You can adjust the `steps` to balance the computation
288
  # costs and the sample quality.
289
+ if method == 'dpm-solver':
290
+ dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver")
291
+ elif method == 'dpm-solver++':
292
+ dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++")
293
+
294
  steps = t // infer_speedup
295
  if use_tqdm:
296
  self.bar = tqdm(desc="sample time step", total=steps)
297
  x = dpm_solver.sample(
298
  x,
299
  steps=steps,
300
+ order=2,
301
  skip_type="time_uniform",
302
+ method="multistep",
303
  )
304
  if use_tqdm:
305
  self.bar.close()
 
320
  x, torch.full((b,), i, device=device, dtype=torch.long),
321
  infer_speedup, cond=cond
322
  )
323
+ elif method == 'ddim':
324
+ if use_tqdm:
325
+ for i in tqdm(
326
+ reversed(range(0, t, infer_speedup)), desc='sample time step',
327
+ total=t // infer_speedup,
328
+ ):
329
+ x = self.p_sample_ddim(
330
+ x, torch.full((b,), i, device=device, dtype=torch.long),
331
+ infer_speedup, cond=cond
332
+ )
333
+ else:
334
+ for i in reversed(range(0, t, infer_speedup)):
335
+ x = self.p_sample_ddim(
336
+ x, torch.full((b,), i, device=device, dtype=torch.long),
337
+ infer_speedup, cond=cond
338
+ )
339
+ elif method == 'unipc':
340
+ from .uni_pc import NoiseScheduleVP, UniPC, model_wrapper
341
+ # 1. Define the noise schedule.
342
+ noise_schedule = NoiseScheduleVP(schedule='discrete', betas=self.betas[:t])
343
+
344
+ # 2. Convert your discrete-time `model` to the continuous-time
345
+ # noise prediction model. Here is an example for a diffusion model
346
+ # `model` with the noise prediction type ("noise") .
347
+ def my_wrapper(fn):
348
+ def wrapped(x, t, **kwargs):
349
+ ret = fn(x, t, **kwargs)
350
+ if use_tqdm:
351
+ self.bar.update(1)
352
+ return ret
353
+
354
+ return wrapped
355
+
356
+ model_fn = model_wrapper(
357
+ my_wrapper(self.denoise_fn),
358
+ noise_schedule,
359
+ model_type="noise", # or "x_start" or "v" or "score"
360
+ model_kwargs={"cond": cond}
361
+ )
362
+
363
+ # 3. Define uni_pc and sample by multistep UniPC.
364
+ # You can adjust the `steps` to balance the computation
365
+ # costs and the sample quality.
366
+ uni_pc = UniPC(model_fn, noise_schedule, variant='bh2')
367
+
368
+ steps = t // infer_speedup
369
+ if use_tqdm:
370
+ self.bar = tqdm(desc="sample time step", total=steps)
371
+ x = uni_pc.sample(
372
+ x,
373
+ steps=steps,
374
+ order=2,
375
+ skip_type="time_uniform",
376
+ method="multistep",
377
+ )
378
+ if use_tqdm:
379
+ self.bar.close()
380
  else:
381
  raise NotImplementedError(method)
382
  else:
diffusion/diffusion_onnx.py CHANGED
@@ -1,15 +1,14 @@
 
1
  from collections import deque
2
  from functools import partial
3
  from inspect import isfunction
4
- import torch.nn.functional as F
5
- import librosa.sequence
6
  import numpy as np
7
- from torch.nn import Conv1d
8
- from torch.nn import Mish
9
  import torch
 
10
  from torch import nn
 
11
  from tqdm import tqdm
12
- import math
13
 
14
 
15
  def exists(x):
@@ -27,8 +26,10 @@ def extract(a, t):
27
 
28
 
29
  def noise_like(shape, device, repeat=False):
30
- repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
31
- noise = lambda: torch.randn(shape, device=device)
 
 
32
  return repeat_noise() if repeat else noise()
33
 
34
 
@@ -389,7 +390,11 @@ class GaussianDiffusion(nn.Module):
389
 
390
  if method is not None and infer_speedup > 1:
391
  if method == 'dpm-solver':
392
- from .dpm_solver_pytorch import NoiseScheduleVP, model_wrapper, DPM_Solver
 
 
 
 
393
  # 1. Define the noise schedule.
394
  noise_schedule = NoiseScheduleVP(schedule='discrete', betas=self.betas[:t])
395
 
@@ -576,9 +581,6 @@ class GaussianDiffusion(nn.Module):
576
  plms_noise_stage = torch.tensor(0, dtype=torch.long, device=device)
577
  noise_list = torch.zeros((0, 1, 1, self.mel_bins, n_frames), device=device)
578
 
579
- ot = step_range[0]
580
- ot_1 = torch.full((1,), ot, device=device, dtype=torch.long)
581
-
582
  for t in step_range:
583
  t_1 = torch.full((1,), t, device=device, dtype=torch.long)
584
  noise_pred = self.denoise_fn(x, t_1, cond)
 
1
+ import math
2
  from collections import deque
3
  from functools import partial
4
  from inspect import isfunction
5
+
 
6
  import numpy as np
 
 
7
  import torch
8
+ import torch.nn.functional as F
9
  from torch import nn
10
+ from torch.nn import Conv1d, Mish
11
  from tqdm import tqdm
 
12
 
13
 
14
  def exists(x):
 
26
 
27
 
28
  def noise_like(shape, device, repeat=False):
29
+ def repeat_noise():
30
+ return torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
31
+ def noise():
32
+ return torch.randn(shape, device=device)
33
  return repeat_noise() if repeat else noise()
34
 
35
 
 
390
 
391
  if method is not None and infer_speedup > 1:
392
  if method == 'dpm-solver':
393
+ from .dpm_solver_pytorch import (
394
+ DPM_Solver,
395
+ NoiseScheduleVP,
396
+ model_wrapper,
397
+ )
398
  # 1. Define the noise schedule.
399
  noise_schedule = NoiseScheduleVP(schedule='discrete', betas=self.betas[:t])
400
 
 
581
  plms_noise_stage = torch.tensor(0, dtype=torch.long, device=device)
582
  noise_list = torch.zeros((0, 1, 1, self.mel_bins, n_frames), device=device)
583
 
 
 
 
584
  for t in step_range:
585
  t_1 = torch.full((1,), t, device=device, dtype=torch.long)
586
  noise_pred = self.denoise_fn(x, t_1, cond)
diffusion/dpm_solver_pytorch.py CHANGED
@@ -1,5 +1,3 @@
1
- import math
2
-
3
  import torch
4
 
5
 
@@ -11,7 +9,8 @@ class NoiseScheduleVP:
11
  alphas_cumprod=None,
12
  continuous_beta_0=0.1,
13
  continuous_beta_1=20.,
14
- ):
 
15
  """Create a wrapper class for the forward SDE (VP type).
16
 
17
  ***
@@ -46,7 +45,7 @@ class NoiseScheduleVP:
46
  betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details)
47
  alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details)
48
 
49
- Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`.
50
 
51
  **Important**: Please pay special attention for the args for `alphas_cumprod`:
52
  The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that
@@ -59,21 +58,19 @@ class NoiseScheduleVP:
59
 
60
  2. For continuous-time DPMs:
61
 
62
- We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise
63
- schedule are the default settings in DDPM and improved-DDPM:
64
 
65
  Args:
66
  beta_min: A `float` number. The smallest beta for the linear schedule.
67
  beta_max: A `float` number. The largest beta for the linear schedule.
68
- cosine_s: A `float` number. The hyperparameter in the cosine schedule.
69
- cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule.
70
  T: A `float` number. The ending time of the forward process.
71
 
72
  ===============================================================
73
 
74
  Args:
75
  schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs,
76
- 'linear' or 'cosine' for continuous-time DPMs.
77
  Returns:
78
  A wrapper object of the forward SDE (VP type).
79
 
@@ -92,10 +89,8 @@ class NoiseScheduleVP:
92
 
93
  """
94
 
95
- if schedule not in ['discrete', 'linear', 'cosine']:
96
- raise ValueError(
97
- "Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(
98
- schedule))
99
 
100
  self.schedule = schedule
101
  if schedule == 'discrete':
@@ -104,40 +99,37 @@ class NoiseScheduleVP:
104
  else:
105
  assert alphas_cumprod is not None
106
  log_alphas = 0.5 * torch.log(alphas_cumprod)
107
- self.total_N = len(log_alphas)
108
  self.T = 1.
109
- self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1))
110
- self.log_alpha_array = log_alphas.reshape((1, -1,))
 
111
  else:
 
112
  self.total_N = 1000
113
  self.beta_0 = continuous_beta_0
114
  self.beta_1 = continuous_beta_1
115
- self.cosine_s = 0.008
116
- self.cosine_beta_max = 999.
117
- self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (
118
- 1. + self.cosine_s) / math.pi - self.cosine_s
119
- self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.))
120
- self.schedule = schedule
121
- if schedule == 'cosine':
122
- # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.
123
- # Note that T = 0.9946 may be not the optimal setting. However, we find it works well.
124
- self.T = 0.9946
125
- else:
126
- self.T = 1.
 
127
 
128
  def marginal_log_mean_coeff(self, t):
129
  """
130
  Compute log(alpha_t) of a given continuous-time label t in [0, T].
131
  """
132
  if self.schedule == 'discrete':
133
- return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device),
134
- self.log_alpha_array.to(t.device)).reshape((-1))
135
  elif self.schedule == 'linear':
136
  return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
137
- elif self.schedule == 'cosine':
138
- log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.))
139
- log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
140
- return log_alpha_t
141
 
142
  def marginal_alpha(self, t):
143
  """
@@ -165,32 +157,25 @@ class NoiseScheduleVP:
165
  """
166
  if self.schedule == 'linear':
167
  tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
168
- Delta = self.beta_0 ** 2 + tmp
169
  return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
170
  elif self.schedule == 'discrete':
171
  log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb)
172
- t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]),
173
- torch.flip(self.t_array.to(lamb.device), [1]))
174
  return t.reshape((-1,))
175
- else:
176
- log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
177
- t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (
178
- 1. + self.cosine_s) / math.pi - self.cosine_s
179
- t = t_fn(log_alpha)
180
- return t
181
 
182
 
183
  def model_wrapper(
184
- model,
185
- noise_schedule,
186
- model_type="noise",
187
- model_kwargs={},
188
- guidance_type="uncond",
189
- condition=None,
190
- unconditional_condition=None,
191
- guidance_scale=1.,
192
- classifier_fn=None,
193
- classifier_kwargs={},
194
  ):
195
  """Create a wrapper function for the noise prediction model.
196
 
@@ -293,8 +278,6 @@ def model_wrapper(
293
  return t_continuous
294
 
295
  def noise_pred_fn(x, t_continuous, cond=None):
296
- if t_continuous.reshape((-1,)).shape[0] == 1:
297
- t_continuous = t_continuous.expand((x.shape[0]))
298
  t_input = get_model_input_time(t_continuous)
299
  if cond is None:
300
  output = model(x, t_input, **model_kwargs)
@@ -304,16 +287,13 @@ def model_wrapper(
304
  return output
305
  elif model_type == "x_start":
306
  alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
307
- dims = x.dim()
308
- return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims)
309
  elif model_type == "v":
310
  alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
311
- dims = x.dim()
312
- return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x
313
  elif model_type == "score":
314
  sigma_t = noise_schedule.marginal_std(t_continuous)
315
- dims = x.dim()
316
- return -expand_dims(sigma_t, dims) * output
317
 
318
  def cond_grad_fn(x, t_input):
319
  """
@@ -328,8 +308,6 @@ def model_wrapper(
328
  """
329
  The noise predicition model function that is used for DPM-Solver.
330
  """
331
- if t_continuous.reshape((-1,)).shape[0] == 1:
332
- t_continuous = t_continuous.expand((x.shape[0]))
333
  if guidance_type == "uncond":
334
  return noise_pred_fn(x, t_continuous)
335
  elif guidance_type == "classifier":
@@ -338,7 +316,7 @@ def model_wrapper(
338
  cond_grad = cond_grad_fn(x, t_input)
339
  sigma_t = noise_schedule.marginal_std(t_continuous)
340
  noise = noise_pred_fn(x, t_continuous)
341
- return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad
342
  elif guidance_type == "classifier-free":
343
  if guidance_scale == 1. or unconditional_condition is None:
344
  return noise_pred_fn(x, t_continuous, cond=condition)
@@ -349,20 +327,34 @@ def model_wrapper(
349
  noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
350
  return noise_uncond + guidance_scale * (noise - noise_uncond)
351
 
352
- assert model_type in ["noise", "x_start", "v"]
353
  assert guidance_type in ["uncond", "classifier", "classifier-free"]
354
  return model_fn
355
 
356
 
357
  class DPM_Solver:
358
- def __init__(self, model_fn, noise_schedule, predict_x0=False, thresholding=False, max_val=1.):
 
 
 
 
 
 
 
 
 
359
  """Construct a DPM-Solver.
360
 
361
- We support both the noise prediction model ("predicting epsilon") and the data prediction model ("predicting x0").
362
- If `predict_x0` is False, we use the solver for the noise prediction model (DPM-Solver).
363
- If `predict_x0` is True, we use the solver for the data prediction model (DPM-Solver++).
364
- In such case, we further support the "dynamic thresholding" in [1] when `thresholding` is True.
365
- The "dynamic thresholding" can greatly improve the sample quality for pixel-space DPMs with large guidance scales.
 
 
 
 
 
366
 
367
  Args:
368
  model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]):
@@ -370,18 +362,65 @@ class DPM_Solver:
370
  def model_fn(x, t_continuous):
371
  return noise
372
  ``
 
373
  noise_schedule: A noise schedule object, such as NoiseScheduleVP.
374
- predict_x0: A `bool`. If true, use the data prediction model; else, use the noise prediction model.
375
- thresholding: A `bool`. Valid when `predict_x0` is True. Whether to use the "dynamic thresholding" in [1].
376
- max_val: A `float`. Valid when both `predict_x0` and `thresholding` are True. The max value for thresholding.
377
-
378
- [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
379
  """
380
- self.model = model_fn
381
  self.noise_schedule = noise_schedule
382
- self.predict_x0 = predict_x0
383
- self.thresholding = thresholding
384
- self.max_val = max_val
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
385
 
386
  def noise_prediction_fn(self, x, t):
387
  """
@@ -391,24 +430,20 @@ class DPM_Solver:
391
 
392
  def data_prediction_fn(self, x, t):
393
  """
394
- Return the data prediction model (with thresholding).
395
  """
396
  noise = self.noise_prediction_fn(x, t)
397
- dims = x.dim()
398
  alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
399
- x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims)
400
- if self.thresholding:
401
- p = 0.995 # A hyperparameter in the paper of "Imagen" [1].
402
- s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
403
- s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims)
404
- x0 = torch.clamp(x0, -s, s) / s
405
  return x0
406
 
407
  def model_fn(self, x, t):
408
  """
409
  Convert the model to the noise prediction model or the data prediction model.
410
  """
411
- if self.predict_x0:
412
  return self.data_prediction_fn(x, t)
413
  else:
414
  return self.noise_prediction_fn(x, t)
@@ -437,11 +472,10 @@ class DPM_Solver:
437
  return torch.linspace(t_T, t_0, N + 1).to(device)
438
  elif skip_type == 'time_quadratic':
439
  t_order = 2
440
- t = torch.linspace(t_T ** (1. / t_order), t_0 ** (1. / t_order), N + 1).pow(t_order).to(device)
441
  return t
442
  else:
443
- raise ValueError(
444
- "Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type))
445
 
446
  def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
447
  """
@@ -478,32 +512,31 @@ class DPM_Solver:
478
  if order == 3:
479
  K = steps // 3 + 1
480
  if steps % 3 == 0:
481
- orders = [3, ] * (K - 2) + [2, 1]
482
  elif steps % 3 == 1:
483
- orders = [3, ] * (K - 1) + [1]
484
  else:
485
- orders = [3, ] * (K - 1) + [2]
486
  elif order == 2:
487
  if steps % 2 == 0:
488
  K = steps // 2
489
- orders = [2, ] * K
490
  else:
491
  K = steps // 2 + 1
492
- orders = [2, ] * (K - 1) + [1]
493
  elif order == 1:
494
  K = 1
495
- orders = [1, ] * steps
496
  else:
497
  raise ValueError("'order' must be '1' or '2' or '3'.")
498
  if skip_type == 'logSNR':
499
  # To reproduce the results in DPM-Solver paper
500
  timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device)
501
  else:
502
- timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[
503
- torch.cumsum(torch.tensor([0, ] + orders), dim=0).to(device)]
504
  return timesteps_outer, orders
505
 
506
- def denoise_fn(self, x, s):
507
  """
508
  Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
509
  """
@@ -515,8 +548,8 @@ class DPM_Solver:
515
 
516
  Args:
517
  x: A pytorch tensor. The initial value at time `s`.
518
- s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
519
- t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
520
  model_s: A pytorch tensor. The model function evaluated at time `s`.
521
  If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
522
  return_intermediate: A `bool`. If true, also return the model value at time `s`.
@@ -524,20 +557,19 @@ class DPM_Solver:
524
  x_t: A pytorch tensor. The approximated solution at time `t`.
525
  """
526
  ns = self.noise_schedule
527
- dims = x.dim()
528
  lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
529
  h = lambda_t - lambda_s
530
  log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t)
531
  sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t)
532
  alpha_t = torch.exp(log_alpha_t)
533
 
534
- if self.predict_x0:
535
  phi_1 = torch.expm1(-h)
536
  if model_s is None:
537
  model_s = self.model_fn(x, s)
538
  x_t = (
539
- expand_dims(sigma_t / sigma_s, dims) * x
540
- - expand_dims(alpha_t * phi_1, dims) * model_s
541
  )
542
  if return_intermediate:
543
  return x_t, {'model_s': model_s}
@@ -548,70 +580,66 @@ class DPM_Solver:
548
  if model_s is None:
549
  model_s = self.model_fn(x, s)
550
  x_t = (
551
- expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
552
- - expand_dims(sigma_t * phi_1, dims) * model_s
553
  )
554
  if return_intermediate:
555
  return x_t, {'model_s': model_s}
556
  else:
557
  return x_t
558
 
559
- def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, return_intermediate=False,
560
- solver_type='dpm_solver'):
561
  """
562
  Singlestep solver DPM-Solver-2 from time `s` to time `t`.
563
 
564
  Args:
565
  x: A pytorch tensor. The initial value at time `s`.
566
- s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
567
- t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
568
  r1: A `float`. The hyperparameter of the second-order solver.
569
  model_s: A pytorch tensor. The model function evaluated at time `s`.
570
  If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
571
  return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time).
572
- solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
573
- The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
574
  Returns:
575
  x_t: A pytorch tensor. The approximated solution at time `t`.
576
  """
577
- if solver_type not in ['dpm_solver', 'taylor']:
578
- raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
579
  if r1 is None:
580
  r1 = 0.5
581
  ns = self.noise_schedule
582
- dims = x.dim()
583
  lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
584
  h = lambda_t - lambda_s
585
  lambda_s1 = lambda_s + r1 * h
586
  s1 = ns.inverse_lambda(lambda_s1)
587
- log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(
588
- s1), ns.marginal_log_mean_coeff(t)
589
  sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t)
590
  alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t)
591
 
592
- if self.predict_x0:
593
  phi_11 = torch.expm1(-r1 * h)
594
  phi_1 = torch.expm1(-h)
595
 
596
  if model_s is None:
597
  model_s = self.model_fn(x, s)
598
  x_s1 = (
599
- expand_dims(sigma_s1 / sigma_s, dims) * x
600
- - expand_dims(alpha_s1 * phi_11, dims) * model_s
601
  )
602
  model_s1 = self.model_fn(x_s1, s1)
603
- if solver_type == 'dpm_solver':
604
  x_t = (
605
- expand_dims(sigma_t / sigma_s, dims) * x
606
- - expand_dims(alpha_t * phi_1, dims) * model_s
607
- - (0.5 / r1) * expand_dims(alpha_t * phi_1, dims) * (model_s1 - model_s)
608
  )
609
  elif solver_type == 'taylor':
610
  x_t = (
611
- expand_dims(sigma_t / sigma_s, dims) * x
612
- - expand_dims(alpha_t * phi_1, dims) * model_s
613
- + (1. / r1) * expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * (
614
- model_s1 - model_s)
615
  )
616
  else:
617
  phi_11 = torch.expm1(r1 * h)
@@ -620,36 +648,35 @@ class DPM_Solver:
620
  if model_s is None:
621
  model_s = self.model_fn(x, s)
622
  x_s1 = (
623
- expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x
624
- - expand_dims(sigma_s1 * phi_11, dims) * model_s
625
  )
626
  model_s1 = self.model_fn(x_s1, s1)
627
- if solver_type == 'dpm_solver':
628
  x_t = (
629
- expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
630
- - expand_dims(sigma_t * phi_1, dims) * model_s
631
- - (0.5 / r1) * expand_dims(sigma_t * phi_1, dims) * (model_s1 - model_s)
632
  )
633
  elif solver_type == 'taylor':
634
  x_t = (
635
- expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
636
- - expand_dims(sigma_t * phi_1, dims) * model_s
637
- - (1. / r1) * expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * (model_s1 - model_s)
638
  )
639
  if return_intermediate:
640
  return x_t, {'model_s': model_s, 'model_s1': model_s1}
641
  else:
642
  return x_t
643
 
644
- def singlestep_dpm_solver_third_update(self, x, s, t, r1=1. / 3., r2=2. / 3., model_s=None, model_s1=None,
645
- return_intermediate=False, solver_type='dpm_solver'):
646
  """
647
  Singlestep solver DPM-Solver-3 from time `s` to time `t`.
648
 
649
  Args:
650
  x: A pytorch tensor. The initial value at time `s`.
651
- s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
652
- t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
653
  r1: A `float`. The hyperparameter of the third-order solver.
654
  r2: A `float`. The hyperparameter of the third-order solver.
655
  model_s: A pytorch tensor. The model function evaluated at time `s`.
@@ -657,32 +684,29 @@ class DPM_Solver:
657
  model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`).
658
  If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it.
659
  return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
660
- solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
661
- The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
662
  Returns:
663
  x_t: A pytorch tensor. The approximated solution at time `t`.
664
  """
665
- if solver_type not in ['dpm_solver', 'taylor']:
666
- raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
667
  if r1 is None:
668
  r1 = 1. / 3.
669
  if r2 is None:
670
  r2 = 2. / 3.
671
  ns = self.noise_schedule
672
- dims = x.dim()
673
  lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
674
  h = lambda_t - lambda_s
675
  lambda_s1 = lambda_s + r1 * h
676
  lambda_s2 = lambda_s + r2 * h
677
  s1 = ns.inverse_lambda(lambda_s1)
678
  s2 = ns.inverse_lambda(lambda_s2)
679
- log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff(
680
- s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t)
681
- sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(
682
- s2), ns.marginal_std(t)
683
  alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t)
684
 
685
- if self.predict_x0:
686
  phi_11 = torch.expm1(-r1 * h)
687
  phi_12 = torch.expm1(-r2 * h)
688
  phi_1 = torch.expm1(-h)
@@ -694,21 +718,21 @@ class DPM_Solver:
694
  model_s = self.model_fn(x, s)
695
  if model_s1 is None:
696
  x_s1 = (
697
- expand_dims(sigma_s1 / sigma_s, dims) * x
698
- - expand_dims(alpha_s1 * phi_11, dims) * model_s
699
  )
700
  model_s1 = self.model_fn(x_s1, s1)
701
  x_s2 = (
702
- expand_dims(sigma_s2 / sigma_s, dims) * x
703
- - expand_dims(alpha_s2 * phi_12, dims) * model_s
704
- + r2 / r1 * expand_dims(alpha_s2 * phi_22, dims) * (model_s1 - model_s)
705
  )
706
  model_s2 = self.model_fn(x_s2, s2)
707
- if solver_type == 'dpm_solver':
708
  x_t = (
709
- expand_dims(sigma_t / sigma_s, dims) * x
710
- - expand_dims(alpha_t * phi_1, dims) * model_s
711
- + (1. / r2) * expand_dims(alpha_t * phi_2, dims) * (model_s2 - model_s)
712
  )
713
  elif solver_type == 'taylor':
714
  D1_0 = (1. / r1) * (model_s1 - model_s)
@@ -716,10 +740,10 @@ class DPM_Solver:
716
  D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
717
  D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
718
  x_t = (
719
- expand_dims(sigma_t / sigma_s, dims) * x
720
- - expand_dims(alpha_t * phi_1, dims) * model_s
721
- + expand_dims(alpha_t * phi_2, dims) * D1
722
- - expand_dims(alpha_t * phi_3, dims) * D2
723
  )
724
  else:
725
  phi_11 = torch.expm1(r1 * h)
@@ -733,21 +757,21 @@ class DPM_Solver:
733
  model_s = self.model_fn(x, s)
734
  if model_s1 is None:
735
  x_s1 = (
736
- expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x
737
- - expand_dims(sigma_s1 * phi_11, dims) * model_s
738
  )
739
  model_s1 = self.model_fn(x_s1, s1)
740
  x_s2 = (
741
- expand_dims(torch.exp(log_alpha_s2 - log_alpha_s), dims) * x
742
- - expand_dims(sigma_s2 * phi_12, dims) * model_s
743
- - r2 / r1 * expand_dims(sigma_s2 * phi_22, dims) * (model_s1 - model_s)
744
  )
745
  model_s2 = self.model_fn(x_s2, s2)
746
- if solver_type == 'dpm_solver':
747
  x_t = (
748
- expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
749
- - expand_dims(sigma_t * phi_1, dims) * model_s
750
- - (1. / r2) * expand_dims(sigma_t * phi_2, dims) * (model_s2 - model_s)
751
  )
752
  elif solver_type == 'taylor':
753
  D1_0 = (1. / r1) * (model_s1 - model_s)
@@ -755,10 +779,10 @@ class DPM_Solver:
755
  D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
756
  D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
757
  x_t = (
758
- expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
759
- - expand_dims(sigma_t * phi_1, dims) * model_s
760
- - expand_dims(sigma_t * phi_2, dims) * D1
761
- - expand_dims(sigma_t * phi_3, dims) * D2
762
  )
763
 
764
  if return_intermediate:
@@ -766,28 +790,26 @@ class DPM_Solver:
766
  else:
767
  return x_t
768
 
769
- def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpm_solver"):
770
  """
771
  Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`.
772
 
773
  Args:
774
  x: A pytorch tensor. The initial value at time `s`.
775
  model_prev_list: A list of pytorch tensor. The previous computed model values.
776
- t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
777
- t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
778
- solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
779
- The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
780
  Returns:
781
  x_t: A pytorch tensor. The approximated solution at time `t`.
782
  """
783
- if solver_type not in ['dpm_solver', 'taylor']:
784
- raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
785
  ns = self.noise_schedule
786
- dims = x.dim()
787
- model_prev_1, model_prev_0 = model_prev_list
788
- t_prev_1, t_prev_0 = t_prev_list
789
- lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda(
790
- t_prev_0), ns.marginal_lambda(t)
791
  log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
792
  sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
793
  alpha_t = torch.exp(log_alpha_t)
@@ -795,55 +817,55 @@ class DPM_Solver:
795
  h_0 = lambda_prev_0 - lambda_prev_1
796
  h = lambda_t - lambda_prev_0
797
  r0 = h_0 / h
798
- D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1)
799
- if self.predict_x0:
800
- if solver_type == 'dpm_solver':
 
801
  x_t = (
802
- expand_dims(sigma_t / sigma_prev_0, dims) * x
803
- - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
804
- - 0.5 * expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * D1_0
805
  )
806
  elif solver_type == 'taylor':
807
  x_t = (
808
- expand_dims(sigma_t / sigma_prev_0, dims) * x
809
- - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
810
- + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1_0
811
  )
812
  else:
813
- if solver_type == 'dpm_solver':
 
814
  x_t = (
815
- expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
816
- - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
817
- - 0.5 * expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * D1_0
818
  )
819
  elif solver_type == 'taylor':
820
  x_t = (
821
- expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
822
- - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
823
- - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1_0
824
  )
825
  return x_t
826
 
827
- def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type='dpm_solver'):
828
  """
829
  Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`.
830
 
831
  Args:
832
  x: A pytorch tensor. The initial value at time `s`.
833
  model_prev_list: A list of pytorch tensor. The previous computed model values.
834
- t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
835
- t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
836
- solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
837
- The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
838
  Returns:
839
  x_t: A pytorch tensor. The approximated solution at time `t`.
840
  """
841
  ns = self.noise_schedule
842
- dims = x.dim()
843
  model_prev_2, model_prev_1, model_prev_0 = model_prev_list
844
  t_prev_2, t_prev_1, t_prev_0 = t_prev_list
845
- lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda(
846
- t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t)
847
  log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
848
  sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
849
  alpha_t = torch.exp(log_alpha_t)
@@ -852,39 +874,44 @@ class DPM_Solver:
852
  h_0 = lambda_prev_0 - lambda_prev_1
853
  h = lambda_t - lambda_prev_0
854
  r0, r1 = h_0 / h, h_1 / h
855
- D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1)
856
- D1_1 = expand_dims(1. / r1, dims) * (model_prev_1 - model_prev_2)
857
- D1 = D1_0 + expand_dims(r0 / (r0 + r1), dims) * (D1_0 - D1_1)
858
- D2 = expand_dims(1. / (r0 + r1), dims) * (D1_0 - D1_1)
859
- if self.predict_x0:
 
 
 
860
  x_t = (
861
- expand_dims(sigma_t / sigma_prev_0, dims) * x
862
- - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
863
- + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1
864
- - expand_dims(alpha_t * ((torch.exp(-h) - 1. + h) / h ** 2 - 0.5), dims) * D2
865
  )
866
  else:
 
 
 
867
  x_t = (
868
- expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
869
- - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
870
- - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1
871
- - expand_dims(sigma_t * ((torch.exp(h) - 1. - h) / h ** 2 - 0.5), dims) * D2
872
  )
873
  return x_t
874
 
875
- def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False, solver_type='dpm_solver', r1=None,
876
- r2=None):
877
  """
878
  Singlestep DPM-Solver with the order `order` from time `s` to time `t`.
879
 
880
  Args:
881
  x: A pytorch tensor. The initial value at time `s`.
882
- s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
883
- t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
884
  order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
885
  return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
886
- solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
887
- The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
888
  r1: A `float`. The hyperparameter of the second-order or third-order solver.
889
  r2: A `float`. The hyperparameter of the third-order solver.
890
  Returns:
@@ -893,26 +920,24 @@ class DPM_Solver:
893
  if order == 1:
894
  return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate)
895
  elif order == 2:
896
- return self.singlestep_dpm_solver_second_update(x, s, t, return_intermediate=return_intermediate,
897
- solver_type=solver_type, r1=r1)
898
  elif order == 3:
899
- return self.singlestep_dpm_solver_third_update(x, s, t, return_intermediate=return_intermediate,
900
- solver_type=solver_type, r1=r1, r2=r2)
901
  else:
902
  raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
903
 
904
- def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type='dpm_solver'):
905
  """
906
  Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`.
907
 
908
  Args:
909
  x: A pytorch tensor. The initial value at time `s`.
910
  model_prev_list: A list of pytorch tensor. The previous computed model values.
911
- t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
912
- t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
913
  order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
914
- solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
915
- The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
916
  Returns:
917
  x_t: A pytorch tensor. The approximated solution at time `t`.
918
  """
@@ -925,8 +950,7 @@ class DPM_Solver:
925
  else:
926
  raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
927
 
928
- def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5,
929
- solver_type='dpm_solver'):
930
  """
931
  The adaptive step size solver based on singlestep DPM-Solver.
932
 
@@ -941,15 +965,15 @@ class DPM_Solver:
941
  theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1].
942
  t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the
943
  current time and `t_0` is less than `t_err`. The default setting is 1e-5.
944
- solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
945
- The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
946
  Returns:
947
  x_0: A pytorch tensor. The approximated solution at time `t_0`.
948
 
949
  [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021.
950
  """
951
  ns = self.noise_schedule
952
- s = t_T * torch.ones((x.shape[0],)).to(x)
953
  lambda_s = ns.marginal_lambda(s)
954
  lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x))
955
  h = h_init * torch.ones_like(s).to(x)
@@ -957,18 +981,16 @@ class DPM_Solver:
957
  nfe = 0
958
  if order == 2:
959
  r1 = 0.5
960
- lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True)
961
- higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1,
962
- solver_type=solver_type,
963
- **kwargs)
964
  elif order == 3:
965
  r1, r2 = 1. / 3., 2. / 3.
966
- lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1,
967
- return_intermediate=True,
968
- solver_type=solver_type)
969
- higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2,
970
- solver_type=solver_type,
971
- **kwargs)
972
  else:
973
  raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order))
974
  while torch.abs((s - t_0)).mean() > t_err:
@@ -976,7 +998,8 @@ class DPM_Solver:
976
  x_lower, lower_noise_kwargs = lower_update(x, s, t)
977
  x_higher = higher_update(x, s, t, **lower_noise_kwargs)
978
  delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev)))
979
- norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True))
 
980
  E = norm_fn((x_higher - x_lower) / delta).max()
981
  if torch.all(E <= 1.):
982
  x = x_higher
@@ -988,10 +1011,45 @@ class DPM_Solver:
988
  print('adaptive solver nfe', nfe)
989
  return x
990
 
991
- def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform',
992
- method='singlestep', denoise=False, solver_type='dpm_solver', atol=0.0078,
993
- rtol=0.05,
994
- ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
995
  """
996
  Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`.
997
 
@@ -1040,15 +1098,19 @@ class DPM_Solver:
1040
 
1041
  Some advices for choosing the algorithm:
1042
  - For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs:
1043
- Use singlestep DPM-Solver ("DPM-Solver-fast" in the paper) with `order = 3`.
1044
- e.g.
1045
- >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=False)
 
 
 
 
1046
  >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3,
1047
  skip_type='time_uniform', method='singlestep')
1048
  - For **guided sampling with large guidance scale** by DPMs:
1049
- Use multistep DPM-Solver with `predict_x0 = True` and `order = 2`.
1050
  e.g.
1051
- >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True)
1052
  >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2,
1053
  skip_type='time_uniform', method='multistep')
1054
 
@@ -1074,72 +1136,116 @@ class DPM_Solver:
1074
  order: A `int`. The order of DPM-Solver.
1075
  skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'.
1076
  method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'.
1077
- denoise: A `bool`. Whether to denoise at the final step. Default is False.
1078
- If `denoise` is True, the total NFE is (`steps` + 1).
1079
- solver_type: A `str`. The taylor expansion type for the solver. `dpm_solver` or `taylor`. We recommend `dpm_solver`.
 
 
 
 
 
 
 
 
 
 
 
1080
  atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
1081
  rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
 
 
1082
  Returns:
1083
  x_end: A pytorch tensor. The approximated solution at time `t_end`.
1084
 
1085
  """
1086
  t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
1087
  t_T = self.noise_schedule.T if t_start is None else t_start
 
 
 
 
 
1088
  device = x.device
1089
- if method == 'adaptive':
1090
- with torch.no_grad():
1091
- x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol,
1092
- solver_type=solver_type)
1093
- elif method == 'multistep':
1094
- assert steps >= order
1095
- timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
1096
- assert timesteps.shape[0] - 1 == steps
1097
- with torch.no_grad():
1098
- vec_t = timesteps[0].expand((x.shape[0]))
1099
- model_prev_list = [self.model_fn(x, vec_t)]
1100
- t_prev_list = [vec_t]
 
 
 
 
 
1101
  # Init the first `order` values by lower order multistep DPM-Solver.
1102
- for init_order in range(1, order):
1103
- vec_t = timesteps[init_order].expand(x.shape[0])
1104
- x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, init_order,
1105
- solver_type=solver_type)
1106
- model_prev_list.append(self.model_fn(x, vec_t))
1107
- t_prev_list.append(vec_t)
 
 
 
1108
  # Compute the remaining values by `order`-th order multistep DPM-Solver.
1109
  for step in range(order, steps + 1):
1110
- vec_t = timesteps[step].expand(x.shape[0])
1111
- x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, order,
1112
- solver_type=solver_type)
 
 
 
 
 
 
 
 
1113
  for i in range(order - 1):
1114
  t_prev_list[i] = t_prev_list[i + 1]
1115
  model_prev_list[i] = model_prev_list[i + 1]
1116
- t_prev_list[-1] = vec_t
1117
  # We do not need to evaluate the final model value.
1118
  if step < steps:
1119
- model_prev_list[-1] = self.model_fn(x, vec_t)
1120
- elif method in ['singlestep', 'singlestep_fixed']:
1121
- if method == 'singlestep':
1122
- timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps=steps, order=order,
1123
- skip_type=skip_type,
1124
- t_T=t_T, t_0=t_0,
1125
- device=device)
1126
- elif method == 'singlestep_fixed':
1127
- K = steps // order
1128
- orders = [order, ] * K
1129
- timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device)
1130
- for i, order in enumerate(orders):
1131
- t_T_inner, t_0_inner = timesteps_outer[i], timesteps_outer[i + 1]
1132
- timesteps_inner = self.get_time_steps(skip_type=skip_type, t_T=t_T_inner.item(), t_0=t_0_inner.item(),
1133
- N=order, device=device)
1134
- lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner)
1135
- vec_s, vec_t = t_T_inner.repeat(x.shape[0]), t_0_inner.repeat(x.shape[0])
1136
- h = lambda_inner[-1] - lambda_inner[0]
1137
- r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h
1138
- r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h
1139
- x = self.singlestep_dpm_solver_update(x, vec_s, vec_t, order, solver_type=solver_type, r1=r1, r2=r2)
1140
- if denoise:
1141
- x = self.denoise_fn(x, torch.ones((x.shape[0],)).to(device) * t_0)
1142
- return x
 
 
 
 
 
 
 
 
 
 
1143
 
1144
 
1145
  #############################################################
@@ -1198,4 +1304,4 @@ def expand_dims(v, dims):
1198
  Returns:
1199
  a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
1200
  """
1201
- return v[(...,) + (None,) * (dims - 1)]
 
 
 
1
  import torch
2
 
3
 
 
9
  alphas_cumprod=None,
10
  continuous_beta_0=0.1,
11
  continuous_beta_1=20.,
12
+ dtype=torch.float32,
13
+ ):
14
  """Create a wrapper class for the forward SDE (VP type).
15
 
16
  ***
 
45
  betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details)
46
  alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details)
47
 
48
+ Note that we always have alphas_cumprod = cumprod(1 - betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`.
49
 
50
  **Important**: Please pay special attention for the args for `alphas_cumprod`:
51
  The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that
 
58
 
59
  2. For continuous-time DPMs:
60
 
61
+ We support the linear VPSDE for the continuous time setting. The hyperparameters for the noise
62
+ schedule are the default settings in Yang Song's ScoreSDE:
63
 
64
  Args:
65
  beta_min: A `float` number. The smallest beta for the linear schedule.
66
  beta_max: A `float` number. The largest beta for the linear schedule.
 
 
67
  T: A `float` number. The ending time of the forward process.
68
 
69
  ===============================================================
70
 
71
  Args:
72
  schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs,
73
+ 'linear' for continuous-time DPMs.
74
  Returns:
75
  A wrapper object of the forward SDE (VP type).
76
 
 
89
 
90
  """
91
 
92
+ if schedule not in ['discrete', 'linear']:
93
+ raise ValueError("Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear'".format(schedule))
 
 
94
 
95
  self.schedule = schedule
96
  if schedule == 'discrete':
 
99
  else:
100
  assert alphas_cumprod is not None
101
  log_alphas = 0.5 * torch.log(alphas_cumprod)
 
102
  self.T = 1.
103
+ self.log_alpha_array = self.numerical_clip_alpha(log_alphas).reshape((1, -1,)).to(dtype=dtype)
104
+ self.total_N = self.log_alpha_array.shape[1]
105
+ self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1)).to(dtype=dtype)
106
  else:
107
+ self.T = 1.
108
  self.total_N = 1000
109
  self.beta_0 = continuous_beta_0
110
  self.beta_1 = continuous_beta_1
111
+
112
+ def numerical_clip_alpha(self, log_alphas, clipped_lambda=-5.1):
113
+ """
114
+ For some beta schedules such as cosine schedule, the log-SNR has numerical isssues.
115
+ We clip the log-SNR near t=T within -5.1 to ensure the stability.
116
+ Such a trick is very useful for diffusion models with the cosine schedule, such as i-DDPM, guided-diffusion and GLIDE.
117
+ """
118
+ log_sigmas = 0.5 * torch.log(1. - torch.exp(2. * log_alphas))
119
+ lambs = log_alphas - log_sigmas
120
+ idx = torch.searchsorted(torch.flip(lambs, [0]), clipped_lambda)
121
+ if idx > 0:
122
+ log_alphas = log_alphas[:-idx]
123
+ return log_alphas
124
 
125
  def marginal_log_mean_coeff(self, t):
126
  """
127
  Compute log(alpha_t) of a given continuous-time label t in [0, T].
128
  """
129
  if self.schedule == 'discrete':
130
+ return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), self.log_alpha_array.to(t.device)).reshape((-1))
 
131
  elif self.schedule == 'linear':
132
  return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
 
 
 
 
133
 
134
  def marginal_alpha(self, t):
135
  """
 
157
  """
158
  if self.schedule == 'linear':
159
  tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
160
+ Delta = self.beta_0**2 + tmp
161
  return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
162
  elif self.schedule == 'discrete':
163
  log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb)
164
+ t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), torch.flip(self.t_array.to(lamb.device), [1]))
 
165
  return t.reshape((-1,))
 
 
 
 
 
 
166
 
167
 
168
  def model_wrapper(
169
+ model,
170
+ noise_schedule,
171
+ model_type="noise",
172
+ model_kwargs={},
173
+ guidance_type="uncond",
174
+ condition=None,
175
+ unconditional_condition=None,
176
+ guidance_scale=1.,
177
+ classifier_fn=None,
178
+ classifier_kwargs={},
179
  ):
180
  """Create a wrapper function for the noise prediction model.
181
 
 
278
  return t_continuous
279
 
280
  def noise_pred_fn(x, t_continuous, cond=None):
 
 
281
  t_input = get_model_input_time(t_continuous)
282
  if cond is None:
283
  output = model(x, t_input, **model_kwargs)
 
287
  return output
288
  elif model_type == "x_start":
289
  alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
290
+ return (x - expand_dims(alpha_t, x.dim()) * output) / expand_dims(sigma_t, x.dim())
 
291
  elif model_type == "v":
292
  alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
293
+ return expand_dims(alpha_t, x.dim()) * output + expand_dims(sigma_t, x.dim()) * x
 
294
  elif model_type == "score":
295
  sigma_t = noise_schedule.marginal_std(t_continuous)
296
+ return -expand_dims(sigma_t, x.dim()) * output
 
297
 
298
  def cond_grad_fn(x, t_input):
299
  """
 
308
  """
309
  The noise predicition model function that is used for DPM-Solver.
310
  """
 
 
311
  if guidance_type == "uncond":
312
  return noise_pred_fn(x, t_continuous)
313
  elif guidance_type == "classifier":
 
316
  cond_grad = cond_grad_fn(x, t_input)
317
  sigma_t = noise_schedule.marginal_std(t_continuous)
318
  noise = noise_pred_fn(x, t_continuous)
319
+ return noise - guidance_scale * expand_dims(sigma_t, x.dim()) * cond_grad
320
  elif guidance_type == "classifier-free":
321
  if guidance_scale == 1. or unconditional_condition is None:
322
  return noise_pred_fn(x, t_continuous, cond=condition)
 
327
  noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
328
  return noise_uncond + guidance_scale * (noise - noise_uncond)
329
 
330
+ assert model_type in ["noise", "x_start", "v", "score"]
331
  assert guidance_type in ["uncond", "classifier", "classifier-free"]
332
  return model_fn
333
 
334
 
335
  class DPM_Solver:
336
+ def __init__(
337
+ self,
338
+ model_fn,
339
+ noise_schedule,
340
+ algorithm_type="dpmsolver++",
341
+ correcting_x0_fn=None,
342
+ correcting_xt_fn=None,
343
+ thresholding_max_val=1.,
344
+ dynamic_thresholding_ratio=0.995,
345
+ ):
346
  """Construct a DPM-Solver.
347
 
348
+ We support both DPM-Solver (`algorithm_type="dpmsolver"`) and DPM-Solver++ (`algorithm_type="dpmsolver++"`).
349
+
350
+ We also support the "dynamic thresholding" method in Imagen[1]. For pixel-space diffusion models, you
351
+ can set both `algorithm_type="dpmsolver++"` and `correcting_x0_fn="dynamic_thresholding"` to use the
352
+ dynamic thresholding. The "dynamic thresholding" can greatly improve the sample quality for pixel-space
353
+ DPMs with large guidance scales. Note that the thresholding method is **unsuitable** for latent-space
354
+ DPMs (such as stable-diffusion).
355
+
356
+ To support advanced algorithms in image-to-image applications, we also support corrector functions for
357
+ both x0 and xt.
358
 
359
  Args:
360
  model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]):
 
362
  def model_fn(x, t_continuous):
363
  return noise
364
  ``
365
+ The shape of `x` is `(batch_size, **shape)`, and the shape of `t_continuous` is `(batch_size,)`.
366
  noise_schedule: A noise schedule object, such as NoiseScheduleVP.
367
+ algorithm_type: A `str`. Either "dpmsolver" or "dpmsolver++".
368
+ correcting_x0_fn: A `str` or a function with the following format:
369
+ ```
370
+ def correcting_x0_fn(x0, t):
371
+ x0_new = ...
372
+ return x0_new
373
+ ```
374
+ This function is to correct the outputs of the data prediction model at each sampling step. e.g.,
375
+ ```
376
+ x0_pred = data_pred_model(xt, t)
377
+ if correcting_x0_fn is not None:
378
+ x0_pred = correcting_x0_fn(x0_pred, t)
379
+ xt_1 = update(x0_pred, xt, t)
380
+ ```
381
+ If `correcting_x0_fn="dynamic_thresholding"`, we use the dynamic thresholding proposed in Imagen[1].
382
+ correcting_xt_fn: A function with the following format:
383
+ ```
384
+ def correcting_xt_fn(xt, t, step):
385
+ x_new = ...
386
+ return x_new
387
+ ```
388
+ This function is to correct the intermediate samples xt at each sampling step. e.g.,
389
+ ```
390
+ xt = ...
391
+ xt = correcting_xt_fn(xt, t, step)
392
+ ```
393
+ thresholding_max_val: A `float`. The max value for thresholding.
394
+ Valid only when use `dpmsolver++` and `correcting_x0_fn="dynamic_thresholding"`.
395
+ dynamic_thresholding_ratio: A `float`. The ratio for dynamic thresholding (see Imagen[1] for details).
396
+ Valid only when use `dpmsolver++` and `correcting_x0_fn="dynamic_thresholding"`.
397
+
398
+ [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour,
399
+ Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models
400
+ with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b.
401
  """
402
+ self.model = lambda x, t: model_fn(x, t.expand((x.shape[0])))
403
  self.noise_schedule = noise_schedule
404
+ assert algorithm_type in ["dpmsolver", "dpmsolver++"]
405
+ self.algorithm_type = algorithm_type
406
+ if correcting_x0_fn == "dynamic_thresholding":
407
+ self.correcting_x0_fn = self.dynamic_thresholding_fn
408
+ else:
409
+ self.correcting_x0_fn = correcting_x0_fn
410
+ self.correcting_xt_fn = correcting_xt_fn
411
+ self.dynamic_thresholding_ratio = dynamic_thresholding_ratio
412
+ self.thresholding_max_val = thresholding_max_val
413
+
414
+ def dynamic_thresholding_fn(self, x0, t):
415
+ """
416
+ The dynamic thresholding method.
417
+ """
418
+ dims = x0.dim()
419
+ p = self.dynamic_thresholding_ratio
420
+ s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
421
+ s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims)
422
+ x0 = torch.clamp(x0, -s, s) / s
423
+ return x0
424
 
425
  def noise_prediction_fn(self, x, t):
426
  """
 
430
 
431
  def data_prediction_fn(self, x, t):
432
  """
433
+ Return the data prediction model (with corrector).
434
  """
435
  noise = self.noise_prediction_fn(x, t)
 
436
  alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
437
+ x0 = (x - sigma_t * noise) / alpha_t
438
+ if self.correcting_x0_fn is not None:
439
+ x0 = self.correcting_x0_fn(x0, t)
 
 
 
440
  return x0
441
 
442
  def model_fn(self, x, t):
443
  """
444
  Convert the model to the noise prediction model or the data prediction model.
445
  """
446
+ if self.algorithm_type == "dpmsolver++":
447
  return self.data_prediction_fn(x, t)
448
  else:
449
  return self.noise_prediction_fn(x, t)
 
472
  return torch.linspace(t_T, t_0, N + 1).to(device)
473
  elif skip_type == 'time_quadratic':
474
  t_order = 2
475
+ t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device)
476
  return t
477
  else:
478
+ raise ValueError("Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type))
 
479
 
480
  def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
481
  """
 
512
  if order == 3:
513
  K = steps // 3 + 1
514
  if steps % 3 == 0:
515
+ orders = [3,] * (K - 2) + [2, 1]
516
  elif steps % 3 == 1:
517
+ orders = [3,] * (K - 1) + [1]
518
  else:
519
+ orders = [3,] * (K - 1) + [2]
520
  elif order == 2:
521
  if steps % 2 == 0:
522
  K = steps // 2
523
+ orders = [2,] * K
524
  else:
525
  K = steps // 2 + 1
526
+ orders = [2,] * (K - 1) + [1]
527
  elif order == 1:
528
  K = 1
529
+ orders = [1,] * steps
530
  else:
531
  raise ValueError("'order' must be '1' or '2' or '3'.")
532
  if skip_type == 'logSNR':
533
  # To reproduce the results in DPM-Solver paper
534
  timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device)
535
  else:
536
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[torch.cumsum(torch.tensor([0,] + orders), 0).to(device)]
 
537
  return timesteps_outer, orders
538
 
539
+ def denoise_to_zero_fn(self, x, s):
540
  """
541
  Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
542
  """
 
548
 
549
  Args:
550
  x: A pytorch tensor. The initial value at time `s`.
551
+ s: A pytorch tensor. The starting time, with the shape (1,).
552
+ t: A pytorch tensor. The ending time, with the shape (1,).
553
  model_s: A pytorch tensor. The model function evaluated at time `s`.
554
  If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
555
  return_intermediate: A `bool`. If true, also return the model value at time `s`.
 
557
  x_t: A pytorch tensor. The approximated solution at time `t`.
558
  """
559
  ns = self.noise_schedule
 
560
  lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
561
  h = lambda_t - lambda_s
562
  log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t)
563
  sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t)
564
  alpha_t = torch.exp(log_alpha_t)
565
 
566
+ if self.algorithm_type == "dpmsolver++":
567
  phi_1 = torch.expm1(-h)
568
  if model_s is None:
569
  model_s = self.model_fn(x, s)
570
  x_t = (
571
+ sigma_t / sigma_s * x
572
+ - alpha_t * phi_1 * model_s
573
  )
574
  if return_intermediate:
575
  return x_t, {'model_s': model_s}
 
580
  if model_s is None:
581
  model_s = self.model_fn(x, s)
582
  x_t = (
583
+ torch.exp(log_alpha_t - log_alpha_s) * x
584
+ - (sigma_t * phi_1) * model_s
585
  )
586
  if return_intermediate:
587
  return x_t, {'model_s': model_s}
588
  else:
589
  return x_t
590
 
591
+ def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, return_intermediate=False, solver_type='dpmsolver'):
 
592
  """
593
  Singlestep solver DPM-Solver-2 from time `s` to time `t`.
594
 
595
  Args:
596
  x: A pytorch tensor. The initial value at time `s`.
597
+ s: A pytorch tensor. The starting time, with the shape (1,).
598
+ t: A pytorch tensor. The ending time, with the shape (1,).
599
  r1: A `float`. The hyperparameter of the second-order solver.
600
  model_s: A pytorch tensor. The model function evaluated at time `s`.
601
  If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
602
  return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time).
603
+ solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
604
+ The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
605
  Returns:
606
  x_t: A pytorch tensor. The approximated solution at time `t`.
607
  """
608
+ if solver_type not in ['dpmsolver', 'taylor']:
609
+ raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type))
610
  if r1 is None:
611
  r1 = 0.5
612
  ns = self.noise_schedule
 
613
  lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
614
  h = lambda_t - lambda_s
615
  lambda_s1 = lambda_s + r1 * h
616
  s1 = ns.inverse_lambda(lambda_s1)
617
+ log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(t)
 
618
  sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t)
619
  alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t)
620
 
621
+ if self.algorithm_type == "dpmsolver++":
622
  phi_11 = torch.expm1(-r1 * h)
623
  phi_1 = torch.expm1(-h)
624
 
625
  if model_s is None:
626
  model_s = self.model_fn(x, s)
627
  x_s1 = (
628
+ (sigma_s1 / sigma_s) * x
629
+ - (alpha_s1 * phi_11) * model_s
630
  )
631
  model_s1 = self.model_fn(x_s1, s1)
632
+ if solver_type == 'dpmsolver':
633
  x_t = (
634
+ (sigma_t / sigma_s) * x
635
+ - (alpha_t * phi_1) * model_s
636
+ - (0.5 / r1) * (alpha_t * phi_1) * (model_s1 - model_s)
637
  )
638
  elif solver_type == 'taylor':
639
  x_t = (
640
+ (sigma_t / sigma_s) * x
641
+ - (alpha_t * phi_1) * model_s
642
+ + (1. / r1) * (alpha_t * (phi_1 / h + 1.)) * (model_s1 - model_s)
 
643
  )
644
  else:
645
  phi_11 = torch.expm1(r1 * h)
 
648
  if model_s is None:
649
  model_s = self.model_fn(x, s)
650
  x_s1 = (
651
+ torch.exp(log_alpha_s1 - log_alpha_s) * x
652
+ - (sigma_s1 * phi_11) * model_s
653
  )
654
  model_s1 = self.model_fn(x_s1, s1)
655
+ if solver_type == 'dpmsolver':
656
  x_t = (
657
+ torch.exp(log_alpha_t - log_alpha_s) * x
658
+ - (sigma_t * phi_1) * model_s
659
+ - (0.5 / r1) * (sigma_t * phi_1) * (model_s1 - model_s)
660
  )
661
  elif solver_type == 'taylor':
662
  x_t = (
663
+ torch.exp(log_alpha_t - log_alpha_s) * x
664
+ - (sigma_t * phi_1) * model_s
665
+ - (1. / r1) * (sigma_t * (phi_1 / h - 1.)) * (model_s1 - model_s)
666
  )
667
  if return_intermediate:
668
  return x_t, {'model_s': model_s, 'model_s1': model_s1}
669
  else:
670
  return x_t
671
 
672
+ def singlestep_dpm_solver_third_update(self, x, s, t, r1=1./3., r2=2./3., model_s=None, model_s1=None, return_intermediate=False, solver_type='dpmsolver'):
 
673
  """
674
  Singlestep solver DPM-Solver-3 from time `s` to time `t`.
675
 
676
  Args:
677
  x: A pytorch tensor. The initial value at time `s`.
678
+ s: A pytorch tensor. The starting time, with the shape (1,).
679
+ t: A pytorch tensor. The ending time, with the shape (1,).
680
  r1: A `float`. The hyperparameter of the third-order solver.
681
  r2: A `float`. The hyperparameter of the third-order solver.
682
  model_s: A pytorch tensor. The model function evaluated at time `s`.
 
684
  model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`).
685
  If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it.
686
  return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
687
+ solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
688
+ The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
689
  Returns:
690
  x_t: A pytorch tensor. The approximated solution at time `t`.
691
  """
692
+ if solver_type not in ['dpmsolver', 'taylor']:
693
+ raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type))
694
  if r1 is None:
695
  r1 = 1. / 3.
696
  if r2 is None:
697
  r2 = 2. / 3.
698
  ns = self.noise_schedule
 
699
  lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
700
  h = lambda_t - lambda_s
701
  lambda_s1 = lambda_s + r1 * h
702
  lambda_s2 = lambda_s + r2 * h
703
  s1 = ns.inverse_lambda(lambda_s1)
704
  s2 = ns.inverse_lambda(lambda_s2)
705
+ log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t)
706
+ sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(s2), ns.marginal_std(t)
 
 
707
  alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t)
708
 
709
+ if self.algorithm_type == "dpmsolver++":
710
  phi_11 = torch.expm1(-r1 * h)
711
  phi_12 = torch.expm1(-r2 * h)
712
  phi_1 = torch.expm1(-h)
 
718
  model_s = self.model_fn(x, s)
719
  if model_s1 is None:
720
  x_s1 = (
721
+ (sigma_s1 / sigma_s) * x
722
+ - (alpha_s1 * phi_11) * model_s
723
  )
724
  model_s1 = self.model_fn(x_s1, s1)
725
  x_s2 = (
726
+ (sigma_s2 / sigma_s) * x
727
+ - (alpha_s2 * phi_12) * model_s
728
+ + r2 / r1 * (alpha_s2 * phi_22) * (model_s1 - model_s)
729
  )
730
  model_s2 = self.model_fn(x_s2, s2)
731
+ if solver_type == 'dpmsolver':
732
  x_t = (
733
+ (sigma_t / sigma_s) * x
734
+ - (alpha_t * phi_1) * model_s
735
+ + (1. / r2) * (alpha_t * phi_2) * (model_s2 - model_s)
736
  )
737
  elif solver_type == 'taylor':
738
  D1_0 = (1. / r1) * (model_s1 - model_s)
 
740
  D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
741
  D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
742
  x_t = (
743
+ (sigma_t / sigma_s) * x
744
+ - (alpha_t * phi_1) * model_s
745
+ + (alpha_t * phi_2) * D1
746
+ - (alpha_t * phi_3) * D2
747
  )
748
  else:
749
  phi_11 = torch.expm1(r1 * h)
 
757
  model_s = self.model_fn(x, s)
758
  if model_s1 is None:
759
  x_s1 = (
760
+ (torch.exp(log_alpha_s1 - log_alpha_s)) * x
761
+ - (sigma_s1 * phi_11) * model_s
762
  )
763
  model_s1 = self.model_fn(x_s1, s1)
764
  x_s2 = (
765
+ (torch.exp(log_alpha_s2 - log_alpha_s)) * x
766
+ - (sigma_s2 * phi_12) * model_s
767
+ - r2 / r1 * (sigma_s2 * phi_22) * (model_s1 - model_s)
768
  )
769
  model_s2 = self.model_fn(x_s2, s2)
770
+ if solver_type == 'dpmsolver':
771
  x_t = (
772
+ (torch.exp(log_alpha_t - log_alpha_s)) * x
773
+ - (sigma_t * phi_1) * model_s
774
+ - (1. / r2) * (sigma_t * phi_2) * (model_s2 - model_s)
775
  )
776
  elif solver_type == 'taylor':
777
  D1_0 = (1. / r1) * (model_s1 - model_s)
 
779
  D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
780
  D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
781
  x_t = (
782
+ (torch.exp(log_alpha_t - log_alpha_s)) * x
783
+ - (sigma_t * phi_1) * model_s
784
+ - (sigma_t * phi_2) * D1
785
+ - (sigma_t * phi_3) * D2
786
  )
787
 
788
  if return_intermediate:
 
790
  else:
791
  return x_t
792
 
793
+ def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpmsolver"):
794
  """
795
  Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`.
796
 
797
  Args:
798
  x: A pytorch tensor. The initial value at time `s`.
799
  model_prev_list: A list of pytorch tensor. The previous computed model values.
800
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,)
801
+ t: A pytorch tensor. The ending time, with the shape (1,).
802
+ solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
803
+ The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
804
  Returns:
805
  x_t: A pytorch tensor. The approximated solution at time `t`.
806
  """
807
+ if solver_type not in ['dpmsolver', 'taylor']:
808
+ raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type))
809
  ns = self.noise_schedule
810
+ model_prev_1, model_prev_0 = model_prev_list[-2], model_prev_list[-1]
811
+ t_prev_1, t_prev_0 = t_prev_list[-2], t_prev_list[-1]
812
+ lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t)
 
 
813
  log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
814
  sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
815
  alpha_t = torch.exp(log_alpha_t)
 
817
  h_0 = lambda_prev_0 - lambda_prev_1
818
  h = lambda_t - lambda_prev_0
819
  r0 = h_0 / h
820
+ D1_0 = (1. / r0) * (model_prev_0 - model_prev_1)
821
+ if self.algorithm_type == "dpmsolver++":
822
+ phi_1 = torch.expm1(-h)
823
+ if solver_type == 'dpmsolver':
824
  x_t = (
825
+ (sigma_t / sigma_prev_0) * x
826
+ - (alpha_t * phi_1) * model_prev_0
827
+ - 0.5 * (alpha_t * phi_1) * D1_0
828
  )
829
  elif solver_type == 'taylor':
830
  x_t = (
831
+ (sigma_t / sigma_prev_0) * x
832
+ - (alpha_t * phi_1) * model_prev_0
833
+ + (alpha_t * (phi_1 / h + 1.)) * D1_0
834
  )
835
  else:
836
+ phi_1 = torch.expm1(h)
837
+ if solver_type == 'dpmsolver':
838
  x_t = (
839
+ (torch.exp(log_alpha_t - log_alpha_prev_0)) * x
840
+ - (sigma_t * phi_1) * model_prev_0
841
+ - 0.5 * (sigma_t * phi_1) * D1_0
842
  )
843
  elif solver_type == 'taylor':
844
  x_t = (
845
+ (torch.exp(log_alpha_t - log_alpha_prev_0)) * x
846
+ - (sigma_t * phi_1) * model_prev_0
847
+ - (sigma_t * (phi_1 / h - 1.)) * D1_0
848
  )
849
  return x_t
850
 
851
+ def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type='dpmsolver'):
852
  """
853
  Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`.
854
 
855
  Args:
856
  x: A pytorch tensor. The initial value at time `s`.
857
  model_prev_list: A list of pytorch tensor. The previous computed model values.
858
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,)
859
+ t: A pytorch tensor. The ending time, with the shape (1,).
860
+ solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
861
+ The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
862
  Returns:
863
  x_t: A pytorch tensor. The approximated solution at time `t`.
864
  """
865
  ns = self.noise_schedule
 
866
  model_prev_2, model_prev_1, model_prev_0 = model_prev_list
867
  t_prev_2, t_prev_1, t_prev_0 = t_prev_list
868
+ lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda(t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t)
 
869
  log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
870
  sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
871
  alpha_t = torch.exp(log_alpha_t)
 
874
  h_0 = lambda_prev_0 - lambda_prev_1
875
  h = lambda_t - lambda_prev_0
876
  r0, r1 = h_0 / h, h_1 / h
877
+ D1_0 = (1. / r0) * (model_prev_0 - model_prev_1)
878
+ D1_1 = (1. / r1) * (model_prev_1 - model_prev_2)
879
+ D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
880
+ D2 = (1. / (r0 + r1)) * (D1_0 - D1_1)
881
+ if self.algorithm_type == "dpmsolver++":
882
+ phi_1 = torch.expm1(-h)
883
+ phi_2 = phi_1 / h + 1.
884
+ phi_3 = phi_2 / h - 0.5
885
  x_t = (
886
+ (sigma_t / sigma_prev_0) * x
887
+ - (alpha_t * phi_1) * model_prev_0
888
+ + (alpha_t * phi_2) * D1
889
+ - (alpha_t * phi_3) * D2
890
  )
891
  else:
892
+ phi_1 = torch.expm1(h)
893
+ phi_2 = phi_1 / h - 1.
894
+ phi_3 = phi_2 / h - 0.5
895
  x_t = (
896
+ (torch.exp(log_alpha_t - log_alpha_prev_0)) * x
897
+ - (sigma_t * phi_1) * model_prev_0
898
+ - (sigma_t * phi_2) * D1
899
+ - (sigma_t * phi_3) * D2
900
  )
901
  return x_t
902
 
903
+ def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False, solver_type='dpmsolver', r1=None, r2=None):
 
904
  """
905
  Singlestep DPM-Solver with the order `order` from time `s` to time `t`.
906
 
907
  Args:
908
  x: A pytorch tensor. The initial value at time `s`.
909
+ s: A pytorch tensor. The starting time, with the shape (1,).
910
+ t: A pytorch tensor. The ending time, with the shape (1,).
911
  order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
912
  return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
913
+ solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
914
+ The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
915
  r1: A `float`. The hyperparameter of the second-order or third-order solver.
916
  r2: A `float`. The hyperparameter of the third-order solver.
917
  Returns:
 
920
  if order == 1:
921
  return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate)
922
  elif order == 2:
923
+ return self.singlestep_dpm_solver_second_update(x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1)
 
924
  elif order == 3:
925
+ return self.singlestep_dpm_solver_third_update(x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1, r2=r2)
 
926
  else:
927
  raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
928
 
929
+ def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type='dpmsolver'):
930
  """
931
  Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`.
932
 
933
  Args:
934
  x: A pytorch tensor. The initial value at time `s`.
935
  model_prev_list: A list of pytorch tensor. The previous computed model values.
936
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,)
937
+ t: A pytorch tensor. The ending time, with the shape (1,).
938
  order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
939
+ solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
940
+ The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
941
  Returns:
942
  x_t: A pytorch tensor. The approximated solution at time `t`.
943
  """
 
950
  else:
951
  raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
952
 
953
+ def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5, solver_type='dpmsolver'):
 
954
  """
955
  The adaptive step size solver based on singlestep DPM-Solver.
956
 
 
965
  theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1].
966
  t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the
967
  current time and `t_0` is less than `t_err`. The default setting is 1e-5.
968
+ solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
969
+ The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
970
  Returns:
971
  x_0: A pytorch tensor. The approximated solution at time `t_0`.
972
 
973
  [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021.
974
  """
975
  ns = self.noise_schedule
976
+ s = t_T * torch.ones((1,)).to(x)
977
  lambda_s = ns.marginal_lambda(s)
978
  lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x))
979
  h = h_init * torch.ones_like(s).to(x)
 
981
  nfe = 0
982
  if order == 2:
983
  r1 = 0.5
984
+ def lower_update(x, s, t):
985
+ return self.dpm_solver_first_update(x, s, t, return_intermediate=True)
986
+ def higher_update(x, s, t, **kwargs):
987
+ return self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, solver_type=solver_type, **kwargs)
988
  elif order == 3:
989
  r1, r2 = 1. / 3., 2. / 3.
990
+ def lower_update(x, s, t):
991
+ return self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, return_intermediate=True, solver_type=solver_type)
992
+ def higher_update(x, s, t, **kwargs):
993
+ return self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2, solver_type=solver_type, **kwargs)
 
 
994
  else:
995
  raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order))
996
  while torch.abs((s - t_0)).mean() > t_err:
 
998
  x_lower, lower_noise_kwargs = lower_update(x, s, t)
999
  x_higher = higher_update(x, s, t, **lower_noise_kwargs)
1000
  delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev)))
1001
+ def norm_fn(v):
1002
+ return torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True))
1003
  E = norm_fn((x_higher - x_lower) / delta).max()
1004
  if torch.all(E <= 1.):
1005
  x = x_higher
 
1011
  print('adaptive solver nfe', nfe)
1012
  return x
1013
 
1014
+ def add_noise(self, x, t, noise=None):
1015
+ """
1016
+ Compute the noised input xt = alpha_t * x + sigma_t * noise.
1017
+
1018
+ Args:
1019
+ x: A `torch.Tensor` with shape `(batch_size, *shape)`.
1020
+ t: A `torch.Tensor` with shape `(t_size,)`.
1021
+ Returns:
1022
+ xt with shape `(t_size, batch_size, *shape)`.
1023
+ """
1024
+ alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
1025
+ if noise is None:
1026
+ noise = torch.randn((t.shape[0], *x.shape), device=x.device)
1027
+ x = x.reshape((-1, *x.shape))
1028
+ xt = expand_dims(alpha_t, x.dim()) * x + expand_dims(sigma_t, x.dim()) * noise
1029
+ if t.shape[0] == 1:
1030
+ return xt.squeeze(0)
1031
+ else:
1032
+ return xt
1033
+
1034
+ def inverse(self, x, steps=20, t_start=None, t_end=None, order=2, skip_type='time_uniform',
1035
+ method='multistep', lower_order_final=True, denoise_to_zero=False, solver_type='dpmsolver',
1036
+ atol=0.0078, rtol=0.05, return_intermediate=False,
1037
+ ):
1038
+ """
1039
+ Inverse the sample `x` from time `t_start` to `t_end` by DPM-Solver.
1040
+ For discrete-time DPMs, we use `t_start=1/N`, where `N` is the total time steps during training.
1041
+ """
1042
+ t_0 = 1. / self.noise_schedule.total_N if t_start is None else t_start
1043
+ t_T = self.noise_schedule.T if t_end is None else t_end
1044
+ assert t_0 > 0 and t_T > 0, "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array"
1045
+ return self.sample(x, steps=steps, t_start=t_0, t_end=t_T, order=order, skip_type=skip_type,
1046
+ method=method, lower_order_final=lower_order_final, denoise_to_zero=denoise_to_zero, solver_type=solver_type,
1047
+ atol=atol, rtol=rtol, return_intermediate=return_intermediate)
1048
+
1049
+ def sample(self, x, steps=20, t_start=None, t_end=None, order=2, skip_type='time_uniform',
1050
+ method='multistep', lower_order_final=True, denoise_to_zero=False, solver_type='dpmsolver',
1051
+ atol=0.0078, rtol=0.05, return_intermediate=False,
1052
+ ):
1053
  """
1054
  Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`.
1055
 
 
1098
 
1099
  Some advices for choosing the algorithm:
1100
  - For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs:
1101
+ Use singlestep DPM-Solver or DPM-Solver++ ("DPM-Solver-fast" in the paper) with `order = 3`.
1102
+ e.g., DPM-Solver:
1103
+ >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver")
1104
+ >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3,
1105
+ skip_type='time_uniform', method='singlestep')
1106
+ e.g., DPM-Solver++:
1107
+ >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++")
1108
  >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3,
1109
  skip_type='time_uniform', method='singlestep')
1110
  - For **guided sampling with large guidance scale** by DPMs:
1111
+ Use multistep DPM-Solver with `algorithm_type="dpmsolver++"` and `order = 2`.
1112
  e.g.
1113
+ >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++")
1114
  >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2,
1115
  skip_type='time_uniform', method='multistep')
1116
 
 
1136
  order: A `int`. The order of DPM-Solver.
1137
  skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'.
1138
  method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'.
1139
+ denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step.
1140
+ Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1).
1141
+
1142
+ This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and
1143
+ score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID
1144
+ for diffusion models sampling by diffusion SDEs for low-resolutional images
1145
+ (such as CIFAR-10). However, we observed that such trick does not matter for
1146
+ high-resolutional images. As it needs an additional NFE, we do not recommend
1147
+ it for high-resolutional images.
1148
+ lower_order_final: A `bool`. Whether to use lower order solvers at the final steps.
1149
+ Only valid for `method=multistep` and `steps < 15`. We empirically find that
1150
+ this trick is a key to stabilizing the sampling by DPM-Solver with very few steps
1151
+ (especially for steps <= 10). So we recommend to set it to be `True`.
1152
+ solver_type: A `str`. The taylor expansion type for the solver. `dpmsolver` or `taylor`. We recommend `dpmsolver`.
1153
  atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
1154
  rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
1155
+ return_intermediate: A `bool`. Whether to save the xt at each step.
1156
+ When set to `True`, method returns a tuple (x0, intermediates); when set to False, method returns only x0.
1157
  Returns:
1158
  x_end: A pytorch tensor. The approximated solution at time `t_end`.
1159
 
1160
  """
1161
  t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
1162
  t_T = self.noise_schedule.T if t_start is None else t_start
1163
+ assert t_0 > 0 and t_T > 0, "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array"
1164
+ if return_intermediate:
1165
+ assert method in ['multistep', 'singlestep', 'singlestep_fixed'], "Cannot use adaptive solver when saving intermediate values"
1166
+ if self.correcting_xt_fn is not None:
1167
+ assert method in ['multistep', 'singlestep', 'singlestep_fixed'], "Cannot use adaptive solver when correcting_xt_fn is not None"
1168
  device = x.device
1169
+ intermediates = []
1170
+ with torch.no_grad():
1171
+ if method == 'adaptive':
1172
+ x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol, solver_type=solver_type)
1173
+ elif method == 'multistep':
1174
+ assert steps >= order
1175
+ timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
1176
+ assert timesteps.shape[0] - 1 == steps
1177
+ # Init the initial values.
1178
+ step = 0
1179
+ t = timesteps[step]
1180
+ t_prev_list = [t]
1181
+ model_prev_list = [self.model_fn(x, t)]
1182
+ if self.correcting_xt_fn is not None:
1183
+ x = self.correcting_xt_fn(x, t, step)
1184
+ if return_intermediate:
1185
+ intermediates.append(x)
1186
  # Init the first `order` values by lower order multistep DPM-Solver.
1187
+ for step in range(1, order):
1188
+ t = timesteps[step]
1189
+ x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, t, step, solver_type=solver_type)
1190
+ if self.correcting_xt_fn is not None:
1191
+ x = self.correcting_xt_fn(x, t, step)
1192
+ if return_intermediate:
1193
+ intermediates.append(x)
1194
+ t_prev_list.append(t)
1195
+ model_prev_list.append(self.model_fn(x, t))
1196
  # Compute the remaining values by `order`-th order multistep DPM-Solver.
1197
  for step in range(order, steps + 1):
1198
+ t = timesteps[step]
1199
+ # We only use lower order for steps < 10
1200
+ if lower_order_final and steps < 10:
1201
+ step_order = min(order, steps + 1 - step)
1202
+ else:
1203
+ step_order = order
1204
+ x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, t, step_order, solver_type=solver_type)
1205
+ if self.correcting_xt_fn is not None:
1206
+ x = self.correcting_xt_fn(x, t, step)
1207
+ if return_intermediate:
1208
+ intermediates.append(x)
1209
  for i in range(order - 1):
1210
  t_prev_list[i] = t_prev_list[i + 1]
1211
  model_prev_list[i] = model_prev_list[i + 1]
1212
+ t_prev_list[-1] = t
1213
  # We do not need to evaluate the final model value.
1214
  if step < steps:
1215
+ model_prev_list[-1] = self.model_fn(x, t)
1216
+ elif method in ['singlestep', 'singlestep_fixed']:
1217
+ if method == 'singlestep':
1218
+ timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps=steps, order=order, skip_type=skip_type, t_T=t_T, t_0=t_0, device=device)
1219
+ elif method == 'singlestep_fixed':
1220
+ K = steps // order
1221
+ orders = [order,] * K
1222
+ timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device)
1223
+ for step, order in enumerate(orders):
1224
+ s, t = timesteps_outer[step], timesteps_outer[step + 1]
1225
+ timesteps_inner = self.get_time_steps(skip_type=skip_type, t_T=s.item(), t_0=t.item(), N=order, device=device)
1226
+ lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner)
1227
+ h = lambda_inner[-1] - lambda_inner[0]
1228
+ r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h
1229
+ r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h
1230
+ x = self.singlestep_dpm_solver_update(x, s, t, order, solver_type=solver_type, r1=r1, r2=r2)
1231
+ if self.correcting_xt_fn is not None:
1232
+ x = self.correcting_xt_fn(x, t, step)
1233
+ if return_intermediate:
1234
+ intermediates.append(x)
1235
+ else:
1236
+ raise ValueError("Got wrong method {}".format(method))
1237
+ if denoise_to_zero:
1238
+ t = torch.ones((1,)).to(device) * t_0
1239
+ x = self.denoise_to_zero_fn(x, t)
1240
+ if self.correcting_xt_fn is not None:
1241
+ x = self.correcting_xt_fn(x, t, step + 1)
1242
+ if return_intermediate:
1243
+ intermediates.append(x)
1244
+ if return_intermediate:
1245
+ return x, intermediates
1246
+ else:
1247
+ return x
1248
+
1249
 
1250
 
1251
  #############################################################
 
1304
  Returns:
1305
  a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
1306
  """
1307
+ return v[(...,) + (None,)*(dims - 1)]
diffusion/infer_gt_mel.py CHANGED
@@ -1,6 +1,6 @@
1
- import numpy as np
2
  import torch
3
  import torch.nn.functional as F
 
4
  from diffusion.unit2mel import load_model_vocoder
5
 
6
 
 
 
1
  import torch
2
  import torch.nn.functional as F
3
+
4
  from diffusion.unit2mel import load_model_vocoder
5
 
6
 
diffusion/logger/__pycache__/__init__.cpython-38.pyc CHANGED
Binary files a/diffusion/logger/__pycache__/__init__.cpython-38.pyc and b/diffusion/logger/__pycache__/__init__.cpython-38.pyc differ
 
diffusion/logger/__pycache__/saver.cpython-38.pyc CHANGED
Binary files a/diffusion/logger/__pycache__/saver.cpython-38.pyc and b/diffusion/logger/__pycache__/saver.cpython-38.pyc differ
 
diffusion/logger/__pycache__/utils.cpython-38.pyc CHANGED
Binary files a/diffusion/logger/__pycache__/utils.cpython-38.pyc and b/diffusion/logger/__pycache__/utils.cpython-38.pyc differ
 
diffusion/logger/saver.py CHANGED
@@ -2,16 +2,16 @@
2
  author: wayn391@mastertones
3
  '''
4
 
 
5
  import os
6
- import json
7
  import time
8
- import yaml
9
- import datetime
10
- import torch
11
  import matplotlib.pyplot as plt
12
- from . import utils
 
13
  from torch.utils.tensorboard import SummaryWriter
14
 
 
15
  class Saver(object):
16
  def __init__(
17
  self,
@@ -125,12 +125,7 @@ class Saver(object):
125
  torch.save({
126
  'global_step': self.global_step,
127
  'model': model.state_dict()}, path_pt)
128
-
129
- # to json
130
- if to_json:
131
- path_json = os.path.join(
132
- self.expdir , name+'.json')
133
- utils.to_json(path_params, path_json)
134
 
135
  def delete_model(self, name='model', postfix=''):
136
  # path
 
2
  author: wayn391@mastertones
3
  '''
4
 
5
+ import datetime
6
  import os
 
7
  import time
8
+
 
 
9
  import matplotlib.pyplot as plt
10
+ import torch
11
+ import yaml
12
  from torch.utils.tensorboard import SummaryWriter
13
 
14
+
15
  class Saver(object):
16
  def __init__(
17
  self,
 
125
  torch.save({
126
  'global_step': self.global_step,
127
  'model': model.state_dict()}, path_pt)
128
+
 
 
 
 
 
129
 
130
  def delete_model(self, name='model', postfix=''):
131
  # path
diffusion/logger/utils.py CHANGED
@@ -1,8 +1,9 @@
1
- import os
2
- import yaml
3
  import json
4
- import pickle
 
5
  import torch
 
 
6
 
7
  def traverse_dir(
8
  root_dir,
@@ -121,6 +122,6 @@ def load_model(
121
  ckpt = torch.load(path_pt, map_location=torch.device(device))
122
  global_step = ckpt['global_step']
123
  model.load_state_dict(ckpt['model'], strict=False)
124
- if ckpt.get('optimizer') != None:
125
  optimizer.load_state_dict(ckpt['optimizer'])
126
  return global_step, model, optimizer
 
 
 
1
  import json
2
+ import os
3
+
4
  import torch
5
+ import yaml
6
+
7
 
8
  def traverse_dir(
9
  root_dir,
 
122
  ckpt = torch.load(path_pt, map_location=torch.device(device))
123
  global_step = ckpt['global_step']
124
  model.load_state_dict(ckpt['model'], strict=False)
125
+ if ckpt.get("optimizer") is not None:
126
  optimizer.load_state_dict(ckpt['optimizer'])
127
  return global_step, model, optimizer
diffusion/onnx_export.py CHANGED
@@ -1,12 +1,12 @@
1
- from diffusion_onnx import GaussianDiffusion
2
  import os
3
- import yaml
 
4
  import torch
5
  import torch.nn as nn
6
- import numpy as np
7
- from wavenet import WaveNet
8
  import torch.nn.functional as F
9
- import diffusion
 
 
10
 
11
  class DotDict(dict):
12
  def __getattr__(*args):
@@ -33,7 +33,9 @@ def load_model_vocoder(
33
  128,
34
  args.model.n_layers,
35
  args.model.n_chans,
36
- args.model.n_hidden)
 
 
37
 
38
  print(' [Loading] ' + model_path)
39
  ckpt = torch.load(model_path, map_location=torch.device(device))
@@ -52,8 +54,11 @@ class Unit2Mel(nn.Module):
52
  out_dims=128,
53
  n_layers=20,
54
  n_chans=384,
55
- n_hidden=256):
 
 
56
  super().__init__()
 
57
  self.unit_embed = nn.Linear(input_channel, n_hidden)
58
  self.f0_embed = nn.Linear(1, n_hidden)
59
  self.volume_embed = nn.Linear(1, n_hidden)
@@ -64,9 +69,13 @@ class Unit2Mel(nn.Module):
64
  self.n_spk = n_spk
65
  if n_spk is not None and n_spk > 1:
66
  self.spk_embed = nn.Embedding(n_spk, n_hidden)
67
-
 
 
 
 
68
  # diffusion
69
- self.decoder = GaussianDiffusion(out_dims, n_layers, n_chans, n_hidden)
70
  self.hidden_size = n_hidden
71
  self.speaker_map = torch.zeros((self.n_spk,1,1,n_hidden))
72
 
@@ -138,8 +147,8 @@ class Unit2Mel(nn.Module):
138
  spks.update({i:1.0/float(self.n_spk)})
139
  spk_mix = torch.tensor(spk_mix)
140
  spk_mix = spk_mix.repeat(n_frames, 1)
141
- orgouttt = self.init_spkembed(hubert, f0.unsqueeze(-1), volume.unsqueeze(-1), spk_mix_dict=spks)
142
- outtt = self.forward(hubert, mel2ph, f0, volume, spk_mix)
143
  if export_encoder:
144
  torch.onnx.export(
145
  self,
@@ -173,8 +182,8 @@ class Unit2Mel(nn.Module):
173
  spk_mix.append(1.0/float(self.n_spk))
174
  spks.update({i:1.0/float(self.n_spk)})
175
  spk_mix = torch.tensor(spk_mix)
176
- orgouttt = self.orgforward(hubert, f0.unsqueeze(-1), volume.unsqueeze(-1), spk_mix_dict=spks)
177
- outtt = self.forward(hubert, mel2ph, f0, volume, spk_mix)
178
 
179
  torch.onnx.export(
180
  self,
 
 
1
  import os
2
+
3
+ import numpy as np
4
  import torch
5
  import torch.nn as nn
 
 
6
  import torch.nn.functional as F
7
+ import yaml
8
+ from diffusion_onnx import GaussianDiffusion
9
+
10
 
11
  class DotDict(dict):
12
  def __getattr__(*args):
 
33
  128,
34
  args.model.n_layers,
35
  args.model.n_chans,
36
+ args.model.n_hidden,
37
+ args.model.timesteps,
38
+ args.model.k_step_max)
39
 
40
  print(' [Loading] ' + model_path)
41
  ckpt = torch.load(model_path, map_location=torch.device(device))
 
54
  out_dims=128,
55
  n_layers=20,
56
  n_chans=384,
57
+ n_hidden=256,
58
+ timesteps=1000,
59
+ k_step_max=1000):
60
  super().__init__()
61
+
62
  self.unit_embed = nn.Linear(input_channel, n_hidden)
63
  self.f0_embed = nn.Linear(1, n_hidden)
64
  self.volume_embed = nn.Linear(1, n_hidden)
 
69
  self.n_spk = n_spk
70
  if n_spk is not None and n_spk > 1:
71
  self.spk_embed = nn.Embedding(n_spk, n_hidden)
72
+
73
+ self.timesteps = timesteps if timesteps is not None else 1000
74
+ self.k_step_max = k_step_max if k_step_max is not None and k_step_max>0 and k_step_max<self.timesteps else self.timesteps
75
+
76
+
77
  # diffusion
78
+ self.decoder = GaussianDiffusion(out_dims, n_layers, n_chans, n_hidden,self.timesteps,self.k_step_max)
79
  self.hidden_size = n_hidden
80
  self.speaker_map = torch.zeros((self.n_spk,1,1,n_hidden))
81
 
 
147
  spks.update({i:1.0/float(self.n_spk)})
148
  spk_mix = torch.tensor(spk_mix)
149
  spk_mix = spk_mix.repeat(n_frames, 1)
150
+ self.init_spkembed(hubert, f0.unsqueeze(-1), volume.unsqueeze(-1), spk_mix_dict=spks)
151
+ self.forward(hubert, mel2ph, f0, volume, spk_mix)
152
  if export_encoder:
153
  torch.onnx.export(
154
  self,
 
182
  spk_mix.append(1.0/float(self.n_spk))
183
  spks.update({i:1.0/float(self.n_spk)})
184
  spk_mix = torch.tensor(spk_mix)
185
+ self.orgforward(hubert, f0.unsqueeze(-1), volume.unsqueeze(-1), spk_mix_dict=spks)
186
+ self.forward(hubert, mel2ph, f0, volume, spk_mix)
187
 
188
  torch.onnx.export(
189
  self,
diffusion/solver.py CHANGED
@@ -1,13 +1,15 @@
1
- import os
2
  import time
 
 
3
  import numpy as np
4
  import torch
5
- import librosa
6
- from diffusion.logger.saver import Saver
7
- from diffusion.logger import utils
8
  from torch import autocast
9
  from torch.cuda.amp import GradScaler
10
 
 
 
 
 
11
  def test(args, model, vocoder, loader_test, saver):
12
  print(' [*] testing...')
13
  model.eval()
@@ -40,10 +42,12 @@ def test(args, model, vocoder, loader_test, saver):
40
  data['f0'],
41
  data['volume'],
42
  data['spk_id'],
43
- gt_spec=None,
44
  infer=True,
45
  infer_speedup=args.infer.speedup,
46
- method=args.infer.method)
 
 
47
  signal = vocoder.infer(mel, data['f0'])
48
  ed_time = time.time()
49
 
@@ -62,7 +66,8 @@ def test(args, model, vocoder, loader_test, saver):
62
  data['volume'],
63
  data['spk_id'],
64
  gt_spec=data['mel'],
65
- infer=False)
 
66
  test_loss += loss.item()
67
 
68
  # log mel
@@ -121,11 +126,11 @@ def train(args, initial_global_step, model, optimizer, scheduler, vocoder, loade
121
  # forward
122
  if dtype == torch.float32:
123
  loss = model(data['units'].float(), data['f0'], data['volume'], data['spk_id'],
124
- aug_shift = data['aug_shift'], gt_spec=data['mel'].float(), infer=False)
125
  else:
126
  with autocast(device_type=args.device, dtype=dtype):
127
  loss = model(data['units'], data['f0'], data['volume'], data['spk_id'],
128
- aug_shift = data['aug_shift'], gt_spec=data['mel'], infer=False)
129
 
130
  # handle nan loss
131
  if torch.isnan(loss):
 
 
1
  import time
2
+
3
+ import librosa
4
  import numpy as np
5
  import torch
 
 
 
6
  from torch import autocast
7
  from torch.cuda.amp import GradScaler
8
 
9
+ from diffusion.logger import utils
10
+ from diffusion.logger.saver import Saver
11
+
12
+
13
  def test(args, model, vocoder, loader_test, saver):
14
  print(' [*] testing...')
15
  model.eval()
 
42
  data['f0'],
43
  data['volume'],
44
  data['spk_id'],
45
+ gt_spec=None if model.k_step_max == model.timesteps else data['mel'],
46
  infer=True,
47
  infer_speedup=args.infer.speedup,
48
+ method=args.infer.method,
49
+ k_step=model.k_step_max
50
+ )
51
  signal = vocoder.infer(mel, data['f0'])
52
  ed_time = time.time()
53
 
 
66
  data['volume'],
67
  data['spk_id'],
68
  gt_spec=data['mel'],
69
+ infer=False,
70
+ k_step=model.k_step_max)
71
  test_loss += loss.item()
72
 
73
  # log mel
 
126
  # forward
127
  if dtype == torch.float32:
128
  loss = model(data['units'].float(), data['f0'], data['volume'], data['spk_id'],
129
+ aug_shift = data['aug_shift'], gt_spec=data['mel'].float(), infer=False, k_step=model.k_step_max)
130
  else:
131
  with autocast(device_type=args.device, dtype=dtype):
132
  loss = model(data['units'], data['f0'], data['volume'], data['spk_id'],
133
+ aug_shift = data['aug_shift'], gt_spec=data['mel'], infer=False, k_step=model.k_step_max)
134
 
135
  # handle nan loss
136
  if torch.isnan(loss):
diffusion/uni_pc.py ADDED
@@ -0,0 +1,733 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+
5
+
6
+ class NoiseScheduleVP:
7
+ def __init__(
8
+ self,
9
+ schedule='discrete',
10
+ betas=None,
11
+ alphas_cumprod=None,
12
+ continuous_beta_0=0.1,
13
+ continuous_beta_1=20.,
14
+ dtype=torch.float32,
15
+ ):
16
+ """Create a wrapper class for the forward SDE (VP type).
17
+ ***
18
+ Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
19
+ We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images.
20
+ ***
21
+ The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).
22
+ We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper).
23
+ Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:
24
+ log_alpha_t = self.marginal_log_mean_coeff(t)
25
+ sigma_t = self.marginal_std(t)
26
+ lambda_t = self.marginal_lambda(t)
27
+ Moreover, as lambda(t) is an invertible function, we also support its inverse function:
28
+ t = self.inverse_lambda(lambda_t)
29
+ ===============================================================
30
+ We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]).
31
+ 1. For discrete-time DPMs:
32
+ For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by:
33
+ t_i = (i + 1) / N
34
+ e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1.
35
+ We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3.
36
+ Args:
37
+ betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details)
38
+ alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details)
39
+ Note that we always have alphas_cumprod = cumprod(1 - betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`.
40
+ **Important**: Please pay special attention for the args for `alphas_cumprod`:
41
+ The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that
42
+ q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ).
43
+ Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have
44
+ alpha_{t_n} = \sqrt{\hat{alpha_n}},
45
+ and
46
+ log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}).
47
+ 2. For continuous-time DPMs:
48
+ We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise
49
+ schedule are the default settings in DDPM and improved-DDPM:
50
+ Args:
51
+ beta_min: A `float` number. The smallest beta for the linear schedule.
52
+ beta_max: A `float` number. The largest beta for the linear schedule.
53
+ cosine_s: A `float` number. The hyperparameter in the cosine schedule.
54
+ cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule.
55
+ T: A `float` number. The ending time of the forward process.
56
+ ===============================================================
57
+ Args:
58
+ schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs,
59
+ 'linear' or 'cosine' for continuous-time DPMs.
60
+ Returns:
61
+ A wrapper object of the forward SDE (VP type).
62
+
63
+ ===============================================================
64
+ Example:
65
+ # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1):
66
+ >>> ns = NoiseScheduleVP('discrete', betas=betas)
67
+ # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1):
68
+ >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
69
+ # For continuous-time DPMs (VPSDE), linear schedule:
70
+ >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.)
71
+ """
72
+
73
+ if schedule not in ['discrete', 'linear', 'cosine']:
74
+ raise ValueError("Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(schedule))
75
+
76
+ self.schedule = schedule
77
+ if schedule == 'discrete':
78
+ if betas is not None:
79
+ log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
80
+ else:
81
+ assert alphas_cumprod is not None
82
+ log_alphas = 0.5 * torch.log(alphas_cumprod)
83
+ self.total_N = len(log_alphas)
84
+ self.T = 1.
85
+ self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1)).to(dtype=dtype)
86
+ self.log_alpha_array = log_alphas.reshape((1, -1,)).to(dtype=dtype)
87
+ else:
88
+ self.total_N = 1000
89
+ self.beta_0 = continuous_beta_0
90
+ self.beta_1 = continuous_beta_1
91
+ self.cosine_s = 0.008
92
+ self.cosine_beta_max = 999.
93
+ self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s
94
+ self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.))
95
+ self.schedule = schedule
96
+ if schedule == 'cosine':
97
+ # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.
98
+ # Note that T = 0.9946 may be not the optimal setting. However, we find it works well.
99
+ self.T = 0.9946
100
+ else:
101
+ self.T = 1.
102
+
103
+ def marginal_log_mean_coeff(self, t):
104
+ """
105
+ Compute log(alpha_t) of a given continuous-time label t in [0, T].
106
+ """
107
+ if self.schedule == 'discrete':
108
+ return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), self.log_alpha_array.to(t.device)).reshape((-1))
109
+ elif self.schedule == 'linear':
110
+ return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
111
+ elif self.schedule == 'cosine':
112
+ def log_alpha_fn(s):
113
+ return torch.log(torch.cos((s + self.cosine_s) / (1.0 + self.cosine_s) * math.pi / 2.0))
114
+ log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
115
+ return log_alpha_t
116
+
117
+ def marginal_alpha(self, t):
118
+ """
119
+ Compute alpha_t of a given continuous-time label t in [0, T].
120
+ """
121
+ return torch.exp(self.marginal_log_mean_coeff(t))
122
+
123
+ def marginal_std(self, t):
124
+ """
125
+ Compute sigma_t of a given continuous-time label t in [0, T].
126
+ """
127
+ return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
128
+
129
+ def marginal_lambda(self, t):
130
+ """
131
+ Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
132
+ """
133
+ log_mean_coeff = self.marginal_log_mean_coeff(t)
134
+ log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
135
+ return log_mean_coeff - log_std
136
+
137
+ def inverse_lambda(self, lamb):
138
+ """
139
+ Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
140
+ """
141
+ if self.schedule == 'linear':
142
+ tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
143
+ Delta = self.beta_0**2 + tmp
144
+ return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
145
+ elif self.schedule == 'discrete':
146
+ log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb)
147
+ t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), torch.flip(self.t_array.to(lamb.device), [1]))
148
+ return t.reshape((-1,))
149
+ else:
150
+ log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
151
+ def t_fn(log_alpha_t):
152
+ return torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2.0 * (1.0 + self.cosine_s) / math.pi - self.cosine_s
153
+ t = t_fn(log_alpha)
154
+ return t
155
+
156
+
157
+ def model_wrapper(
158
+ model,
159
+ noise_schedule,
160
+ model_type="noise",
161
+ model_kwargs={},
162
+ guidance_type="uncond",
163
+ condition=None,
164
+ unconditional_condition=None,
165
+ guidance_scale=1.,
166
+ classifier_fn=None,
167
+ classifier_kwargs={},
168
+ ):
169
+ """Create a wrapper function for the noise prediction model.
170
+ """
171
+
172
+ def get_model_input_time(t_continuous):
173
+ """
174
+ Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
175
+ For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
176
+ For continuous-time DPMs, we just use `t_continuous`.
177
+ """
178
+ if noise_schedule.schedule == 'discrete':
179
+ return (t_continuous - 1. / noise_schedule.total_N) * noise_schedule.total_N
180
+ else:
181
+ return t_continuous
182
+
183
+ def noise_pred_fn(x, t_continuous, cond=None):
184
+ t_input = get_model_input_time(t_continuous)
185
+ if cond is None:
186
+ output = model(x, t_input, **model_kwargs)
187
+ else:
188
+ output = model(x, t_input, cond, **model_kwargs)
189
+ if model_type == "noise":
190
+ return output
191
+ elif model_type == "x_start":
192
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
193
+ return (x - alpha_t * output) / sigma_t
194
+ elif model_type == "v":
195
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
196
+ return alpha_t * output + sigma_t * x
197
+ elif model_type == "score":
198
+ sigma_t = noise_schedule.marginal_std(t_continuous)
199
+ return -sigma_t * output
200
+
201
+ def cond_grad_fn(x, t_input):
202
+ """
203
+ Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
204
+ """
205
+ with torch.enable_grad():
206
+ x_in = x.detach().requires_grad_(True)
207
+ log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
208
+ return torch.autograd.grad(log_prob.sum(), x_in)[0]
209
+
210
+ def model_fn(x, t_continuous):
211
+ """
212
+ The noise predicition model function that is used for DPM-Solver.
213
+ """
214
+ if guidance_type == "uncond":
215
+ return noise_pred_fn(x, t_continuous)
216
+ elif guidance_type == "classifier":
217
+ assert classifier_fn is not None
218
+ t_input = get_model_input_time(t_continuous)
219
+ cond_grad = cond_grad_fn(x, t_input)
220
+ sigma_t = noise_schedule.marginal_std(t_continuous)
221
+ noise = noise_pred_fn(x, t_continuous)
222
+ return noise - guidance_scale * sigma_t * cond_grad
223
+ elif guidance_type == "classifier-free":
224
+ if guidance_scale == 1. or unconditional_condition is None:
225
+ return noise_pred_fn(x, t_continuous, cond=condition)
226
+ else:
227
+ x_in = torch.cat([x] * 2)
228
+ t_in = torch.cat([t_continuous] * 2)
229
+ c_in = torch.cat([unconditional_condition, condition])
230
+ noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
231
+ return noise_uncond + guidance_scale * (noise - noise_uncond)
232
+
233
+ assert model_type in ["noise", "x_start", "v"]
234
+ assert guidance_type in ["uncond", "classifier", "classifier-free"]
235
+ return model_fn
236
+
237
+
238
+ class UniPC:
239
+ def __init__(
240
+ self,
241
+ model_fn,
242
+ noise_schedule,
243
+ algorithm_type="data_prediction",
244
+ correcting_x0_fn=None,
245
+ correcting_xt_fn=None,
246
+ thresholding_max_val=1.,
247
+ dynamic_thresholding_ratio=0.995,
248
+ variant='bh1'
249
+ ):
250
+ """Construct a UniPC.
251
+
252
+ We support both data_prediction and noise_prediction.
253
+ """
254
+ self.model = lambda x, t: model_fn(x, t.expand((x.shape[0])))
255
+ self.noise_schedule = noise_schedule
256
+ assert algorithm_type in ["data_prediction", "noise_prediction"]
257
+
258
+ if correcting_x0_fn == "dynamic_thresholding":
259
+ self.correcting_x0_fn = self.dynamic_thresholding_fn
260
+ else:
261
+ self.correcting_x0_fn = correcting_x0_fn
262
+
263
+ self.correcting_xt_fn = correcting_xt_fn
264
+ self.dynamic_thresholding_ratio = dynamic_thresholding_ratio
265
+ self.thresholding_max_val = thresholding_max_val
266
+
267
+ self.variant = variant
268
+ self.predict_x0 = algorithm_type == "data_prediction"
269
+
270
+ def dynamic_thresholding_fn(self, x0, t=None):
271
+ """
272
+ The dynamic thresholding method.
273
+ """
274
+ dims = x0.dim()
275
+ p = self.dynamic_thresholding_ratio
276
+ s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
277
+ s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims)
278
+ x0 = torch.clamp(x0, -s, s) / s
279
+ return x0
280
+
281
+ def noise_prediction_fn(self, x, t):
282
+ """
283
+ Return the noise prediction model.
284
+ """
285
+ return self.model(x, t)
286
+
287
+ def data_prediction_fn(self, x, t):
288
+ """
289
+ Return the data prediction model (with corrector).
290
+ """
291
+ noise = self.noise_prediction_fn(x, t)
292
+ alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
293
+ x0 = (x - sigma_t * noise) / alpha_t
294
+ if self.correcting_x0_fn is not None:
295
+ x0 = self.correcting_x0_fn(x0)
296
+ return x0
297
+
298
+ def model_fn(self, x, t):
299
+ """
300
+ Convert the model to the noise prediction model or the data prediction model.
301
+ """
302
+ if self.predict_x0:
303
+ return self.data_prediction_fn(x, t)
304
+ else:
305
+ return self.noise_prediction_fn(x, t)
306
+
307
+ def get_time_steps(self, skip_type, t_T, t_0, N, device):
308
+ """Compute the intermediate time steps for sampling.
309
+ """
310
+ if skip_type == 'logSNR':
311
+ lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
312
+ lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
313
+ logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
314
+ return self.noise_schedule.inverse_lambda(logSNR_steps)
315
+ elif skip_type == 'time_uniform':
316
+ return torch.linspace(t_T, t_0, N + 1).to(device)
317
+ elif skip_type == 'time_quadratic':
318
+ t_order = 2
319
+ t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device)
320
+ return t
321
+ else:
322
+ raise ValueError("Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type))
323
+
324
+ def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
325
+ """
326
+ Get the order of each step for sampling by the singlestep DPM-Solver.
327
+ """
328
+ if order == 3:
329
+ K = steps // 3 + 1
330
+ if steps % 3 == 0:
331
+ orders = [3,] * (K - 2) + [2, 1]
332
+ elif steps % 3 == 1:
333
+ orders = [3,] * (K - 1) + [1]
334
+ else:
335
+ orders = [3,] * (K - 1) + [2]
336
+ elif order == 2:
337
+ if steps % 2 == 0:
338
+ K = steps // 2
339
+ orders = [2,] * K
340
+ else:
341
+ K = steps // 2 + 1
342
+ orders = [2,] * (K - 1) + [1]
343
+ elif order == 1:
344
+ K = steps
345
+ orders = [1,] * steps
346
+ else:
347
+ raise ValueError("'order' must be '1' or '2' or '3'.")
348
+ if skip_type == 'logSNR':
349
+ # To reproduce the results in DPM-Solver paper
350
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device)
351
+ else:
352
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[torch.cumsum(torch.tensor([0,] + orders), 0).to(device)]
353
+ return timesteps_outer, orders
354
+
355
+ def denoise_to_zero_fn(self, x, s):
356
+ """
357
+ Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
358
+ """
359
+ return self.data_prediction_fn(x, s)
360
+
361
+ def multistep_uni_pc_update(self, x, model_prev_list, t_prev_list, t, order, **kwargs):
362
+ if len(t.shape) == 0:
363
+ t = t.view(-1)
364
+ if 'bh' in self.variant:
365
+ return self.multistep_uni_pc_bh_update(x, model_prev_list, t_prev_list, t, order, **kwargs)
366
+ else:
367
+ assert self.variant == 'vary_coeff'
368
+ return self.multistep_uni_pc_vary_update(x, model_prev_list, t_prev_list, t, order, **kwargs)
369
+
370
+ def multistep_uni_pc_vary_update(self, x, model_prev_list, t_prev_list, t, order, use_corrector=True):
371
+ #print(f'using unified predictor-corrector with order {order} (solver type: vary coeff)')
372
+ ns = self.noise_schedule
373
+ assert order <= len(model_prev_list)
374
+
375
+ # first compute rks
376
+ t_prev_0 = t_prev_list[-1]
377
+ lambda_prev_0 = ns.marginal_lambda(t_prev_0)
378
+ lambda_t = ns.marginal_lambda(t)
379
+ model_prev_0 = model_prev_list[-1]
380
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
381
+ log_alpha_t = ns.marginal_log_mean_coeff(t)
382
+ alpha_t = torch.exp(log_alpha_t)
383
+
384
+ h = lambda_t - lambda_prev_0
385
+
386
+ rks = []
387
+ D1s = []
388
+ for i in range(1, order):
389
+ t_prev_i = t_prev_list[-(i + 1)]
390
+ model_prev_i = model_prev_list[-(i + 1)]
391
+ lambda_prev_i = ns.marginal_lambda(t_prev_i)
392
+ rk = (lambda_prev_i - lambda_prev_0) / h
393
+ rks.append(rk)
394
+ D1s.append((model_prev_i - model_prev_0) / rk)
395
+
396
+ rks.append(1.)
397
+ rks = torch.tensor(rks, device=x.device)
398
+
399
+ K = len(rks)
400
+ # build C matrix
401
+ C = []
402
+
403
+ col = torch.ones_like(rks)
404
+ for k in range(1, K + 1):
405
+ C.append(col)
406
+ col = col * rks / (k + 1)
407
+ C = torch.stack(C, dim=1)
408
+
409
+ if len(D1s) > 0:
410
+ D1s = torch.stack(D1s, dim=1) # (B, K)
411
+ C_inv_p = torch.linalg.inv(C[:-1, :-1])
412
+ A_p = C_inv_p
413
+
414
+ if use_corrector:
415
+ #print('using corrector')
416
+ C_inv = torch.linalg.inv(C)
417
+ A_c = C_inv
418
+
419
+ hh = -h if self.predict_x0 else h
420
+ h_phi_1 = torch.expm1(hh)
421
+ h_phi_ks = []
422
+ factorial_k = 1
423
+ h_phi_k = h_phi_1
424
+ for k in range(1, K + 2):
425
+ h_phi_ks.append(h_phi_k)
426
+ h_phi_k = h_phi_k / hh - 1 / factorial_k
427
+ factorial_k *= (k + 1)
428
+
429
+ model_t = None
430
+ if self.predict_x0:
431
+ x_t_ = (
432
+ sigma_t / sigma_prev_0 * x
433
+ - alpha_t * h_phi_1 * model_prev_0
434
+ )
435
+ # now predictor
436
+ x_t = x_t_
437
+ if len(D1s) > 0:
438
+ # compute the residuals for predictor
439
+ for k in range(K - 1):
440
+ x_t = x_t - alpha_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_p[k])
441
+ # now corrector
442
+ if use_corrector:
443
+ model_t = self.model_fn(x_t, t)
444
+ D1_t = (model_t - model_prev_0)
445
+ x_t = x_t_
446
+ k = 0
447
+ for k in range(K - 1):
448
+ x_t = x_t - alpha_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_c[k][:-1])
449
+ x_t = x_t - alpha_t * h_phi_ks[K] * (D1_t * A_c[k][-1])
450
+ else:
451
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
452
+ x_t_ = (
453
+ (torch.exp(log_alpha_t - log_alpha_prev_0)) * x
454
+ - (sigma_t * h_phi_1) * model_prev_0
455
+ )
456
+ # now predictor
457
+ x_t = x_t_
458
+ if len(D1s) > 0:
459
+ # compute the residuals for predictor
460
+ for k in range(K - 1):
461
+ x_t = x_t - sigma_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_p[k])
462
+ # now corrector
463
+ if use_corrector:
464
+ model_t = self.model_fn(x_t, t)
465
+ D1_t = (model_t - model_prev_0)
466
+ x_t = x_t_
467
+ k = 0
468
+ for k in range(K - 1):
469
+ x_t = x_t - sigma_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_c[k][:-1])
470
+ x_t = x_t - sigma_t * h_phi_ks[K] * (D1_t * A_c[k][-1])
471
+ return x_t, model_t
472
+
473
+ def multistep_uni_pc_bh_update(self, x, model_prev_list, t_prev_list, t, order, x_t=None, use_corrector=True):
474
+ #print(f'using unified predictor-corrector with order {order} (solver type: B(h))')
475
+ ns = self.noise_schedule
476
+ assert order <= len(model_prev_list)
477
+
478
+ # first compute rks
479
+ t_prev_0 = t_prev_list[-1]
480
+ lambda_prev_0 = ns.marginal_lambda(t_prev_0)
481
+ lambda_t = ns.marginal_lambda(t)
482
+ model_prev_0 = model_prev_list[-1]
483
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
484
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
485
+ alpha_t = torch.exp(log_alpha_t)
486
+
487
+ h = lambda_t - lambda_prev_0
488
+
489
+ rks = []
490
+ D1s = []
491
+ for i in range(1, order):
492
+ t_prev_i = t_prev_list[-(i + 1)]
493
+ model_prev_i = model_prev_list[-(i + 1)]
494
+ lambda_prev_i = ns.marginal_lambda(t_prev_i)
495
+ rk = (lambda_prev_i - lambda_prev_0) / h
496
+ rks.append(rk)
497
+ D1s.append((model_prev_i - model_prev_0) / rk)
498
+
499
+ rks.append(1.)
500
+ rks = torch.tensor(rks, device=x.device)
501
+
502
+ R = []
503
+ b = []
504
+
505
+ hh = -h if self.predict_x0 else h
506
+ h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
507
+ h_phi_k = h_phi_1 / hh - 1
508
+
509
+ factorial_i = 1
510
+
511
+ if self.variant == 'bh1':
512
+ B_h = hh
513
+ elif self.variant == 'bh2':
514
+ B_h = torch.expm1(hh)
515
+ else:
516
+ raise NotImplementedError()
517
+
518
+ for i in range(1, order + 1):
519
+ R.append(torch.pow(rks, i - 1))
520
+ b.append(h_phi_k * factorial_i / B_h)
521
+ factorial_i *= (i + 1)
522
+ h_phi_k = h_phi_k / hh - 1 / factorial_i
523
+
524
+ R = torch.stack(R)
525
+ b = torch.cat(b)
526
+
527
+ # now predictor
528
+ use_predictor = len(D1s) > 0 and x_t is None
529
+ if len(D1s) > 0:
530
+ D1s = torch.stack(D1s, dim=1) # (B, K)
531
+ if x_t is None:
532
+ # for order 2, we use a simplified version
533
+ if order == 2:
534
+ rhos_p = torch.tensor([0.5], device=b.device)
535
+ else:
536
+ rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1])
537
+ else:
538
+ D1s = None
539
+
540
+ if use_corrector:
541
+ #print('using corrector')
542
+ # for order 1, we use a simplified version
543
+ if order == 1:
544
+ rhos_c = torch.tensor([0.5], device=b.device)
545
+ else:
546
+ rhos_c = torch.linalg.solve(R, b)
547
+
548
+ model_t = None
549
+ if self.predict_x0:
550
+ x_t_ = (
551
+ sigma_t / sigma_prev_0 * x
552
+ - alpha_t * h_phi_1 * model_prev_0
553
+ )
554
+
555
+ if x_t is None:
556
+ if use_predictor:
557
+ pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s)
558
+ else:
559
+ pred_res = 0
560
+ x_t = x_t_ - alpha_t * B_h * pred_res
561
+
562
+ if use_corrector:
563
+ model_t = self.model_fn(x_t, t)
564
+ if D1s is not None:
565
+ corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s)
566
+ else:
567
+ corr_res = 0
568
+ D1_t = (model_t - model_prev_0)
569
+ x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t)
570
+ else:
571
+ x_t_ = (
572
+ torch.exp(log_alpha_t - log_alpha_prev_0) * x
573
+ - sigma_t * h_phi_1 * model_prev_0
574
+ )
575
+ if x_t is None:
576
+ if use_predictor:
577
+ pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s)
578
+ else:
579
+ pred_res = 0
580
+ x_t = x_t_ - sigma_t * B_h * pred_res
581
+
582
+ if use_corrector:
583
+ model_t = self.model_fn(x_t, t)
584
+ if D1s is not None:
585
+ corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s)
586
+ else:
587
+ corr_res = 0
588
+ D1_t = (model_t - model_prev_0)
589
+ x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t)
590
+ return x_t, model_t
591
+
592
+ def sample(self, x, steps=20, t_start=None, t_end=None, order=2, skip_type='time_uniform',
593
+ method='multistep', lower_order_final=True, denoise_to_zero=False, atol=0.0078, rtol=0.05, return_intermediate=False,
594
+ ):
595
+ """
596
+ Compute the sample at time `t_end` by UniPC, given the initial `x` at time `t_start`.
597
+ """
598
+ t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
599
+ t_T = self.noise_schedule.T if t_start is None else t_start
600
+ assert t_0 > 0 and t_T > 0, "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array"
601
+ if return_intermediate:
602
+ assert method in ['multistep', 'singlestep', 'singlestep_fixed'], "Cannot use adaptive solver when saving intermediate values"
603
+ if self.correcting_xt_fn is not None:
604
+ assert method in ['multistep', 'singlestep', 'singlestep_fixed'], "Cannot use adaptive solver when correcting_xt_fn is not None"
605
+ device = x.device
606
+ intermediates = []
607
+ with torch.no_grad():
608
+ if method == 'multistep':
609
+ assert steps >= order
610
+ timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
611
+ assert timesteps.shape[0] - 1 == steps
612
+ # Init the initial values.
613
+ step = 0
614
+ t = timesteps[step]
615
+ t_prev_list = [t]
616
+ model_prev_list = [self.model_fn(x, t)]
617
+ if self.correcting_xt_fn is not None:
618
+ x = self.correcting_xt_fn(x, t, step)
619
+ if return_intermediate:
620
+ intermediates.append(x)
621
+
622
+ # Init the first `order` values by lower order multistep UniPC.
623
+ for step in range(1, order):
624
+ t = timesteps[step]
625
+ x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, t, step, use_corrector=True)
626
+ if model_x is None:
627
+ model_x = self.model_fn(x, t)
628
+ if self.correcting_xt_fn is not None:
629
+ x = self.correcting_xt_fn(x, t, step)
630
+ if return_intermediate:
631
+ intermediates.append(x)
632
+ t_prev_list.append(t)
633
+ model_prev_list.append(model_x)
634
+
635
+ # Compute the remaining values by `order`-th order multistep DPM-Solver.
636
+ for step in range(order, steps + 1):
637
+ t = timesteps[step]
638
+ if lower_order_final:
639
+ step_order = min(order, steps + 1 - step)
640
+ else:
641
+ step_order = order
642
+ if step == steps:
643
+ #print('do not run corrector at the last step')
644
+ use_corrector = False
645
+ else:
646
+ use_corrector = True
647
+ x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, t, step_order, use_corrector=use_corrector)
648
+ if self.correcting_xt_fn is not None:
649
+ x = self.correcting_xt_fn(x, t, step)
650
+ if return_intermediate:
651
+ intermediates.append(x)
652
+ for i in range(order - 1):
653
+ t_prev_list[i] = t_prev_list[i + 1]
654
+ model_prev_list[i] = model_prev_list[i + 1]
655
+ t_prev_list[-1] = t
656
+ # We do not need to evaluate the final model value.
657
+ if step < steps:
658
+ if model_x is None:
659
+ model_x = self.model_fn(x, t)
660
+ model_prev_list[-1] = model_x
661
+ else:
662
+ raise ValueError("Got wrong method {}".format(method))
663
+
664
+ if denoise_to_zero:
665
+ t = torch.ones((1,)).to(device) * t_0
666
+ x = self.denoise_to_zero_fn(x, t)
667
+ if self.correcting_xt_fn is not None:
668
+ x = self.correcting_xt_fn(x, t, step + 1)
669
+ if return_intermediate:
670
+ intermediates.append(x)
671
+ if return_intermediate:
672
+ return x, intermediates
673
+ else:
674
+ return x
675
+
676
+
677
+ #############################################################
678
+ # other utility functions
679
+ #############################################################
680
+
681
+ def interpolate_fn(x, xp, yp):
682
+ """
683
+ A piecewise linear function y = f(x), using xp and yp as keypoints.
684
+ We implement f(x) in a differentiable way (i.e. applicable for autograd).
685
+ The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
686
+
687
+ Args:
688
+ x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
689
+ xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
690
+ yp: PyTorch tensor with shape [C, K].
691
+ Returns:
692
+ The function values f(x), with shape [N, C].
693
+ """
694
+ N, K = x.shape[0], xp.shape[1]
695
+ all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
696
+ sorted_all_x, x_indices = torch.sort(all_x, dim=2)
697
+ x_idx = torch.argmin(x_indices, dim=2)
698
+ cand_start_idx = x_idx - 1
699
+ start_idx = torch.where(
700
+ torch.eq(x_idx, 0),
701
+ torch.tensor(1, device=x.device),
702
+ torch.where(
703
+ torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
704
+ ),
705
+ )
706
+ end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
707
+ start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
708
+ end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
709
+ start_idx2 = torch.where(
710
+ torch.eq(x_idx, 0),
711
+ torch.tensor(0, device=x.device),
712
+ torch.where(
713
+ torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
714
+ ),
715
+ )
716
+ y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
717
+ start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
718
+ end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
719
+ cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
720
+ return cand
721
+
722
+
723
+ def expand_dims(v, dims):
724
+ """
725
+ Expand the tensor `v` to the dim `dims`.
726
+
727
+ Args:
728
+ `v`: a PyTorch tensor with shape [N].
729
+ `dim`: a `int`.
730
+ Returns:
731
+ a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
732
+ """
733
+ return v[(...,) + (None,)*(dims - 1)]
diffusion/unit2mel.py CHANGED
@@ -1,11 +1,14 @@
1
  import os
2
- import yaml
 
3
  import torch
4
  import torch.nn as nn
5
- import numpy as np
 
6
  from .diffusion import GaussianDiffusion
7
- from .wavenet import WaveNet
8
  from .vocoder import Vocoder
 
 
9
 
10
  class DotDict(dict):
11
  def __getattr__(*args):
@@ -21,9 +24,11 @@ def load_model_vocoder(
21
  device='cpu',
22
  config_path = None
23
  ):
24
- if config_path is None: config_file = os.path.join(os.path.split(model_path)[0], 'config.yaml')
25
- else: config_file = config_path
26
-
 
 
27
  with open(config_file, "r") as config:
28
  args = yaml.safe_load(config)
29
  args = DotDict(args)
@@ -39,13 +44,17 @@ def load_model_vocoder(
39
  vocoder.dimension,
40
  args.model.n_layers,
41
  args.model.n_chans,
42
- args.model.n_hidden)
 
 
 
43
 
44
  print(' [Loading] ' + model_path)
45
  ckpt = torch.load(model_path, map_location=torch.device(device))
46
  model.to(device)
47
  model.load_state_dict(ckpt['model'])
48
  model.eval()
 
49
  return model, vocoder, args
50
 
51
 
@@ -58,7 +67,10 @@ class Unit2Mel(nn.Module):
58
  out_dims=128,
59
  n_layers=20,
60
  n_chans=384,
61
- n_hidden=256):
 
 
 
62
  super().__init__()
63
  self.unit_embed = nn.Linear(input_channel, n_hidden)
64
  self.f0_embed = nn.Linear(1, n_hidden)
@@ -71,9 +83,12 @@ class Unit2Mel(nn.Module):
71
  if n_spk is not None and n_spk > 1:
72
  self.spk_embed = nn.Embedding(n_spk, n_hidden)
73
 
 
 
 
74
  self.n_hidden = n_hidden
75
  # diffusion
76
- self.decoder = GaussianDiffusion(WaveNet(out_dims, n_layers, n_chans, n_hidden), out_dims=out_dims)
77
  self.input_channel = input_channel
78
 
79
  def init_spkembed(self, units, f0, volume, spk_id = None, spk_mix_dict = None, aug_shift = None,
@@ -106,13 +121,12 @@ class Unit2Mel(nn.Module):
106
  hubert_hidden_size = self.input_channel
107
  n_frames = 10
108
  hubert = torch.randn((1, n_frames, hubert_hidden_size))
109
- mel2ph = torch.arange(end=n_frames).unsqueeze(0).long()
110
  f0 = torch.randn((1, n_frames))
111
  volume = torch.randn((1, n_frames))
112
  spks = {}
113
  for i in range(n_spk):
114
  spks.update({i:1.0/float(self.n_spk)})
115
- orgouttt = self.init_spkembed(hubert, f0.unsqueeze(-1), volume.unsqueeze(-1), spk_mix_dict=spks)
116
 
117
  def forward(self, units, f0, volume, spk_id = None, spk_mix_dict = None, aug_shift = None,
118
  gt_spec=None, infer=True, infer_speedup=10, method='dpm-solver', k_step=300, use_tqdm=True):
@@ -124,6 +138,12 @@ class Unit2Mel(nn.Module):
124
  dict of B x n_frames x feat
125
  '''
126
 
 
 
 
 
 
 
127
  x = self.unit_embed(units) + self.f0_embed((1+ f0 / 700).log()) + self.volume_embed(volume)
128
  if self.n_spk is not None and self.n_spk > 1:
129
  if spk_mix_dict is not None:
 
1
  import os
2
+
3
+ import numpy as np
4
  import torch
5
  import torch.nn as nn
6
+ import yaml
7
+
8
  from .diffusion import GaussianDiffusion
 
9
  from .vocoder import Vocoder
10
+ from .wavenet import WaveNet
11
+
12
 
13
  class DotDict(dict):
14
  def __getattr__(*args):
 
24
  device='cpu',
25
  config_path = None
26
  ):
27
+ if config_path is None:
28
+ config_file = os.path.join(os.path.split(model_path)[0], 'config.yaml')
29
+ else:
30
+ config_file = config_path
31
+
32
  with open(config_file, "r") as config:
33
  args = yaml.safe_load(config)
34
  args = DotDict(args)
 
44
  vocoder.dimension,
45
  args.model.n_layers,
46
  args.model.n_chans,
47
+ args.model.n_hidden,
48
+ args.model.timesteps,
49
+ args.model.k_step_max
50
+ )
51
 
52
  print(' [Loading] ' + model_path)
53
  ckpt = torch.load(model_path, map_location=torch.device(device))
54
  model.to(device)
55
  model.load_state_dict(ckpt['model'])
56
  model.eval()
57
+ print(f'Loaded diffusion model, sampler is {args.infer.method}, speedup: {args.infer.speedup} ')
58
  return model, vocoder, args
59
 
60
 
 
67
  out_dims=128,
68
  n_layers=20,
69
  n_chans=384,
70
+ n_hidden=256,
71
+ timesteps=1000,
72
+ k_step_max=1000
73
+ ):
74
  super().__init__()
75
  self.unit_embed = nn.Linear(input_channel, n_hidden)
76
  self.f0_embed = nn.Linear(1, n_hidden)
 
83
  if n_spk is not None and n_spk > 1:
84
  self.spk_embed = nn.Embedding(n_spk, n_hidden)
85
 
86
+ self.timesteps = timesteps if timesteps is not None else 1000
87
+ self.k_step_max = k_step_max if k_step_max is not None and k_step_max>0 and k_step_max<self.timesteps else self.timesteps
88
+
89
  self.n_hidden = n_hidden
90
  # diffusion
91
+ self.decoder = GaussianDiffusion(WaveNet(out_dims, n_layers, n_chans, n_hidden),timesteps=self.timesteps,k_step=self.k_step_max, out_dims=out_dims)
92
  self.input_channel = input_channel
93
 
94
  def init_spkembed(self, units, f0, volume, spk_id = None, spk_mix_dict = None, aug_shift = None,
 
121
  hubert_hidden_size = self.input_channel
122
  n_frames = 10
123
  hubert = torch.randn((1, n_frames, hubert_hidden_size))
 
124
  f0 = torch.randn((1, n_frames))
125
  volume = torch.randn((1, n_frames))
126
  spks = {}
127
  for i in range(n_spk):
128
  spks.update({i:1.0/float(self.n_spk)})
129
+ self.init_spkembed(hubert, f0.unsqueeze(-1), volume.unsqueeze(-1), spk_mix_dict=spks)
130
 
131
  def forward(self, units, f0, volume, spk_id = None, spk_mix_dict = None, aug_shift = None,
132
  gt_spec=None, infer=True, infer_speedup=10, method='dpm-solver', k_step=300, use_tqdm=True):
 
138
  dict of B x n_frames x feat
139
  '''
140
 
141
+ if not self.training and gt_spec is not None and k_step>self.k_step_max:
142
+ raise Exception("The shallow diffusion k_step is greater than the maximum diffusion k_step(k_step_max)!")
143
+
144
+ if not self.training and gt_spec is None and self.k_step_max!=self.timesteps:
145
+ raise Exception("This model can only be used for shallow diffusion and can not infer alone!")
146
+
147
  x = self.unit_embed(units) + self.f0_embed((1+ f0 / 700).log()) + self.volume_embed(volume)
148
  if self.n_spk is not None and self.n_spk > 1:
149
  if spk_mix_dict is not None:
diffusion/vocoder.py CHANGED
@@ -1,9 +1,10 @@
1
  import torch
2
- from vdecoder.nsf_hifigan.nvSTFT import STFT
3
- from vdecoder.nsf_hifigan.models import load_model,load_config
4
  from torchaudio.transforms import Resample
5
 
6
-
 
 
 
7
  class Vocoder:
8
  def __init__(self, vocoder_type, vocoder_ckpt, device = None):
9
  if device is None:
 
1
  import torch
 
 
2
  from torchaudio.transforms import Resample
3
 
4
+ from vdecoder.nsf_hifigan.models import load_config, load_model
5
+ from vdecoder.nsf_hifigan.nvSTFT import STFT
6
+
7
+
8
  class Vocoder:
9
  def __init__(self, vocoder_type, vocoder_ckpt, device = None):
10
  if device is None:
modules/DSConv.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from torch.nn.utils import remove_weight_norm, weight_norm
3
+
4
+
5
+ class Depthwise_Separable_Conv1D(nn.Module):
6
+ def __init__(
7
+ self,
8
+ in_channels,
9
+ out_channels,
10
+ kernel_size,
11
+ stride = 1,
12
+ padding = 0,
13
+ dilation = 1,
14
+ bias = True,
15
+ padding_mode = 'zeros', # TODO: refine this type
16
+ device=None,
17
+ dtype=None
18
+ ):
19
+ super().__init__()
20
+ self.depth_conv = nn.Conv1d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size, groups=in_channels,stride = stride,padding=padding,dilation=dilation,bias=bias,padding_mode=padding_mode,device=device,dtype=dtype)
21
+ self.point_conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias, device=device,dtype=dtype)
22
+
23
+ def forward(self, input):
24
+ return self.point_conv(self.depth_conv(input))
25
+
26
+ def weight_norm(self):
27
+ self.depth_conv = weight_norm(self.depth_conv, name = 'weight')
28
+ self.point_conv = weight_norm(self.point_conv, name = 'weight')
29
+
30
+ def remove_weight_norm(self):
31
+ self.depth_conv = remove_weight_norm(self.depth_conv, name = 'weight')
32
+ self.point_conv = remove_weight_norm(self.point_conv, name = 'weight')
33
+
34
+ class Depthwise_Separable_TransposeConv1D(nn.Module):
35
+ def __init__(
36
+ self,
37
+ in_channels,
38
+ out_channels,
39
+ kernel_size,
40
+ stride = 1,
41
+ padding = 0,
42
+ output_padding = 0,
43
+ bias = True,
44
+ dilation = 1,
45
+ padding_mode = 'zeros', # TODO: refine this type
46
+ device=None,
47
+ dtype=None
48
+ ):
49
+ super().__init__()
50
+ self.depth_conv = nn.ConvTranspose1d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size, groups=in_channels,stride = stride,output_padding=output_padding,padding=padding,dilation=dilation,bias=bias,padding_mode=padding_mode,device=device,dtype=dtype)
51
+ self.point_conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias, device=device,dtype=dtype)
52
+
53
+ def forward(self, input):
54
+ return self.point_conv(self.depth_conv(input))
55
+
56
+ def weight_norm(self):
57
+ self.depth_conv = weight_norm(self.depth_conv, name = 'weight')
58
+ self.point_conv = weight_norm(self.point_conv, name = 'weight')
59
+
60
+ def remove_weight_norm(self):
61
+ remove_weight_norm(self.depth_conv, name = 'weight')
62
+ remove_weight_norm(self.point_conv, name = 'weight')
63
+
64
+
65
+ def weight_norm_modules(module, name = 'weight', dim = 0):
66
+ if isinstance(module,Depthwise_Separable_Conv1D) or isinstance(module,Depthwise_Separable_TransposeConv1D):
67
+ module.weight_norm()
68
+ return module
69
+ else:
70
+ return weight_norm(module,name,dim)
71
+
72
+ def remove_weight_norm_modules(module, name = 'weight'):
73
+ if isinstance(module,Depthwise_Separable_Conv1D) or isinstance(module,Depthwise_Separable_TransposeConv1D):
74
+ module.remove_weight_norm()
75
+ else:
76
+ remove_weight_norm(module,name)
modules/F0Predictor/CrepeF0Predictor.py CHANGED
@@ -1,7 +1,9 @@
1
- from modules.F0Predictor.F0Predictor import F0Predictor
2
- from modules.F0Predictor.crepe import CrepePitchExtractor
3
  import torch
4
 
 
 
 
 
5
  class CrepeF0Predictor(F0Predictor):
6
  def __init__(self,hop_length=512,f0_min=50,f0_max=1100,device=None,sampling_rate=44100,threshold=0.05,model="full"):
7
  self.F0Creper = CrepePitchExtractor(hop_length=hop_length,f0_min=f0_min,f0_max=f0_max,device=device,threshold=threshold,model=model)
 
 
 
1
  import torch
2
 
3
+ from modules.F0Predictor.crepe import CrepePitchExtractor
4
+ from modules.F0Predictor.F0Predictor import F0Predictor
5
+
6
+
7
  class CrepeF0Predictor(F0Predictor):
8
  def __init__(self,hop_length=512,f0_min=50,f0_max=1100,device=None,sampling_rate=44100,threshold=0.05,model="full"):
9
  self.F0Creper = CrepePitchExtractor(hop_length=hop_length,f0_min=f0_min,f0_max=f0_max,device=device,threshold=threshold,model=model)
modules/F0Predictor/DioF0Predictor.py CHANGED
@@ -1,6 +1,8 @@
1
- from modules.F0Predictor.F0Predictor import F0Predictor
2
- import pyworld
3
  import numpy as np
 
 
 
 
4
 
5
  class DioF0Predictor(F0Predictor):
6
  def __init__(self,hop_length=512,f0_min=50,f0_max=1100,sampling_rate=44100):
@@ -13,39 +15,25 @@ class DioF0Predictor(F0Predictor):
13
  '''
14
  对F0进行插值处理
15
  '''
 
 
 
16
 
17
- data = np.reshape(f0, (f0.size, 1))
18
-
19
- vuv_vector = np.zeros((data.size, 1), dtype=np.float32)
20
- vuv_vector[data > 0.0] = 1.0
21
- vuv_vector[data <= 0.0] = 0.0
22
-
23
- ip_data = data
24
-
25
- frame_number = data.size
26
- last_value = 0.0
27
- for i in range(frame_number):
28
- if data[i] <= 0.0:
29
- j = i + 1
30
- for j in range(i + 1, frame_number):
31
- if data[j] > 0.0:
32
- break
33
- if j < frame_number - 1:
34
- if last_value > 0.0:
35
- step = (data[j] - data[i - 1]) / float(j - i)
36
- for k in range(i, j):
37
- ip_data[k] = data[i - 1] + step * (k - i + 1)
38
- else:
39
- for k in range(i, j):
40
- ip_data[k] = data[j]
41
- else:
42
- for k in range(i, frame_number):
43
- ip_data[k] = last_value
44
- else:
45
- ip_data[i] = data[i] #这里可能存在一个没有必要的拷贝
46
- last_value = data[i]
47
-
48
- return ip_data[:,0], vuv_vector[:,0]
49
 
50
  def resize_f0(self,x, target_len):
51
  source = np.array(x)
 
 
 
1
  import numpy as np
2
+ import pyworld
3
+
4
+ from modules.F0Predictor.F0Predictor import F0Predictor
5
+
6
 
7
  class DioF0Predictor(F0Predictor):
8
  def __init__(self,hop_length=512,f0_min=50,f0_max=1100,sampling_rate=44100):
 
15
  '''
16
  对F0进行插值处理
17
  '''
18
+ vuv_vector = np.zeros_like(f0, dtype=np.float32)
19
+ vuv_vector[f0 > 0.0] = 1.0
20
+ vuv_vector[f0 <= 0.0] = 0.0
21
 
22
+ nzindex = np.nonzero(f0)[0]
23
+ data = f0[nzindex]
24
+ nzindex = nzindex.astype(np.float32)
25
+ time_org = self.hop_length / self.sampling_rate * nzindex
26
+ time_frame = np.arange(f0.shape[0]) * self.hop_length / self.sampling_rate
27
+
28
+ if data.shape[0] <= 0:
29
+ return np.zeros(f0.shape[0], dtype=np.float32),vuv_vector
30
+
31
+ if data.shape[0] == 1:
32
+ return np.ones(f0.shape[0], dtype=np.float32) * f0[0],vuv_vector
33
+
34
+ f0 = np.interp(time_frame, time_org, data, left=data[0], right=data[-1])
35
+
36
+ return f0,vuv_vector
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  def resize_f0(self,x, target_len):
39
  source = np.array(x)
modules/F0Predictor/HarvestF0Predictor.py CHANGED
@@ -1,6 +1,8 @@
1
- from modules.F0Predictor.F0Predictor import F0Predictor
2
- import pyworld
3
  import numpy as np
 
 
 
 
4
 
5
  class HarvestF0Predictor(F0Predictor):
6
  def __init__(self,hop_length=512,f0_min=50,f0_max=1100,sampling_rate=44100):
@@ -13,40 +15,25 @@ class HarvestF0Predictor(F0Predictor):
13
  '''
14
  对F0进行插值处理
15
  '''
 
 
 
16
 
17
- data = np.reshape(f0, (f0.size, 1))
18
-
19
- vuv_vector = np.zeros((data.size, 1), dtype=np.float32)
20
- vuv_vector[data > 0.0] = 1.0
21
- vuv_vector[data <= 0.0] = 0.0
22
-
23
- ip_data = data
24
-
25
- frame_number = data.size
26
- last_value = 0.0
27
- for i in range(frame_number):
28
- if data[i] <= 0.0:
29
- j = i + 1
30
- for j in range(i + 1, frame_number):
31
- if data[j] > 0.0:
32
- break
33
- if j < frame_number - 1:
34
- if last_value > 0.0:
35
- step = (data[j] - data[i - 1]) / float(j - i)
36
- for k in range(i, j):
37
- ip_data[k] = data[i - 1] + step * (k - i + 1)
38
- else:
39
- for k in range(i, j):
40
- ip_data[k] = data[j]
41
- else:
42
- for k in range(i, frame_number):
43
- ip_data[k] = last_value
44
- else:
45
- ip_data[i] = data[i] #这里可能存在一个没有必要的拷贝
46
- last_value = data[i]
47
-
48
- return ip_data[:,0], vuv_vector[:,0]
49
 
 
 
 
 
 
 
 
 
 
50
  def resize_f0(self,x, target_len):
51
  source = np.array(x)
52
  source[source<0.001] = np.nan
 
 
 
1
  import numpy as np
2
+ import pyworld
3
+
4
+ from modules.F0Predictor.F0Predictor import F0Predictor
5
+
6
 
7
  class HarvestF0Predictor(F0Predictor):
8
  def __init__(self,hop_length=512,f0_min=50,f0_max=1100,sampling_rate=44100):
 
15
  '''
16
  对F0进行插值处理
17
  '''
18
+ vuv_vector = np.zeros_like(f0, dtype=np.float32)
19
+ vuv_vector[f0 > 0.0] = 1.0
20
+ vuv_vector[f0 <= 0.0] = 0.0
21
 
22
+ nzindex = np.nonzero(f0)[0]
23
+ data = f0[nzindex]
24
+ nzindex = nzindex.astype(np.float32)
25
+ time_org = self.hop_length / self.sampling_rate * nzindex
26
+ time_frame = np.arange(f0.shape[0]) * self.hop_length / self.sampling_rate
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
+ if data.shape[0] <= 0:
29
+ return np.zeros(f0.shape[0], dtype=np.float32),vuv_vector
30
+
31
+ if data.shape[0] == 1:
32
+ return np.ones(f0.shape[0], dtype=np.float32) * f0[0],vuv_vector
33
+
34
+ f0 = np.interp(time_frame, time_org, data, left=data[0], right=data[-1])
35
+
36
+ return f0,vuv_vector
37
  def resize_f0(self,x, target_len):
38
  source = np.array(x)
39
  source[source<0.001] = np.nan
modules/F0Predictor/PMF0Predictor.py CHANGED
@@ -1,6 +1,8 @@
1
- from modules.F0Predictor.F0Predictor import F0Predictor
2
- import parselmouth
3
  import numpy as np
 
 
 
 
4
 
5
  class PMF0Predictor(F0Predictor):
6
  def __init__(self,hop_length=512,f0_min=50,f0_max=1100,sampling_rate=44100):
@@ -14,39 +16,26 @@ class PMF0Predictor(F0Predictor):
14
  '''
15
  对F0进行插值处理
16
  '''
 
 
 
17
 
18
- data = np.reshape(f0, (f0.size, 1))
19
-
20
- vuv_vector = np.zeros((data.size, 1), dtype=np.float32)
21
- vuv_vector[data > 0.0] = 1.0
22
- vuv_vector[data <= 0.0] = 0.0
23
-
24
- ip_data = data
25
-
26
- frame_number = data.size
27
- last_value = 0.0
28
- for i in range(frame_number):
29
- if data[i] <= 0.0:
30
- j = i + 1
31
- for j in range(i + 1, frame_number):
32
- if data[j] > 0.0:
33
- break
34
- if j < frame_number - 1:
35
- if last_value > 0.0:
36
- step = (data[j] - data[i - 1]) / float(j - i)
37
- for k in range(i, j):
38
- ip_data[k] = data[i - 1] + step * (k - i + 1)
39
- else:
40
- for k in range(i, j):
41
- ip_data[k] = data[j]
42
- else:
43
- for k in range(i, frame_number):
44
- ip_data[k] = last_value
45
- else:
46
- ip_data[i] = data[i] #这里可能存在一个没有必要的拷贝
47
- last_value = data[i]
48
 
49
- return ip_data[:,0], vuv_vector[:,0]
50
 
51
  def compute_f0(self,wav,p_len=None):
52
  x = wav
 
 
 
1
  import numpy as np
2
+ import parselmouth
3
+
4
+ from modules.F0Predictor.F0Predictor import F0Predictor
5
+
6
 
7
  class PMF0Predictor(F0Predictor):
8
  def __init__(self,hop_length=512,f0_min=50,f0_max=1100,sampling_rate=44100):
 
16
  '''
17
  对F0进行插值处理
18
  '''
19
+ vuv_vector = np.zeros_like(f0, dtype=np.float32)
20
+ vuv_vector[f0 > 0.0] = 1.0
21
+ vuv_vector[f0 <= 0.0] = 0.0
22
 
23
+ nzindex = np.nonzero(f0)[0]
24
+ data = f0[nzindex]
25
+ nzindex = nzindex.astype(np.float32)
26
+ time_org = self.hop_length / self.sampling_rate * nzindex
27
+ time_frame = np.arange(f0.shape[0]) * self.hop_length / self.sampling_rate
28
+
29
+ if data.shape[0] <= 0:
30
+ return np.zeros(f0.shape[0], dtype=np.float32),vuv_vector
31
+
32
+ if data.shape[0] == 1:
33
+ return np.ones(f0.shape[0], dtype=np.float32) * f0[0],vuv_vector
34
+
35
+ f0 = np.interp(time_frame, time_org, data, left=data[0], right=data[-1])
36
+
37
+ return f0,vuv_vector
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
 
39
 
40
  def compute_f0(self,wav,p_len=None):
41
  x = wav
modules/F0Predictor/RMVPEF0Predictor.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn.functional as F
6
+
7
+ from modules.F0Predictor.F0Predictor import F0Predictor
8
+
9
+ from .rmvpe import RMVPE
10
+
11
+
12
+ class RMVPEF0Predictor(F0Predictor):
13
+ def __init__(self,hop_length=512,f0_min=50,f0_max=1100, dtype=torch.float32, device=None,sampling_rate=44100,threshold=0.05):
14
+ self.rmvpe = RMVPE(model_path="pretrain/rmvpe.pt",dtype=dtype,device=device)
15
+ self.hop_length = hop_length
16
+ self.f0_min = f0_min
17
+ self.f0_max = f0_max
18
+ if device is None:
19
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
20
+ else:
21
+ self.device = device
22
+ self.threshold = threshold
23
+ self.sampling_rate = sampling_rate
24
+ self.dtype = dtype
25
+
26
+ def repeat_expand(
27
+ self, content: Union[torch.Tensor, np.ndarray], target_len: int, mode: str = "nearest"
28
+ ):
29
+ ndim = content.ndim
30
+
31
+ if content.ndim == 1:
32
+ content = content[None, None]
33
+ elif content.ndim == 2:
34
+ content = content[None]
35
+
36
+ assert content.ndim == 3
37
+
38
+ is_np = isinstance(content, np.ndarray)
39
+ if is_np:
40
+ content = torch.from_numpy(content)
41
+
42
+ results = torch.nn.functional.interpolate(content, size=target_len, mode=mode)
43
+
44
+ if is_np:
45
+ results = results.numpy()
46
+
47
+ if ndim == 1:
48
+ return results[0, 0]
49
+ elif ndim == 2:
50
+ return results[0]
51
+
52
+ def post_process(self, x, sampling_rate, f0, pad_to):
53
+ if isinstance(f0, np.ndarray):
54
+ f0 = torch.from_numpy(f0).float().to(x.device)
55
+
56
+ if pad_to is None:
57
+ return f0
58
+
59
+ f0 = self.repeat_expand(f0, pad_to)
60
+
61
+ vuv_vector = torch.zeros_like(f0)
62
+ vuv_vector[f0 > 0.0] = 1.0
63
+ vuv_vector[f0 <= 0.0] = 0.0
64
+
65
+ # 去掉0频率, 并线性插值
66
+ nzindex = torch.nonzero(f0).squeeze()
67
+ f0 = torch.index_select(f0, dim=0, index=nzindex).cpu().numpy()
68
+ time_org = self.hop_length / sampling_rate * nzindex.cpu().numpy()
69
+ time_frame = np.arange(pad_to) * self.hop_length / sampling_rate
70
+
71
+ vuv_vector = F.interpolate(vuv_vector[None,None,:],size=pad_to)[0][0]
72
+
73
+ if f0.shape[0] <= 0:
74
+ return torch.zeros(pad_to, dtype=torch.float, device=x.device),vuv_vector.cpu().numpy()
75
+ if f0.shape[0] == 1:
76
+ return torch.ones(pad_to, dtype=torch.float, device=x.device) * f0[0],vuv_vector.cpu().numpy()
77
+
78
+ # 大概可以用 torch 重写?
79
+ f0 = np.interp(time_frame, time_org, f0, left=f0[0], right=f0[-1])
80
+ #vuv_vector = np.ceil(scipy.ndimage.zoom(vuv_vector,pad_to/len(vuv_vector),order = 0))
81
+
82
+ return f0,vuv_vector.cpu().numpy()
83
+
84
+ def compute_f0(self,wav,p_len=None):
85
+ x = torch.FloatTensor(wav).to(self.dtype).to(self.device)
86
+ if p_len is None:
87
+ p_len = x.shape[0]//self.hop_length
88
+ else:
89
+ assert abs(p_len-x.shape[0]//self.hop_length) < 4, "pad length error"
90
+ f0 = self.rmvpe.infer_from_audio(x,self.sampling_rate,self.threshold)
91
+ if torch.all(f0 == 0):
92
+ rtn = f0.cpu().numpy() if p_len is None else np.zeros(p_len)
93
+ return rtn,rtn
94
+ return self.post_process(x,self.sampling_rate,f0,p_len)[0]
95
+
96
+ def compute_f0_uv(self,wav,p_len=None):
97
+ x = torch.FloatTensor(wav).to(self.dtype).to(self.device)
98
+ if p_len is None:
99
+ p_len = x.shape[0]//self.hop_length
100
+ else:
101
+ assert abs(p_len-x.shape[0]//self.hop_length) < 4, "pad length error"
102
+ f0 = self.rmvpe.infer_from_audio(x,self.sampling_rate,self.threshold)
103
+ if torch.all(f0 == 0):
104
+ rtn = f0.cpu().numpy() if p_len is None else np.zeros(p_len)
105
+ return rtn,rtn
106
+ return self.post_process(x,self.sampling_rate,f0,p_len)
modules/F0Predictor/__pycache__/CrepeF0Predictor.cpython-38.pyc CHANGED
Binary files a/modules/F0Predictor/__pycache__/CrepeF0Predictor.cpython-38.pyc and b/modules/F0Predictor/__pycache__/CrepeF0Predictor.cpython-38.pyc differ
 
modules/F0Predictor/__pycache__/F0Predictor.cpython-38.pyc CHANGED
Binary files a/modules/F0Predictor/__pycache__/F0Predictor.cpython-38.pyc and b/modules/F0Predictor/__pycache__/F0Predictor.cpython-38.pyc differ
 
modules/F0Predictor/__pycache__/HarvestF0Predictor.cpython-38.pyc CHANGED
Binary files a/modules/F0Predictor/__pycache__/HarvestF0Predictor.cpython-38.pyc and b/modules/F0Predictor/__pycache__/HarvestF0Predictor.cpython-38.pyc differ
 
modules/F0Predictor/__pycache__/PMF0Predictor.cpython-38.pyc CHANGED
Binary files a/modules/F0Predictor/__pycache__/PMF0Predictor.cpython-38.pyc and b/modules/F0Predictor/__pycache__/PMF0Predictor.cpython-38.pyc differ
 
modules/F0Predictor/__pycache__/RMVPEF0Predictor.cpython-38.pyc ADDED
Binary file (3.28 kB). View file
 
modules/F0Predictor/__pycache__/__init__.cpython-38.pyc CHANGED
Binary files a/modules/F0Predictor/__pycache__/__init__.cpython-38.pyc and b/modules/F0Predictor/__pycache__/__init__.cpython-38.pyc differ
 
modules/F0Predictor/__pycache__/crepe.cpython-38.pyc CHANGED
Binary files a/modules/F0Predictor/__pycache__/crepe.cpython-38.pyc and b/modules/F0Predictor/__pycache__/crepe.cpython-38.pyc differ
 
modules/F0Predictor/crepe.py CHANGED
@@ -1,14 +1,14 @@
1
- from typing import Optional,Union
 
2
  try:
3
  from typing import Literal
4
- except Exception as e:
5
  from typing_extensions import Literal
6
  import numpy as np
7
  import torch
8
  import torchcrepe
9
  from torch import nn
10
  from torch.nn import functional as F
11
- import scipy
12
 
13
  #from:https://github.com/fishaudio/fish-diffusion
14
 
@@ -97,19 +97,19 @@ class BasePitchExtractor:
97
  f0 = torch.index_select(f0, dim=0, index=nzindex).cpu().numpy()
98
  time_org = self.hop_length / sampling_rate * nzindex.cpu().numpy()
99
  time_frame = np.arange(pad_to) * self.hop_length / sampling_rate
 
 
100
 
101
  if f0.shape[0] <= 0:
102
- return torch.zeros(pad_to, dtype=torch.float, device=x.device),torch.zeros(pad_to, dtype=torch.float, device=x.device)
103
-
104
  if f0.shape[0] == 1:
105
- return torch.ones(pad_to, dtype=torch.float, device=x.device) * f0[0],torch.ones(pad_to, dtype=torch.float, device=x.device)
106
 
107
  # 大概可以用 torch 重写?
108
  f0 = np.interp(time_frame, time_org, f0, left=f0[0], right=f0[-1])
109
- vuv_vector = vuv_vector.cpu().numpy()
110
- vuv_vector = np.ceil(scipy.ndimage.zoom(vuv_vector,pad_to/len(vuv_vector),order = 0))
111
 
112
- return f0,vuv_vector
113
 
114
 
115
  class MaskedAvgPool1d(nn.Module):
@@ -323,7 +323,7 @@ class CrepePitchExtractor(BasePitchExtractor):
323
  else:
324
  pd = torchcrepe.filter.median(pd, 3)
325
 
326
- pd = torchcrepe.threshold.Silence(-60.0)(pd, x, sampling_rate, 512)
327
  f0 = torchcrepe.threshold.At(self.threshold)(f0, pd)
328
 
329
  if self.use_fast_filters:
@@ -334,7 +334,7 @@ class CrepePitchExtractor(BasePitchExtractor):
334
  f0 = torch.where(torch.isnan(f0), torch.full_like(f0, 0), f0)[0]
335
 
336
  if torch.all(f0 == 0):
337
- rtn = f0.cpu().numpy() if pad_to==None else np.zeros(pad_to)
338
  return rtn,rtn
339
 
340
  return self.post_process(x, sampling_rate, f0, pad_to)
 
1
+ from typing import Optional, Union
2
+
3
  try:
4
  from typing import Literal
5
+ except Exception:
6
  from typing_extensions import Literal
7
  import numpy as np
8
  import torch
9
  import torchcrepe
10
  from torch import nn
11
  from torch.nn import functional as F
 
12
 
13
  #from:https://github.com/fishaudio/fish-diffusion
14
 
 
97
  f0 = torch.index_select(f0, dim=0, index=nzindex).cpu().numpy()
98
  time_org = self.hop_length / sampling_rate * nzindex.cpu().numpy()
99
  time_frame = np.arange(pad_to) * self.hop_length / sampling_rate
100
+
101
+ vuv_vector = F.interpolate(vuv_vector[None,None,:],size=pad_to)[0][0]
102
 
103
  if f0.shape[0] <= 0:
104
+ return torch.zeros(pad_to, dtype=torch.float, device=x.device),vuv_vector.cpu().numpy()
 
105
  if f0.shape[0] == 1:
106
+ return torch.ones(pad_to, dtype=torch.float, device=x.device) * f0[0],vuv_vector.cpu().numpy()
107
 
108
  # 大概可以用 torch 重写?
109
  f0 = np.interp(time_frame, time_org, f0, left=f0[0], right=f0[-1])
110
+ #vuv_vector = np.ceil(scipy.ndimage.zoom(vuv_vector,pad_to/len(vuv_vector),order = 0))
 
111
 
112
+ return f0,vuv_vector.cpu().numpy()
113
 
114
 
115
  class MaskedAvgPool1d(nn.Module):
 
323
  else:
324
  pd = torchcrepe.filter.median(pd, 3)
325
 
326
+ pd = torchcrepe.threshold.Silence(-60.0)(pd, x, sampling_rate, self.hop_length)
327
  f0 = torchcrepe.threshold.At(self.threshold)(f0, pd)
328
 
329
  if self.use_fast_filters:
 
334
  f0 = torch.where(torch.isnan(f0), torch.full_like(f0, 0), f0)[0]
335
 
336
  if torch.all(f0 == 0):
337
+ rtn = f0.cpu().numpy() if pad_to is None else np.zeros(pad_to)
338
  return rtn,rtn
339
 
340
  return self.post_process(x, sampling_rate, f0, pad_to)
modules/F0Predictor/rmvpe/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from .constants import * # noqa: F403
2
+ from .inference import RMVPE # noqa: F401
3
+ from .model import E2E, E2E0 # noqa: F401
4
+ from .spec import MelSpectrogram # noqa: F401
5
+ from .utils import ( # noqa: F401
6
+ cycle,
7
+ summary,
8
+ to_local_average_cents,
9
+ to_viterbi_cents,
10
+ )
modules/F0Predictor/rmvpe/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (416 Bytes). View file
 
modules/F0Predictor/rmvpe/__pycache__/constants.cpython-38.pyc ADDED
Binary file (300 Bytes). View file
 
modules/F0Predictor/rmvpe/__pycache__/deepunet.cpython-38.pyc ADDED
Binary file (6.86 kB). View file
 
modules/F0Predictor/rmvpe/__pycache__/inference.cpython-38.pyc ADDED
Binary file (2.54 kB). View file
 
modules/F0Predictor/rmvpe/__pycache__/model.cpython-38.pyc ADDED
Binary file (2.41 kB). View file
 
modules/F0Predictor/rmvpe/__pycache__/seq.cpython-38.pyc ADDED
Binary file (1.18 kB). View file