jer233 commited on
Commit
377127e
·
verified ·
1 Parent(s): 2b33c6f

Create MMD.py

Browse files
Files changed (1) hide show
  1. MMD.py +168 -0
MMD.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def flexible_kernel(X, Y, X_org, Y_org, sigma, sigma0=0.1, epsilon=1e-08):
5
+ """Flexible kernel calculation as in MMDu."""
6
+ Dxy = Pdist2(X, Y)
7
+ Dxy_org = Pdist2(X_org, Y_org)
8
+ L = 1
9
+ Kxy = (1 - epsilon) * torch.exp(
10
+ -((Dxy / sigma0) ** L) - Dxy_org / sigma
11
+ ) + epsilon * torch.exp(-Dxy_org / sigma)
12
+ return Kxy
13
+
14
+
15
+ def MMD_Diff_Var(Kyy, Kzz, Kxy, Kxz, epsilon=1e-08):
16
+ """Compute the variance of the difference statistic MMDXY - MMDXZ."""
17
+ """Referenced from: https://github.com/eugenium/MMD/blob/master/mmd.py"""
18
+ m = Kxy.shape[0]
19
+ n = Kyy.shape[0]
20
+ r = Kzz.shape[0]
21
+
22
+ # Remove diagonal elements
23
+ Kyynd = Kyy - torch.diag(torch.diag(Kyy))
24
+ Kzznd = Kzz - torch.diag(torch.diag(Kzz))
25
+
26
+ u_yy = torch.sum(Kyynd) * (1.0 / (n * (n - 1)))
27
+ u_zz = torch.sum(Kzznd) * (1.0 / (r * (r - 1)))
28
+ u_xy = torch.sum(Kxy) / (m * n)
29
+ u_xz = torch.sum(Kxz) / (m * r)
30
+
31
+ t1 = (1.0 / n**3) * torch.sum(Kyynd.T @ Kyynd) - u_yy**2
32
+ t2 = (1.0 / (n**2 * m)) * torch.sum(Kxy.T @ Kxy) - u_xy**2
33
+ t3 = (1.0 / (n * m**2)) * torch.sum(Kxy @ Kxy.T) - u_xy**2
34
+ t4 = (1.0 / r**3) * torch.sum(Kzznd.T @ Kzznd) - u_zz**2
35
+ t5 = (1.0 / (r * m**2)) * torch.sum(Kxz @ Kxz.T) - u_xz**2
36
+ t6 = (1.0 / (r**2 * m)) * torch.sum(Kxz.T @ Kxz) - u_xz**2
37
+ t7 = (1.0 / (n**2 * m)) * torch.sum(Kyynd @ Kxy.T) - u_yy * u_xy
38
+ t8 = (1.0 / (n * m * r)) * torch.sum(Kxy.T @ Kxz) - u_xz * u_xy
39
+ t9 = (1.0 / (r**2 * m)) * torch.sum(Kzznd @ Kxz.T) - u_zz * u_xz
40
+
41
+ if type(epsilon) == torch.Tensor:
42
+ epsilon_tensor = epsilon.clone().detach()
43
+ else:
44
+ epsilon_tensor = torch.tensor(epsilon, device=Kyy.device)
45
+ zeta1 = torch.max(t1 + t2 + t3 + t4 + t5 + t6 - 2 * (t7 + t8 + t9), epsilon_tensor)
46
+ zeta2 = torch.max(
47
+ (1 / m / (m - 1)) * torch.sum((Kyynd - Kzznd - Kxy.T - Kxy + Kxz + Kxz.T) ** 2)
48
+ - (u_yy - 2 * u_xy - (u_zz - 2 * u_xz)) ** 2,
49
+ epsilon_tensor,
50
+ )
51
+
52
+ data = {
53
+ "t1": t1.item(),
54
+ "t2": t2.item(),
55
+ "t3": t3.item(),
56
+ "t4": t4.item(),
57
+ "t5": t5.item(),
58
+ "t6": t6.item(),
59
+ "t7": t7.item(),
60
+ "t8": t8.item(),
61
+ "t9": t9.item(),
62
+ "zeta1": zeta1.item(),
63
+ "zeta2": zeta2.item(),
64
+ }
65
+
66
+ Var = (4 * (m - 2) / (m * (m - 1))) * zeta1
67
+ Var_z2 = Var + (2.0 / (m * (m - 1))) * zeta2
68
+
69
+ return Var, Var_z2, data
70
+
71
+
72
+ def Pdist2(x, y):
73
+ """compute the paired distance between x and y."""
74
+ x_norm = (x**2).sum(1).view(-1, 1)
75
+ if y is not None:
76
+ y_norm = (y**2).sum(1).view(1, -1)
77
+ else:
78
+ y = x
79
+ y_norm = x_norm.view(1, -1)
80
+ Pdist = x_norm + y_norm - 2.0 * torch.mm(x, torch.transpose(y, 0, 1))
81
+ Pdist[Pdist < 0] = 0
82
+ return Pdist
83
+
84
+
85
+ def MMD_batch2(
86
+ Fea,
87
+ len_s,
88
+ Fea_org,
89
+ sigma,
90
+ sigma0=0.1,
91
+ epsilon=10 ** (-10),
92
+ is_var_computed=True,
93
+ use_1sample_U=True,
94
+ coeff_xy=2,
95
+ ):
96
+ X = Fea[0:len_s, :]
97
+ Y = Fea[len_s:, :]
98
+ L = 1 # generalized Gaussian (if L>1)
99
+
100
+ nx = X.shape[0]
101
+ ny = Y.shape[0]
102
+ Dxx = Pdist2(X, X)
103
+ Dyy = torch.zeros(Fea.shape[0] - len_s, 1).to(Dxx.device)
104
+ # Dyy = Pdist2(Y, Y)
105
+ Dxy = Pdist2(X, Y).transpose(0, 1)
106
+ Kx = torch.exp(-Dxx / sigma0)
107
+ Ky = torch.exp(-Dyy / sigma0)
108
+ Kxy = torch.exp(-Dxy / sigma0)
109
+
110
+ nx = Kx.shape[0]
111
+
112
+ is_unbiased = False
113
+ xx = torch.div((torch.sum(Kx)), (nx * nx))
114
+ yy = Ky.reshape(-1)
115
+ xy = torch.div(torch.sum(Kxy, dim=1), (nx))
116
+
117
+ mmd2 = xx - 2 * xy + yy
118
+ return mmd2
119
+
120
+
121
+ # MMD for three samples
122
+ def MMD_3_Sample_Test(
123
+ ref_fea,
124
+ fea_y,
125
+ fea_z,
126
+ ref_fea_org,
127
+ fea_y_org,
128
+ fea_z_org,
129
+ sigma,
130
+ sigma0,
131
+ epsilon,
132
+ alpha,
133
+ ):
134
+ """Run three-sample test (TST) using deep kernel kernel."""
135
+ X = ref_fea.clone().detach()
136
+ Y = fea_y.clone().detach()
137
+ Z = fea_z.clone().detach()
138
+ X_org = ref_fea_org.clone().detach()
139
+ Y_org = fea_y_org.clone().detach()
140
+ Z_org = fea_z_org.clone().detach()
141
+
142
+ Kyy = flexible_kernel(Y, Y, Y_org, Y_org, sigma, sigma0, epsilon)
143
+ Kzz = flexible_kernel(Z, Z, Z_org, Z_org, sigma, sigma0, epsilon)
144
+ Kxy = flexible_kernel(X, Y, X_org, Y_org, sigma, sigma0, epsilon)
145
+ Kxz = flexible_kernel(X, Z, X_org, Z_org, sigma, sigma0, epsilon)
146
+
147
+ Kyynd = Kyy - torch.diag(torch.diag(Kyy))
148
+ Kzznd = Kzz - torch.diag(torch.diag(Kzz))
149
+
150
+ Diff_Var, _, _ = MMD_Diff_Var(Kyy, Kzz, Kxy, Kxz, epsilon)
151
+
152
+ u_yy = torch.sum(Kyynd) / (Y.shape[0] * (Y.shape[0] - 1))
153
+ u_zz = torch.sum(Kzznd) / (Z.shape[0] * (Z.shape[0] - 1))
154
+ u_xy = torch.sum(Kxy) / (X.shape[0] * Y.shape[0])
155
+ u_xz = torch.sum(Kxz) / (X.shape[0] * Z.shape[0])
156
+
157
+ t = u_yy - 2 * u_xy - (u_zz - 2 * u_xz)
158
+ if Diff_Var.item() <= 0:
159
+ Diff_Var = torch.max(epsilon, torch.tensor(1e-08))
160
+ p_value = torch.distributions.Normal(0, 1).cdf(-t / torch.sqrt((Diff_Var)))
161
+ t = t / torch.sqrt(Diff_Var)
162
+
163
+ if p_value > alpha:
164
+ h = 0
165
+ else:
166
+ h = 1
167
+
168
+ return h, p_value.item(), t.item()