feat: initial implement end-to-end business with mock-chat-agent-client response and validation logic.
Browse files- app/agent/chat_agent_client.py +14 -0
- app/agent/chat_agent_scheme.py +10 -0
- app/api/chat_api.py +3 -3
- app/mapper/chat_mapper.py +15 -6
- app/model/chat_model.py +4 -4
- app/repository/chat_repository.py +9 -18
- app/schema/chat_schema.py +5 -5
- app/service/chat_service.py +107 -31
- app/service/chat_validation.py +11 -0
- gradio_chatbot.py +11 -11
app/agent/chat_agent_client.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from app.agent.chat_agent_scheme import UserChatAgentRequest, AssistantChatAgentResponse
|
2 |
+
|
3 |
+
|
4 |
+
class ChatAgentClient:
|
5 |
+
def __init__(self):
|
6 |
+
self.agent_name = "ChatAgentClient"
|
7 |
+
|
8 |
+
def process(self, user_chat_agent_request: UserChatAgentRequest) -> AssistantChatAgentResponse:
|
9 |
+
# TODO implement the logic to process the chat
|
10 |
+
agent_name = self.agent_name
|
11 |
+
return AssistantChatAgentResponse(
|
12 |
+
message=f"Here is the {agent_name} Processed message: This is a placeholder response for the user-question",
|
13 |
+
figure=None, # Placeholder for any figure data if needed
|
14 |
+
)
|
app/agent/chat_agent_scheme.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pydantic import BaseModel
|
2 |
+
|
3 |
+
|
4 |
+
class UserChatAgentRequest(BaseModel):
|
5 |
+
message: str
|
6 |
+
|
7 |
+
|
8 |
+
class AssistantChatAgentResponse(BaseModel):
|
9 |
+
message: str
|
10 |
+
figure: dict | None = None
|
app/api/chat_api.py
CHANGED
@@ -3,7 +3,7 @@
|
|
3 |
from typing import Any, List, Optional
|
4 |
from fastapi import APIRouter, HTTPException, Depends, Request
|
5 |
from pydantic import BaseModel
|
6 |
-
from app.schema.chat_schema import ChatCompletionRequest, ChatCompletionResponse,
|
7 |
from app.schema.conversation_schema import ConversationResponse, ConversationItemResponse
|
8 |
from app.service.chat_service import ChatService
|
9 |
from app.security.auth_service import AuthService
|
@@ -57,7 +57,7 @@ async def list_chat_completions(request: Request, username: str = Depends(auth_s
|
|
57 |
page: int = 1
|
58 |
limit: int = 10
|
59 |
sort: dict = {"created_date": -1}
|
60 |
-
project: dict =
|
61 |
|
62 |
try:
|
63 |
query = {"created_by": username}
|
@@ -80,7 +80,7 @@ async def retrieve_chat_completion(completion_id: str, request: Request, usernam
|
|
80 |
|
81 |
|
82 |
# get all messages for a chat completion
|
83 |
-
@router.get("/chat/completions/{completion_id}/messages", response_model=List[
|
84 |
async def list_messages(completion_id: str, request: Request, username: str = Depends(auth_service.verify_credentials)):
|
85 |
"""
|
86 |
Get all messages for a chat completion
|
|
|
3 |
from typing import Any, List, Optional
|
4 |
from fastapi import APIRouter, HTTPException, Depends, Request
|
5 |
from pydantic import BaseModel
|
6 |
+
from app.schema.chat_schema import ChatCompletionRequest, ChatCompletionResponse, ChatMessageResponse
|
7 |
from app.schema.conversation_schema import ConversationResponse, ConversationItemResponse
|
8 |
from app.service.chat_service import ChatService
|
9 |
from app.security.auth_service import AuthService
|
|
|
57 |
page: int = 1
|
58 |
limit: int = 10
|
59 |
sort: dict = {"created_date": -1}
|
60 |
+
project: dict = {}
|
61 |
|
62 |
try:
|
63 |
query = {"created_by": username}
|
|
|
80 |
|
81 |
|
82 |
# get all messages for a chat completion
|
83 |
+
@router.get("/chat/completions/{completion_id}/messages", response_model=List[ChatMessageResponse], deprecated=True)
|
84 |
async def list_messages(completion_id: str, request: Request, username: str = Depends(auth_service.verify_credentials)):
|
85 |
"""
|
86 |
Get all messages for a chat completion
|
app/mapper/chat_mapper.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
from app.mapper.base_mapper import BaseMapper
|
2 |
-
from app.model.chat_model import ChatCompletion,
|
3 |
-
from app.schema.chat_schema import ChatCompletionResponse, ChatCompletionRequest,
|
4 |
|
5 |
|
6 |
class ChatMapper(BaseMapper[ChatCompletion, ChatCompletionResponse]):
|
@@ -15,11 +15,12 @@ class ChatMapper(BaseMapper[ChatCompletion, ChatCompletionResponse]):
|
|
15 |
last_message = model.messages[-1] if model.messages else None
|
16 |
message_response = None
|
17 |
if last_message:
|
18 |
-
message_response =
|
19 |
message_id=last_message.message_id,
|
20 |
role=last_message.role,
|
21 |
content=last_message.content,
|
22 |
figure=last_message.figure,
|
|
|
23 |
)
|
24 |
|
25 |
# Create choice response
|
@@ -35,11 +36,19 @@ class ChatMapper(BaseMapper[ChatCompletion, ChatCompletionResponse]):
|
|
35 |
messages = []
|
36 |
if schema.messages:
|
37 |
for msg in schema.messages:
|
38 |
-
messages.append(
|
39 |
|
40 |
return ChatCompletion(
|
|
|
41 |
completion_id=schema.completion_id,
|
42 |
-
model=schema.model
|
43 |
messages=messages,
|
44 |
-
stream=schema.stream
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
)
|
|
|
1 |
from app.mapper.base_mapper import BaseMapper
|
2 |
+
from app.model.chat_model import ChatCompletion, ChatMessageModel
|
3 |
+
from app.schema.chat_schema import ChatCompletionResponse, ChatCompletionRequest, ChatMessageResponse, ChoiceResponse
|
4 |
|
5 |
|
6 |
class ChatMapper(BaseMapper[ChatCompletion, ChatCompletionResponse]):
|
|
|
15 |
last_message = model.messages[-1] if model.messages else None
|
16 |
message_response = None
|
17 |
if last_message:
|
18 |
+
message_response = ChatMessageResponse(
|
19 |
message_id=last_message.message_id,
|
20 |
role=last_message.role,
|
21 |
content=last_message.content,
|
22 |
figure=last_message.figure,
|
23 |
+
created_date=last_message.created_date,
|
24 |
)
|
25 |
|
26 |
# Create choice response
|
|
|
36 |
messages = []
|
37 |
if schema.messages:
|
38 |
for msg in schema.messages:
|
39 |
+
messages.append(ChatMessageModel(role=msg.role, content=msg.content, figure=None, message_id=None))
|
40 |
|
41 |
return ChatCompletion(
|
42 |
+
title=schema.title,
|
43 |
completion_id=schema.completion_id,
|
44 |
+
model=schema.model,
|
45 |
messages=messages,
|
46 |
+
stream=schema.stream,
|
47 |
+
created_by=schema.created_by,
|
48 |
+
created_date=schema.created_date,
|
49 |
+
object_field=schema.object_field,
|
50 |
+
is_archived=schema.archived,
|
51 |
+
is_starred=schema.starred,
|
52 |
+
last_updated_by=schema.last_updated_by,
|
53 |
+
last_updated_date=schema.last_updated_date,
|
54 |
)
|
app/model/chat_model.py
CHANGED
@@ -51,13 +51,13 @@ from typing import List, Optional, Any
|
|
51 |
# }
|
52 |
|
53 |
|
54 |
-
class
|
55 |
"""
|
56 |
A message in a chat completion.
|
57 |
"""
|
58 |
|
59 |
-
message_id: str = Field(
|
60 |
-
role:
|
61 |
content: str = Field(..., description="The content of the message")
|
62 |
figure: Optional[dict[str, Any]] = Field(None, description="The figure data for visualization")
|
63 |
created_date: datetime = Field(default_factory=datetime.now, description="The timestamp of the message")
|
@@ -89,7 +89,7 @@ class ChatCompletion(BaseModel):
|
|
89 |
|
90 |
# openai compatible fields
|
91 |
model: Optional[str] = Field(None, description="The model used for the chat completion", examples=["gpt-4o-mini", "gpt-4o", "gpt-3.5-turbo"])
|
92 |
-
messages: Optional[List[
|
93 |
|
94 |
# not implemented yet
|
95 |
# temperature: float = Field(default=0.7,ge=0.0, le=1.0, description="What sampling temperature to use, between 0 and 1. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.")
|
|
|
51 |
# }
|
52 |
|
53 |
|
54 |
+
class ChatMessageModel(BaseModel):
|
55 |
"""
|
56 |
A message in a chat completion.
|
57 |
"""
|
58 |
|
59 |
+
message_id: Optional[str] = Field(None, description="The unique identifier for the message")
|
60 |
+
role: str = Field(..., description="The role of the message sender", examples=["user", "assistant", "system"])
|
61 |
content: str = Field(..., description="The content of the message")
|
62 |
figure: Optional[dict[str, Any]] = Field(None, description="The figure data for visualization")
|
63 |
created_date: datetime = Field(default_factory=datetime.now, description="The timestamp of the message")
|
|
|
89 |
|
90 |
# openai compatible fields
|
91 |
model: Optional[str] = Field(None, description="The model used for the chat completion", examples=["gpt-4o-mini", "gpt-4o", "gpt-3.5-turbo"])
|
92 |
+
messages: Optional[List[ChatMessageModel]] = Field(None, description="The messages in the chat completion")
|
93 |
|
94 |
# not implemented yet
|
95 |
# temperature: float = Field(default=0.7,ge=0.0, le=1.0, description="What sampling temperature to use, between 0 and 1. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.")
|
app/repository/chat_repository.py
CHANGED
@@ -1,8 +1,7 @@
|
|
1 |
from typing import Any, List, Optional
|
2 |
from app.db.factory import db_client
|
3 |
-
from app.model.chat_model import
|
4 |
from loguru import logger
|
5 |
-
import uuid
|
6 |
import pymongo
|
7 |
|
8 |
|
@@ -21,7 +20,7 @@ class ChatRepository:
|
|
21 |
self.db = db_client.db
|
22 |
self.collection = "chat_completion"
|
23 |
|
24 |
-
async def
|
25 |
"""
|
26 |
Create a new chat completion in the database.
|
27 |
|
@@ -48,7 +47,7 @@ class ChatRepository:
|
|
48 |
logger.info(f"Successfully created new chat completion with ID: {entity.completion_id}")
|
49 |
return await self.find_by_id(entity.completion_id)
|
50 |
|
51 |
-
async def
|
52 |
"""
|
53 |
Update an existing chat completion in the database.
|
54 |
|
@@ -108,9 +107,9 @@ class ChatRepository:
|
|
108 |
try:
|
109 |
result = await self.find_by_id(entity.completion_id)
|
110 |
if result:
|
111 |
-
return await self.
|
112 |
else:
|
113 |
-
return await self.
|
114 |
except Exception as e:
|
115 |
logger.error(f"Error saving chat completion: {e}")
|
116 |
raise
|
@@ -118,7 +117,7 @@ class ChatRepository:
|
|
118 |
logger.debug("END REPO: save chat completion")
|
119 |
|
120 |
async def find(
|
121 |
-
self, query: dict = {}, page: int = 1, limit: int = 10, sort: dict = {"created_date": -1}, projection: dict =
|
122 |
) -> List[ChatCompletion]:
|
123 |
"""
|
124 |
Find a chat completion by a given query. with pagination
|
@@ -147,7 +146,7 @@ class ChatRepository:
|
|
147 |
logger.debug(f"END REPO: find, returning {len(result_models)} models.")
|
148 |
return result_models
|
149 |
|
150 |
-
async def find_by_id(self, completion_id: str, projection: dict = None) -> ChatCompletion:
|
151 |
"""
|
152 |
Find a chat completion by a given id.
|
153 |
Example : completion_id = "123"
|
@@ -169,7 +168,7 @@ class ChatRepository:
|
|
169 |
logger.info(f"Chat completion with ID {completion_id} not found in DB.")
|
170 |
return None
|
171 |
|
172 |
-
async def find_messages(self, completion_id: str) -> List[
|
173 |
"""
|
174 |
Find all messages for a given chat completion id.
|
175 |
Example : completion_id = "123"
|
@@ -181,7 +180,7 @@ class ChatRepository:
|
|
181 |
if chat_doc and "messages" in chat_doc and chat_doc["messages"]:
|
182 |
try:
|
183 |
messages_list = [
|
184 |
-
|
185 |
message_id=item["message_id"],
|
186 |
role=item["role"],
|
187 |
content=item["content"],
|
@@ -215,16 +214,8 @@ class ChatRepository:
|
|
215 |
logger.error(f"Error finding plot by message id: {e}")
|
216 |
return None
|
217 |
|
218 |
-
# Mesajları Python tarafında filtreleyelim
|
219 |
if entity_doc and "messages" in entity_doc and entity_doc["messages"]:
|
220 |
try:
|
221 |
-
# İstenen message_id'ye sahip mesajı bul
|
222 |
-
# for message in entity_doc["messages"]:
|
223 |
-
# if message["message_id"] == message_id:
|
224 |
-
# figure = message.get("figure")
|
225 |
-
# logger.debug(f"REPO find figure: {figure}")
|
226 |
-
# return figure
|
227 |
-
|
228 |
match = next((message for message in entity_doc["messages"] if message["message_id"] == message_id), None)
|
229 |
if match:
|
230 |
figure = match.get("figure")
|
|
|
1 |
from typing import Any, List, Optional
|
2 |
from app.db.factory import db_client
|
3 |
+
from app.model.chat_model import ChatMessageModel, ChatCompletion
|
4 |
from loguru import logger
|
|
|
5 |
import pymongo
|
6 |
|
7 |
|
|
|
20 |
self.db = db_client.db
|
21 |
self.collection = "chat_completion"
|
22 |
|
23 |
+
async def create(self, entity: ChatCompletion) -> ChatCompletion:
|
24 |
"""
|
25 |
Create a new chat completion in the database.
|
26 |
|
|
|
47 |
logger.info(f"Successfully created new chat completion with ID: {entity.completion_id}")
|
48 |
return await self.find_by_id(entity.completion_id)
|
49 |
|
50 |
+
async def update(self, entity: ChatCompletion) -> ChatCompletion:
|
51 |
"""
|
52 |
Update an existing chat completion in the database.
|
53 |
|
|
|
107 |
try:
|
108 |
result = await self.find_by_id(entity.completion_id)
|
109 |
if result:
|
110 |
+
return await self.update(entity)
|
111 |
else:
|
112 |
+
return await self.create(entity)
|
113 |
except Exception as e:
|
114 |
logger.error(f"Error saving chat completion: {e}")
|
115 |
raise
|
|
|
117 |
logger.debug("END REPO: save chat completion")
|
118 |
|
119 |
async def find(
|
120 |
+
self, query: dict = {}, page: int = 1, limit: int = 10, sort: dict = {"created_date": -1}, projection: dict = {}
|
121 |
) -> List[ChatCompletion]:
|
122 |
"""
|
123 |
Find a chat completion by a given query. with pagination
|
|
|
146 |
logger.debug(f"END REPO: find, returning {len(result_models)} models.")
|
147 |
return result_models
|
148 |
|
149 |
+
async def find_by_id(self, completion_id: str, projection: dict = None) -> ChatCompletion | None:
|
150 |
"""
|
151 |
Find a chat completion by a given id.
|
152 |
Example : completion_id = "123"
|
|
|
168 |
logger.info(f"Chat completion with ID {completion_id} not found in DB.")
|
169 |
return None
|
170 |
|
171 |
+
async def find_messages(self, completion_id: str) -> List[ChatMessageModel]:
|
172 |
"""
|
173 |
Find all messages for a given chat completion id.
|
174 |
Example : completion_id = "123"
|
|
|
180 |
if chat_doc and "messages" in chat_doc and chat_doc["messages"]:
|
181 |
try:
|
182 |
messages_list = [
|
183 |
+
ChatMessageModel(
|
184 |
message_id=item["message_id"],
|
185 |
role=item["role"],
|
186 |
content=item["content"],
|
|
|
214 |
logger.error(f"Error finding plot by message id: {e}")
|
215 |
return None
|
216 |
|
|
|
217 |
if entity_doc and "messages" in entity_doc and entity_doc["messages"]:
|
218 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
219 |
match = next((message for message in entity_doc["messages"] if message["message_id"] == message_id), None)
|
220 |
if match:
|
221 |
figure = match.get("figure")
|
app/schema/chat_schema.py
CHANGED
@@ -3,7 +3,7 @@ from pydantic import BaseModel, Field
|
|
3 |
from datetime import datetime
|
4 |
|
5 |
|
6 |
-
class
|
7 |
"""
|
8 |
Represents a message in a chat completion.
|
9 |
"""
|
@@ -22,11 +22,11 @@ class ChatCompletionRequest(BaseModel):
|
|
22 |
description="The unique identifier for the chat completion. When starting a new chat, this will be a new UUID. When continuing a previous chat, this will be the same as the previous chat completion id.",
|
23 |
)
|
24 |
model: Optional[str] = Field(None, description="The model to use for the chat completion")
|
25 |
-
messages: Optional[List[
|
26 |
stream: Optional[bool] = Field(None, description="Whether to stream the chat completion")
|
27 |
|
28 |
|
29 |
-
class
|
30 |
"""
|
31 |
A chat completion message generated by the model.
|
32 |
"""
|
@@ -45,7 +45,7 @@ class ChoiceResponse(BaseModel):
|
|
45 |
examples=["stop", "length", "content_filter"],
|
46 |
)
|
47 |
index: Optional[int] = Field(None, description="The index of the choice in the list of choices.")
|
48 |
-
message: Optional[
|
49 |
# logprobs: str = None # not implemented yet
|
50 |
|
51 |
|
@@ -54,7 +54,7 @@ class ChatCompletionResponse(BaseModel):
|
|
54 |
Represents a chat completion response returned by model, based on the provided input.
|
55 |
"""
|
56 |
|
57 |
-
completion_id:
|
58 |
choices: Optional[List[ChoiceResponse]] = Field(None, description="A list of chat completion choices.")
|
59 |
created: Optional[int] = Field(None, description="The Unix timestamp (in seconds) of when the chat completion was created.")
|
60 |
model: Optional[str] = Field(None, description="The model used for the chat completion")
|
|
|
3 |
from datetime import datetime
|
4 |
|
5 |
|
6 |
+
class ChatMessageRequest(BaseModel):
|
7 |
"""
|
8 |
Represents a message in a chat completion.
|
9 |
"""
|
|
|
22 |
description="The unique identifier for the chat completion. When starting a new chat, this will be a new UUID. When continuing a previous chat, this will be the same as the previous chat completion id.",
|
23 |
)
|
24 |
model: Optional[str] = Field(None, description="The model to use for the chat completion")
|
25 |
+
messages: Optional[List[ChatMessageRequest]] = Field(None, description="The messages to use for the chat completion")
|
26 |
stream: Optional[bool] = Field(None, description="Whether to stream the chat completion")
|
27 |
|
28 |
|
29 |
+
class ChatMessageResponse(BaseModel):
|
30 |
"""
|
31 |
A chat completion message generated by the model.
|
32 |
"""
|
|
|
45 |
examples=["stop", "length", "content_filter"],
|
46 |
)
|
47 |
index: Optional[int] = Field(None, description="The index of the choice in the list of choices.")
|
48 |
+
message: Optional[ChatMessageResponse] = Field(None, description="The message to use for the chat completion")
|
49 |
# logprobs: str = None # not implemented yet
|
50 |
|
51 |
|
|
|
54 |
Represents a chat completion response returned by model, based on the provided input.
|
55 |
"""
|
56 |
|
57 |
+
completion_id: str = Field(None, description="The unique identifier for the chat completion")
|
58 |
choices: Optional[List[ChoiceResponse]] = Field(None, description="A list of chat completion choices.")
|
59 |
created: Optional[int] = Field(None, description="The Unix timestamp (in seconds) of when the chat completion was created.")
|
60 |
model: Optional[str] = Field(None, description="The model used for the chat completion")
|
app/service/chat_service.py
CHANGED
@@ -1,12 +1,17 @@
|
|
1 |
import datetime
|
2 |
-
from typing import Any, List
|
|
|
|
|
|
|
3 |
from app.repository.chat_repository import ChatRepository
|
4 |
-
from app.schema.chat_schema import ChatCompletionRequest, ChatCompletionResponse,
|
5 |
from app.mapper.chat_mapper import ChatMapper
|
6 |
from app.mapper.conversation_mapper import ConversationMapper
|
7 |
import uuid
|
8 |
from loguru import logger
|
9 |
from app.schema.conversation_schema import ConversationResponse
|
|
|
|
|
10 |
|
11 |
|
12 |
class ChatService:
|
@@ -14,6 +19,8 @@ class ChatService:
|
|
14 |
self.chat_repository = ChatRepository()
|
15 |
self.chat_mapper = ChatMapper()
|
16 |
self.conversation_mapper = ConversationMapper()
|
|
|
|
|
17 |
|
18 |
async def find(self, query: dict, page: int, limit: int, sort: dict, project: dict = None) -> List[ChatCompletionResponse]:
|
19 |
logger.debug(f"BEGIN SERVICE: find for query: {query}, page: {page}, limit: {limit}, sort: {sort}, project: {project}")
|
@@ -24,12 +31,12 @@ class ChatService:
|
|
24 |
entity = await self.chat_repository.find_by_id(completion_id, project)
|
25 |
return self.chat_mapper.to_schema(entity) if entity else None
|
26 |
|
27 |
-
async def find_messages(self, completion_id: str) -> List[
|
28 |
logger.debug(f"BEGIN SERVICE: find_messages for completion_id: {completion_id}")
|
29 |
messages = await self.chat_repository.find_messages(completion_id)
|
30 |
logger.debug(f"END SERVICE: find_messages for completion_id: {completion_id}, messages: {len(messages)}")
|
31 |
messages_response = [
|
32 |
-
|
33 |
message_id=message.message_id,
|
34 |
role=message.role,
|
35 |
content=message.content,
|
@@ -41,7 +48,7 @@ class ChatService:
|
|
41 |
return messages_response
|
42 |
|
43 |
# conversation service
|
44 |
-
async def find_all_conversations(self, username: str) ->
|
45 |
"""Find all conversations for a given username."""
|
46 |
query = {"created_by": username}
|
47 |
sort = {"last_updated_date": -1} # Sort by last updated date in descending order
|
@@ -50,7 +57,8 @@ class ChatService:
|
|
50 |
result = self.conversation_mapper.to_schema_list(entities)
|
51 |
return ConversationResponse(items=result, total=len(result), limit=100, offset=0)
|
52 |
|
53 |
-
|
|
|
54 |
"""Find a conversation by its completion ID."""
|
55 |
logger.debug(f"BEGIN SERVICE: find_conversation_by_id for completion_id: {completion_id}")
|
56 |
projection = {"messages": 0, "_id": 0}
|
@@ -58,13 +66,13 @@ class ChatService:
|
|
58 |
logger.debug(f"END SERVICE: find_conversation_by_id for completion_id: {completion_id}, entity: {entity}")
|
59 |
|
60 |
if entity:
|
61 |
-
|
62 |
-
result =
|
63 |
return result
|
64 |
else:
|
65 |
return None
|
66 |
|
67 |
-
async def find_plot_by_message(self, completion_id: str, message_id: str) ->
|
68 |
logger.debug(f"BEGIN SERVICE: find_plot_by_message for completion_id: {completion_id}, message_id: {message_id}")
|
69 |
figure = await self.chat_repository.find_plot_by_message(completion_id, message_id)
|
70 |
|
@@ -77,26 +85,94 @@ class ChatService:
|
|
77 |
logger.debug(f"END SERVICE: find_plot_by_message for completion_id: {completion_id}, message_id: {message_id} with figure")
|
78 |
return result
|
79 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import datetime
|
2 |
+
from typing import Any, List
|
3 |
+
|
4 |
+
from app.agent.chat_agent_scheme import UserChatAgentRequest
|
5 |
+
from app.model.chat_model import ChatMessageModel
|
6 |
from app.repository.chat_repository import ChatRepository
|
7 |
+
from app.schema.chat_schema import ChatCompletionRequest, ChatCompletionResponse, ChatMessageResponse, ChatMessageRequest
|
8 |
from app.mapper.chat_mapper import ChatMapper
|
9 |
from app.mapper.conversation_mapper import ConversationMapper
|
10 |
import uuid
|
11 |
from loguru import logger
|
12 |
from app.schema.conversation_schema import ConversationResponse
|
13 |
+
from app.service.chat_validation import ChatValidation
|
14 |
+
from app.agent.chat_agent_client import ChatAgentClient
|
15 |
|
16 |
|
17 |
class ChatService:
|
|
|
19 |
self.chat_repository = ChatRepository()
|
20 |
self.chat_mapper = ChatMapper()
|
21 |
self.conversation_mapper = ConversationMapper()
|
22 |
+
self.chat_validation = ChatValidation()
|
23 |
+
self.chat_agent_client = ChatAgentClient()
|
24 |
|
25 |
async def find(self, query: dict, page: int, limit: int, sort: dict, project: dict = None) -> List[ChatCompletionResponse]:
|
26 |
logger.debug(f"BEGIN SERVICE: find for query: {query}, page: {page}, limit: {limit}, sort: {sort}, project: {project}")
|
|
|
31 |
entity = await self.chat_repository.find_by_id(completion_id, project)
|
32 |
return self.chat_mapper.to_schema(entity) if entity else None
|
33 |
|
34 |
+
async def find_messages(self, completion_id: str) -> List[ChatMessageResponse]:
|
35 |
logger.debug(f"BEGIN SERVICE: find_messages for completion_id: {completion_id}")
|
36 |
messages = await self.chat_repository.find_messages(completion_id)
|
37 |
logger.debug(f"END SERVICE: find_messages for completion_id: {completion_id}, messages: {len(messages)}")
|
38 |
messages_response = [
|
39 |
+
ChatMessageResponse(
|
40 |
message_id=message.message_id,
|
41 |
role=message.role,
|
42 |
content=message.content,
|
|
|
48 |
return messages_response
|
49 |
|
50 |
# conversation service
|
51 |
+
async def find_all_conversations(self, username: str) -> ConversationResponse:
|
52 |
"""Find all conversations for a given username."""
|
53 |
query = {"created_by": username}
|
54 |
sort = {"last_updated_date": -1} # Sort by last updated date in descending order
|
|
|
57 |
result = self.conversation_mapper.to_schema_list(entities)
|
58 |
return ConversationResponse(items=result, total=len(result), limit=100, offset=0)
|
59 |
|
60 |
+
# conversation service
|
61 |
+
async def find_conversation_by_id(self, completion_id: str) -> ConversationResponse | None:
|
62 |
"""Find a conversation by its completion ID."""
|
63 |
logger.debug(f"BEGIN SERVICE: find_conversation_by_id for completion_id: {completion_id}")
|
64 |
projection = {"messages": 0, "_id": 0}
|
|
|
66 |
logger.debug(f"END SERVICE: find_conversation_by_id for completion_id: {completion_id}, entity: {entity}")
|
67 |
|
68 |
if entity:
|
69 |
+
conversation_item = self.conversation_mapper.to_schema(entity)
|
70 |
+
result = ConversationResponse(items=[conversation_item], total=1, limit=1, offset=0)
|
71 |
return result
|
72 |
else:
|
73 |
return None
|
74 |
|
75 |
+
async def find_plot_by_message(self, completion_id: str, message_id: str) -> dict[str, Any]:
|
76 |
logger.debug(f"BEGIN SERVICE: find_plot_by_message for completion_id: {completion_id}, message_id: {message_id}")
|
77 |
figure = await self.chat_repository.find_plot_by_message(completion_id, message_id)
|
78 |
|
|
|
85 |
logger.debug(f"END SERVICE: find_plot_by_message for completion_id: {completion_id}, message_id: {message_id} with figure")
|
86 |
return result
|
87 |
|
88 |
+
async def _save_chat_completion(self, request: ChatCompletionRequest, username: str) -> ChatCompletionResponse:
|
89 |
+
"""
|
90 |
+
Save a chat completion to the database.
|
91 |
+
"""
|
92 |
+
logger.debug(f"BEGIN SERVICE: for request: {request}, username: {username}")
|
93 |
+
try:
|
94 |
+
# Convert request to model
|
95 |
+
entity = self.chat_mapper.to_model(request)
|
96 |
+
|
97 |
+
entity.last_updated_by = username
|
98 |
+
entity.last_updated_date = datetime.datetime.now()
|
99 |
+
if entity.completion_id:
|
100 |
+
# generate a new chat completion
|
101 |
+
entity.completion_id = str(uuid.uuid4())
|
102 |
+
last_user_request_message = request.messages[-1]
|
103 |
+
current_entity = await self.chat_repository.find_by_id(entity.completion_id)
|
104 |
+
if not current_entity:
|
105 |
+
# create new chat completion with new user request message
|
106 |
+
entity.created_by = username
|
107 |
+
entity.created_date = datetime.datetime.now()
|
108 |
+
entity.last_updated_by = username
|
109 |
+
entity.last_updated_date = datetime.datetime.now()
|
110 |
+
# title can generate with LLM from user request message.content
|
111 |
+
entity.title = last_user_request_message.content[:20]
|
112 |
+
final_entity = await self.chat_repository.create(entity)
|
113 |
+
else:
|
114 |
+
# update existing chat completion with new user request message
|
115 |
+
|
116 |
+
message_model = ChatMessageModel(
|
117 |
+
message_id=str(uuid.uuid4()),
|
118 |
+
role=last_user_request_message.role,
|
119 |
+
content=last_user_request_message.content,
|
120 |
+
figure=None,
|
121 |
+
created_date=datetime.datetime.now(),
|
122 |
+
)
|
123 |
+
current_entity.messages.append(message_model)
|
124 |
+
current_entity.last_updated_date = datetime.datetime.now()
|
125 |
+
final_entity = await self.chat_repository.update(current_entity)
|
126 |
+
|
127 |
+
# Convert model to response
|
128 |
+
result = self.chat_mapper.to_schema(final_entity)
|
129 |
+
logger.debug("END SERVICE")
|
130 |
+
return result
|
131 |
+
except Exception as e:
|
132 |
+
logger.error(f"Error saving chat completion: {e}")
|
133 |
+
raise
|
134 |
+
|
135 |
+
async def chat_agent_client_process(self, user_chat_completion: ChatCompletionRequest, username: str):
|
136 |
+
logger.debug(f"BEGIN SERVICE: Agentic Chat AI process. username: {username}")
|
137 |
+
last_user_message = user_chat_completion.messages[-1].content
|
138 |
+
user_chat_agent_request = UserChatAgentRequest(message=last_user_message)
|
139 |
+
result = self.chat_agent_client.process(user_chat_agent_request)
|
140 |
+
logger.debug("END SERVICE: Agentic Chat AI process")
|
141 |
+
return result
|
142 |
|
143 |
+
async def handle_chat_completion(self, user_chat_completion: ChatCompletionRequest, username: str) -> ChatCompletionResponse:
|
144 |
+
last_user_message = user_chat_completion
|
145 |
+
logger.debug(f"BEGIN SERVICE: last_user_message: {last_user_message}, username: {username}")
|
146 |
+
|
147 |
+
# validate user message
|
148 |
+
self.chat_validation.validate_request(user_chat_completion)
|
149 |
+
|
150 |
+
# save user message to database
|
151 |
+
logger.info("Saving user message to database")
|
152 |
+
repo_user_message = await self._save_chat_completion(user_chat_completion, username)
|
153 |
+
logger.info(f"Saved user message to database with completion_id: {repo_user_message.completion_id}")
|
154 |
+
|
155 |
+
# region agentic-ai process start #########################################################
|
156 |
+
try:
|
157 |
+
logger.info("Agentic Chat AI process started")
|
158 |
+
agent_result = await self.chat_agent_client_process(user_chat_completion, username)
|
159 |
+
assistant_message = ChatMessageRequest(role="assistant", content=agent_result.message)
|
160 |
+
assistant_chat_completion = user_chat_completion
|
161 |
+
assistant_chat_completion.messages = [assistant_message] # replace user messages with assistant message
|
162 |
+
logger.info(f"Agentic Chat AI process completed. Part of Assistant Message...: {assistant_message.content[:50]}...")
|
163 |
+
except Exception as e:
|
164 |
+
logger.error(f"Error agentic-ai process: {e}")
|
165 |
+
raise
|
166 |
+
# endregion agentic-ai process start ######################################################
|
167 |
+
|
168 |
+
# validate agent response
|
169 |
+
self.chat_validation.validate_response(agent_result)
|
170 |
+
|
171 |
+
# save assistant message to database
|
172 |
+
repo_assistant_message = await self._save_chat_completion(assistant_chat_completion, username)
|
173 |
+
|
174 |
+
# generate api response with user, agent, db etc... TBD
|
175 |
+
result = repo_assistant_message
|
176 |
+
|
177 |
+
logger.debug("END SERVICE")
|
178 |
+
return result
|
app/service/chat_validation.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class ChatValidation:
|
2 |
+
def __init__(self):
|
3 |
+
pass
|
4 |
+
|
5 |
+
def validate_request(self, completion):
|
6 |
+
# TODO implement request validation logic
|
7 |
+
pass
|
8 |
+
|
9 |
+
def validate_response(self, agent_result):
|
10 |
+
# TODO implement response validation logic
|
11 |
+
pass
|
gradio_chatbot.py
CHANGED
@@ -2,7 +2,7 @@ import json
|
|
2 |
import gradio as gr
|
3 |
import environs
|
4 |
import httpx
|
5 |
-
from typing import List, Tuple, Optional
|
6 |
from dataclasses import dataclass
|
7 |
from enum import Enum
|
8 |
import os
|
@@ -125,7 +125,7 @@ class MessageStatus(Enum):
|
|
125 |
|
126 |
|
127 |
@dataclass
|
128 |
-
class
|
129 |
"""Data class for message response"""
|
130 |
|
131 |
status: MessageStatus
|
@@ -142,7 +142,7 @@ class ChatAPI:
|
|
142 |
self.api_key = api_key
|
143 |
self.endpoint = f"{base_url}/v1/chat/completions"
|
144 |
|
145 |
-
async def send_message(self, prompt: str) ->
|
146 |
"""
|
147 |
Send a message to the chat API
|
148 |
|
@@ -150,7 +150,7 @@ class ChatAPI:
|
|
150 |
prompt (str): The message to send
|
151 |
|
152 |
Returns:
|
153 |
-
|
154 |
"""
|
155 |
logger.trace(f"Calling chat API with prompt: {prompt}")
|
156 |
try:
|
@@ -169,7 +169,7 @@ class ChatAPI:
|
|
169 |
|
170 |
if response.status_code != 200:
|
171 |
logger.error(f"API Error: {response.text}")
|
172 |
-
return
|
173 |
status=MessageStatus.ERROR,
|
174 |
content="",
|
175 |
figure=None,
|
@@ -187,14 +187,14 @@ class ChatAPI:
|
|
187 |
logger.trace(f"Figure: {figure}")
|
188 |
content = message.get("content", "Content not found")
|
189 |
logger.trace(f"Last message: {content}")
|
190 |
-
return
|
191 |
status=MessageStatus.SUCCESS,
|
192 |
content=content,
|
193 |
figure=figure,
|
194 |
)
|
195 |
else:
|
196 |
logger.error("Invalid API response")
|
197 |
-
return
|
198 |
status=MessageStatus.ERROR,
|
199 |
content="",
|
200 |
error="Invalid API response",
|
@@ -202,14 +202,14 @@ class ChatAPI:
|
|
202 |
|
203 |
except httpx.TimeoutException:
|
204 |
logger.error("API request timed out")
|
205 |
-
return
|
206 |
status=MessageStatus.ERROR,
|
207 |
content="",
|
208 |
error="Request timed out. Please try again.",
|
209 |
)
|
210 |
except Exception as e:
|
211 |
logger.error(f"Error: {str(e)}")
|
212 |
-
return
|
213 |
status=MessageStatus.ERROR,
|
214 |
content="",
|
215 |
error=f"Error: {str(e)}",
|
@@ -319,13 +319,13 @@ class ChatInterface:
|
|
319 |
None,
|
320 |
)
|
321 |
|
322 |
-
def clear_history() ->
|
323 |
"""Clear chat history"""
|
324 |
return [], "", "Chat cleared.", "", None
|
325 |
|
326 |
def retry_last_message(
|
327 |
history: List[List[str]],
|
328 |
-
) ->
|
329 |
"""Retry the last message"""
|
330 |
if not history:
|
331 |
return history, "", "No message to retry.", "", None
|
|
|
2 |
import gradio as gr
|
3 |
import environs
|
4 |
import httpx
|
5 |
+
from typing import List, Tuple, Optional, Any
|
6 |
from dataclasses import dataclass
|
7 |
from enum import Enum
|
8 |
import os
|
|
|
125 |
|
126 |
|
127 |
@dataclass
|
128 |
+
class ChatMessageResponse:
|
129 |
"""Data class for message response"""
|
130 |
|
131 |
status: MessageStatus
|
|
|
142 |
self.api_key = api_key
|
143 |
self.endpoint = f"{base_url}/v1/chat/completions"
|
144 |
|
145 |
+
async def send_message(self, prompt: str) -> ChatMessageResponse:
|
146 |
"""
|
147 |
Send a message to the chat API
|
148 |
|
|
|
150 |
prompt (str): The message to send
|
151 |
|
152 |
Returns:
|
153 |
+
ChatMessageResponse: The response from the API
|
154 |
"""
|
155 |
logger.trace(f"Calling chat API with prompt: {prompt}")
|
156 |
try:
|
|
|
169 |
|
170 |
if response.status_code != 200:
|
171 |
logger.error(f"API Error: {response.text}")
|
172 |
+
return ChatMessageResponse(
|
173 |
status=MessageStatus.ERROR,
|
174 |
content="",
|
175 |
figure=None,
|
|
|
187 |
logger.trace(f"Figure: {figure}")
|
188 |
content = message.get("content", "Content not found")
|
189 |
logger.trace(f"Last message: {content}")
|
190 |
+
return ChatMessageResponse(
|
191 |
status=MessageStatus.SUCCESS,
|
192 |
content=content,
|
193 |
figure=figure,
|
194 |
)
|
195 |
else:
|
196 |
logger.error("Invalid API response")
|
197 |
+
return ChatMessageResponse(
|
198 |
status=MessageStatus.ERROR,
|
199 |
content="",
|
200 |
error="Invalid API response",
|
|
|
202 |
|
203 |
except httpx.TimeoutException:
|
204 |
logger.error("API request timed out")
|
205 |
+
return ChatMessageResponse(
|
206 |
status=MessageStatus.ERROR,
|
207 |
content="",
|
208 |
error="Request timed out. Please try again.",
|
209 |
)
|
210 |
except Exception as e:
|
211 |
logger.error(f"Error: {str(e)}")
|
212 |
+
return ChatMessageResponse(
|
213 |
status=MessageStatus.ERROR,
|
214 |
content="",
|
215 |
error=f"Error: {str(e)}",
|
|
|
319 |
None,
|
320 |
)
|
321 |
|
322 |
+
def clear_history() -> tuple[list[Any], str, str, str, None]:
|
323 |
"""Clear chat history"""
|
324 |
return [], "", "Chat cleared.", "", None
|
325 |
|
326 |
def retry_last_message(
|
327 |
history: List[List[str]],
|
328 |
+
) -> tuple[list[list[str]], str, str, str, None]:
|
329 |
"""Retry the last message"""
|
330 |
if not history:
|
331 |
return history, "", "No message to retry.", "", None
|