Spaces:
Sleeping
Sleeping
# You can find this code for Chainlit python streaming here (https://docs.chainlit.io/concepts/streaming/python) | |
# OpenAI Chat completion | |
import os | |
import openai | |
from openai import AsyncOpenAI # importing openai for API usage | |
import chainlit as cl # importing chainlit for our app | |
from chainlit.prompt import Prompt, PromptMessage # importing prompt tools | |
from chainlit.playground.providers import ChatOpenAI # importing ChatOpenAI tools | |
from dotenv import load_dotenv | |
from aimakerspace.text_utils import TextFileLoader, CharacterTextSplitter | |
from aimakerspace.vectordatabase import VectorDatabase | |
import asyncio | |
from getpass import getpass | |
# load_dotenv() | |
#openai.api_key = getpass("OpenAI API Key: ") | |
#os.environ["OPENAI_API_KEY"] = openai.api_key | |
#wandb_key = getpass("Weights and Biases API Key: ") | |
#os.environ["WANDB_API_KEY"] = wandb_key | |
text_loader = TextFileLoader("data/KingLear.txt") | |
documents = text_loader.load_documents() | |
text_splitter = CharacterTextSplitter() | |
split_documents = text_splitter.split_texts(documents) | |
vector_db = VectorDatabase() | |
vector_db = asyncio.run(vector_db.abuild_from_list(split_documents)) | |
# | |
from aimakerspace.openai_utils.prompts import ( | |
UserRolePrompt, | |
SystemRolePrompt, | |
AssistantRolePrompt, | |
) | |
from aimakerspace.openai_utils.chatmodel import ChatOpenAI | |
chat_openai = ChatOpenAI() | |
RAQA_PROMPT_TEMPLATE = """ | |
Use the provided context to answer the user's query. | |
You may not answer the user's query unless there is specific context in the following text. | |
If you do not know the answer, or cannot answer, please respond with "I don't know". | |
Context: | |
{context} | |
""" | |
raqa_prompt = SystemRolePrompt(RAQA_PROMPT_TEMPLATE) | |
USER_PROMPT_TEMPLATE = """ | |
User Query: | |
{user_query} | |
""" | |
user_prompt = UserRolePrompt(USER_PROMPT_TEMPLATE) | |
import wandb | |
os.environ["WANDB_NOTEBOOK_NAME"] = "app.py" | |
wandb.init(project="Visibility Example") | |
# marks a function that should be run each time the chatbot receives a message from a user | |
async def main(message: cl.Message): | |
import datetime | |
from wandb.sdk.data_types.trace_tree import Trace | |
class RetrievalAugmentedQAPipeline: | |
def __init__(self, llm: ChatOpenAI(), vector_db_retriever: VectorDatabase, wandb_project = None) -> None: | |
self.llm = llm | |
self.vector_db_retriever = vector_db_retriever | |
self.wandb_project = wandb_project | |
def run_pipeline(self, user_query: str) -> str: | |
context_list = self.vector_db_retriever.search_by_text(user_query, k=4) | |
context_prompt = "" | |
for context in context_list: | |
context_prompt += context[0] + "\n" | |
formatted_system_prompt = raqa_prompt.create_message(context=context_prompt) | |
formatted_user_prompt = user_prompt.create_message(user_query=user_query) | |
start_time = datetime.datetime.now().timestamp() * 1000 | |
try: | |
openai_response = self.llm.run([formatted_system_prompt, formatted_user_prompt], text_only=False) | |
end_time = datetime.datetime.now().timestamp() * 1000 | |
status = "success" | |
status_message = (None, ) | |
response_text = openai_response.choices[0].message.content | |
token_usage = dict(openai_response.usage) | |
model = openai_response.model | |
except Exception as e: | |
end_time = datetime.datetime.now().timestamp() * 1000 | |
status = "error" | |
status_message = str(e) | |
response_text = "" | |
token_usage = {} | |
model = "" | |
if self.wandb_project: | |
root_span = Trace( | |
name="root_span", | |
kind="llm", | |
status_code=status, | |
status_message=status_message, | |
start_time_ms=start_time, | |
end_time_ms=end_time, | |
metadata={ | |
"token_usage" : token_usage, | |
"model_name" : model | |
}, | |
inputs= {"system_prompt" : formatted_system_prompt, "user_prompt" : formatted_user_prompt}, | |
outputs= {"response" : response_text} | |
) | |
root_span.log(name="openai_trace") | |
return response_text if response_text else "We ran into an error. Please try again later. Full Error Message: " + status_message | |
retrieval_augmented_qa_pipeline = RetrievalAugmentedQAPipeline( | |
vector_db_retriever=vector_db, | |
llm=chat_openai, | |
wandb_project="LLM Visibility Example" | |
) | |
await cl.Message( | |
content=f"{retrieval_augmented_qa_pipeline.run_pipeline(message.content)}", | |
).send() | |