pphuc25 commited on
Commit
22b4626
1 Parent(s): 241b3a2

feat: add Premium feature

Browse files
src/agents_framework/integrations.py CHANGED
@@ -17,8 +17,9 @@ from .tools import (
17
  generate_response_base_context,
18
  openai_client_answer
19
  )
 
20
  from .utils import get_data
21
- from typing import List
22
 
23
  # log = logger.get_logger(__name__)
24
 
@@ -55,27 +56,26 @@ class AgentIntegrationService:
55
  return_source_documents=True,
56
  )
57
 
58
- async def answer_to_messages(self, query: List[list], chat_history: List[dict] | None) -> str:
59
- question_type = classifer_question_execution.invoke(query).content.replace('`', "").replace("Tags", "").strip()
60
- # log.info(f"question_type: {question_type}")
61
- print(question_type)
62
  agent_using = self._route(question_type)
63
  if agent_using == self.answer_based_search:
64
- query_reformat = openai_client_answer(query, chat_history, model="gpt-3.5-turbo", template=data["rephrase_query"])
65
- if query == 'not_need':
66
- query_reformat = query
67
- print(query_reformat)
 
 
 
68
  answer = agent_using(query_reformat)
69
  else:
70
  answer = agent_using(query, chat_history)
71
 
72
- if question_type in ["ask_personal", "search_web"]:
73
- return {"text": answer.content, "ids_location": None}
74
  else:
75
- return {
76
- "text": f"Đây là lộ trình tôi nghĩ ra bao gồm {len(set(answer))} điểm du lịch",
77
- "ids_location": list(set(answer))
78
- }
79
 
80
 
81
  def answer_based_retrieval(self, query, chat_history):
@@ -93,16 +93,16 @@ class AgentIntegrationService:
93
  else:
94
  continue
95
  num_questions, questions = batch["number_of_locations"], batch["questions"]
96
- locations_index = []
97
  print(questions)
98
  for question in questions:
99
- locations_index.append(self._retrieval_places(question))
100
- return locations_index
 
101
 
102
  def _retrieval_places(self, query):
103
  response = self.qa_baseline.invoke({"query": query})
104
  location_index = response["source_documents"][0].metadata["seq_num"] - 1
105
- print(location_index)
106
  return location_index
107
 
108
 
@@ -114,8 +114,8 @@ class AgentIntegrationService:
114
  date = datetime.datetime.now().isoformat()
115
  print(context)
116
  answer = generate_response_base_context(context.strip(), date=date, template=data['interpreter_web'])
117
- print(answer)
118
- return answer
119
 
120
 
121
  def answer_normal(self, user_input: str, chat_history, model_name="gpt-3.5-turbo") -> str:
 
17
  generate_response_base_context,
18
  openai_client_answer
19
  )
20
+ from api.schemas import OutputLocation
21
  from .utils import get_data
22
+ from typing import List, Dict, Tuple, Union
23
 
24
  # log = logger.get_logger(__name__)
25
 
 
56
  return_source_documents=True,
57
  )
58
 
59
+ async def answer_to_messages(self, query: List[list], chat_history: List[dict] | None, question_type: str) -> Union[str, OutputLocation | None]:
 
 
 
60
  agent_using = self._route(question_type)
61
  if agent_using == self.answer_based_search:
62
+ query_reformat = openai_client_answer(
63
+ query,
64
+ chat_history,
65
+ model="gpt-3.5-turbo",
66
+ template=data["rephrase_query"]
67
+ )
68
+ if query == 'not_need': query_reformat = query
69
  answer = agent_using(query_reformat)
70
  else:
71
  answer = agent_using(query, chat_history)
72
 
73
+ text_message, location_message = "", None
74
+ if question_type in ["ask_personal", "search_web"]: text_message = answer
75
  else:
76
+ text_message = f"Chào bạn, đây là một lịch trình tham quan thành phố Hồ Chí Minh trong một ngày mà bạn có thể tham khảo:"
77
+ location_message = OutputLocation(locations=answer, message=text_message)
78
+ return text_message, location_message
 
79
 
80
 
81
  def answer_based_retrieval(self, query, chat_history):
 
93
  else:
94
  continue
95
  num_questions, questions = batch["number_of_locations"], batch["questions"]
96
+ locations_index = set()
97
  print(questions)
98
  for question in questions:
99
+ locations_index.add(self._retrieval_places(question))
100
+ return list(locations_index)
101
+
102
 
103
  def _retrieval_places(self, query):
104
  response = self.qa_baseline.invoke({"query": query})
105
  location_index = response["source_documents"][0].metadata["seq_num"] - 1
 
106
  return location_index
107
 
108
 
 
114
  date = datetime.datetime.now().isoformat()
115
  print(context)
116
  answer = generate_response_base_context(context.strip(), date=date, template=data['interpreter_web'])
117
+ # print(answer.content)
118
+ return answer.content
119
 
120
 
121
  def answer_normal(self, user_input: str, chat_history, model_name="gpt-3.5-turbo") -> str:
src/api/routes.py CHANGED
@@ -3,7 +3,8 @@ from fastapi import APIRouter, HTTPException, Depends
3
  from fastapi.security import APIKeyHeader
4
 
5
  from agents.models import service
6
- from .schemas import MessageInput ,ChatAgentResponse
 
7
 
8
 
9
  header_scheme = APIKeyHeader(name="X-API-Key", auto_error=False)
@@ -25,6 +26,7 @@ async def agents_root():
25
  @router.post("/chat-agent")
26
  async def chat_completion(
27
  message: MessageInput,
 
28
  api_key: APIKeyHeader = Depends(header_scheme)
29
 
30
  ) -> ChatAgentResponse:
@@ -53,23 +55,26 @@ async def chat_completion(
53
  user_question = message["question"]
54
  hist_messages = message["history"]
55
 
56
- # log.info(f"User question: {user_question}")
57
- # log.info(f"User message: {hist_messages}")
58
-
59
  chat_history = []
60
  if hist_messages and len(hist_messages[0]) > 0:
61
  chat_history = hist_messages
62
- # for mes in hist_messages:
63
- # log.info(f"Conversation history message: {mes['question']} | {mes['answer']}")
64
- # chat_history.append(HumanMessage(content=mes['question']))
65
- # chat_history.append(AIMessage(content=mes['answer']))
66
 
67
- # log.info(f"User Question: {user_question}")
 
 
68
 
69
- response = await service.answer_to_messages(query=user_question, chat_history=chat_history)
70
  # log.info(f"Agent response: {response['text']}, Sources: {response['ids_location']}")
71
 
72
- return response
 
 
 
 
 
 
 
 
73
 
74
 
75
 
 
3
  from fastapi.security import APIKeyHeader
4
 
5
  from agents.models import service
6
+ from agents_framework.tools import classifer_question_execution
7
+ from .schemas import MessageInput ,ChatAgentResponse, OutputLocation
8
 
9
 
10
  header_scheme = APIKeyHeader(name="X-API-Key", auto_error=False)
 
26
  @router.post("/chat-agent")
27
  async def chat_completion(
28
  message: MessageInput,
29
+ user_status,
30
  api_key: APIKeyHeader = Depends(header_scheme)
31
 
32
  ) -> ChatAgentResponse:
 
55
  user_question = message["question"]
56
  hist_messages = message["history"]
57
 
 
 
 
58
  chat_history = []
59
  if hist_messages and len(hist_messages[0]) > 0:
60
  chat_history = hist_messages
 
 
 
 
61
 
62
+ question_type = classify_question_type(user_question)
63
+ if question_type == "search_web" and user_status != "Premium":
64
+ return ChatAgentResponse(texts_message="Rất tiếc, hiện tại bạn chưa thể sử dụng tính năng này. Hãy nâng cấp lên gói Premium ngay hôm nay để có thể trải nghiệm không giới hạn và mở khóa những tính năng tuyệt vời khác cho chuyến đ cho mình!")
65
 
66
+ text_message, location_message = await service.answer_to_messages(query=user_question, chat_history=chat_history, question_type=question_type)
67
  # log.info(f"Agent response: {response['text']}, Sources: {response['ids_location']}")
68
 
69
+ return ChatAgentResponse(texts_message=text_message, locations_message=location_message)
70
+ # return {"text": text, "ids_location": ids_location}
71
+
72
+ def classify_question_type(query: str):
73
+
74
+ question_type = classifer_question_execution.invoke(query).content.replace('`', "").replace("Tags", "").strip()
75
+ # log.info(f"question_type: {question_type}")
76
+ print(question_type)
77
+ return question_type
78
 
79
 
80
 
src/api/schemas.py CHANGED
@@ -2,7 +2,7 @@ from typing import List
2
  from pydantic import BaseModel
3
 
4
  from langchain.docstore.document import Document
5
-
6
 
7
  # # #########################################
8
  # # Internal schemas
@@ -17,6 +17,11 @@ class MessageInput(BaseModel):
17
  # # API schemas
18
  # # #########################################
19
 
 
 
 
 
20
  class ChatAgentResponse(BaseModel):
21
- text: str | None
22
- ids_location: List[int] | None
 
 
2
  from pydantic import BaseModel
3
 
4
  from langchain.docstore.document import Document
5
+ from typing import Optional
6
 
7
  # # #########################################
8
  # # Internal schemas
 
17
  # # API schemas
18
  # # #########################################
19
 
20
+ class OutputLocation(BaseModel):
21
+ locations: List[int]
22
+ message: str
23
+
24
  class ChatAgentResponse(BaseModel):
25
+ texts_message: str | None
26
+ locations_message: Optional[None | OutputLocation] = None
27
+