ChatWorld / src /DataBase /BaseDB.py
JiangYH's picture
Upload folder using huggingface_hub
4ab98db verified
from abc import ABCMeta, abstractmethod
from typing import Union
from langchain_community.embeddings import HuggingFaceEmbeddings
from transformers import AutoTokenizer
from langchain.text_splitter import TokenTextSplitter
from langchain_core.documents import Document
from torch.cuda import is_available
class BaseDB(metaclass=ABCMeta):
def __init__(self, embedding_name: str = None, persist_dir=None) -> None:
super().__init__()
self.client = None
if persist_dir:
self.persist_dir = persist_dir
else:
self.persist_dir = "data"
if not embedding_name:
embedding_name = "BAAI/bge-small-zh-v1.5"
if is_available():
model_kwargs = {"device": "cuda"}
else:
model_kwargs = {"device": "cpu"}
self.embedding = HuggingFaceEmbeddings(model_name=embedding_name,model_kwargs=model_kwargs)
self.tokenizer = AutoTokenizer.from_pretrained(embedding_name)
self.init_db()
@abstractmethod
def init_db(self):
pass
def text_splitter(
self, text: Union[str, Document], chunk_size=300, chunk_overlap=10
):
if isinstance(text, Document):
return TokenTextSplitter.from_huggingface_tokenizer(
self.tokenizer, chunk_size=chunk_size, chunk_overlap=chunk_overlap
).split_documents(text)
elif isinstance(text, str):
return TokenTextSplitter.from_huggingface_tokenizer(
self.tokenizer, chunk_size=chunk_size, chunk_overlap=chunk_overlap
).split_text(text)
else:
raise ValueError("text must be a str or Document")
@abstractmethod
def addStories(self, stories, metas=None):
pass
@abstractmethod
def deleteStoriesByMeta(self, metas):
pass
@abstractmethod
def searchBySim(self, query, n_results, metas, only_return_document=True):
pass
@abstractmethod
def searchByMeta(self, metas=None):
pass