TRaw commited on
Commit
3e5089f
1 Parent(s): 4752671

Upload 5 files

Browse files
Files changed (6) hide show
  1. .gitattributes +1 -0
  2. app.py +408 -0
  3. autogen-human-input.gif +3 -0
  4. autogen.png +0 -0
  5. human.png +0 -0
  6. requirements.txt +3 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ autogen-human-input.gif filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,408 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import threading
4
+ from itertools import chain
5
+
6
+ import anyio
7
+ import autogen
8
+ import gradio as gr
9
+ from autogen import Agent, AssistantAgent, OpenAIWrapper, UserProxyAgent
10
+ from autogen.code_utils import extract_code
11
+ from gradio import ChatInterface, Request
12
+ from gradio.helpers import special_args
13
+
14
+ LOG_LEVEL = "INFO"
15
+ TIMEOUT = 60
16
+
17
+
18
+ class myChatInterface(ChatInterface):
19
+ async def _submit_fn(
20
+ self,
21
+ message: str,
22
+ history_with_input: list[list[str | None]],
23
+ request: Request,
24
+ *args,
25
+ ) -> tuple[list[list[str | None]], list[list[str | None]]]:
26
+ history = history_with_input[:-1]
27
+ inputs, _, _ = special_args(self.fn, inputs=[message, history, *args], request=request)
28
+
29
+ if self.is_async:
30
+ await self.fn(*inputs)
31
+ else:
32
+ await anyio.to_thread.run_sync(self.fn, *inputs, limiter=self.limiter)
33
+
34
+ # history.append([message, response])
35
+ return history, history
36
+
37
+
38
+ with gr.Blocks() as demo:
39
+
40
+ def flatten_chain(list_of_lists):
41
+ return list(chain.from_iterable(list_of_lists))
42
+
43
+ class thread_with_trace(threading.Thread):
44
+ # https://www.geeksforgeeks.org/python-different-ways-to-kill-a-thread/
45
+ # https://stackoverflow.com/questions/6893968/how-to-get-the-return-value-from-a-thread
46
+ def __init__(self, *args, **keywords):
47
+ threading.Thread.__init__(self, *args, **keywords)
48
+ self.killed = False
49
+ self._return = None
50
+
51
+ def start(self):
52
+ self.__run_backup = self.run
53
+ self.run = self.__run
54
+ threading.Thread.start(self)
55
+
56
+ def __run(self):
57
+ sys.settrace(self.globaltrace)
58
+ self.__run_backup()
59
+ self.run = self.__run_backup
60
+
61
+ def run(self):
62
+ if self._target is not None:
63
+ self._return = self._target(*self._args, **self._kwargs)
64
+
65
+ def globaltrace(self, frame, event, arg):
66
+ if event == "call":
67
+ return self.localtrace
68
+ else:
69
+ return None
70
+
71
+ def localtrace(self, frame, event, arg):
72
+ if self.killed:
73
+ if event == "line":
74
+ raise SystemExit()
75
+ return self.localtrace
76
+
77
+ def kill(self):
78
+ self.killed = True
79
+
80
+ def join(self, timeout=0):
81
+ threading.Thread.join(self, timeout)
82
+ return self._return
83
+
84
+ def update_agent_history(recipient, messages, sender, config):
85
+ if config is None:
86
+ config = recipient
87
+ if messages is None:
88
+ messages = recipient._oai_messages[sender]
89
+ message = messages[-1]
90
+ message.get("content", "")
91
+ # config.append(msg) if msg is not None else None # config can be agent_history
92
+ return False, None # required to ensure the agent communication flow continues
93
+
94
+ def _is_termination_msg(message):
95
+ """Check if a message is a termination message.
96
+ Terminate when no code block is detected. Currently only detect python code blocks.
97
+ """
98
+ if isinstance(message, dict):
99
+ message = message.get("content")
100
+ if message is None:
101
+ return False
102
+ cb = extract_code(message)
103
+ contain_code = False
104
+ for c in cb:
105
+ # todo: support more languages
106
+ if c[0] == "python":
107
+ contain_code = True
108
+ break
109
+ return not contain_code
110
+
111
+ def initialize_agents(config_list):
112
+ assistant = AssistantAgent(
113
+ name="assistant",
114
+ max_consecutive_auto_reply=5,
115
+ llm_config={
116
+ # "seed": 42,
117
+ "timeout": TIMEOUT,
118
+ "config_list": config_list,
119
+ },
120
+ )
121
+
122
+ userproxy = UserProxyAgent(
123
+ name="userproxy",
124
+ human_input_mode="NEVER",
125
+ is_termination_msg=_is_termination_msg,
126
+ max_consecutive_auto_reply=5,
127
+ # code_execution_config=False,
128
+ code_execution_config={
129
+ "work_dir": "coding",
130
+ "use_docker": False, # set to True or image name like "python:3" to use docker
131
+ },
132
+ )
133
+
134
+ # assistant.register_reply([Agent, None], update_agent_history)
135
+ # userproxy.register_reply([Agent, None], update_agent_history)
136
+
137
+ return assistant, userproxy
138
+
139
+ def chat_to_oai_message(chat_history):
140
+ """Convert chat history to OpenAI message format."""
141
+ messages = []
142
+ if LOG_LEVEL == "DEBUG":
143
+ print(f"chat_to_oai_message: {chat_history}")
144
+ for msg in chat_history:
145
+ messages.append(
146
+ {
147
+ "content": msg[0].split()[0] if msg[0].startswith("exitcode") else msg[0],
148
+ "role": "user",
149
+ }
150
+ )
151
+ messages.append({"content": msg[1], "role": "assistant"})
152
+ return messages
153
+
154
+ def oai_message_to_chat(oai_messages, sender):
155
+ """Convert OpenAI message format to chat history."""
156
+ chat_history = []
157
+ messages = oai_messages[sender]
158
+ if LOG_LEVEL == "DEBUG":
159
+ print(f"oai_message_to_chat: {messages}")
160
+ for i in range(0, len(messages), 2):
161
+ chat_history.append(
162
+ [
163
+ messages[i]["content"],
164
+ messages[i + 1]["content"] if i + 1 < len(messages) else "",
165
+ ]
166
+ )
167
+ return chat_history
168
+
169
+ def agent_history_to_chat(agent_history):
170
+ """Convert agent history to chat history."""
171
+ chat_history = []
172
+ for i in range(0, len(agent_history), 2):
173
+ chat_history.append(
174
+ [
175
+ agent_history[i],
176
+ agent_history[i + 1] if i + 1 < len(agent_history) else None,
177
+ ]
178
+ )
179
+ return chat_history
180
+
181
+ def initiate_chat(config_list, user_message, chat_history):
182
+ if LOG_LEVEL == "DEBUG":
183
+ print(f"chat_history_init: {chat_history}")
184
+ # agent_history = flatten_chain(chat_history)
185
+ if len(config_list[0].get("api_key", "")) < 2:
186
+ chat_history.append(
187
+ [
188
+ user_message,
189
+ "Hi, nice to meet you! Please enter your API keys in below text boxs.",
190
+ ]
191
+ )
192
+ return chat_history
193
+ else:
194
+ llm_config = {
195
+ # "seed": 42,
196
+ "timeout": TIMEOUT,
197
+ "config_list": config_list,
198
+ }
199
+ assistant.llm_config.update(llm_config)
200
+ assistant.client = OpenAIWrapper(**assistant.llm_config)
201
+
202
+ if user_message.strip().lower().startswith("show file:"):
203
+ filename = user_message.strip().lower().replace("show file:", "").strip()
204
+ filepath = os.path.join("coding", filename)
205
+ if os.path.exists(filepath):
206
+ chat_history.append([user_message, (filepath,)])
207
+ else:
208
+ chat_history.append([user_message, f"File {filename} not found."])
209
+ return chat_history
210
+
211
+ assistant.reset()
212
+ oai_messages = chat_to_oai_message(chat_history)
213
+ assistant._oai_system_message_origin = assistant._oai_system_message.copy()
214
+ assistant._oai_system_message += oai_messages
215
+
216
+ try:
217
+ userproxy.initiate_chat(assistant, message=user_message)
218
+ messages = userproxy.chat_messages
219
+ chat_history += oai_message_to_chat(messages, assistant)
220
+ # agent_history = flatten_chain(chat_history)
221
+ except Exception as e:
222
+ # agent_history += [user_message, str(e)]
223
+ # chat_history[:] = agent_history_to_chat(agent_history)
224
+ chat_history.append([user_message, str(e)])
225
+
226
+ assistant._oai_system_message = assistant._oai_system_message_origin.copy()
227
+ if LOG_LEVEL == "DEBUG":
228
+ print(f"chat_history: {chat_history}")
229
+ # print(f"agent_history: {agent_history}")
230
+ return chat_history
231
+
232
+ def chatbot_reply_thread(input_text, chat_history, config_list):
233
+ """Chat with the agent through terminal."""
234
+ thread = thread_with_trace(target=initiate_chat, args=(config_list, input_text, chat_history))
235
+ thread.start()
236
+ try:
237
+ messages = thread.join(timeout=TIMEOUT)
238
+ if thread.is_alive():
239
+ thread.kill()
240
+ thread.join()
241
+ messages = [
242
+ input_text,
243
+ "Timeout Error: Please check your API keys and try again later.",
244
+ ]
245
+ except Exception as e:
246
+ messages = [
247
+ [
248
+ input_text,
249
+ str(e) if len(str(e)) > 0 else "Invalid Request to OpenAI, please check your API keys.",
250
+ ]
251
+ ]
252
+ return messages
253
+
254
+ def chatbot_reply_plain(input_text, chat_history, config_list):
255
+ """Chat with the agent through terminal."""
256
+ try:
257
+ messages = initiate_chat(config_list, input_text, chat_history)
258
+ except Exception as e:
259
+ messages = [
260
+ [
261
+ input_text,
262
+ str(e) if len(str(e)) > 0 else "Invalid Request to OpenAI, please check your API keys.",
263
+ ]
264
+ ]
265
+ return messages
266
+
267
+ def chatbot_reply(input_text, chat_history, config_list):
268
+ """Chat with the agent through terminal."""
269
+ return chatbot_reply_thread(input_text, chat_history, config_list)
270
+
271
+ def get_description_text():
272
+ return """
273
+ # Microsoft AutoGen: Multi-Round Human Interaction Chatbot Demo
274
+
275
+ This demo shows how to build a chatbot which can handle multi-round conversations with human interactions.
276
+
277
+ #### [AutoGen](https://github.com/microsoft/autogen) [Discord](https://discord.gg/pAbnFJrkgZ) [Paper](https://arxiv.org/abs/2308.08155) [SourceCode](https://github.com/thinkall/autogen-demos)
278
+ """
279
+
280
+ def update_config():
281
+ config_list = autogen.config_list_from_models(
282
+ model_list=[os.environ.get("MODEL", "gpt-35-turbo")],
283
+ )
284
+ if not config_list:
285
+ config_list = [
286
+ {
287
+ "api_key": "",
288
+ "base_url": "",
289
+ "api_type": "azure",
290
+ "api_version": "2023-07-01-preview",
291
+ "model": "gpt-35-turbo",
292
+ }
293
+ ]
294
+
295
+ return config_list
296
+
297
+ def set_params(model, oai_key, aoai_key, aoai_base):
298
+ os.environ["MODEL"] = model
299
+ os.environ["OPENAI_API_KEY"] = oai_key
300
+ os.environ["AZURE_OPENAI_API_KEY"] = aoai_key
301
+ os.environ["AZURE_OPENAI_API_BASE"] = aoai_base
302
+
303
+ def respond(message, chat_history, model, oai_key, aoai_key, aoai_base):
304
+ set_params(model, oai_key, aoai_key, aoai_base)
305
+ config_list = update_config()
306
+ chat_history[:] = chatbot_reply(message, chat_history, config_list)
307
+ if LOG_LEVEL == "DEBUG":
308
+ print(f"return chat_history: {chat_history}")
309
+ return ""
310
+
311
+ config_list, assistant, userproxy = (
312
+ [
313
+ {
314
+ "api_key": "",
315
+ "base_url": "",
316
+ "api_type": "azure",
317
+ "api_version": "2023-07-01-preview",
318
+ "model": "gpt-35-turbo",
319
+ }
320
+ ],
321
+ None,
322
+ None,
323
+ )
324
+ assistant, userproxy = initialize_agents(config_list)
325
+
326
+ description = gr.Markdown(get_description_text())
327
+
328
+ with gr.Row() as params:
329
+ txt_model = gr.Dropdown(
330
+ label="Model",
331
+ choices=[
332
+ "gpt-4",
333
+ "gpt-35-turbo",
334
+ "gpt-3.5-turbo",
335
+ ],
336
+ allow_custom_value=True,
337
+ value="gpt-35-turbo",
338
+ container=True,
339
+ )
340
+ txt_oai_key = gr.Textbox(
341
+ label="OpenAI API Key",
342
+ placeholder="Enter OpenAI API Key",
343
+ max_lines=1,
344
+ show_label=True,
345
+ container=True,
346
+ type="password",
347
+ )
348
+ txt_aoai_key = gr.Textbox(
349
+ label="Azure OpenAI API Key",
350
+ placeholder="Enter Azure OpenAI API Key",
351
+ max_lines=1,
352
+ show_label=True,
353
+ container=True,
354
+ type="password",
355
+ )
356
+ txt_aoai_base_url = gr.Textbox(
357
+ label="Azure OpenAI API Base",
358
+ placeholder="Enter Azure OpenAI Base Url",
359
+ max_lines=1,
360
+ show_label=True,
361
+ container=True,
362
+ type="password",
363
+ )
364
+
365
+ chatbot = gr.Chatbot(
366
+ [],
367
+ elem_id="chatbot",
368
+ bubble_full_width=False,
369
+ avatar_images=(
370
+ "human.png",
371
+ (os.path.join(os.path.dirname(__file__), "autogen.png")),
372
+ ),
373
+ render=False,
374
+ height=600,
375
+ )
376
+
377
+ txt_input = gr.Textbox(
378
+ scale=4,
379
+ show_label=False,
380
+ placeholder="Enter text and press enter",
381
+ container=False,
382
+ render=False,
383
+ autofocus=True,
384
+ )
385
+
386
+ chatiface = myChatInterface(
387
+ respond,
388
+ chatbot=chatbot,
389
+ textbox=txt_input,
390
+ additional_inputs=[
391
+ txt_model,
392
+ txt_oai_key,
393
+ txt_aoai_key,
394
+ txt_aoai_base_url,
395
+ ],
396
+ examples=[
397
+ ["write a python function to count the sum of two numbers?"],
398
+ ["what if the production of two numbers?"],
399
+ [
400
+ "Plot a chart of the last year's stock prices of Microsoft, Google and Apple and save to stock_price.png."
401
+ ],
402
+ ["show file: stock_price.png"],
403
+ ],
404
+ )
405
+
406
+
407
+ if __name__ == "__main__":
408
+ demo.launch(share=True, server_name="0.0.0.0")
autogen-human-input.gif ADDED

Git LFS Details

  • SHA256: 048316f37d62ef43d9bdcd66bf1629cb94622d70ecbefc7648d090401fa2a795
  • Pointer size: 132 Bytes
  • Size of remote file: 4.81 MB
autogen.png ADDED
human.png ADDED
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ pyautogen==0.2.0b4
2
+ gradio>=4.0.0
3
+ yfinance