|
from typing import List |
|
from pathlib import Path |
|
|
|
from langchain.embeddings import OpenAIEmbeddings |
|
from langchain.chat_models import ChatOpenAI |
|
from langchain.prompts import ChatPromptTemplate |
|
from langchain.schema import StrOutputParser |
|
from langchain.document_loaders import ( |
|
PyMuPDFLoader, |
|
) |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain.vectorstores.chroma import Chroma |
|
from langchain.indexes import SQLRecordManager, index |
|
from langchain.schema import Document |
|
from langchain.schema.runnable import Runnable, RunnablePassthrough, RunnableConfig |
|
|
|
import chainlit as cl |
|
|
|
|
|
chunk_size = 1024 |
|
chunk_overlap = 50 |
|
|
|
embeddings_model = OpenAIEmbeddings() |
|
|
|
PDF_STORAGE_PATH = "./pdfs" |
|
|
|
|
|
def process_pdfs(pdf_storage_path: str): |
|
pdf_directory = Path(pdf_storage_path) |
|
docs = [] |
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100) |
|
|
|
for pdf_path in pdf_directory.glob("*.pdf"): |
|
loader = PyMuPDFLoader(str(pdf_path)) |
|
documents = loader.load() |
|
docs += text_splitter.split_documents(documents) |
|
|
|
doc_search = Chroma.from_documents(docs, embeddings_model) |
|
|
|
namespace = "chromadb/my_documents" |
|
record_manager = SQLRecordManager( |
|
namespace, db_url="sqlite:///record_manager_cache.sql" |
|
) |
|
record_manager.create_schema() |
|
|
|
index_result = index( |
|
docs, |
|
record_manager, |
|
doc_search, |
|
cleanup="incremental", |
|
source_id_key="source", |
|
) |
|
|
|
return doc_search |
|
|
|
|
|
doc_search = process_pdfs(PDF_STORAGE_PATH) |
|
model = ChatOpenAI(model_name="gpt-4", streaming=True) |
|
|
|
|
|
@cl.on_chat_start |
|
async def on_chat_start(): |
|
template = """Answer the question based only on the following context: |
|
|
|
{context} |
|
|
|
Question: {question} |
|
""" |
|
prompt = ChatPromptTemplate.from_template(template) |
|
|
|
def format_docs(docs): |
|
return "\n\n".join([d.page_content for d in docs]) |
|
|
|
retriever = doc_search.as_retriever() |
|
|
|
runnable = ( |
|
{"context": retriever | format_docs, "question": RunnablePassthrough()} |
|
| prompt |
|
| model |
|
| StrOutputParser() |
|
) |
|
|
|
cl.user_session.set("runnable", runnable) |
|
|
|
|
|
@cl.on_message |
|
async def on_message(message: cl.Message): |
|
runnable = cl.user_session.get("runnable") |
|
|
|
msg = cl.Message(content="") |
|
await msg.send() |
|
|
|
async for chunk in runnable.astream( |
|
message.content, |
|
config=RunnableConfig(callbacks=[cl.LangchainCallbackHandler()]), |
|
): |
|
await msg.stream_token(chunk) |
|
|
|
await msg.update() |