File size: 4,356 Bytes
579ab0b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import chainlit as cl
from helper import HelperMethods
from pydantic.v1.error_wrappers import ValidationError
from cohere.error import CohereAPIError

COLLECTION_NAME = "ISO_27001_Collection"


@cl.on_chat_start
async def on_chat_start():
    """
    Is called when a new chat session is created. Adds an initial message and sets important objects into session state.
    """

    await cl.sleep(1)

    msg = cl.Message(author="ISO 27001 - Assistant", content="Hello, do you have questions on ISO 27001? Feel free to ask me.")
    await msg.send()

    helper = HelperMethods()

    try:
        llm, MAX_CONTEXT_SIZE = await helper.get_LLM()
    except ValidationError as e:
        error_message = cl.ErrorMessage(
            author="ISO 27001 - Assistant",
            content="A validation error occurred. Please ensure the Open API_KEY is correctly set. You can navigate to the profile icon and then reset the keys. After that reload the page and try to ask the question again.",
        )
        await error_message.send()
        return

    state = {"llm": llm, "max_context_size": MAX_CONTEXT_SIZE, "vectordb": helper.get_index_vector_db(COLLECTION_NAME)}
    cl.user_session.set("state_ISO", state)


@cl.on_message
async def on_message(message: cl.Message):
    """
    Is called when a new message is sent by the user. Executes the RAG pipeline (check english, retrieve contexts, check relevancy, check context size, prompt LLM)
    """

    state = cl.user_session.get("state_ISO")
    helper = HelperMethods()
    query = message.content

    if helper.check_if_english(query):

        try:
            docs = helper.retrieve_contexts(state["vectordb"], query)
        except CohereAPIError as e:
            error_message = cl.ErrorMessage(
                author="ISO 27001 - Assistant",
                content="A validation error occurred. Please ensure the Cohere API_KEY is correctly set. You can navigate to the profile icon and then reset the keys. After that reload the page and try to ask the question again.",
            )
            await error_message.send()
            return
        
        if helper.check_if_relevant(docs):
            if helper.is_context_size_valid(docs, query, state["max_context_size"]):

                msg = cl.Message(author="ISO 27001 - Assistant", content="")
                await msg.send()

                full_prompt, sources, template_path, template_source= helper.get_full_prompt_sources_and_template(docs, state["llm"], query)

                try:
                    stream = state["llm"].astream(full_prompt)
                except ValidationError as e:
                    error_message = cl.ErrorMessage(
                        author="ISO 27001 - Assistant",
                        content="A validation error occurred. Please ensure the Open API_KEY is correctly set. You can navigate to the profile icon and then reset the keys. After that reload the page and try to ask the question again.",
                    )
                    await error_message.send()
                    return  

                async for part in stream:
                    await msg.stream_token(part.content)

                if template_path == "":
                    sources_str = "\n\nSources: \n" + sources
                    msg.content += sources_str
                    await msg.update()
                else:
                    sources_str = "\n\nSources: \n" + sources
                    elements = [cl.File(name=template_source, path=template_path, display="inline")]
                    msg.content += sources_str
                    msg.elements = elements
                    await msg.update()
            else:
                await cl.Message(
                    author="ISO 27001 - Assistant",
                    content="I am sorry. I cannot process your question, as it would exceed my token limit. Please try to reformulate your question, or ask something else.",
                ).send()
        else:
            await cl.Message(author="ISO 27001 - Assistant", content="I am sorry. I cannot process your question, as it is not related to ISO 27001.").send()
    else:
        await cl.Message(
            author="ISO 27001 - Assistant", content="I am sorry. I cannot process your question, as I can only answer questions written in English."
        ).send()