Spaces:
Sleeping
Sleeping
import chainlit as cl | |
from helper import HelperMethods | |
from pydantic.v1.error_wrappers import ValidationError | |
from cohere.error import CohereAPIError | |
COLLECTION_NAME = "ISO_27001_Collection" | |
async def on_chat_start(): | |
""" | |
Is called when a new chat session is created. Adds an initial message and sets important objects into session state. | |
""" | |
await cl.sleep(1) | |
msg = cl.Message(author="ISO 27001 - Assistant", content="Hello, do you have questions on ISO 27001? Feel free to ask me.") | |
await msg.send() | |
helper = HelperMethods() | |
try: | |
llm, MAX_CONTEXT_SIZE = await helper.get_LLM() | |
except ValidationError as e: | |
error_message = cl.ErrorMessage( | |
author="ISO 27001 - Assistant", | |
content="A validation error occurred. Please ensure the Open API_KEY is correctly set. You can navigate to the profile icon and then reset the keys. After that reload the page and try to ask the question again.", | |
) | |
await error_message.send() | |
return | |
state = {"llm": llm, "max_context_size": MAX_CONTEXT_SIZE, "vectordb": helper.get_index_vector_db(COLLECTION_NAME)} | |
cl.user_session.set("state_ISO", state) | |
async def on_message(message: cl.Message): | |
""" | |
Is called when a new message is sent by the user. Executes the RAG pipeline (check english, retrieve contexts, check relevancy, check context size, prompt LLM) | |
""" | |
state = cl.user_session.get("state_ISO") | |
helper = HelperMethods() | |
query = message.content | |
if helper.check_if_english(query): | |
try: | |
docs = helper.retrieve_contexts(state["vectordb"], query) | |
except CohereAPIError as e: | |
error_message = cl.ErrorMessage( | |
author="ISO 27001 - Assistant", | |
content="A validation error occurred. Please ensure the Cohere API_KEY is correctly set. You can navigate to the profile icon and then reset the keys. After that reload the page and try to ask the question again.", | |
) | |
await error_message.send() | |
return | |
if helper.check_if_relevant(docs): | |
if helper.is_context_size_valid(docs, query, state["max_context_size"]): | |
msg = cl.Message(author="ISO 27001 - Assistant", content="") | |
await msg.send() | |
full_prompt, sources, template_path, template_source= helper.get_full_prompt_sources_and_template(docs, state["llm"], query) | |
try: | |
stream = state["llm"].astream(full_prompt) | |
except ValidationError as e: | |
error_message = cl.ErrorMessage( | |
author="ISO 27001 - Assistant", | |
content="A validation error occurred. Please ensure the Open API_KEY is correctly set. You can navigate to the profile icon and then reset the keys. After that reload the page and try to ask the question again.", | |
) | |
await error_message.send() | |
return | |
async for part in stream: | |
await msg.stream_token(part.content) | |
if template_path == "": | |
sources_str = "\n\nSources: \n" + sources | |
msg.content += sources_str | |
await msg.update() | |
else: | |
sources_str = "\n\nSources: \n" + sources | |
elements = [cl.File(name=template_source, path=template_path, display="inline")] | |
msg.content += sources_str | |
msg.elements = elements | |
await msg.update() | |
else: | |
await cl.Message( | |
author="ISO 27001 - Assistant", | |
content="I am sorry. I cannot process your question, as it would exceed my token limit. Please try to reformulate your question, or ask something else.", | |
).send() | |
else: | |
await cl.Message(author="ISO 27001 - Assistant", content="I am sorry. I cannot process your question, as it is not related to ISO 27001.").send() | |
else: | |
await cl.Message( | |
author="ISO 27001 - Assistant", content="I am sorry. I cannot process your question, as I can only answer questions written in English." | |
).send() | |