File size: 3,168 Bytes
e1aa577
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import random
from sentence_transformers import SentenceTransformer
import faiss
import pandas as pd


class Dedup:

    def __init__(self, config=None):
        self.index = None
        self.xb = None
        self.clusters = None
        self.th = (config or {}).get("dedup_threshold", 0.5)
        self.model_name = (config or {}).get("embeddings_model", 'all-MiniLM-L6-v2')

    def copy(self):
        return Dedup(
            {"dedup_threshold": self.th,
             "embeddings_model": self.model_name}
        )

    def generate_embeddings(self, texts):
        """
        Generate embeddings for the given texts using the SentenceTransformer model.
        """
        model = SentenceTransformer(self.model_name)
        embeddings = model.encode(texts, show_progress_bar=True)
        return embeddings

    def build_index(self, records):
        """
        Build the FAISS index for the given dataset.
        input: records - a pandas dataframe with a 'text' column
        output: index - the FAISS index
                embeddings - the embeddings of the dataset
        """
        # Generate embeddings for the dataset
        embeddings = self.generate_embeddings(records['text'].tolist())

        # Build the FAISS index
        embeddings_dim = embeddings.shape[1]
        index = faiss.IndexFlatL2(embeddings_dim)
        index.add(embeddings)
        return index, embeddings

    def cluster_data(self, records):
        """
        Cluster the given dataset.
        input: records - a pandas dataframe with a 'text' column
        output: clusters - a list of clusters, where each cluster is a set of indices
        """

        if self.index is None:
            self.index, self.xb = self.build_index(records)

        distances, indices = self.index.search(self.xb, 30) #TODO: dereive it from the batch size

        clusters = []
        visited = set()

        for i in range(len(self.xb)):
            if i in visited:
                continue

            # Find neighbors and create a new cluster
            neighbors = [idx for idx, distance in zip(indices[i], distances[i]) if distance <= self.th]
            new_cluster = {i}

            # Add all neighbors to the new cluster
            for neighbor in neighbors:
                if neighbor not in visited:
                    visited.add(neighbor)
                    new_cluster.add(neighbor)

            clusters.append(new_cluster)
        return clusters

    def sample(self, records: pd.DataFrame, operation_function=random.choice):
        """
        Sample the given dataset.
        input: records - a pandas dataframe with a 'text' column
               operation_function - a function that receives a cluster and returns an index
        output: a pandas dataframe with the sampled records
        """

        if not callable(operation_function):
            raise ValueError("The 'operation_function' must be a callable function.")

        if self.clusters is None:
            self.clusters = self.cluster_data(records)

        samples = [operation_function(list(cluster)) for cluster in self.clusters]
        return records.iloc[sorted(samples)]