PuruAI commited on
Commit
728e583
·
verified ·
1 Parent(s): 7e00014

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +152 -39
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import os
2
  import json
3
  import threading
 
4
  from typing import Dict, Any
5
 
6
  import gradio as gr
@@ -8,6 +9,7 @@ from fastapi import FastAPI, Depends, HTTPException
8
  from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
9
  import uvicorn
10
 
 
11
  from transformers import pipeline
12
 
13
  # ===== LangChain imports (fixed for 0.3.x + community modules) =====
@@ -21,11 +23,13 @@ from langchain.agents import initialize_agent, Tool
21
  from langchain.chains import LLMChain
22
  from langchain.prompts import PromptTemplate
23
  from langchain.docstore.document import Document
24
- from langchain.tools import PythonREPL # Corrected import
 
25
 
26
  # ===========================================
27
  # ENVIRONMENT VARIABLES
28
  # ===========================================
 
29
  HF_TOKEN = os.getenv("HF_TOKEN")
30
  SERPAPI_KEY = os.getenv("SERPAPI_API_KEY")
31
  JWT_SECRET = os.getenv("JWT_SECRET", "changeme123")
@@ -34,7 +38,9 @@ JWT_SECRET = os.getenv("JWT_SECRET", "changeme123")
34
  # AUTH
35
  # ===========================================
36
  security = HTTPBearer()
 
37
  def verify_jwt(credentials: HTTPAuthorizationCredentials = Depends(security)):
 
38
  token = credentials.credentials
39
  if token != JWT_SECRET:
40
  raise HTTPException(status_code=403, detail="Invalid token")
@@ -47,10 +53,18 @@ MODEL_ID = "PuruAI/Medini_Intelligence"
47
  FALLBACK_MODEL = "gpt2"
48
 
49
  def load_llm():
 
 
 
50
  try:
51
- model_pipeline = pipeline("text-generation", model=MODEL_ID, use_auth_token=HF_TOKEN)
 
52
  except Exception:
53
- model_pipeline = pipeline("text-generation", model=FALLBACK_MODEL)
 
 
 
 
54
  return HuggingFacePipeline(pipeline=model_pipeline)
55
 
56
  llm = load_llm()
@@ -61,31 +75,60 @@ llm = load_llm()
61
  embeddings = HuggingFaceEmbeddings()
62
  chroma_db = Chroma(persist_directory="./medini_memory", embedding_function=embeddings)
63
  retriever = chroma_db.as_retriever()
64
- retrieval_qa = LLMChain(
65
- llm=llm,
66
- prompt=PromptTemplate(
67
- input_variables=["question"],
68
- template="{question}"
69
- )
70
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  # ===========================================
73
  # TOOLS
74
  # ===========================================
75
  search = SerpAPIWrapper(serpapi_api_key=SERPAPI_KEY)
76
- python_tool = PythonREPL() # ✅ Updated
77
 
78
  tools = [
79
- Tool(name="Knowledge Recall", func=lambda q: retrieval_qa.run({"question": q}), description="Retrieve info from Medini memory."),
80
- Tool(name="Web Search", func=search.run, description="Search the web for info."),
81
- Tool(name="Python REPL", func=python_tool.run, description="Execute Python code."),
 
82
  ]
83
 
 
 
 
84
  # ===========================================
85
  # AGENT
86
  # ===========================================
87
  memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
88
- agent = initialize_agent(tools=tools, llm=llm, agent="conversational-react-description", memory=memory, verbose=False)
 
 
 
 
 
 
 
89
 
90
  # ===========================================
91
  # PLANNER (Autonomous Goal)
@@ -93,7 +136,7 @@ agent = initialize_agent(tools=tools, llm=llm, agent="conversational-react-descr
93
  plan_prompt = PromptTemplate(
94
  input_variables=["goal"],
95
  template="""
96
- You are Medini Planner. Decompose the high-level goal into JSON steps (max 6) with: id, name, description, tool_hint.
97
  Return JSON only.
98
  Goal: {goal}
99
  """
@@ -101,34 +144,72 @@ Goal: {goal}
101
  planner_chain = LLMChain(llm=llm, prompt=plan_prompt)
102
 
103
  def create_plan(goal: str) -> Dict[str, Any]:
 
104
  raw = planner_chain.run(goal=goal)
105
- import re
 
106
  m = re.search(r"\{.*\}", raw, flags=re.DOTALL)
107
  if not m:
108
- raise ValueError("Planner did not return JSON")
109
- return json.loads(m.group(0))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
  def execute_step(step: Dict[str, Any]) -> Dict[str, Any]:
 
112
  hint = (step.get("tool_hint") or "").lower()
113
  input_text = step.get("description")
 
 
 
 
114
  try:
 
 
115
  if "recall" in hint:
116
- output = tools[0].func(input_text)
117
  elif "search" in hint:
118
- output = tools[1].func(input_text)
119
  elif "python" in hint:
120
- output = tools[2].func(input_text)
 
 
 
 
121
  else:
 
122
  output = agent.run(input_text)
 
123
  status = "ok"
 
124
  except Exception as e:
125
- output = str(e)
126
  status = "error"
127
- chroma_db.add_documents([Document(page_content=f"Step {step['id']} - {step['name']}: {output}")])
 
 
 
128
  return {"id": step['id'], "name": step['name'], "status": status, "output": output}
129
 
130
  def execute_plan(goal: str) -> Dict[str, Any]:
131
- plan = create_plan(goal)
 
 
 
 
 
132
  results = [execute_step(step) for step in plan.get("steps", [])]
133
  return {"goal": goal, "plan": plan, "results": results}
134
 
@@ -139,11 +220,13 @@ app = FastAPI(title="Medini Agent API")
139
 
140
  @app.post("/chat")
141
  def chat_endpoint(message: str, auth: bool = Depends(verify_jwt)):
 
142
  response = agent.run(message)
143
  return {"response": response}
144
 
145
  @app.post("/goal")
146
  def goal_endpoint(goal: str, auth: bool = Depends(verify_jwt)):
 
147
  report = execute_plan(goal)
148
  return report
149
 
@@ -151,29 +234,59 @@ def goal_endpoint(goal: str, auth: bool = Depends(verify_jwt)):
151
  # GRADIO FRONTEND
152
  # ===========================================
153
  def gradio_chat(message, history):
154
- response = agent.run(message)
155
- history.append((message, response))
156
- return history, history
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
 
158
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
159
  gr.Markdown("# 🤖 Medini Autonomous Agent")
160
  gr.Markdown("Chat or submit high-level goals. Agentic AI handles reasoning, memory, and tool use.")
161
- chatbot = gr.Chatbot()
162
- msg = gr.Textbox(placeholder="Type your message...", label="Chat")
163
- goal_input = gr.Textbox(placeholder="Enter high-level goal...", label="Goal")
164
- run_goal_btn = gr.Button("Run Goal")
165
- clear_btn = gr.Button("Clear Chat")
166
- goal_output = gr.JSON()
 
 
 
 
167
 
168
- msg.submit(gradio_chat, [msg, chatbot], [chatbot, chatbot])
169
- clear_btn.click(lambda: None, None, chatbot, queue=False)
170
- run_goal_btn.click(lambda g: execute_plan(g), [goal_input], goal_output)
 
 
 
 
 
 
171
 
172
  # ===========================================
173
  # LAUNCH
174
  # ===========================================
175
  if __name__ == "__main__":
176
  def start_api():
177
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
 
 
 
178
  threading.Thread(target=start_api, daemon=True).start()
179
- demo.launch()
 
 
 
1
  import os
2
  import json
3
  import threading
4
+ import re
5
  from typing import Dict, Any
6
 
7
  import gradio as gr
 
9
  from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
10
  import uvicorn
11
 
12
+ # Note: The 'transformers' pipeline is required for the LLM loader
13
  from transformers import pipeline
14
 
15
  # ===== LangChain imports (fixed for 0.3.x + community modules) =====
 
23
  from langchain.chains import LLMChain
24
  from langchain.prompts import PromptTemplate
25
  from langchain.docstore.document import Document
26
+ # Use the correct import path for the Python REPL tool
27
+ from langchain.tools.utility import PythonREPL
28
 
29
  # ===========================================
30
  # ENVIRONMENT VARIABLES
31
  # ===========================================
32
+ # IMPORTANT: These environment variables must be set for the app to run outside of this environment
33
  HF_TOKEN = os.getenv("HF_TOKEN")
34
  SERPAPI_KEY = os.getenv("SERPAPI_API_KEY")
35
  JWT_SECRET = os.getenv("JWT_SECRET", "changeme123")
 
38
  # AUTH
39
  # ===========================================
40
  security = HTTPBearer()
41
+
42
  def verify_jwt(credentials: HTTPAuthorizationCredentials = Depends(security)):
43
+ """Verifies the custom JWT token (in this case, checking against JWT_SECRET)."""
44
  token = credentials.credentials
45
  if token != JWT_SECRET:
46
  raise HTTPException(status_code=403, detail="Invalid token")
 
53
  FALLBACK_MODEL = "gpt2"
54
 
55
  def load_llm():
56
+ """Loads the HuggingFace model pipeline, adding generation arguments for stability."""
57
+ # FIX 2: Added max_new_tokens for better agent reasoning and response length
58
+ pipeline_kwargs = {"max_new_tokens": 512, "temperature": 0.7}
59
  try:
60
+ # Load the specified model
61
+ model_pipeline = pipeline("text-generation", model=MODEL_ID, use_auth_token=HF_TOKEN, **pipeline_kwargs)
62
  except Exception:
63
+ # Fallback to a common model if the primary one fails
64
+ print(f"Warning: Failed to load {MODEL_ID}. Falling back to {FALLBACK_MODEL}.")
65
+ model_pipeline = pipeline("text-generation", model=FALLBACK_MODEL, **pipeline_kwargs)
66
+
67
+ # Wrap the pipeline in HuggingFacePipeline for LangChain integration
68
  return HuggingFacePipeline(pipeline=model_pipeline)
69
 
70
  llm = load_llm()
 
75
  embeddings = HuggingFaceEmbeddings()
76
  chroma_db = Chroma(persist_directory="./medini_memory", embedding_function=embeddings)
77
  retriever = chroma_db.as_retriever()
78
+
79
+ # FIX 1: Redefine the Retrieval QA components
80
+ qa_prompt_template = """
81
+ You are a question-answering system. Use the following context, which contains information retrieved from memory, to answer the user's question.
82
+ If the context is empty or does not contain the answer, state clearly that the information is not in memory.
83
+
84
+ Context:
85
+ {context}
86
+
87
+ Question: {question}
88
+ Answer:
89
+ """
90
+ QA_PROMPT = PromptTemplate(template=qa_prompt_template, input_variables=["context", "question"])
91
+ # This LLMChain is used specifically for answering questions based on retrieved context
92
+ qa_chain = LLMChain(llm=llm, prompt=QA_PROMPT)
93
+
94
+ def retrieve_and_answer(question: str) -> str:
95
+ """Retrieves context from Chroma DB and passes it to the QA Chain."""
96
+ # 1. Use the defined retriever to find relevant documents
97
+ docs = retriever.get_relevant_documents(question)
98
+ context = "\n---\n".join([d.page_content for d in docs])
99
+
100
+ # 2. Run the QA chain with the retrieved context
101
+ return qa_chain.run(context=context, question=question)
102
+
103
 
104
  # ===========================================
105
  # TOOLS
106
  # ===========================================
107
  search = SerpAPIWrapper(serpapi_api_key=SERPAPI_KEY)
108
+ python_tool = PythonREPL()
109
 
110
  tools = [
111
+ # FIX 1: Use the new function that correctly retrieves info from the vector store
112
+ Tool(name="Knowledge Recall", func=retrieve_and_answer, description="Retrieve info from Medini memory (Chroma DB). Use this when the answer might be in a previously executed step or private notes."),
113
+ Tool(name="Web Search", func=search.run, description="Search the web for up-to-date information."),
114
+ Tool(name="Python REPL", func=python_tool.run, description="Execute Python code, useful for math and data manipulation."),
115
  ]
116
 
117
+ # FIX 4: Create a map for robust tool execution lookup
118
+ TOOL_MAP = {tool.name.lower().replace(" ", ""): tool.func for tool in tools}
119
+
120
  # ===========================================
121
  # AGENT
122
  # ===========================================
123
  memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
124
+ # The LLM must be passed to the agent
125
+ agent = initialize_agent(
126
+ tools=tools,
127
+ llm=llm,
128
+ agent="conversational-react-description",
129
+ memory=memory,
130
+ verbose=True # Changed to True for better debugging/visibility
131
+ )
132
 
133
  # ===========================================
134
  # PLANNER (Autonomous Goal)
 
136
  plan_prompt = PromptTemplate(
137
  input_variables=["goal"],
138
  template="""
139
+ You are Medini Planner. Decompose the high-level goal into a JSON object containing a 'steps' array (max 6 steps). Each step must have: id (integer), name (short string), description (detailed instruction), and tool_hint (either 'recall', 'search', 'python', or 'agent').
140
  Return JSON only.
141
  Goal: {goal}
142
  """
 
144
  planner_chain = LLMChain(llm=llm, prompt=plan_prompt)
145
 
146
  def create_plan(goal: str) -> Dict[str, Any]:
147
+ """Generates a structured plan using the planner chain."""
148
  raw = planner_chain.run(goal=goal)
149
+
150
+ # FIX 3: Robust JSON Parsing - Find the JSON block and clean up common LLM formatting
151
  m = re.search(r"\{.*\}", raw, flags=re.DOTALL)
152
  if not m:
153
+ # Fallback to entire raw output if no braces are found
154
+ json_str = raw
155
+ else:
156
+ json_str = m.group(0)
157
+
158
+ # Clean up common markdown code fences
159
+ json_str = json_str.replace("```json", "").replace("```", "").strip()
160
+
161
+ try:
162
+ plan = json.loads(json_str)
163
+ if 'steps' not in plan:
164
+ raise ValueError("Parsed JSON is missing the 'steps' array.")
165
+ return plan
166
+ except json.JSONDecodeError as e:
167
+ print(f"JSON Parsing Error: {e} in string: {json_str[:200]}...")
168
+ raise ValueError("Planner returned malformed JSON. Check the LLM's output format.") from e
169
 
170
  def execute_step(step: Dict[str, Any]) -> Dict[str, Any]:
171
+ """Executes a single step using the appropriate tool or the main agent."""
172
  hint = (step.get("tool_hint") or "").lower()
173
  input_text = step.get("description")
174
+
175
+ output = "Execution skipped."
176
+ status = "error" # Default to error
177
+
178
  try:
179
+ tool_func = None
180
+ # FIX 4: Use string matching and TOOL_MAP lookup for robust execution
181
  if "recall" in hint:
182
+ tool_func = TOOL_MAP.get("knowledgerecall")
183
  elif "search" in hint:
184
+ tool_func = TOOL_MAP.get("websearch")
185
  elif "python" in hint:
186
+ tool_func = TOOL_MAP.get("pythonrepl")
187
+
188
+ if tool_func:
189
+ # Execute the specific tool
190
+ output = tool_func(input_text)
191
  else:
192
+ # Fallback to the main agent for generic reasoning/conversation
193
  output = agent.run(input_text)
194
+
195
  status = "ok"
196
+
197
  except Exception as e:
198
+ output = f"Execution Error: {str(e)}"
199
  status = "error"
200
+
201
+ # Add the result of the step execution to the vector memory
202
+ chroma_db.add_documents([Document(page_content=f"Step {step['id']} - {step['name']} Result: {output}")])
203
+
204
  return {"id": step['id'], "name": step['name'], "status": status, "output": output}
205
 
206
  def execute_plan(goal: str) -> Dict[str, Any]:
207
+ """Creates a plan and executes all steps sequentially."""
208
+ try:
209
+ plan = create_plan(goal)
210
+ except ValueError as e:
211
+ return {"goal": goal, "error": str(e)}
212
+
213
  results = [execute_step(step) for step in plan.get("steps", [])]
214
  return {"goal": goal, "plan": plan, "results": results}
215
 
 
220
 
221
  @app.post("/chat")
222
  def chat_endpoint(message: str, auth: bool = Depends(verify_jwt)):
223
+ """API endpoint for basic conversational chat."""
224
  response = agent.run(message)
225
  return {"response": response}
226
 
227
  @app.post("/goal")
228
  def goal_endpoint(goal: str, auth: bool = Depends(verify_jwt)):
229
+ """API endpoint for executing autonomous goals."""
230
  report = execute_plan(goal)
231
  return report
232
 
 
234
  # GRADIO FRONTEND
235
  # ===========================================
236
  def gradio_chat(message, history):
237
+ """Gradio function for the chat interface."""
238
+ # The agent manages its own history/memory
239
+ try:
240
+ response = agent.run(message)
241
+ history.append((message, response))
242
+ except Exception as e:
243
+ history.append((message, f"An error occurred: {str(e)}"))
244
+
245
+ return history, "" # Return history and clear the input box
246
+
247
+ def gradio_execute_plan(goal):
248
+ """Gradio function to execute the full autonomous plan."""
249
+ try:
250
+ report = execute_plan(goal)
251
+ return report
252
+ except Exception as e:
253
+ return {"error": f"Failed to execute plan: {str(e)}"}
254
 
255
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
256
  gr.Markdown("# 🤖 Medini Autonomous Agent")
257
  gr.Markdown("Chat or submit high-level goals. Agentic AI handles reasoning, memory, and tool use.")
258
+
259
+ with gr.Row():
260
+ with gr.Column(scale=2):
261
+ gr.Markdown("## Conversational Chat")
262
+ chatbot = gr.Chatbot(height=400)
263
+ msg = gr.Textbox(placeholder="Type your message...", label="Chat Input")
264
+ clear_btn = gr.Button("Clear Chat")
265
+
266
+ msg.submit(gradio_chat, [msg, chatbot], [chatbot, msg])
267
+ clear_btn.click(lambda: [], None, chatbot, queue=False)
268
 
269
+ with gr.Column(scale=1):
270
+ gr.Markdown("## Autonomous Goal Planner")
271
+ goal_input = gr.Textbox(placeholder="Enter high-level goal (e.g., 'Research the latest quarterly earnings of Tesla and save the key points').", label="Goal")
272
+ run_goal_btn = gr.Button("Run Goal", variant="primary")
273
+ gr.Markdown("---")
274
+ gr.Markdown("### Execution Report")
275
+ goal_output = gr.JSON(label="Plan and Results")
276
+
277
+ run_goal_btn.click(gradio_execute_plan, [goal_input], goal_output)
278
 
279
  # ===========================================
280
  # LAUNCH
281
  # ===========================================
282
  if __name__ == "__main__":
283
  def start_api():
284
+ """Starts the FastAPI server in a separate thread."""
285
+ # Use log_level="critical" to reduce noisy logs from uvicorn in the console
286
+ uvicorn.run(app, host="0.0.0.0", port=8000, log_level="critical")
287
+
288
+ # Start the API in the background
289
  threading.Thread(target=start_api, daemon=True).start()
290
+
291
+ # Launch the Gradio interface
292
+ demo.launch(share=False)