File size: 2,194 Bytes
71968a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
# native packages
from api.llms.base import get_LLM
from api.embedding_models.base import get_embedding_model
from api.vector_index.base import get_vector_index
from llama_index.core import Settings
from llama_index.core.memory import ChatMemoryBuffer

QUERY_ENGINE_MODE = "tree_summarize"
CHAT_ENGINE_MODE = "context"
TOP_K = 3
MEMORY_TOKEN_LIMIT = 8000

class QueryEngine:
    def __init__(self,
                 embedding_model = "BAAI/bge-m3",
                 llm = "aya:8b",
                 vector_index = "chroma",
                 force_new_db = False):
        self.embed_config = get_embedding_model(embedding_model)
        self.llm_config = get_LLM(llm)
        self.index = get_vector_index(vector_index, force_new_db)
        self.engine = self.index.as_query_engine(
            text_qa_template = self.llm_config.query_context_template,
            response_mode = QUERY_ENGINE_MODE,
            similarity_top_k = TOP_K,
            streaming = True
        )

    def query(self, user_input):
        return self.engine.query(user_input)

    def query_streaming(self, user_input):
        return self.engine.query(user_input)

class ChatEngine:
    def __init__(self,
                 embedding_model = "BAAI/bge-m3",
                 llm = "gpt4o_mini",
                 vector_index = "chroma",
                 force_new_db = False):
        self.embed_config = get_embedding_model(embedding_model)
        self.llm_config = get_LLM(llm)
        self.index = get_vector_index(vector_index, force_new_db)
        self.engine = self.index.as_chat_engine(
            llm = Settings.llm,
            chat_mode = CHAT_ENGINE_MODE,
            verbose = False,
            memory = ChatMemoryBuffer.from_defaults(token_limit=MEMORY_TOKEN_LIMIT),
            system_prompt = self.llm_config.system_prompt,
            context_template = self.llm_config.chat_context_template,
            response_mode = QUERY_ENGINE_MODE,
            similarity_top_k = TOP_K,
            streaming = True
        )
    
    def query(self, user_input):
        return self.engine.chat(user_input)

    def query_streaming(self, user_input):
        return self.engine.stream_chat(user_input)