File size: 3,009 Bytes
f37ceb5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from MongoChainGenerator import *
from MongoEmbeddingGenerator import *
from DATABASE import *
from appConfig import LOG


class Main:
    qa_chains = {}
    embedding_generator = None

    def __init__(self) -> None:
        DATABASE()
        self._initialize_embedding_generator()
        self._load_existing_qa_chains()

    def _initialize_embedding_generator(self):
        if Main.embedding_generator is None:
            Main.embedding_generator = MongoEmbeddingGenerator(repo_id=CONST_VAR.EMBEDDING_MODEL_REPO_ID)
            LOG.debug("Embedding generator initialized")

    def _load_existing_qa_chains(self):
        chats = DATABASE.client["chatData"]["chats"].find()
        for chat in chats:
            if chat["collectionName"] not in Main.qa_chains:
                self.create_exist_chains(chat)

    def create_exist_chains(self, chat):
        if chat["collectionName"] not in Main.qa_chains:
            qa_generator = MongoChainGenerator(
                embedding_model=Main.embedding_generator.embedding_model,
                db_collection_name=chat["collectionName"],
                template_context=chat["templateContext"]
            )
            Main.qa_chains[chat["collectionName"]] = qa_generator.generate_retrieval_qa_chain()
            LOG.debug("Chain created for collection " + chat["collectionName"])
        else:
            LOG.debug("Chain already exists for collection " + chat["collectionName"])

    def generate_embedding(self, content: str, file_name: str, collection_name: str):
        return Main.embedding_generator.generate_embeddings(content, file_name, collection_name)

    def generate_tmp_embedding_and_chain(self, contents: str, tmp_collection_name):
        qa_generator = MongoChainGenerator(
            embedding_model=Main.embedding_generator.embedding_model,
            template_context=CONST_VAR.TEMPLATE_CONTEXT,
            tmp_vector_embedding=Main.embedding_generator.generate_tmp_embeddings(pdf_bytes=contents)
        )
        Main.qa_chains[tmp_collection_name] = qa_generator.generate_retrieval_qa_chain()
        LOG.debug(tmp_collection_name + ' chain created')

    def ask_question(self, question: str, collection_name):
        if collection_name in Main.qa_chains:
            try:
                LOG.debug(collection_name + " answering")
                response = Main.qa_chains[collection_name]({"query": question, "early_stopping": True, "min_length": 2000, "max_tokens": 5000})
                return response["result"]
            except Exception as e:
                LOG.error("An error occurred while answering question: {}".format(str(e)))
                return "Retry to ask question! An error occurred: {}".format(str(e))
        else:
            LOG.warning("Chain for collection '{}' not found.".format(collection_name))
            return "Chain for collection '{}' not found.".format(collection_name)

    def check_collection_name(self, collection_name):
        return collection_name in self.qa_chains