|
import faiss |
|
import shutil |
|
from beartype import beartype |
|
import numpy as np |
|
import json |
|
import argparse |
|
from zsvision.zs_utils import BlockTimer |
|
import tiktoken |
|
from pathlib import Path |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from pipeline_paths import PIPELINE_PATHS |
|
from llm_api_utils import ( |
|
init_openai_with_api_key, |
|
EMBEDDING_DIMENSIONS, |
|
PRICE_PER_1K_TOKENS, |
|
) |
|
from langchain_community.vectorstores import FAISS |
|
from langchain_openai import OpenAIEmbeddings |
|
from langchain.docstore.in_memory import InMemoryDocstore |
|
|
|
|
|
class EmbedResults: |
|
def __init__( |
|
self, |
|
embedding_model="ada", |
|
limit=0, |
|
refresh=False, |
|
refresh_faiss_db=False, |
|
text_embedding_chunk_size=500, |
|
filter_str="", |
|
): |
|
self.embedding_model = embedding_model |
|
self.limit = limit |
|
self.refresh = refresh |
|
self.refresh_faiss_db = refresh_faiss_db |
|
self.text_embedding_chunk_size = text_embedding_chunk_size |
|
self.filter_str = filter_str |
|
|
|
@beartype |
|
def compute_embeddings_from_chunks( |
|
self, embedding_function: OpenAIEmbeddings, metadatas: list, faiss_db |
|
): |
|
doc_chunks = [] |
|
metadatas_without_chunks = [] |
|
for metadata in metadatas: |
|
doc_chunk = metadata.pop("doc_chunk") |
|
doc_chunks.append(doc_chunk) |
|
metadatas_without_chunks.append(metadata) |
|
|
|
with BlockTimer(f"Embedding {len(metadatas)} fragments"): |
|
embeddings = embedding_function.embed_documents(doc_chunks) |
|
|
|
faiss_db._FAISS__add(doc_chunks, embeddings, metadatas_without_chunks) |
|
|
|
return faiss_db |
|
|
|
@beartype |
|
def parse_date_of_fetching(self, data: dict) -> str: |
|
evidence_keys = { |
|
"search_results_fetched", |
|
"results_fetched_from_wikipedia_1M_with_cohere-22-12", |
|
} |
|
for key in evidence_keys: |
|
if key in data["dates"]: |
|
evidence_fetched_date = data["dates"][key] |
|
return evidence_fetched_date |
|
raise ValueError(f"Could not find evidence fetched date in {data['dates']}") |
|
|
|
def embed_for_uuid(self, srcs): |
|
init_openai_with_api_key() |
|
|
|
embedding_function = OpenAIEmbeddings() |
|
|
|
index = faiss.IndexFlatL2(EMBEDDING_DIMENSIONS[self.embedding_model]) |
|
docstore = InMemoryDocstore({}) |
|
index_to_docstore_id = {} |
|
faiss_db = FAISS( |
|
embedding_function=embedding_function.embed_query, |
|
index=index, |
|
docstore=docstore, |
|
index_to_docstore_id=index_to_docstore_id, |
|
) |
|
|
|
already_embedded_chunks = { |
|
doc.metadata["chunk_tag"] for doc in faiss_db.docstore._dict.values() |
|
} |
|
|
|
splitter = RecursiveCharacterTextSplitter( |
|
chunk_size=self.text_embedding_chunk_size, |
|
chunk_overlap=0, |
|
) |
|
|
|
kwarg_list = [] |
|
seen_links = set() |
|
metadatas = [] |
|
total_chunks = 0 |
|
chunks_to_embed = 0 |
|
chunks_to_skip = 0 |
|
|
|
for data in srcs: |
|
evidence_fetched_date = self.parse_date_of_fetching(data) |
|
|
|
for document in data["documents"]: |
|
for search_result in document["search_results"]: |
|
|
|
if search_result["link"] in seen_links: |
|
continue |
|
seen_links.add(search_result["link"]) |
|
|
|
doc_chunks = [ |
|
doc.page_content |
|
for doc in splitter.create_documents([search_result["text"]]) |
|
] |
|
chunk_tags = [ |
|
f"{search_result['link']}-chunk-{idx}-chunk_sz-{self.text_embedding_chunk_size}" |
|
for idx in range(len(doc_chunks)) |
|
] |
|
for doc_chunk, chunk_tag in zip(doc_chunks, chunk_tags): |
|
if chunk_tag not in already_embedded_chunks: |
|
metadatas.append( |
|
{ |
|
"doc_chunk": doc_chunk, |
|
"link": search_result["link"], |
|
"chunk_tag": chunk_tag, |
|
"date_accessed": evidence_fetched_date, |
|
"query": document["claim"], |
|
} |
|
) |
|
chunks_to_embed += 1 |
|
else: |
|
chunks_to_skip += 1 |
|
total_chunks += len(doc_chunks) |
|
|
|
encoding = tiktoken.encoding_for_model(self.embedding_model) |
|
doc_chunks = [x["doc_chunk"] for x in metadatas] |
|
num_words = len(" ".join(doc_chunks).split()) |
|
num_tokens = len(encoding.encode("".join(doc_chunks))) |
|
|
|
print( |
|
f"Created {total_chunks} chunks of text to answer from {len(seen_links)} websites" |
|
) |
|
print( |
|
f"Embedding {chunks_to_embed} (skipping {chunks_to_skip}) chunks of text from {len(seen_links)} websites)" |
|
) |
|
print( |
|
f"Embedding {num_tokens} tokens ({num_words} words) from {len(doc_chunks)} chunks" |
|
) |
|
print( |
|
f"Step5: Estimated cost: {num_tokens * PRICE_PER_1K_TOKENS[self.embedding_model]['embed'] / 1000:.2f} USD" |
|
) |
|
|
|
if metadatas: |
|
self.compute_embeddings_from_chunks( |
|
embedding_function=embedding_function, |
|
faiss_db=faiss_db, |
|
metadatas=metadatas, |
|
) |
|
|
|
return faiss_db |
|
return None |
|
|
|
def embed(self): |
|
init_openai_with_api_key() |
|
src_paths = [] |
|
for evidence_key in ( |
|
"google_search_results_evidence", |
|
"cohere_wikipedia_evidence", |
|
): |
|
evidence_paths = list(PIPELINE_PATHS[evidence_key].glob("**/*.json")) |
|
src_paths.extend(evidence_paths) |
|
|
|
if self.filter_str: |
|
num_paths = len(src_paths) |
|
src_paths = [ |
|
src_path for src_path in src_paths if self.filter_str in src_path.name |
|
] |
|
print( |
|
f"Filtering for {self.filter_str} (from {num_paths} to {len(src_paths)})" |
|
) |
|
|
|
print(f"Found {len(src_paths)} collections of evidence") |
|
src_paths = sorted(src_paths) |
|
|
|
embedding_function = OpenAIEmbeddings() |
|
faiss_persist_dir = ( |
|
PIPELINE_PATHS["faiss_db_embeddings_for_evidence"] |
|
/ f"{self.embedding_model}_chunk_size_{self.text_embedding_chunk_size}" |
|
) |
|
|
|
if faiss_persist_dir.exists(): |
|
if self.refresh_faiss_db: |
|
print(f"Deleting existing database at {faiss_persist_dir}") |
|
shutil.rmtree(faiss_persist_dir) |
|
|
|
|
|
if faiss_persist_dir.exists() and not self.refresh_faiss_db: |
|
faiss_db = FAISS.load_local( |
|
folder_path=str(faiss_persist_dir), embeddings=embedding_function |
|
) |
|
print(f"Found existing database at {faiss_persist_dir}, using... ") |
|
else: |
|
index = faiss.IndexFlatL2(EMBEDDING_DIMENSIONS[self.embedding_model]) |
|
docstore = InMemoryDocstore({}) |
|
index_to_docstore_id = {} |
|
faiss_db = FAISS( |
|
embedding_function=embedding_function.embed_query, |
|
index=index, |
|
docstore=docstore, |
|
index_to_docstore_id=index_to_docstore_id, |
|
) |
|
print(f"Persisting intialised database to {faiss_persist_dir}") |
|
faiss_db.save_local(folder_path=str(faiss_persist_dir)) |
|
|
|
already_embedded_chunks = { |
|
doc.metadata["chunk_tag"] for doc in faiss_db.docstore._dict.values() |
|
} |
|
|
|
splitter = RecursiveCharacterTextSplitter( |
|
chunk_size=self.text_embedding_chunk_size, |
|
chunk_overlap=0, |
|
) |
|
|
|
kwarg_list = [] |
|
seen_links = set() |
|
metadatas = [] |
|
total_chunks = 0 |
|
chunks_to_embed = 0 |
|
chunks_to_skip = 0 |
|
|
|
for src_path in src_paths: |
|
with open(src_path, "r") as f: |
|
data = json.load(f) |
|
|
|
evidence_fetched_date = self.parse_date_of_fetching(data) |
|
|
|
for document in data["documents"]: |
|
for search_result in document["search_results"]: |
|
|
|
if search_result["link"] in seen_links: |
|
continue |
|
seen_links.add(search_result["link"]) |
|
|
|
doc_chunks = [ |
|
doc.page_content |
|
for doc in splitter.create_documents([search_result["text"]]) |
|
] |
|
chunk_tags = [ |
|
f"{search_result['link']}-chunk-{idx}-chunk_sz-{self.text_embedding_chunk_size}" |
|
for idx in range(len(doc_chunks)) |
|
] |
|
for doc_chunk, chunk_tag in zip(doc_chunks, chunk_tags): |
|
if chunk_tag not in already_embedded_chunks: |
|
metadatas.append( |
|
{ |
|
"doc_chunk": doc_chunk, |
|
"link": search_result["link"], |
|
"chunk_tag": chunk_tag, |
|
"date_accessed": evidence_fetched_date, |
|
"query": document["claim"], |
|
} |
|
) |
|
chunks_to_embed += 1 |
|
else: |
|
chunks_to_skip += 1 |
|
total_chunks += len(doc_chunks) |
|
|
|
encoding = tiktoken.encoding_for_model(self.embedding_model) |
|
doc_chunks = [x["doc_chunk"] for x in metadatas] |
|
num_words = len(" ".join(doc_chunks).split()) |
|
num_tokens = len(encoding.encode("".join(doc_chunks))) |
|
|
|
print( |
|
f"Created {total_chunks} chunks of text to answer from {len(seen_links)} websites" |
|
) |
|
print( |
|
f"Embedding {chunks_to_embed} (skipping {chunks_to_skip}) chunks of text from {len(seen_links)} websites)" |
|
) |
|
print( |
|
f"Embedding {num_tokens} tokens ({num_words} words) from {len(doc_chunks)} chunks" |
|
) |
|
print( |
|
f"Estimated cost: {num_tokens * PRICE_PER_1K_TOKENS[self.embedding_model]['embed'] / 1000:.2f} USD" |
|
) |
|
|
|
if metadatas: |
|
self.compute_embeddings_from_chunks( |
|
embedding_function=embedding_function, |
|
faiss_persist_dir=faiss_persist_dir, |
|
metadatas=metadatas, |
|
) |
|
|