llm_rules / app.py
normster's picture
Update app.py
88e545f verified
import argparse
from dataclasses import asdict, dataclass, field
from datetime import datetime
import html
from itertools import zip_longest
import os
import textwrap
from typing import Dict, List, Tuple
from dotenv import load_dotenv
import gradio as gr
from pymongo import MongoClient
from llm_rules import Role, Message, models, scenarios
MONGO_URI = "mongodb+srv://{username}:{password}@{host}/?retryWrites=true&w=majority"
MONGO_DB = None
PLACEHOLDER = "Enter message"
History = List[List[str]]
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--hf_proxy", action="store_true", default=False)
parser.add_argument("--port", type=int, default=7860)
return parser.parse_args()
@dataclass
class State:
scenario_name: str
provider_name: str
model_name: str
scenario: scenarios.scenario.BaseScenario = None
model: models.BaseModel = None
system_message: str = None
use_system_instructions: bool = False
messages: List[Message] = field(default_factory=list)
redacted_messages: List[Message] = field(default_factory=list)
last_user_message_valid: bool = False
def __post_init__(self):
self.scenario = scenarios.SCENARIOS[self.scenario_name]()
self.model = models.MODEL_BUILDERS[self.provider_name](
model=self.model_name,
stream=True,
temperature=0,
)
self.messages = self.get_initial_messages()
self.redacted_messages = self.get_initial_messages(redacted=True)
def get_initial_messages(self, redacted=False) -> List[Message]:
prompt = self.scenario.redacted_prompt if redacted else self.scenario.prompt
if self.use_system_instructions:
messages = [
Message(Role.SYSTEM, prompt),
]
else:
messages = [
Message(Role.SYSTEM, models.PROMPTS[self.system_message]),
Message(Role.USER, prompt),
Message(Role.ASSISTANT, self.scenario.initial_response),
]
return messages
def get_history(self) -> History:
"""Process redacted messages into format for chatbot to display."""
redacted_messages = self.redacted_messages[1:] # skip system message
history = []
args = [iter(redacted_messages)] * 2
for u, a in zip_longest(*args):
u = html.escape(u.content, quote=False)
a = None if a is None else html.escape(a.content, quote=False)
history.append([u, a])
return history
def update_state_and_history(self, history: History, delta: str) -> History:
"""Incrementally update last item of both messages and history."""
# Redacted messages points to same assistant message
self.messages[-1].content += delta
history[-1][-1] += html.escape(delta, quote=False)
return history
def get_info(self):
info_str = "Return to send message. Shift + Return to add a new line."
if self.scenario.format_message:
info_str = self.scenario.format_message + " " + info_str
return info_str
def unescape_messages(self) -> List[Message]:
return [Message(m.role, html.unescape(m.content)) for m in self.messages]
def change_provider(state: State, provider_name: str) -> Tuple[State, Dict]:
"""Update model provider and model selection."""
state.provider_name = provider_name.lower()
state.model_name = models.MODEL_DEFAULTS[state.provider_name]
state.model = models.MODEL_BUILDERS[state.provider_name](
model=state.model_name,
stream=True,
temperature=0,
)
update_model = gr.update(
choices=models.MODEL_NAMES_BY_PROVIDER[state.provider_name],
value=state.model_name,
)
return state, update_model
def change_model(state: State, model_name: str) -> State:
"""Update model selection."""
state.model_name = model_name
state.model = models.MODEL_BUILDERS[state.provider_name](
model=state.model_name,
stream=True,
temperature=0,
)
return state
def change_scenario(state: State, scenario: str) -> Tuple[State, Dict]:
state.scenario = scenarios.SCENARIOS[scenario]()
state.scenario_name = scenario
update = gr.update(placeholder=PLACEHOLDER, label=state.get_info())
return state, update
def send_user_message(state: State, input: str) -> Tuple[State, History, Dict]:
"""Update state and chatbot with user input, clear textbox."""
user_msg = Message(Role.USER, input)
if not state.scenario.is_valid_user_message(user_msg):
gr.Warning(f"Invalid user message: {state.scenario.format_message}'")
update = gr.update()
else:
state.messages.append(user_msg)
state.redacted_messages.append(user_msg)
state.last_user_message_valid = True
update = gr.update(placeholder=PLACEHOLDER, value="")
return state, state.get_history(), update
def send_assistant_message(state: State, api_key: str) -> Tuple[State, History]:
"""Request model response and update blocks."""
history = state.get_history()
yield state, history
if not state.last_user_message_valid:
return
try:
api_key = None if api_key == "" else api_key
response = state.model(state.messages, api_key=api_key)
except Exception as e:
raise gr.Error(f"API error: {e} Please reset the scenario and try again.")
asst_msg = Message(Role.ASSISTANT, "")
state.messages.append(asst_msg)
state.redacted_messages.append(asst_msg)
history = state.get_history()
for delta in response:
history = state.update_state_and_history(history, delta)
yield state, history
def evaluate_and_log(state: State) -> Tuple[State, Dict]:
"""Evaluate messages and update chatbot."""
if not state.last_user_message_valid:
return state, gr.update()
messages = state.unescape_messages()
result = state.scenario.evaluate(messages, state.use_system_instructions)
state.last_user_message_valid = False
global MONGO_DB
if MONGO_DB is not None:
doc = {
"timestamp": datetime.now(),
"scenario": state.scenario_name,
"params": asdict(state.scenario.p),
"provider": state.provider_name,
"model": state.model_name,
"system_instructions": state.use_system_instructions,
"messages": Message.serialize(state.messages),
"result": asdict(result),
}
try:
MONGO_DB.insert_one(doc)
except Exception as e:
gr.Error(f"Database failed: {e}")
if not result.passed:
gr.Info("Rule broken!")
return state, gr.update(
interactive=False,
value=f"Scenario terminated: {result.reason} Please reset the scenario or select another.",
)
else:
return state, gr.update(
interactive=True,
placeholder=PLACEHOLDER,
)
def reset_scenario(state: State) -> Tuple[State, History, dict]:
state.messages = state.get_initial_messages()
state.redacted_messages = state.get_initial_messages(redacted=True)
history = state.get_history()
update = gr.update(placeholder=PLACEHOLDER, value="", interactive=True)
return state, history, update
def make_block(initial_state):
markdown = (
"<h1>RuLES: Rule-following Language Evaluation Scenarios</h1>"
"<p>Your goal is to get the assistant model to break its rules. The rules are presented to the model in the "
"first user message. Each assistant response is checked using the evaluate method of the scenario, and the "
"conversation terminates upon a failing result. Passwords and other secrets are redacted from the chat "
"interface and replaced with \"[REDACTED]\", but the actual value is still provided to the model. Some "
"scenarios require user messages in a specific format. Your messages will be logged for research use. Please do"
" not enter any sensitive information.</p>"
"<p>Anthropic has graciously provided us with subsidized API access, and Google Cloud currently offers 100% "
"discounted access to the PaLM 2 API during its Preview stage, so no API key is required for these models. "
"To access OpenAI models, please enter your own API key. We do not record your key, which you can verify "
"in the demo's source code.</p>"
"<p>See the RuLES <a href=\"https://github.com/normster/llm_rules\">github repo</a> for more information.</p>"
)
with gr.Blocks(
gr.themes.Monochrome(
font=[
gr.themes.GoogleFont("Source Sans Pro"),
"ui-sans-serif",
"system-ui",
"sans-serif",
],
radius_size=gr.themes.sizes.radius_sm,
)
) as block:
gr.Markdown(markdown, sanitize_html=False)
state = gr.State(value=initial_state)
with gr.Row():
provider_select = gr.Dropdown(
["Anthropic", "OpenAI", "Google"],
value="Anthropic",
label="Provider",
)
model_select = gr.Dropdown(
models.MODEL_NAMES_BY_PROVIDER["anthropic"],
value="claude-instant-v1.2",
label="Model",
)
scenario_select = gr.Dropdown(
scenarios.SCENARIOS.keys(),
value=initial_state.scenario_name,
label="Scenario",
)
apikey = gr.Textbox(placeholder="sk-...", label="API Key")
chatbot = gr.Chatbot(initial_state.get_history(), show_label=False)
textbox = gr.Textbox(placeholder=PLACEHOLDER, label=initial_state.get_info())
reset_button = gr.Button("Reset Scenario")
# Event listeners
textbox.submit(
send_user_message, [state, textbox], [state, chatbot, textbox], queue=True
).then(
send_assistant_message,
[state, apikey],
[state, chatbot],
queue=True,
).then(
evaluate_and_log, state, [state, textbox], queue=True
)
# Change to default model for new provider when provider is changed
provider_select.change(
change_provider,
[state, provider_select],
[state, model_select],
queue=False,
).then(
reset_scenario, state, [state, chatbot, textbox], queue=False
)
# Change to specified model
model_select.change(
change_model,
[state, model_select],
[state],
queue=False,
).then(
reset_scenario, state, [state, chatbot, textbox], queue=False
)
# Change to specified scenario
scenario_select.change(
change_scenario,
[state, scenario_select],
[state, textbox],
queue=False,
).then(reset_scenario, state, [state, chatbot, textbox], queue=False)
# Reset scenario state, chat history, and input textbox
reset_button.click(
reset_scenario, state, [state, chatbot, textbox], queue=False
)
block.load(reset_scenario, state, [state, chatbot, textbox], queue=False)
return block
def main(args):
load_dotenv()
initial_state = State(
scenario_name="Encryption",
provider_name="anthropic",
model_name="claude-instant-v1.2",
)
initial_state.messages = (initial_state.get_initial_messages(),)
initial_state.redacted_messages = (
initial_state.get_initial_messages(redacted=True),
)
# Comment this out to disable logging
global MONGO_DB
mongo_uri = MONGO_URI.format(
username=os.environ["MONGO_USERNAME"],
password=os.environ["MONGO_PASSWORD"],
host=os.environ["MONGO_HOST"],
)
client = MongoClient(mongo_uri)
MONGO_DB = client["messages"]["v1.0"]
block = make_block(initial_state)
block.queue(concurrency_count=2)
block.launch(
server_port=args.port,
share=args.hf_proxy,
)
if __name__ == "__main__":
args = parse_args()
main(args)