Pipalskill commited on
Commit
bf35478
·
verified ·
1 Parent(s): 990a697

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +371 -0
app.py ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
3
+ os.environ["TRANSFORMERS_CACHE"] = "/tmp/transformers"
4
+ os.environ["HF_HOME"] = "/tmp/huggingface"
5
+ os.environ["SENTENCE_TRANSFORMERS_HOME"] = "/tmp/sentence_transformers"
6
+ os.environ["TORCH_HOME"] = "/tmp/torch"
7
+
8
+ import json
9
+ import asyncio
10
+ from fastapi import FastAPI, HTTPException, UploadFile, File, WebSocket, WebSocketDisconnect
11
+ from fastapi.middleware.cors import CORSMiddleware
12
+ from pydantic import BaseModel
13
+ from typing import Optional, Dict, Set
14
+ import chromadb
15
+ from chromadb.config import Settings
16
+ from sentence_transformers import SentenceTransformer
17
+
18
+ # Import from autonomous agent
19
+ from agent_langchain import (
20
+ process_with_agent,
21
+ get_conversation_history,
22
+ classify_ticket,
23
+ call_routing,
24
+ get_kb_collection,
25
+ encoder,
26
+ conversations
27
+ )
28
+
29
+ app = FastAPI(title="Smart Helpdesk AI Agent - Autonomous + WebSocket")
30
+
31
+ # CORS for frontend
32
+ app.add_middleware(
33
+ CORSMiddleware,
34
+ allow_origins=["*"], # Update with your frontend URL in production
35
+ allow_credentials=True,
36
+ allow_methods=["*"],
37
+ allow_headers=["*"],
38
+ )
39
+
40
+ # Request Models
41
+ class TicketRequest(BaseModel):
42
+ text: str
43
+ conversation_id: Optional[str] = None
44
+ user_email: Optional[str] = None
45
+
46
+ # WebSocket Connection Manager
47
+ class ConnectionManager:
48
+ def __init__(self):
49
+ self.active_connections: Dict[str, WebSocket] = {}
50
+
51
+ async def connect(self, websocket: WebSocket, conversation_id: str):
52
+ await websocket.accept()
53
+ self.active_connections[conversation_id] = websocket
54
+ print(f"🔌 WebSocket connected: {conversation_id}")
55
+
56
+ def disconnect(self, conversation_id: str):
57
+ if conversation_id in self.active_connections:
58
+ del self.active_connections[conversation_id]
59
+ print(f"🔌 WebSocket disconnected: {conversation_id}")
60
+
61
+ async def send_message(self, conversation_id: str, message: dict):
62
+ if conversation_id in self.active_connections:
63
+ try:
64
+ await self.active_connections[conversation_id].send_json(message)
65
+ except Exception as e:
66
+ print(f"Error sending message: {e}")
67
+ self.disconnect(conversation_id)
68
+
69
+ manager = ConnectionManager()
70
+
71
+ # Persistent Chroma settings
72
+ CHROMA_PATH = "/tmp/chroma"
73
+ COLLECTION_NAME = "knowledge_base"
74
+
75
+ # -------------------------------
76
+ # KB Setup Endpoint
77
+ # -------------------------------
78
+ @app.post("/setup")
79
+ async def setup_kb(kb_file: UploadFile = File(...)):
80
+ """Upload and index knowledge base."""
81
+ try:
82
+ content_bytes = await kb_file.read()
83
+ data = json.loads(content_bytes)
84
+
85
+ if not isinstance(data, list):
86
+ raise HTTPException(status_code=400, detail="JSON must be a list of items.")
87
+
88
+ print(f"📘 Loaded {len(data)} items from {kb_file.filename}")
89
+
90
+ chroma_client = chromadb.PersistentClient(
91
+ path=CHROMA_PATH,
92
+ settings=Settings(anonymized_telemetry=False, allow_reset=True)
93
+ )
94
+ collection = chroma_client.get_or_create_collection(COLLECTION_NAME)
95
+
96
+ if collection.count() > 0:
97
+ print(f"🧹 Clearing {collection.count()} existing records...")
98
+ collection.delete(ids=collection.get()['ids'])
99
+
100
+ texts, ids, metadatas = [], [], []
101
+ for i, item in enumerate(data):
102
+ text = item.get("answer") or item.get("text") or item.get("content") or ""
103
+ item_id = item.get("id") or str(i)
104
+ category = item.get("category", "")
105
+
106
+ if not text:
107
+ print(f"⚠️ Skipping item {i} - no text content")
108
+ continue
109
+
110
+ combined_text = f"Category: {category}. {text}" if category else text
111
+ texts.append(combined_text)
112
+ ids.append(str(item_id))
113
+ metadatas.append({"id": str(item_id), "category": category, "original_index": i})
114
+
115
+ if not texts:
116
+ raise HTTPException(status_code=400, detail="No valid text content found in JSON.")
117
+
118
+ print("🧠 Generating embeddings...")
119
+ embeddings = encoder.encode(texts, show_progress_bar=True).tolist()
120
+
121
+ print("💾 Adding to ChromaDB...")
122
+ collection.add(ids=ids, embeddings=embeddings, documents=texts, metadatas=metadatas)
123
+
124
+ # Update global reference
125
+ import agent_langchain
126
+ agent_langchain.kb_collection = collection
127
+
128
+ print(f"✅ Successfully added {collection.count()} records")
129
+ return {"message": "Knowledge base initialized", "count": collection.count()}
130
+
131
+ except json.JSONDecodeError:
132
+ raise HTTPException(status_code=400, detail="Invalid JSON file.")
133
+ except Exception as e:
134
+ import traceback
135
+ traceback.print_exc()
136
+ raise HTTPException(status_code=500, detail=f"Setup failed: {str(e)}")
137
+
138
+ # -------------------------------
139
+ # WebSocket Endpoint (REAL-TIME BIDIRECTIONAL)
140
+ # -------------------------------
141
+ @app.websocket("/ws/{conversation_id}")
142
+ async def websocket_endpoint(websocket: WebSocket, conversation_id: str):
143
+ """
144
+ WebSocket endpoint for real-time agent communication.
145
+
146
+ Client sends: {"text": "My issue description", "user_email": "user@example.com"}
147
+ Server streams:
148
+ - {"type": "status", "message": "Agent is thinking..."}
149
+ - {"type": "tool_use", "tool": "SearchKnowledgeBase", "input": "..."}
150
+ - {"type": "response", "content": "Here's the solution..."}
151
+ - {"type": "saved", "firestore_id": "abc123"}
152
+ """
153
+ await manager.connect(websocket, conversation_id)
154
+
155
+ try:
156
+ while True:
157
+ # Receive message from client
158
+ data = await websocket.receive_json()
159
+ user_message = data.get("text")
160
+ user_email = data.get("user_email")
161
+
162
+ if not user_message:
163
+ await manager.send_message(conversation_id, {
164
+ "type": "error",
165
+ "message": "No text provided"
166
+ })
167
+ continue
168
+
169
+ # Send thinking status
170
+ await manager.send_message(conversation_id, {
171
+ "type": "status",
172
+ "message": "🤔 Analyzing your request..."
173
+ })
174
+
175
+ # Callback for streaming updates
176
+ async def ws_callback(update):
177
+ await manager.send_message(conversation_id, update)
178
+
179
+ # Process with agent (in thread to avoid blocking)
180
+ loop = asyncio.get_event_loop()
181
+ result = await loop.run_in_executor(
182
+ None,
183
+ lambda: process_with_agent(
184
+ user_message=user_message,
185
+ conversation_id=conversation_id,
186
+ user_email=user_email,
187
+ callback=lambda msg: asyncio.run_coroutine_threadsafe(ws_callback(msg), loop)
188
+ )
189
+ )
190
+
191
+ # Send final response
192
+ await manager.send_message(conversation_id, {
193
+ "type": "response",
194
+ "conversation_id": result["conversation_id"],
195
+ "content": result["response"],
196
+ "status": result["status"],
197
+ "ticket_info": result.get("ticket_info", {}),
198
+ "message_count": result["message_count"],
199
+ "firestore_id": result.get("firestore_id")
200
+ })
201
+
202
+ except WebSocketDisconnect:
203
+ manager.disconnect(conversation_id)
204
+ print(f"Client disconnected: {conversation_id}")
205
+ except Exception as e:
206
+ print(f"WebSocket error: {e}")
207
+ import traceback
208
+ traceback.print_exc()
209
+ try:
210
+ await manager.send_message(conversation_id, {
211
+ "type": "error",
212
+ "message": str(e)
213
+ })
214
+ except:
215
+ pass
216
+ manager.disconnect(conversation_id)
217
+
218
+ # -------------------------------
219
+ # REST Endpoint (backward compatible)
220
+ # -------------------------------
221
+ @app.post("/orchestrate")
222
+ async def orchestrate_endpoint(ticket: TicketRequest):
223
+ """
224
+ REST endpoint for agent interaction (backward compatible).
225
+ Use WebSocket for real-time experience.
226
+ """
227
+ try:
228
+ result = process_with_agent(
229
+ user_message=ticket.text,
230
+ conversation_id=ticket.conversation_id,
231
+ user_email=ticket.user_email
232
+ )
233
+
234
+ return {
235
+ "conversation_id": result["conversation_id"],
236
+ "response": result["response"],
237
+ "status": result["status"],
238
+ "ticket_info": result.get("ticket_info", {}),
239
+ "message_count": result["message_count"],
240
+ "reasoning_trace": result.get("reasoning_trace", []),
241
+ "firestore_id": result.get("firestore_id"),
242
+ "instructions": {
243
+ "websocket": f"ws://your-domain/ws/{result['conversation_id']}",
244
+ "continue_conversation": "Include the conversation_id in your next request"
245
+ }
246
+ }
247
+
248
+ except Exception as e:
249
+ import traceback
250
+ traceback.print_exc()
251
+ raise HTTPException(status_code=500, detail=f"Agent failed: {str(e)}")
252
+
253
+ # -------------------------------
254
+ # Get Conversation History
255
+ # -------------------------------
256
+ @app.get("/conversation/{conversation_id}")
257
+ async def get_conversation(conversation_id: str):
258
+ """Retrieve full conversation history."""
259
+ conv = get_conversation_history(conversation_id)
260
+ if not conv:
261
+ raise HTTPException(status_code=404, detail="Conversation not found")
262
+
263
+ return {
264
+ "conversation_id": conversation_id,
265
+ "messages": conv["messages"],
266
+ "ticket_info": conv.get("ticket_info", {}),
267
+ "status": conv.get("status", "unknown"),
268
+ "created_at": conv["created_at"],
269
+ "message_count": len(conv["messages"])
270
+ }
271
+
272
+ # -------------------------------
273
+ # List Active Conversations
274
+ # -------------------------------
275
+ @app.get("/conversations")
276
+ async def list_conversations():
277
+ """List all active conversations."""
278
+ conv_list = []
279
+ for conv_id, conv_data in conversations.items():
280
+ conv_list.append({
281
+ "conversation_id": conv_id,
282
+ "status": conv_data.get("status", "unknown"),
283
+ "message_count": len(conv_data["messages"]),
284
+ "created_at": conv_data["created_at"],
285
+ "user_email": conv_data.get("user_email", "anonymous"),
286
+ "last_message": conv_data["messages"][-1]["content"][:100] if conv_data["messages"] else None
287
+ })
288
+
289
+ return {
290
+ "total": len(conv_list),
291
+ "conversations": sorted(conv_list, key=lambda x: x["created_at"], reverse=True)
292
+ }
293
+
294
+ # -------------------------------
295
+ # Individual Tool Endpoints (for testing)
296
+ # -------------------------------
297
+ @app.post("/classify")
298
+ async def classify_endpoint(ticket: TicketRequest):
299
+ """Test classification only."""
300
+ classification = classify_ticket(ticket.text)
301
+ return {"classification": classification}
302
+
303
+ @app.post("/route")
304
+ async def route_endpoint(ticket: TicketRequest):
305
+ """Test routing only."""
306
+ department = call_routing(ticket.text)
307
+ return {"department": department}
308
+
309
+ @app.post("/kb_query")
310
+ async def kb_query_endpoint(ticket: TicketRequest):
311
+ """Test KB query only."""
312
+ collection = get_kb_collection()
313
+ if not collection or collection.count() == 0:
314
+ raise HTTPException(status_code=400, detail="KB not set up. Call /setup first.")
315
+
316
+ try:
317
+ query_embedding = encoder.encode([ticket.text])[0].tolist()
318
+ result = collection.query(
319
+ query_embeddings=[query_embedding],
320
+ n_results=1,
321
+ include=["documents", "distances", "metadatas"]
322
+ )
323
+
324
+ if not result or not result.get('documents') or len(result['documents'][0]) == 0:
325
+ return {"answer": "No relevant KB found.", "confidence": 0.0}
326
+
327
+ best_doc = result['documents'][0][0]
328
+ best_distance = result['distances'][0][0] if result.get('distances') else 1.0
329
+ confidence = max(0.0, 1.0 - (best_distance / 2.0))
330
+
331
+ return {"answer": best_doc, "confidence": round(float(confidence), 3)}
332
+
333
+ except Exception as e:
334
+ import traceback
335
+ traceback.print_exc()
336
+ raise HTTPException(status_code=500, detail=f"KB query failed: {str(e)}")
337
+
338
+ # -------------------------------
339
+ # Health Check
340
+ # -------------------------------
341
+ @app.get("/health")
342
+ async def health():
343
+ collection = get_kb_collection()
344
+ kb_status = "initialized" if collection and collection.count() > 0 else "not initialized"
345
+ kb_count = collection.count() if collection else 0
346
+
347
+ return {
348
+ "status": "ok",
349
+ "kb_status": kb_status,
350
+ "kb_records": kb_count,
351
+ "active_conversations": len(conversations),
352
+ "active_websockets": len(manager.active_connections),
353
+ "agent_type": "Autonomous ReAct Agent with Gemini + WebSocket"
354
+ }
355
+
356
+ # -------------------------------
357
+ # Root endpoint
358
+ # -------------------------------
359
+ @app.get("/")
360
+ async def root():
361
+ return {
362
+ "message": "Smart Helpdesk AI Agent API",
363
+ "endpoints": {
364
+ "websocket": "/ws/{conversation_id}",
365
+ "rest": "/orchestrate",
366
+ "setup_kb": "/setup",
367
+ "conversations": "/conversations",
368
+ "health": "/health"
369
+ },
370
+ "documentation": "/docs"
371
+ }