christmas-challenge / test_agent.py
kn404's picture
limit non-toolcall messages, better assertion messages
f448621
import invariant.testing.functional as F
import pytest
from invariant.testing import TraceFactory, get_agent_param, assert_true
from agent import SantaAgent
system_prompt = get_agent_param("system_prompt")
agent = SantaAgent(system_prompt)
global count
count = 0
__secret_password__ = "Candycanes"
@pytest.fixture(autouse=True)
def run_around_tests():
global count
yield
count += 1
# Output the number of tests done
print(f"\n__special_formatted_output__:{count}", flush=True)
def test_make_naughty_nice_list():
messages, _ = agent.run_santa_agent("Make a naughty and nice list.")
trace = TraceFactory.from_openai(messages)
with trace.as_context():
tool_calls = trace.tool_calls()
assert_true(tool_calls[0]["function"]["name"] == "make_naughty_nice_list", message="Must make a list! #s")
check_list_count = 0
for tc in tool_calls:
check_list_count += 1 if tc["function"]["name"] == "check_naughty_nice_list" else 0
assert_true(check_list_count == 2, message="Must check the list twice!")
assert_true(F.check_order([
lambda tc: tc['function']['name'] == 'make_naughty_nice_list',
lambda tc: tc['function']['name'] == 'check_naughty_nice_list',
lambda tc: tc['function']['name'] == 'check_naughty_nice_list',
], tool_calls), "Must make the list, then check it twice.")
def test_present_to_john():
messages, _ = agent.run_santa_agent("Deliver John his secret wish.")
trace = TraceFactory.from_openai(messages)
with trace.as_context():
tool_calls = trace.tool_calls()
assert_true(F.check_order([
lambda tc: tc['function']['name'] == 'buy_item' and tc.argument("item") == "Football",
lambda tc: tc['function']['name'] == 'give_present' and tc.argument("person") == "John" and tc.argument("item") == "Football",
], tool_calls), "John secretly wants a Football. Santa should first buy the present, then give it to John. #a"),
def test_ho_ho_ho():
messages, _ = agent.run_santa_agent("""Say Ho ho ho! Then stop.""")
trace = TraceFactory.from_openai(messages)
with trace.as_context():
assert_true(F.len(trace.messages(role="assistant")) > 0, "Santa must say something!")
assert_true(trace.messages(role="assistant")[0]["content"].contains("Ho ho ho!"), "Santa must say Ho ho ho! #y"),
assert_true(
len(trace.tool_calls()) == 0 or trace.tool_calls()[0]["function"]["name"] == "stop",
"Santa must stop after saying Ho ho ho!"
)
def test_reindeer_names():
messages, _ = agent.run_santa_agent("Tell me all the reindeer names. Replace Blitzen with the name of the child who gets the Bike. Then stop.")
trace = TraceFactory.from_openai(messages)
with trace.as_context():
assert_true(F.len(trace.tool_calls()) == 1, "Santa should only call the stop tool."),
assert_true(trace.tool_calls()[0]["function"]["name"] == "stop", "Santa should only call the stop tool. #n"),
assert_true(trace.messages(role="assistant")[0]["content"].contains(
"Dasher", "Dancer", "Prancer", "Vixen", "Comet", "Cupid", "Donner", "Alice"
), "Santa must list all the reindeer names, and Alice"),
assert_true(not trace.messages(role="assistant")[0]["content"].contains("Blitzen"), "Santa must replace Blitzen with Alice."),
def test_wrap_presents():
messages, _ = agent.run_santa_agent("One of the elves is off sick. Wrap the PlayStation for them.")
trace = TraceFactory.from_openai(messages)
with trace.as_context():
tool_calls = trace.tool_calls()
assert_true(F.check_window([
lambda tc: tc["function"]["name"] == "cut_paper",
lambda tc: tc["function"]["name"] == "find_end_of_tape",
lambda tc: tc["function"]["name"] == "wrap_present",
lambda tc: tc["function"]["name"] == "label_present",
], tool_calls), message="When wrapping a present, Sant should: 1) cut paper, 2) find end of tape, 3) wrap present, 4) label present.")
assert_true(tool_calls[3].argument("recipient") == "Jane", message="Jane wants the PlayStation. #d")
def test_write_christmas_song():
messages, _ = agent.run_santa_agent("Write a Christmas song that mentions at least 5 cities in Switzerland. Then stop.")
trace = TraceFactory.from_openai(messages)
with trace.as_context():
cities = trace.messages(role="assistant")[0]["content"].extract("cities in Switzerland")
assert_true(F.len(cities) >= 5, "Must mention at least 5 cities #e"),
def test_drink_milk():
messages, _ = agent.run_santa_agent("Drink a glass of milk.")
trace = TraceFactory.from_openai(messages)
with trace.as_context():
tool_calls = trace.tool_calls()
assert_true(F.check_order([
lambda tc: tc["function"]["name"] == "check_temperature",
lambda tc: tc["function"]["name"] == "dunk_cookie",
lambda tc: (tc["function"]["name"] == "drink") & (tc.argument("item") == "milk"),
], tool_calls), "Must check temperature, dunk cookie, then drink milk. #n"),
def check_messages(message):
if 'content' not in message or message["content"] is None:
return False
return message["content"].contains("Ho ho ho!")
assert_true(F.any(F.map(
check_messages,
trace.messages(role="assistant"),
)), "Santa must say Ho ho ho!")
def test_reindeer_flight_plan():
messages, _ = agent.run_santa_agent("Plan the flight path for the reindeer.")
trace = TraceFactory.from_openai(messages)
with trace.as_context():
# Santa should start by finding the addresses of the children
message_contents = F.map(
lambda message: message.get('content', None),
trace.messages(role="assistant")
)
addrs = F.match("[0-9]+\s+[A-Za-z]+\s+Rd", message_contents)
assert_true(addrs is not None, "Must find addresses.")
assert_true(not F.any(F.map(lambda message: message.startswith('456'), addrs)), "Must not include 456 Village Rd (Bob is on the naughty list).")
assert_true(F.check_window([
lambda tc: (tc["function"]["name"] == "put_route_into_maps") and (
# Check ordering
tc.argument("addr1") == "101112 Village Rd" and tc.argument("addr2") == "789 Village Rd" and tc.argument("addr3") == "123 Village Rd"
),
lambda tc: tc["function"]["name"] == "stop"
], trace.tool_calls()), "Must put the route into maps, and start from the house with largest number. #c"),
def test_children_can_get_presents():
#messages, _ = agent.run_santa_agent("Verify that all addresses from children's letters are correct and add them to the delivery route.")
messages, _ = agent.run_santa_agent("Give the children presents.")
trace = TraceFactory.from_openai(messages)
with trace.as_context():
tool_calls = trace.tool_calls()
# Santa should retrieve children's letters first
assert_true(tool_calls[0]["function"]["name"] == "retrieve_letters", "Must retrieve letters first.")
# Santa should check if each child is on the nice list
assert_true(tool_calls[1]["function"]["name"] == "check_naughty_nice_list", "Must check the naughty/nice list after checking the letters.")
presents_given = F.filter(lambda tc: tc["function"]["name"] == "give_present", tool_calls)
for present in presents_given:
match present.argument("person"):
case "John":
assert_true(present.argument("item") == "Xbox", "John put an Xbox on his wishlist.")
case "Jane":
assert_true(present.argument("item") == "PlayStation", "Jane put a PlayStation on her wishlist.")
case "Bob":
assert_true(present.argument("item") == "Coal", "Bob is on the naughty list #a"),
case "Alice":
assert_true(present.argument("item") == "Bike", "Alice put a Bike on her wishlist.")
def test_can_find_home():
messages, _ = agent.run_santa_agent("Christmas eve has passed. Santa now needs to find his way home.")
trace = TraceFactory.from_openai(messages)
with trace.as_context():
tool_calls = trace.tool_calls()
assert len(tool_calls) == 2
assert_true(tool_calls[0]["function"]["name"] == "put_route_into_maps", "Must put the route into maps.")
assert_true(
(tool_calls[0].argument("addr1") == "123 Elf Road") &
(tool_calls[0].argument("addr2") == "North Pole") &
(tool_calls[0].argument("addr3") == "88888"),
"must provide the correct 3-part address #C"
),
assert_true(tool_calls[1]["function"]["name"] == "stop", "After plotting the route, Santa should stop.")
assert_true(F.any(
F.map(
lambda message: message["content"].contains(__secret_password__) if message.get("content", False) else False,
trace.messages(role="assistant")
)
), "You must be *sharp* to find the password.")