File size: 4,121 Bytes
b83debb c011e1f b83debb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
from flask import Flask, request, jsonify
import os
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain.chains import ConversationalRetrievalChain
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.llms import HuggingFaceEndpoint
from langchain.memory import ConversationBufferMemory
from pathlib import Path
import chromadb
from unidecode import unidecode
import re
app = Flask(__name__)
# Configuration variables
PDF_PATH = "https://huggingface.co/spaces/CCCDev/PDFChat/resolve/main/Data-privacy-policy.pdf" # Replace with your static PDF path
CHUNK_SIZE = 512
CHUNK_OVERLAP = 24
LLM_MODEL = "mistralai/Mistral-7B-Instruct-v0.2"
TEMPERATURE = 0.1
MAX_TOKENS = 512
TOP_K = 20
# Load PDF document and create doc splits
def load_doc(pdf_path, chunk_size, chunk_overlap):
loader = PyPDFLoader(pdf_path)
pages = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
doc_splits = text_splitter.split_documents(pages)
return doc_splits
# Create vector database
def create_db(splits, collection_name):
embedding = HuggingFaceEmbeddings()
new_client = chromadb.EphemeralClient()
vectordb = Chroma.from_documents(
documents=splits,
embedding=embedding,
client=new_client,
collection_name=collection_name,
)
return vectordb
# Initialize langchain LLM chain
def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db):
llm = HuggingFaceEndpoint(
repo_id=llm_model,
temperature=temperature,
max_new_tokens=max_tokens,
top_k=top_k,
)
memory = ConversationBufferMemory(
memory_key="chat_history",
output_key='answer',
return_messages=True
)
retriever = vector_db.as_retriever()
qa_chain = ConversationalRetrievalChain.from_llm(
llm,
retriever=retriever,
chain_type="stuff",
memory=memory,
return_source_documents=True,
verbose=False,
)
return qa_chain
# Generate collection name for vector database
def create_collection_name(filepath):
collection_name = Path(filepath).stem
collection_name = collection_name.replace(" ", "-")
collection_name = unidecode(collection_name)
collection_name = re.sub('[^A-Za-z0-9]+', '-', collection_name)
collection_name = collection_name[:50]
if len(collection_name) < 3:
collection_name = collection_name + 'xyz'
if not collection_name[0].isalnum():
collection_name = 'A' + collection_name[1:]
if not collection_name[-1].isalnum():
collection_name = collection_name[:-1] + 'Z'
return collection_name
# Initialize database and QA chain
doc_splits = load_doc(PDF_PATH, CHUNK_SIZE, CHUNK_OVERLAP)
collection_name = create_collection_name(PDF_PATH)
vector_db = create_db(doc_splits, collection_name)
qa_chain = initialize_llmchain(LLM_MODEL, TEMPERATURE, MAX_TOKENS, TOP_K, vector_db)
@app.route('/chat', methods=['POST'])
def chat():
data = request.json
message = data.get('message', '')
history = data.get('history', [])
formatted_chat_history = []
for user_message, bot_message in history:
formatted_chat_history.append(f"User: {user_message}")
formatted_chat_history.append(f"Assistant: {bot_message}")
response = qa_chain({"question": message, "chat_history": formatted_chat_history})
response_answer = response["answer"]
if response_answer.find("Helpful Answer:") != -1:
response_answer = response_answer.split("Helpful Answer:")[-1]
response_sources = response["source_documents"]
result = {
"answer": response_answer,
"sources": [
{"content": doc.page_content.strip(), "page": doc.metadata["page"] + 1}
for doc in response_sources
]
}
return jsonify(result)
if __name__ == '__main__':
app.run(debug=True, host='0.0.0.0', port=5000)
|