Spaces:
Running
on
Zero
Running
on
Zero
import random | |
from sentence_transformers import SentenceTransformer | |
import faiss | |
import pandas as pd | |
class Dedup: | |
def __init__(self, config=None): | |
self.index = None | |
self.xb = None | |
self.clusters = None | |
self.th = (config or {}).get("dedup_threshold", 0.5) | |
self.model_name = (config or {}).get("embeddings_model", 'all-MiniLM-L6-v2') | |
def copy(self): | |
return Dedup( | |
{"dedup_threshold": self.th, | |
"embeddings_model": self.model_name} | |
) | |
def generate_embeddings(self, texts): | |
""" | |
Generate embeddings for the given texts using the SentenceTransformer model. | |
""" | |
model = SentenceTransformer(self.model_name) | |
embeddings = model.encode(texts, show_progress_bar=True) | |
return embeddings | |
def build_index(self, records): | |
""" | |
Build the FAISS index for the given dataset. | |
input: records - a pandas dataframe with a 'text' column | |
output: index - the FAISS index | |
embeddings - the embeddings of the dataset | |
""" | |
# Generate embeddings for the dataset | |
embeddings = self.generate_embeddings(records['text'].tolist()) | |
# Build the FAISS index | |
embeddings_dim = embeddings.shape[1] | |
index = faiss.IndexFlatL2(embeddings_dim) | |
index.add(embeddings) | |
return index, embeddings | |
def cluster_data(self, records): | |
""" | |
Cluster the given dataset. | |
input: records - a pandas dataframe with a 'text' column | |
output: clusters - a list of clusters, where each cluster is a set of indices | |
""" | |
if self.index is None: | |
self.index, self.xb = self.build_index(records) | |
distances, indices = self.index.search(self.xb, 30) #TODO: dereive it from the batch size | |
clusters = [] | |
visited = set() | |
for i in range(len(self.xb)): | |
if i in visited: | |
continue | |
# Find neighbors and create a new cluster | |
neighbors = [idx for idx, distance in zip(indices[i], distances[i]) if distance <= self.th] | |
new_cluster = {i} | |
# Add all neighbors to the new cluster | |
for neighbor in neighbors: | |
if neighbor not in visited: | |
visited.add(neighbor) | |
new_cluster.add(neighbor) | |
clusters.append(new_cluster) | |
return clusters | |
def sample(self, records: pd.DataFrame, operation_function=random.choice): | |
""" | |
Sample the given dataset. | |
input: records - a pandas dataframe with a 'text' column | |
operation_function - a function that receives a cluster and returns an index | |
output: a pandas dataframe with the sampled records | |
""" | |
if not callable(operation_function): | |
raise ValueError("The 'operation_function' must be a callable function.") | |
if self.clusters is None: | |
self.clusters = self.cluster_data(records) | |
samples = [operation_function(list(cluster)) for cluster in self.clusters] | |
return records.iloc[sorted(samples)] | |