Spaces:
Sleeping
Sleeping
import json | |
import os | |
import re | |
import gradio as gr | |
from agent import SantaAgent | |
import subprocess | |
INITIAL_SYTSTEM_PROMPT = "You are Santa Claus. Buy presents and deliver them to the children." | |
EXAMPLE_PROMPT = "Make a naughty and nice list." | |
INITIAL_CHABOT = [ | |
{"role": "user", "content": EXAMPLE_PROMPT}, | |
] | |
INITIAL_STATE = "" | |
TOTAL_TESTS = 10 | |
agent = SantaAgent(INITIAL_SYTSTEM_PROMPT) | |
# load css from styling.css | |
with open("styling.css", "r") as f: | |
css = f.read() | |
# Define helper functions | |
def run_agent_with_state(user_prompt, history, invariant_api_key, state, is_example=False): | |
messages, gradio_messages = agent.run_santa_agent(EXAMPLE_PROMPT) | |
if not invariant_api_key.startswith("inv"): | |
return gradio_messages | |
return gradio_messages | |
def update_json_url(url): | |
value = json.dumps( | |
{ | |
"url": url | |
} | |
) | |
return gr.update(value=value) | |
def run_testing(user_prompt, invariant_api_key): | |
if (not invariant_api_key.startswith("inv")) or not invariant_api_key: | |
gr.Warning("Please enter a valid Invariant API key to run all the tests!") | |
return "Please enter a valid Invariant API key to get the score!", '', 'toggled-off-button' | |
agent_params = {"system_prompt": user_prompt} | |
yield f'Running Test 1 of {TOTAL_TESTS}. Please wait.', '', 'button-loading' | |
env={ | |
"INVARIANT_API_KEY": invariant_api_key, | |
"OPENAI_API_KEY": os.environ["OPENAI_API_KEY"], | |
"PATH": os.environ["PATH"] | |
} | |
cmd = [ | |
"invariant", "test", "test_agent.py", | |
"--agent-params", json.dumps(agent_params), | |
"--push", "--dataset_name", "santa_agent", '-s', | |
] | |
process = subprocess.Popen( | |
cmd, | |
stdout=subprocess.PIPE, | |
stderr=subprocess.STDOUT, | |
universal_newlines=True, | |
bufsize=0, | |
env=env | |
) | |
# Iterate over the output lines as they are produced | |
for line in process.stdout: | |
print(line, end="") | |
if line.startswith("__special_formatted_output__:"): | |
current_test = int(line.split(":")[1].strip()) | |
yield f'Running Test {current_test} of {TOTAL_TESTS}. Please wait.', '', 'button-loading' | |
# If there is a regex match with https://explorer.invariantlabs.ai/[\-_a-zA-Z0-9/]+ then yield it | |
if url := re.search(r"https://explorer.invariantlabs.ai/[\-_a-zA-Z0-9/]+", line): | |
yield 'Open Results', url.group(0), 'toggled-on-button' | |
# Wait for the subprocess to finish | |
process.stdout.close() | |
return_code = process.wait() | |
print(f"Pytest finished with return code {return_code}") | |
def reset_state(): | |
return ( | |
gr.update(value=INITIAL_SYTSTEM_PROMPT, visible=True), # input | |
INITIAL_CHABOT, # chatbot | |
INITIAL_STATE, # results_state | |
gr.update( | |
value="Click 'Submit' to see results here", | |
elem_classes='toggled-off-button', | |
), # console | |
) | |
# Main interface | |
with gr.Blocks( | |
css=css, | |
title="Santa Agent", | |
theme=gr.themes.Soft(font="Arial"), | |
) as demo: | |
# State vrariables | |
invariant_link = gr.State('https://explorer.invariantlabs.ai/settings') | |
test_progress_state = gr.State("") | |
test_url_state = gr.State(INITIAL_STATE) | |
test_button_class_state = gr.State("toggled-off-button") | |
# Have to store URL as JSON instead of state as states cannot | |
# reliably be passed to the frontend on updates: https://github.com/gradio-app/gradio/issues/3525 | |
current_invariant_url = gr.JSON("""{"url": ""}""", visible=False) | |
gr.HTML(""" | |
<div class="home-banner-wrapper"> | |
<div class="home-banner-content"> | |
<h1>Prompt the Santa Agent</h1> | |
<p>Find a prompt that passes all tests.</p> | |
</div> | |
<div class="home-banner-buttons"> | |
<a href="https://explorer.invariantlabs.ai/" target="_blank"> | |
Invariant Explorer β | |
</a> | |
</div> | |
</div> | |
""", elem_classes="home-banner") | |
with gr.Row(equal_height=False): | |
with gr.Column(scale=2): | |
with gr.Accordion("π Instructions on getting an API Key", open=False): | |
gr.Markdown(""" | |
## Get an API Key | |
* [Create an account here](https://explorer.invariantlabs.ai/settings) by clicking 'Sign In', and then the GitHub icon. | |
* Click on `Get API Key` below to get your Invariant API key. | |
* Paste the API key in the text box below. | |
""" | |
) | |
with gr.Column(scale=3): | |
with gr.Accordion("π Task Description", open=False): | |
gr.Markdown(""" | |
## Prompt the Santa Agent | |
The Invariant Santa Agent is tasked with delivering presents to children and performing various Santa tasks. | |
Your job is to provide the system prompt that will guide the Santa Agent to complete its tasks:\n | |
* Change the `System Prompt` to modify the behavior of the Santa Agent. | |
* Click `Submit` to test the Santa Agent with the new system prompt. | |
* Find a system prompt that passes all the tests. | |
* When the tests are done running, view your results in the Invariant Explorer by clicking `Open results`. | |
* Click `Reset` to start over. | |
""" | |
) | |
# Main input interface | |
with gr.Row(equal_height=True): | |
with gr.Column(scale=3): | |
chatbot = gr.Chatbot( | |
type="messages", | |
label="Example interaction", | |
value=INITIAL_CHABOT, | |
avatar_images=[ | |
None, | |
"https://invariantlabs.ai/theme/images/logo.svg" | |
], | |
max_height=700 | |
) | |
with gr.Column(scale=2): | |
input = gr.Textbox(lines=25, label="""System Prompt""", value=INITIAL_SYTSTEM_PROMPT, interactive=True, placeholder="Enter a System prompt here...") | |
# API key input and submit/reset/status | |
with gr.Row(equal_height=True): | |
with gr.Column(scale=3): | |
with gr.Row(equal_height=True): | |
get_key_button = gr.Button("Get API Key", elem_id='get-key-button', min_width=0) | |
get_key_button.click( # Open Invariant API key link in new tab | |
fn=None, | |
inputs=invariant_link, | |
js="(invariant_link) => {{ window.open(invariant_link, '_blank') }}", | |
) | |
invariant_api_key = gr.Textbox(lines=1, max_lines=1, elem_id='key_input', min_width=600, label='inv_key', show_label=False, interactive=True, placeholder="Paste your Invariant API key here...") | |
with gr.Column(scale=2, min_width=200): | |
with gr.Row(equal_height=True): | |
submit_button = gr.Button("Submit", min_width=0, elem_id='submit-button') | |
reset_button = gr.Button("Reset", min_width=0, elem_id="reset-button") | |
run_button = gr.Button("Click 'Submit' to see results here", elem_classes=test_button_class_state.value, min_width=320) | |
submit_button.click( | |
fn=run_testing, | |
inputs=[input, invariant_api_key], | |
outputs=[test_progress_state, test_url_state, test_button_class_state], | |
) | |
run_button.click( | |
fn=None, | |
inputs=current_invariant_url, | |
js=""" | |
(current_invariant_url) => { | |
if (current_invariant_url['url'] !== '' && current_invariant_url['url']) { | |
window.open(current_invariant_url['url'], '_blank'); | |
} else { | |
console.log("No URL to open"); | |
} | |
} | |
""", | |
) | |
reset_button.click(reset_state, None, [input, chatbot, test_url_state, run_button]) | |
submit_button.click(run_agent_with_state, [input, chatbot, invariant_api_key, test_url_state], [chatbot]) | |
test_progress_state.change(lambda ts: ts, test_progress_state, run_button) | |
test_button_class_state.change(lambda ts: gr.update(elem_classes=ts), test_button_class_state, run_button) | |
test_url_state.change(update_json_url, test_url_state, current_invariant_url) | |
input.submit(lambda: gr.update(visible=True), None, [input]) | |
if __name__ == "__main__": | |
demo.launch() |