File size: 2,031 Bytes
6f179e7
 
 
 
 
 
 
4ab98db
6f179e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4ab98db
 
 
 
 
 
6f179e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from abc import ABCMeta, abstractmethod
from typing import Union

from langchain_community.embeddings import HuggingFaceEmbeddings
from transformers import AutoTokenizer
from langchain.text_splitter import TokenTextSplitter
from langchain_core.documents import Document
from torch.cuda import is_available


class BaseDB(metaclass=ABCMeta):
    def __init__(self, embedding_name: str = None, persist_dir=None) -> None:
        super().__init__()

        self.client = None

        if persist_dir:
            self.persist_dir = persist_dir
        else:
            self.persist_dir = "data"

        if not embedding_name:
            embedding_name = "BAAI/bge-small-zh-v1.5"

        if is_available():
            model_kwargs = {"device": "cuda"}
        else:
            model_kwargs = {"device": "cpu"}

        self.embedding = HuggingFaceEmbeddings(model_name=embedding_name,model_kwargs=model_kwargs)
        self.tokenizer = AutoTokenizer.from_pretrained(embedding_name)

        self.init_db()

    @abstractmethod
    def init_db(self):
        pass

    def text_splitter(
        self, text: Union[str, Document], chunk_size=300, chunk_overlap=10
    ):
        if isinstance(text, Document):
            return TokenTextSplitter.from_huggingface_tokenizer(
                self.tokenizer, chunk_size=chunk_size, chunk_overlap=chunk_overlap
            ).split_documents(text)
        elif isinstance(text, str):
            return TokenTextSplitter.from_huggingface_tokenizer(
                self.tokenizer, chunk_size=chunk_size, chunk_overlap=chunk_overlap
            ).split_text(text)
        else:
            raise ValueError("text must be a str or Document")

    @abstractmethod
    def addStories(self, stories, metas=None):
        pass

    @abstractmethod
    def deleteStoriesByMeta(self, metas):
        pass

    @abstractmethod
    def searchBySim(self, query, n_results, metas, only_return_document=True):
        pass

    @abstractmethod
    def searchByMeta(self, metas=None):
        pass