File size: 15,476 Bytes
0eeee8c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 |
import threading
import chromadb
import posthog
import torch
import math
import numpy as np
import extensions.superboogav2.parameters as parameters
from chromadb.config import Settings
from sentence_transformers import SentenceTransformer
from modules.logging_colors import logger
from modules.text_generation import encode, decode
logger.debug('Intercepting all calls to posthog.')
posthog.capture = lambda *args, **kwargs: None
class Collecter():
def __init__(self):
pass
def add(self, texts: list[str], texts_with_context: list[str], starting_indices: list[int]):
pass
def get(self, search_strings: list[str], n_results: int) -> list[str]:
pass
def clear(self):
pass
class Embedder():
def __init__(self):
pass
def embed(self, text: str) -> list[torch.Tensor]:
pass
class Info:
def __init__(self, start_index, text_with_context, distance, id):
self.text_with_context = text_with_context
self.start_index = start_index
self.distance = distance
self.id = id
def calculate_distance(self, other_info):
if parameters.get_new_dist_strategy() == parameters.DIST_MIN_STRATEGY:
# Min
return min(self.distance, other_info.distance)
elif parameters.get_new_dist_strategy() == parameters.DIST_HARMONIC_STRATEGY:
# Harmonic mean
return 2 * (self.distance * other_info.distance) / (self.distance + other_info.distance)
elif parameters.get_new_dist_strategy() == parameters.DIST_GEOMETRIC_STRATEGY:
# Geometric mean
return (self.distance * other_info.distance) ** 0.5
elif parameters.get_new_dist_strategy() == parameters.DIST_ARITHMETIC_STRATEGY:
# Arithmetic mean
return (self.distance + other_info.distance) / 2
else: # Min is default
return min(self.distance, other_info.distance)
def merge_with(self, other_info):
s1 = self.text_with_context
s2 = other_info.text_with_context
s1_start = self.start_index
s2_start = other_info.start_index
new_dist = self.calculate_distance(other_info)
if self.should_merge(s1, s2, s1_start, s2_start):
if s1_start <= s2_start:
if s1_start + len(s1) >= s2_start + len(s2): # if s1 completely covers s2
return Info(s1_start, s1, new_dist, self.id)
else:
overlap = max(0, s1_start + len(s1) - s2_start)
return Info(s1_start, s1 + s2[overlap:], new_dist, self.id)
else:
if s2_start + len(s2) >= s1_start + len(s1): # if s2 completely covers s1
return Info(s2_start, s2, new_dist, other_info.id)
else:
overlap = max(0, s2_start + len(s2) - s1_start)
return Info(s2_start, s2 + s1[overlap:], new_dist, other_info.id)
return None
@staticmethod
def should_merge(s1, s2, s1_start, s2_start):
# Check if s1 and s2 are adjacent or overlapping
s1_end = s1_start + len(s1)
s2_end = s2_start + len(s2)
return not (s1_end < s2_start or s2_end < s1_start)
class ChromaCollector(Collecter):
def __init__(self, embedder: Embedder):
super().__init__()
self.chroma_client = chromadb.Client(Settings(anonymized_telemetry=False))
self.embedder = embedder
self.collection = self.chroma_client.create_collection(name="context", embedding_function=self.embedder.embed)
self.ids = []
self.id_to_info = {}
self.embeddings_cache = {}
self.lock = threading.Lock() # Locking so the server doesn't break.
def add(self, texts: list[str], texts_with_context: list[str], starting_indices: list[int], metadatas: list[dict] = None):
with self.lock:
assert metadatas is None or len(metadatas) == len(texts), "metadatas must be None or have the same length as texts"
if len(texts) == 0:
return
new_ids = self._get_new_ids(len(texts))
(existing_texts, existing_embeddings, existing_ids, existing_metas), \
(non_existing_texts, non_existing_ids, non_existing_metas) = self._split_texts_by_cache_hit(texts, new_ids, metadatas)
# If there are any already existing texts, add them all at once.
if existing_texts:
logger.info(f'Adding {len(existing_embeddings)} cached embeddings.')
args = {'embeddings': existing_embeddings, 'documents': existing_texts, 'ids': existing_ids}
if metadatas is not None:
args['metadatas'] = existing_metas
self.collection.add(**args)
# If there are any non-existing texts, compute their embeddings all at once. Each call to embed has significant overhead.
if non_existing_texts:
non_existing_embeddings = self.embedder.embed(non_existing_texts).tolist()
for text, embedding in zip(non_existing_texts, non_existing_embeddings):
self.embeddings_cache[text] = embedding
logger.info(f'Adding {len(non_existing_embeddings)} new embeddings.')
args = {'embeddings': non_existing_embeddings, 'documents': non_existing_texts, 'ids': non_existing_ids}
if metadatas is not None:
args['metadatas'] = non_existing_metas
self.collection.add(**args)
# Create a dictionary that maps each ID to its context and starting index
new_info = {
id_: {'text_with_context': context, 'start_index': start_index}
for id_, context, start_index in zip(new_ids, texts_with_context, starting_indices)
}
self.id_to_info.update(new_info)
self.ids.extend(new_ids)
def _split_texts_by_cache_hit(self, texts: list[str], new_ids: list[str], metadatas: list[dict]):
existing_texts, non_existing_texts = [], []
existing_embeddings = []
existing_ids, non_existing_ids = [], []
existing_metas, non_existing_metas = [], []
for i, text in enumerate(texts):
id_ = new_ids[i]
metadata = metadatas[i] if metadatas is not None else None
embedding = self.embeddings_cache.get(text)
if embedding:
existing_texts.append(text)
existing_embeddings.append(embedding)
existing_ids.append(id_)
existing_metas.append(metadata)
else:
non_existing_texts.append(text)
non_existing_ids.append(id_)
non_existing_metas.append(metadata)
return (existing_texts, existing_embeddings, existing_ids, existing_metas), \
(non_existing_texts, non_existing_ids, non_existing_metas)
def _get_new_ids(self, num_new_ids: int):
if self.ids:
max_existing_id = max(int(id_) for id_ in self.ids)
else:
max_existing_id = -1
return [str(i + max_existing_id + 1) for i in range(num_new_ids)]
def _find_min_max_start_index(self):
max_index, min_index = 0, float('inf')
for _, val in self.id_to_info.items():
if val['start_index'] > max_index:
max_index = val['start_index']
if val['start_index'] < min_index:
min_index = val['start_index']
return min_index, max_index
# NB: Does not make sense to weigh excerpts from different documents.
# But let's say that's the user's problem. Perfect world scenario:
# Apply time weighing to different documents. For each document, then, add
# separate time weighing.
def _apply_sigmoid_time_weighing(self, infos: list[Info], document_len: int, time_steepness: float, time_power: float):
sigmoid = lambda x: 1 / (1 + np.exp(-x))
weights = sigmoid(time_steepness * np.linspace(-10, 10, document_len))
# Scale to [0,time_power] and shift it up to [1-time_power, 1]
weights = weights - min(weights)
weights = weights * (time_power / max(weights))
weights = weights + (1 - time_power)
# Reverse the weights
weights = weights[::-1]
for info in infos:
index = info.start_index
info.distance *= weights[index]
def _filter_outliers_by_median_distance(self, infos: list[Info], significant_level: float):
# Ensure there are infos to filter
if not infos:
return []
# Find info with minimum distance
min_info = min(infos, key=lambda x: x.distance)
# Calculate median distance among infos
median_distance = np.median([inf.distance for inf in infos])
# Filter out infos that have a distance significantly greater than the median
filtered_infos = [inf for inf in infos if inf.distance <= significant_level * median_distance]
# Always include the info with minimum distance
if min_info not in filtered_infos:
filtered_infos.append(min_info)
return filtered_infos
def _merge_infos(self, infos: list[Info]):
merged_infos = []
current_info = infos[0]
for next_info in infos[1:]:
merged = current_info.merge_with(next_info)
if merged is not None:
current_info = merged
else:
merged_infos.append(current_info)
current_info = next_info
merged_infos.append(current_info)
return merged_infos
# Main function for retrieving chunks by distance. It performs merging, time weighing, and mean filtering.
def _get_documents_ids_distances(self, search_strings: list[str], n_results: int):
n_results = min(len(self.ids), n_results)
if n_results == 0:
return [], [], []
if isinstance(search_strings, str):
search_strings = [search_strings]
infos = []
min_start_index, max_start_index = self._find_min_max_start_index()
for search_string in search_strings:
result = self.collection.query(query_texts=search_string, n_results=math.ceil(n_results / len(search_strings)), include=['distances'])
curr_infos = [Info(start_index=self.id_to_info[id]['start_index'],
text_with_context=self.id_to_info[id]['text_with_context'],
distance=distance, id=id)
for id, distance in zip(result['ids'][0], result['distances'][0])]
self._apply_sigmoid_time_weighing(infos=curr_infos, document_len=max_start_index - min_start_index + 1, time_steepness=parameters.get_time_steepness(), time_power=parameters.get_time_power())
curr_infos = self._filter_outliers_by_median_distance(curr_infos, parameters.get_significant_level())
infos.extend(curr_infos)
infos.sort(key=lambda x: x.start_index)
infos = self._merge_infos(infos)
texts_with_context = [inf.text_with_context for inf in infos]
ids = [inf.id for inf in infos]
distances = [inf.distance for inf in infos]
return texts_with_context, ids, distances
# Get chunks by similarity
def get(self, search_strings: list[str], n_results: int) -> list[str]:
with self.lock:
documents, _, _ = self._get_documents_ids_distances(search_strings, n_results)
return documents
# Get ids by similarity
def get_ids(self, search_strings: list[str], n_results: int) -> list[str]:
with self.lock:
_, ids, _ = self._get_documents_ids_distances(search_strings, n_results)
return ids
# Cutoff token count
def _get_documents_up_to_token_count(self, documents: list[str], max_token_count: int):
# TODO: Move to caller; We add delimiters there which might go over the limit.
current_token_count = 0
return_documents = []
for doc in documents:
doc_tokens = encode(doc)[0]
doc_token_count = len(doc_tokens)
if current_token_count + doc_token_count > max_token_count:
# If adding this document would exceed the max token count,
# truncate the document to fit within the limit.
remaining_tokens = max_token_count - current_token_count
truncated_doc = decode(doc_tokens[:remaining_tokens], skip_special_tokens=True)
return_documents.append(truncated_doc)
break
else:
return_documents.append(doc)
current_token_count += doc_token_count
return return_documents
# Get chunks by similarity and then sort by ids
def get_sorted_by_ids(self, search_strings: list[str], n_results: int, max_token_count: int) -> list[str]:
with self.lock:
documents, ids, _ = self._get_documents_ids_distances(search_strings, n_results)
sorted_docs = [x for _, x in sorted(zip(ids, documents))]
return self._get_documents_up_to_token_count(sorted_docs, max_token_count)
# Get chunks by similarity and then sort by distance (lowest distance is last).
def get_sorted_by_dist(self, search_strings: list[str], n_results: int, max_token_count: int) -> list[str]:
with self.lock:
documents, _, distances = self._get_documents_ids_distances(search_strings, n_results)
sorted_docs = [doc for doc, _ in sorted(zip(documents, distances), key=lambda x: x[1])] # sorted lowest -> highest
# If a document is truncated or competely skipped, it would be with high distance.
return_documents = self._get_documents_up_to_token_count(sorted_docs, max_token_count)
return_documents.reverse() # highest -> lowest
return return_documents
def delete(self, ids_to_delete: list[str], where: dict):
with self.lock:
ids_to_delete = self.collection.get(ids=ids_to_delete, where=where)['ids']
self.collection.delete(ids=ids_to_delete, where=where)
# Remove the deleted ids from self.ids and self.id_to_info
ids_set = set(ids_to_delete)
self.ids = [id_ for id_ in self.ids if id_ not in ids_set]
for id_ in ids_to_delete:
self.id_to_info.pop(id_, None)
logger.info(f'Successfully deleted {len(ids_to_delete)} records from chromaDB.')
def clear(self):
with self.lock:
self.chroma_client.reset()
self.collection = self.chroma_client.create_collection("context", embedding_function=self.embedder.embed)
self.ids = []
self.id_to_info = {}
logger.info('Successfully cleared all records and reset chromaDB.')
class SentenceTransformerEmbedder(Embedder):
def __init__(self) -> None:
logger.debug('Creating Sentence Embedder...')
self.model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
self.embed = self.model.encode
def make_collector():
return ChromaCollector(SentenceTransformerEmbedder()) |