|
from typing import Dict, List |
|
|
|
import numpy as np |
|
from numpy import ndarray |
|
from sklearn.cluster import KMeans |
|
from sklearn.decomposition import PCA |
|
from sklearn.mixture import GaussianMixture |
|
|
|
|
|
class ClusterFeatures(object): |
|
""" |
|
Basic handling of clustering features. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
features: ndarray, |
|
algorithm: str = 'kmeans', |
|
pca_k: int = None, |
|
random_state: int = 12345, |
|
): |
|
""" |
|
:param features: the embedding matrix created by bert parent. |
|
:param algorithm: Which clustering algorithm to use. |
|
:param pca_k: If you want the features to be ran through pca, this is the components number. |
|
:param random_state: Random state. |
|
""" |
|
if pca_k: |
|
self.features = PCA(n_components=pca_k).fit_transform(features) |
|
else: |
|
self.features = features |
|
|
|
self.algorithm = algorithm |
|
self.pca_k = pca_k |
|
self.random_state = random_state |
|
|
|
def __get_model(self, k: int): |
|
""" |
|
Retrieve clustering model. |
|
|
|
:param k: amount of clusters. |
|
:return: Clustering model. |
|
""" |
|
|
|
if self.algorithm == 'gmm': |
|
return GaussianMixture(n_components=k, random_state=self.random_state) |
|
return KMeans(n_clusters=k, random_state=self.random_state) |
|
|
|
def __get_centroids(self, model): |
|
""" |
|
Retrieve centroids of model. |
|
|
|
:param model: Clustering model. |
|
:return: Centroids. |
|
""" |
|
if self.algorithm == 'gmm': |
|
return model.means_ |
|
return model.cluster_centers_ |
|
|
|
def __find_closest_args(self, centroids: np.ndarray) -> Dict: |
|
""" |
|
Find the closest arguments to centroid. |
|
|
|
:param centroids: Centroids to find closest. |
|
:return: Closest arguments. |
|
""" |
|
centroid_min = 1e10 |
|
cur_arg = -1 |
|
args = {} |
|
used_idx = [] |
|
|
|
for j, centroid in enumerate(centroids): |
|
|
|
for i, feature in enumerate(self.features): |
|
value = np.linalg.norm(feature - centroid) |
|
|
|
if value < centroid_min and i not in used_idx: |
|
cur_arg = i |
|
centroid_min = value |
|
|
|
used_idx.append(cur_arg) |
|
args[j] = cur_arg |
|
centroid_min = 1e10 |
|
cur_arg = -1 |
|
|
|
return args |
|
|
|
def calculate_elbow(self, k_max: int) -> List[float]: |
|
""" |
|
Calculates elbow up to the provided k_max. |
|
|
|
:param k_max: K_max to calculate elbow for. |
|
:return: The inertias up to k_max. |
|
""" |
|
inertias = [] |
|
|
|
for k in range(1, min(k_max, len(self.features))): |
|
model = self.__get_model(k).fit(self.features) |
|
|
|
inertias.append(model.inertia_) |
|
|
|
return inertias |
|
|
|
def calculate_optimal_cluster(self, k_max: int): |
|
""" |
|
Calculates the optimal cluster based on Elbow. |
|
|
|
:param k_max: The max k to search elbow for. |
|
:return: The optimal cluster size. |
|
""" |
|
delta_1 = [] |
|
delta_2 = [] |
|
|
|
max_strength = 0 |
|
k = 1 |
|
|
|
inertias = self.calculate_elbow(k_max) |
|
|
|
for i in range(len(inertias)): |
|
delta_1.append(inertias[i] - inertias[i - 1] if i > 0 else 0.0) |
|
delta_2.append(delta_1[i] - delta_1[i - 1] if i > 1 else 0.0) |
|
|
|
for j in range(len(inertias)): |
|
strength = 0 if j <= 1 or j == len(inertias) - 1 else delta_2[j + 1] - delta_1[j + 1] |
|
|
|
if strength > max_strength: |
|
max_strength = strength |
|
k = j + 1 |
|
|
|
return k |
|
|
|
def cluster(self, ratio: float = 0.1, num_sentences: int = None) -> List[int]: |
|
""" |
|
Clusters sentences based on the ratio. |
|
|
|
:param ratio: Ratio to use for clustering. |
|
:param num_sentences: Number of sentences. Overrides ratio. |
|
:return: Sentences index that qualify for summary. |
|
""" |
|
|
|
if num_sentences is not None: |
|
if num_sentences == 0: |
|
return [] |
|
|
|
k = min(num_sentences, len(self.features)) |
|
else: |
|
k = max(int(len(self.features) * ratio), 1) |
|
|
|
model = self.__get_model(k).fit(self.features) |
|
|
|
centroids = self.__get_centroids(model) |
|
cluster_args = self.__find_closest_args(centroids) |
|
|
|
sorted_values = sorted(cluster_args.values()) |
|
return sorted_values |
|
|
|
def __call__(self, ratio: float = 0.1, num_sentences: int = None) -> List[int]: |
|
""" |
|
Clusters sentences based on the ratio. |
|
|
|
:param ratio: Ratio to use for clustering. |
|
:param num_sentences: Number of sentences. Overrides ratio. |
|
:return: Sentences index that qualify for summary. |
|
""" |
|
return self.cluster(ratio) |