Spaces:
Sleeping
Sleeping
from fastapi import FastAPI | |
from langchain_community.document_loaders import PyPDFLoader | |
from langchain_community.embeddings import OpenAIEmbeddings | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_community.vectorstores import Chroma | |
from langchain.chains import RetrievalQAWithSourcesChain | |
from langchain_community.chat_models import ChatOpenAI | |
from langchain.prompts.chat import ChatPromptTemplate | |
import chainlit as cl | |
import tempfile | |
from dotenv import load_dotenv | |
load_dotenv() | |
app = FastAPI() | |
import tiktoken | |
def tiktoken_len(text): | |
tokens = tiktoken.encoding_for_model("gpt-3.5-turbo").encode( | |
text, | |
) | |
return len(tokens) | |
# Split the document into chunks | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=500, | |
chunk_overlap=50, | |
length_function=tiktoken_len, | |
) | |
# Load the embeddings model | |
# from langchain_openai.embeddings import OpenAIEmbeddings | |
from langchain.embeddings import OpenAIEmbeddings | |
embedding_model = OpenAIEmbeddings(model="text-embedding-3-small") | |
# from langchain_openai import ChatOpenAI | |
from langchain.chat_models import ChatOpenAI | |
openai_chat_model = ChatOpenAI(model="gpt-3.5-turbo") | |
from langchain_core.prompts import ChatPromptTemplate | |
RAG_PROMPT = """ | |
SYSTEM: | |
You are a professional personal assistant. | |
CONTEXT: | |
{context} | |
QUERY: | |
{question} | |
""" | |
rag_prompt = ChatPromptTemplate.from_template(RAG_PROMPT) | |
from operator import itemgetter | |
from langchain.schema.output_parser import StrOutputParser | |
from langchain.schema.runnable import RunnablePassthrough | |
async def init(): | |
files = None | |
# Wait for the user to upload a file | |
while files is None: | |
files = await cl.AskFileMessage( | |
content="Please upload a file to start chatting!", accept=["pdf"] | |
).send() | |
file = files[0] | |
msg = cl.Message(content=f"Processing `{file.name}`...") | |
await msg.send() | |
with tempfile.NamedTemporaryFile(delete=False) as temp: | |
temp.write(file.content) | |
temp_path = temp.name | |
# Load the PDF using PyPDFLoader into an array of documents, where each document contains the page content and metadata with page number. | |
loader = PyPDFLoader(temp_path) | |
docs = loader.load_and_split() # Define `docs` by loading and splitting the PDF | |
# Split the documents into chunks | |
split_chunks = text_splitter.split_documents(docs) # Split the `docs` into chunks | |
# Combine the page content into a single text variable. | |
text = ' '.join([page.page_content for page in docs]) # Use `docs` to create the `text` variable | |
# Split the text into chunks | |
texts = text_splitter.split_text(text) # Split the `text` into chunks | |
# Create metadata for each chunk | |
metadatas = [{"source": f"{i}-word"} for i in range(len(texts))] # Create metadata for each chunk | |
# Create a Chroma vector store | |
embeddings = OpenAIEmbeddings() | |
docsearch = await cl.make_async(Chroma.from_texts)( | |
texts, embeddings, metadatas=metadatas # Use `texts` and `metadatas` to create the vector store | |
) | |
# Create a chain that uses the Chroma vector store | |
chain = RetrievalQAWithSourcesChain.from_chain_type( | |
ChatOpenAI(temperature=0), | |
chain_type="stuff", | |
retriever=docsearch.as_retriever(), # Use the Chroma retriever | |
) | |
# Save the metadata and texts in the user session | |
cl.user_session.set("metadatas", metadatas) # Save `metadatas` in the user session | |
cl.user_session.set("texts", texts) # Save `texts` in the user session | |
# Let the user know that the system is ready | |
msg.content = f"`{file.name}` processed. You can now ask questions!" | |
await msg.update() | |
cl.user_session.set("chain", chain) | |
async def process_response(message): | |
chain = cl.user_session.get("chain") | |
if chain is None: | |
await cl.Message(content="The system is not initialized. Please upload a PDF file first.").send() | |
return | |
# Use the chain to process the user's question | |
response = await chain.acall({ | |
"question": message.content | |
}) | |
answer = response["answer"] | |
sources = response["sources"].strip() | |
source_elements = [] | |
# Get the metadata and texts from the user session | |
metadatas = cl.user_session.get("metadatas") | |
all_sources = [m["source"] for m in metadatas] | |
texts = cl.user_session.get("texts") | |
if sources: | |
found_sources = [] | |
# Add the sources to the message | |
for source in sources.split(","): | |
source_name = source.strip().replace(".", "") | |
# Get the index of the source | |
try: | |
index = all_sources.index(source_name) | |
except ValueError: | |
continue | |
text = texts[index] | |
found_sources.append(source_name) | |
# Create the text element referenced in the message | |
source_elements.append(cl.Text(content=text, name=source_name)) | |
if found_sources: | |
answer += f"\nSources: {', '.join(found_sources)}" | |
else: | |
answer += "\nNo sources found" | |
await cl.Message(content=answer, elements=source_elements).send() | |