|
from abc import ABCMeta, abstractmethod |
|
from typing import Union |
|
|
|
from langchain_community.embeddings import HuggingFaceEmbeddings |
|
from transformers import AutoTokenizer |
|
from langchain.text_splitter import TokenTextSplitter |
|
from langchain_core.documents import Document |
|
from torch.cuda import is_available |
|
|
|
|
|
class BaseDB(metaclass=ABCMeta): |
|
def __init__(self, embedding_name: str = None, persist_dir=None) -> None: |
|
super().__init__() |
|
|
|
self.client = None |
|
|
|
if persist_dir: |
|
self.persist_dir = persist_dir |
|
else: |
|
self.persist_dir = "data" |
|
|
|
if not embedding_name: |
|
embedding_name = "BAAI/bge-small-zh-v1.5" |
|
|
|
if is_available(): |
|
model_kwargs = {"device": "cuda"} |
|
else: |
|
model_kwargs = {"device": "cpu"} |
|
|
|
self.embedding = HuggingFaceEmbeddings(model_name=embedding_name,model_kwargs=model_kwargs) |
|
self.tokenizer = AutoTokenizer.from_pretrained(embedding_name) |
|
|
|
self.init_db() |
|
|
|
@abstractmethod |
|
def init_db(self): |
|
pass |
|
|
|
def text_splitter( |
|
self, text: Union[str, Document], chunk_size=300, chunk_overlap=10 |
|
): |
|
if isinstance(text, Document): |
|
return TokenTextSplitter.from_huggingface_tokenizer( |
|
self.tokenizer, chunk_size=chunk_size, chunk_overlap=chunk_overlap |
|
).split_documents(text) |
|
elif isinstance(text, str): |
|
return TokenTextSplitter.from_huggingface_tokenizer( |
|
self.tokenizer, chunk_size=chunk_size, chunk_overlap=chunk_overlap |
|
).split_text(text) |
|
else: |
|
raise ValueError("text must be a str or Document") |
|
|
|
@abstractmethod |
|
def addStories(self, stories, metas=None): |
|
pass |
|
|
|
@abstractmethod |
|
def deleteStoriesByMeta(self, metas): |
|
pass |
|
|
|
@abstractmethod |
|
def searchBySim(self, query, n_results, metas, only_return_document=True): |
|
pass |
|
|
|
@abstractmethod |
|
def searchByMeta(self, metas=None): |
|
pass |
|
|