import torch import torch.nn as nn from torch.nn.init import xavier_normal_ from sklearn.cluster import KMeans class MLPLayers(nn.Module): def __init__( self, layers, dropout=0.0, activation="relu", bn=False ): super(MLPLayers, self).__init__() self.layers = layers self.dropout = dropout self.activation = activation self.use_bn = bn mlp_modules = [] for idx, (input_size, output_size) in enumerate( zip(self.layers[:-1], self.layers[1:]) ): mlp_modules.append(nn.Dropout(p=self.dropout)) mlp_modules.append(nn.Linear(input_size, output_size)) if self.use_bn: mlp_modules.append(nn.BatchNorm1d(num_features=output_size)) activation_func = activation_layer(self.activation, output_size) if activation_func is not None and idx != (len(self.layers)-2): mlp_modules.append(activation_func) self.mlp_layers = nn.Sequential(*mlp_modules) self.apply(self.init_weights) def init_weights(self, module): # We just initialize the module with normal distribution as the paper said if isinstance(module, nn.Linear): xavier_normal_(module.weight.data) if module.bias is not None: module.bias.data.fill_(0.0) def forward(self, input_feature): return self.mlp_layers(input_feature) def activation_layer(activation_name="relu", emb_dim=None): if activation_name is None: activation = None elif isinstance(activation_name, str): if activation_name.lower() == "sigmoid": activation = nn.Sigmoid() elif activation_name.lower() == "tanh": activation = nn.Tanh() elif activation_name.lower() == "relu": activation = nn.ReLU() elif activation_name.lower() == "leakyrelu": activation = nn.LeakyReLU() elif activation_name.lower() == "none": activation = None elif issubclass(activation_name, nn.Module): activation = activation_name() else: raise NotImplementedError( "activation function {} is not implemented".format(activation_name) ) return activation def kmeans( samples, num_clusters, num_iters = 10, ): B, dim, dtype, device = samples.shape[0], samples.shape[-1], samples.dtype, samples.device x = samples.cpu().detach().numpy() cluster = KMeans(n_clusters = num_clusters, max_iter = num_iters).fit(x) centers = cluster.cluster_centers_ tensor_centers = torch.from_numpy(centers).to(device) return tensor_centers @torch.no_grad() def sinkhorn_algorithm(distances, epsilon, sinkhorn_iterations): Q = torch.exp(- distances / epsilon) B = Q.shape[0] # number of samples to assign K = Q.shape[1] # how many centroids per block (usually set to 256) # make the matrix sums to 1 sum_Q = Q.sum(-1, keepdim=True).sum(-2, keepdim=True) Q /= sum_Q # print(Q.sum()) for it in range(sinkhorn_iterations): # normalize each column: total weight per sample must be 1/B Q /= torch.sum(Q, dim=1, keepdim=True) Q /= B # normalize each row: total weight per prototype must be 1/K Q /= torch.sum(Q, dim=0, keepdim=True) Q /= K Q *= B # the colomns must sum to 1 so that Q is an assignment return Q