rdkulkarni's picture
Create app.py
eaf690e
raw
history blame contribute delete
No virus
2.61 kB
from typing import List
from pathlib import Path
from langchain.embeddings import OpenAIEmbeddings
from langchain.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
from langchain.schema import StrOutputParser
from langchain.document_loaders import (
PyMuPDFLoader,
)
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores.chroma import Chroma
from langchain.indexes import SQLRecordManager, index
from langchain.schema import Document
from langchain.schema.runnable import Runnable, RunnablePassthrough, RunnableConfig
import chainlit as cl
chunk_size = 1024
chunk_overlap = 50
embeddings_model = OpenAIEmbeddings()
PDF_STORAGE_PATH = "./pdfs"
def process_pdfs(pdf_storage_path: str):
pdf_directory = Path(pdf_storage_path)
docs = [] # type: List[Document]
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
for pdf_path in pdf_directory.glob("*.pdf"):
loader = PyMuPDFLoader(str(pdf_path))
documents = loader.load()
docs += text_splitter.split_documents(documents)
doc_search = Chroma.from_documents(docs, embeddings_model)
namespace = "chromadb/my_documents"
record_manager = SQLRecordManager(
namespace, db_url="sqlite:///record_manager_cache.sql"
)
record_manager.create_schema()
index_result = index(
docs,
record_manager,
doc_search,
cleanup="incremental",
source_id_key="source",
)
return doc_search
doc_search = process_pdfs(PDF_STORAGE_PATH)
model = ChatOpenAI(model_name="gpt-4", streaming=True)
@cl.on_chat_start
async def on_chat_start():
template = """Answer the question based only on the following context:
{context}
Question: {question}
"""
prompt = ChatPromptTemplate.from_template(template)
def format_docs(docs):
return "\n\n".join([d.page_content for d in docs])
retriever = doc_search.as_retriever()
runnable = (
{"context": retriever | format_docs, "question": RunnablePassthrough()}
| prompt
| model
| StrOutputParser()
)
cl.user_session.set("runnable", runnable)
@cl.on_message
async def on_message(message: cl.Message):
runnable = cl.user_session.get("runnable") # type: Runnable
msg = cl.Message(content="")
await msg.send()
async for chunk in runnable.astream(
message.content,
config=RunnableConfig(callbacks=[cl.LangchainCallbackHandler()]),
):
await msg.stream_token(chunk)
await msg.update()