feat: add Premium feature
Browse files- src/agents_framework/integrations.py +21 -21
- src/api/routes.py +16 -11
- src/api/schemas.py +8 -3
src/agents_framework/integrations.py
CHANGED
@@ -17,8 +17,9 @@ from .tools import (
|
|
17 |
generate_response_base_context,
|
18 |
openai_client_answer
|
19 |
)
|
|
|
20 |
from .utils import get_data
|
21 |
-
from typing import List
|
22 |
|
23 |
# log = logger.get_logger(__name__)
|
24 |
|
@@ -55,27 +56,26 @@ class AgentIntegrationService:
|
|
55 |
return_source_documents=True,
|
56 |
)
|
57 |
|
58 |
-
async def answer_to_messages(self, query: List[list], chat_history: List[dict] | None) -> str:
|
59 |
-
question_type = classifer_question_execution.invoke(query).content.replace('`', "").replace("Tags", "").strip()
|
60 |
-
# log.info(f"question_type: {question_type}")
|
61 |
-
print(question_type)
|
62 |
agent_using = self._route(question_type)
|
63 |
if agent_using == self.answer_based_search:
|
64 |
-
query_reformat = openai_client_answer(
|
65 |
-
|
66 |
-
|
67 |
-
|
|
|
|
|
|
|
68 |
answer = agent_using(query_reformat)
|
69 |
else:
|
70 |
answer = agent_using(query, chat_history)
|
71 |
|
72 |
-
|
73 |
-
|
74 |
else:
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
}
|
79 |
|
80 |
|
81 |
def answer_based_retrieval(self, query, chat_history):
|
@@ -93,16 +93,16 @@ class AgentIntegrationService:
|
|
93 |
else:
|
94 |
continue
|
95 |
num_questions, questions = batch["number_of_locations"], batch["questions"]
|
96 |
-
locations_index =
|
97 |
print(questions)
|
98 |
for question in questions:
|
99 |
-
locations_index.
|
100 |
-
return locations_index
|
|
|
101 |
|
102 |
def _retrieval_places(self, query):
|
103 |
response = self.qa_baseline.invoke({"query": query})
|
104 |
location_index = response["source_documents"][0].metadata["seq_num"] - 1
|
105 |
-
print(location_index)
|
106 |
return location_index
|
107 |
|
108 |
|
@@ -114,8 +114,8 @@ class AgentIntegrationService:
|
|
114 |
date = datetime.datetime.now().isoformat()
|
115 |
print(context)
|
116 |
answer = generate_response_base_context(context.strip(), date=date, template=data['interpreter_web'])
|
117 |
-
print(answer)
|
118 |
-
return answer
|
119 |
|
120 |
|
121 |
def answer_normal(self, user_input: str, chat_history, model_name="gpt-3.5-turbo") -> str:
|
|
|
17 |
generate_response_base_context,
|
18 |
openai_client_answer
|
19 |
)
|
20 |
+
from api.schemas import OutputLocation
|
21 |
from .utils import get_data
|
22 |
+
from typing import List, Dict, Tuple, Union
|
23 |
|
24 |
# log = logger.get_logger(__name__)
|
25 |
|
|
|
56 |
return_source_documents=True,
|
57 |
)
|
58 |
|
59 |
+
async def answer_to_messages(self, query: List[list], chat_history: List[dict] | None, question_type: str) -> Union[str, OutputLocation | None]:
|
|
|
|
|
|
|
60 |
agent_using = self._route(question_type)
|
61 |
if agent_using == self.answer_based_search:
|
62 |
+
query_reformat = openai_client_answer(
|
63 |
+
query,
|
64 |
+
chat_history,
|
65 |
+
model="gpt-3.5-turbo",
|
66 |
+
template=data["rephrase_query"]
|
67 |
+
)
|
68 |
+
if query == 'not_need': query_reformat = query
|
69 |
answer = agent_using(query_reformat)
|
70 |
else:
|
71 |
answer = agent_using(query, chat_history)
|
72 |
|
73 |
+
text_message, location_message = "", None
|
74 |
+
if question_type in ["ask_personal", "search_web"]: text_message = answer
|
75 |
else:
|
76 |
+
text_message = f"Chào bạn, đây là một lịch trình tham quan thành phố Hồ Chí Minh trong một ngày mà bạn có thể tham khảo:"
|
77 |
+
location_message = OutputLocation(locations=answer, message=text_message)
|
78 |
+
return text_message, location_message
|
|
|
79 |
|
80 |
|
81 |
def answer_based_retrieval(self, query, chat_history):
|
|
|
93 |
else:
|
94 |
continue
|
95 |
num_questions, questions = batch["number_of_locations"], batch["questions"]
|
96 |
+
locations_index = set()
|
97 |
print(questions)
|
98 |
for question in questions:
|
99 |
+
locations_index.add(self._retrieval_places(question))
|
100 |
+
return list(locations_index)
|
101 |
+
|
102 |
|
103 |
def _retrieval_places(self, query):
|
104 |
response = self.qa_baseline.invoke({"query": query})
|
105 |
location_index = response["source_documents"][0].metadata["seq_num"] - 1
|
|
|
106 |
return location_index
|
107 |
|
108 |
|
|
|
114 |
date = datetime.datetime.now().isoformat()
|
115 |
print(context)
|
116 |
answer = generate_response_base_context(context.strip(), date=date, template=data['interpreter_web'])
|
117 |
+
# print(answer.content)
|
118 |
+
return answer.content
|
119 |
|
120 |
|
121 |
def answer_normal(self, user_input: str, chat_history, model_name="gpt-3.5-turbo") -> str:
|
src/api/routes.py
CHANGED
@@ -3,7 +3,8 @@ from fastapi import APIRouter, HTTPException, Depends
|
|
3 |
from fastapi.security import APIKeyHeader
|
4 |
|
5 |
from agents.models import service
|
6 |
-
from .
|
|
|
7 |
|
8 |
|
9 |
header_scheme = APIKeyHeader(name="X-API-Key", auto_error=False)
|
@@ -25,6 +26,7 @@ async def agents_root():
|
|
25 |
@router.post("/chat-agent")
|
26 |
async def chat_completion(
|
27 |
message: MessageInput,
|
|
|
28 |
api_key: APIKeyHeader = Depends(header_scheme)
|
29 |
|
30 |
) -> ChatAgentResponse:
|
@@ -53,23 +55,26 @@ async def chat_completion(
|
|
53 |
user_question = message["question"]
|
54 |
hist_messages = message["history"]
|
55 |
|
56 |
-
# log.info(f"User question: {user_question}")
|
57 |
-
# log.info(f"User message: {hist_messages}")
|
58 |
-
|
59 |
chat_history = []
|
60 |
if hist_messages and len(hist_messages[0]) > 0:
|
61 |
chat_history = hist_messages
|
62 |
-
# for mes in hist_messages:
|
63 |
-
# log.info(f"Conversation history message: {mes['question']} | {mes['answer']}")
|
64 |
-
# chat_history.append(HumanMessage(content=mes['question']))
|
65 |
-
# chat_history.append(AIMessage(content=mes['answer']))
|
66 |
|
67 |
-
|
|
|
|
|
68 |
|
69 |
-
|
70 |
# log.info(f"Agent response: {response['text']}, Sources: {response['ids_location']}")
|
71 |
|
72 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
|
74 |
|
75 |
|
|
|
3 |
from fastapi.security import APIKeyHeader
|
4 |
|
5 |
from agents.models import service
|
6 |
+
from agents_framework.tools import classifer_question_execution
|
7 |
+
from .schemas import MessageInput ,ChatAgentResponse, OutputLocation
|
8 |
|
9 |
|
10 |
header_scheme = APIKeyHeader(name="X-API-Key", auto_error=False)
|
|
|
26 |
@router.post("/chat-agent")
|
27 |
async def chat_completion(
|
28 |
message: MessageInput,
|
29 |
+
user_status,
|
30 |
api_key: APIKeyHeader = Depends(header_scheme)
|
31 |
|
32 |
) -> ChatAgentResponse:
|
|
|
55 |
user_question = message["question"]
|
56 |
hist_messages = message["history"]
|
57 |
|
|
|
|
|
|
|
58 |
chat_history = []
|
59 |
if hist_messages and len(hist_messages[0]) > 0:
|
60 |
chat_history = hist_messages
|
|
|
|
|
|
|
|
|
61 |
|
62 |
+
question_type = classify_question_type(user_question)
|
63 |
+
if question_type == "search_web" and user_status != "Premium":
|
64 |
+
return ChatAgentResponse(texts_message="Rất tiếc, hiện tại bạn chưa thể sử dụng tính năng này. Hãy nâng cấp lên gói Premium ngay hôm nay để có thể trải nghiệm không giới hạn và mở khóa những tính năng tuyệt vời khác cho chuyến đ cho mình!")
|
65 |
|
66 |
+
text_message, location_message = await service.answer_to_messages(query=user_question, chat_history=chat_history, question_type=question_type)
|
67 |
# log.info(f"Agent response: {response['text']}, Sources: {response['ids_location']}")
|
68 |
|
69 |
+
return ChatAgentResponse(texts_message=text_message, locations_message=location_message)
|
70 |
+
# return {"text": text, "ids_location": ids_location}
|
71 |
+
|
72 |
+
def classify_question_type(query: str):
|
73 |
+
|
74 |
+
question_type = classifer_question_execution.invoke(query).content.replace('`', "").replace("Tags", "").strip()
|
75 |
+
# log.info(f"question_type: {question_type}")
|
76 |
+
print(question_type)
|
77 |
+
return question_type
|
78 |
|
79 |
|
80 |
|
src/api/schemas.py
CHANGED
@@ -2,7 +2,7 @@ from typing import List
|
|
2 |
from pydantic import BaseModel
|
3 |
|
4 |
from langchain.docstore.document import Document
|
5 |
-
|
6 |
|
7 |
# # #########################################
|
8 |
# # Internal schemas
|
@@ -17,6 +17,11 @@ class MessageInput(BaseModel):
|
|
17 |
# # API schemas
|
18 |
# # #########################################
|
19 |
|
|
|
|
|
|
|
|
|
20 |
class ChatAgentResponse(BaseModel):
|
21 |
-
|
22 |
-
|
|
|
|
2 |
from pydantic import BaseModel
|
3 |
|
4 |
from langchain.docstore.document import Document
|
5 |
+
from typing import Optional
|
6 |
|
7 |
# # #########################################
|
8 |
# # Internal schemas
|
|
|
17 |
# # API schemas
|
18 |
# # #########################################
|
19 |
|
20 |
+
class OutputLocation(BaseModel):
|
21 |
+
locations: List[int]
|
22 |
+
message: str
|
23 |
+
|
24 |
class ChatAgentResponse(BaseModel):
|
25 |
+
texts_message: str | None
|
26 |
+
locations_message: Optional[None | OutputLocation] = None
|
27 |
+
|