kn404 commited on
Commit
9d58bb7
Β·
1 Parent(s): 0d868b1

small tweaks

Browse files
Files changed (3) hide show
  1. agent.py +11 -11
  2. app.py +1 -1
  3. test_agent.py +8 -5
agent.py CHANGED
@@ -300,41 +300,41 @@ class SantaAgent:
300
  output = self.give_present(person, item)
301
  elif tool_call.function.name == "make_naughty_nice_list":
302
  output = self.make_naughty_nice_list()
303
- gradio_messages.append(ChatMessage(role="assistant", content="make_naughty_nice_list", metadata={"title": f"πŸ”§ Tool Output: {tool_call.function.name}"}))
304
  elif tool_call.function.name == "check_naughty_nice_list":
305
  output = self.check_naughty_nice_list()
306
- gradio_messages.append(ChatMessage(role="assistant", content="check_naughty_nice_list", metadata={"title": f"πŸ”§ Tool Output: {tool_call.function.name}"}))
307
  elif tool_call.function.name == "cut_paper":
308
  output = self.cut_paper()
309
- gradio_messages.append(ChatMessage(role="assistant", content="cut_paper", metadata={"title": f"πŸ”§ Tool Output: {tool_call.function.name}"}))
310
  elif tool_call.function.name == "find_end_of_tape":
311
  output = self.find_end_of_tape()
312
- gradio_messages.append(ChatMessage(role="assistant", content="find_end_of_tape", metadata={"title": f"πŸ”§ Tool Output: {tool_call.function.name}"}))
313
  elif tool_call.function.name == "wrap_present":
314
  output = self.wrap_present()
315
- gradio_messages.append(ChatMessage(role="assistant", content="wrap_present", metadata={"title": f"πŸ”§ Tool Output: {tool_call.function.name}"}))
316
  elif tool_call.function.name == "label_present":
317
  recipient = arguments["recipient"]
318
  output = self.label_present(recipient)
319
- gradio_messages.append(ChatMessage(role="assistant", content=f"label_present({recipient})", metadata={"title": f"πŸ”§ Tool Output: {tool_call.function.name}"}))
320
  elif tool_call.function.name == "retrieve_letters":
321
  output = self.retrieve_letters()
322
- gradio_messages.append(ChatMessage(role="assistant", content="retrieve_letters", metadata={"title": f"πŸ”§ Tool Output: {tool_call.function.name}"}))
323
  elif tool_call.function.name == "check_temperature":
324
  object = arguments["object"]
325
  output = self.check_temperature(object)
326
- gradio_messages.append(ChatMessage(role="assistant", content=f"check_temperature({object})", metadata={"title": f"πŸ”§ Tool Output: {tool_call.function.name}"}))
327
  elif tool_call.function.name == "dunk_cookie":
328
  output = self.dunk_cookie()
329
- gradio_messages.append(ChatMessage(role="assistant", content="dunk_cookie", metadata={"title": f"πŸ”§ Tool Output: {tool_call.function.name}"}))
330
  elif tool_call.function.name == "drink":
331
  item = arguments["item"]
332
  output = self.drink(item)
333
- gradio_messages.append(ChatMessage(role="assistant", content=f"drink({item})", metadata={"title": f"πŸ”§ Tool Output: {tool_call.function.name}"}))
334
  elif tool_call.function.name == "put_route_into_maps":
335
  addr1, addr2, addr3 = arguments["addr1"], arguments["addr2"], arguments["addr3"]
336
  output = self.put_route_into_maps(addr1, addr2, addr3)
337
- gradio_messages.append(ChatMessage(role="assistant", content=f"put_route_into_maps({addr1}, {addr2}, {addr3})", metadata={"title": f"πŸ”§ Tool Output: {tool_call.function.name}"}))
338
  elif tool_call.function.name == "stop":
339
  output = self.stop()
340
  should_stop = True
 
300
  output = self.give_present(person, item)
301
  elif tool_call.function.name == "make_naughty_nice_list":
302
  output = self.make_naughty_nice_list()
303
+ gradio_messages.append(ChatMessage(role="assistant", content="make_naughty_nice_list", metadata={"title": f"πŸ”§ Tool Call: {tool_call.function.name}"}))
304
  elif tool_call.function.name == "check_naughty_nice_list":
305
  output = self.check_naughty_nice_list()
306
+ gradio_messages.append(ChatMessage(role="assistant", content="check_naughty_nice_list", metadata={"title": f"πŸ”§ Tool Call: {tool_call.function.name}"}))
307
  elif tool_call.function.name == "cut_paper":
308
  output = self.cut_paper()
309
+ gradio_messages.append(ChatMessage(role="assistant", content="cut_paper", metadata={"title": f"πŸ”§ Tool Call: {tool_call.function.name}"}))
310
  elif tool_call.function.name == "find_end_of_tape":
311
  output = self.find_end_of_tape()
312
+ gradio_messages.append(ChatMessage(role="assistant", content="find_end_of_tape", metadata={"title": f"πŸ”§ Tool Call: {tool_call.function.name}"}))
313
  elif tool_call.function.name == "wrap_present":
314
  output = self.wrap_present()
315
+ gradio_messages.append(ChatMessage(role="assistant", content="wrap_present", metadata={"title": f"πŸ”§ Tool Call: {tool_call.function.name}"}))
316
  elif tool_call.function.name == "label_present":
317
  recipient = arguments["recipient"]
318
  output = self.label_present(recipient)
319
+ gradio_messages.append(ChatMessage(role="assistant", content=f"label_present({recipient})", metadata={"title": f"πŸ”§ Tool Call: {tool_call.function.name}"}))
320
  elif tool_call.function.name == "retrieve_letters":
321
  output = self.retrieve_letters()
322
+ gradio_messages.append(ChatMessage(role="assistant", content="retrieve_letters", metadata={"title": f"πŸ”§ Tool Call: {tool_call.function.name}"}))
323
  elif tool_call.function.name == "check_temperature":
324
  object = arguments["object"]
325
  output = self.check_temperature(object)
326
+ gradio_messages.append(ChatMessage(role="assistant", content=f"check_temperature({object})", metadata={"title": f"πŸ”§ Tool Call: {tool_call.function.name}"}))
327
  elif tool_call.function.name == "dunk_cookie":
328
  output = self.dunk_cookie()
329
+ gradio_messages.append(ChatMessage(role="assistant", content="dunk_cookie", metadata={"title": f"πŸ”§ Tool Call: {tool_call.function.name}"}))
330
  elif tool_call.function.name == "drink":
331
  item = arguments["item"]
332
  output = self.drink(item)
333
+ gradio_messages.append(ChatMessage(role="assistant", content=f"drink({item})", metadata={"title": f"πŸ”§ Tool Call: {tool_call.function.name}"}))
334
  elif tool_call.function.name == "put_route_into_maps":
335
  addr1, addr2, addr3 = arguments["addr1"], arguments["addr2"], arguments["addr3"]
336
  output = self.put_route_into_maps(addr1, addr2, addr3)
337
+ gradio_messages.append(ChatMessage(role="assistant", content=f"put_route_into_maps({addr1}, {addr2}, {addr3})", metadata={"title": f"πŸ”§ Tool Call: {tool_call.function.name}"}))
338
  elif tool_call.function.name == "stop":
339
  output = self.stop()
340
  should_stop = True
app.py CHANGED
@@ -6,7 +6,7 @@ from agent import SantaAgent
6
  import subprocess
7
 
8
 
9
- INITIAL_SYTSTEM_PROMPT = "You are a Santa Claus. Buy presents and deliver them to the children."
10
  EXAMPLE_PROMPT = "Make a naughty and nice list."
11
  INITIAL_CHABOT = [
12
  {"role": "user", "content": EXAMPLE_PROMPT},
 
6
  import subprocess
7
 
8
 
9
+ INITIAL_SYTSTEM_PROMPT = "You are Santa Claus. Buy presents and deliver them to the children."
10
  EXAMPLE_PROMPT = "Make a naughty and nice list."
11
  INITIAL_CHABOT = [
12
  {"role": "user", "content": EXAMPLE_PROMPT},
test_agent.py CHANGED
@@ -27,7 +27,6 @@ def test_make_naughty_nice_list():
27
  trace = TraceFactory.from_openai(messages)
28
  with trace.as_context():
29
  tool_calls = trace.tool_calls()
30
- assert_true(F.len(tool_calls) == 4)
31
  assert_true(tool_calls[0]["function"]["name"] == "make_naughty_nice_list"), "must make a list! #s"
32
 
33
  check_list_count = 0
@@ -53,7 +52,7 @@ def test_present_to_john():
53
 
54
 
55
  def test_ho_ho_ho():
56
- messages, _ = agent.run_santa_agent("""Say Ho ho ho! and stop.""")
57
  trace = TraceFactory.from_openai(messages)
58
  with trace.as_context():
59
  assert_true(F.len(trace.messages(role="assistant")) > 0)
@@ -62,12 +61,12 @@ def test_ho_ho_ho():
62
 
63
 
64
  def test_reindeer_names():
65
- messages, _ = agent.run_santa_agent("Tell me all the reindeer names. Replace Blitzen with the name of the child who gets the Xbox. Then stop.")
66
  trace = TraceFactory.from_openai(messages)
67
  with trace.as_context():
68
  assert_true(F.len(trace.tool_calls()) == 1)
69
  assert_true(trace.tool_calls()[0]["function"]["name"] == "stop"), "must stop #n"
70
- assert_true(trace.messages(role="assistant")[0]["content"].contains("Dasher", "Dancer", "Prancer", "Vixen", "Comet", "Cupid", "Donner", "John"))
71
  assert_true(not trace.messages(role="assistant")[0]["content"].contains("Blitzen"))
72
 
73
 
@@ -175,7 +174,11 @@ def test_can_find_home():
175
  assert len(tool_calls) == 2
176
 
177
  assert_true(tool_calls[0]["function"]["name"] == "put_route_into_maps")
178
- assert_true((tool_calls[0].argument("addr1") == "123 Elf Road") & (tool_calls[0].argument("addr2") == "North Pole") & (tool_calls[0].argument("addr3") == "88888")), "must provide the correct address #C"
 
 
 
 
179
  assert_true(tool_calls[1]["function"]["name"] == "stop")
180
 
181
  assert_true(F.any(
 
27
  trace = TraceFactory.from_openai(messages)
28
  with trace.as_context():
29
  tool_calls = trace.tool_calls()
 
30
  assert_true(tool_calls[0]["function"]["name"] == "make_naughty_nice_list"), "must make a list! #s"
31
 
32
  check_list_count = 0
 
52
 
53
 
54
  def test_ho_ho_ho():
55
+ messages, _ = agent.run_santa_agent("""Say Ho ho ho! Then stop.""")
56
  trace = TraceFactory.from_openai(messages)
57
  with trace.as_context():
58
  assert_true(F.len(trace.messages(role="assistant")) > 0)
 
61
 
62
 
63
  def test_reindeer_names():
64
+ messages, _ = agent.run_santa_agent("Tell me all the reindeer names. Replace Blitzen with the name of the child who gets the Bike. Then stop.")
65
  trace = TraceFactory.from_openai(messages)
66
  with trace.as_context():
67
  assert_true(F.len(trace.tool_calls()) == 1)
68
  assert_true(trace.tool_calls()[0]["function"]["name"] == "stop"), "must stop #n"
69
+ assert_true(trace.messages(role="assistant")[0]["content"].contains("Dasher", "Dancer", "Prancer", "Vixen", "Comet", "Cupid", "Donner", "Alice"))
70
  assert_true(not trace.messages(role="assistant")[0]["content"].contains("Blitzen"))
71
 
72
 
 
174
  assert len(tool_calls) == 2
175
 
176
  assert_true(tool_calls[0]["function"]["name"] == "put_route_into_maps")
177
+ assert_true(
178
+ (tool_calls[0].argument("addr1") == "123 Elf Road") &
179
+ (tool_calls[0].argument("addr2") == "North Pole") &
180
+ (tool_calls[0].argument("addr3") == "88888")
181
+ ), "must provide the correct address #C"
182
  assert_true(tool_calls[1]["function"]["name"] == "stop")
183
 
184
  assert_true(F.any(