from abc import ABCMeta, abstractmethod from typing import Union from langchain_community.embeddings import HuggingFaceEmbeddings from transformers import AutoTokenizer from langchain.text_splitter import TokenTextSplitter from langchain_core.documents import Document from torch.cuda import is_available class BaseDB(metaclass=ABCMeta): def __init__(self, embedding_name: str = None, persist_dir=None) -> None: super().__init__() self.client = None if persist_dir: self.persist_dir = persist_dir else: self.persist_dir = "data" if not embedding_name: embedding_name = "BAAI/bge-small-zh-v1.5" if is_available(): model_kwargs = {"device": "cuda"} else: model_kwargs = {"device": "cpu"} self.embedding = HuggingFaceEmbeddings(model_name=embedding_name,model_kwargs=model_kwargs) self.tokenizer = AutoTokenizer.from_pretrained(embedding_name) self.init_db() @abstractmethod def init_db(self): pass def text_splitter( self, text: Union[str, Document], chunk_size=300, chunk_overlap=10 ): if isinstance(text, Document): return TokenTextSplitter.from_huggingface_tokenizer( self.tokenizer, chunk_size=chunk_size, chunk_overlap=chunk_overlap ).split_documents(text) elif isinstance(text, str): return TokenTextSplitter.from_huggingface_tokenizer( self.tokenizer, chunk_size=chunk_size, chunk_overlap=chunk_overlap ).split_text(text) else: raise ValueError("text must be a str or Document") @abstractmethod def addStories(self, stories, metas=None): pass @abstractmethod def deleteStoriesByMeta(self, metas): pass @abstractmethod def searchBySim(self, query, n_results, metas, only_return_document=True): pass @abstractmethod def searchByMeta(self, metas=None): pass