|
import torch |
|
import torch.nn as nn |
|
from torch.nn.init import xavier_normal_ |
|
from sklearn.cluster import KMeans |
|
|
|
|
|
class MLPLayers(nn.Module): |
|
|
|
def __init__( |
|
self, layers, dropout=0.0, activation="relu", bn=False |
|
): |
|
super(MLPLayers, self).__init__() |
|
self.layers = layers |
|
self.dropout = dropout |
|
self.activation = activation |
|
self.use_bn = bn |
|
|
|
mlp_modules = [] |
|
for idx, (input_size, output_size) in enumerate( |
|
zip(self.layers[:-1], self.layers[1:]) |
|
): |
|
mlp_modules.append(nn.Dropout(p=self.dropout)) |
|
mlp_modules.append(nn.Linear(input_size, output_size)) |
|
if self.use_bn: |
|
mlp_modules.append(nn.BatchNorm1d(num_features=output_size)) |
|
activation_func = activation_layer(self.activation, output_size) |
|
if activation_func is not None and idx != (len(self.layers)-2): |
|
mlp_modules.append(activation_func) |
|
|
|
self.mlp_layers = nn.Sequential(*mlp_modules) |
|
self.apply(self.init_weights) |
|
|
|
def init_weights(self, module): |
|
|
|
if isinstance(module, nn.Linear): |
|
xavier_normal_(module.weight.data) |
|
if module.bias is not None: |
|
module.bias.data.fill_(0.0) |
|
|
|
def forward(self, input_feature): |
|
return self.mlp_layers(input_feature) |
|
|
|
def activation_layer(activation_name="relu", emb_dim=None): |
|
|
|
if activation_name is None: |
|
activation = None |
|
elif isinstance(activation_name, str): |
|
if activation_name.lower() == "sigmoid": |
|
activation = nn.Sigmoid() |
|
elif activation_name.lower() == "tanh": |
|
activation = nn.Tanh() |
|
elif activation_name.lower() == "relu": |
|
activation = nn.ReLU() |
|
elif activation_name.lower() == "leakyrelu": |
|
activation = nn.LeakyReLU() |
|
elif activation_name.lower() == "none": |
|
activation = None |
|
elif issubclass(activation_name, nn.Module): |
|
activation = activation_name() |
|
else: |
|
raise NotImplementedError( |
|
"activation function {} is not implemented".format(activation_name) |
|
) |
|
|
|
return activation |
|
|
|
def kmeans( |
|
samples, |
|
num_clusters, |
|
num_iters = 10, |
|
): |
|
B, dim, dtype, device = samples.shape[0], samples.shape[-1], samples.dtype, samples.device |
|
x = samples.cpu().detach().numpy() |
|
|
|
cluster = KMeans(n_clusters = num_clusters, max_iter = num_iters).fit(x) |
|
|
|
centers = cluster.cluster_centers_ |
|
tensor_centers = torch.from_numpy(centers).to(device) |
|
|
|
return tensor_centers |
|
|
|
|
|
@torch.no_grad() |
|
def sinkhorn_algorithm(distances, epsilon, sinkhorn_iterations): |
|
Q = torch.exp(- distances / epsilon) |
|
|
|
B = Q.shape[0] |
|
K = Q.shape[1] |
|
|
|
|
|
sum_Q = Q.sum(-1, keepdim=True).sum(-2, keepdim=True) |
|
Q /= sum_Q |
|
|
|
for it in range(sinkhorn_iterations): |
|
|
|
|
|
Q /= torch.sum(Q, dim=1, keepdim=True) |
|
Q /= B |
|
|
|
|
|
Q /= torch.sum(Q, dim=0, keepdim=True) |
|
Q /= K |
|
|
|
|
|
Q *= B |
|
return Q |