HoneyTian's picture
update
69ad385
raw
history blame
5.05 kB
#!/usr/bin/python3
# -*- coding: utf-8 -*-
"""
https://github.com/georgepar/gmmhmm-pytorch/blob/master/gmm.py
https://github.com/ldeecke/gmm-torch
"""
import math
from sklearn import cluster
import torch
import torch.nn as nn
class GaussianMixtureModel(nn.Module):
def __init__(self,
n_mixtures: int,
n_features: int,
init: str = "random",
device: str = 'cpu',
n_iter: int = 1000,
delta: float = 1e-3,
warm_start: bool = False,
):
super(GaussianMixtureModel, self).__init__()
self.n_mixtures = n_mixtures
self.n_features = n_features
self.init = init
self.device = device
self.n_iter = n_iter
self.delta = delta
self.warm_start = warm_start
if init not in ('kmeans', 'random'):
raise AssertionError
self.mu = nn.Parameter(
torch.Tensor(n_mixtures, n_features),
requires_grad=False,
)
self.sigma = None
# the weight of each gaussian
self.pi = nn.Parameter(
torch.Tensor(n_mixtures),
requires_grad=False
)
self.converged_ = False
self.eps = 1e-6
self.delta = delta
self.warm_start = warm_start
self.n_iter = n_iter
def reset_sigma(self):
raise NotImplementedError
def estimate_precisions(self):
raise NotImplementedError
def log_prob(self, x):
raise NotImplementedError
def weighted_log_prob(self, x):
log_prob = self.log_prob(x)
weighted_log_prob = log_prob + torch.log(self.pi)
return weighted_log_prob
def log_likelihood(self, x):
weighted_log_prob = self.weighted_log_prob(x)
per_sample_log_likelihood = torch.logsumexp(weighted_log_prob, dim=1)
log_likelihood = torch.sum(per_sample_log_likelihood)
return log_likelihood
def e_step(self, x):
weighted_log_prob = self.weighted_log_prob(x)
weighted_log_prob = weighted_log_prob.unsqueeze(dim=-1)
log_likelihood = torch.logsumexp(weighted_log_prob, dim=1, keepdim=True)
q = weighted_log_prob - log_likelihood
return q.squeeze()
def m_step(self, x, q):
x = x.unsqueeze(dim=1)
return
def estimate_mu(self, x, pi, responsibilities):
nk = pi * x.size(0)
mu = torch.sum(responsibilities * x, dim=0, keepdim=True) / nk
return mu
def estimate_pi(self, x, responsibilities):
pi = torch.sum(responsibilities, dim=0, keepdim=True) + self.eps
pi = pi / x.size(0)
return pi
def reset_parameters(self, x=None):
if self.init == 'random' or x is None:
self.mu.normal_()
self.reset_sigma()
self.pi.fill_(1.0 / self.n_mixtures)
elif self.init == 'kmeans':
centroids = cluster.KMeans(n_clusters=self.n_mixtures, n_init=1).fit(x).cluster_centers_
centroids = torch.tensor(centroids).to(self.device)
self.update_(mu=centroids)
else:
raise NotImplementedError
class DiagonalCovarianceGMM(GaussianMixtureModel):
def __init__(self,
n_mixtures: int,
n_features: int,
init: str = "random",
device: str = 'cpu',
n_iter: int = 1000,
delta: float = 1e-3,
warm_start: bool = False,
):
super(DiagonalCovarianceGMM, self).__init__(
n_mixtures=n_mixtures,
n_features=n_features,
init=init,
device=device,
n_iter=n_iter,
delta=delta,
warm_start=warm_start,
)
self.sigma = nn.Parameter(
torch.Tensor(n_mixtures, n_features), requires_grad=False
)
self.reset_parameters()
self.to(self.device)
def reset_sigma(self):
self.sigma.fill_(1)
def estimate_precisions(self):
return torch.rsqrt(self.sigma)
def log_prob(self, x):
precisions = self.estimate_precisions()
x = x.unsqueeze(1)
mu = self.mu.unsqueeze(0)
precisions = precisions.unsqueeze(0)
# This is outer product
exp_term = torch.sum(
(mu * mu + x * x - 2 * x * mu) * (precisions ** 2), dim=2, keepdim=True
)
log_det = torch.sum(torch.log(precisions), dim=2, keepdim=True)
logp = -0.5 * (self.n_features * torch.log(2 * math.pi) + exp_term) + log_det
return logp.squeeze()
def estimate_sigma(self, x, mu, pi, responsibilities):
nk = pi * x.size(0)
x2 = (responsibilities * x * x).sum(0, keepdim=True) / nk
mu2 = mu * mu
xmu = (responsibilities * mu * x).sum(0, keepdim=True) / nk
sigma = x2 - 2 * xmu + mu2 + self.eps
return sigma
def demo1():
return
if __name__ == '__main__':
demo1()