File size: 4,827 Bytes
0528be1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
from typing import Dict, List
import numpy as np
from numpy import ndarray
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from sklearn.mixture import GaussianMixture
class ClusterFeatures(object):
"""
Basic handling of clustering features.
"""
def __init__(
self,
features: ndarray,
algorithm: str = 'kmeans',
pca_k: int = None,
random_state: int = 12345,
):
"""
:param features: the embedding matrix created by bert parent.
:param algorithm: Which clustering algorithm to use.
:param pca_k: If you want the features to be ran through pca, this is the components number.
:param random_state: Random state.
"""
if pca_k:
self.features = PCA(n_components=pca_k).fit_transform(features)
else:
self.features = features
self.algorithm = algorithm
self.pca_k = pca_k
self.random_state = random_state
def __get_model(self, k: int):
"""
Retrieve clustering model.
:param k: amount of clusters.
:return: Clustering model.
"""
if self.algorithm == 'gmm':
return GaussianMixture(n_components=k, random_state=self.random_state)
return KMeans(n_clusters=k, random_state=self.random_state)
def __get_centroids(self, model):
"""
Retrieve centroids of model.
:param model: Clustering model.
:return: Centroids.
"""
if self.algorithm == 'gmm':
return model.means_
return model.cluster_centers_
def __find_closest_args(self, centroids: np.ndarray) -> Dict:
"""
Find the closest arguments to centroid.
:param centroids: Centroids to find closest.
:return: Closest arguments.
"""
centroid_min = 1e10
cur_arg = -1
args = {}
used_idx = []
for j, centroid in enumerate(centroids):
for i, feature in enumerate(self.features):
value = np.linalg.norm(feature - centroid)
if value < centroid_min and i not in used_idx:
cur_arg = i
centroid_min = value
used_idx.append(cur_arg)
args[j] = cur_arg
centroid_min = 1e10
cur_arg = -1
return args
def calculate_elbow(self, k_max: int) -> List[float]:
"""
Calculates elbow up to the provided k_max.
:param k_max: K_max to calculate elbow for.
:return: The inertias up to k_max.
"""
inertias = []
for k in range(1, min(k_max, len(self.features))):
model = self.__get_model(k).fit(self.features)
inertias.append(model.inertia_)
return inertias
def calculate_optimal_cluster(self, k_max: int):
"""
Calculates the optimal cluster based on Elbow.
:param k_max: The max k to search elbow for.
:return: The optimal cluster size.
"""
delta_1 = []
delta_2 = []
max_strength = 0
k = 1
inertias = self.calculate_elbow(k_max)
for i in range(len(inertias)):
delta_1.append(inertias[i] - inertias[i - 1] if i > 0 else 0.0)
delta_2.append(delta_1[i] - delta_1[i - 1] if i > 1 else 0.0)
for j in range(len(inertias)):
strength = 0 if j <= 1 or j == len(inertias) - 1 else delta_2[j + 1] - delta_1[j + 1]
if strength > max_strength:
max_strength = strength
k = j + 1
return k
def cluster(self, ratio: float = 0.1, num_sentences: int = None) -> List[int]:
"""
Clusters sentences based on the ratio.
:param ratio: Ratio to use for clustering.
:param num_sentences: Number of sentences. Overrides ratio.
:return: Sentences index that qualify for summary.
"""
if num_sentences is not None:
if num_sentences == 0:
return []
k = min(num_sentences, len(self.features))
else:
k = max(int(len(self.features) * ratio), 1)
model = self.__get_model(k).fit(self.features)
centroids = self.__get_centroids(model)
cluster_args = self.__find_closest_args(centroids)
sorted_values = sorted(cluster_args.values())
return sorted_values
def __call__(self, ratio: float = 0.1, num_sentences: int = None) -> List[int]:
"""
Clusters sentences based on the ratio.
:param ratio: Ratio to use for clustering.
:param num_sentences: Number of sentences. Overrides ratio.
:return: Sentences index that qualify for summary.
"""
return self.cluster(ratio) |