Spaces:
Running
Running
Create MMD.py
Browse files
MMD.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
def flexible_kernel(X, Y, X_org, Y_org, sigma, sigma0=0.1, epsilon=1e-08):
|
5 |
+
"""Flexible kernel calculation as in MMDu."""
|
6 |
+
Dxy = Pdist2(X, Y)
|
7 |
+
Dxy_org = Pdist2(X_org, Y_org)
|
8 |
+
L = 1
|
9 |
+
Kxy = (1 - epsilon) * torch.exp(
|
10 |
+
-((Dxy / sigma0) ** L) - Dxy_org / sigma
|
11 |
+
) + epsilon * torch.exp(-Dxy_org / sigma)
|
12 |
+
return Kxy
|
13 |
+
|
14 |
+
|
15 |
+
def MMD_Diff_Var(Kyy, Kzz, Kxy, Kxz, epsilon=1e-08):
|
16 |
+
"""Compute the variance of the difference statistic MMDXY - MMDXZ."""
|
17 |
+
"""Referenced from: https://github.com/eugenium/MMD/blob/master/mmd.py"""
|
18 |
+
m = Kxy.shape[0]
|
19 |
+
n = Kyy.shape[0]
|
20 |
+
r = Kzz.shape[0]
|
21 |
+
|
22 |
+
# Remove diagonal elements
|
23 |
+
Kyynd = Kyy - torch.diag(torch.diag(Kyy))
|
24 |
+
Kzznd = Kzz - torch.diag(torch.diag(Kzz))
|
25 |
+
|
26 |
+
u_yy = torch.sum(Kyynd) * (1.0 / (n * (n - 1)))
|
27 |
+
u_zz = torch.sum(Kzznd) * (1.0 / (r * (r - 1)))
|
28 |
+
u_xy = torch.sum(Kxy) / (m * n)
|
29 |
+
u_xz = torch.sum(Kxz) / (m * r)
|
30 |
+
|
31 |
+
t1 = (1.0 / n**3) * torch.sum(Kyynd.T @ Kyynd) - u_yy**2
|
32 |
+
t2 = (1.0 / (n**2 * m)) * torch.sum(Kxy.T @ Kxy) - u_xy**2
|
33 |
+
t3 = (1.0 / (n * m**2)) * torch.sum(Kxy @ Kxy.T) - u_xy**2
|
34 |
+
t4 = (1.0 / r**3) * torch.sum(Kzznd.T @ Kzznd) - u_zz**2
|
35 |
+
t5 = (1.0 / (r * m**2)) * torch.sum(Kxz @ Kxz.T) - u_xz**2
|
36 |
+
t6 = (1.0 / (r**2 * m)) * torch.sum(Kxz.T @ Kxz) - u_xz**2
|
37 |
+
t7 = (1.0 / (n**2 * m)) * torch.sum(Kyynd @ Kxy.T) - u_yy * u_xy
|
38 |
+
t8 = (1.0 / (n * m * r)) * torch.sum(Kxy.T @ Kxz) - u_xz * u_xy
|
39 |
+
t9 = (1.0 / (r**2 * m)) * torch.sum(Kzznd @ Kxz.T) - u_zz * u_xz
|
40 |
+
|
41 |
+
if type(epsilon) == torch.Tensor:
|
42 |
+
epsilon_tensor = epsilon.clone().detach()
|
43 |
+
else:
|
44 |
+
epsilon_tensor = torch.tensor(epsilon, device=Kyy.device)
|
45 |
+
zeta1 = torch.max(t1 + t2 + t3 + t4 + t5 + t6 - 2 * (t7 + t8 + t9), epsilon_tensor)
|
46 |
+
zeta2 = torch.max(
|
47 |
+
(1 / m / (m - 1)) * torch.sum((Kyynd - Kzznd - Kxy.T - Kxy + Kxz + Kxz.T) ** 2)
|
48 |
+
- (u_yy - 2 * u_xy - (u_zz - 2 * u_xz)) ** 2,
|
49 |
+
epsilon_tensor,
|
50 |
+
)
|
51 |
+
|
52 |
+
data = {
|
53 |
+
"t1": t1.item(),
|
54 |
+
"t2": t2.item(),
|
55 |
+
"t3": t3.item(),
|
56 |
+
"t4": t4.item(),
|
57 |
+
"t5": t5.item(),
|
58 |
+
"t6": t6.item(),
|
59 |
+
"t7": t7.item(),
|
60 |
+
"t8": t8.item(),
|
61 |
+
"t9": t9.item(),
|
62 |
+
"zeta1": zeta1.item(),
|
63 |
+
"zeta2": zeta2.item(),
|
64 |
+
}
|
65 |
+
|
66 |
+
Var = (4 * (m - 2) / (m * (m - 1))) * zeta1
|
67 |
+
Var_z2 = Var + (2.0 / (m * (m - 1))) * zeta2
|
68 |
+
|
69 |
+
return Var, Var_z2, data
|
70 |
+
|
71 |
+
|
72 |
+
def Pdist2(x, y):
|
73 |
+
"""compute the paired distance between x and y."""
|
74 |
+
x_norm = (x**2).sum(1).view(-1, 1)
|
75 |
+
if y is not None:
|
76 |
+
y_norm = (y**2).sum(1).view(1, -1)
|
77 |
+
else:
|
78 |
+
y = x
|
79 |
+
y_norm = x_norm.view(1, -1)
|
80 |
+
Pdist = x_norm + y_norm - 2.0 * torch.mm(x, torch.transpose(y, 0, 1))
|
81 |
+
Pdist[Pdist < 0] = 0
|
82 |
+
return Pdist
|
83 |
+
|
84 |
+
|
85 |
+
def MMD_batch2(
|
86 |
+
Fea,
|
87 |
+
len_s,
|
88 |
+
Fea_org,
|
89 |
+
sigma,
|
90 |
+
sigma0=0.1,
|
91 |
+
epsilon=10 ** (-10),
|
92 |
+
is_var_computed=True,
|
93 |
+
use_1sample_U=True,
|
94 |
+
coeff_xy=2,
|
95 |
+
):
|
96 |
+
X = Fea[0:len_s, :]
|
97 |
+
Y = Fea[len_s:, :]
|
98 |
+
L = 1 # generalized Gaussian (if L>1)
|
99 |
+
|
100 |
+
nx = X.shape[0]
|
101 |
+
ny = Y.shape[0]
|
102 |
+
Dxx = Pdist2(X, X)
|
103 |
+
Dyy = torch.zeros(Fea.shape[0] - len_s, 1).to(Dxx.device)
|
104 |
+
# Dyy = Pdist2(Y, Y)
|
105 |
+
Dxy = Pdist2(X, Y).transpose(0, 1)
|
106 |
+
Kx = torch.exp(-Dxx / sigma0)
|
107 |
+
Ky = torch.exp(-Dyy / sigma0)
|
108 |
+
Kxy = torch.exp(-Dxy / sigma0)
|
109 |
+
|
110 |
+
nx = Kx.shape[0]
|
111 |
+
|
112 |
+
is_unbiased = False
|
113 |
+
xx = torch.div((torch.sum(Kx)), (nx * nx))
|
114 |
+
yy = Ky.reshape(-1)
|
115 |
+
xy = torch.div(torch.sum(Kxy, dim=1), (nx))
|
116 |
+
|
117 |
+
mmd2 = xx - 2 * xy + yy
|
118 |
+
return mmd2
|
119 |
+
|
120 |
+
|
121 |
+
# MMD for three samples
|
122 |
+
def MMD_3_Sample_Test(
|
123 |
+
ref_fea,
|
124 |
+
fea_y,
|
125 |
+
fea_z,
|
126 |
+
ref_fea_org,
|
127 |
+
fea_y_org,
|
128 |
+
fea_z_org,
|
129 |
+
sigma,
|
130 |
+
sigma0,
|
131 |
+
epsilon,
|
132 |
+
alpha,
|
133 |
+
):
|
134 |
+
"""Run three-sample test (TST) using deep kernel kernel."""
|
135 |
+
X = ref_fea.clone().detach()
|
136 |
+
Y = fea_y.clone().detach()
|
137 |
+
Z = fea_z.clone().detach()
|
138 |
+
X_org = ref_fea_org.clone().detach()
|
139 |
+
Y_org = fea_y_org.clone().detach()
|
140 |
+
Z_org = fea_z_org.clone().detach()
|
141 |
+
|
142 |
+
Kyy = flexible_kernel(Y, Y, Y_org, Y_org, sigma, sigma0, epsilon)
|
143 |
+
Kzz = flexible_kernel(Z, Z, Z_org, Z_org, sigma, sigma0, epsilon)
|
144 |
+
Kxy = flexible_kernel(X, Y, X_org, Y_org, sigma, sigma0, epsilon)
|
145 |
+
Kxz = flexible_kernel(X, Z, X_org, Z_org, sigma, sigma0, epsilon)
|
146 |
+
|
147 |
+
Kyynd = Kyy - torch.diag(torch.diag(Kyy))
|
148 |
+
Kzznd = Kzz - torch.diag(torch.diag(Kzz))
|
149 |
+
|
150 |
+
Diff_Var, _, _ = MMD_Diff_Var(Kyy, Kzz, Kxy, Kxz, epsilon)
|
151 |
+
|
152 |
+
u_yy = torch.sum(Kyynd) / (Y.shape[0] * (Y.shape[0] - 1))
|
153 |
+
u_zz = torch.sum(Kzznd) / (Z.shape[0] * (Z.shape[0] - 1))
|
154 |
+
u_xy = torch.sum(Kxy) / (X.shape[0] * Y.shape[0])
|
155 |
+
u_xz = torch.sum(Kxz) / (X.shape[0] * Z.shape[0])
|
156 |
+
|
157 |
+
t = u_yy - 2 * u_xy - (u_zz - 2 * u_xz)
|
158 |
+
if Diff_Var.item() <= 0:
|
159 |
+
Diff_Var = torch.max(epsilon, torch.tensor(1e-08))
|
160 |
+
p_value = torch.distributions.Normal(0, 1).cdf(-t / torch.sqrt((Diff_Var)))
|
161 |
+
t = t / torch.sqrt(Diff_Var)
|
162 |
+
|
163 |
+
if p_value > alpha:
|
164 |
+
h = 0
|
165 |
+
else:
|
166 |
+
h = 1
|
167 |
+
|
168 |
+
return h, p_value.item(), t.item()
|