dylanglenister commited on
Commit
89a41e2
·
1 Parent(s): 521a291

Updating session route.

Browse files

Combining chat route with it.
Splitting summarise into its own route.

src/api/routes/chat.py DELETED
@@ -1,134 +0,0 @@
1
- # api/routes/chat.py
2
-
3
- import time
4
- from datetime import datetime, timezone
5
-
6
- from fastapi import APIRouter, Depends, HTTPException
7
-
8
- from src.core.state import MedicalState, get_state
9
- from src.models.chat import ChatRequest, ChatResponse, SummariseRequest
10
- from src.services.medical_response import generate_medical_response
11
- from src.services.summariser import summarise_title_with_nvidia
12
- from src.utils.logger import logger
13
-
14
- router = APIRouter(tags=["Chat"])
15
-
16
- @router.post("/chat", response_model=ChatResponse)
17
- async def chat_endpoint(
18
- request: ChatRequest,
19
- state: MedicalState = Depends(get_state)
20
- ):
21
- """
22
- Process a chat message, generate response, and persist short-term cache + long-term Mongo.
23
- """
24
- start_time = time.time()
25
-
26
- logger().info(f"POST /chat user={request.account_id} session={request.session_id} patient={request.patient_id}")
27
- logger().info(f"Message: {request.message[:100]}...") # Log first 100 chars of message
28
-
29
- # Currently completely pointless
30
- #user_profile = state.memory_system.get_user(request.account_id)
31
- #if not user_profile:
32
- # state.memory_system.create_user()
33
-
34
- session = None
35
- session_id = request.session_id
36
-
37
- #if request.session_id != "default":
38
- if session_id:
39
- try:
40
- session = state.memory_system.get_session(session_id)
41
- except Exception as e:
42
- logger().error(f"Error retrieving session: {e}")
43
- raise HTTPException(status_code=400, detail=f"Invalid session ID ({session_id}): {str(e)}")
44
-
45
- if not session:
46
- if not request.patient_id:
47
- raise HTTPException(status_code=400, detail="patient_id required for new sessions")
48
-
49
- session = state.memory_system.create_session(
50
- request.account_id,
51
- request.patient_id,
52
- "New Chat"
53
- )
54
-
55
- logger().info(f"Created new session: {session.session_id}")
56
-
57
- session_id = session.session_id
58
-
59
- try:
60
- # Get enhanced medical context with STM + LTM semantic search + NVIDIA reasoning
61
- medical_context = await state.history_manager.get_enhanced_conversation_context(
62
- request.account_id,
63
- session_id,
64
- request.message,
65
- state.nvidia_rotator,
66
- patient_id=request.patient_id
67
- )
68
- except Exception as e:
69
- logger().error(f"Error getting medical context: {e}")
70
- logger().error(f"Request data: {request.model_dump()}")
71
- raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
72
-
73
- try:
74
- # Generate response using Gemini AI
75
- logger().info(f"Generating medical response using Gemini AI for user {request.account_id}")
76
- response = await generate_medical_response(
77
- request.message,
78
- "Medical Professional",
79
- "",
80
- state.gemini_rotator,
81
- medical_context
82
- )
83
-
84
- # Process and store the exchange
85
- try:
86
- await state.history_manager.process_medical_exchange(
87
- request.account_id,
88
- session_id,
89
- request.message,
90
- response,
91
- state.gemini_rotator,
92
- state.nvidia_rotator,
93
- patient_id=request.patient_id,
94
- doctor_id=request.account_id,
95
- session_title="New Chat"
96
- )
97
- except Exception as e:
98
- logger().warning(f"Failed to process medical exchange: {e}")
99
- # Continue without storing if there's an error
100
-
101
- # Calculate response time
102
- response_time = time.time() - start_time
103
-
104
- logger().info(f"Generated response in {response_time:.2f}s for user {request.account_id}")
105
- logger().info(f"Response length: {len(response)} characters")
106
-
107
- return ChatResponse(
108
- response=response,
109
- session_id=session_id,
110
- timestamp=datetime.now(timezone.utc).isoformat(),
111
- medical_context=medical_context if medical_context else None
112
- )
113
-
114
- except Exception as e:
115
- logger().error(f"Error in chat endpoint: {e}")
116
- logger().error(f"Request data: {request.model_dump()}")
117
- raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
118
-
119
- @router.post("/summarise")
120
- async def summarise_endpoint(
121
- req: SummariseRequest,
122
- state: MedicalState = Depends(get_state)
123
- ):
124
- """Summarise a text into a short 3-5 word title using NVIDIA if available."""
125
- try:
126
- title = await summarise_title_with_nvidia(
127
- req.text,
128
- state.nvidia_rotator,
129
- max_words=min(max(req.max_words or 5, 3), 7)
130
- )
131
- return {"title": title}
132
- except Exception as e:
133
- logger().error(f"Error summarising title: {e}")
134
- raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/api/routes/session.py CHANGED
@@ -1,87 +1,132 @@
1
- # api/routes/session.py
2
 
3
- from datetime import datetime
4
 
5
- from fastapi import APIRouter, Depends, HTTPException
6
 
7
- from src.core.state import MedicalState, get_state
8
- from src.data.repositories.session import delete_session, get_session_messages
9
- from src.models.chat import SessionRequest
 
10
  from src.utils.logger import logger
11
 
12
- router = APIRouter(prefix="/session", tags=["Session"])
13
 
14
- @router.post("")
 
15
  async def create_chat_session(
16
- request: SessionRequest,
17
- state: MedicalState = Depends(get_state)
18
  ):
19
- """Create a new chat session (cache + Mongo)"""
20
- try:
21
- logger().info(f"POST /session user_id={request.account_id} patient_id={request.patient_id}")
22
- session_id = state.memory_system.create_session(request.account_id, request.title or "New Chat")
23
- # Also ensure in Mongo with patient/doctor
24
- #ensure_session(session_id=session_id, patient_id=request.patient_id, doctor_id=request.doctor_id, title=request.title or "New Chat")
25
- return {"session_id": session_id, "message": "Session created successfully"}
26
- except Exception as e:
27
- logger().error(f"Error creating session: {e}")
28
- raise HTTPException(status_code=500, detail=str(e))
29
 
30
- @router.get("/{session_id}")
 
31
  async def get_chat_session(
32
  session_id: str,
33
- state: MedicalState = Depends(get_state)
34
  ):
35
- """Get session from cache (for quick preview)"""
36
- try:
37
- session = state.memory_system.get_session(session_id)
38
- if not session:
39
- raise HTTPException(status_code=404, detail="Session not found")
 
40
 
41
- return session.to_dict()
42
- except HTTPException:
43
- raise
44
- except Exception as e:
45
- logger().error(f"Error getting session: {e}")
46
- raise HTTPException(status_code=500, detail=str(e))
47
 
48
- @router.get("/{session_id}/messages")
49
- async def list_messages_for_session(session_id: str, limit: int | None = None):
50
- """List messages for a session from Mongo, verified to belong to the patient"""
51
- try:
52
- logger().info(f"GET /session/{session_id}/messages limit={limit}")
53
- msgs = get_session_messages(session_id, limit)
54
- # ensure JSON-friendly timestamps
55
- for m in msgs:
56
- if isinstance(m.get("timestamp"), datetime):
57
- m["timestamp"] = m["timestamp"].isoformat()
58
- m["_id"] = str(m["_id"]) if "_id" in m else None
59
- return {"messages": msgs}
60
- except Exception as e:
61
- logger().error(f"Error listing messages: {e}")
62
- raise HTTPException(status_code=500, detail=str(e))
63
-
64
- @router.delete("/{session_id}")
65
  async def delete_chat_session(
66
  session_id: str,
67
- state: MedicalState = Depends(get_state)
68
  ):
69
- """Delete a chat session from both memory system and MongoDB"""
70
- try:
71
- logger().info(f"DELETE /session/{session_id}")
 
 
 
 
 
72
 
73
- # Delete from memory system
74
- state.memory_system.delete_session(session_id)
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
- # Delete from MongoDB
77
- session_deleted = delete_session(session_id)
78
 
79
- logger().info(f"Deleted session {session_id}: session={session_deleted}")
 
 
 
 
 
 
 
 
 
 
80
 
81
- return {
82
- "message": "Session deleted successfully",
83
- "session_deleted": session_deleted
84
- }
 
 
 
 
85
  except Exception as e:
86
- logger().error(f"Error deleting session: {e}")
87
- raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/api/routes/chat.py
2
 
3
+ from datetime import datetime, timezone
4
 
5
+ from fastapi import APIRouter, Depends, HTTPException, status
6
 
7
+ from src.core.state import AppState, get_state
8
+ from src.models.session import (ChatRequest, ChatResponse, Message, Session,
9
+ SessionCreateRequest)
10
+ from src.services.medical_response import generate_medical_response
11
  from src.utils.logger import logger
12
 
13
+ router = APIRouter(prefix="/session", tags=["Session & Chat"])
14
 
15
+
16
+ @router.post("", response_model=Session, status_code=status.HTTP_201_CREATED)
17
  async def create_chat_session(
18
+ req: SessionCreateRequest,
19
+ state: AppState = Depends(get_state)
20
  ):
21
+ """Creates a new, empty chat session."""
22
+ logger().info(f"POST /session for patient_id={req.patient_id}")
23
+ session = state.memory_manager.create_session(
24
+ user_id=req.account_id,
25
+ patient_id=req.patient_id,
26
+ title=req.title or "New Chat"
27
+ )
28
+ if not session:
29
+ raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to create session.")
30
+ return session
31
 
32
+
33
+ @router.get("/{session_id}", response_model=Session)
34
  async def get_chat_session(
35
  session_id: str,
36
+ state: AppState = Depends(get_state)
37
  ):
38
+ """Retrieves a session's metadata and all its messages."""
39
+ logger().info(f"GET /session/{session_id}")
40
+ session = state.memory_manager.get_session(session_id)
41
+ if not session:
42
+ raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Session not found")
43
+ return session
44
 
 
 
 
 
 
 
45
 
46
+ @router.delete("/{session_id}", status_code=status.HTTP_204_NO_CONTENT)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  async def delete_chat_session(
48
  session_id: str,
49
+ state: AppState = Depends(get_state)
50
  ):
51
+ """Deletes a chat session permanently."""
52
+ logger().info(f"DELETE /session/{session_id}")
53
+ # UPDATED CALL
54
+ success = state.memory_manager.delete_session(session_id)
55
+ if not success:
56
+ raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Session not found or already deleted")
57
+ return None
58
+
59
 
60
+ @router.get("/{session_id}/messages", response_model=list[Message])
61
+ async def list_messages_for_session(
62
+ session_id: str,
63
+ limit: int | None = None,
64
+ state: AppState = Depends(get_state)
65
+ ):
66
+ """Lists all messages for a specific session from the database."""
67
+ logger().info(f"GET /session/{session_id}/messages limit={limit}")
68
+ if not state.memory_manager.get_session(session_id):
69
+ raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Session not found")
70
+
71
+ # UPDATED CALL
72
+ messages = state.memory_manager.get_session_messages(session_id, limit)
73
+ return messages
74
 
 
 
75
 
76
+ @router.post("/{session_id}/messages", response_model=ChatResponse)
77
+ async def post_chat_message(
78
+ session_id: str,
79
+ req: ChatRequest,
80
+ state: AppState = Depends(get_state)
81
+ ):
82
+ """
83
+ Posts a message to a session, gets a generated medical response,
84
+ and persists the full exchange to long-term memory.
85
+ """
86
+ logger().info(f"POST /session/{session_id}/messages")
87
 
88
+ # 1. Get Enhanced Context
89
+ try:
90
+ medical_context = await state.memory_manager.get_enhanced_context(
91
+ session_id=session_id,
92
+ patient_id=req.patient_id,
93
+ question=req.message,
94
+ nvidia_rotator=state.nvidia_rotator
95
+ )
96
  except Exception as e:
97
+ logger().error(f"Error getting medical context: {e}")
98
+ raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to build medical context.")
99
+
100
+ # 2. Generate AI Response
101
+ try:
102
+ # In a real app, user role/specialty would come from the authenticated user
103
+ response_text = await generate_medical_response(
104
+ user_message=req.message,
105
+ user_role="Medical Professional",
106
+ user_specialty="",
107
+ rotator=state.gemini_rotator,
108
+ medical_context=medical_context
109
+ )
110
+ except Exception as e:
111
+ logger().error(f"Error generating medical response: {e}")
112
+ raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to generate AI response.")
113
+
114
+ # 3. Process and Store the Exchange
115
+ summary = await state.memory_manager.process_medical_exchange(
116
+ session_id=session_id,
117
+ patient_id=req.patient_id,
118
+ doctor_id=req.account_id,
119
+ question=req.message,
120
+ answer=response_text,
121
+ gemini_rotator=state.gemini_rotator,
122
+ nvidia_rotator=state.nvidia_rotator
123
+ )
124
+ if not summary:
125
+ logger().warning(f"Failed to process and store medical exchange for session {session_id}")
126
+
127
+ return ChatResponse(
128
+ response=response_text,
129
+ session_id=session_id,
130
+ timestamp=datetime.now(timezone.utc),
131
+ medical_context=medical_context
132
+ )
src/api/routes/summarise.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/api/route/summarise.py
2
+
3
+ from fastapi import APIRouter, Depends, HTTPException, status
4
+
5
+ from src.core.state import AppState, get_state
6
+ from src.models.summarise import SummariseRequest, SummariseResponse
7
+ from src.services import summariser
8
+ from src.utils.logger import logger
9
+
10
+ router = APIRouter(prefix="/summarise", tags=["Utilities"])
11
+
12
+
13
+ @router.post("", response_model=SummariseResponse)
14
+ async def summarise_text_endpoint(
15
+ req: SummariseRequest,
16
+ state: AppState = Depends(get_state)
17
+ ):
18
+ """
19
+ Summarises a given text into a short title using an available AI model.
20
+ """
21
+ logger().info(f"POST /summarise for text starting with: '{req.text[:50]}...'")
22
+ try:
23
+ title = await summariser.summarise_title_with_nvidia(
24
+ text=req.text,
25
+ rotator=state.nvidia_rotator,
26
+ max_words=req.max_words
27
+ )
28
+ if not title:
29
+ raise HTTPException(
30
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
31
+ detail="Failed to generate a summary title."
32
+ )
33
+
34
+ return SummariseResponse(title=title)
35
+ except Exception as e:
36
+ logger().error(f"Error summarising title: {e}")
37
+ raise HTTPException(
38
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
39
+ detail="An unexpected error occurred while generating the summary."
40
+ )
src/core/memory_manager.py CHANGED
@@ -7,7 +7,7 @@ from src.data.repositories import patient as patient_repo
7
  from src.data.repositories import session as session_repo
8
  from src.models.account import Account
9
  from src.models.patient import Patient
10
- from src.models.session import Session
11
  from src.services import summariser
12
  from src.services.nvidia import nvidia_chat
13
  from src.utils.embeddings import EmbeddingClient
@@ -62,6 +62,7 @@ class MemoryManager:
62
  except ActionFailed as e:
63
  logger().error(f"Failed to search accounts in MemoryManager: {e}")
64
  return []
 
65
  # --- Patient Management Facade ---
66
 
67
  def create_patient(self, **kwargs) -> str | None:
@@ -138,6 +139,22 @@ class MemoryManager:
138
  logger().error(f"Failed to get sessions for patient '{patient_id}': {e}")
139
  return []
140
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  # --- Core Business Logic ---
142
 
143
  async def process_medical_exchange(
 
7
  from src.data.repositories import session as session_repo
8
  from src.models.account import Account
9
  from src.models.patient import Patient
10
+ from src.models.session import Message, Session
11
  from src.services import summariser
12
  from src.services.nvidia import nvidia_chat
13
  from src.utils.embeddings import EmbeddingClient
 
62
  except ActionFailed as e:
63
  logger().error(f"Failed to search accounts in MemoryManager: {e}")
64
  return []
65
+
66
  # --- Patient Management Facade ---
67
 
68
  def create_patient(self, **kwargs) -> str | None:
 
139
  logger().error(f"Failed to get sessions for patient '{patient_id}': {e}")
140
  return []
141
 
142
+ def delete_session(self, session_id: str) -> bool:
143
+ """Deletes a chat session."""
144
+ try:
145
+ return session_repo.delete_session(session_id)
146
+ except ActionFailed as e:
147
+ logger().error(f"Failed to delete session '{session_id}' in MemoryManager: {e}")
148
+ return False
149
+
150
+ def get_session_messages(self, session_id: str, limit: int | None = None) -> list[Message]:
151
+ """Gets messages from a specific chat session."""
152
+ try:
153
+ return session_repo.get_session_messages(session_id, limit)
154
+ except ActionFailed as e:
155
+ logger().error(f"Failed to get messages for session '{session_id}': {e}")
156
+ return []
157
+
158
  # --- Core Business Logic ---
159
 
160
  async def process_medical_exchange(
src/main.py CHANGED
@@ -26,11 +26,11 @@ except Exception as e:
26
  # Import project modules after trying to load environment variables
27
  from src.api.routes import account as account_route
28
  from src.api.routes import audio as audio_route
29
- from src.api.routes import chat as chat_route
30
  from src.api.routes import patient as patients_route
31
  from src.api.routes import session as session_route
32
  from src.api.routes import static as static_route
33
  from src.api.routes import system as system_route
 
34
  from src.core.state import AppState, get_state
35
  from src.data.repositories import account as account_repo
36
  from src.data.repositories import medical_memory as medical_memory_repo
@@ -125,12 +125,12 @@ app.add_middleware(
125
  app.mount("/static", StaticFiles(directory="static"), name="static")
126
 
127
  # Include routers
128
- app.include_router(chat_route.router)
129
  app.include_router(session_route.router)
130
  app.include_router(patients_route.router)
131
  app.include_router(account_route.router)
 
132
  app.include_router(system_route.router)
133
- app.include_router(static_route.router)
134
  app.include_router(audio_route.router)
135
  app.include_router(emr_route.router)
136
 
 
26
  # Import project modules after trying to load environment variables
27
  from src.api.routes import account as account_route
28
  from src.api.routes import audio as audio_route
 
29
  from src.api.routes import patient as patients_route
30
  from src.api.routes import session as session_route
31
  from src.api.routes import static as static_route
32
  from src.api.routes import system as system_route
33
+ from src.api.routes import summarise as summarise_route
34
  from src.core.state import AppState, get_state
35
  from src.data.repositories import account as account_repo
36
  from src.data.repositories import medical_memory as medical_memory_repo
 
125
  app.mount("/static", StaticFiles(directory="static"), name="static")
126
 
127
  # Include routers
128
+ app.include_router(static_route.router)
129
  app.include_router(session_route.router)
130
  app.include_router(patients_route.router)
131
  app.include_router(account_route.router)
132
+ app.include_router(summarise_route.router)
133
  app.include_router(system_route.router)
 
134
  app.include_router(audio_route.router)
135
  app.include_router(emr_route.router)
136
 
src/models/chat.py DELETED
@@ -1,25 +0,0 @@
1
- # models/chat.py
2
-
3
- from pydantic import BaseModel
4
-
5
-
6
- class ChatRequest(BaseModel):
7
- account_id: str
8
- patient_id: str
9
- session_id: str | None = None
10
- message: str
11
-
12
- class ChatResponse(BaseModel):
13
- response: str
14
- session_id: str
15
- timestamp: str
16
- medical_context: str | None = None
17
-
18
- class SessionRequest(BaseModel):
19
- account_id: str
20
- patient_id: str
21
- title: str | None = "New Chat"
22
-
23
- class SummariseRequest(BaseModel):
24
- text: str
25
- max_words: int | None = 5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/models/session.py CHANGED
@@ -26,3 +26,26 @@ class Session(BaseMongoModel):
26
  created_at: datetime
27
  updated_at: datetime
28
  messages: list[Message] = Field(default_factory=list)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  created_at: datetime
27
  updated_at: datetime
28
  messages: list[Message] = Field(default_factory=list)
29
+
30
+ # --- API Request Models ---
31
+
32
+ class SessionCreateRequest(BaseModel):
33
+ account_id: str
34
+ patient_id: str
35
+ title: str | None = "New Chat"
36
+
37
+
38
+ class ChatRequest(BaseModel):
39
+ """Request model for sending a message to a session."""
40
+ account_id: str # For context, though session_id implies this
41
+ patient_id: str # For context, though session_id implies this
42
+ message: str
43
+
44
+ # --- API Response Models ---
45
+
46
+ class ChatResponse(BaseModel):
47
+ """Response model for a chat interaction."""
48
+ response: str
49
+ session_id: str
50
+ timestamp: datetime
51
+ medical_context: str | None = None
src/models/summarise.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/models/summarise.py
2
+
3
+ from pydantic import BaseModel, Field
4
+
5
+
6
+ class SummariseRequest(BaseModel):
7
+ """Request model for the text summarisation endpoint."""
8
+ text: str
9
+ max_words: int = Field(default=5, ge=3, le=10) # Enforce reasonable limits
10
+
11
+
12
+ class SummariseResponse(BaseModel):
13
+ """Response model for the text summarisation endpoint."""
14
+ title: str
src/services/medical_response.py CHANGED
@@ -1,4 +1,4 @@
1
- # services/medical_response.py
2
 
3
  from src.core import prompt_builder
4
  from src.data.medical_kb import search_medical_kb
 
1
+ # src/services/medical_response.py
2
 
3
  from src.core import prompt_builder
4
  from src.data.medical_kb import search_medical_kb