import os import pickle from collections import defaultdict from typing import List, Tuple import numpy as np import scipy import torch import tqdm from loguru import logger from transformers import AutoModelForMaskedLM, AutoTokenizer from app.config.models.configs import Config, Document from app.utils import torch_device, split class SpladeSparseVectorDB: def __init__( self, config: Config, ) -> None: self._config = config # cuda or mps or cpu self._device = torch_device() logger.info(f"Setting device to {self._device}") self.tokenizer = AutoTokenizer.from_pretrained( "naver/splade-v3", device=self._device, use_fast=True ) self.model = AutoModelForMaskedLM.from_pretrained("naver/splade-v3") self.model.to(self._device) self._embeddings = None self._ids = None self._l2_norm_matrix = None self._labels_to_ind = defaultdict(list) self._chunk_size_to_ind = defaultdict(list) self.n_batch = config.embeddings.splade_config.n_batch def _get_batch_embeddings( self, docs: List[str] ) -> np.ndarray: tokens = self.tokenizer( docs, return_tensors="pt", padding=True, truncation=True ).to(self._device) output = self.model(**tokens) vecs = ( torch.max( torch.log(1 + torch.relu(output.logits)) * tokens.attention_mask.unsqueeze(-1), dim=1, )[0] .squeeze() .detach() .cpu() .numpy() ) del output del tokens return vecs def _get_embedding_fnames(self): folder_name = os.path.join(self._config.embeddings.embeddings_path, "splade") fn_embeddings = os.path.join(folder_name, "splade_embeddings.npz") fn_ids = os.path.join(folder_name, "splade_ids.pickle") fn_metadatas = os.path.join(folder_name, "splade_metadatas.pickle") return folder_name, fn_embeddings, fn_ids, fn_metadatas def load(self) -> None: _, fn_embeddings, fn_ids, fn_metadatas = self._get_embedding_fnames() try: self._embeddings = scipy.sparse.load_npz(fn_embeddings) with open(fn_ids, "rb") as fp: self._ids = np.array(pickle.load(fp)) with open(fn_metadatas, "rb") as fm: self._metadatas = np.array(pickle.load(fm)) self._l2_norm_matrix = scipy.sparse.linalg.norm(self._embeddings, axis=1) for ind, m in enumerate(self._metadatas): if m["label"]: self._labels_to_ind[m["label"]].append(ind) self._chunk_size_to_ind[m["chunk_size"]].append(ind) logger.info(f"SPLADE: Got {len(self._labels_to_ind)} labels.") except FileNotFoundError: raise FileNotFoundError( "Embeddings don't exist" ) logger.info(f"Loaded sparse embeddings from {fn_embeddings}") def generate_embeddings( self, docs: List[Document], persist: bool = True ) -> Tuple[np.ndarray, List[str], List[dict]]: chunk_size = self.n_batch ids = [d.metadata["document_id"] for d in docs] metadatas = [d.metadata for d in docs] vecs = [] for chunk in tqdm.tqdm( split(docs, chunk_size=chunk_size), total=int(len(docs) / chunk_size) ): texts = [d.page_content for d in chunk if d.page_content] vecs.append(self._get_batch_embeddings(texts)) embeddings = np.vstack(vecs) if persist: self.persist_embeddings(embeddings, metadatas, ids) return embeddings, ids, metadatas def persist_embeddings(self, embeddings, metadatas, ids): folder_name, fn_embeddings, fn_ids, fn_metadatas = self._get_embedding_fnames() csr_embeddings = scipy.sparse.csr_matrix(embeddings) if not os.path.exists(folder_name): os.makedirs(folder_name) scipy.sparse.save_npz(fn_embeddings, csr_embeddings) self.save_list(ids, fn_ids) self.save_list(metadatas, fn_metadatas) logger.info(f"Saved embeddings to {fn_embeddings}") def query( self, search: str, chunk_size: int, n: int = 50, label: str = "" ) -> Tuple[np.ndarray, np.ndarray]: if self._embeddings is None or self._ids is None: logger.info("Loading embeddings...") self.load() if ( label and label in self._labels_to_ind and self._embeddings is not None and self._ids is not None ): indices = sorted( list( set(self._labels_to_ind[label]).intersection( set(self._chunk_size_to_ind[chunk_size]) ) ) ) else: indices = sorted(list(set(self._chunk_size_to_ind[chunk_size]))) embeddings = self._embeddings[indices] ids = self._ids[indices] l2_norm_matrix = scipy.sparse.linalg.norm(embeddings, axis=1) embed_query = self._get_batch_embeddings(docs=[search]) l2_norm_query = scipy.linalg.norm(embed_query) if embeddings is not None and l2_norm_matrix is not None and ids is not None: cosine_similarity = embeddings.dot(embed_query) / ( l2_norm_matrix * l2_norm_query ) most_similar = np.argsort(cosine_similarity) top_similar_indices = most_similar[-n:][::-1] return ( ids[top_similar_indices], cosine_similarity[top_similar_indices], ) def save_list(self, list_: list, fname: str) -> None: with open(fname, "wb") as fp: pickle.dump(list_, fp)