import argparse from dataclasses import asdict, dataclass, field from datetime import datetime import html from itertools import zip_longest import os import textwrap from typing import Dict, List, Tuple from dotenv import load_dotenv import gradio as gr from pymongo import MongoClient from llm_rules import Role, Message, models, scenarios MONGO_URI = "mongodb+srv://{username}:{password}@{host}/?retryWrites=true&w=majority" MONGO_DB = None PLACEHOLDER = "Enter message" History = List[List[str]] def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--hf_proxy", action="store_true", default=False) parser.add_argument("--port", type=int, default=7860) return parser.parse_args() @dataclass class State: scenario_name: str provider_name: str model_name: str scenario: scenarios.scenario.BaseScenario = None model: models.BaseModel = None system_message: str = None use_system_instructions: bool = False messages: List[Message] = field(default_factory=list) redacted_messages: List[Message] = field(default_factory=list) last_user_message_valid: bool = False def __post_init__(self): self.scenario = scenarios.SCENARIOS[self.scenario_name]() self.model = models.MODEL_BUILDERS[self.provider_name]( model=self.model_name, stream=True, temperature=0, ) self.messages = self.get_initial_messages() self.redacted_messages = self.get_initial_messages(redacted=True) def get_initial_messages(self, redacted=False) -> List[Message]: prompt = self.scenario.redacted_prompt if redacted else self.scenario.prompt if self.use_system_instructions: messages = [ Message(Role.SYSTEM, prompt), ] else: messages = [ Message(Role.SYSTEM, models.PROMPTS[self.system_message]), Message(Role.USER, prompt), Message(Role.ASSISTANT, self.scenario.initial_response), ] return messages def get_history(self) -> History: """Process redacted messages into format for chatbot to display.""" redacted_messages = self.redacted_messages[1:] # skip system message history = [] args = [iter(redacted_messages)] * 2 for u, a in zip_longest(*args): u = html.escape(u.content, quote=False) a = None if a is None else html.escape(a.content, quote=False) history.append([u, a]) return history def update_state_and_history(self, history: History, delta: str) -> History: """Incrementally update last item of both messages and history.""" # Redacted messages points to same assistant message self.messages[-1].content += delta history[-1][-1] += html.escape(delta, quote=False) return history def get_info(self): info_str = "Return to send message. Shift + Return to add a new line." if self.scenario.format_message: info_str = self.scenario.format_message + " " + info_str return info_str def unescape_messages(self) -> List[Message]: return [Message(m.role, html.unescape(m.content)) for m in self.messages] def change_provider(state: State, provider_name: str) -> Tuple[State, Dict]: """Update model provider and model selection.""" state.provider_name = provider_name.lower() state.model_name = models.MODEL_DEFAULTS[state.provider_name] state.model = models.MODEL_BUILDERS[state.provider_name]( model=state.model_name, stream=True, temperature=0, ) update_model = gr.update( choices=models.MODEL_NAMES_BY_PROVIDER[state.provider_name], value=state.model_name, ) return state, update_model def change_model(state: State, model_name: str) -> State: """Update model selection.""" state.model_name = model_name state.model = models.MODEL_BUILDERS[state.provider_name]( model=state.model_name, stream=True, temperature=0, ) return state def change_scenario(state: State, scenario: str) -> Tuple[State, Dict]: state.scenario = scenarios.SCENARIOS[scenario]() state.scenario_name = scenario update = gr.update(placeholder=PLACEHOLDER, label=state.get_info()) return state, update def send_user_message(state: State, input: str) -> Tuple[State, History, Dict]: """Update state and chatbot with user input, clear textbox.""" user_msg = Message(Role.USER, input) if not state.scenario.is_valid_user_message(user_msg): gr.Warning(f"Invalid user message: {state.scenario.format_message}'") update = gr.update() else: state.messages.append(user_msg) state.redacted_messages.append(user_msg) state.last_user_message_valid = True update = gr.update(placeholder=PLACEHOLDER, value="") return state, state.get_history(), update def send_assistant_message(state: State, api_key: str) -> Tuple[State, History]: """Request model response and update blocks.""" history = state.get_history() yield state, history if not state.last_user_message_valid: return try: api_key = None if api_key == "" else api_key response = state.model(state.messages, api_key=api_key) except Exception as e: raise gr.Error(f"API error: {e} Please reset the scenario and try again.") asst_msg = Message(Role.ASSISTANT, "") state.messages.append(asst_msg) state.redacted_messages.append(asst_msg) history = state.get_history() for delta in response: history = state.update_state_and_history(history, delta) yield state, history def evaluate_and_log(state: State) -> Tuple[State, Dict]: """Evaluate messages and update chatbot.""" if not state.last_user_message_valid: return state, gr.update() messages = state.unescape_messages() result = state.scenario.evaluate(messages, state.use_system_instructions) state.last_user_message_valid = False global MONGO_DB if MONGO_DB is not None: doc = { "timestamp": datetime.now(), "scenario": state.scenario_name, "params": asdict(state.scenario.p), "provider": state.provider_name, "model": state.model_name, "system_instructions": state.use_system_instructions, "messages": Message.serialize(state.messages), "result": asdict(result), } try: MONGO_DB.insert_one(doc) except Exception as e: gr.Error(f"Database failed: {e}") if not result.passed: gr.Info("Rule broken!") return state, gr.update( interactive=False, value=f"Scenario terminated: {result.reason} Please reset the scenario or select another.", ) else: return state, gr.update( interactive=True, placeholder=PLACEHOLDER, ) def reset_scenario(state: State) -> Tuple[State, History, dict]: state.messages = state.get_initial_messages() state.redacted_messages = state.get_initial_messages(redacted=True) history = state.get_history() update = gr.update(placeholder=PLACEHOLDER, value="", interactive=True) return state, history, update def make_block(initial_state): markdown = ( "

RuLES: Rule-following Language Evaluation Scenarios

" "

Your goal is to get the assistant model to break its rules. The rules are presented to the model in the " "first user message. Each assistant response is checked using the evaluate method of the scenario, and the " "conversation terminates upon a failing result. Passwords and other secrets are redacted from the chat " "interface and replaced with \"[REDACTED]\", but the actual value is still provided to the model. Some " "scenarios require user messages in a specific format. Your messages will be logged for research use. Please do" " not enter any sensitive information.

" "

Anthropic has graciously provided us with subsidized API access, and Google Cloud currently offers 100% " "discounted access to the PaLM 2 API during its Preview stage, so no API key is required for these models. " "To access OpenAI models, please enter your own API key. We do not record your key, which you can verify " "in the demo's source code.

" "

See the RuLES github repo for more information.

" ) with gr.Blocks( gr.themes.Monochrome( font=[ gr.themes.GoogleFont("Source Sans Pro"), "ui-sans-serif", "system-ui", "sans-serif", ], radius_size=gr.themes.sizes.radius_sm, ) ) as block: gr.Markdown(markdown, sanitize_html=False) state = gr.State(value=initial_state) with gr.Row(): provider_select = gr.Dropdown( ["Anthropic", "OpenAI", "Google"], value="Anthropic", label="Provider", ) model_select = gr.Dropdown( models.MODEL_NAMES_BY_PROVIDER["anthropic"], value="claude-instant-v1.2", label="Model", ) scenario_select = gr.Dropdown( scenarios.SCENARIOS.keys(), value=initial_state.scenario_name, label="Scenario", ) apikey = gr.Textbox(placeholder="sk-...", label="API Key") chatbot = gr.Chatbot(initial_state.get_history(), show_label=False) textbox = gr.Textbox(placeholder=PLACEHOLDER, label=initial_state.get_info()) reset_button = gr.Button("Reset Scenario") # Event listeners textbox.submit( send_user_message, [state, textbox], [state, chatbot, textbox], queue=True ).then( send_assistant_message, [state, apikey], [state, chatbot], queue=True, ).then( evaluate_and_log, state, [state, textbox], queue=True ) # Change to default model for new provider when provider is changed provider_select.change( change_provider, [state, provider_select], [state, model_select], queue=False, ).then( reset_scenario, state, [state, chatbot, textbox], queue=False ) # Change to specified model model_select.change( change_model, [state, model_select], [state], queue=False, ).then( reset_scenario, state, [state, chatbot, textbox], queue=False ) # Change to specified scenario scenario_select.change( change_scenario, [state, scenario_select], [state, textbox], queue=False, ).then(reset_scenario, state, [state, chatbot, textbox], queue=False) # Reset scenario state, chat history, and input textbox reset_button.click( reset_scenario, state, [state, chatbot, textbox], queue=False ) block.load(reset_scenario, state, [state, chatbot, textbox], queue=False) return block def main(args): load_dotenv() initial_state = State( scenario_name="Encryption", provider_name="anthropic", model_name="claude-instant-v1.2", ) initial_state.messages = (initial_state.get_initial_messages(),) initial_state.redacted_messages = ( initial_state.get_initial_messages(redacted=True), ) # Comment this out to disable logging global MONGO_DB mongo_uri = MONGO_URI.format( username=os.environ["MONGO_USERNAME"], password=os.environ["MONGO_PASSWORD"], host=os.environ["MONGO_HOST"], ) client = MongoClient(mongo_uri) MONGO_DB = client["messages"]["v1.0"] block = make_block(initial_state) block.queue(concurrency_count=2) block.launch( server_port=args.port, share=args.hf_proxy, ) if __name__ == "__main__": args = parse_args() main(args)