Spaces:
Running
Running
import threading | |
import chromadb | |
import posthog | |
import torch | |
import math | |
import numpy as np | |
import extensions.superboogav2.parameters as parameters | |
from chromadb.config import Settings | |
from sentence_transformers import SentenceTransformer | |
from modules.logging_colors import logger | |
from modules.text_generation import encode, decode | |
logger.debug('Intercepting all calls to posthog.') | |
posthog.capture = lambda *args, **kwargs: None | |
class Collecter(): | |
def __init__(self): | |
pass | |
def add(self, texts: list[str], texts_with_context: list[str], starting_indices: list[int]): | |
pass | |
def get(self, search_strings: list[str], n_results: int) -> list[str]: | |
pass | |
def clear(self): | |
pass | |
class Embedder(): | |
def __init__(self): | |
pass | |
def embed(self, text: str) -> list[torch.Tensor]: | |
pass | |
class Info: | |
def __init__(self, start_index, text_with_context, distance, id): | |
self.text_with_context = text_with_context | |
self.start_index = start_index | |
self.distance = distance | |
self.id = id | |
def calculate_distance(self, other_info): | |
if parameters.get_new_dist_strategy() == parameters.DIST_MIN_STRATEGY: | |
# Min | |
return min(self.distance, other_info.distance) | |
elif parameters.get_new_dist_strategy() == parameters.DIST_HARMONIC_STRATEGY: | |
# Harmonic mean | |
return 2 * (self.distance * other_info.distance) / (self.distance + other_info.distance) | |
elif parameters.get_new_dist_strategy() == parameters.DIST_GEOMETRIC_STRATEGY: | |
# Geometric mean | |
return (self.distance * other_info.distance) ** 0.5 | |
elif parameters.get_new_dist_strategy() == parameters.DIST_ARITHMETIC_STRATEGY: | |
# Arithmetic mean | |
return (self.distance + other_info.distance) / 2 | |
else: # Min is default | |
return min(self.distance, other_info.distance) | |
def merge_with(self, other_info): | |
s1 = self.text_with_context | |
s2 = other_info.text_with_context | |
s1_start = self.start_index | |
s2_start = other_info.start_index | |
new_dist = self.calculate_distance(other_info) | |
if self.should_merge(s1, s2, s1_start, s2_start): | |
if s1_start <= s2_start: | |
if s1_start + len(s1) >= s2_start + len(s2): # if s1 completely covers s2 | |
return Info(s1_start, s1, new_dist, self.id) | |
else: | |
overlap = max(0, s1_start + len(s1) - s2_start) | |
return Info(s1_start, s1 + s2[overlap:], new_dist, self.id) | |
else: | |
if s2_start + len(s2) >= s1_start + len(s1): # if s2 completely covers s1 | |
return Info(s2_start, s2, new_dist, other_info.id) | |
else: | |
overlap = max(0, s2_start + len(s2) - s1_start) | |
return Info(s2_start, s2 + s1[overlap:], new_dist, other_info.id) | |
return None | |
def should_merge(s1, s2, s1_start, s2_start): | |
# Check if s1 and s2 are adjacent or overlapping | |
s1_end = s1_start + len(s1) | |
s2_end = s2_start + len(s2) | |
return not (s1_end < s2_start or s2_end < s1_start) | |
class ChromaCollector(Collecter): | |
def __init__(self, embedder: Embedder): | |
super().__init__() | |
self.chroma_client = chromadb.Client(Settings(anonymized_telemetry=False)) | |
self.embedder = embedder | |
self.collection = self.chroma_client.create_collection(name="context", embedding_function=self.embedder.embed) | |
self.ids = [] | |
self.id_to_info = {} | |
self.embeddings_cache = {} | |
self.lock = threading.Lock() # Locking so the server doesn't break. | |
def add(self, texts: list[str], texts_with_context: list[str], starting_indices: list[int], metadatas: list[dict] = None): | |
with self.lock: | |
assert metadatas is None or len(metadatas) == len(texts), "metadatas must be None or have the same length as texts" | |
if len(texts) == 0: | |
return | |
new_ids = self._get_new_ids(len(texts)) | |
(existing_texts, existing_embeddings, existing_ids, existing_metas), \ | |
(non_existing_texts, non_existing_ids, non_existing_metas) = self._split_texts_by_cache_hit(texts, new_ids, metadatas) | |
# If there are any already existing texts, add them all at once. | |
if existing_texts: | |
logger.info(f'Adding {len(existing_embeddings)} cached embeddings.') | |
args = {'embeddings': existing_embeddings, 'documents': existing_texts, 'ids': existing_ids} | |
if metadatas is not None: | |
args['metadatas'] = existing_metas | |
self.collection.add(**args) | |
# If there are any non-existing texts, compute their embeddings all at once. Each call to embed has significant overhead. | |
if non_existing_texts: | |
non_existing_embeddings = self.embedder.embed(non_existing_texts).tolist() | |
for text, embedding in zip(non_existing_texts, non_existing_embeddings): | |
self.embeddings_cache[text] = embedding | |
logger.info(f'Adding {len(non_existing_embeddings)} new embeddings.') | |
args = {'embeddings': non_existing_embeddings, 'documents': non_existing_texts, 'ids': non_existing_ids} | |
if metadatas is not None: | |
args['metadatas'] = non_existing_metas | |
self.collection.add(**args) | |
# Create a dictionary that maps each ID to its context and starting index | |
new_info = { | |
id_: {'text_with_context': context, 'start_index': start_index} | |
for id_, context, start_index in zip(new_ids, texts_with_context, starting_indices) | |
} | |
self.id_to_info.update(new_info) | |
self.ids.extend(new_ids) | |
def _split_texts_by_cache_hit(self, texts: list[str], new_ids: list[str], metadatas: list[dict]): | |
existing_texts, non_existing_texts = [], [] | |
existing_embeddings = [] | |
existing_ids, non_existing_ids = [], [] | |
existing_metas, non_existing_metas = [], [] | |
for i, text in enumerate(texts): | |
id_ = new_ids[i] | |
metadata = metadatas[i] if metadatas is not None else None | |
embedding = self.embeddings_cache.get(text) | |
if embedding: | |
existing_texts.append(text) | |
existing_embeddings.append(embedding) | |
existing_ids.append(id_) | |
existing_metas.append(metadata) | |
else: | |
non_existing_texts.append(text) | |
non_existing_ids.append(id_) | |
non_existing_metas.append(metadata) | |
return (existing_texts, existing_embeddings, existing_ids, existing_metas), \ | |
(non_existing_texts, non_existing_ids, non_existing_metas) | |
def _get_new_ids(self, num_new_ids: int): | |
if self.ids: | |
max_existing_id = max(int(id_) for id_ in self.ids) | |
else: | |
max_existing_id = -1 | |
return [str(i + max_existing_id + 1) for i in range(num_new_ids)] | |
def _find_min_max_start_index(self): | |
max_index, min_index = 0, float('inf') | |
for _, val in self.id_to_info.items(): | |
if val['start_index'] > max_index: | |
max_index = val['start_index'] | |
if val['start_index'] < min_index: | |
min_index = val['start_index'] | |
return min_index, max_index | |
# NB: Does not make sense to weigh excerpts from different documents. | |
# But let's say that's the user's problem. Perfect world scenario: | |
# Apply time weighing to different documents. For each document, then, add | |
# separate time weighing. | |
def _apply_sigmoid_time_weighing(self, infos: list[Info], document_len: int, time_steepness: float, time_power: float): | |
sigmoid = lambda x: 1 / (1 + np.exp(-x)) | |
weights = sigmoid(time_steepness * np.linspace(-10, 10, document_len)) | |
# Scale to [0,time_power] and shift it up to [1-time_power, 1] | |
weights = weights - min(weights) | |
weights = weights * (time_power / max(weights)) | |
weights = weights + (1 - time_power) | |
# Reverse the weights | |
weights = weights[::-1] | |
for info in infos: | |
index = info.start_index | |
info.distance *= weights[index] | |
def _filter_outliers_by_median_distance(self, infos: list[Info], significant_level: float): | |
# Ensure there are infos to filter | |
if not infos: | |
return [] | |
# Find info with minimum distance | |
min_info = min(infos, key=lambda x: x.distance) | |
# Calculate median distance among infos | |
median_distance = np.median([inf.distance for inf in infos]) | |
# Filter out infos that have a distance significantly greater than the median | |
filtered_infos = [inf for inf in infos if inf.distance <= significant_level * median_distance] | |
# Always include the info with minimum distance | |
if min_info not in filtered_infos: | |
filtered_infos.append(min_info) | |
return filtered_infos | |
def _merge_infos(self, infos: list[Info]): | |
merged_infos = [] | |
current_info = infos[0] | |
for next_info in infos[1:]: | |
merged = current_info.merge_with(next_info) | |
if merged is not None: | |
current_info = merged | |
else: | |
merged_infos.append(current_info) | |
current_info = next_info | |
merged_infos.append(current_info) | |
return merged_infos | |
# Main function for retrieving chunks by distance. It performs merging, time weighing, and mean filtering. | |
def _get_documents_ids_distances(self, search_strings: list[str], n_results: int): | |
n_results = min(len(self.ids), n_results) | |
if n_results == 0: | |
return [], [], [] | |
if isinstance(search_strings, str): | |
search_strings = [search_strings] | |
infos = [] | |
min_start_index, max_start_index = self._find_min_max_start_index() | |
for search_string in search_strings: | |
result = self.collection.query(query_texts=search_string, n_results=math.ceil(n_results / len(search_strings)), include=['distances']) | |
curr_infos = [Info(start_index=self.id_to_info[id]['start_index'], | |
text_with_context=self.id_to_info[id]['text_with_context'], | |
distance=distance, id=id) | |
for id, distance in zip(result['ids'][0], result['distances'][0])] | |
self._apply_sigmoid_time_weighing(infos=curr_infos, document_len=max_start_index - min_start_index + 1, time_steepness=parameters.get_time_steepness(), time_power=parameters.get_time_power()) | |
curr_infos = self._filter_outliers_by_median_distance(curr_infos, parameters.get_significant_level()) | |
infos.extend(curr_infos) | |
infos.sort(key=lambda x: x.start_index) | |
infos = self._merge_infos(infos) | |
texts_with_context = [inf.text_with_context for inf in infos] | |
ids = [inf.id for inf in infos] | |
distances = [inf.distance for inf in infos] | |
return texts_with_context, ids, distances | |
# Get chunks by similarity | |
def get(self, search_strings: list[str], n_results: int) -> list[str]: | |
with self.lock: | |
documents, _, _ = self._get_documents_ids_distances(search_strings, n_results) | |
return documents | |
# Get ids by similarity | |
def get_ids(self, search_strings: list[str], n_results: int) -> list[str]: | |
with self.lock: | |
_, ids, _ = self._get_documents_ids_distances(search_strings, n_results) | |
return ids | |
# Cutoff token count | |
def _get_documents_up_to_token_count(self, documents: list[str], max_token_count: int): | |
# TODO: Move to caller; We add delimiters there which might go over the limit. | |
current_token_count = 0 | |
return_documents = [] | |
for doc in documents: | |
doc_tokens = encode(doc)[0] | |
doc_token_count = len(doc_tokens) | |
if current_token_count + doc_token_count > max_token_count: | |
# If adding this document would exceed the max token count, | |
# truncate the document to fit within the limit. | |
remaining_tokens = max_token_count - current_token_count | |
truncated_doc = decode(doc_tokens[:remaining_tokens], skip_special_tokens=True) | |
return_documents.append(truncated_doc) | |
break | |
else: | |
return_documents.append(doc) | |
current_token_count += doc_token_count | |
return return_documents | |
# Get chunks by similarity and then sort by ids | |
def get_sorted_by_ids(self, search_strings: list[str], n_results: int, max_token_count: int) -> list[str]: | |
with self.lock: | |
documents, ids, _ = self._get_documents_ids_distances(search_strings, n_results) | |
sorted_docs = [x for _, x in sorted(zip(ids, documents))] | |
return self._get_documents_up_to_token_count(sorted_docs, max_token_count) | |
# Get chunks by similarity and then sort by distance (lowest distance is last). | |
def get_sorted_by_dist(self, search_strings: list[str], n_results: int, max_token_count: int) -> list[str]: | |
with self.lock: | |
documents, _, distances = self._get_documents_ids_distances(search_strings, n_results) | |
sorted_docs = [doc for doc, _ in sorted(zip(documents, distances), key=lambda x: x[1])] # sorted lowest -> highest | |
# If a document is truncated or competely skipped, it would be with high distance. | |
return_documents = self._get_documents_up_to_token_count(sorted_docs, max_token_count) | |
return_documents.reverse() # highest -> lowest | |
return return_documents | |
def delete(self, ids_to_delete: list[str], where: dict): | |
with self.lock: | |
ids_to_delete = self.collection.get(ids=ids_to_delete, where=where)['ids'] | |
self.collection.delete(ids=ids_to_delete, where=where) | |
# Remove the deleted ids from self.ids and self.id_to_info | |
ids_set = set(ids_to_delete) | |
self.ids = [id_ for id_ in self.ids if id_ not in ids_set] | |
for id_ in ids_to_delete: | |
self.id_to_info.pop(id_, None) | |
logger.info(f'Successfully deleted {len(ids_to_delete)} records from chromaDB.') | |
def clear(self): | |
with self.lock: | |
self.chroma_client.reset() | |
self.collection = self.chroma_client.create_collection("context", embedding_function=self.embedder.embed) | |
self.ids = [] | |
self.id_to_info = {} | |
logger.info('Successfully cleared all records and reset chromaDB.') | |
class SentenceTransformerEmbedder(Embedder): | |
def __init__(self) -> None: | |
logger.debug('Creating Sentence Embedder...') | |
self.model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2") | |
self.embed = self.model.encode | |
def make_collector(): | |
return ChromaCollector(SentenceTransformerEmbedder()) |