codewithharsha commited on
Commit
84301bc
Β·
verified Β·
1 Parent(s): e2535b2

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +91 -87
main.py CHANGED
@@ -1,39 +1,45 @@
1
  import os
2
  import time
3
- import json # Import JSON for parsing
4
  from flask import Flask, request, jsonify, render_template
5
- from flask_cors import CORS
 
 
6
  from langchain_groq import ChatGroq
7
  from langchain_text_splitters import RecursiveCharacterTextSplitter
8
  from langchain.chains.combine_documents import create_stuff_documents_chain
9
- # from langchain.chains import create_stuff_documents_chain
10
  from langchain_core.prompts import ChatPromptTemplate
11
  from langchain.chains import create_retrieval_chain
12
  from langchain_community.vectorstores import FAISS
13
  from langchain_community.document_loaders import PyPDFDirectoryLoader
14
  from langchain_huggingface import HuggingFaceEmbeddings
15
- from dotenv import load_dotenv
16
 
17
- # Load environment variables
18
- load_dotenv()
19
 
20
- # --- LLM and API Key Setup ---
 
 
 
21
  groq_api_key = os.getenv("GROQ_API_KEY")
22
 
23
  if not groq_api_key:
24
- raise ValueError("GROQ_API_KEY not found. Please set it in your .env file or as an environment variable.")
25
 
 
 
 
26
  llm = ChatGroq(groq_api_key=groq_api_key, model_name="llama-3.1-8b-instant")
27
 
28
-
 
 
29
  def load_retrieval_chain():
30
  """
31
- Loads the vector database and creates the retrieval chain.
32
- This function runs once when the server starts.
33
  """
34
- print("Loading vector database... This may take a moment.")
35
-
36
- # --- PROMPT TEMPLATE - Reverted to simple stateless version ---
37
  prompt_template = """
38
  You are a friendly and helpful hotel assistant.
39
  Your role is to provide clear, welcoming, and professional responses to guest questions.
@@ -57,108 +63,106 @@ Question: {input}
57
  Your JSON Response:
58
  """
59
  prompt = ChatPromptTemplate.from_template(prompt_template)
60
-
 
61
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
62
- loader = PyPDFDirectoryLoader("data")
63
- docs = loader.load()
64
-
65
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
66
- final_documents = text_splitter.split_documents(docs[:50])
67
-
68
- vectors = FAISS.from_documents(final_documents, embeddings)
69
-
70
- print("Vector database loaded successfully.")
71
-
72
- document_chain = create_stuff_documents_chain(llm, prompt)
73
-
74
- # 1. Create the retriever from the vector store
 
 
 
 
 
 
 
 
 
 
75
  retriever = vectors.as_retriever()
76
- # 2. Create the retrieval chain
77
  retrieval_chain = create_retrieval_chain(retriever, document_chain)
78
-
 
79
  return retrieval_chain
80
 
81
- # --- Flask App Initialization ---
 
 
82
  app = Flask(__name__)
83
- CORS(app) # Enable CORS for all routes in your app
84
 
85
- # Load the retrieval chain ONCE when the app starts
86
- try:
87
- retrieval_chain = load_retrieval_chain()
88
- except Exception as e:
89
- print(f"Failed to load vector database on startup: {e}")
90
- retrieval_chain = None
91
 
92
- # --- NEW ROUTE TO SERVE YOUR WEBPAGE ---
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  @app.route("/")
94
  def index():
95
- """
96
- Serves the index.html file from the 'templates' folder.
97
- """
98
- return render_template('index.html')
99
 
100
  @app.route("/chat", methods=["POST"])
101
  def chat():
102
- """
103
- The main chat endpoint.
104
- Receives a JSON with "query" and returns a JSON with "intent" and "response".
105
- """
106
  if retrieval_chain is None:
107
- return jsonify({"error": "Vector database is not initialized. Check server logs."}), 500
108
 
109
  try:
 
 
110
  data = request.json
111
  user_query = data.get("query")
112
-
113
  if not user_query:
114
  return jsonify({"error": "No query provided"}), 400
115
 
116
- print(f"Received query: {user_query}")
117
-
118
  start = time.process_time()
119
-
120
- # Invoke the chain with the user's query
121
  response = retrieval_chain.invoke({'input': user_query})
122
-
123
- response_time = time.process_time() - start
124
- print(f"Response time: {response_time:.4f} seconds")
125
 
126
- # --- Parse the JSON response from the LLM ---
127
  try:
128
- # The LLM's answer is in the 'answer' field
129
  llm_output_str = response['answer']
130
- # The LLM output itself is a JSON string, so we parse it.
131
- parsed_response = json.loads(llm_output_str)
132
-
133
- # We can also add the RAG context for debugging
134
- parsed_response["context"] = [doc.page_content for doc in response['context']]
135
-
136
- print(f"LLM Response: {parsed_response}")
137
-
138
- return jsonify(parsed_response)
139
-
140
  except json.JSONDecodeError:
141
- print(f"Error: LLM did not return valid JSON. Response was: {llm_output_str}")
142
  return jsonify({"intent": "qa", "response": "I'm sorry, I had a small glitch. Could you rephrase that?"})
143
- except Exception as e:
144
- print(f"Error parsing LLM response: {e}")
145
- return jsonify({"intent": "qa", "response": "I'm sorry, I'm having trouble processing that request."})
146
-
147
  except Exception as e:
148
- print(f"Error processing request: {e}")
149
  return jsonify({"error": str(e)}), 500
150
 
151
- # --- /book ENDPOINT REMOVED ---
152
-
153
- # --- Run the Flask Server ---
154
  if __name__ == "__main__":
155
- # Ensure a 'data' directory exists
156
- if not os.path.exists("data"):
157
- os.makedirs("data")
158
- print("Created 'data' directory. Please add your PDF files here and restart the server.")
159
-
160
- # init_db() call removed
161
-
162
- print("Starting Flask server...")
163
- # Running on 0.0.0.0 makes it accessible on your network, ready for EC2
164
- app.run(debug=True, host="0.0.0.0", port=7860)
 
1
  import os
2
  import time
3
+ import json
4
  from flask import Flask, request, jsonify, render_template
5
+ from flask_cors import CORS
6
+ from dotenv import load_dotenv
7
+ import logging
8
  from langchain_groq import ChatGroq
9
  from langchain_text_splitters import RecursiveCharacterTextSplitter
10
  from langchain.chains.combine_documents import create_stuff_documents_chain
 
11
  from langchain_core.prompts import ChatPromptTemplate
12
  from langchain.chains import create_retrieval_chain
13
  from langchain_community.vectorstores import FAISS
14
  from langchain_community.document_loaders import PyPDFDirectoryLoader
15
  from langchain_huggingface import HuggingFaceEmbeddings
 
16
 
17
+ logging.basicConfig(level=logging.DEBUG)
 
18
 
19
+ # ==========================================================
20
+ # Load environment variables
21
+ # ==========================================================
22
+ load_dotenv()
23
  groq_api_key = os.getenv("GROQ_API_KEY")
24
 
25
  if not groq_api_key:
26
+ raise ValueError("❌ GROQ_API_KEY not found. Please set it in your .env file or as an environment variable.")
27
 
28
+ # ==========================================================
29
+ # Initialize LLM
30
+ # ==========================================================
31
  llm = ChatGroq(groq_api_key=groq_api_key, model_name="llama-3.1-8b-instant")
32
 
33
+ # ==========================================================
34
+ # Function: Load / Build Retrieval Chain
35
+ # ==========================================================
36
  def load_retrieval_chain():
37
  """
38
+ Loads or builds the FAISS vector index and creates a retrieval chain.
39
+ This is now lazy-loaded to prevent Gunicorn worker boot crashes.
40
  """
41
+ print("πŸ”„ Initializing retrieval chain...")
42
+
 
43
  prompt_template = """
44
  You are a friendly and helpful hotel assistant.
45
  Your role is to provide clear, welcoming, and professional responses to guest questions.
 
63
  Your JSON Response:
64
  """
65
  prompt = ChatPromptTemplate.from_template(prompt_template)
66
+
67
+ # --- Load Embeddings ---
68
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
69
+
70
+ # --- Create or Load FAISS Vectorstore ---
71
+ if not os.path.exists("data"):
72
+ os.makedirs("data")
73
+ print("⚠️ 'data' folder created. Please add your PDFs and restart.")
74
+ raise ValueError("No PDFs found in 'data' folder.")
75
+
76
+ if os.path.exists("faiss_index"):
77
+ print("βœ… Loading existing FAISS index...")
78
+ vectors = FAISS.load_local("faiss_index", embeddings, allow_dangerous_deserialization=True)
79
+ else:
80
+ print("πŸ“„ Loading PDFs and building FAISS index (first-time setup)...")
81
+ loader = PyPDFDirectoryLoader("data")
82
+ docs = loader.load()
83
+ if not docs:
84
+ raise ValueError("No PDF documents found in 'data' folder.")
85
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
86
+ final_docs = text_splitter.split_documents(docs[:50])
87
+ vectors = FAISS.from_documents(final_docs, embeddings)
88
+ vectors.save_local("faiss_index")
89
+ print("πŸ’Ύ FAISS index saved to 'faiss_index' for future runs.")
90
+
91
+ # --- Create Chains ---
92
  retriever = vectors.as_retriever()
93
+ document_chain = create_stuff_documents_chain(llm, prompt)
94
  retrieval_chain = create_retrieval_chain(retriever, document_chain)
95
+
96
+ print("βœ… Retrieval chain initialized successfully.")
97
  return retrieval_chain
98
 
99
+ # ==========================================================
100
+ # Flask App Setup
101
+ # ==========================================================
102
  app = Flask(__name__)
103
+ CORS(app)
104
 
105
+ retrieval_chain = None # Lazy-load later
 
 
 
 
 
106
 
107
+ @app.before_request
108
+ def init_retrieval():
109
+ """Initialize retrieval chain after Flask starts (prevents Gunicorn crash)."""
110
+ global retrieval_chain
111
+ if retrieval_chain is None:
112
+ try:
113
+ retrieval_chain = load_retrieval_chain()
114
+ except Exception as e:
115
+ print(f"❌ Failed to initialize retrieval chain: {e}")
116
+ retrieval_chain = None
117
+
118
+ # ==========================================================
119
+ # Routes
120
+ # ==========================================================
121
  @app.route("/")
122
  def index():
123
+ """Serve main web page."""
124
+ return render_template("index.html")
 
 
125
 
126
  @app.route("/chat", methods=["POST"])
127
  def chat():
128
+ """Main chat endpoint."""
129
+ global retrieval_chain
130
+
 
131
  if retrieval_chain is None:
132
+ return jsonify({"error": "Vector database not initialized. Try again in a few seconds."}), 500
133
 
134
  try:
135
+ user_input = request.json.get("message")
136
+ app.logger.info(f"Received user input: {user_input}")
137
  data = request.json
138
  user_query = data.get("query")
 
139
  if not user_query:
140
  return jsonify({"error": "No query provided"}), 400
141
 
142
+ print(f"πŸ’¬ Received query: {user_query}")
 
143
  start = time.process_time()
144
+
145
+ # Run retrieval chain
146
  response = retrieval_chain.invoke({'input': user_query})
147
+ elapsed = time.process_time() - start
148
+ print(f"⏱️ Response time: {elapsed:.3f} sec")
 
149
 
150
+ # Parse LLM JSON
151
  try:
 
152
  llm_output_str = response['answer']
153
+ parsed = json.loads(llm_output_str)
154
+ parsed["context"] = [doc.page_content for doc in response['context']]
155
+ return jsonify(parsed)
 
 
 
 
 
 
 
156
  except json.JSONDecodeError:
157
+ print(f"⚠️ Invalid JSON from LLM: {response.get('answer', '')}")
158
  return jsonify({"intent": "qa", "response": "I'm sorry, I had a small glitch. Could you rephrase that?"})
 
 
 
 
159
  except Exception as e:
160
+ print(f"❌ Error during chat request: {e}")
161
  return jsonify({"error": str(e)}), 500
162
 
163
+ # ==========================================================
164
+ # App Runner (for local debugging)
165
+ # ==========================================================
166
  if __name__ == "__main__":
167
+ print("πŸš€ Starting Flask development server...")
168
+ app.run(host="0.0.0.0", port=7860, debug=True)