pdf-ingestor / app.py
shoshana-levitt's picture
change imports
4c4805a
from fastapi import FastAPI
from langchain_community.document_loaders import PyPDFLoader
from langchain_community.embeddings import OpenAIEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain.chains import RetrievalQAWithSourcesChain
from langchain_community.chat_models import ChatOpenAI
from langchain.prompts.chat import ChatPromptTemplate
import chainlit as cl
import tempfile
from dotenv import load_dotenv
load_dotenv()
app = FastAPI()
import tiktoken
def tiktoken_len(text):
tokens = tiktoken.encoding_for_model("gpt-3.5-turbo").encode(
text,
)
return len(tokens)
# Split the document into chunks
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=500,
chunk_overlap=50,
length_function=tiktoken_len,
)
# Load the embeddings model
# from langchain_openai.embeddings import OpenAIEmbeddings
from langchain.embeddings import OpenAIEmbeddings
embedding_model = OpenAIEmbeddings(model="text-embedding-3-small")
# from langchain_openai import ChatOpenAI
from langchain.chat_models import ChatOpenAI
openai_chat_model = ChatOpenAI(model="gpt-3.5-turbo")
from langchain_core.prompts import ChatPromptTemplate
RAG_PROMPT = """
SYSTEM:
You are a professional personal assistant.
CONTEXT:
{context}
QUERY:
{question}
"""
rag_prompt = ChatPromptTemplate.from_template(RAG_PROMPT)
from operator import itemgetter
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnablePassthrough
@cl.on_chat_start
async def init():
files = None
# Wait for the user to upload a file
while files is None:
files = await cl.AskFileMessage(
content="Please upload a file to start chatting!", accept=["pdf"]
).send()
file = files[0]
msg = cl.Message(content=f"Processing `{file.name}`...")
await msg.send()
with tempfile.NamedTemporaryFile(delete=False) as temp:
temp.write(file.content)
temp_path = temp.name
# Load the PDF using PyPDFLoader into an array of documents, where each document contains the page content and metadata with page number.
loader = PyPDFLoader(temp_path)
docs = loader.load_and_split() # Define `docs` by loading and splitting the PDF
# Split the documents into chunks
split_chunks = text_splitter.split_documents(docs) # Split the `docs` into chunks
# Combine the page content into a single text variable.
text = ' '.join([page.page_content for page in docs]) # Use `docs` to create the `text` variable
# Split the text into chunks
texts = text_splitter.split_text(text) # Split the `text` into chunks
# Create metadata for each chunk
metadatas = [{"source": f"{i}-word"} for i in range(len(texts))] # Create metadata for each chunk
# Create a Chroma vector store
embeddings = OpenAIEmbeddings()
docsearch = await cl.make_async(Chroma.from_texts)(
texts, embeddings, metadatas=metadatas # Use `texts` and `metadatas` to create the vector store
)
# Create a chain that uses the Chroma vector store
chain = RetrievalQAWithSourcesChain.from_chain_type(
ChatOpenAI(temperature=0),
chain_type="stuff",
retriever=docsearch.as_retriever(), # Use the Chroma retriever
)
# Save the metadata and texts in the user session
cl.user_session.set("metadatas", metadatas) # Save `metadatas` in the user session
cl.user_session.set("texts", texts) # Save `texts` in the user session
# Let the user know that the system is ready
msg.content = f"`{file.name}` processed. You can now ask questions!"
await msg.update()
cl.user_session.set("chain", chain)
@cl.on_message
async def process_response(message):
chain = cl.user_session.get("chain")
if chain is None:
await cl.Message(content="The system is not initialized. Please upload a PDF file first.").send()
return
# Use the chain to process the user's question
response = await chain.acall({
"question": message.content
})
answer = response["answer"]
sources = response["sources"].strip()
source_elements = []
# Get the metadata and texts from the user session
metadatas = cl.user_session.get("metadatas")
all_sources = [m["source"] for m in metadatas]
texts = cl.user_session.get("texts")
if sources:
found_sources = []
# Add the sources to the message
for source in sources.split(","):
source_name = source.strip().replace(".", "")
# Get the index of the source
try:
index = all_sources.index(source_name)
except ValueError:
continue
text = texts[index]
found_sources.append(source_name)
# Create the text element referenced in the message
source_elements.append(cl.Text(content=text, name=source_name))
if found_sources:
answer += f"\nSources: {', '.join(found_sources)}"
else:
answer += "\nNo sources found"
await cl.Message(content=answer, elements=source_elements).send()