keithhon commited on
Commit
fc23e1a
1 Parent(s): 3faed8a

Upload vocoder/distribution.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. vocoder/distribution.py +132 -0
vocoder/distribution.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+
5
+
6
+ def log_sum_exp(x):
7
+ """ numerically stable log_sum_exp implementation that prevents overflow """
8
+ # TF ordering
9
+ axis = len(x.size()) - 1
10
+ m, _ = torch.max(x, dim=axis)
11
+ m2, _ = torch.max(x, dim=axis, keepdim=True)
12
+ return m + torch.log(torch.sum(torch.exp(x - m2), dim=axis))
13
+
14
+
15
+ # It is adapted from https://github.com/r9y9/wavenet_vocoder/blob/master/wavenet_vocoder/mixture.py
16
+ def discretized_mix_logistic_loss(y_hat, y, num_classes=65536,
17
+ log_scale_min=None, reduce=True):
18
+ if log_scale_min is None:
19
+ log_scale_min = float(np.log(1e-14))
20
+ y_hat = y_hat.permute(0,2,1)
21
+ assert y_hat.dim() == 3
22
+ assert y_hat.size(1) % 3 == 0
23
+ nr_mix = y_hat.size(1) // 3
24
+
25
+ # (B x T x C)
26
+ y_hat = y_hat.transpose(1, 2)
27
+
28
+ # unpack parameters. (B, T, num_mixtures) x 3
29
+ logit_probs = y_hat[:, :, :nr_mix]
30
+ means = y_hat[:, :, nr_mix:2 * nr_mix]
31
+ log_scales = torch.clamp(y_hat[:, :, 2 * nr_mix:3 * nr_mix], min=log_scale_min)
32
+
33
+ # B x T x 1 -> B x T x num_mixtures
34
+ y = y.expand_as(means)
35
+
36
+ centered_y = y - means
37
+ inv_stdv = torch.exp(-log_scales)
38
+ plus_in = inv_stdv * (centered_y + 1. / (num_classes - 1))
39
+ cdf_plus = torch.sigmoid(plus_in)
40
+ min_in = inv_stdv * (centered_y - 1. / (num_classes - 1))
41
+ cdf_min = torch.sigmoid(min_in)
42
+
43
+ # log probability for edge case of 0 (before scaling)
44
+ # equivalent: torch.log(F.sigmoid(plus_in))
45
+ log_cdf_plus = plus_in - F.softplus(plus_in)
46
+
47
+ # log probability for edge case of 255 (before scaling)
48
+ # equivalent: (1 - F.sigmoid(min_in)).log()
49
+ log_one_minus_cdf_min = -F.softplus(min_in)
50
+
51
+ # probability for all other cases
52
+ cdf_delta = cdf_plus - cdf_min
53
+
54
+ mid_in = inv_stdv * centered_y
55
+ # log probability in the center of the bin, to be used in extreme cases
56
+ # (not actually used in our code)
57
+ log_pdf_mid = mid_in - log_scales - 2. * F.softplus(mid_in)
58
+
59
+ # tf equivalent
60
+ """
61
+ log_probs = tf.where(x < -0.999, log_cdf_plus,
62
+ tf.where(x > 0.999, log_one_minus_cdf_min,
63
+ tf.where(cdf_delta > 1e-5,
64
+ tf.log(tf.maximum(cdf_delta, 1e-12)),
65
+ log_pdf_mid - np.log(127.5))))
66
+ """
67
+ # TODO: cdf_delta <= 1e-5 actually can happen. How can we choose the value
68
+ # for num_classes=65536 case? 1e-7? not sure..
69
+ inner_inner_cond = (cdf_delta > 1e-5).float()
70
+
71
+ inner_inner_out = inner_inner_cond * \
72
+ torch.log(torch.clamp(cdf_delta, min=1e-12)) + \
73
+ (1. - inner_inner_cond) * (log_pdf_mid - np.log((num_classes - 1) / 2))
74
+ inner_cond = (y > 0.999).float()
75
+ inner_out = inner_cond * log_one_minus_cdf_min + (1. - inner_cond) * inner_inner_out
76
+ cond = (y < -0.999).float()
77
+ log_probs = cond * log_cdf_plus + (1. - cond) * inner_out
78
+
79
+ log_probs = log_probs + F.log_softmax(logit_probs, -1)
80
+
81
+ if reduce:
82
+ return -torch.mean(log_sum_exp(log_probs))
83
+ else:
84
+ return -log_sum_exp(log_probs).unsqueeze(-1)
85
+
86
+
87
+ def sample_from_discretized_mix_logistic(y, log_scale_min=None):
88
+ """
89
+ Sample from discretized mixture of logistic distributions
90
+ Args:
91
+ y (Tensor): B x C x T
92
+ log_scale_min (float): Log scale minimum value
93
+ Returns:
94
+ Tensor: sample in range of [-1, 1].
95
+ """
96
+ if log_scale_min is None:
97
+ log_scale_min = float(np.log(1e-14))
98
+ assert y.size(1) % 3 == 0
99
+ nr_mix = y.size(1) // 3
100
+
101
+ # B x T x C
102
+ y = y.transpose(1, 2)
103
+ logit_probs = y[:, :, :nr_mix]
104
+
105
+ # sample mixture indicator from softmax
106
+ temp = logit_probs.data.new(logit_probs.size()).uniform_(1e-5, 1.0 - 1e-5)
107
+ temp = logit_probs.data - torch.log(- torch.log(temp))
108
+ _, argmax = temp.max(dim=-1)
109
+
110
+ # (B, T) -> (B, T, nr_mix)
111
+ one_hot = to_one_hot(argmax, nr_mix)
112
+ # select logistic parameters
113
+ means = torch.sum(y[:, :, nr_mix:2 * nr_mix] * one_hot, dim=-1)
114
+ log_scales = torch.clamp(torch.sum(
115
+ y[:, :, 2 * nr_mix:3 * nr_mix] * one_hot, dim=-1), min=log_scale_min)
116
+ # sample from logistic & clip to interval
117
+ # we don't actually round to the nearest 8bit value when sampling
118
+ u = means.data.new(means.size()).uniform_(1e-5, 1.0 - 1e-5)
119
+ x = means + torch.exp(log_scales) * (torch.log(u) - torch.log(1. - u))
120
+
121
+ x = torch.clamp(torch.clamp(x, min=-1.), max=1.)
122
+
123
+ return x
124
+
125
+
126
+ def to_one_hot(tensor, n, fill_with=1.):
127
+ # we perform one hot encore with respect to the last axis
128
+ one_hot = torch.FloatTensor(tensor.size() + (n,)).zero_()
129
+ if tensor.is_cuda:
130
+ one_hot = one_hot.cuda()
131
+ one_hot.scatter_(len(tensor.size()), tensor.unsqueeze(-1), fill_with)
132
+ return one_hot