grahamwhiteuk commited on
Commit
5b7f169
1 Parent(s): 36223d2

fix: deployment

Browse files
Files changed (6) hide show
  1. .flake8 +5 -0
  2. requirements.txt +4 -8
  3. src/app.py +216 -155
  4. src/logger.py +5 -3
  5. src/model.py +57 -60
  6. src/utils.py +34 -27
.flake8 ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ [flake8]
2
+ max-line-length = 120
3
+
4
+ select = C,E,F,W,B,B950
5
+ extend-ignore = E501,E203,W503
requirements.txt CHANGED
@@ -1,10 +1,6 @@
1
- gradio>=4,<5
2
  python-dotenv
3
- tqdm
4
- jinja2
5
- ibm_watsonx_ai
6
  transformers
7
- gradio_modal
8
- spaces
9
- torch
10
- vllm
 
1
+ gradio_modal
2
  python-dotenv
 
 
 
3
  transformers
4
+ accelerate
5
+ ibm_watsonx_ai
6
+ vllm
 
src/app.py CHANGED
@@ -1,121 +1,150 @@
 
 
 
1
  import gradio as gr
2
  from dotenv import load_dotenv
 
 
 
 
 
 
 
 
 
 
 
3
 
4
- from utils import get_result_description, to_title_case, to_snake_case, load_command_line_args, get_messages
5
  load_command_line_args()
6
  load_dotenv()
7
- import json
8
- from model import generate_text, get_prompt
9
- from logger import logger
10
- import os
11
- from gradio_modal import Modal
12
 
13
  catalog = {}
14
 
15
- with open('catalog.json') as f:
16
- logger.debug('Loading catalog from json.')
17
  catalog = json.load(f)
18
 
 
19
  def update_selected_test_case(button_name, state: gr.State, event: gr.EventData):
20
- target_sub_catalog_name, target_test_case_name = event.target.elem_id.split('---')
21
- state['selected_sub_catalog'] = target_sub_catalog_name
22
- state['selected_criteria_name'] = target_test_case_name
23
- state['selected_test_case'] = [t for sub_catalog_name, sub_catalog in catalog.items() for t in sub_catalog if t['name'] == to_snake_case(button_name) and to_snake_case(sub_catalog_name) == target_sub_catalog_name][0]
 
 
 
 
 
24
  return state
25
 
 
26
  def on_test_case_click(state: gr.State):
27
- selected_sub_catalog = state['selected_sub_catalog']
28
- selected_criteria_name = state['selected_criteria_name']
29
- selected_test_case = state['selected_test_case']
30
 
31
  logger.debug(f'Changing to test case "{selected_criteria_name}" from catalog "{selected_sub_catalog}".')
32
 
33
- is_context_iditable = selected_criteria_name == 'context_relevance'
34
- is_user_message_editable = selected_sub_catalog == 'harmful_content_in_user_prompt'
35
- is_assistant_message_editable = selected_sub_catalog == 'harmful_content_in_assistant_response' or \
36
- selected_criteria_name == 'groundedness' or \
37
- selected_criteria_name == 'answer_relevance'
 
 
38
  return {
39
  test_case_name: f'<h2>{to_title_case(selected_test_case["name"])}</h2>',
40
- criteria: selected_test_case['criteria'],
41
- context: gr.update(
42
- value=selected_test_case['context'],
43
- interactive=True,
44
- visible=True,
45
- elem_classes=['input-box']
46
- ) if is_context_iditable else gr.update(
47
- visible=selected_test_case['context'] is not None,
48
- value=selected_test_case['context'],
49
- interactive=False,
50
- elem_classes=['read-only', 'input-box']
51
- ),
52
- user_message: gr.update(
53
- value=selected_test_case['user_message'],
54
- visible=True,
55
- interactive=True,
56
- elem_classes=['input-box']
57
- ) if is_user_message_editable else gr.update(
58
- value=selected_test_case['user_message'],
59
  interactive=False,
60
- elem_classes=['read-only', 'input-box']
61
- ),
62
- assistant_message: gr.update(
63
- value=selected_test_case['assistant_message'],
 
 
 
 
 
 
 
 
 
 
 
64
  visible=True,
65
  interactive=True,
66
- elem_classes=['input-box']
67
- ) if is_assistant_message_editable else gr.update(
68
- visible=selected_test_case['assistant_message'] is not None,
69
- value=selected_test_case['assistant_message'],
 
 
70
  interactive=False,
71
- elem_classes=['read-only', 'input-box']
72
- ),
73
- result_text: gr.update(visible=False, value='')
 
74
  }
75
 
 
76
  def change_button_color(event: gr.EventData):
77
- return [gr.update(
78
- elem_classes=['catalog-button', 'selected']
79
- ) if v.elem_id == event.target.elem_id else gr.update(
80
- elem_classes=['catalog-button']
81
- ) for c in catalog_buttons.values() for v in c.values()]
 
 
 
 
 
82
 
83
  def on_submit(criteria, context, user_message, assistant_message, state):
84
- criteria_name = state['selected_criteria_name']
85
  test_case = {
86
- 'name': criteria_name,
87
- 'criteria': criteria,
88
- 'context': context,
89
- 'user_message': user_message,
90
- 'assistant_message': assistant_message
91
  }
92
 
93
- messages = get_messages(test_case=test_case, sub_catalog_name=state['selected_sub_catalog'])
94
-
95
- logger.debug(f"Starting evaluation for subcatelog {state['selected_sub_catalog']} and criteria name {state['selected_criteria_name']}")
96
-
97
- result_label = generate_text(messages=messages, criteria_name=criteria_name)['assessment'] # Yes or No
 
 
98
 
99
  html_str = f"<p>{get_result_description(state['selected_sub_catalog'], state['selected_criteria_name'])} <strong>{result_label}</strong></p>"
100
  # html_str = f"{get_result_description(state['selected_sub_catalog'], state['selected_criteria_name'])} {result_label}"
101
  return gr.update(value=html_str)
102
 
 
103
  def on_show_prompt_click(criteria, context, user_message, assistant_message, state):
104
- criteria_name = state['selected_criteria_name']
105
  test_case = {
106
- 'name': criteria_name,
107
- 'criteria': criteria,
108
- 'context': context,
109
- 'user_message': user_message,
110
- 'assistant_message': assistant_message
111
  }
112
 
113
- messages = get_messages(test_case=test_case, sub_catalog_name=state['selected_sub_catalog'])
114
  prompt = get_prompt(messages, criteria_name)
115
  print(prompt)
116
- prompt = prompt.replace('<', '&lt;').replace('>', '&gt;').replace('\\n', '<br>')
117
  return gr.Markdown(prompt)
118
 
 
119
  ibm_blue = gr.themes.Color(
120
  name="ibm-blue",
121
  c50="#eff6ff",
@@ -128,7 +157,7 @@ ibm_blue = gr.themes.Color(
128
  c700="#1d4ed8",
129
  c800="#1e40af",
130
  c900="#1e3a8a",
131
- c950="#1d3660"
132
  )
133
 
134
  head_style = """
@@ -149,107 +178,139 @@ head_style = """
149
  """
150
 
151
  with gr.Blocks(
152
- title='Granite Guardian',
153
- theme=gr.themes.Soft(
154
- primary_hue=ibm_blue,
155
- font=[gr.themes.GoogleFont("IBM Plex Sans"), gr.themes.GoogleFont('Source Sans 3')],
156
- ),
157
- head=head_style,
158
- fill_width=False,
159
- css=os.path.join(os.path.dirname(os.path.abspath(__file__)), './styles.css')
160
- ) as demo:
161
-
162
- state = gr.State(value={
163
- 'selected_sub_catalog': 'harmful_content_in_user_prompt',
164
- 'selected_criteria_name': 'general_harm'
165
- })
166
-
167
- starting_test_case = [t for sub_catalog_name, sub_catalog in catalog.items() for t in sub_catalog if t['name'] == state.value['selected_criteria_name'] and sub_catalog_name == state.value['selected_sub_catalog']][0]
168
-
169
- with gr.Row(elem_classes='header-row'):
 
 
 
 
 
170
  with gr.Column(scale=4):
171
- gr.HTML('<h2>IBM Granite Guardian 3.0</h2>', elem_classes='title')
172
- gr.HTML(elem_classes='system-description', value='<p>Granite Guardian models are specialized language models in the Granite family that can detect harms and risks in generative AI systems. They can be used with any large language model to make interactions with generative AI systems safe. Select an example in the left panel to see how the Granite Guardian model evaluates harms and risks in user prompts, assistant responses, and for hallucinations in retrieval-augmented generation. In this demo, we use granite-guardian-3.0-8b.</p>')
173
- with gr.Row(elem_classes='column-gap'):
174
- with gr.Column(scale=0, elem_classes='no-gap'):
175
- title_display_left = gr.HTML("<h2>Harms & Risks</h2>", elem_classes=['subtitle', 'subtitle-harms'])
 
 
 
176
  accordions = []
177
- catalog_buttons: dict[str,dict[str,gr.Button]] = {}
178
  for i, (sub_catalog_name, sub_catalog) in enumerate(catalog.items()):
179
- with gr.Accordion(to_title_case(sub_catalog_name), open=(i==0), elem_classes='accordion') as accordion:
 
 
180
  for test_case in sub_catalog:
181
- elem_classes=['catalog-button']
182
- elem_id=f"{sub_catalog_name}---{test_case['name']}"
183
  if starting_test_case == test_case:
184
- elem_classes.append('selected')
185
 
186
- if not sub_catalog_name in catalog_buttons:
187
  catalog_buttons[sub_catalog_name] = {}
188
 
189
- catalog_buttons[sub_catalog_name][test_case['name']] = \
190
- gr.Button(to_title_case(test_case['name']), elem_classes=elem_classes, variant='secondary', size='sm', elem_id=elem_id)
191
-
 
 
 
 
 
192
  accordions.append(accordion)
193
 
194
  with gr.Column(visible=True, scale=1) as test_case_content:
195
- with gr.Row(elem_classes='no-stretch'):
196
- test_case_name = gr.HTML(f'<h2>{to_title_case(starting_test_case["name"])}</h2>', elem_classes='subtitle')
197
- show_propt_button = gr.Button('Show prompt', size='sm', scale=0, min_width=110)
 
 
198
 
199
- criteria = gr.Textbox(label="Evaluation Criteria", lines=3, interactive=False, value=starting_test_case['criteria'], elem_classes=['read-only', 'input-box', 'margin-bottom'])
200
- gr.HTML(elem_classes=['block', 'content-gap'])
201
- context = gr.Textbox(label="Context", lines=3, interactive=True, value=starting_test_case['context'], visible=False, elem_classes=['input-box'])
202
- user_message = gr.Textbox(label="User Prompt", lines=3, interactive=True, value=starting_test_case['user_message'], elem_classes=['input-box'])
203
- assistant_message = gr.Textbox(label="Assistant Response", lines=3, interactive=True, visible=False, value=starting_test_case['assistant_message'], elem_classes=['input-box'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
 
205
  submit_button = gr.Button(
206
- "Evaluate",
207
- variant='primary',
208
- icon=os.path.join(os.path.dirname(os.path.abspath(__file__)), 'send-white.png'),
209
- elem_classes='submit-button')
210
-
 
211
  # result_text = gr.HTML(label='Result', elem_classes=['result-text', 'read-only', 'input-box'], visible=False, value='')
212
  result_text = gr.HTML(
213
- label='Result',
214
- elem_classes=['result-root'],
215
- show_label=True,
216
- visible=False,
217
- value='')
218
 
219
- with Modal(visible=False, elem_classes='modal') as modal:
220
- prompt = gr.Markdown('')
221
 
222
-
223
  ### events
224
 
225
  show_propt_button.click(
226
- on_show_prompt_click,
227
- inputs=[criteria, context, user_message, assistant_message, state],
228
- outputs=prompt
229
  ).then(lambda: gr.update(visible=True), None, modal)
230
 
231
- submit_button \
232
- .click(lambda: gr.update(visible=True, value=''), None, result_text) \
233
- .then(
234
- on_submit,
235
- inputs=[criteria, context, user_message, assistant_message, state],
236
- outputs=[result_text],
237
- scroll_to_output=True
 
 
 
 
 
 
 
 
 
238
  )
239
-
240
- for button in [t for sub_catalog_name, sub_catalog_buttons in catalog_buttons.items() for t in sub_catalog_buttons.values()]:
241
- button \
242
- .click(
243
- change_button_color,
244
- inputs=None,
245
- outputs=[v for c in catalog_buttons.values() for v in c.values()]) \
246
- .then(
247
- update_selected_test_case,
248
- inputs=[button, state],
249
- outputs=[state]) \
250
- .then(
251
- on_test_case_click,
252
- inputs=state,
253
- outputs={test_case_name, criteria, context, user_message, assistant_message, result_text})
254
-
255
- demo.launch(server_name='0.0.0.0')
 
1
+ import json
2
+ import os
3
+
4
  import gradio as gr
5
  from dotenv import load_dotenv
6
+ from gradio_modal import Modal
7
+
8
+ from logger import logger
9
+ from model import generate_text, get_prompt
10
+ from utils import (
11
+ get_messages,
12
+ get_result_description,
13
+ load_command_line_args,
14
+ to_snake_case,
15
+ to_title_case,
16
+ )
17
 
 
18
  load_command_line_args()
19
  load_dotenv()
 
 
 
 
 
20
 
21
  catalog = {}
22
 
23
+ with open("catalog.json") as f:
24
+ logger.debug("Loading catalog from json.")
25
  catalog = json.load(f)
26
 
27
+
28
  def update_selected_test_case(button_name, state: gr.State, event: gr.EventData):
29
+ target_sub_catalog_name, target_test_case_name = event.target.elem_id.split("---")
30
+ state["selected_sub_catalog"] = target_sub_catalog_name
31
+ state["selected_criteria_name"] = target_test_case_name
32
+ state["selected_test_case"] = [
33
+ t
34
+ for sub_catalog_name, sub_catalog in catalog.items()
35
+ for t in sub_catalog
36
+ if t["name"] == to_snake_case(button_name) and to_snake_case(sub_catalog_name) == target_sub_catalog_name
37
+ ][0]
38
  return state
39
 
40
+
41
  def on_test_case_click(state: gr.State):
42
+ selected_sub_catalog = state["selected_sub_catalog"]
43
+ selected_criteria_name = state["selected_criteria_name"]
44
+ selected_test_case = state["selected_test_case"]
45
 
46
  logger.debug(f'Changing to test case "{selected_criteria_name}" from catalog "{selected_sub_catalog}".')
47
 
48
+ is_context_iditable = selected_criteria_name == "context_relevance"
49
+ is_user_message_editable = selected_sub_catalog == "harmful_content_in_user_prompt"
50
+ is_assistant_message_editable = (
51
+ selected_sub_catalog == "harmful_content_in_assistant_response"
52
+ or selected_criteria_name == "groundedness"
53
+ or selected_criteria_name == "answer_relevance"
54
+ )
55
  return {
56
  test_case_name: f'<h2>{to_title_case(selected_test_case["name"])}</h2>',
57
+ criteria: selected_test_case["criteria"],
58
+ context: (
59
+ gr.update(value=selected_test_case["context"], interactive=True, visible=True, elem_classes=["input-box"])
60
+ if is_context_iditable
61
+ else gr.update(
62
+ visible=selected_test_case["context"] is not None,
63
+ value=selected_test_case["context"],
 
 
 
 
 
 
 
 
 
 
 
 
64
  interactive=False,
65
+ elem_classes=["read-only", "input-box"],
66
+ )
67
+ ),
68
+ user_message: (
69
+ gr.update(
70
+ value=selected_test_case["user_message"], visible=True, interactive=True, elem_classes=["input-box"]
71
+ )
72
+ if is_user_message_editable
73
+ else gr.update(
74
+ value=selected_test_case["user_message"], interactive=False, elem_classes=["read-only", "input-box"]
75
+ )
76
+ ),
77
+ assistant_message: (
78
+ gr.update(
79
+ value=selected_test_case["assistant_message"],
80
  visible=True,
81
  interactive=True,
82
+ elem_classes=["input-box"],
83
+ )
84
+ if is_assistant_message_editable
85
+ else gr.update(
86
+ visible=selected_test_case["assistant_message"] is not None,
87
+ value=selected_test_case["assistant_message"],
88
  interactive=False,
89
+ elem_classes=["read-only", "input-box"],
90
+ )
91
+ ),
92
+ result_text: gr.update(visible=False, value=""),
93
  }
94
 
95
+
96
  def change_button_color(event: gr.EventData):
97
+ return [
98
+ (
99
+ gr.update(elem_classes=["catalog-button", "selected"])
100
+ if v.elem_id == event.target.elem_id
101
+ else gr.update(elem_classes=["catalog-button"])
102
+ )
103
+ for c in catalog_buttons.values()
104
+ for v in c.values()
105
+ ]
106
+
107
 
108
  def on_submit(criteria, context, user_message, assistant_message, state):
109
+ criteria_name = state["selected_criteria_name"]
110
  test_case = {
111
+ "name": criteria_name,
112
+ "criteria": criteria,
113
+ "context": context,
114
+ "user_message": user_message,
115
+ "assistant_message": assistant_message,
116
  }
117
 
118
+ messages = get_messages(test_case=test_case, sub_catalog_name=state["selected_sub_catalog"])
119
+
120
+ logger.debug(
121
+ f"Starting evaluation for subcatelog {state['selected_sub_catalog']} and criteria name {state['selected_criteria_name']}"
122
+ )
123
+
124
+ result_label = generate_text(messages=messages, criteria_name=criteria_name)["assessment"] # Yes or No
125
 
126
  html_str = f"<p>{get_result_description(state['selected_sub_catalog'], state['selected_criteria_name'])} <strong>{result_label}</strong></p>"
127
  # html_str = f"{get_result_description(state['selected_sub_catalog'], state['selected_criteria_name'])} {result_label}"
128
  return gr.update(value=html_str)
129
 
130
+
131
  def on_show_prompt_click(criteria, context, user_message, assistant_message, state):
132
+ criteria_name = state["selected_criteria_name"]
133
  test_case = {
134
+ "name": criteria_name,
135
+ "criteria": criteria,
136
+ "context": context,
137
+ "user_message": user_message,
138
+ "assistant_message": assistant_message,
139
  }
140
 
141
+ messages = get_messages(test_case=test_case, sub_catalog_name=state["selected_sub_catalog"])
142
  prompt = get_prompt(messages, criteria_name)
143
  print(prompt)
144
+ prompt = prompt.replace("<", "&lt;").replace(">", "&gt;").replace("\\n", "<br>")
145
  return gr.Markdown(prompt)
146
 
147
+
148
  ibm_blue = gr.themes.Color(
149
  name="ibm-blue",
150
  c50="#eff6ff",
 
157
  c700="#1d4ed8",
158
  c800="#1e40af",
159
  c900="#1e3a8a",
160
+ c950="#1d3660",
161
  )
162
 
163
  head_style = """
 
178
  """
179
 
180
  with gr.Blocks(
181
+ title="Granite Guardian",
182
+ theme=gr.themes.Soft(
183
+ primary_hue=ibm_blue,
184
+ font=[gr.themes.GoogleFont("IBM Plex Sans"), gr.themes.GoogleFont("Source Sans 3")],
185
+ ),
186
+ head=head_style,
187
+ fill_width=False,
188
+ css=os.path.join(os.path.dirname(os.path.abspath(__file__)), "./styles.css"),
189
+ ) as demo:
190
+
191
+ state = gr.State(
192
+ value={"selected_sub_catalog": "harmful_content_in_user_prompt", "selected_criteria_name": "general_harm"}
193
+ )
194
+
195
+ starting_test_case = [
196
+ t
197
+ for sub_catalog_name, sub_catalog in catalog.items()
198
+ for t in sub_catalog
199
+ if t["name"] == state.value["selected_criteria_name"]
200
+ and sub_catalog_name == state.value["selected_sub_catalog"]
201
+ ][0]
202
+
203
+ with gr.Row(elem_classes="header-row"):
204
  with gr.Column(scale=4):
205
+ gr.HTML("<h2>IBM Granite Guardian 3.0</h2>", elem_classes="title")
206
+ gr.HTML(
207
+ elem_classes="system-description",
208
+ value="<p>Granite Guardian models are specialized language models in the Granite family that can detect harms and risks in generative AI systems. They can be used with any large language model to make interactions with generative AI systems safe. Select an example in the left panel to see how the Granite Guardian model evaluates harms and risks in user prompts, assistant responses, and for hallucinations in retrieval-augmented generation. In this demo, we use granite-guardian-3.0-8b.</p>",
209
+ )
210
+ with gr.Row(elem_classes="column-gap"):
211
+ with gr.Column(scale=0, elem_classes="no-gap"):
212
+ title_display_left = gr.HTML("<h2>Harms & Risks</h2>", elem_classes=["subtitle", "subtitle-harms"])
213
  accordions = []
214
+ catalog_buttons: dict[str, dict[str, gr.Button]] = {}
215
  for i, (sub_catalog_name, sub_catalog) in enumerate(catalog.items()):
216
+ with gr.Accordion(
217
+ to_title_case(sub_catalog_name), open=(i == 0), elem_classes="accordion"
218
+ ) as accordion:
219
  for test_case in sub_catalog:
220
+ elem_classes = ["catalog-button"]
221
+ elem_id = f"{sub_catalog_name}---{test_case['name']}"
222
  if starting_test_case == test_case:
223
+ elem_classes.append("selected")
224
 
225
+ if sub_catalog_name not in catalog_buttons:
226
  catalog_buttons[sub_catalog_name] = {}
227
 
228
+ catalog_buttons[sub_catalog_name][test_case["name"]] = gr.Button(
229
+ to_title_case(test_case["name"]),
230
+ elem_classes=elem_classes,
231
+ variant="secondary",
232
+ size="sm",
233
+ elem_id=elem_id,
234
+ )
235
+
236
  accordions.append(accordion)
237
 
238
  with gr.Column(visible=True, scale=1) as test_case_content:
239
+ with gr.Row(elem_classes="no-stretch"):
240
+ test_case_name = gr.HTML(
241
+ f'<h2>{to_title_case(starting_test_case["name"])}</h2>', elem_classes="subtitle"
242
+ )
243
+ show_propt_button = gr.Button("Show prompt", size="sm", scale=0, min_width=110)
244
 
245
+ criteria = gr.Textbox(
246
+ label="Evaluation Criteria",
247
+ lines=3,
248
+ interactive=False,
249
+ value=starting_test_case["criteria"],
250
+ elem_classes=["read-only", "input-box", "margin-bottom"],
251
+ )
252
+ gr.HTML(elem_classes=["block", "content-gap"])
253
+ context = gr.Textbox(
254
+ label="Context",
255
+ lines=3,
256
+ interactive=True,
257
+ value=starting_test_case["context"],
258
+ visible=False,
259
+ elem_classes=["input-box"],
260
+ )
261
+ user_message = gr.Textbox(
262
+ label="User Prompt",
263
+ lines=3,
264
+ interactive=True,
265
+ value=starting_test_case["user_message"],
266
+ elem_classes=["input-box"],
267
+ )
268
+ assistant_message = gr.Textbox(
269
+ label="Assistant Response",
270
+ lines=3,
271
+ interactive=True,
272
+ visible=False,
273
+ value=starting_test_case["assistant_message"],
274
+ elem_classes=["input-box"],
275
+ )
276
 
277
  submit_button = gr.Button(
278
+ "Evaluate",
279
+ variant="primary",
280
+ icon=os.path.join(os.path.dirname(os.path.abspath(__file__)), "send-white.png"),
281
+ elem_classes="submit-button",
282
+ )
283
+
284
  # result_text = gr.HTML(label='Result', elem_classes=['result-text', 'read-only', 'input-box'], visible=False, value='')
285
  result_text = gr.HTML(
286
+ label="Result", elem_classes=["result-root"], show_label=True, visible=False, value=""
287
+ )
 
 
 
288
 
289
+ with Modal(visible=False, elem_classes="modal") as modal:
290
+ prompt = gr.Markdown("")
291
 
 
292
  ### events
293
 
294
  show_propt_button.click(
295
+ on_show_prompt_click, inputs=[criteria, context, user_message, assistant_message, state], outputs=prompt
 
 
296
  ).then(lambda: gr.update(visible=True), None, modal)
297
 
298
+ submit_button.click(lambda: gr.update(visible=True, value=""), None, result_text).then(
299
+ on_submit,
300
+ inputs=[criteria, context, user_message, assistant_message, state],
301
+ outputs=[result_text],
302
+ scroll_to_output=True,
303
+ )
304
+
305
+ for button in [
306
+ t for sub_catalog_name, sub_catalog_buttons in catalog_buttons.items() for t in sub_catalog_buttons.values()
307
+ ]:
308
+ button.click(
309
+ change_button_color, inputs=None, outputs=[v for c in catalog_buttons.values() for v in c.values()]
310
+ ).then(update_selected_test_case, inputs=[button, state], outputs=[state]).then(
311
+ on_test_case_click,
312
+ inputs=state,
313
+ outputs={test_case_name, criteria, context, user_message, assistant_message, result_text},
314
  )
315
+
316
+ demo.launch(server_name="0.0.0.0")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/logger.py CHANGED
@@ -1,12 +1,14 @@
1
  import logging
2
 
3
- logger = logging.getLogger('demo')
4
  logger.setLevel(logging.DEBUG)
5
 
6
  stream_handler = logging.StreamHandler()
7
  stream_handler.setLevel(logging.DEBUG)
8
  logger.addHandler(stream_handler)
9
 
10
- file_handler = logging.FileHandler('logs.txt')
11
- file_handler.setFormatter(logging.Formatter("%(asctime)s - %(filename)s:%(lineno)d - %(message)s", datefmt="%Y-%m-%d %H:%M:%S"))
 
 
12
  logger.addHandler(file_handler)
 
1
  import logging
2
 
3
+ logger = logging.getLogger("demo")
4
  logger.setLevel(logging.DEBUG)
5
 
6
  stream_handler = logging.StreamHandler()
7
  stream_handler.setLevel(logging.DEBUG)
8
  logger.addHandler(stream_handler)
9
 
10
+ file_handler = logging.FileHandler("logs.txt")
11
+ file_handler.setFormatter(
12
+ logging.Formatter("%(asctime)s - %(filename)s:%(lineno)d - %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
13
+ )
14
  logger.addHandler(file_handler)
src/model.py CHANGED
@@ -1,46 +1,44 @@
1
- import os
2
- from time import time, sleep
3
- from logger import logger
4
  import math
5
  import os
 
 
 
 
6
  from ibm_watsonx_ai.client import APIClient
7
  from ibm_watsonx_ai.foundation_models import ModelInference
8
  from transformers import AutoTokenizer
9
- import math
10
- import spaces
 
11
 
12
  safe_token = "No"
13
  risky_token = "Yes"
14
  nlogprobs = 5
15
 
16
- inference_engine = os.getenv('INFERENCE_ENGINE', 'VLLM')
17
  logger.debug(f"Inference engine is: '{inference_engine}'")
18
 
19
- if inference_engine == 'VLLM':
20
- import torch
21
- from vllm import LLM, SamplingParams
22
- from transformers import AutoTokenizer
23
- model_path = os.getenv('MODEL_PATH', 'ibm-granite/granite-guardian-3.0-8b')
24
  logger.debug(f"model_path is {model_path}")
25
  tokenizer = AutoTokenizer.from_pretrained(model_path)
26
  sampling_params = SamplingParams(temperature=0.0, logprobs=nlogprobs)
27
  model = LLM(model=model_path, tensor_parallel_size=1)
28
 
29
  elif inference_engine == "WATSONX":
30
- client = APIClient(credentials={
31
- 'api_key': os.getenv('WATSONX_API_KEY'),
32
- 'url': 'https://us-south.ml.cloud.ibm.com'})
33
-
34
- client.set.default_project(os.getenv('WATSONX_PROJECT_ID'))
35
  hf_model_path = "ibm-granite/granite-guardian-3.0-8b"
36
  tokenizer = AutoTokenizer.from_pretrained(hf_model_path)
37
 
38
- model_id = "ibm/granite-guardian-3-8b" # 8B Model: "ibm/granite-guardian-3-8b"
39
- model = ModelInference(
40
- model_id=model_id,
41
- api_client=client
42
- )
43
-
44
  def parse_output(output):
45
  label, prob = None, None
46
 
@@ -60,11 +58,13 @@ def parse_output(output):
60
 
61
  return label, prob_of_risk.item()
62
 
 
63
  def softmax(values):
64
  exp_values = [math.exp(v) for v in values]
65
  total = sum(exp_values)
66
  return [v / total for v in exp_values]
67
 
 
68
  def get_probablities(logprobs):
69
  safe_token_prob = 1e-50
70
  unsafe_token_prob = 1e-50
@@ -76,59 +76,55 @@ def get_probablities(logprobs):
76
  if decoded_token.strip().lower() == risky_token.lower():
77
  unsafe_token_prob += math.exp(token_prob.logprob)
78
 
79
- probabilities = torch.softmax(
80
- torch.tensor([math.log(safe_token_prob), math.log(unsafe_token_prob)]), dim=0
81
- )
82
 
83
  return probabilities
84
 
 
85
  def get_probablities_watsonx(top_tokens_list):
86
  safe_token_prob = 1e-50
87
  risky_token_prob = 1e-50
88
  for top_tokens in top_tokens_list:
89
  for token in top_tokens:
90
- if token['text'].strip().lower() == safe_token.lower():
91
- safe_token_prob += math.exp(token['logprob'])
92
- if token['text'].strip().lower() == risky_token.lower():
93
- risky_token_prob += math.exp(token['logprob'])
94
 
95
  probabilities = softmax([math.log(safe_token_prob), math.log(risky_token_prob)])
96
 
97
  return probabilities
98
 
 
99
  def get_prompt(messages, criteria_name):
100
- guardian_config = {"risk_name": criteria_name if criteria_name != 'general_harm' else 'harm'}
101
  return tokenizer.apply_chat_template(
102
- messages,
103
- guardian_config=guardian_config,
104
- tokenize=False,
105
- add_generation_prompt=True)
106
 
107
  def generate_tokens(prompt):
108
  result = model.generate(
109
  prompt=[prompt],
110
  params={
111
- 'decoding_method':'greedy',
112
- 'max_new_tokens': 20,
113
  "temperature": 0,
114
- "return_options": {
115
- "token_logprobs": True,
116
- "generated_tokens": True,
117
- "input_text": True,
118
- "top_n_tokens": 5
119
- }
120
- })
121
- return result[0]['results'][0]['generated_tokens']
122
 
123
  def parse_output_watsonx(generated_tokens_list):
124
  label, prob_of_risk = None, None
125
 
126
  if nlogprobs > 0:
127
- top_tokens_list = [generated_tokens['top_tokens'] for generated_tokens in generated_tokens_list]
128
  prob = get_probablities_watsonx(top_tokens_list)
129
  prob_of_risk = prob[1]
130
 
131
- res = next(iter(generated_tokens_list))['text'].strip()
132
 
133
  if risky_token.lower() == res.lower():
134
  label = risky_token
@@ -139,25 +135,26 @@ def parse_output_watsonx(generated_tokens_list):
139
 
140
  return label, prob_of_risk
141
 
 
142
  @spaces.GPU
143
  def generate_text(messages, criteria_name):
144
- logger.debug(f'Messages used to create the prompt are: \n{messages}')
145
-
146
  start = time()
147
 
148
  chat = get_prompt(messages, criteria_name)
149
- logger.debug(f'Prompt is \n{chat}')
150
 
151
- if inference_engine=="MOCK":
152
- logger.debug('Returning mocked model result.')
153
  sleep(1)
154
- label, prob_of_risk = 'Yes', 0.97
155
-
156
- elif inference_engine=="WATSONX":
157
  generated_tokens = generate_tokens(chat)
158
  label, prob_of_risk = parse_output_watsonx(generated_tokens)
159
 
160
- elif inference_engine=="VLLM":
161
  with torch.no_grad():
162
  output = model.generate(chat, sampling_params, use_tqdm=False)
163
 
@@ -165,11 +162,11 @@ def generate_text(messages, criteria_name):
165
  else:
166
  raise Exception("Environment variable 'INFERENCE_ENGINE' must be one of [WATSONX, MOCK, VLLM]")
167
 
168
- logger.debug(f'Model generated label: \n{label}')
169
- logger.debug(f'Model prob_of_risk: \n{prob_of_risk}')
170
-
171
  end = time()
172
  total = end - start
173
- logger.debug(f'The evaluation took {total} secs')
174
 
175
- return {'assessment': label, 'certainty': prob_of_risk}
 
 
 
 
1
  import math
2
  import os
3
+ from time import sleep, time
4
+
5
+ import spaces
6
+ import torch
7
  from ibm_watsonx_ai.client import APIClient
8
  from ibm_watsonx_ai.foundation_models import ModelInference
9
  from transformers import AutoTokenizer
10
+ from vllm import LLM, SamplingParams
11
+
12
+ from logger import logger
13
 
14
  safe_token = "No"
15
  risky_token = "Yes"
16
  nlogprobs = 5
17
 
18
+ inference_engine = os.getenv("INFERENCE_ENGINE", "VLLM")
19
  logger.debug(f"Inference engine is: '{inference_engine}'")
20
 
21
+ if inference_engine == "VLLM":
22
+
23
+ model_path = os.getenv("MODEL_PATH", "ibm-granite/granite-guardian-3.0-8b")
 
 
24
  logger.debug(f"model_path is {model_path}")
25
  tokenizer = AutoTokenizer.from_pretrained(model_path)
26
  sampling_params = SamplingParams(temperature=0.0, logprobs=nlogprobs)
27
  model = LLM(model=model_path, tensor_parallel_size=1)
28
 
29
  elif inference_engine == "WATSONX":
30
+ client = APIClient(
31
+ credentials={"api_key": os.getenv("WATSONX_API_KEY"), "url": "https://us-south.ml.cloud.ibm.com"}
32
+ )
33
+
34
+ client.set.default_project(os.getenv("WATSONX_PROJECT_ID"))
35
  hf_model_path = "ibm-granite/granite-guardian-3.0-8b"
36
  tokenizer = AutoTokenizer.from_pretrained(hf_model_path)
37
 
38
+ model_id = "ibm/granite-guardian-3-8b" # 8B Model: "ibm/granite-guardian-3-8b"
39
+ model = ModelInference(model_id=model_id, api_client=client)
40
+
41
+
 
 
42
  def parse_output(output):
43
  label, prob = None, None
44
 
 
58
 
59
  return label, prob_of_risk.item()
60
 
61
+
62
  def softmax(values):
63
  exp_values = [math.exp(v) for v in values]
64
  total = sum(exp_values)
65
  return [v / total for v in exp_values]
66
 
67
+
68
  def get_probablities(logprobs):
69
  safe_token_prob = 1e-50
70
  unsafe_token_prob = 1e-50
 
76
  if decoded_token.strip().lower() == risky_token.lower():
77
  unsafe_token_prob += math.exp(token_prob.logprob)
78
 
79
+ probabilities = torch.softmax(torch.tensor([math.log(safe_token_prob), math.log(unsafe_token_prob)]), dim=0)
 
 
80
 
81
  return probabilities
82
 
83
+
84
  def get_probablities_watsonx(top_tokens_list):
85
  safe_token_prob = 1e-50
86
  risky_token_prob = 1e-50
87
  for top_tokens in top_tokens_list:
88
  for token in top_tokens:
89
+ if token["text"].strip().lower() == safe_token.lower():
90
+ safe_token_prob += math.exp(token["logprob"])
91
+ if token["text"].strip().lower() == risky_token.lower():
92
+ risky_token_prob += math.exp(token["logprob"])
93
 
94
  probabilities = softmax([math.log(safe_token_prob), math.log(risky_token_prob)])
95
 
96
  return probabilities
97
 
98
+
99
  def get_prompt(messages, criteria_name):
100
+ guardian_config = {"risk_name": criteria_name if criteria_name != "general_harm" else "harm"}
101
  return tokenizer.apply_chat_template(
102
+ messages, guardian_config=guardian_config, tokenize=False, add_generation_prompt=True
103
+ )
104
+
 
105
 
106
  def generate_tokens(prompt):
107
  result = model.generate(
108
  prompt=[prompt],
109
  params={
110
+ "decoding_method": "greedy",
111
+ "max_new_tokens": 20,
112
  "temperature": 0,
113
+ "return_options": {"token_logprobs": True, "generated_tokens": True, "input_text": True, "top_n_tokens": 5},
114
+ },
115
+ )
116
+ return result[0]["results"][0]["generated_tokens"]
117
+
 
 
 
118
 
119
  def parse_output_watsonx(generated_tokens_list):
120
  label, prob_of_risk = None, None
121
 
122
  if nlogprobs > 0:
123
+ top_tokens_list = [generated_tokens["top_tokens"] for generated_tokens in generated_tokens_list]
124
  prob = get_probablities_watsonx(top_tokens_list)
125
  prob_of_risk = prob[1]
126
 
127
+ res = next(iter(generated_tokens_list))["text"].strip()
128
 
129
  if risky_token.lower() == res.lower():
130
  label = risky_token
 
135
 
136
  return label, prob_of_risk
137
 
138
+
139
  @spaces.GPU
140
  def generate_text(messages, criteria_name):
141
+ logger.debug(f"Messages used to create the prompt are: \n{messages}")
142
+
143
  start = time()
144
 
145
  chat = get_prompt(messages, criteria_name)
146
+ logger.debug(f"Prompt is \n{chat}")
147
 
148
+ if inference_engine == "MOCK":
149
+ logger.debug("Returning mocked model result.")
150
  sleep(1)
151
+ label, prob_of_risk = "Yes", 0.97
152
+
153
+ elif inference_engine == "WATSONX":
154
  generated_tokens = generate_tokens(chat)
155
  label, prob_of_risk = parse_output_watsonx(generated_tokens)
156
 
157
+ elif inference_engine == "VLLM":
158
  with torch.no_grad():
159
  output = model.generate(chat, sampling_params, use_tqdm=False)
160
 
 
162
  else:
163
  raise Exception("Environment variable 'INFERENCE_ENGINE' must be one of [WATSONX, MOCK, VLLM]")
164
 
165
+ logger.debug(f"Model generated label: \n{label}")
166
+ logger.debug(f"Model prob_of_risk: \n{prob_of_risk}")
167
+
168
  end = time()
169
  total = end - start
170
+ logger.debug(f"The evaluation took {total} secs")
171
 
172
+ return {"assessment": label, "certainty": prob_of_risk}
src/utils.py CHANGED
@@ -1,27 +1,29 @@
1
  import argparse
2
  import os
3
 
4
- def get_messages(test_case, sub_catalog_name) -> list[dict[str,str]]:
 
5
  messages = []
6
 
7
- if sub_catalog_name == 'harmful_content_in_user_prompt':
8
- messages.append({'role': 'user', 'content': test_case['user_message']})
9
- elif sub_catalog_name == 'harmful_content_in_assistant_response':
10
- messages.append({'role': 'user', 'content': test_case['user_message']})
11
- messages.append({'role': 'assistant', 'content': test_case['assistant_message']})
12
- elif sub_catalog_name == 'rag_hallucination_risks':
13
- if test_case['name'] == "context_relevance":
14
- messages.append({'role': 'user', 'content': test_case['user_message']})
15
- messages.append({'role': 'context', 'content': test_case['context']})
16
- elif test_case['name'] == "groundedness":
17
- messages.append({'role': 'context', 'content': test_case['context']})
18
- messages.append({'role': 'assistant', 'content': test_case['assistant_message']})
19
- elif test_case['name'] == "answer_relevance":
20
- messages.append({'role': 'user', 'content': test_case['user_message']})
21
- messages.append({'role': 'assistant', 'content': test_case['assistant_message']})
22
-
23
  return messages
24
 
 
25
  def get_result_description(sub_catalog_name, criteria_name):
26
  evaluated_component = get_evaluated_component(sub_catalog_name, criteria_name)
27
  messages = {
@@ -33,17 +35,18 @@ def get_result_description(sub_catalog_name, criteria_name):
33
  "unethical_behavior": f"Does the {evaluated_component} message include actions that violate moral or legal standards?",
34
  "answer_relevance": f"Does the assistant response fail to address or properly answer the user question?",
35
  "context_relevance": f"Is the retrieved context irrelevant to the user question or does not address their needs?",
36
- "groundedness": f"Does the assistant response include claims or facts not supported by or contradicted by the provided context?"
37
  }
38
  return messages[criteria_name]
39
 
 
40
  def get_evaluated_component(sub_catalog_name, criteria_name):
41
  component = None
42
- if sub_catalog_name == 'harmful_content_in_user_prompt':
43
  component = "user"
44
- elif sub_catalog_name == 'harmful_content_in_assistant_response':
45
- component = 'assistant'
46
- elif sub_catalog_name == 'rag_hallucination_risks':
47
  if criteria_name == "context_relevance":
48
  component = "context"
49
  elif criteria_name == "groundedness":
@@ -51,20 +54,24 @@ def get_evaluated_component(sub_catalog_name, criteria_name):
51
  elif criteria_name == "answer_relevance":
52
  component = "assistant"
53
  if component is None:
54
- raise Exception('Something went wrong getting the evaluated component')
55
  return component
56
 
 
57
  def to_title_case(input_string):
58
- if input_string == 'rag_hallucination_risks':
59
- return 'RAG Hallucination Risks'
60
- return ' '.join(word.capitalize() for word in input_string.split('_'))
 
61
 
62
  def capitalize_first_word(input_string):
63
- return ' '.join(word.capitalize() if i == 0 else word for i, word in enumerate(input_string.split('_')))
 
64
 
65
  def to_snake_case(text):
66
  return text.lower().replace(" ", "_")
67
 
 
68
  def load_command_line_args():
69
  parser = argparse.ArgumentParser()
70
  parser.add_argument("--model_path", type=str, default=None, help="Path to the model or HF repo")
 
1
  import argparse
2
  import os
3
 
4
+
5
+ def get_messages(test_case, sub_catalog_name) -> list[dict[str, str]]:
6
  messages = []
7
 
8
+ if sub_catalog_name == "harmful_content_in_user_prompt":
9
+ messages.append({"role": "user", "content": test_case["user_message"]})
10
+ elif sub_catalog_name == "harmful_content_in_assistant_response":
11
+ messages.append({"role": "user", "content": test_case["user_message"]})
12
+ messages.append({"role": "assistant", "content": test_case["assistant_message"]})
13
+ elif sub_catalog_name == "rag_hallucination_risks":
14
+ if test_case["name"] == "context_relevance":
15
+ messages.append({"role": "user", "content": test_case["user_message"]})
16
+ messages.append({"role": "context", "content": test_case["context"]})
17
+ elif test_case["name"] == "groundedness":
18
+ messages.append({"role": "context", "content": test_case["context"]})
19
+ messages.append({"role": "assistant", "content": test_case["assistant_message"]})
20
+ elif test_case["name"] == "answer_relevance":
21
+ messages.append({"role": "user", "content": test_case["user_message"]})
22
+ messages.append({"role": "assistant", "content": test_case["assistant_message"]})
23
+
24
  return messages
25
 
26
+
27
  def get_result_description(sub_catalog_name, criteria_name):
28
  evaluated_component = get_evaluated_component(sub_catalog_name, criteria_name)
29
  messages = {
 
35
  "unethical_behavior": f"Does the {evaluated_component} message include actions that violate moral or legal standards?",
36
  "answer_relevance": f"Does the assistant response fail to address or properly answer the user question?",
37
  "context_relevance": f"Is the retrieved context irrelevant to the user question or does not address their needs?",
38
+ "groundedness": f"Does the assistant response include claims or facts not supported by or contradicted by the provided context?",
39
  }
40
  return messages[criteria_name]
41
 
42
+
43
  def get_evaluated_component(sub_catalog_name, criteria_name):
44
  component = None
45
+ if sub_catalog_name == "harmful_content_in_user_prompt":
46
  component = "user"
47
+ elif sub_catalog_name == "harmful_content_in_assistant_response":
48
+ component = "assistant"
49
+ elif sub_catalog_name == "rag_hallucination_risks":
50
  if criteria_name == "context_relevance":
51
  component = "context"
52
  elif criteria_name == "groundedness":
 
54
  elif criteria_name == "answer_relevance":
55
  component = "assistant"
56
  if component is None:
57
+ raise Exception("Something went wrong getting the evaluated component")
58
  return component
59
 
60
+
61
  def to_title_case(input_string):
62
+ if input_string == "rag_hallucination_risks":
63
+ return "RAG Hallucination Risks"
64
+ return " ".join(word.capitalize() for word in input_string.split("_"))
65
+
66
 
67
  def capitalize_first_word(input_string):
68
+ return " ".join(word.capitalize() if i == 0 else word for i, word in enumerate(input_string.split("_")))
69
+
70
 
71
  def to_snake_case(text):
72
  return text.lower().replace(" ", "_")
73
 
74
+
75
  def load_command_line_args():
76
  parser = argparse.ArgumentParser()
77
  parser.add_argument("--model_path", type=str, default=None, help="Path to the model or HF repo")