PuruAI commited on
Commit
eb720d3
·
verified ·
1 Parent(s): 728e583

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -76
app.py CHANGED
@@ -23,13 +23,13 @@ from langchain.agents import initialize_agent, Tool
23
  from langchain.chains import LLMChain
24
  from langchain.prompts import PromptTemplate
25
  from langchain.docstore.document import Document
26
- # Use the correct import path for the Python REPL tool
27
- from langchain.tools.utility import PythonREPL
 
28
 
29
  # ===========================================
30
  # ENVIRONMENT VARIABLES
31
  # ===========================================
32
- # IMPORTANT: These environment variables must be set for the app to run outside of this environment
33
  HF_TOKEN = os.getenv("HF_TOKEN")
34
  SERPAPI_KEY = os.getenv("SERPAPI_API_KEY")
35
  JWT_SECRET = os.getenv("JWT_SECRET", "changeme123")
@@ -40,7 +40,6 @@ JWT_SECRET = os.getenv("JWT_SECRET", "changeme123")
40
  security = HTTPBearer()
41
 
42
  def verify_jwt(credentials: HTTPAuthorizationCredentials = Depends(security)):
43
- """Verifies the custom JWT token (in this case, checking against JWT_SECRET)."""
44
  token = credentials.credentials
45
  if token != JWT_SECRET:
46
  raise HTTPException(status_code=403, detail="Invalid token")
@@ -53,18 +52,12 @@ MODEL_ID = "PuruAI/Medini_Intelligence"
53
  FALLBACK_MODEL = "gpt2"
54
 
55
  def load_llm():
56
- """Loads the HuggingFace model pipeline, adding generation arguments for stability."""
57
- # FIX 2: Added max_new_tokens for better agent reasoning and response length
58
  pipeline_kwargs = {"max_new_tokens": 512, "temperature": 0.7}
59
  try:
60
- # Load the specified model
61
  model_pipeline = pipeline("text-generation", model=MODEL_ID, use_auth_token=HF_TOKEN, **pipeline_kwargs)
62
  except Exception:
63
- # Fallback to a common model if the primary one fails
64
  print(f"Warning: Failed to load {MODEL_ID}. Falling back to {FALLBACK_MODEL}.")
65
  model_pipeline = pipeline("text-generation", model=FALLBACK_MODEL, **pipeline_kwargs)
66
-
67
- # Wrap the pipeline in HuggingFacePipeline for LangChain integration
68
  return HuggingFacePipeline(pipeline=model_pipeline)
69
 
70
  llm = load_llm()
@@ -76,31 +69,22 @@ embeddings = HuggingFaceEmbeddings()
76
  chroma_db = Chroma(persist_directory="./medini_memory", embedding_function=embeddings)
77
  retriever = chroma_db.as_retriever()
78
 
79
- # FIX 1: Redefine the Retrieval QA components
80
  qa_prompt_template = """
81
  You are a question-answering system. Use the following context, which contains information retrieved from memory, to answer the user's question.
82
  If the context is empty or does not contain the answer, state clearly that the information is not in memory.
83
-
84
  Context:
85
  {context}
86
-
87
  Question: {question}
88
  Answer:
89
  """
90
  QA_PROMPT = PromptTemplate(template=qa_prompt_template, input_variables=["context", "question"])
91
- # This LLMChain is used specifically for answering questions based on retrieved context
92
  qa_chain = LLMChain(llm=llm, prompt=QA_PROMPT)
93
 
94
  def retrieve_and_answer(question: str) -> str:
95
- """Retrieves context from Chroma DB and passes it to the QA Chain."""
96
- # 1. Use the defined retriever to find relevant documents
97
  docs = retriever.get_relevant_documents(question)
98
  context = "\n---\n".join([d.page_content for d in docs])
99
-
100
- # 2. Run the QA chain with the retrieved context
101
  return qa_chain.run(context=context, question=question)
102
 
103
-
104
  # ===========================================
105
  # TOOLS
106
  # ===========================================
@@ -108,26 +92,23 @@ search = SerpAPIWrapper(serpapi_api_key=SERPAPI_KEY)
108
  python_tool = PythonREPL()
109
 
110
  tools = [
111
- # FIX 1: Use the new function that correctly retrieves info from the vector store
112
- Tool(name="Knowledge Recall", func=retrieve_and_answer, description="Retrieve info from Medini memory (Chroma DB). Use this when the answer might be in a previously executed step or private notes."),
113
  Tool(name="Web Search", func=search.run, description="Search the web for up-to-date information."),
114
  Tool(name="Python REPL", func=python_tool.run, description="Execute Python code, useful for math and data manipulation."),
115
  ]
116
 
117
- # FIX 4: Create a map for robust tool execution lookup
118
  TOOL_MAP = {tool.name.lower().replace(" ", ""): tool.func for tool in tools}
119
 
120
  # ===========================================
121
  # AGENT
122
  # ===========================================
123
  memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
124
- # The LLM must be passed to the agent
125
  agent = initialize_agent(
126
- tools=tools,
127
- llm=llm,
128
- agent="conversational-react-description",
129
- memory=memory,
130
- verbose=True # Changed to True for better debugging/visibility
131
  )
132
 
133
  # ===========================================
@@ -144,40 +125,23 @@ Goal: {goal}
144
  planner_chain = LLMChain(llm=llm, prompt=plan_prompt)
145
 
146
  def create_plan(goal: str) -> Dict[str, Any]:
147
- """Generates a structured plan using the planner chain."""
148
  raw = planner_chain.run(goal=goal)
149
-
150
- # FIX 3: Robust JSON Parsing - Find the JSON block and clean up common LLM formatting
151
  m = re.search(r"\{.*\}", raw, flags=re.DOTALL)
152
- if not m:
153
- # Fallback to entire raw output if no braces are found
154
- json_str = raw
155
- else:
156
- json_str = m.group(0)
157
-
158
- # Clean up common markdown code fences
159
- json_str = json_str.replace("```json", "").replace("```", "").strip()
160
-
161
- try:
162
- plan = json.loads(json_str)
163
- if 'steps' not in plan:
164
- raise ValueError("Parsed JSON is missing the 'steps' array.")
165
- return plan
166
- except json.JSONDecodeError as e:
167
- print(f"JSON Parsing Error: {e} in string: {json_str[:200]}...")
168
- raise ValueError("Planner returned malformed JSON. Check the LLM's output format.") from e
169
 
170
  def execute_step(step: Dict[str, Any]) -> Dict[str, Any]:
171
- """Executes a single step using the appropriate tool or the main agent."""
172
  hint = (step.get("tool_hint") or "").lower()
173
  input_text = step.get("description")
174
-
175
  output = "Execution skipped."
176
- status = "error" # Default to error
177
 
178
  try:
179
  tool_func = None
180
- # FIX 4: Use string matching and TOOL_MAP lookup for robust execution
181
  if "recall" in hint:
182
  tool_func = TOOL_MAP.get("knowledgerecall")
183
  elif "search" in hint:
@@ -186,30 +150,22 @@ def execute_step(step: Dict[str, Any]) -> Dict[str, Any]:
186
  tool_func = TOOL_MAP.get("pythonrepl")
187
 
188
  if tool_func:
189
- # Execute the specific tool
190
  output = tool_func(input_text)
191
  else:
192
- # Fallback to the main agent for generic reasoning/conversation
193
  output = agent.run(input_text)
194
-
195
  status = "ok"
196
-
197
  except Exception as e:
198
  output = f"Execution Error: {str(e)}"
199
  status = "error"
200
-
201
- # Add the result of the step execution to the vector memory
202
  chroma_db.add_documents([Document(page_content=f"Step {step['id']} - {step['name']} Result: {output}")])
203
-
204
  return {"id": step['id'], "name": step['name'], "status": status, "output": output}
205
 
206
  def execute_plan(goal: str) -> Dict[str, Any]:
207
- """Creates a plan and executes all steps sequentially."""
208
  try:
209
  plan = create_plan(goal)
210
  except ValueError as e:
211
  return {"goal": goal, "error": str(e)}
212
-
213
  results = [execute_step(step) for step in plan.get("steps", [])]
214
  return {"goal": goal, "plan": plan, "results": results}
215
 
@@ -220,13 +176,11 @@ app = FastAPI(title="Medini Agent API")
220
 
221
  @app.post("/chat")
222
  def chat_endpoint(message: str, auth: bool = Depends(verify_jwt)):
223
- """API endpoint for basic conversational chat."""
224
  response = agent.run(message)
225
  return {"response": response}
226
 
227
  @app.post("/goal")
228
  def goal_endpoint(goal: str, auth: bool = Depends(verify_jwt)):
229
- """API endpoint for executing autonomous goals."""
230
  report = execute_plan(goal)
231
  return report
232
 
@@ -234,21 +188,16 @@ def goal_endpoint(goal: str, auth: bool = Depends(verify_jwt)):
234
  # GRADIO FRONTEND
235
  # ===========================================
236
  def gradio_chat(message, history):
237
- """Gradio function for the chat interface."""
238
- # The agent manages its own history/memory
239
  try:
240
  response = agent.run(message)
241
  history.append((message, response))
242
  except Exception as e:
243
  history.append((message, f"An error occurred: {str(e)}"))
244
-
245
- return history, "" # Return history and clear the input box
246
 
247
  def gradio_execute_plan(goal):
248
- """Gradio function to execute the full autonomous plan."""
249
  try:
250
- report = execute_plan(goal)
251
- return report
252
  except Exception as e:
253
  return {"error": f"Failed to execute plan: {str(e)}"}
254
 
@@ -268,7 +217,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
268
 
269
  with gr.Column(scale=1):
270
  gr.Markdown("## Autonomous Goal Planner")
271
- goal_input = gr.Textbox(placeholder="Enter high-level goal (e.g., 'Research the latest quarterly earnings of Tesla and save the key points').", label="Goal")
272
  run_goal_btn = gr.Button("Run Goal", variant="primary")
273
  gr.Markdown("---")
274
  gr.Markdown("### Execution Report")
@@ -281,12 +230,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
281
  # ===========================================
282
  if __name__ == "__main__":
283
  def start_api():
284
- """Starts the FastAPI server in a separate thread."""
285
- # Use log_level="critical" to reduce noisy logs from uvicorn in the console
286
  uvicorn.run(app, host="0.0.0.0", port=8000, log_level="critical")
287
 
288
- # Start the API in the background
289
  threading.Thread(target=start_api, daemon=True).start()
290
-
291
- # Launch the Gradio interface
292
  demo.launch(share=False)
 
23
  from langchain.chains import LLMChain
24
  from langchain.prompts import PromptTemplate
25
  from langchain.docstore.document import Document
26
+
27
+ # Correct import for Python REPL tool
28
+ from langchain.tools import PythonREPL
29
 
30
  # ===========================================
31
  # ENVIRONMENT VARIABLES
32
  # ===========================================
 
33
  HF_TOKEN = os.getenv("HF_TOKEN")
34
  SERPAPI_KEY = os.getenv("SERPAPI_API_KEY")
35
  JWT_SECRET = os.getenv("JWT_SECRET", "changeme123")
 
40
  security = HTTPBearer()
41
 
42
  def verify_jwt(credentials: HTTPAuthorizationCredentials = Depends(security)):
 
43
  token = credentials.credentials
44
  if token != JWT_SECRET:
45
  raise HTTPException(status_code=403, detail="Invalid token")
 
52
  FALLBACK_MODEL = "gpt2"
53
 
54
  def load_llm():
 
 
55
  pipeline_kwargs = {"max_new_tokens": 512, "temperature": 0.7}
56
  try:
 
57
  model_pipeline = pipeline("text-generation", model=MODEL_ID, use_auth_token=HF_TOKEN, **pipeline_kwargs)
58
  except Exception:
 
59
  print(f"Warning: Failed to load {MODEL_ID}. Falling back to {FALLBACK_MODEL}.")
60
  model_pipeline = pipeline("text-generation", model=FALLBACK_MODEL, **pipeline_kwargs)
 
 
61
  return HuggingFacePipeline(pipeline=model_pipeline)
62
 
63
  llm = load_llm()
 
69
  chroma_db = Chroma(persist_directory="./medini_memory", embedding_function=embeddings)
70
  retriever = chroma_db.as_retriever()
71
 
 
72
  qa_prompt_template = """
73
  You are a question-answering system. Use the following context, which contains information retrieved from memory, to answer the user's question.
74
  If the context is empty or does not contain the answer, state clearly that the information is not in memory.
 
75
  Context:
76
  {context}
 
77
  Question: {question}
78
  Answer:
79
  """
80
  QA_PROMPT = PromptTemplate(template=qa_prompt_template, input_variables=["context", "question"])
 
81
  qa_chain = LLMChain(llm=llm, prompt=QA_PROMPT)
82
 
83
  def retrieve_and_answer(question: str) -> str:
 
 
84
  docs = retriever.get_relevant_documents(question)
85
  context = "\n---\n".join([d.page_content for d in docs])
 
 
86
  return qa_chain.run(context=context, question=question)
87
 
 
88
  # ===========================================
89
  # TOOLS
90
  # ===========================================
 
92
  python_tool = PythonREPL()
93
 
94
  tools = [
95
+ Tool(name="Knowledge Recall", func=retrieve_and_answer, description="Retrieve info from Medini memory."),
 
96
  Tool(name="Web Search", func=search.run, description="Search the web for up-to-date information."),
97
  Tool(name="Python REPL", func=python_tool.run, description="Execute Python code, useful for math and data manipulation."),
98
  ]
99
 
 
100
  TOOL_MAP = {tool.name.lower().replace(" ", ""): tool.func for tool in tools}
101
 
102
  # ===========================================
103
  # AGENT
104
  # ===========================================
105
  memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
 
106
  agent = initialize_agent(
107
+ tools=tools,
108
+ llm=llm,
109
+ agent="conversational-react-description",
110
+ memory=memory,
111
+ verbose=True
112
  )
113
 
114
  # ===========================================
 
125
  planner_chain = LLMChain(llm=llm, prompt=plan_prompt)
126
 
127
  def create_plan(goal: str) -> Dict[str, Any]:
 
128
  raw = planner_chain.run(goal=goal)
 
 
129
  m = re.search(r"\{.*\}", raw, flags=re.DOTALL)
130
+ json_str = m.group(0) if m else raw
131
+ json_str = json_str.replace("```json", "").replace("```", "").strip()
132
+ plan = json.loads(json_str)
133
+ if 'steps' not in plan:
134
+ raise ValueError("Parsed JSON is missing the 'steps' array.")
135
+ return plan
 
 
 
 
 
 
 
 
 
 
 
136
 
137
  def execute_step(step: Dict[str, Any]) -> Dict[str, Any]:
 
138
  hint = (step.get("tool_hint") or "").lower()
139
  input_text = step.get("description")
 
140
  output = "Execution skipped."
141
+ status = "error"
142
 
143
  try:
144
  tool_func = None
 
145
  if "recall" in hint:
146
  tool_func = TOOL_MAP.get("knowledgerecall")
147
  elif "search" in hint:
 
150
  tool_func = TOOL_MAP.get("pythonrepl")
151
 
152
  if tool_func:
 
153
  output = tool_func(input_text)
154
  else:
 
155
  output = agent.run(input_text)
 
156
  status = "ok"
 
157
  except Exception as e:
158
  output = f"Execution Error: {str(e)}"
159
  status = "error"
160
+
 
161
  chroma_db.add_documents([Document(page_content=f"Step {step['id']} - {step['name']} Result: {output}")])
 
162
  return {"id": step['id'], "name": step['name'], "status": status, "output": output}
163
 
164
  def execute_plan(goal: str) -> Dict[str, Any]:
 
165
  try:
166
  plan = create_plan(goal)
167
  except ValueError as e:
168
  return {"goal": goal, "error": str(e)}
 
169
  results = [execute_step(step) for step in plan.get("steps", [])]
170
  return {"goal": goal, "plan": plan, "results": results}
171
 
 
176
 
177
  @app.post("/chat")
178
  def chat_endpoint(message: str, auth: bool = Depends(verify_jwt)):
 
179
  response = agent.run(message)
180
  return {"response": response}
181
 
182
  @app.post("/goal")
183
  def goal_endpoint(goal: str, auth: bool = Depends(verify_jwt)):
 
184
  report = execute_plan(goal)
185
  return report
186
 
 
188
  # GRADIO FRONTEND
189
  # ===========================================
190
  def gradio_chat(message, history):
 
 
191
  try:
192
  response = agent.run(message)
193
  history.append((message, response))
194
  except Exception as e:
195
  history.append((message, f"An error occurred: {str(e)}"))
196
+ return history, ""
 
197
 
198
  def gradio_execute_plan(goal):
 
199
  try:
200
+ return execute_plan(goal)
 
201
  except Exception as e:
202
  return {"error": f"Failed to execute plan: {str(e)}"}
203
 
 
217
 
218
  with gr.Column(scale=1):
219
  gr.Markdown("## Autonomous Goal Planner")
220
+ goal_input = gr.Textbox(placeholder="Enter high-level goal.", label="Goal")
221
  run_goal_btn = gr.Button("Run Goal", variant="primary")
222
  gr.Markdown("---")
223
  gr.Markdown("### Execution Report")
 
230
  # ===========================================
231
  if __name__ == "__main__":
232
  def start_api():
 
 
233
  uvicorn.run(app, host="0.0.0.0", port=8000, log_level="critical")
234
 
 
235
  threading.Thread(target=start_api, daemon=True).start()
 
 
236
  demo.launch(share=False)