Spaces:
Sleeping
Sleeping
# # Files used: | |
# conf.files.context | |
# conf.files.index | |
# conf.files.embeddings | |
import logging | |
import csv | |
import io | |
import hydra | |
from omegaconf import DictConfig, OmegaConf | |
import faiss | |
from datasets import load_dataset, IterableDataset | |
from tqdm.auto import tqdm | |
from langchain_community.docstore.in_memory import InMemoryDocstore | |
from langchain_huggingface import HuggingFaceEmbeddings | |
from langchain_community.vectorstores import FAISS | |
from langchain_core.documents import Document | |
log = logging.getLogger(__name__) | |
def get_hnsw_pq_index(conf: DictConfig) -> faiss.IndexHNSWPQ: | |
index = faiss.index_factory( | |
conf.embeddings.dim, | |
f"HNSW{conf.indexes.hnsw_pq.hnsw_m},PQ{conf.indexes.hnsw_pq.pq_m}x{conf.indexes.hnsw_pq.bits}", | |
) | |
index.hnsw.efConstruction = conf.indexes.hnsw_pq.ef_construction | |
index.hnsw.efSearch = conf.indexes.hnsw_pq.ef_search | |
return index | |
def get_hnsw_index(conf: DictConfig) -> faiss.IndexHNSW: | |
index = faiss.index_factory( | |
conf.embeddings.dim, | |
f"HNSW{conf.indexes.hnsw_pq.hnsw_m}", | |
) | |
index.hnsw.efConstruction = conf.indexes.hnsw_pq.ef_construction | |
index.hnsw.efSearch = conf.indexes.hnsw_pq.ef_search | |
return index | |
def get_index(conf: DictConfig) -> faiss.Index: | |
if conf.indexing.index_type == "hnsw_pq": | |
return get_hnsw_pq_index(conf) | |
elif conf.indexing.index_type == "hnsw": | |
return get_hnsw_index(conf) | |
else: | |
raise ValueError(f"Unknown index type: {conf.indexing.index_type}") | |
def get_embedding_model(conf: DictConfig) -> HuggingFaceEmbeddings: | |
return HuggingFaceEmbeddings( | |
model_name=conf.embeddings.model, | |
model_kwargs={"trust_remote_code": True, "device": conf.embeddings.device}, | |
) | |
def get_faiss_store(conf: DictConfig, embedding_model: HuggingFaceEmbeddings) -> FAISS: | |
index = get_index(conf) | |
return FAISS( | |
embedding_function=embedding_model, | |
index=index, | |
docstore=InMemoryDocstore(), | |
index_to_docstore_id={}, | |
) | |
def get_dataset( | |
conf, | |
) -> IterableDataset: | |
with open(conf.files.context, "r") as f: | |
num_lines = sum(1 for _ in f) | |
dataset = load_dataset( | |
"json", | |
data_files={"train": [conf.files.context]}, | |
streaming=True, | |
split="train", | |
) | |
batched_dataset = dataset.batch(batch_size=conf.indexing.batch_size) | |
return batched_dataset, num_lines | |
def add_batch( | |
embedding_model: HuggingFaceEmbeddings, | |
faiss_store: FAISS, | |
csv_writer: csv.writer, | |
f: io.TextIOWrapper, | |
batch: dict, | |
) -> None: | |
embeddings = embedding_model.embed_documents(batch["text_content"]) | |
text_and_embeddings = zip(batch["text_content"], embeddings) | |
metadatas = [ | |
{ | |
"source_name": source_name, | |
"associated_dates": associated_date, | |
"chunk_id": chunk_id, | |
} | |
for source_name, associated_date, chunk_id in zip( | |
batch["source_name"], batch["associated_dates"], batch["chunk_id"] | |
) | |
] | |
faiss_store.add_embeddings( | |
text_embeddings=text_and_embeddings, metadatas=metadatas, ids=batch["chunk_id"] | |
) | |
csv_writer.writerows([[f"{x:.8f}" for x in row] for row in embeddings]) | |
f.flush() | |
def save_index(conf: DictConfig, faiss_store: FAISS) -> None: | |
faiss_store.save_local(conf.files.index) | |
def main(conf: DictConfig) -> None: | |
log.info("Config:\n%s", OmegaConf.to_yaml(conf)) | |
embedding_model = get_embedding_model(conf) | |
faiss_store = get_faiss_store(conf, embedding_model) | |
dataset, num_lines = get_dataset(conf) | |
with open(conf.files.embeddings, "a", newline="") as f: | |
csv_writer = csv.writer(f, delimiter="\t") | |
for batch in tqdm(dataset, total=num_lines // conf.indexing.batch_size): | |
add_batch(embedding_model, faiss_store, csv_writer, f, batch) | |
save_index(conf, faiss_store) | |
if __name__ == "__main__": | |
main() | |