|
import torch |
|
import numpy as np |
|
|
|
class GaussianMixtureModel: |
|
def __init__(self, mu, Sigma, pi): |
|
|
|
assert isinstance(mu, torch.Tensor), "mu must be a torch.Tensor." |
|
assert isinstance(Sigma, torch.Tensor), "Sigma must be a torch.Tensor." |
|
assert isinstance(pi, torch.Tensor), "pi must be a torch.Tensor." |
|
|
|
|
|
assert mu.dtype == torch.float32, "mu must have dtype torch.float32." |
|
assert Sigma.dtype == torch.float32, "Sigma must have dtype torch.float32." |
|
assert pi.dtype == torch.float32, "pi must have dtype torch.float32." |
|
|
|
self.K, self.d = mu.shape[:2] |
|
|
|
|
|
if mu.shape == (self.K, self.d): |
|
mu = mu.unsqueeze(-1) |
|
assert mu.shape == (self.K, self.d, 1), "mu must have shape (K, d, 1)." |
|
|
|
|
|
assert Sigma.shape == (self.K, self.d, self.d), "Sigma must have shape (K, d, d)." |
|
|
|
|
|
if pi.shape == (self.K,): |
|
pi = pi.view(self.K, 1, 1) |
|
elif pi.shape == (self.K, 1): |
|
pi = pi.unsqueeze(-1) |
|
assert pi.shape == (self.K, 1, 1), "pi must have shape (K, 1, 1)." |
|
|
|
|
|
assert torch.isclose(torch.sum(pi), torch.tensor(1.0)), "Mixture weights must sum to 1." |
|
|
|
self.mu = mu |
|
self.Sigma = Sigma |
|
self.pi = pi |
|
|
|
def sample(self, n_samples): |
|
|
|
samples = [] |
|
for _ in range(n_samples): |
|
|
|
k = torch.multinomial(self.pi.reshape(self.pi.shape[0]), 1).item() |
|
sample = torch.distributions.MultivariateNormal(self.mu[k].squeeze(), self.Sigma[k]).sample() |
|
samples.append(sample) |
|
return torch.stack(samples) |
|
|
|
def log_prob(self, x): |
|
|
|
x = x.view(1, self.d, 1) |
|
diff = x - self.mu |
|
inv_Sigma = torch.inverse(self.Sigma) |
|
|
|
exponent = -0.5 * torch.bmm(torch.bmm(diff.transpose(1, 2), inv_Sigma), diff).squeeze() |
|
normalization = torch.log(torch.det(2 * torch.pi * self.Sigma)) / 2 |
|
log_probs = torch.log(self.pi.squeeze()) + exponent - normalization |
|
return torch.logsumexp(log_probs, dim=0) |
|
|
|
def score(self, x): |
|
|
|
B = x.shape[0] |
|
x = x.view(B, 1, self.d, 1) |
|
diff = x - self.mu.unsqueeze(0) |
|
inv_Sigma = torch.inverse(self.Sigma).unsqueeze(0) |
|
|
|
diff_t = diff.transpose(-2, -1).contiguous().view(B * self.K, 1, self.d) |
|
inv_Sigma_flat = inv_Sigma.view(self.K, self.d, self.d).expand(B, self.K, self.d, self.d).contiguous().view(B * self.K, self.d, self.d) |
|
exponent = -0.5 * torch.bmm(torch.bmm(diff_t, inv_Sigma_flat), diff.view(B * self.K, self.d, 1)).view(B, self.K) |
|
|
|
normalization = torch.log(torch.det(2 * torch.pi * self.Sigma)).unsqueeze(0) / 2 |
|
probs = torch.exp(torch.log(self.pi.squeeze()).unsqueeze(0) + exponent - normalization) |
|
norm_probs = probs / torch.sum(probs, dim=1, keepdim=True) |
|
|
|
gradients = [] |
|
for k in range(self.K): |
|
gradient = -torch.bmm(inv_Sigma[:, k].expand(B, self.d, self.d), diff[:, k]) |
|
gradients.append(norm_probs[:, k].unsqueeze(1) * gradient.squeeze(-1)) |
|
return torch.sum(torch.stack(gradients, dim=1), dim=1) |
|
|
|
def forward_diffusion(self, t): |
|
|
|
mu_t = self.mu * torch.tensor(np.exp(-0.5 * t), dtype=torch.float32) |
|
exp_neg_beta_t = torch.tensor(np.exp(-t), dtype=torch.float32).view(1, 1, 1) |
|
Sigma_t = self.Sigma * exp_neg_beta_t + torch.eye(self.d, dtype=torch.float32) * (1 - exp_neg_beta_t) |
|
return GaussianMixtureModel(mu_t, Sigma_t, self.pi) |
|
|
|
def flow(self, x_t, t, dt, num_steps): |
|
for _ in range(num_steps): |
|
x_t = self.probability_flow_ode(x_t, t, dt) |
|
t += dt |
|
return x_t |
|
|
|
def flow_gmm_to_normal(self, x_0, T=5, N=32): |
|
dt = T / N |
|
return self.flow(x_0, 0, dt, N) |
|
|
|
def flow_normal_to_gmm(self, x_T, T=5, N=32): |
|
dt = -T / N |
|
return self.flow(x_T, T, dt, N) |
|
|
|
def probability_flow_ode(self, x_t, t, dt): |
|
|
|
|
|
gmm_t = self.forward_diffusion(t) |
|
|
|
|
|
score = gmm_t.score(x_t) |
|
|
|
|
|
drift = -0.5 * x_t - 0.5 * score |
|
|
|
|
|
x_t_plus_dt = x_t + drift * dt |
|
|
|
return x_t_plus_dt |
|
|
|
|
|
if __name__ == "__main__": |
|
mu = torch.tensor([[0, 0], [1, 1]], dtype=torch.float32) |
|
Sigma = torch.stack([torch.eye(2), torch.eye(2)], dim=0).float() |
|
pi = torch.tensor([0.5, 0.5], dtype=torch.float32) |
|
|
|
gmm = GaussianMixtureModel(mu, Sigma, pi) |
|
|
|
|
|
t = 1.0 |
|
gmm_t = gmm.forward_diffusion(t) |
|
|
|
|
|
samples = gmm_t.sample(10) |
|
print("Samples after forward diffusion:\n", samples) |
|
|
|
|
|
log_prob = gmm_t.log_prob(torch.tensor([0.5, 0.5], dtype=torch.float32)) |
|
print("Log Probability after forward diffusion:", log_prob) |
|
|
|
|
|
score = gmm_t.score(torch.tensor([[0.5, 0.5]], dtype=torch.float32)) |
|
print("Score after forward diffusion:", score) |
|
|
|
|
|
x_0 = torch.tensor([[0.5, 0.5]], dtype=torch.float32) |
|
x_T_normal = gmm.flow_gmm_to_normal(x_0) |
|
print("x_T_normal after flowing from GMM to normal:", x_T_normal) |
|
|
|
|
|
x_T = torch.tensor([[0.0, 0.0]], dtype=torch.float32) |
|
x_0_gmm = gmm.flow_normal_to_gmm(x_T) |
|
print("x_0_gmm after flowing from normal to GMM:", x_0_gmm) |
|
|