SungBeom commited on
Commit
57b936b
โ€ข
1 Parent(s): 34d8557

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +581 -130
app.py CHANGED
@@ -1,23 +1,20 @@
1
  import os
2
  import configparser
3
 
4
- # config = configparser.ConfigParser()
5
- # config.read('./secrets.ini')
6
-
7
- # openai_api_key = config['OPENAI']['OPENAI_API_KEY']
8
- # serper_api_key = config['SERPER']['SERPER_API_KEY']
9
- # serp_api_key = config['SERPAPI']['SERPAPI_API_KEY']
10
- # os.environ.update({'OPENAI_API_KEY': openai_api_key})
11
- # os.environ.update({'SERPER_API_KEY': serper_api_key})
12
- # os.environ.update({'SERPAPI_API_KEY': serp_api_key})
13
-
14
- from typing import List, Union
15
  import re
 
 
16
  import json
 
 
 
 
17
 
18
  import pandas as pd
19
  from langchain import SerpAPIWrapper, LLMChain
20
  from langchain.agents import Tool, AgentType, AgentExecutor, LLMSingleActionAgent, AgentOutputParser
 
21
  from langchain.chat_models import ChatOpenAI
22
  from langchain.chains import LLMChain, SimpleSequentialChain
23
  from langchain.chains.query_constructor.base import AttributeInfo
@@ -28,37 +25,58 @@ from langchain.prompts import PromptTemplate, StringPromptTemplate, load_prompt,
28
  from langchain.llms import OpenAI
29
  from langchain.retrievers.self_query.base import SelfQueryRetriever
30
  from langchain.schema import AgentAction, AgentFinish, HumanMessage
31
- from langchain.vectorstores import DocArrayInMemorySearch, Chroma
32
-
33
- stage_analyzer_inception_prompt = load_prompt("./templates/stage_analyzer_inception_prompt_template.json")
34
- llm = ChatOpenAI(model='gpt-3.5-turbo', temperature=0.0)
35
- stage_analyzer_chain = LLMChain(
36
- llm=llm,
37
- prompt=stage_analyzer_inception_prompt,
38
- verbose=False,
39
- output_key="stage_number")
40
-
41
- user_response_prompt = load_prompt("./templates/user_response_prompt.json")
42
- llm = ChatOpenAI(model='gpt-4', temperature=0.5)
43
- user_response_chain = LLMChain(
44
- llm=llm,
45
- prompt=user_response_prompt,
46
- verbose=False,
47
- output_key="user_responses"
48
- )
49
 
50
- df = pd.read_json('./data/unified_wine_data.json', encoding='utf-8', lines=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
- loader =DataFrameLoader(data_frame=df, page_content_column='name')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  docs = loader.load()
54
  embeddings = OpenAIEmbeddings()
55
-
56
  metadata_field_info = [
57
  AttributeInfo(
58
  name="body",
59
  description="1-5 rating for the body of wine",
60
  type="int",
61
  ),
 
 
 
 
 
62
  AttributeInfo(
63
  name="sweetness",
64
  description="1-5 rating for the sweetness of wine",
@@ -91,57 +109,164 @@ metadata_field_info = [
91
  ),
92
  ]
93
 
94
- vectorstore = Chroma.from_documents(docs, embeddings)
95
- document_content_description = "Database of a wine"
96
  llm = OpenAI(temperature=0)
97
- retriever = SelfQueryRetriever.from_llm(
98
- llm, vectorstore, document_content_description, metadata_field_info, verbose=False
99
  ) # Added missing closing parenthesis
100
 
101
- def search_with_url(query):
102
- return SeleniumURLLoader(urls=[query]).load()
 
 
 
 
 
 
103
 
104
- index = VectorstoreIndexCreator(
105
- vectorstore_cls=DocArrayInMemorySearch
106
- ).from_loaders([loader])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
 
 
 
 
 
 
108
  search = SerpAPIWrapper()
109
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  tools = [
111
  Tool(
112
  name="Wine database",
113
- func=retriever.get_relevant_documents,
 
114
  description="""
115
- Database about the wines in wine store. You can get information such as the price of the wine, purchase URL, features, rating information, and more.
116
  You can search wines with the following attributes:
117
- - body: 1-5 rating int for the body of wine. You have to specify greater than or less than. For example, if you want to search for wines with a body rating of less than 3, enter 'body: gt 0 lt 3'
118
- - price: The price range of the wine. Please enter the price range in the form of range. For example, if you want to search for wines that cost less than 20,000 won, enter 'price: gt 0 lt20000'
119
- - rating: 1-5 rating float for the wine. You have to specify greater than or less than. For example, if you want to search for wines with a rating of less than 3, enter 'rating: gt 0 lt 3'
120
  - wine_type: The type of wine. It can be '๋ ˆ๋“œ', '๋กœ์ œ', '์ŠคํŒŒํด๋ง', 'ํ™”์ดํŠธ', '๋””์ €ํŠธ', '์ฃผ์ •๊ฐ•ํ™”'
121
- - name: The name of wine. ์ž…๋ ฅํ•  ๋•Œ๋Š” '์™€์ธ ์ด๋ฆ„์€ "๋น„๋ƒ ์กฐ์ž˜" ์ž…๋‹ˆ๋‹ค' ์ด๋Ÿฐ ์‹์œผ๋กœ ์ž…๋ ฅํ•ด์ฃผ์„ธ์š”.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  """
123
  ),
124
- # Tool(
125
- # name = "Search specific wine with url",
126
- # func=search_with_url,
127
- # description="Search specific wine with url. Query must be url"
128
- # ),
129
  Tool(
130
- name = "Wine database 2",
131
- func=index.query,
132
- description="Database about the wines in wine store. You can use this tool if you're having trouble getting information from the wine database tool above. Query must be in String"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  ),
134
  Tool(
135
  name = "Search",
136
  func=search.run,
 
137
  description="Useful for when you need to ask with search. Search in English only."
138
  ),
 
 
 
 
 
 
139
  ]
140
-
141
  template = """
142
  Your role is a chatbot that asks customers questions about wine and makes recommendations.
143
  Never forget your name is "์ด์šฐ์„ ".
144
- Keep your responses in short length to retain the user's attention.
 
145
  Only generate one response at a time! When you are done generating, end with '<END_OF_TURN>' to give the user a chance to respond.
146
  Responses should be in Korean.
147
 
@@ -153,6 +278,7 @@ Use the following format:
153
  Thought: you should always think about what to do
154
  Action: the action to take, should be one of [{tool_names}]
155
  Action Input: the input to the action
 
156
  Observation: the result of the action
157
  ... (this Thought/Action/Action Input/Observation can repeat N times)
158
  Thought: I now know the final answer
@@ -173,14 +299,15 @@ Last user saying: {input}
173
  """
174
 
175
  conversation_stages_dict = {
176
- "1": "Start: Start the conversation by introducing yourself. Be polite and respectful while maintaining a professional tone of conversation.",
177
- "2": "Analyze: Figuring out the customer's preferences in order to make wine recommendations. Ask questions to figure out the preferences of your customer in order to make wine recommendations. Ask only one question at a time. The wine database tool is not available here.",
178
- "3": "Product Recommendation: Recommend the right wine based on the user's preferences identified. Recommendations must be limited to wines in wine database, and you can use tools to do this. After making a wine recommendation, it asks if the user likes the wine you recommended.",
179
- "4": "Sales: If the customer wants to get the wine you recommended, provides a link and image of wine and price of it. Link you provide must be open in new tab when clicked. Otherwise, it takes you back to the recommendation stage.",
180
- "5": "Close: When you're done, say goodbye to the customer.",
181
- "6": "Question and Answering: This is where you answer the customer's questions. To answer customer question, you can use the search tool or the wine database tool.",
182
- "7": "Place Recommendation: Recommend wine bar based on location. you must be ask the location the customer wants before recommend the wine bar",
183
- "8": "Not in the given steps: This step is for when none of the steps between 1 and 7 apply.",
 
184
  }
185
 
186
  # Set up a prompt template
@@ -197,8 +324,42 @@ class CustomPromptTemplate(StringPromptTemplate):
197
  # Format them in a particular way
198
  intermediate_steps = kwargs.pop("intermediate_steps")
199
  thoughts = ""
 
200
  for action, observation in intermediate_steps:
201
  thoughts += action.log
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
  thoughts += f"\nObservation: {observation}\nThought: "
203
  # Set the agent_scratchpad variable to that value
204
  kwargs["agent_scratchpad"] = thoughts
@@ -208,12 +369,12 @@ class CustomPromptTemplate(StringPromptTemplate):
208
  kwargs["tool_names"] = ", ".join([tool.name for tool in self.tools])
209
  return self.template.format(**kwargs)
210
 
 
211
  prompt = CustomPromptTemplate(
212
  template=template,
213
  tools=tools,
214
  input_variables=["input", "intermediate_steps", "conversation_history", "stage_number"]
215
  )
216
-
217
  class CustomOutputParser(AgentOutputParser):
218
 
219
  def parse(self, llm_output: str) -> Union[AgentAction, AgentFinish]:
@@ -226,39 +387,263 @@ class CustomOutputParser(AgentOutputParser):
226
  log=llm_output,
227
  )
228
  # Parse out the action and action input
229
- regex = r"Action\s*\d*\s*:(.*?)\nAction\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)"
230
  match = re.search(regex, llm_output, re.DOTALL)
231
  if not match:
232
  raise ValueError(f"Could not parse LLM output: `{llm_output}`")
233
  action = match.group(1).strip()
234
  action_input = match.group(2)
 
 
235
  # Return the action and action input
236
  return AgentAction(tool=action, tool_input=action_input.strip(" ").strip('"'), log=llm_output)
237
 
238
  output_parser = CustomOutputParser()
239
 
240
- llm_chain = LLMChain(llm=ChatOpenAI(model='gpt-4', temperature=0.0), prompt=prompt, verbose=False,)
241
-
242
- tool_names = [tool.name for tool in tools]
243
- agent = LLMSingleActionAgent(
244
- llm_chain=llm_chain,
245
- output_parser=output_parser,
246
- stop=["\nObservation:"],
247
- allowed_tools=tool_names
248
- )
249
-
250
- agent_executor = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
 
252
- import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
 
254
- # user_response, stage_history, conversation_history, pre_conversation_history = "", "", """""", """"""
 
 
255
 
256
- stage_description = ""
257
- for key, value in conversation_stages_dict.items():
258
- stage_description += f"{key}.{value}\n"
259
 
260
  with gr.Blocks(css='#chatbot .overflow-y-auto{height:750px}') as demo:
261
-
262
  with gr.Row():
263
  gr.HTML("""<div style="text-align: center; max-width: 500px; margin: 0 auto;">
264
  <div>
@@ -268,52 +653,118 @@ with gr.Blocks(css='#chatbot .overflow-y-auto{height:750px}') as demo:
268
  LinkedIn <a href="https://www.linkedin.com/company/audrey-ai/about/">Audrey.ai</a>
269
  </p>
270
  </div>""")
271
-
272
  chatbot = gr.Chatbot()
273
- msg = gr.Textbox(label='User input')
274
- samples = [["์ด๋ฒˆ ์ฃผ์— ์นœ๊ตฌ๋“ค๊ณผ ๋ชจ์ž„์ด ์žˆ๋Š”๋ฐ, ํ›Œ๋ฅญํ•œ ์™€์ธ ํ•œ ๋ณ‘์„ ์ถ”์ฒœํ•ด์ค„๋ž˜?"], ["์ž…๋ฌธ์ž์—๊ฒŒ ์ข‹์€ ์™€์ธ์„ ์ถ”์ฒœํ•ด์ค„๋ž˜?"], ["๋ณด๋ฅด๋„์™€ ๋ถ€๋ฅด๊ณ ๋‰ด ์™€์ธ์˜ ์ฐจ์ด์ ์€ ๋ญ์•ผ?"]]
275
- user_response_examples = gr.Dataset(samples=samples, components=[msg], type="index")
276
- stage_history = gr.Textbox(value="stage history: ", interactive=False, label='stage history')
277
- submit_btn = gr.Button("์ „์†ก")
 
 
 
278
  clear_btn = gr.ClearButton([msg, chatbot])
279
- stage_info = gr.Textbox(value=stage_description, interactive=False, label='stage description')
280
-
281
- def load_example(example_id):
282
- global samples
283
- return samples[example_id][0]
284
-
285
- def answer(user_response, chat_history, stage_history):
286
- global samples
287
- chat_history = chat_history or []
288
- stage_history = stage_history or ""
289
- pre_conversation_history = ""
290
- for idx, chat in enumerate(chat_history):
291
- pre_conversation_history += f"User: {chat[0]} <END_OF_TURN>\n"
292
- pre_conversation_history += f"์ด์šฐ์„ : {chat[1]} <END_OF_TURN>\n"
293
- conversation_history = pre_conversation_history + f"User: {user_response} <END_OF_TURN>\n"
294
- stage_number = stage_analyzer_chain.run({'conversation_history': conversation_history, 'stage_history': stage_history})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
295
  stage_number = stage_number[-1]
296
  stage_history += stage_number if stage_history == "stage history: " else ", " + stage_number
297
- response = agent_executor.run({'input':user_response, 'conversation_history': pre_conversation_history, 'stage_number': stage_number})
298
- conversation_history += "์ด์šฐ์„ : " + response + "\n"
299
- for line in conversation_history.split('\n'):
300
- print(line)
301
- response = response.split('<END_OF_TURN>')[0]
302
- chat_history.append((user_response, response))
303
- user_response_examples = []
304
- for user_response_example in user_response_chain.run({'conversation_history': conversation_history}).split('|'):
305
- user_response_examples.append([user_response_example])
306
- samples = user_response_examples
307
-
308
- return "", chat_history, stage_history, gr.Dataset.update(samples=samples)
309
-
310
- def clear(*args):
311
- global samples
312
- samples = [["์ด๋ฒˆ ์ฃผ์— ์นœ๊ตฌ๋“ค๊ณผ ๋ชจ์ž„์ด ์žˆ๋Š”๋ฐ, ํ›Œ๋ฅญํ•œ ์™€์ธ ํ•œ ๋ณ‘์„ ์ถ”์ฒœํ•ด์ค„๋ž˜?"], ["์ž…๋ฌธ์ž์—๊ฒŒ ์ข‹์€ ์™€์ธ์„ ์ถ”์ฒœํ•ด์ค„๋ž˜?"], ["๋ณด๋ฅด๋„์™€ ๋ถ€๋ฅด๊ณ ๋‰ด ์™€์ธ์˜ ์ฐจ์ด์ ์€ ๋ญ์•ผ?"]]
313
- return gr.Dataset.update(samples=samples), "stage history: "
314
-
315
- clear_btn.click(fn=clear, inputs=[user_response_examples, stage_history], outputs=[user_response_examples, stage_history])
316
- user_response_examples.click(load_example, inputs=[user_response_examples], outputs=[msg])
317
- submit_btn.click(answer, [msg, chatbot, stage_history], [msg, chatbot, stage_history, user_response_examples])
318
- msg.submit(answer, [msg, chatbot, stage_history], [msg, chatbot, stage_history, user_response_examples])
319
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import configparser
3
 
4
+ from typing import List, Union, Optional, Any, Dict
 
 
 
 
 
 
 
 
 
 
5
  import re
6
+ import sys
7
+ import time
8
  import json
9
+ import asyncio
10
+ import aiohttp
11
+ import requests
12
+ import threading
13
 
14
  import pandas as pd
15
  from langchain import SerpAPIWrapper, LLMChain
16
  from langchain.agents import Tool, AgentType, AgentExecutor, LLMSingleActionAgent, AgentOutputParser
17
+ from langchain.callbacks.streaming_stdout_final_only import FinalStreamingStdOutCallbackHandler
18
  from langchain.chat_models import ChatOpenAI
19
  from langchain.chains import LLMChain, SimpleSequentialChain
20
  from langchain.chains.query_constructor.base import AttributeInfo
 
25
  from langchain.llms import OpenAI
26
  from langchain.retrievers.self_query.base import SelfQueryRetriever
27
  from langchain.schema import AgentAction, AgentFinish, HumanMessage
28
+ from langchain.vectorstores import Chroma
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
+ import gradio as gr
31
+
32
+ # config = configparser.ConfigParser()
33
+ # config.read('./secrets.ini')
34
+ # openai_api_key = config['OPENAI']['OPENAI_API_KEY']
35
+ # serper_api_key = config['SERPER']['SERPER_API_KEY']
36
+ # serp_api_key = config['SERPAPI']['SERPAPI_API_KEY']
37
+ # kakao_api_key = config['KAKAO_MAP']['KAKAO_API_KEY']
38
+ # huggingface_token = config['HUGGINGFACE']['HUGGINGFACE_TOKEN']
39
+
40
+
41
+ # os.environ.update({'OPENAI_API_KEY': openai_api_key})
42
+ # os.environ.update({'SERPER_API_KEY': serper_api_key})
43
+ # os.environ.update({'SERPAPI_API_KEY': serp_api_key})
44
+ # os.environ.update({'KAKAO_API_KEY': kakao_api_key})
45
+ # os.environ.update({'HUGGINGFACE_TOKEN': huggingface_token})
46
+
47
+ huggingface_token = os.getenv('HUGGINGFACE_TOKEN')
48
+ kakao_api_key = os.getenv('KAKAO_API_KEY')
49
 
50
+ ### Load wine database json
51
+ df = pd.read_json('./data/unified_wine_data.json', encoding='utf-8', lines=True)
52
+ ### Prepare Langchain Tool
53
+ #### Tool1: Wine database 1
54
+ df['page_content'] = ''
55
+ columns = ['name', 'pairing']
56
+ for column in columns:
57
+ if column != 'page_content':
58
+ df['page_content'] += column + ':' + df[column].astype(str) + ','
59
+ columns = ['rating', 'price', 'body', 'sweetness', 'alcohol', 'acidity', 'tannin']
60
+ for idx in df.index:
61
+ for column in columns:
62
+ if type(df[column][idx]) == str:
63
+ df[column][idx] = df[column][idx].replace(',', '')
64
+ df[column][idx] = float(df[column][idx]) if df[column][idx] != '' else -1
65
+ loader =DataFrameLoader(data_frame=df, page_content_column='page_content')
66
  docs = loader.load()
67
  embeddings = OpenAIEmbeddings()
68
+ # ์•„๋ž˜๋Š” wine database1์— metadata_field Attribute์ด๋‹ค. ์•„๋ž˜๋ฅผ ๊ธฐ์ค€์œผ๋กœ ์„œ์น˜๋ฅผ ์ง„ํ–‰ํ•˜๊ฒŒ ๋œ๋‹ค.
69
  metadata_field_info = [
70
  AttributeInfo(
71
  name="body",
72
  description="1-5 rating for the body of wine",
73
  type="int",
74
  ),
75
+ AttributeInfo(
76
+ name="tannin",
77
+ description="1-5 rating for the tannin of wine",
78
+ type="int",
79
+ ),
80
  AttributeInfo(
81
  name="sweetness",
82
  description="1-5 rating for the sweetness of wine",
 
109
  ),
110
  ]
111
 
112
+ wine_vectorstore = Chroma.from_documents(docs, embeddings)
113
+ document_content_description = "A database of wines. 'name' and 'pairing' must be included in the query, and 'Body', 'Tannin', 'Sweetness', 'Alcohol', 'Price', 'Rating', 'Wine_Type', and 'Country' can be included in the filter. query and filter must be form of 'key: value'. For example, query: 'name: ๋”ํŽ˜๋ฆฌ๋‡ฝ, pairing:์œก๋ฅ˜'."
114
  llm = OpenAI(temperature=0)
115
+ wine_retriever = SelfQueryRetriever.from_llm(
116
+ llm, wine_vectorstore, document_content_description, metadata_field_info, verbose=True
117
  ) # Added missing closing parenthesis
118
 
119
+ #### Tool2: Wine bar database
120
+ df = pd.read_json('./data/wine_bar.json', encoding='utf-8', lines=True)
121
+ df['page_content'] = ''
122
+ columns = ['summary']
123
+ for column in columns:
124
+ if column != 'page_content':
125
+ df['page_content'] += df[column].astype(str) + ','
126
+ df = df.drop(columns=['review'])
127
 
128
+ loader =DataFrameLoader(data_frame=df, page_content_column='page_content')
129
+ docs = loader.load()
130
+ embeddings = OpenAIEmbeddings()
131
+ wine_bar_vectorstore = Chroma.from_documents(docs, embeddings)
132
+ wine_bar_vectorstore.similarity_search_with_score('์—ฌ์ž์นœ๊ตฌ๋ž‘ ๊ฐˆ๋งŒํ•œ ์™€์ธ๋ฐ”', k=5)
133
+ metadata_field_info = [
134
+ AttributeInfo(
135
+ name="name",
136
+ description="The name of the wine bar",
137
+ type="str",
138
+ ),
139
+ AttributeInfo(
140
+ name="rating",
141
+ description="1-5 rating for the wine bar",
142
+ type="float"
143
+ ),
144
+ AttributeInfo(
145
+ name="district",
146
+ description="The district of the wine bar.",
147
+ type="str",
148
+ ),
149
+ ]
150
 
151
+ document_content_description = "Database of a winebar"
152
+ llm = OpenAI(temperature=0)
153
+ wine_bar_retriever = SelfQueryRetriever.from_llm(
154
+ llm, wine_bar_vectorstore, document_content_description, metadata_field_info=metadata_field_info, verbose=True
155
+ ) # Added missing closing parenthesis
156
+ #### Tool3: Search in Google
157
  search = SerpAPIWrapper()
158
+ #### Tool4: Kakao Map API
159
+
160
+
161
+ class KakaoMap:
162
+ def __init__(self):
163
+ self.url = 'https://dapi.kakao.com/v2/local/search/keyword.json'
164
+ self.headers = {"Authorization": f"KakaoAK {kakao_api_key}"}
165
+
166
+ async def arun(self, query):
167
+ async with aiohttp.ClientSession() as session:
168
+ params = {'query': query,'page': 1}
169
+ async with session.get(self.url, params=params, headers=self.headers) as response:
170
+ places = await response.json()
171
+ address = places['documents'][0]['address_name']
172
+ if not address.split()[0].startswith('์„œ์šธ'):
173
+ return {'district': 'not in seoul'}
174
+ else:
175
+ return {'district': address.split()[1]}
176
+
177
+ def run(self, query):
178
+ params = {'query': query,'page': 1}
179
+ places = requests.get(self.url, params=params, headers=self.headers).json()
180
+ address = places['documents'][0]['address_name']
181
+ if not address.split()[0].startswith('์„œ์šธ'):
182
+ return {'district': 'not in seoul'}
183
+ else:
184
+ return {'district': address.split()[1]}
185
+ kakao_map = KakaoMap()
186
  tools = [
187
  Tool(
188
  name="Wine database",
189
+ func=wine_retriever.get_relevant_documents,
190
+ coroutine=wine_retriever.aget_relevant_documents,
191
  description="""
192
+ Database about the wines in wine store.
193
  You can search wines with the following attributes:
194
+ - price: The price range of the wine. You have to specify greater than and less than.
195
+ - rating: 1-5 rating float for the wine. You have to specify greater than and less than.
 
196
  - wine_type: The type of wine. It can be '๋ ˆ๋“œ', '๋กœ์ œ', '์ŠคํŒŒํด๋ง', 'ํ™”์ดํŠธ', '๋””์ €ํŠธ', '์ฃผ์ •๊ฐ•ํ™”'
197
+ - name: The name of wine.
198
+ - pairing: The food pairing of wine.
199
+ The form of Action Input must be 'key1: value1, key2: value2, ...'. For example, to search for wines with a rating of less than 3 points, a price range of 50000์› or more, and a meat pairing, enter 'rating: gt 0 lt 3, price: gt 50000, pairing: ๊ณ ๊ธฐ'.
200
+ --------------------------------------------------
201
+ You can get the following attributes:
202
+ - url: Wine purchase site URL.
203
+ - vivino_link: Vivino link of wine.
204
+ - flavor_description
205
+ - site_name: Wine purchase site name.
206
+ - name: The name of wine in korean.
207
+ - en_name: The name of wine in english.
208
+ - price: The price of wine in ์›.
209
+ - rating: 1-5 vivino rating.
210
+ - wine_type: The type of wine.
211
+ - pairing: The food pairing of wine.
212
+ - pickup_location: Offline stores where you can purchase wine
213
+ - img_url
214
+ - country
215
+ - body
216
+ - tannin
217
+ - sweetness
218
+ - acidity
219
+ - alcohol
220
+ - grape
221
+ The form of Desired Outcome must be 'key1, key2, ...'. For example to get the name and price of wine, enter 'name, price'.
222
  """
223
  ),
 
 
 
 
 
224
  Tool(
225
+ name = "Wine bar database",
226
+ func=wine_bar_retriever.get_relevant_documents,
227
+ coroutine=wine_bar_retriever.aget_relevant_documents,
228
+ description="Database about the winebars in Seoul. It should be the first thing you use when looking for information about a wine bar."
229
+ """
230
+ - query: The query of winebar. You can search wines with review data like mood or something.
231
+ - name: The name of winebar.
232
+ - price: The average price point of a wine bar.
233
+ - rating: 1-5 rating float for the wine bar.
234
+ - district: The district of wine bar. Input district must be korean. For example, if you want to search for wines in Gangnam, enter 'district: ๊ฐ•๋‚จ๊ตฌ'
235
+ The form of Action Input must be 'key1: value1, key2: value2, ...'.
236
+ --------------------------------------------------
237
+ You can get the following attributes:
238
+ - name: The name of winebar.
239
+ - url: Wine purchase site URL.
240
+ - rating: 1-5 ๋ง๊ณ ํ”Œ๋ ˆ์ดํŠธ(๋ง›์ง‘๊ฒ€์ƒ‰ ์•ฑ) rating.
241
+ - summary: Summarized information about wine bars
242
+ - address
243
+ - phone
244
+ - parking
245
+ - opening_hours
246
+ - menu
247
+ - holidays
248
+ - img_url
249
+ The form of Desired Outcome must be 'key1, key2, ...'. For example to get the name and price of wine, enter 'name, price'.
250
+ """
251
  ),
252
  Tool(
253
  name = "Search",
254
  func=search.run,
255
+ coroutine=search.arun,
256
  description="Useful for when you need to ask with search. Search in English only."
257
  ),
258
+ Tool(
259
+ name = "Map",
260
+ func=kakao_map.run,
261
+ coroutine=kakao_map.arun,
262
+ description="The tool used to draw a district for a region. When looking for wine bars, you can use this before applying filters based on location. The query must be in Korean. You can get the following attribute: district."
263
+ ),
264
  ]
 
265
  template = """
266
  Your role is a chatbot that asks customers questions about wine and makes recommendations.
267
  Never forget your name is "์ด์šฐ์„ ".
268
+ Keep your responses in short length to retain the user's attention unless you describe the wine for recommendations.
269
+ Be sure to actively empathize and respond to your users.
270
  Only generate one response at a time! When you are done generating, end with '<END_OF_TURN>' to give the user a chance to respond.
271
  Responses should be in Korean.
272
 
 
278
  Thought: you should always think about what to do
279
  Action: the action to take, should be one of [{tool_names}]
280
  Action Input: the input to the action
281
+ Desired Outcome: the desired outcome from the action (optional)
282
  Observation: the result of the action
283
  ... (this Thought/Action/Action Input/Observation can repeat N times)
284
  Thought: I now know the final answer
 
299
  """
300
 
301
  conversation_stages_dict = {
302
+ "1": "Introduction: Start the conversation by introducing yourself. Maintain politeness, respect, and a professional tone.",
303
+ "2": "Needs Analysis: Identify the customer's needs to make wine recommendations. Note that the wine database tools are not available. You ask about the occasion the customer will enjoy the wine, what they eat with it, and their desired price point. Ask only ONE question at a time.",
304
+ "3": "Checking Price Range: Asking the customer's preferred price point. Again, remember that the tool for this is not available. But if you know the customer's perferences and price range, then search for the three most suitable wines with tool and recommend them product cards in a list format with a Vivino link, price, rating, wine type, flavor description, and image.",
305
+ "4": "Wine Recommendation: Propose the three most suitable wines based on the customer's needs and price range. Before the recommendation, you should have identified the occasion the customer will enjoy the wine, what they will eat with it, and their desired price point. Each wine recommendation should form of product cards in a list format with a Vivino link, price, rating, wine type, flavor description, and image. Use only wines available in the database for recommendations. If there are no suitable wines in the database, inform the customer. After making a recommendation, inquire whether the customer likes the suggested wine.",
306
+ "5": "Sales: If the customer approves of the recommended wine, provide a detailed description. Supply a product card in a list format with a Vivino link, price, rating, wine type, flavor description, and image.",
307
+ "6": "Location Suggestions: Recommend wine bars based on the customer's location and occasion. Before making a recommendation, always use the map tool to find the district of the customer's preferred location. Then use the wine bar database tool to find a suitable wine bar. Provide form of product cards in a list format with the wine bar's name, url, rating, address, menu, opening_hours, holidays, phone, summary, and image with img_urls.",
308
+ "7": "Concluding the Conversation: Respond appropriately to the customer's comments to wrap up the conversation.",
309
+ "8": "Questions and Answers: This stage involves answering customer's inquiries. Use the search tool or wine database tool to provide specific answers where possible. Describe answer as detailed as possible",
310
+ "9": "Other Situations: Use this step when the situation does not fit into any of the steps between 1 and 8."
311
  }
312
 
313
  # Set up a prompt template
 
324
  # Format them in a particular way
325
  intermediate_steps = kwargs.pop("intermediate_steps")
326
  thoughts = ""
327
+ special_chars = "()'[]{}"
328
  for action, observation in intermediate_steps:
329
  thoughts += action.log
330
+
331
+ if ('Desired Outcome: ' in action.log) and (('Action: Wine database' in action.log) or ('Action: Wine bar database' in action.log)):
332
+ regex = r"Desired Outcome:(.*)"
333
+ match = re.search(regex, action.log, re.DOTALL)
334
+ if not match:
335
+ raise ValueError(f"Could not parse Desired Outcome: `{action.log}`")
336
+ desired_outcome_keys = [key.strip() for key in match.group(1).split(',')]
337
+
338
+ pattern = re.compile(r'metadata=\{(.*?)\}')
339
+ matches = pattern.findall(f'{observation}')
340
+ documents = ['{'+f'{match}'+'}' for match in matches]
341
+
342
+ pattern = re.compile(r"'(\w+)':\s*('[^']+'|\b[^\s,]+\b)")
343
+ output=[]
344
+
345
+ for doc in documents:
346
+ # Extract key-value pairs from the document string
347
+ matches = pattern.findall(doc)
348
+
349
+ # Convert matches to a dictionary
350
+ doc_dict = dict(matches)
351
+
352
+ # Create a new dictionary containing only the desired keys
353
+ item_dict = {}
354
+ for key in desired_outcome_keys:
355
+ value = doc_dict.get(key, "")
356
+ for c in special_chars:
357
+ value = value.replace(c, "")
358
+ item_dict[key] = value
359
+ output.append(item_dict)
360
+
361
+ observation = ','.join([str(i) for i in output])
362
+
363
  thoughts += f"\nObservation: {observation}\nThought: "
364
  # Set the agent_scratchpad variable to that value
365
  kwargs["agent_scratchpad"] = thoughts
 
369
  kwargs["tool_names"] = ", ".join([tool.name for tool in self.tools])
370
  return self.template.format(**kwargs)
371
 
372
+
373
  prompt = CustomPromptTemplate(
374
  template=template,
375
  tools=tools,
376
  input_variables=["input", "intermediate_steps", "conversation_history", "stage_number"]
377
  )
 
378
  class CustomOutputParser(AgentOutputParser):
379
 
380
  def parse(self, llm_output: str) -> Union[AgentAction, AgentFinish]:
 
387
  log=llm_output,
388
  )
389
  # Parse out the action and action input
390
+ regex = r"Action\s*\d*\s*:(.*?)\nAction\s*\d*\s*Input\s*\d*\s*:[\s]*(.*?)\n"
391
  match = re.search(regex, llm_output, re.DOTALL)
392
  if not match:
393
  raise ValueError(f"Could not parse LLM output: `{llm_output}`")
394
  action = match.group(1).strip()
395
  action_input = match.group(2)
396
+ # desired_outcome = match.group(3).strip() if match.group(3) else None
397
+
398
  # Return the action and action input
399
  return AgentAction(tool=action, tool_input=action_input.strip(" ").strip('"'), log=llm_output)
400
 
401
  output_parser = CustomOutputParser()
402
 
403
+ ### Gradio
404
+
405
+ class CustomStreamingStdOutCallbackHandler(FinalStreamingStdOutCallbackHandler):
406
+ """Callback handler for streaming in agents.
407
+ Only works with agents using LLMs that support streaming.
408
+
409
+ The output will be streamed until "<END" is reached.
410
+ """
411
+ def __init__(
412
+ self,
413
+ *,
414
+ answer_prefix_tokens: Optional[List[str]] = None,
415
+ end_prefix_tokens: str = "<END",
416
+ strip_tokens: bool = True,
417
+ stream_prefix: bool = False,
418
+ sender: str
419
+ ) -> None:
420
+ """Instantiate EofStreamingStdOutCallbackHandler.
421
+
422
+ Args:
423
+ answer_prefix_tokens: Token sequence that prefixes the anwer.
424
+ Default is ["Final", "Answer", ":"]
425
+ end_of_file_token: Token that signals end of file.
426
+ Default is "END"
427
+ strip_tokens: Ignore white spaces and new lines when comparing
428
+ answer_prefix_tokens to last tokens? (to determine if answer has been
429
+ reached)
430
+ stream_prefix: Should answer prefix itself also be streamed?
431
+ """
432
+ super().__init__(answer_prefix_tokens=answer_prefix_tokens, strip_tokens=strip_tokens, stream_prefix=stream_prefix)
433
+ self.end_prefix_tokens = end_prefix_tokens
434
+ self.end_reached = False
435
+ self.sender = sender
436
+
437
+ def append_to_last_tokens(self, token: str) -> None:
438
+ self.last_tokens.append(token)
439
+ self.last_tokens_stripped.append(token.strip())
440
+ if len(self.last_tokens) > 5:
441
+ self.last_tokens.pop(0)
442
+ self.last_tokens_stripped.pop(0)
443
+
444
+ def on_llm_start(
445
+ self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
446
+ ) -> None:
447
+ """Run when LLM starts running."""
448
+ self.answer_reached = False
449
+ self.end_reached = False
450
+
451
+ def check_if_answer_reached(self) -> bool:
452
+ if self.strip_tokens:
453
+ return ''.join(self.last_tokens_stripped) in self.answer_prefix_tokens_stripped
454
+ else:
455
+ unfied_last_tokens = ''.join(self.last_tokens)
456
+ try:
457
+ unfied_last_tokens.index(self.answer_prefix_tokens)
458
+ return True
459
+ except:
460
+ return False
461
+
462
+ def check_if_end_reached(self) -> bool:
463
+ if self.strip_tokens:
464
+ return ''.join(self.last_tokens_stripped) in self.answer_prefix_tokens_stripped
465
+ else:
466
+ unfied_last_tokens = ''.join(self.last_tokens)
467
+ try:
468
+ unfied_last_tokens.index(self.end_prefix_tokens)
469
+ self.sender[1] = True
470
+ return True
471
+ except:
472
+ # try:
473
+ # unfied_last_tokens.index('Action Input')
474
+ # self.sender[1] = False
475
+ # return False
476
+ # except:
477
+ # return False
478
+ return False
479
+
480
+ def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
481
+ """Run on new LLM token. Only available when streaming is enabled."""
482
+ # Remember the last n tokens, where n = len(answer_prefix_tokens)
483
+ self.append_to_last_tokens(token)
484
+
485
+ # Check if the last n tokens match the answer_prefix_tokens list ...
486
+ if not self.answer_reached and self.check_if_answer_reached():
487
+ self.answer_reached = True
488
+ if self.stream_prefix:
489
+ for t in self.last_tokens:
490
+ sys.stdout.write(t)
491
+ sys.stdout.flush()
492
+ return
493
+
494
+ if not self.end_reached and self.check_if_end_reached():
495
+ self.end_reached = True
496
+
497
+ if self.end_reached:
498
+ pass
499
+ elif self.answer_reached:
500
+ if self.last_tokens[-2] == ":":
501
+ pass
502
+ else:
503
+ self.sender[0] += self.last_tokens[-2]
504
+
505
+ class UnifiedAgent:
506
+ def __init__(self):
507
+
508
+ tools = [
509
+ Tool(
510
+ name="Wine database",
511
+ func=wine_retriever.get_relevant_documents,
512
+ coroutine=wine_retriever.aget_relevant_documents,
513
+ description="""
514
+ Database about the wines in wine store.
515
+ You can search wines with the following attributes:
516
+ - price: The price range of the wine. You have to specify greater than and less than.
517
+ - rating: 1-5 rating float for the wine. You have to specify greater than and less than.
518
+ - wine_type: The type of wine. It can be '๋ ˆ๋“œ', '๋กœ์ œ', '์ŠคํŒŒํด๋ง', 'ํ™”์ดํŠธ', '๋””์ €ํŠธ', '์ฃผ์ •๊ฐ•ํ™”'
519
+ - name: The name of wine.
520
+ - pairing: The food pairing of wine.
521
+ The form of Action Input must be 'key1: value1, key2: value2, ...'. For example, to search for wines with a rating of less than 3 points, a price range of 50000์› or more, and a meat pairing, enter 'rating: gt 0 lt 3, price: gt 50000, pairing: ๊ณ ๊ธฐ'.
522
+ --------------------------------------------------
523
+ You can get the following attributes:
524
+ - url: Wine purchase site URL.
525
+ - vivino_link: Vivino link of wine.
526
+ - flavor_description
527
+ - site_name: Wine purchase site name.
528
+ - name: The name of wine in korean.
529
+ - en_name: The name of wine in english.
530
+ - price: The price of wine in ์›.
531
+ - rating: 1-5 vivino rating.
532
+ - wine_type: The type of wine.
533
+ - pairing: The food pairing of wine.
534
+ - pickup_location: Offline stores where you can purchase wine
535
+ - img_url
536
+ - country
537
+ - body
538
+ - tannin
539
+ - sweetness
540
+ - acidity
541
+ - alcohol
542
+ - grape
543
+ The form of Desired Outcome must be 'key1, key2, ...'. For example to get the name and price of wine, enter 'name, price'.
544
+ """
545
+ ),
546
+ Tool(
547
+ name = "Wine bar database",
548
+ func=wine_bar_retriever.get_relevant_documents,
549
+ coroutine=wine_bar_retriever.aget_relevant_documents,
550
+ description="Database about the winebars in Seoul. It should be the first thing you use when looking for information about a wine bar."
551
+ """
552
+ - query: The query of winebar. You can search wines with review data like mood or something.
553
+ - name: The name of winebar.
554
+ - price: The average price point of a wine bar.
555
+ - rating: 1-5 rating float for the wine bar.
556
+ - district: The district of wine bar. Input district must be korean. For example, if you want to search for wines in Gangnam, enter 'district: ๊ฐ•๋‚จ๊ตฌ'
557
+ The form of Action Input must be 'key1: value1, key2: value2, ...'.
558
+ --------------------------------------------------
559
+ You can get the following attributes:
560
+ - name: The name of winebar.
561
+ - url: Wine purchase site URL.
562
+ - rating: 1-5 ๋ง๊ณ ํ”Œ๋ ˆ์ดํŠธ(๋ง›์ง‘๊ฒ€์ƒ‰ ์•ฑ) rating.
563
+ - summary: Summarized information about wine bars
564
+ - address
565
+ - phone
566
+ - parking
567
+ - opening_hours
568
+ - menu
569
+ - holidays
570
+ - img_url
571
+ The form of Desired Outcome must be 'key1, key2, ...'. For example to get the name and price of wine, enter 'name, price'.
572
+ """
573
+ ),
574
+ Tool(
575
+ name = "Search",
576
+ func=search.run,
577
+ coroutine=search.arun,
578
+ description="Useful for when you need to ask with search. Search in English only."
579
+ ),
580
+ Tool(
581
+ name = "Map",
582
+ func=kakao_map.run,
583
+ coroutine=kakao_map.arun,
584
+ description="The tool used to draw a district for a region. When looking for wine bars, you can use this before applying filters based on location. The query must be in Korean. You can get the following attribute: district."
585
+ ),
586
+ ]
587
 
588
+ llm_chain = LLMChain(llm=ChatOpenAI(model='gpt-4', temperature=0.5, streaming=True), prompt=prompt, verbose=True,)
589
+
590
+ tool_names = [tool.name for tool in tools]
591
+ agent = LLMSingleActionAgent(
592
+ llm_chain=llm_chain,
593
+ output_parser=output_parser,
594
+ stop=["\nObservation:"],
595
+ allowed_tools=tool_names
596
+ )
597
+ agent_executor = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=False)
598
+
599
+ self.agent_executor = agent_executor
600
+
601
+ async def arun(self, sender, *args, **kwargs):
602
+ resp = await self.agent_executor.arun(kwargs, callbacks=[CustomStreamingStdOutCallbackHandler(answer_prefix_tokens='์ด์šฐ์„ :', end_prefix_tokens='<END', strip_tokens=False, sender=sender)])
603
+ return resp
604
+
605
+
606
+ class UnifiedChain:
607
+ def __init__(self):
608
+ stage_analyzer_inception_prompt = load_prompt("./templates/stage_analyzer_inception_prompt_template.json")
609
+ llm = ChatOpenAI(model='gpt-3.5-turbo', temperature=0.0)
610
+ stage_analyzer_chain = LLMChain(
611
+ llm=llm,
612
+ prompt=stage_analyzer_inception_prompt,
613
+ verbose=False,
614
+ output_key="stage_number")
615
+
616
+ user_response_prompt = load_prompt("./templates/user_response_prompt.json")
617
+ # ๋žญ์ฒด์ธ ๋ชจ๋ธ ์„ ์–ธ, ๋žญ์ฒด์ธ์€ ์–ธ์–ด๋ชจ๋ธ๊ณผ ํ”„๋กฌํ”„ํŠธ๋กœ ๊ตฌ์„ฑ๋ฉ๋‹ˆ๋‹ค.
618
+ llm = ChatOpenAI(model='gpt-3.5-turbo', temperature=0.5)
619
+ user_response_chain = LLMChain(
620
+ llm=llm,
621
+ prompt=user_response_prompt,
622
+ verbose=False, # ๊ณผ์ •์„ ์ถœ๋ ฅํ• ์ง€
623
+ output_key="user_responses"
624
+ )
625
+
626
+ self.stage_analyzer_chain = stage_analyzer_chain
627
+ self.user_response_chain = user_response_chain
628
+
629
+ async def arun_stage_analyzer_chain(self, *args, **kwargs):
630
+ resp = await self.stage_analyzer_chain.arun(kwargs)
631
+ return resp
632
+
633
+ async def arun_user_response_chain(self, *args, **kwargs):
634
+ resp = await self.user_response_chain.arun(kwargs)
635
+ return resp
636
+
637
+ unified_chain = UnifiedChain()
638
+ unified_agent = UnifiedAgent()
639
 
640
+ # logging
641
+ # callback = gr.CSVLogger()
642
+ hf_writer = gr.HuggingFaceDatasetSaver(huggingface_token, "chatwine-korean")
643
 
 
 
 
644
 
645
  with gr.Blocks(css='#chatbot .overflow-y-auto{height:750px}') as demo:
646
+
647
  with gr.Row():
648
  gr.HTML("""<div style="text-align: center; max-width: 500px; margin: 0 auto;">
649
  <div>
 
653
  LinkedIn <a href="https://www.linkedin.com/company/audrey-ai/about/">Audrey.ai</a>
654
  </p>
655
  </div>""")
656
+
657
  chatbot = gr.Chatbot()
658
+
659
+ with gr.Row():
660
+ with gr.Column(scale=0.85):
661
+ msg = gr.Textbox()
662
+ with gr.Column(scale=0.15, min_width=0):
663
+ submit_btn = gr.Button("์ „์†ก")
664
+
665
+ user_response_examples = gr.Dataset(samples=[["์ด๋ฒˆ ์ฃผ์— ์นœ๊ตฌ๋“ค๊ณผ ๋ชจ์ž„์ด ์žˆ๋Š”๋ฐ, ํ›Œ๋ฅญํ•œ ์™€์ธ ํ•œ ๋ณ‘์„ ์ถ”์ฒœํ•ด์ค„๋ž˜?"], ["์ž…๋ฌธ์ž์—๊ฒŒ ์ข‹์€ ์™€์ธ์„ ์ถ”์ฒœํ•ด์ค„๋ž˜?"], ["์—ฐ์ธ๊ณผ ๊ฐ€๊ธฐ ์ข‹์€ ์™€์ธ๋ฐ”๋ฅผ ์•Œ๋ ค์ค˜"]], components=[msg], type="index")
666
  clear_btn = gr.ClearButton([msg, chatbot])
667
+
668
+ dev_mod = True
669
+ cur_stage = gr.Textbox(visible=dev_mod, interactive=False, label='current_stage')
670
+ stage_hist = gr.Textbox(visible=dev_mod, value="stage history: ", interactive=False, label='stage history')
671
+ chat_hist = gr.Textbox(visible=dev_mod, interactive=False, label='chatting_history')
672
+ response_examples_text = gr.Textbox(visible=dev_mod, interactive=False, value="์ด๋ฒˆ ์ฃผ์— ์นœ๊ตฌ๋“ค๊ณผ ๋ชจ์ž„์ด ์žˆ๋Š”๋ฐ, ํ›Œ๋ฅญํ•œ ์™€์ธ ํ•œ ๋ณ‘์„ ์ถ”์ฒœํ•ด์ค„๋ž˜?|์ž…๋ฌธ์ž์—๊ฒŒ ์ข‹์€ ์™€์ธ์„ ์ถ”์ฒœํ•ด์ค„๋ž˜?|์—ฐ์ธ๊ณผ ๊ฐ€๊ธฐ ์ข‹์€ ์™€์ธ๋ฐ”๋ฅผ ์•Œ๋ ค์ค˜", label='response_examples')
673
+ btn = gr.Button("Flag", visible=dev_mod)
674
+ hf_writer.setup(components=[chat_hist, stage_hist, response_examples_text], flagging_dir="chatwine-korean")
675
+
676
+ def click_flag_btn(*args):
677
+ hf_writer.flag(flag_data=[*args])
678
+
679
+ def clean(*args):
680
+ return gr.Dataset.update(samples=[["์ด๋ฒˆ ์ฃผ์— ์นœ๊ตฌ๋“ค๊ณผ ๋ชจ์ž„์ด ์žˆ๋Š”๋ฐ, ํ›Œ๋ฅญํ•œ ์™€์ธ ํ•œ ๋ณ‘์„ ์ถ”์ฒœํ•ด์ค„๋ž˜?"], ["์ž…๋ฌธ์ž์—๊ฒŒ ์ข‹์€ ์™€์ธ์„ ์ถ”์ฒœํ•ด์ค„๋ž˜?"], ["์—ฐ์ธ๊ณผ ๊ฐ€๊ธฐ ์ข‹์€ ์™€์ธ๋ฐ”๋ฅผ ์•Œ๋ ค์ค˜"]]), "", "stage history: ", "", "์ด๋ฒˆ ์ฃผ์— ์นœ๊ตฌ๋“ค๊ณผ ๋ชจ์ž„์ด ์žˆ๋Š”๋ฐ, ํ›Œ๋ฅญํ•œ ์™€์ธ ํ•œ ๋ณ‘์„ ์ถ”์ฒœํ•ด์ค„๋ž˜?|์ž…๋ฌธ์ž์—๊ฒŒ ์ข‹์€ ์™€์ธ์„ ์ถ”์ฒœํ•ด์ค„๋ž˜?|์—ฐ์ธ๊ณผ ๊ฐ€๊ธฐ ์ข‹์€ ์™€์ธ๋ฐ”๋ฅผ ์•Œ๋ ค์ค˜"
681
+
682
+ def load_example(response_text, input_idx):
683
+ response_examples = []
684
+ for user_response_example in response_text.split('|'):
685
+ response_examples.append([user_response_example])
686
+ return response_examples[input_idx][0]
687
+
688
+ async def agent_run(agent_exec, inp, sender):
689
+ sender[0] = ""
690
+ await agent_exec.arun(inp)
691
+
692
+ def user_chat(user_message, chat_history_list, chat_history):
693
+ return (chat_history_list + [[user_message, None]], chat_history + f"User: {user_message} <END_OF_TURN>\n", [])
694
+
695
+ async def bot_stage_pred(user_response, chat_history, stage_history):
696
+ pre_chat_history = '<END_OF_TURN>'.join(chat_history.split('<END_OF_TURN>')[:-2])
697
+ if pre_chat_history != '':
698
+ pre_chat_history += '<END_OF_TURN>'
699
+ # stage_number = unified_chain.stage_analyzer_chain.run({'conversation_history': pre_chat_history, 'stage_history': stage_history.replace('stage history: ', ''), 'last_user_saying':user_response+' <END_OF_TURN>\n'})
700
+ stage_number = await unified_chain.arun_stage_analyzer_chain(conversation_history=pre_chat_history, stage_history= stage_history.replace('stage history: ', ''), last_user_saying=user_response+' <END_OF_TURN>\n')
701
  stage_number = stage_number[-1]
702
  stage_history += stage_number if stage_history == "stage history: " else ", " + stage_number
703
+
704
+ return stage_number, stage_history
705
+
706
+ async def bot_chat(user_response, chat_history, chat_history_list, current_stage): # stream output by yielding
707
+
708
+ pre_chat_history = '<END_OF_TURN>'.join(chat_history.split('<END_OF_TURN>')[:-2])
709
+ if pre_chat_history != '':
710
+ pre_chat_history += '<END_OF_TURN>'
711
+
712
+ sender = ["", False]
713
+ task = asyncio.create_task(unified_agent.arun(sender = sender, input=user_response+' <END_OF_TURN>\n', conversation_history=pre_chat_history, stage_number= current_stage))
714
+ await asyncio.sleep(0)
715
+
716
+ while(sender[1] == False):
717
+ await asyncio.sleep(0.2)
718
+ chat_history_list[-1][1] = sender[0]
719
+ yield chat_history_list, chat_history + f"์ด์šฐ์„ : {sender[0]}<END_OF_TURN>\n"
720
+
721
+ chat_history_list[-1][1] = sender[0]
722
+ yield chat_history_list, chat_history + f"์ด์šฐ์„ : {sender[0]}<END_OF_TURN>\n"
723
+
724
+ async def bot_response_pred(chat_history):
725
+ response_examples = []
726
+ pre_chat_history = '<END_OF_TURN>'.join(chat_history.split('<END_OF_TURN>')[-3:])
727
+ out = await unified_chain.arun_user_response_chain(conversation_history=pre_chat_history)
728
+ for user_response_example in out.split('|'):
729
+ response_examples.append([user_response_example])
730
+ return [response_examples, out, ""]
731
+
732
+ # btn.click(lambda *args: hf_writer.flag(args), [msg, chat_hist, stage_hist, response_examples_text], None, preprocess=False)
733
+
734
+ msg.submit(
735
+ user_chat, [msg, chatbot, chat_hist], [chatbot, chat_hist, user_response_examples], queue=False
736
+ ).then(
737
+ bot_stage_pred, [msg, chat_hist, stage_hist], [cur_stage, stage_hist], queue=False
738
+ ).then(
739
+ bot_chat, [msg, chat_hist, chatbot, cur_stage], [chatbot, chat_hist]
740
+ ).then(
741
+ bot_response_pred, chat_hist, [user_response_examples, response_examples_text, msg]
742
+ ).then(
743
+ click_flag_btn, [chat_hist, stage_hist, response_examples_text], None
744
+ )
745
+
746
+
747
+
748
+ submit_btn.click(
749
+ user_chat, [msg, chatbot, chat_hist], [chatbot, chat_hist, user_response_examples], queue=False
750
+ ).then(
751
+ bot_stage_pred, [msg, chat_hist, stage_hist], [cur_stage, stage_hist], queue=False
752
+ ).then(
753
+ bot_chat, [msg, chat_hist, chatbot, cur_stage], [chatbot, chat_hist]
754
+ ).then(
755
+ bot_response_pred, chat_hist, [user_response_examples, response_examples_text, msg]
756
+ ).then(
757
+ click_flag_btn, [chat_hist, stage_hist, response_examples_text], None
758
+ )
759
+
760
+
761
+
762
+ clear_btn.click(
763
+ clean,
764
+ inputs=[user_response_examples, cur_stage, stage_hist, chat_hist, response_examples_text],
765
+ outputs=[user_response_examples, cur_stage, stage_hist, chat_hist, response_examples_text],
766
+ queue=False)
767
+ user_response_examples.click(load_example, inputs=[response_examples_text, user_response_examples], outputs=[msg], queue=False)
768
+ btn.click(lambda *args: hf_writer.flag(args), [chat_hist, stage_hist, response_examples_text], None, preprocess=False)
769
+ demo.queue(concurrency_count=100)
770
+ demo.launch()