Spaces:
Sleeping
Sleeping
File size: 4,356 Bytes
579ab0b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
import chainlit as cl
from helper import HelperMethods
from pydantic.v1.error_wrappers import ValidationError
from cohere.error import CohereAPIError
COLLECTION_NAME = "ISO_27001_Collection"
@cl.on_chat_start
async def on_chat_start():
"""
Is called when a new chat session is created. Adds an initial message and sets important objects into session state.
"""
await cl.sleep(1)
msg = cl.Message(author="ISO 27001 - Assistant", content="Hello, do you have questions on ISO 27001? Feel free to ask me.")
await msg.send()
helper = HelperMethods()
try:
llm, MAX_CONTEXT_SIZE = await helper.get_LLM()
except ValidationError as e:
error_message = cl.ErrorMessage(
author="ISO 27001 - Assistant",
content="A validation error occurred. Please ensure the Open API_KEY is correctly set. You can navigate to the profile icon and then reset the keys. After that reload the page and try to ask the question again.",
)
await error_message.send()
return
state = {"llm": llm, "max_context_size": MAX_CONTEXT_SIZE, "vectordb": helper.get_index_vector_db(COLLECTION_NAME)}
cl.user_session.set("state_ISO", state)
@cl.on_message
async def on_message(message: cl.Message):
"""
Is called when a new message is sent by the user. Executes the RAG pipeline (check english, retrieve contexts, check relevancy, check context size, prompt LLM)
"""
state = cl.user_session.get("state_ISO")
helper = HelperMethods()
query = message.content
if helper.check_if_english(query):
try:
docs = helper.retrieve_contexts(state["vectordb"], query)
except CohereAPIError as e:
error_message = cl.ErrorMessage(
author="ISO 27001 - Assistant",
content="A validation error occurred. Please ensure the Cohere API_KEY is correctly set. You can navigate to the profile icon and then reset the keys. After that reload the page and try to ask the question again.",
)
await error_message.send()
return
if helper.check_if_relevant(docs):
if helper.is_context_size_valid(docs, query, state["max_context_size"]):
msg = cl.Message(author="ISO 27001 - Assistant", content="")
await msg.send()
full_prompt, sources, template_path, template_source= helper.get_full_prompt_sources_and_template(docs, state["llm"], query)
try:
stream = state["llm"].astream(full_prompt)
except ValidationError as e:
error_message = cl.ErrorMessage(
author="ISO 27001 - Assistant",
content="A validation error occurred. Please ensure the Open API_KEY is correctly set. You can navigate to the profile icon and then reset the keys. After that reload the page and try to ask the question again.",
)
await error_message.send()
return
async for part in stream:
await msg.stream_token(part.content)
if template_path == "":
sources_str = "\n\nSources: \n" + sources
msg.content += sources_str
await msg.update()
else:
sources_str = "\n\nSources: \n" + sources
elements = [cl.File(name=template_source, path=template_path, display="inline")]
msg.content += sources_str
msg.elements = elements
await msg.update()
else:
await cl.Message(
author="ISO 27001 - Assistant",
content="I am sorry. I cannot process your question, as it would exceed my token limit. Please try to reformulate your question, or ask something else.",
).send()
else:
await cl.Message(author="ISO 27001 - Assistant", content="I am sorry. I cannot process your question, as it is not related to ISO 27001.").send()
else:
await cl.Message(
author="ISO 27001 - Assistant", content="I am sorry. I cannot process your question, as I can only answer questions written in English."
).send()
|