Spaces:
Sleeping
Sleeping
import os | |
import chainlit as cl | |
import openai | |
import tiktoken | |
from dotenv import load_dotenv | |
from operator import itemgetter | |
from langchain_community.document_loaders import PyMuPDFLoader | |
from langchain_text_splitters import RecursiveCharacterTextSplitter | |
from langchain_community.vectorstores import FAISS | |
from langchain_openai.embeddings import OpenAIEmbeddings | |
from langchain_core.prompts import PromptTemplate | |
from langchain_core.runnables import RunnableConfig, RunnablePassthrough | |
from langchain_openai import ChatOpenAI | |
# Load environment variables from .env file | |
load_dotenv() | |
# Environment variables | |
openai.api_key = os.environ.get("OPENAI_API_KEY") | |
if not openai.api_key: | |
raise ValueError("OPENAI_API_KEY environment variable not set") | |
# Set vector store path | |
VECTOR_STORE_PATH = "./data/vectorstore" | |
# Document loader | |
document_loader = PyMuPDFLoader("./data/Airbnb-10k.pdf") | |
documents = document_loader.load() | |
def tiktoken_len(text): | |
tokens = tiktoken.encoding_for_model("gpt-4o").encode(text) | |
return len(tokens) | |
# Load embeddings | |
openai_embeddings = OpenAIEmbeddings(model="text-embedding-ada-002") | |
# Text splitter | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=600, chunk_overlap=100) | |
split_documents = text_splitter.split_documents(documents) | |
# Create or load vector store | |
if os.path.exists(os.path.join(VECTOR_STORE_PATH, "index.faiss")): | |
print("Loading existing vectorstore from disk.") | |
vectorstore = FAISS.load_local( | |
VECTOR_STORE_PATH, | |
openai_embeddings, | |
allow_dangerous_deserialization=True | |
) | |
retriever = vectorstore.as_retriever() | |
print("Loaded Vectorstore") | |
else: | |
print("Indexing Files") | |
os.makedirs(VECTOR_STORE_PATH, exist_ok=True) | |
vectorstore = FAISS.from_documents(split_documents[:32], openai_embeddings) | |
for i in range(32, len(split_documents), 32): | |
vectorstore.add_documents(split_documents[i:i+32]) | |
vectorstore.save_local(VECTOR_STORE_PATH) | |
print("Vectorstore created and documents indexed.") | |
# Create retriever | |
retriever = vectorstore.as_retriever() | |
# Define the prompt template | |
RAG_PROMPT_TEMPLATE = """\ | |
system | |
You are a helpful assistant. You answer user questions based on provided context. If you can't answer the question with the provided context, say you don't know. | |
user | |
User Query: | |
{query} | |
Context: | |
{context} | |
assistant | |
""" | |
rag_prompt = PromptTemplate.from_template(RAG_PROMPT_TEMPLATE) | |
# Create ChatOpenAI instance | |
llm = ChatOpenAI(model_name="gpt-4o", temperature=0) | |
retrieval_augmented_qa_chain = ( | |
{"context": itemgetter("question") | retriever, "question": itemgetter("question")} | |
| RunnablePassthrough.assign(context=itemgetter("context")) | |
| {"response": rag_prompt | llm, "context": itemgetter("context")} | |
) | |
# Chainlit | |
async def start_chat(): | |
""" | |
This function will be called at the start of every user session. | |
We will build our LCEL RAG chain here and store it in the user session. | |
The user session is a dictionary that is unique to each user session and is stored in the memory of the server. | |
""" | |
settings = { | |
"model": "gpt-4o", | |
"temperature": 0, | |
"max_tokens": 500, | |
"top_p": 1, | |
"frequency_penalty": 0, | |
"presence_penalty": 0, | |
} | |
try: | |
lcel_rag_chain = ({"context": itemgetter("query") | retriever, "query": itemgetter("query")} | |
| rag_prompt | llm) | |
cl.user_session.set("lcel_rag_chain", lcel_rag_chain) | |
print("Chat session started and LCEL RAG chain set.") | |
except Exception as e: | |
print(f"Error in start_chat: {e}") | |
async def main(message: cl.Message): | |
""" | |
This function will be called every time a message is received from a session. | |
We will use the LCEL RAG chain to generate a response to the user query. | |
The LCEL RAG chain is stored in the user session and is unique to each user session - this is why we can access it here. | |
""" | |
try: | |
lcel_rag_chain = cl.user_session.get("lcel_rag_chain") | |
print(f"Received message: {message.content}") | |
print("Using LCEL RAG chain to generate response...") | |
msg = cl.Message(content="") | |
async for chunk in lcel_rag_chain.astream( | |
{"query": message.content}, | |
config=RunnableConfig(callbacks=[cl.LangchainCallbackHandler()]), | |
): | |
chunk_text = chunk.content if hasattr(chunk, 'content') else str(chunk) | |
print(f"Streaming chunk: {chunk_text}") | |
await msg.stream_token(chunk_text) | |
print("Sending final message...") | |
await msg.send() | |
print("Message sent.") | |
except KeyError as e: | |
print(f"Session error: {e}") | |
await message.send("Session error occurred. Please try again.") | |
except Exception as e: | |
print(f"Error: {e}") | |
await message.send("An error occurred. Please try again.") | |