pphuc25 commited on
Commit
9121b91
1 Parent(s): 519f5e0

chore: add handle multple user

Browse files
src/agents_framework/integrations.py CHANGED
@@ -12,10 +12,11 @@ import datetime
12
 
13
  from .tools import (
14
  load_retrieval_tool,
15
- classifer_question_execution,
16
  search_searxng_engine,
17
  generate_response_base_context,
18
- openai_client_answer
 
 
19
  )
20
  from api.schemas import OutputLocation
21
  from .utils import get_data
@@ -26,18 +27,6 @@ from typing import List, Dict, Tuple, Union
26
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
27
 
28
 
29
- paths = {
30
- "question_type_classification": "data/prompts/question_type_classification.txt",
31
- "questions_generation": "data/prompts/questions_generation.txt",
32
- "normal_conversation": "data/prompts/normal_conversation.txt",
33
- "rephrase_query": "data/prompts/rephrase_query.txt",
34
- "interpreter_web": "data/prompts/interpreter_web.txt",
35
-
36
- }
37
-
38
- # Load the data
39
- data = get_data(paths)
40
-
41
 
42
 
43
  class AgentIntegrationService:
@@ -63,7 +52,7 @@ class AgentIntegrationService:
63
  query,
64
  chat_history,
65
  model="gpt-3.5-turbo",
66
- template=data["rephrase_query"]
67
  )
68
  if query == 'not_need': query_reformat = query
69
  answer = agent_using(query_reformat)
@@ -114,34 +103,37 @@ class AgentIntegrationService:
114
  context += result['snippet'] + "\n"
115
  date = datetime.datetime.now().isoformat()
116
  print(context)
117
- answer = generate_response_base_context(context.strip(), date=date, template=data['interpreter_web'])
118
  # print(answer.content)
119
  return answer.content
120
 
121
 
122
  def answer_normal(self, user_input: str, chat_history, model_name="gpt-3.5-turbo") -> str:
123
  """Extract the token name or name from the user input"""
124
- set_conversation = [("system", data["normal_conversation"])]
 
125
  if chat_history and len(chat_history) > 0:
126
- for history in chat_history:
127
- set_conversation.append(("human", history['question']))
128
- set_conversation.append(("system", history['answer']))
129
- set_conversation.append(("human", "{input}"))
130
- final_prompt = ChatPromptTemplate.from_messages(set_conversation) | ChatOpenAI(model=model_name, temperature=0)
131
- return final_prompt.invoke(user_input)
 
 
132
 
133
 
134
  def create_questions(self, user_input: str, chat_history) -> str:
135
  """Create questions from the user input"""
136
- set_conversation = [("system", data['questions_generation'])]
137
  if chat_history:
138
  for mes in chat_history:
139
- set_conversation.append(("human", mes['question']))
140
- set_conversation.append(("ai", mes['answer']))
141
- set_conversation.append(("human", "{input}"))
 
 
142
 
143
- create_questions_execution = ChatPromptTemplate.from_messages(set_conversation) | ChatOpenAI(model="gpt-4o", temperature=0)
144
- answer = create_questions_execution.invoke({"input": user_input}).content
145
  print(answer)
146
  data_loaded = yaml.safe_load(answer.replace("```", "").replace("yaml", ""))
147
  return data_loaded
 
12
 
13
  from .tools import (
14
  load_retrieval_tool,
 
15
  search_searxng_engine,
16
  generate_response_base_context,
17
+ openai_client_answer,
18
+ data_prompts,
19
+ create_execution_w_custom_message
20
  )
21
  from api.schemas import OutputLocation
22
  from .utils import get_data
 
27
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
28
 
29
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
 
32
  class AgentIntegrationService:
 
52
  query,
53
  chat_history,
54
  model="gpt-3.5-turbo",
55
+ template=data_prompts["rephrase_query"]
56
  )
57
  if query == 'not_need': query_reformat = query
58
  answer = agent_using(query_reformat)
 
103
  context += result['snippet'] + "\n"
104
  date = datetime.datetime.now().isoformat()
105
  print(context)
106
+ answer = generate_response_base_context(context.strip(), date=date, template=data_prompts['interpreter_web'])
107
  # print(answer.content)
108
  return answer.content
109
 
110
 
111
  def answer_normal(self, user_input: str, chat_history, model_name="gpt-3.5-turbo") -> str:
112
  """Extract the token name or name from the user input"""
113
+
114
+ messages = [{"role": "system", "content": data_prompts['normal_conversation']}]
115
  if chat_history and len(chat_history) > 0:
116
+ for mes in chat_history:
117
+ messages.append({"role": "user", "content": mes['question']})
118
+ messages.append({"role": "assistant", "content": mes['answer']})
119
+ messages.append({"role": "user", "content": user_input})
120
+
121
+ answer = create_execution_w_custom_message(messages).content
122
+
123
+ return answer
124
 
125
 
126
  def create_questions(self, user_input: str, chat_history) -> str:
127
  """Create questions from the user input"""
128
+ messages = [{"role": "system", "content": data_prompts['questions_generation']}]
129
  if chat_history:
130
  for mes in chat_history:
131
+ messages.append({"role": "user", "content": mes['question']})
132
+ messages.append({"role": "assistant", "content": mes['answer']})
133
+ messages.append({"role": "user", "content": user_input})
134
+
135
+ answer = create_execution_w_custom_message(messages).content
136
 
 
 
137
  print(answer)
138
  data_loaded = yaml.safe_load(answer.replace("```", "").replace("yaml", ""))
139
  return data_loaded
src/agents_framework/tools.py CHANGED
@@ -13,6 +13,12 @@ from langchain.retrievers import ParentDocumentRetriever
13
  from langchain_core.prompts import ChatPromptTemplate
14
  from langchain_community.utilities import SearxSearchWrapper
15
 
 
 
 
 
 
 
16
  from .utils import get_data
17
 
18
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
@@ -33,18 +39,31 @@ PROJECT_ROOT = "./" # insert your project root directory name here
33
 
34
 
35
  paths = {
 
36
  "question_type_classification": "data/prompts/question_type_classification.txt",
37
  "questions_generation": "data/prompts/questions_generation.txt",
38
- "interpreter_web": "data/prompts/interpreter_web.txt"
 
39
  }
40
 
41
  # Load the data
42
- data = get_data(paths)
43
  current_date = datetime.datetime.now().isoformat()
44
 
45
  # Initial search api
46
  searxng_api = SearxSearchWrapper(searx_host=SEARXNG_PORT, k = 20)
47
 
 
 
 
 
 
 
 
 
 
 
 
48
  def load_retrieval_tool(embedding_model):
49
  parent_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
50
  child_splitter = RecursiveCharacterTextSplitter(chunk_size=200, chunk_overlap=0)
@@ -68,6 +87,7 @@ def load_retrieval_tool(embedding_model):
68
  return retriever_tool
69
 
70
 
 
71
  def search_searxng_engine(query, engines: list=["google"], enabled_engines: list=["google", "apple_maps"]):
72
  results = searxng_api.results(
73
  query,
@@ -80,6 +100,7 @@ def search_searxng_engine(query, engines: list=["google"], enabled_engines: list
80
  return results
81
 
82
 
 
83
  def openai_client_answer(query, history=None, model="gpt-3.5-turbo", template=""):
84
  chat_completion = openai_client.chat.completions.create(
85
  messages=[
@@ -93,6 +114,7 @@ def openai_client_answer(query, history=None, model="gpt-3.5-turbo", template=""
93
  return chat_completion.choices[0].message.content
94
 
95
 
 
96
  def generate_response_base_context(context, date, template, model="gpt-4o"):
97
  chat_completion = openai_client.chat.completions.create(
98
  messages=[
@@ -106,20 +128,42 @@ def generate_response_base_context(context, date, template, model="gpt-4o"):
106
  return chat_completion.choices[0].message
107
 
108
 
109
- classifer_question_execution = ChatPromptTemplate.from_messages(
110
- [
111
- ("system", data['question_type_classification']),
112
- ("human", "{input}"),
113
- ]
114
- ) | ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
 
117
- create_questions_execution = ChatPromptTemplate.from_messages(
118
- [
119
- ("system", data['questions_generation']),
120
- ("human", "{input}"),
121
- ]
122
- ) | ChatOpenAI(model="gpt-4o", temperature=0)
123
 
124
 
125
  if __name__ == "__main__":
@@ -134,7 +178,7 @@ if __name__ == "__main__":
134
 
135
 
136
  def generate_response(context: str) -> str:
137
- prompt = data["interpreter_web"].format(context=context, date=current_date)
138
  print(prompt)
139
  location_answer_execution = ChatPromptTemplate.from_messages(
140
  [
 
13
  from langchain_core.prompts import ChatPromptTemplate
14
  from langchain_community.utilities import SearxSearchWrapper
15
 
16
+ from tenacity import (
17
+ retry,
18
+ stop_after_attempt,
19
+ wait_random_exponential,
20
+ ) # for exponential backoff
21
+
22
  from .utils import get_data
23
 
24
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
 
39
 
40
 
41
  paths = {
42
+ "normal_conversation": "data/prompts/normal_conversation.txt",
43
  "question_type_classification": "data/prompts/question_type_classification.txt",
44
  "questions_generation": "data/prompts/questions_generation.txt",
45
+ "interpreter_web": "data/prompts/interpreter_web.txt",
46
+ "rephrase_query": "data/prompts/rephrase_query.txt",
47
  }
48
 
49
  # Load the data
50
+ data_prompts = get_data(paths)
51
  current_date = datetime.datetime.now().isoformat()
52
 
53
  # Initial search api
54
  searxng_api = SearxSearchWrapper(searx_host=SEARXNG_PORT, k = 20)
55
 
56
+
57
+
58
+
59
+
60
+
61
+
62
+
63
+
64
+
65
+
66
+
67
  def load_retrieval_tool(embedding_model):
68
  parent_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
69
  child_splitter = RecursiveCharacterTextSplitter(chunk_size=200, chunk_overlap=0)
 
87
  return retriever_tool
88
 
89
 
90
+ @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
91
  def search_searxng_engine(query, engines: list=["google"], enabled_engines: list=["google", "apple_maps"]):
92
  results = searxng_api.results(
93
  query,
 
100
  return results
101
 
102
 
103
+ @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
104
  def openai_client_answer(query, history=None, model="gpt-3.5-turbo", template=""):
105
  chat_completion = openai_client.chat.completions.create(
106
  messages=[
 
114
  return chat_completion.choices[0].message.content
115
 
116
 
117
+ @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
118
  def generate_response_base_context(context, date, template, model="gpt-4o"):
119
  chat_completion = openai_client.chat.completions.create(
120
  messages=[
 
128
  return chat_completion.choices[0].message
129
 
130
 
131
+ @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
132
+ def create_execution_wout_edit_template(input_text, template, model="gpt-3.5-turbo"):
133
+ chat_completion = openai_client.chat.completions.create(
134
+ messages=[
135
+ {"role": "system", "content": template},
136
+ {"role": "user", "content": input_text}
137
+ ],
138
+ model=model,
139
+ temperature=0
140
+ )
141
+ return chat_completion.choices[0].message
142
+
143
+
144
+ @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
145
+ def create_execution_w_custom_message(messages: list, model="gpt-3.5-turbo"):
146
+ chat_completion = openai_client.chat.completions.create(
147
+ messages=messages,
148
+ model=model,
149
+ temperature=0
150
+ )
151
+ return chat_completion.choices[0].message
152
+
153
+ # classifer_question_execution = ChatPromptTemplate.from_messages(
154
+ # [
155
+ # ("system", data['question_type_classification']),
156
+ # ("human", "{input}"),
157
+ # ]
158
+ # ) | ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
159
 
160
 
161
+ # create_questions_execution = ChatPromptTemplate.from_messages(
162
+ # [
163
+ # ("system", data_prompts['questions_generation']),
164
+ # ("human", "{input}"),
165
+ # ]
166
+ # ) | ChatOpenAI(model="gpt-4o", temperature=0)
167
 
168
 
169
  if __name__ == "__main__":
 
178
 
179
 
180
  def generate_response(context: str) -> str:
181
+ prompt = data_prompts["interpreter_web"].format(context=context, date=current_date)
182
  print(prompt)
183
  location_answer_execution = ChatPromptTemplate.from_messages(
184
  [
src/api/routes.py CHANGED
@@ -3,7 +3,7 @@ from fastapi import APIRouter, HTTPException, Depends
3
  from fastapi.security import APIKeyHeader
4
 
5
  from agents.models import service
6
- from agents_framework.tools import classifer_question_execution
7
  from .schemas import MessageInput ,ChatAgentResponse, OutputLocation
8
 
9
 
@@ -70,9 +70,8 @@ async def chat_completion(
70
  # return {"text": text, "ids_location": ids_location}
71
 
72
  def classify_question_type(query: str):
73
-
74
- question_type = classifer_question_execution.invoke(query).content.replace('`', "").replace("Tags", "").strip()
75
- # log.info(f"question_type: {question_type}")
76
  print(question_type)
77
  return question_type
78
 
 
3
  from fastapi.security import APIKeyHeader
4
 
5
  from agents.models import service
6
+ from agents_framework.tools import create_execution_wout_edit_template, data_prompts
7
  from .schemas import MessageInput ,ChatAgentResponse, OutputLocation
8
 
9
 
 
70
  # return {"text": text, "ids_location": ids_location}
71
 
72
  def classify_question_type(query: str):
73
+ question_type = create_execution_wout_edit_template(query, data_prompts["question_type_classification"]).content.replace('`', "").replace("Tags", "").strip()
74
+ # question_type = classifer_question_execution.invoke(query).content.replace('`', "").replace("Tags", "").strip()
 
75
  print(question_type)
76
  return question_type
77