CCCDev commited on
Commit
b83debb
·
verified ·
1 Parent(s): d88d0f2

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +119 -0
app.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify
2
+ import os
3
+ from langchain_community.document_loaders import PyPDFLoader
4
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
5
+ from langchain_community.vectorstores import Chroma
6
+ from langchain.chains import ConversationalRetrievalChain
7
+ from langchain_community.embeddings import HuggingFaceEmbeddings
8
+ from langchain_community.llms import HuggingFaceEndpoint
9
+ from langchain.memory import ConversationBufferMemory
10
+ from pathlib import Path
11
+ import chromadb
12
+ from unidecode import unidecode
13
+ import re
14
+
15
+ app = Flask(__name__)
16
+
17
+ # Configuration variables
18
+ PDF_PATH = "path/to/your/static.pdf" # Replace with your static PDF path
19
+ CHUNK_SIZE = 512
20
+ CHUNK_OVERLAP = 24
21
+ LLM_MODEL = "mistralai/Mistral-7B-Instruct-v0.2"
22
+ TEMPERATURE = 0.1
23
+ MAX_TOKENS = 512
24
+ TOP_K = 20
25
+
26
+ # Load PDF document and create doc splits
27
+ def load_doc(pdf_path, chunk_size, chunk_overlap):
28
+ loader = PyPDFLoader(pdf_path)
29
+ pages = loader.load()
30
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
31
+ doc_splits = text_splitter.split_documents(pages)
32
+ return doc_splits
33
+
34
+ # Create vector database
35
+ def create_db(splits, collection_name):
36
+ embedding = HuggingFaceEmbeddings()
37
+ new_client = chromadb.EphemeralClient()
38
+ vectordb = Chroma.from_documents(
39
+ documents=splits,
40
+ embedding=embedding,
41
+ client=new_client,
42
+ collection_name=collection_name,
43
+ )
44
+ return vectordb
45
+
46
+ # Initialize langchain LLM chain
47
+ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db):
48
+ llm = HuggingFaceEndpoint(
49
+ repo_id=llm_model,
50
+ temperature=temperature,
51
+ max_new_tokens=max_tokens,
52
+ top_k=top_k,
53
+ )
54
+
55
+ memory = ConversationBufferMemory(
56
+ memory_key="chat_history",
57
+ output_key='answer',
58
+ return_messages=True
59
+ )
60
+ retriever = vector_db.as_retriever()
61
+ qa_chain = ConversationalRetrievalChain.from_llm(
62
+ llm,
63
+ retriever=retriever,
64
+ chain_type="stuff",
65
+ memory=memory,
66
+ return_source_documents=True,
67
+ verbose=False,
68
+ )
69
+ return qa_chain
70
+
71
+ # Generate collection name for vector database
72
+ def create_collection_name(filepath):
73
+ collection_name = Path(filepath).stem
74
+ collection_name = collection_name.replace(" ", "-")
75
+ collection_name = unidecode(collection_name)
76
+ collection_name = re.sub('[^A-Za-z0-9]+', '-', collection_name)
77
+ collection_name = collection_name[:50]
78
+ if len(collection_name) < 3:
79
+ collection_name = collection_name + 'xyz'
80
+ if not collection_name[0].isalnum():
81
+ collection_name = 'A' + collection_name[1:]
82
+ if not collection_name[-1].isalnum():
83
+ collection_name = collection_name[:-1] + 'Z'
84
+ return collection_name
85
+
86
+ # Initialize database and QA chain
87
+ doc_splits = load_doc(PDF_PATH, CHUNK_SIZE, CHUNK_OVERLAP)
88
+ collection_name = create_collection_name(PDF_PATH)
89
+ vector_db = create_db(doc_splits, collection_name)
90
+ qa_chain = initialize_llmchain(LLM_MODEL, TEMPERATURE, MAX_TOKENS, TOP_K, vector_db)
91
+
92
+ @app.route('/chat', methods=['POST'])
93
+ def chat():
94
+ data = request.json
95
+ message = data.get('message', '')
96
+ history = data.get('history', [])
97
+
98
+ formatted_chat_history = []
99
+ for user_message, bot_message in history:
100
+ formatted_chat_history.append(f"User: {user_message}")
101
+ formatted_chat_history.append(f"Assistant: {bot_message}")
102
+
103
+ response = qa_chain({"question": message, "chat_history": formatted_chat_history})
104
+ response_answer = response["answer"]
105
+ if response_answer.find("Helpful Answer:") != -1:
106
+ response_answer = response_answer.split("Helpful Answer:")[-1]
107
+ response_sources = response["source_documents"]
108
+
109
+ result = {
110
+ "answer": response_answer,
111
+ "sources": [
112
+ {"content": doc.page_content.strip(), "page": doc.metadata["page"] + 1}
113
+ for doc in response_sources
114
+ ]
115
+ }
116
+ return jsonify(result)
117
+
118
+ if __name__ == '__main__':
119
+ app.run(debug=True, host='0.0.0.0', port=5000)