|
from math import ceil |
|
from config import settings |
|
from mongoengine import connect |
|
from llm_agent import ReactAgent |
|
from fastapi import FastAPI, Response |
|
from chat import ( |
|
ChatSession, |
|
FeedbackRequest, |
|
LikeDislikeRequest, |
|
UpdateSessionNameRequest, |
|
Message, |
|
ChatSessionSchema, |
|
Role, |
|
) |
|
|
|
|
|
app = FastAPI() |
|
react_agent = ReactAgent() |
|
|
|
|
|
@app.on_event("startup") |
|
async def load_scheduler_and_DB(): |
|
connect(settings.DB_NAME, host=settings.DB_URI, alias="default") |
|
print("Database connection established!!") |
|
|
|
|
|
@app.post("/query", tags=["Chat-Model"]) |
|
def handle_query(chat_schema: ChatSessionSchema): |
|
if chat_schema.session_id: |
|
chat_session = ChatSession.objects(id=chat_schema.session_id) |
|
chat_session = chat_session.first() |
|
|
|
chat_history = chat_session.get_last_messages() |
|
response = react_agent.handle_query( |
|
session_id=chat_schema.session_id, |
|
query=chat_schema.query, |
|
chat_history=chat_history, |
|
) |
|
|
|
chat_session.add_message_with_metadata( |
|
role=Role.USER.value, content=chat_schema.query |
|
) |
|
chat_session.add_message_with_metadata(role=Role.MODEL.value, content=response) |
|
|
|
return {"response": response} |
|
|
|
|
|
@app.post( |
|
"/temp_session", |
|
tags=["Chat-Session"], |
|
) |
|
def temp_session() -> dict: |
|
session = ChatSession() |
|
session.save() |
|
return {"message": "Session created", "session_id": session.get_id()} |
|
|
|
|
|
@app.post("/feedback", tags=["Chat-Features"]) |
|
def feedback(feedback_request: FeedbackRequest): |
|
chat_session = ChatSession.objects.get(id=feedback_request.session_id) |
|
if chat_session: |
|
chat_session.feedback_message( |
|
feedback_request.message_id, feedback_request.feedback |
|
) |
|
return {"message": "Feedback saved successfully"} |
|
else: |
|
return Response( |
|
content="Chat session not found", status_code=404, media_type="text/plain" |
|
) |
|
|
|
|
|
@app.post("/like", tags=["Chat-Features"]) |
|
def like_message(like_request: LikeDislikeRequest): |
|
chat_session = ChatSession.objects.get(id=like_request.session_id) |
|
if chat_session: |
|
chat_session.like_message(like_request.message_id) |
|
return {"message": "Message liked successfully"} |
|
else: |
|
return Response( |
|
content="Chat session not found", status_code=404, media_type="text/plain" |
|
) |
|
|
|
|
|
@app.post("/dislike", tags=["Chat-Features"]) |
|
def dislike_message(dislike_request: LikeDislikeRequest): |
|
chat_session = ChatSession.objects.get(id=dislike_request.session_id) |
|
if chat_session: |
|
chat_session.dislike_message(dislike_request.message_id) |
|
return {"message": "Message disliked successfully"} |
|
else: |
|
return Response( |
|
content="Chat session not found", status_code=404, media_type="text/plain" |
|
) |
|
|
|
|
|
@app.post( |
|
"/session_name_change", |
|
tags=["Chat-Session"], |
|
) |
|
def update_session_name(request: UpdateSessionNameRequest): |
|
try: |
|
chat_session = ChatSession.objects(id=request.session_id).first() |
|
if not chat_session: |
|
return Response( |
|
content="Session not found", status_code=404, media_type="text/plain" |
|
) |
|
|
|
chat_session.session_name = request.new_session_name |
|
chat_session.save() |
|
|
|
return {"message": "Session name updated successfully"} |
|
except Exception: |
|
return Response( |
|
content="Session not found", status_code=404, media_type="text/plain" |
|
) |
|
|
|
|
|
@app.get( |
|
"/chat_session/<session_id>", |
|
tags=["Chat-Session"], |
|
) |
|
def get_chat_session( |
|
session_id: str, |
|
page: int = 1, |
|
size: int = 20, |
|
): |
|
try: |
|
chat_session = ChatSession.objects.get(id=session_id) |
|
except BaseException: |
|
return Response( |
|
content="Chat session not found", status_code=404, media_type="text/plain" |
|
) |
|
skip = (page - 1) * size |
|
message_ids = [ |
|
message.id for message in chat_session.messages[skip : skip + size] |
|
] |
|
messages = Message.objects(id__in=message_ids) |
|
serialized_messages = [ |
|
{ |
|
**message.to_mongo().to_dict(), |
|
"_id": str(message.id), |
|
"chat_session": ( |
|
str(message.chat_session.id) if message.chat_session else None |
|
), |
|
} |
|
for message in messages |
|
] |
|
|
|
total_count = ChatSession.objects.get(id=session_id).count() |
|
total_pages = ceil(total_count / size) |
|
has_next_page = page < total_pages |
|
next_page = page + 1 if has_next_page else None |
|
|
|
return { |
|
"total_count": total_count, |
|
"total_pages": total_pages, |
|
"has_next_page": has_next_page, |
|
"next_page": next_page, |
|
"messages": serialized_messages, |
|
} |
|
|
|
|
|
@app.delete( |
|
"/delete_session", |
|
tags=["Chat-Session"], |
|
) |
|
def delete_session(session_id: str): |
|
try: |
|
chat_session = ChatSession.objects(id=session_id).first() |
|
if not chat_session: |
|
return Response( |
|
content="Chat session not found", |
|
status_code=404, |
|
media_type="text/plain", |
|
) |
|
|
|
chat_session.delete() |
|
return {"message": "Session deleted successfully"} |
|
except Exception: |
|
raise Response( |
|
content="Chat session not found", status_code=404, media_type="text/plain" |
|
) |
|
|