PDFChat / app_api.py
CCCDev's picture
Rename app.py to app_api.py
f263697 verified
from flask import Flask, request, jsonify
import os
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain.chains import ConversationalRetrievalChain
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.llms import HuggingFaceEndpoint
from langchain.memory import ConversationBufferMemory
from pathlib import Path
import chromadb
from unidecode import unidecode
import re
app = Flask(__name__)
# Configuration variables
PDF_PATH = "https://huggingface.co/spaces/CCCDev/PDFChat/resolve/main/Data-privacy-policy.pdf" # Replace with your static PDF path
CHUNK_SIZE = 512
CHUNK_OVERLAP = 24
LLM_MODEL = "mistralai/Mistral-7B-Instruct-v0.2"
TEMPERATURE = 0.1
MAX_TOKENS = 512
TOP_K = 20
# Load PDF document and create doc splits
def load_doc(pdf_path, chunk_size, chunk_overlap):
loader = PyPDFLoader(pdf_path)
pages = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
doc_splits = text_splitter.split_documents(pages)
return doc_splits
# Create vector database
def create_db(splits, collection_name):
embedding = HuggingFaceEmbeddings()
new_client = chromadb.EphemeralClient()
vectordb = Chroma.from_documents(
documents=splits,
embedding=embedding,
client=new_client,
collection_name=collection_name,
)
return vectordb
# Initialize langchain LLM chain
def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db):
llm = HuggingFaceEndpoint(
repo_id=llm_model,
temperature=temperature,
max_new_tokens=max_tokens,
top_k=top_k,
)
memory = ConversationBufferMemory(
memory_key="chat_history",
output_key='answer',
return_messages=True
)
retriever = vector_db.as_retriever()
qa_chain = ConversationalRetrievalChain.from_llm(
llm,
retriever=retriever,
chain_type="stuff",
memory=memory,
return_source_documents=True,
verbose=False,
)
return qa_chain
# Generate collection name for vector database
def create_collection_name(filepath):
collection_name = Path(filepath).stem
collection_name = collection_name.replace(" ", "-")
collection_name = unidecode(collection_name)
collection_name = re.sub('[^A-Za-z0-9]+', '-', collection_name)
collection_name = collection_name[:50]
if len(collection_name) < 3:
collection_name = collection_name + 'xyz'
if not collection_name[0].isalnum():
collection_name = 'A' + collection_name[1:]
if not collection_name[-1].isalnum():
collection_name = collection_name[:-1] + 'Z'
return collection_name
# Initialize database and QA chain
doc_splits = load_doc(PDF_PATH, CHUNK_SIZE, CHUNK_OVERLAP)
collection_name = create_collection_name(PDF_PATH)
vector_db = create_db(doc_splits, collection_name)
qa_chain = initialize_llmchain(LLM_MODEL, TEMPERATURE, MAX_TOKENS, TOP_K, vector_db)
@app.route('/chat', methods=['POST'])
def chat():
data = request.json
message = data.get('message', '')
history = data.get('history', [])
formatted_chat_history = []
for user_message, bot_message in history:
formatted_chat_history.append(f"User: {user_message}")
formatted_chat_history.append(f"Assistant: {bot_message}")
response = qa_chain({"question": message, "chat_history": formatted_chat_history})
response_answer = response["answer"]
if response_answer.find("Helpful Answer:") != -1:
response_answer = response_answer.split("Helpful Answer:")[-1]
response_sources = response["source_documents"]
result = {
"answer": response_answer,
"sources": [
{"content": doc.page_content.strip(), "page": doc.metadata["page"] + 1}
for doc in response_sources
]
}
return jsonify(result)
if __name__ == '__main__':
app.run(debug=True, host='0.0.0.0', port=5000)