Spaces:
Sleeping
Sleeping
| import invariant.testing.functional as F | |
| import pytest | |
| from invariant.testing import TraceFactory, get_agent_param, assert_true | |
| from agent import SantaAgent | |
| system_prompt = get_agent_param("system_prompt") | |
| agent = SantaAgent(system_prompt) | |
| global count | |
| count = 0 | |
| __secret_password__ = "Candycanes" | |
| def run_around_tests(): | |
| global count | |
| yield | |
| count += 1 | |
| # Output the number of tests done | |
| print(f"\n__special_formatted_output__:{count}", flush=True) | |
| def test_make_naughty_nice_list(): | |
| messages, _ = agent.run_santa_agent("Make a naughty and nice list.") | |
| trace = TraceFactory.from_openai(messages) | |
| with trace.as_context(): | |
| tool_calls = trace.tool_calls() | |
| assert_true(tool_calls[0]["function"]["name"] == "make_naughty_nice_list", message="Must make a list! #s") | |
| check_list_count = 0 | |
| for tc in tool_calls: | |
| check_list_count += 1 if tc["function"]["name"] == "check_naughty_nice_list" else 0 | |
| assert_true(check_list_count == 2, message="Must check the list twice!") | |
| assert_true(F.check_order([ | |
| lambda tc: tc['function']['name'] == 'make_naughty_nice_list', | |
| lambda tc: tc['function']['name'] == 'check_naughty_nice_list', | |
| lambda tc: tc['function']['name'] == 'check_naughty_nice_list', | |
| ], tool_calls), "Must make the list, then check it twice.") | |
| def test_present_to_john(): | |
| messages, _ = agent.run_santa_agent("Deliver John his secret wish.") | |
| trace = TraceFactory.from_openai(messages) | |
| with trace.as_context(): | |
| tool_calls = trace.tool_calls() | |
| assert_true(F.check_order([ | |
| lambda tc: tc['function']['name'] == 'buy_item' and tc.argument("item") == "Football", | |
| lambda tc: tc['function']['name'] == 'give_present' and tc.argument("person") == "John" and tc.argument("item") == "Football", | |
| ], tool_calls), "John secretly wants a Football. Santa should first buy the present, then give it to John. #a"), | |
| def test_ho_ho_ho(): | |
| messages, _ = agent.run_santa_agent("""Say Ho ho ho! Then stop.""") | |
| trace = TraceFactory.from_openai(messages) | |
| with trace.as_context(): | |
| assert_true(F.len(trace.messages(role="assistant")) > 0, "Santa must say something!") | |
| assert_true(trace.messages(role="assistant")[0]["content"].contains("Ho ho ho!"), "Santa must say Ho ho ho! #y"), | |
| assert_true( | |
| len(trace.tool_calls()) == 0 or trace.tool_calls()[0]["function"]["name"] == "stop", | |
| "Santa must stop after saying Ho ho ho!" | |
| ) | |
| def test_reindeer_names(): | |
| messages, _ = agent.run_santa_agent("Tell me all the reindeer names. Replace Blitzen with the name of the child who gets the Bike. Then stop.") | |
| trace = TraceFactory.from_openai(messages) | |
| with trace.as_context(): | |
| assert_true(F.len(trace.tool_calls()) == 1, "Santa should only call the stop tool."), | |
| assert_true(trace.tool_calls()[0]["function"]["name"] == "stop", "Santa should only call the stop tool. #n"), | |
| assert_true(trace.messages(role="assistant")[0]["content"].contains( | |
| "Dasher", "Dancer", "Prancer", "Vixen", "Comet", "Cupid", "Donner", "Alice" | |
| ), "Santa must list all the reindeer names, and Alice"), | |
| assert_true(not trace.messages(role="assistant")[0]["content"].contains("Blitzen"), "Santa must replace Blitzen with Alice."), | |
| def test_wrap_presents(): | |
| messages, _ = agent.run_santa_agent("One of the elves is off sick. Wrap the PlayStation for them.") | |
| trace = TraceFactory.from_openai(messages) | |
| with trace.as_context(): | |
| tool_calls = trace.tool_calls() | |
| assert_true(F.check_window([ | |
| lambda tc: tc["function"]["name"] == "cut_paper", | |
| lambda tc: tc["function"]["name"] == "find_end_of_tape", | |
| lambda tc: tc["function"]["name"] == "wrap_present", | |
| lambda tc: tc["function"]["name"] == "label_present", | |
| ], tool_calls), message="When wrapping a present, Sant should: 1) cut paper, 2) find end of tape, 3) wrap present, 4) label present.") | |
| assert_true(tool_calls[3].argument("recipient") == "Jane", message="Jane wants the PlayStation. #d") | |
| def test_write_christmas_song(): | |
| messages, _ = agent.run_santa_agent("Write a Christmas song that mentions at least 5 cities in Switzerland. Then stop.") | |
| trace = TraceFactory.from_openai(messages) | |
| with trace.as_context(): | |
| cities = trace.messages(role="assistant")[0]["content"].extract("cities in Switzerland") | |
| assert_true(F.len(cities) >= 5, "Must mention at least 5 cities #e"), | |
| def test_drink_milk(): | |
| messages, _ = agent.run_santa_agent("Drink a glass of milk.") | |
| trace = TraceFactory.from_openai(messages) | |
| with trace.as_context(): | |
| tool_calls = trace.tool_calls() | |
| assert_true(F.check_order([ | |
| lambda tc: tc["function"]["name"] == "check_temperature", | |
| lambda tc: tc["function"]["name"] == "dunk_cookie", | |
| lambda tc: (tc["function"]["name"] == "drink") & (tc.argument("item") == "milk"), | |
| ], tool_calls), "Must check temperature, dunk cookie, then drink milk. #n"), | |
| def check_messages(message): | |
| if 'content' not in message or message["content"] is None: | |
| return False | |
| return message["content"].contains("Ho ho ho!") | |
| assert_true(F.any(F.map( | |
| check_messages, | |
| trace.messages(role="assistant"), | |
| )), "Santa must say Ho ho ho!") | |
| def test_reindeer_flight_plan(): | |
| messages, _ = agent.run_santa_agent("Plan the flight path for the reindeer.") | |
| trace = TraceFactory.from_openai(messages) | |
| with trace.as_context(): | |
| # Santa should start by finding the addresses of the children | |
| message_contents = F.map( | |
| lambda message: message.get('content', None), | |
| trace.messages(role="assistant") | |
| ) | |
| addrs = F.match("[0-9]+\s+[A-Za-z]+\s+Rd", message_contents) | |
| assert_true(addrs is not None, "Must find addresses.") | |
| assert_true(not F.any(F.map(lambda message: message.startswith('456'), addrs)), "Must not include 456 Village Rd (Bob is on the naughty list).") | |
| assert_true(F.check_window([ | |
| lambda tc: (tc["function"]["name"] == "put_route_into_maps") and ( | |
| # Check ordering | |
| tc.argument("addr1") == "101112 Village Rd" and tc.argument("addr2") == "789 Village Rd" and tc.argument("addr3") == "123 Village Rd" | |
| ), | |
| lambda tc: tc["function"]["name"] == "stop" | |
| ], trace.tool_calls()), "Must put the route into maps, and start from the house with largest number. #c"), | |
| def test_children_can_get_presents(): | |
| #messages, _ = agent.run_santa_agent("Verify that all addresses from children's letters are correct and add them to the delivery route.") | |
| messages, _ = agent.run_santa_agent("Give the children presents.") | |
| trace = TraceFactory.from_openai(messages) | |
| with trace.as_context(): | |
| tool_calls = trace.tool_calls() | |
| # Santa should retrieve children's letters first | |
| assert_true(tool_calls[0]["function"]["name"] == "retrieve_letters", "Must retrieve letters first.") | |
| # Santa should check if each child is on the nice list | |
| assert_true(tool_calls[1]["function"]["name"] == "check_naughty_nice_list", "Must check the naughty/nice list after checking the letters.") | |
| presents_given = F.filter(lambda tc: tc["function"]["name"] == "give_present", tool_calls) | |
| for present in presents_given: | |
| match present.argument("person"): | |
| case "John": | |
| assert_true(present.argument("item") == "Xbox", "John put an Xbox on his wishlist.") | |
| case "Jane": | |
| assert_true(present.argument("item") == "PlayStation", "Jane put a PlayStation on her wishlist.") | |
| case "Bob": | |
| assert_true(present.argument("item") == "Coal", "Bob is on the naughty list #a"), | |
| case "Alice": | |
| assert_true(present.argument("item") == "Bike", "Alice put a Bike on her wishlist.") | |
| def test_can_find_home(): | |
| messages, _ = agent.run_santa_agent("Christmas eve has passed. Santa now needs to find his way home.") | |
| trace = TraceFactory.from_openai(messages) | |
| with trace.as_context(): | |
| tool_calls = trace.tool_calls() | |
| assert len(tool_calls) == 2 | |
| assert_true(tool_calls[0]["function"]["name"] == "put_route_into_maps", "Must put the route into maps.") | |
| assert_true( | |
| (tool_calls[0].argument("addr1") == "123 Elf Road") & | |
| (tool_calls[0].argument("addr2") == "North Pole") & | |
| (tool_calls[0].argument("addr3") == "88888"), | |
| "must provide the correct 3-part address #C" | |
| ), | |
| assert_true(tool_calls[1]["function"]["name"] == "stop", "After plotting the route, Santa should stop.") | |
| assert_true(F.any( | |
| F.map( | |
| lambda message: message["content"].contains(__secret_password__) if message.get("content", False) else False, | |
| trace.messages(role="assistant") | |
| ) | |
| ), "You must be *sharp* to find the password.") | |