Basti8499's picture
Adds all necessary files
579ab0b verified
raw
history blame
No virus
4.36 kB
import chainlit as cl
from helper import HelperMethods
from pydantic.v1.error_wrappers import ValidationError
from cohere.error import CohereAPIError
COLLECTION_NAME = "ISO_27001_Collection"
@cl.on_chat_start
async def on_chat_start():
"""
Is called when a new chat session is created. Adds an initial message and sets important objects into session state.
"""
await cl.sleep(1)
msg = cl.Message(author="ISO 27001 - Assistant", content="Hello, do you have questions on ISO 27001? Feel free to ask me.")
await msg.send()
helper = HelperMethods()
try:
llm, MAX_CONTEXT_SIZE = await helper.get_LLM()
except ValidationError as e:
error_message = cl.ErrorMessage(
author="ISO 27001 - Assistant",
content="A validation error occurred. Please ensure the Open API_KEY is correctly set. You can navigate to the profile icon and then reset the keys. After that reload the page and try to ask the question again.",
)
await error_message.send()
return
state = {"llm": llm, "max_context_size": MAX_CONTEXT_SIZE, "vectordb": helper.get_index_vector_db(COLLECTION_NAME)}
cl.user_session.set("state_ISO", state)
@cl.on_message
async def on_message(message: cl.Message):
"""
Is called when a new message is sent by the user. Executes the RAG pipeline (check english, retrieve contexts, check relevancy, check context size, prompt LLM)
"""
state = cl.user_session.get("state_ISO")
helper = HelperMethods()
query = message.content
if helper.check_if_english(query):
try:
docs = helper.retrieve_contexts(state["vectordb"], query)
except CohereAPIError as e:
error_message = cl.ErrorMessage(
author="ISO 27001 - Assistant",
content="A validation error occurred. Please ensure the Cohere API_KEY is correctly set. You can navigate to the profile icon and then reset the keys. After that reload the page and try to ask the question again.",
)
await error_message.send()
return
if helper.check_if_relevant(docs):
if helper.is_context_size_valid(docs, query, state["max_context_size"]):
msg = cl.Message(author="ISO 27001 - Assistant", content="")
await msg.send()
full_prompt, sources, template_path, template_source= helper.get_full_prompt_sources_and_template(docs, state["llm"], query)
try:
stream = state["llm"].astream(full_prompt)
except ValidationError as e:
error_message = cl.ErrorMessage(
author="ISO 27001 - Assistant",
content="A validation error occurred. Please ensure the Open API_KEY is correctly set. You can navigate to the profile icon and then reset the keys. After that reload the page and try to ask the question again.",
)
await error_message.send()
return
async for part in stream:
await msg.stream_token(part.content)
if template_path == "":
sources_str = "\n\nSources: \n" + sources
msg.content += sources_str
await msg.update()
else:
sources_str = "\n\nSources: \n" + sources
elements = [cl.File(name=template_source, path=template_path, display="inline")]
msg.content += sources_str
msg.elements = elements
await msg.update()
else:
await cl.Message(
author="ISO 27001 - Assistant",
content="I am sorry. I cannot process your question, as it would exceed my token limit. Please try to reformulate your question, or ask something else.",
).send()
else:
await cl.Message(author="ISO 27001 - Assistant", content="I am sorry. I cannot process your question, as it is not related to ISO 27001.").send()
else:
await cl.Message(
author="ISO 27001 - Assistant", content="I am sorry. I cannot process your question, as I can only answer questions written in English."
).send()