from scipy.spatial.distance import cosine import argparse import json import pdb import torch import torch.nn.functional as F import numpy as np import time from collections import OrderedDict class TWCClustering: def __init__(self): print("In Zscore Clustering") def compute_matrix(self,embeddings): #print("Computing similarity matrix ...)") embeddings= np.array(embeddings) start = time.time() vec_a = embeddings.T #vec_a shape (1024,) vec_a = vec_a/np.linalg.norm(vec_a,axis=0) #Norm is along axis 0 - rows vec_a = vec_a.T #vec_a shape becomes (,1024) similarity_matrix = np.inner(vec_a,vec_a) end = time.time() time_val = (end-start)*1000 #print(f"Similarity matrix computation complete. Time taken:{(time_val/(1000*60)):.2f} minutes") return similarity_matrix def get_terms_above_threshold(self,matrix,embeddings,pivot_index,threshold): run_index = pivot_index picked_arr = [] while (run_index < len(embeddings)): if (matrix[pivot_index][run_index] >= threshold): picked_arr.append(run_index) run_index += 1 return picked_arr def update_picked_dict_arr(self,picked_dict,arr): for i in range(len(arr)): picked_dict[arr[i]] = 1 def update_picked_dict(self,picked_dict,in_dict): for key in in_dict: picked_dict[key] = 1 def find_pivot_subgraph(self,pivot_index,arr,matrix,threshold,strict_cluster = True): center_index = pivot_index center_score = 0 center_dict = {} for i in range(len(arr)): node_i_index = arr[i] running_score = 0 temp_dict = {} for j in range(len(arr)): node_j_index = arr[j] cosine_dist = matrix[node_i_index][node_j_index] if ((cosine_dist < threshold) and strict_cluster): continue running_score += cosine_dist temp_dict[node_j_index] = cosine_dist if (running_score > center_score): center_index = node_i_index center_dict = temp_dict center_score = running_score sorted_d = OrderedDict(sorted(center_dict.items(), key=lambda kv: kv[1], reverse=True)) return {"pivot_index":center_index,"orig_index":pivot_index,"neighs":sorted_d} def update_overlap_stats(self,overlap_dict,cluster_info): arr = list(cluster_info["neighs"].keys()) for val in arr: if (val not in overlap_dict): overlap_dict[val] = 1 else: overlap_dict[val] += 1 def bucket_overlap(self,overlap_dict): bucket_dict = {} for key in overlap_dict: if (overlap_dict[key] not in bucket_dict): bucket_dict[overlap_dict[key]] = 1 else: bucket_dict[overlap_dict[key]] += 1 sorted_d = OrderedDict(sorted(bucket_dict.items(), key=lambda kv: kv[1], reverse=False)) return sorted_d def merge_clusters(self,ref_cluster,curr_cluster): dup_arr = ref_cluster.copy() for j in range(len(curr_cluster)): if (curr_cluster[j] not in dup_arr): ref_cluster.append(curr_cluster[j]) def non_overlapped_clustering(self,matrix,embeddings,threshold,mean,std,cluster_dict): picked_dict = {} overlap_dict = {} candidates = [] for i in range(len(embeddings)): if (i in picked_dict): continue zscore = mean + threshold*std arr = self.get_terms_above_threshold(matrix,embeddings,i,zscore) candidates.append(arr) self.update_picked_dict_arr(picked_dict,arr) # Merge arrays to create non-overlapping sets run_index_i = 0 while (run_index_i < len(candidates)): ref_cluster = candidates[run_index_i] run_index_j = run_index_i + 1 found = False while (run_index_j < len(candidates)): curr_cluster = candidates[run_index_j] for k in range(len(curr_cluster)): if (curr_cluster[k] in ref_cluster): self.merge_clusters(ref_cluster,curr_cluster) candidates.pop(run_index_j) found = True run_index_i = 0 break if (found): break else: run_index_j += 1 if (not found): run_index_i += 1 zscore = mean + threshold*std for i in range(len(candidates)): arr = candidates[i] cluster_info = self.find_pivot_subgraph(arr[0],arr,matrix,zscore,strict_cluster = False) cluster_dict["clusters"].append(cluster_info) return {} def overlapped_clustering(self,matrix,embeddings,threshold,mean,std,cluster_dict): picked_dict = {} overlap_dict = {} zscore = mean + threshold*std for i in range(len(embeddings)): if (i in picked_dict): continue arr = self.get_terms_above_threshold(matrix,embeddings,i,zscore) cluster_info = self.find_pivot_subgraph(i,arr,matrix,zscore,strict_cluster = True) self.update_picked_dict(picked_dict,cluster_info["neighs"]) self.update_overlap_stats(overlap_dict,cluster_info) cluster_dict["clusters"].append(cluster_info) sorted_d = self.bucket_overlap(overlap_dict) return sorted_d def cluster(self,output_file,texts,embeddings,threshold,clustering_type): is_overlapped = True if clustering_type == "overlapped" else False matrix = self.compute_matrix(embeddings) mean = np.mean(matrix) std = np.std(matrix) zscores = [] inc = 0 value = mean while (value < 1): zscores.append({"threshold":inc,"cosine":round(value,2)}) inc += 1 value = mean + inc*std #print("In clustering:",round(std,2),zscores) cluster_dict = {} cluster_dict["clusters"] = [] if (is_overlapped): sorted_d = self.overlapped_clustering(matrix,embeddings,threshold,mean,std,cluster_dict) else: sorted_d = self.non_overlapped_clustering(matrix,embeddings,threshold,mean,std,cluster_dict) curr_threshold = f"{threshold} (cosine:{mean+threshold*std:.2f})" cluster_dict["info"] ={"mean":mean,"std":std,"current_threshold":curr_threshold,"zscores":zscores,"overlap":list(sorted_d.items())} return cluster_dict