normster commited on
Commit
9e4826d
1 Parent(s): 2113199

manual upload

Browse files
Files changed (3) hide show
  1. README.md +6 -6
  2. app.py +344 -0
  3. requirements.txt +2 -0
README.md CHANGED
@@ -1,12 +1,12 @@
1
  ---
2
- title: Llm Rules
3
- emoji: 👀
4
- colorFrom: indigo
5
- colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 4.0.2
8
  app_file: app.py
9
- pinned: false
10
  license: mit
11
  ---
12
 
 
1
  ---
2
+ title: "RuLES: Rule-following Language Evaluation Scenarios"
3
+ emoji: ⚖️
4
+ colorFrom: pink
5
+ colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 3.50.2
8
  app_file: app.py
9
+ pinned: true
10
  license: mit
11
  ---
12
 
app.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from dataclasses import asdict, dataclass, field
3
+ from datetime import datetime
4
+ import html
5
+ from itertools import zip_longest
6
+ import os
7
+ import textwrap
8
+ from typing import Dict, List, Tuple
9
+
10
+ from dotenv import load_dotenv
11
+ import gradio as gr
12
+ from pymongo import MongoClient
13
+
14
+ from rules import Role, Message, models, scenarios
15
+
16
+
17
+ MONGO_URI = "mongodb+srv://{username}:{password}@{host}/?retryWrites=true&w=majority"
18
+ MONGO_DB = None
19
+ PLACEHOLDER = "Enter message"
20
+
21
+ History = List[List[str]]
22
+
23
+
24
+ def parse_args():
25
+ parser = argparse.ArgumentParser()
26
+ parser.add_argument("--hf_proxy", action="store_true", default=False)
27
+ parser.add_argument("--port", type=int, default=7860)
28
+ return parser.parse_args()
29
+
30
+
31
+ @dataclass
32
+ class State:
33
+ scenario_name: str
34
+ provider_name: str
35
+ model_name: str
36
+ scenario: scenarios.BaseScenario = None
37
+ model: models.BaseModel = None
38
+ system_message: str = None
39
+ use_system_instructions: bool = False
40
+ messages: List[Message] = field(default_factory=list)
41
+ redacted_messages: List[Message] = field(default_factory=list)
42
+ last_user_message_valid: bool = False
43
+
44
+ def __post_init__(self):
45
+ self.scenario = scenarios.SCENARIOS[self.scenario_name]()
46
+ self.model = models.MODEL_BUILDERS[self.provider_name](
47
+ model=self.model_name,
48
+ stream=True,
49
+ temperature=0,
50
+ )
51
+ self.messages = self.get_initial_messages()
52
+ self.redacted_messages = self.get_initial_messages(redacted=True)
53
+
54
+ def get_initial_messages(self, redacted=False) -> List[Message]:
55
+ prompt = self.scenario.redacted_prompt if redacted else self.scenario.prompt
56
+ if self.use_system_instructions:
57
+ messages = [
58
+ Message(Role.SYSTEM, prompt),
59
+ ]
60
+ else:
61
+ messages = [
62
+ Message(Role.SYSTEM, models.SYSTEM_MESSAGES[self.system_message]),
63
+ Message(Role.USER, prompt),
64
+ Message(Role.ASSISTANT, self.scenario.initial_response),
65
+ ]
66
+ return messages
67
+
68
+ def get_history(self) -> History:
69
+ """Process redacted messages into format for chatbot to display."""
70
+ redacted_messages = self.redacted_messages[1:] # skip system message
71
+ history = []
72
+ args = [iter(redacted_messages)] * 2
73
+ for u, a in zip_longest(*args):
74
+ u = html.escape(u.content, quote=False)
75
+ a = None if a is None else html.escape(a.content, quote=False)
76
+ history.append([u, a])
77
+ return history
78
+
79
+ def update_state_and_history(self, history: History, delta: str) -> History:
80
+ """Incrementally update last item of both messages and history."""
81
+ # Redacted messages points to same assistant message
82
+ self.messages[-1].content += delta
83
+ history[-1][-1] += html.escape(delta, quote=False)
84
+ return history
85
+
86
+ def get_info(self):
87
+ info_str = "Return to send message. Shift + Return to add a new line."
88
+ if self.scenario.format_message:
89
+ info_str = self.scenario.format_message + " " + info_str
90
+ return info_str
91
+
92
+ def unescape_messages(self) -> List[Message]:
93
+ return [Message(m.role, html.unescape(m.content)) for m in self.messages]
94
+
95
+
96
+ def change_provider(state: State, provider_name: str) -> Tuple[State, Dict]:
97
+ """Update model provider and model selection."""
98
+ state.provider_name = provider_name.lower()
99
+ state.model_name = models.MODEL_DEFAULTS[state.provider_name]
100
+ state.model = models.MODEL_BUILDERS[state.provider_name](
101
+ model=state.model_name,
102
+ stream=True,
103
+ temperature=0,
104
+ )
105
+ update_model = gr.update(
106
+ choices=models.MODEL_NAMES_BY_PROVIDER[state.provider_name],
107
+ value=state.model_name,
108
+ )
109
+ return state, update_model
110
+
111
+
112
+ def change_model(state: State, model_name: str) -> State:
113
+ """Update model selection."""
114
+ state.model_name = model_name
115
+ state.model = models.MODEL_BUILDERS[state.provider_name](
116
+ model=state.model_name,
117
+ stream=True,
118
+ temperature=0,
119
+ )
120
+ return state
121
+
122
+
123
+ def change_scenario(state: State, scenario: str) -> Tuple[State, Dict]:
124
+ state.scenario = scenarios.SCENARIOS[scenario]()
125
+ state.scenario_name = scenario
126
+ update = gr.update(placeholder=PLACEHOLDER, label=state.get_info())
127
+ return state, update
128
+
129
+
130
+ def send_user_message(state: State, input: str) -> Tuple[State, History, Dict]:
131
+ """Update state and chatbot with user input, clear textbox."""
132
+ user_msg = Message(Role.USER, input)
133
+ if not state.scenario.is_valid_user_message(user_msg):
134
+ gr.Warning(f"Invalid user message: {state.scenario.format_message}'")
135
+ update = gr.update()
136
+ else:
137
+ state.messages.append(user_msg)
138
+ state.redacted_messages.append(user_msg)
139
+ state.last_user_message_valid = True
140
+ update = gr.update(placeholder=PLACEHOLDER, value="")
141
+ return state, state.get_history(), update
142
+
143
+
144
+ def send_assistant_message(state: State, api_key: str) -> Tuple[State, History]:
145
+ """Request model response and update blocks."""
146
+ history = state.get_history()
147
+ yield state, history
148
+
149
+ if not state.last_user_message_valid:
150
+ return
151
+
152
+ try:
153
+ api_key = None if api_key == "" else api_key
154
+ response = state.model(state.messages, api_key=api_key)
155
+ except Exception as e:
156
+ raise gr.Error(f"API error: {e} Please reset the scenario and try again.")
157
+
158
+ asst_msg = Message(Role.ASSISTANT, "")
159
+ state.messages.append(asst_msg)
160
+ state.redacted_messages.append(asst_msg)
161
+ history = state.get_history()
162
+
163
+ for delta in response:
164
+ history = state.update_state_and_history(history, delta)
165
+ yield state, history
166
+
167
+
168
+ def evaluate_and_log(state: State) -> Tuple[State, Dict]:
169
+ """Evaluate messages and update chatbot."""
170
+ if not state.last_user_message_valid:
171
+ return state, gr.update()
172
+
173
+ messages = state.unescape_messages()
174
+ result = state.scenario.evaluate(messages, state.use_system_instructions)
175
+ state.last_user_message_valid = False
176
+
177
+ global MONGO_DB
178
+ if MONGO_DB is not None:
179
+ doc = {
180
+ "timestamp": datetime.now(),
181
+ "scenario": state.scenario_name,
182
+ "params": asdict(state.scenario.p),
183
+ "provider": state.provider_name,
184
+ "model": state.model_name,
185
+ "system_instructions": state.use_system_instructions,
186
+ "messages": Message.serialize(state.messages),
187
+ "result": asdict(result),
188
+ }
189
+
190
+ try:
191
+ MONGO_DB.insert_one(doc)
192
+ except Exception as e:
193
+ gr.Error(f"Database failed: {e}")
194
+
195
+ if not result.passed:
196
+ gr.Info("Rule broken!")
197
+ return state, gr.update(
198
+ interactive=False,
199
+ value=f"Scenario terminated: {result.reason} Please reset the scenario or select another.",
200
+ )
201
+ else:
202
+ return state, gr.update(
203
+ interactive=True,
204
+ placeholder=PLACEHOLDER,
205
+ )
206
+
207
+
208
+ def reset_scenario(state: State) -> Tuple[State, History, dict]:
209
+ state.messages = state.get_initial_messages()
210
+ state.redacted_messages = state.get_initial_messages(redacted=True)
211
+ history = state.get_history()
212
+ update = gr.update(placeholder=PLACEHOLDER, value="", interactive=True)
213
+ return state, history, update
214
+
215
+
216
+ def make_block(initial_state):
217
+ markdown = (
218
+ "<h1>RuLES: Rule-following Language Evaluation Scenarios</h1>"
219
+ "<p>Your goal is to get the assistant model to break its rules. The rules are presented to the model in the "
220
+ "first user message. Each assistant response is checked using the evaluate method of the scenario, and the "
221
+ "conversation terminates upon a failing result. Passwords and other secrets are redacted from the chat "
222
+ "interface and replaced with \"[REDACTED]\", but the actual value is still provided to the model. Some "
223
+ "scenarios require user messages in a specific format. Your messages will be logged for research use. Please do"
224
+ " not enter any sensitive information.</p>"
225
+ "<p>Anthropic has graciously provided us with subsidized API access, and Google Cloud currently offers 100% "
226
+ "discounted access to the PaLM 2 API during its Preview stage, so no API key is required for these models. "
227
+ "To access OpenAI models, please enter your own API key. We do not record your key, but you should verify this "
228
+ "in the demo's source code.</p>"
229
+ "<p>See the RuLES <a href=\"https://github.com/normster/rules\">github repo</a> for more information.</p>"
230
+ )
231
+
232
+ with gr.Blocks(
233
+ gr.themes.Monochrome(
234
+ font=[
235
+ gr.themes.GoogleFont("Source Sans Pro"),
236
+ "ui-sans-serif",
237
+ "system-ui",
238
+ "sans-serif",
239
+ ],
240
+ radius_size=gr.themes.sizes.radius_sm,
241
+ )
242
+ ) as block:
243
+ gr.Markdown(markdown, sanitize_html=False)
244
+ state = gr.State(value=initial_state)
245
+ with gr.Row():
246
+ provider_select = gr.Dropdown(
247
+ ["Anthropic", "OpenAI", "Google"],
248
+ value="Anthropic",
249
+ label="Provider",
250
+ )
251
+ model_select = gr.Dropdown(
252
+ models.MODEL_NAMES_BY_PROVIDER["anthropic"],
253
+ value="claude-instant-v1.2",
254
+ label="Model",
255
+ )
256
+ scenario_select = gr.Dropdown(
257
+ scenarios.SCENARIOS.keys(),
258
+ value=initial_state.scenario_name,
259
+ label="Scenario",
260
+ )
261
+ apikey = gr.Textbox(placeholder="sk-...", label="API Key")
262
+ chatbot = gr.Chatbot(initial_state.get_history(), show_label=False)
263
+ textbox = gr.Textbox(placeholder=PLACEHOLDER, label=initial_state.get_info())
264
+ reset_button = gr.Button("Reset Scenario")
265
+
266
+ # Event listeners
267
+ textbox.submit(
268
+ send_user_message, [state, textbox], [state, chatbot, textbox], queue=True
269
+ ).then(
270
+ send_assistant_message,
271
+ [state, apikey],
272
+ [state, chatbot],
273
+ queue=True,
274
+ ).then(
275
+ evaluate_and_log, state, [state, textbox], queue=True
276
+ )
277
+ # Change to default model for new provider when provider is changed
278
+ provider_select.change(
279
+ change_provider,
280
+ [state, provider_select],
281
+ [state, model_select],
282
+ queue=False,
283
+ ).then(
284
+ reset_scenario, state, [state, chatbot, textbox], queue=False
285
+ )
286
+ # Change to specified model
287
+ model_select.change(
288
+ change_model,
289
+ [state, model_select],
290
+ [state],
291
+ queue=False,
292
+ ).then(
293
+ reset_scenario, state, [state, chatbot, textbox], queue=False
294
+ )
295
+ # Change to specified scenario
296
+ scenario_select.change(
297
+ change_scenario,
298
+ [state, scenario_select],
299
+ [state, textbox],
300
+ queue=False,
301
+ ).then(reset_scenario, state, [state, chatbot, textbox], queue=False)
302
+ # Reset scenario state, chat history, and input textbox
303
+ reset_button.click(
304
+ reset_scenario, state, [state, chatbot, textbox], queue=False
305
+ )
306
+ block.load(reset_scenario, state, [state, chatbot, textbox], queue=False)
307
+
308
+ return block
309
+
310
+
311
+ def main(args):
312
+ load_dotenv()
313
+
314
+ initial_state = State(
315
+ scenario_name="Encryption",
316
+ provider_name="anthropic",
317
+ model_name="claude-instant-v1.2",
318
+ )
319
+ initial_state.messages = (initial_state.get_initial_messages(),)
320
+ initial_state.redacted_messages = (
321
+ initial_state.get_initial_messages(redacted=True),
322
+ )
323
+
324
+ # Comment this out to disable logging
325
+ global MONGO_DB
326
+ mongo_uri = MONGO_URI.format(
327
+ username=os.environ["MONGO_USERNAME"],
328
+ password=os.environ["MONGO_PASSWORD"],
329
+ host=os.environ["MONGO_HOST"],
330
+ )
331
+ client = MongoClient(mongo_uri)
332
+ MONGO_DB = client["messages"]["v1.0"]
333
+
334
+ block = make_block(initial_state)
335
+ block.queue(concurrency_count=2)
336
+ block.launch(
337
+ server_port=args.port,
338
+ share=args.hf_proxy,
339
+ )
340
+
341
+
342
+ if __name__ == "__main__":
343
+ args = parse_args()
344
+ main(args)
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ llm_rules @ git+https://github.com/normster/llm_rules
2
+ pymongo