Spaces:
Runtime error
Runtime error
Upload vocoder/distribution.py with huggingface_hub
Browse files- vocoder/distribution.py +132 -0
vocoder/distribution.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
def log_sum_exp(x):
|
7 |
+
""" numerically stable log_sum_exp implementation that prevents overflow """
|
8 |
+
# TF ordering
|
9 |
+
axis = len(x.size()) - 1
|
10 |
+
m, _ = torch.max(x, dim=axis)
|
11 |
+
m2, _ = torch.max(x, dim=axis, keepdim=True)
|
12 |
+
return m + torch.log(torch.sum(torch.exp(x - m2), dim=axis))
|
13 |
+
|
14 |
+
|
15 |
+
# It is adapted from https://github.com/r9y9/wavenet_vocoder/blob/master/wavenet_vocoder/mixture.py
|
16 |
+
def discretized_mix_logistic_loss(y_hat, y, num_classes=65536,
|
17 |
+
log_scale_min=None, reduce=True):
|
18 |
+
if log_scale_min is None:
|
19 |
+
log_scale_min = float(np.log(1e-14))
|
20 |
+
y_hat = y_hat.permute(0,2,1)
|
21 |
+
assert y_hat.dim() == 3
|
22 |
+
assert y_hat.size(1) % 3 == 0
|
23 |
+
nr_mix = y_hat.size(1) // 3
|
24 |
+
|
25 |
+
# (B x T x C)
|
26 |
+
y_hat = y_hat.transpose(1, 2)
|
27 |
+
|
28 |
+
# unpack parameters. (B, T, num_mixtures) x 3
|
29 |
+
logit_probs = y_hat[:, :, :nr_mix]
|
30 |
+
means = y_hat[:, :, nr_mix:2 * nr_mix]
|
31 |
+
log_scales = torch.clamp(y_hat[:, :, 2 * nr_mix:3 * nr_mix], min=log_scale_min)
|
32 |
+
|
33 |
+
# B x T x 1 -> B x T x num_mixtures
|
34 |
+
y = y.expand_as(means)
|
35 |
+
|
36 |
+
centered_y = y - means
|
37 |
+
inv_stdv = torch.exp(-log_scales)
|
38 |
+
plus_in = inv_stdv * (centered_y + 1. / (num_classes - 1))
|
39 |
+
cdf_plus = torch.sigmoid(plus_in)
|
40 |
+
min_in = inv_stdv * (centered_y - 1. / (num_classes - 1))
|
41 |
+
cdf_min = torch.sigmoid(min_in)
|
42 |
+
|
43 |
+
# log probability for edge case of 0 (before scaling)
|
44 |
+
# equivalent: torch.log(F.sigmoid(plus_in))
|
45 |
+
log_cdf_plus = plus_in - F.softplus(plus_in)
|
46 |
+
|
47 |
+
# log probability for edge case of 255 (before scaling)
|
48 |
+
# equivalent: (1 - F.sigmoid(min_in)).log()
|
49 |
+
log_one_minus_cdf_min = -F.softplus(min_in)
|
50 |
+
|
51 |
+
# probability for all other cases
|
52 |
+
cdf_delta = cdf_plus - cdf_min
|
53 |
+
|
54 |
+
mid_in = inv_stdv * centered_y
|
55 |
+
# log probability in the center of the bin, to be used in extreme cases
|
56 |
+
# (not actually used in our code)
|
57 |
+
log_pdf_mid = mid_in - log_scales - 2. * F.softplus(mid_in)
|
58 |
+
|
59 |
+
# tf equivalent
|
60 |
+
"""
|
61 |
+
log_probs = tf.where(x < -0.999, log_cdf_plus,
|
62 |
+
tf.where(x > 0.999, log_one_minus_cdf_min,
|
63 |
+
tf.where(cdf_delta > 1e-5,
|
64 |
+
tf.log(tf.maximum(cdf_delta, 1e-12)),
|
65 |
+
log_pdf_mid - np.log(127.5))))
|
66 |
+
"""
|
67 |
+
# TODO: cdf_delta <= 1e-5 actually can happen. How can we choose the value
|
68 |
+
# for num_classes=65536 case? 1e-7? not sure..
|
69 |
+
inner_inner_cond = (cdf_delta > 1e-5).float()
|
70 |
+
|
71 |
+
inner_inner_out = inner_inner_cond * \
|
72 |
+
torch.log(torch.clamp(cdf_delta, min=1e-12)) + \
|
73 |
+
(1. - inner_inner_cond) * (log_pdf_mid - np.log((num_classes - 1) / 2))
|
74 |
+
inner_cond = (y > 0.999).float()
|
75 |
+
inner_out = inner_cond * log_one_minus_cdf_min + (1. - inner_cond) * inner_inner_out
|
76 |
+
cond = (y < -0.999).float()
|
77 |
+
log_probs = cond * log_cdf_plus + (1. - cond) * inner_out
|
78 |
+
|
79 |
+
log_probs = log_probs + F.log_softmax(logit_probs, -1)
|
80 |
+
|
81 |
+
if reduce:
|
82 |
+
return -torch.mean(log_sum_exp(log_probs))
|
83 |
+
else:
|
84 |
+
return -log_sum_exp(log_probs).unsqueeze(-1)
|
85 |
+
|
86 |
+
|
87 |
+
def sample_from_discretized_mix_logistic(y, log_scale_min=None):
|
88 |
+
"""
|
89 |
+
Sample from discretized mixture of logistic distributions
|
90 |
+
Args:
|
91 |
+
y (Tensor): B x C x T
|
92 |
+
log_scale_min (float): Log scale minimum value
|
93 |
+
Returns:
|
94 |
+
Tensor: sample in range of [-1, 1].
|
95 |
+
"""
|
96 |
+
if log_scale_min is None:
|
97 |
+
log_scale_min = float(np.log(1e-14))
|
98 |
+
assert y.size(1) % 3 == 0
|
99 |
+
nr_mix = y.size(1) // 3
|
100 |
+
|
101 |
+
# B x T x C
|
102 |
+
y = y.transpose(1, 2)
|
103 |
+
logit_probs = y[:, :, :nr_mix]
|
104 |
+
|
105 |
+
# sample mixture indicator from softmax
|
106 |
+
temp = logit_probs.data.new(logit_probs.size()).uniform_(1e-5, 1.0 - 1e-5)
|
107 |
+
temp = logit_probs.data - torch.log(- torch.log(temp))
|
108 |
+
_, argmax = temp.max(dim=-1)
|
109 |
+
|
110 |
+
# (B, T) -> (B, T, nr_mix)
|
111 |
+
one_hot = to_one_hot(argmax, nr_mix)
|
112 |
+
# select logistic parameters
|
113 |
+
means = torch.sum(y[:, :, nr_mix:2 * nr_mix] * one_hot, dim=-1)
|
114 |
+
log_scales = torch.clamp(torch.sum(
|
115 |
+
y[:, :, 2 * nr_mix:3 * nr_mix] * one_hot, dim=-1), min=log_scale_min)
|
116 |
+
# sample from logistic & clip to interval
|
117 |
+
# we don't actually round to the nearest 8bit value when sampling
|
118 |
+
u = means.data.new(means.size()).uniform_(1e-5, 1.0 - 1e-5)
|
119 |
+
x = means + torch.exp(log_scales) * (torch.log(u) - torch.log(1. - u))
|
120 |
+
|
121 |
+
x = torch.clamp(torch.clamp(x, min=-1.), max=1.)
|
122 |
+
|
123 |
+
return x
|
124 |
+
|
125 |
+
|
126 |
+
def to_one_hot(tensor, n, fill_with=1.):
|
127 |
+
# we perform one hot encore with respect to the last axis
|
128 |
+
one_hot = torch.FloatTensor(tensor.size() + (n,)).zero_()
|
129 |
+
if tensor.is_cuda:
|
130 |
+
one_hot = one_hot.cuda()
|
131 |
+
one_hot.scatter_(len(tensor.size()), tensor.unsqueeze(-1), fill_with)
|
132 |
+
return one_hot
|