File size: 2,031 Bytes
6f179e7 4ab98db 6f179e7 4ab98db 6f179e7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
from abc import ABCMeta, abstractmethod
from typing import Union
from langchain_community.embeddings import HuggingFaceEmbeddings
from transformers import AutoTokenizer
from langchain.text_splitter import TokenTextSplitter
from langchain_core.documents import Document
from torch.cuda import is_available
class BaseDB(metaclass=ABCMeta):
def __init__(self, embedding_name: str = None, persist_dir=None) -> None:
super().__init__()
self.client = None
if persist_dir:
self.persist_dir = persist_dir
else:
self.persist_dir = "data"
if not embedding_name:
embedding_name = "BAAI/bge-small-zh-v1.5"
if is_available():
model_kwargs = {"device": "cuda"}
else:
model_kwargs = {"device": "cpu"}
self.embedding = HuggingFaceEmbeddings(model_name=embedding_name,model_kwargs=model_kwargs)
self.tokenizer = AutoTokenizer.from_pretrained(embedding_name)
self.init_db()
@abstractmethod
def init_db(self):
pass
def text_splitter(
self, text: Union[str, Document], chunk_size=300, chunk_overlap=10
):
if isinstance(text, Document):
return TokenTextSplitter.from_huggingface_tokenizer(
self.tokenizer, chunk_size=chunk_size, chunk_overlap=chunk_overlap
).split_documents(text)
elif isinstance(text, str):
return TokenTextSplitter.from_huggingface_tokenizer(
self.tokenizer, chunk_size=chunk_size, chunk_overlap=chunk_overlap
).split_text(text)
else:
raise ValueError("text must be a str or Document")
@abstractmethod
def addStories(self, stories, metas=None):
pass
@abstractmethod
def deleteStoriesByMeta(self, metas):
pass
@abstractmethod
def searchBySim(self, query, n_results, metas, only_return_document=True):
pass
@abstractmethod
def searchByMeta(self, metas=None):
pass
|