wip
Browse files- agent.py +1 -1
- app.py +61 -29
- test_agent.py +12 -0
agent.py
CHANGED
@@ -50,7 +50,7 @@ class SantaAgent:
|
|
50 |
"type": "function",
|
51 |
"function": {
|
52 |
"name": "stop",
|
53 |
-
"description": "
|
54 |
}
|
55 |
}
|
56 |
]
|
|
|
50 |
"type": "function",
|
51 |
"function": {
|
52 |
"name": "stop",
|
53 |
+
"description": "Use this tool if you are finished and want to stop."
|
54 |
}
|
55 |
}
|
56 |
]
|
app.py
CHANGED
@@ -2,74 +2,106 @@ import json
|
|
2 |
import os
|
3 |
import re
|
4 |
import invariant.testing.functional as F
|
5 |
-
from invariant.testing import TraceFactory, assert_true
|
6 |
-
from agent import SantaAgent
|
7 |
import gradio as gr
|
|
|
8 |
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
messages, gradio_messages = agent.run_santa_agent(prompt)
|
15 |
-
|
16 |
-
if not invariant_api_key.startswith("inv"):
|
17 |
-
return gradio_messages, "", "Please enter a valid Invariant API key to get the score!"
|
18 |
|
19 |
-
|
20 |
|
|
|
21 |
env={
|
22 |
"INVARIANT_API_KEY": invariant_api_key,
|
23 |
"OPENAI_API_KEY": os.environ["OPENAI_API_KEY"],
|
24 |
"PATH": os.environ["PATH"]
|
25 |
}
|
26 |
-
# env = None
|
27 |
-
|
28 |
-
# run command invariant test test_agent.py --agent-params '{"system_prompt": "you are santa"}'
|
29 |
import subprocess
|
30 |
out = subprocess.run([
|
31 |
"invariant", "test", "test_agent.py",
|
32 |
"--agent-params", json.dumps(agent_params),
|
33 |
"--push", "--dataset_name", "santa_agent",
|
34 |
], capture_output=True, text=True, env=env)
|
35 |
-
|
36 |
url = re.search(r"https://explorer.invariantlabs.ai/[\-_a-zA-Z0-9/]+", out.stdout).group(0)
|
37 |
-
|
38 |
-
message = "Please find your results at: " + url
|
39 |
-
return gradio_messages, "", message
|
40 |
-
|
41 |
|
42 |
with gr.Blocks() as demo:
|
|
|
|
|
|
|
43 |
gr.Markdown("""
|
44 |
## Prompt the Santa Agent
|
45 |
-
* Find a system prompt that
|
46 |
""")
|
47 |
-
input = gr.Textbox(lines=1, label="""System Prompt""", value=
|
48 |
with gr.Row():
|
49 |
with gr.Column(scale=2):
|
50 |
chatbot = gr.Chatbot(
|
51 |
type="messages",
|
52 |
label="Example interaction",
|
53 |
-
value=
|
54 |
-
{"role": "user", "content": "Could you please deliver Xbox to John?"},
|
55 |
-
],
|
56 |
avatar_images=[
|
57 |
None,
|
58 |
"https://invariantlabs.ai/theme/images/logo.svg"
|
59 |
],
|
60 |
)
|
61 |
with gr.Column(scale=1):
|
62 |
-
console = gr.
|
63 |
|
64 |
invariant_api_key = gr.Textbox(lines=1, label="""Invariant API Key - you can play without it, but to obtain full score please register and get the key at https://explorer.invariantlabs.ai/settings""")
|
65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
input.submit(lambda: gr.update(visible=False), None, [input])
|
67 |
|
68 |
-
# Submit button
|
69 |
submit = gr.Button("Submit")
|
70 |
-
submit.click(
|
|
|
71 |
submit.click(lambda: gr.update(visible=False), None, [input])
|
72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
|
74 |
if __name__ == "__main__":
|
75 |
demo.launch()
|
|
|
2 |
import os
|
3 |
import re
|
4 |
import invariant.testing.functional as F
|
|
|
|
|
5 |
import gradio as gr
|
6 |
+
from agent import SantaAgent
|
7 |
|
8 |
+
INITIAL_SYTSTEM_PROMPT = "You are a Santa Claus. Buy presents and deliver them to the children."
|
9 |
+
INITIAL_CHABOT = [
|
10 |
+
{"role": "user", "content": "Could you please deliver Xbox to John?"},
|
11 |
+
]
|
12 |
+
INITIAL_STATE = ""
|
|
|
|
|
|
|
|
|
13 |
|
14 |
+
agent = SantaAgent(INITIAL_SYTSTEM_PROMPT)
|
15 |
|
16 |
+
def run_testing(agent_params, invariant_api_key):
|
17 |
env={
|
18 |
"INVARIANT_API_KEY": invariant_api_key,
|
19 |
"OPENAI_API_KEY": os.environ["OPENAI_API_KEY"],
|
20 |
"PATH": os.environ["PATH"]
|
21 |
}
|
|
|
|
|
|
|
22 |
import subprocess
|
23 |
out = subprocess.run([
|
24 |
"invariant", "test", "test_agent.py",
|
25 |
"--agent-params", json.dumps(agent_params),
|
26 |
"--push", "--dataset_name", "santa_agent",
|
27 |
], capture_output=True, text=True, env=env)
|
|
|
28 |
url = re.search(r"https://explorer.invariantlabs.ai/[\-_a-zA-Z0-9/]+", out.stdout).group(0)
|
29 |
+
return url
|
|
|
|
|
|
|
30 |
|
31 |
with gr.Blocks() as demo:
|
32 |
+
# Add state at the beginning of the Blocks
|
33 |
+
results_state = gr.State(INITIAL_STATE)
|
34 |
+
|
35 |
gr.Markdown("""
|
36 |
## Prompt the Santa Agent
|
37 |
+
* Find a system prompt that passes all the tests
|
38 |
""")
|
39 |
+
input = gr.Textbox(lines=1, label="""System Prompt""", value=INITIAL_SYTSTEM_PROMPT)
|
40 |
with gr.Row():
|
41 |
with gr.Column(scale=2):
|
42 |
chatbot = gr.Chatbot(
|
43 |
type="messages",
|
44 |
label="Example interaction",
|
45 |
+
value=INITIAL_CHABOT,
|
|
|
|
|
46 |
avatar_images=[
|
47 |
None,
|
48 |
"https://invariantlabs.ai/theme/images/logo.svg"
|
49 |
],
|
50 |
)
|
51 |
with gr.Column(scale=1):
|
52 |
+
console = gr.Button(value="Console Output", visible=False)
|
53 |
|
54 |
invariant_api_key = gr.Textbox(lines=1, label="""Invariant API Key - you can play without it, but to obtain full score please register and get the key at https://explorer.invariantlabs.ai/settings""")
|
55 |
+
|
56 |
+
def run_agent_with_state(user_prompt, history, invariant_api_key, state, is_example=False):
|
57 |
+
# messages, gradio_messages = agent.run_santa_agent(prompt)
|
58 |
+
gradio_messages = [
|
59 |
+
{"role": "user", "content": "Could you please deliver Xbox to John?"},
|
60 |
+
{"role": "assistant", "content": "I'm sorry, but I can't deliver presents. I'm just a chatbot."},
|
61 |
+
]
|
62 |
+
|
63 |
+
if not invariant_api_key.startswith("inv"):
|
64 |
+
return gradio_messages, "", "Please enter a valid Invariant API key to get the score!", state
|
65 |
+
|
66 |
+
agent_params = {"system_prompt": user_prompt}
|
67 |
+
|
68 |
+
return gradio_messages, "", "Testing in progress...", [agent_params, invariant_api_key]
|
69 |
+
|
70 |
+
def update_console(state):
|
71 |
+
if type(state) == list:
|
72 |
+
agent_params, invariant_api_key = state[0], state[1]
|
73 |
+
return gr.update(value="Testing in progress...", interactive=False, visible=True), (agent_params, invariant_api_key)
|
74 |
+
if type(state) == tuple:
|
75 |
+
agent_params, invariant_api_key = state
|
76 |
+
url = run_testing(agent_params, invariant_api_key)
|
77 |
+
return gr.update(value="Open results", link=url, visible=True, interactive=True), url
|
78 |
+
if type(state) == str and state.startswith("https"):
|
79 |
+
return gr.update(value="Open results", link=state, visible=True, interactive=True), state
|
80 |
+
return gr.update(value="Testing in progress..."), state
|
81 |
+
|
82 |
+
input.submit(run_agent_with_state, [input, chatbot, invariant_api_key, results_state],
|
83 |
+
[chatbot, input, console, results_state])
|
84 |
input.submit(lambda: gr.update(visible=False), None, [input])
|
85 |
|
|
|
86 |
submit = gr.Button("Submit")
|
87 |
+
submit.click(run_agent_with_state, [input, chatbot, invariant_api_key, results_state],
|
88 |
+
[chatbot, input, console, results_state])
|
89 |
submit.click(lambda: gr.update(visible=False), None, [input])
|
90 |
|
91 |
+
reset = gr.Button("Reset")
|
92 |
+
def reset_state():
|
93 |
+
return (
|
94 |
+
gr.update(value=INITIAL_SYTSTEM_PROMPT, visible=True), # input
|
95 |
+
INITIAL_CHABOT, # chatbot
|
96 |
+
INITIAL_STATE, # results_state
|
97 |
+
gr.update(visible=False), # console
|
98 |
+
)
|
99 |
+
|
100 |
+
reset.click(reset_state, None, [input, chatbot, results_state, console])
|
101 |
+
|
102 |
+
timer = gr.Timer(value=1.0)
|
103 |
+
timer.tick(update_console, results_state, [console, results_state])
|
104 |
+
|
105 |
|
106 |
if __name__ == "__main__":
|
107 |
demo.launch()
|
test_agent.py
CHANGED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
from invariant.testing import TraceFactory, get_agent_param, assert_true
|
2 |
from agent import SantaAgent
|
3 |
|
@@ -20,5 +22,15 @@ def test_ho_ho_ho():
|
|
20 |
messages, _ = agent.run_santa_agent("""Just reply with: "Ho ho ho!" and stop""")
|
21 |
trace = TraceFactory.from_openai(messages)
|
22 |
with trace.as_context():
|
|
|
23 |
assert_true(trace.messages(role="assistant")[0]["content"].contains("Ho ho ho!"))
|
24 |
assert_true(trace.tool_calls()[0]["function"]["name"] == "stop")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import invariant.testing.functional as F
|
2 |
+
import pytest
|
3 |
from invariant.testing import TraceFactory, get_agent_param, assert_true
|
4 |
from agent import SantaAgent
|
5 |
|
|
|
22 |
messages, _ = agent.run_santa_agent("""Just reply with: "Ho ho ho!" and stop""")
|
23 |
trace = TraceFactory.from_openai(messages)
|
24 |
with trace.as_context():
|
25 |
+
assert_true(F.len(trace.messages(role="assistant")) > 0)
|
26 |
assert_true(trace.messages(role="assistant")[0]["content"].contains("Ho ho ho!"))
|
27 |
assert_true(trace.tool_calls()[0]["function"]["name"] == "stop")
|
28 |
+
|
29 |
+
|
30 |
+
# @pytest.mark.parametrize("country", ["Finland", "Iceland"])
|
31 |
+
# def test_cities(country):
|
32 |
+
# messages, _ = agent.run_santa_agent(f"""Write a Christmas song that mentions exactly 5 cities in {country}.""")
|
33 |
+
# trace = TraceFactory.from_openai(messages)
|
34 |
+
# with trace.as_context():
|
35 |
+
# cities = trace.messages(role="assistant")[0]["content"].extract(f"cities in {country}")
|
36 |
+
# assert_true(F.len(cities) == 5)
|