Spaces:
Running
on
Zero
Running
on
Zero
grahamwhiteuk
commited on
Commit
•
5b7f169
1
Parent(s):
36223d2
fix: deployment
Browse files- .flake8 +5 -0
- requirements.txt +4 -8
- src/app.py +216 -155
- src/logger.py +5 -3
- src/model.py +57 -60
- src/utils.py +34 -27
.flake8
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[flake8]
|
2 |
+
max-line-length = 120
|
3 |
+
|
4 |
+
select = C,E,F,W,B,B950
|
5 |
+
extend-ignore = E501,E203,W503
|
requirements.txt
CHANGED
@@ -1,10 +1,6 @@
|
|
1 |
-
|
2 |
python-dotenv
|
3 |
-
tqdm
|
4 |
-
jinja2
|
5 |
-
ibm_watsonx_ai
|
6 |
transformers
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
vllm
|
|
|
1 |
+
gradio_modal
|
2 |
python-dotenv
|
|
|
|
|
|
|
3 |
transformers
|
4 |
+
accelerate
|
5 |
+
ibm_watsonx_ai
|
6 |
+
vllm
|
|
src/app.py
CHANGED
@@ -1,121 +1,150 @@
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
from dotenv import load_dotenv
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
-
from utils import get_result_description, to_title_case, to_snake_case, load_command_line_args, get_messages
|
5 |
load_command_line_args()
|
6 |
load_dotenv()
|
7 |
-
import json
|
8 |
-
from model import generate_text, get_prompt
|
9 |
-
from logger import logger
|
10 |
-
import os
|
11 |
-
from gradio_modal import Modal
|
12 |
|
13 |
catalog = {}
|
14 |
|
15 |
-
with open(
|
16 |
-
logger.debug(
|
17 |
catalog = json.load(f)
|
18 |
|
|
|
19 |
def update_selected_test_case(button_name, state: gr.State, event: gr.EventData):
|
20 |
-
target_sub_catalog_name, target_test_case_name = event.target.elem_id.split(
|
21 |
-
state[
|
22 |
-
state[
|
23 |
-
state[
|
|
|
|
|
|
|
|
|
|
|
24 |
return state
|
25 |
|
|
|
26 |
def on_test_case_click(state: gr.State):
|
27 |
-
selected_sub_catalog = state[
|
28 |
-
selected_criteria_name = state[
|
29 |
-
selected_test_case = state[
|
30 |
|
31 |
logger.debug(f'Changing to test case "{selected_criteria_name}" from catalog "{selected_sub_catalog}".')
|
32 |
|
33 |
-
is_context_iditable = selected_criteria_name ==
|
34 |
-
is_user_message_editable = selected_sub_catalog ==
|
35 |
-
is_assistant_message_editable =
|
36 |
-
|
37 |
-
|
|
|
|
|
38 |
return {
|
39 |
test_case_name: f'<h2>{to_title_case(selected_test_case["name"])}</h2>',
|
40 |
-
criteria: selected_test_case[
|
41 |
-
context:
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
visible=selected_test_case['context'] is not None,
|
48 |
-
value=selected_test_case['context'],
|
49 |
-
interactive=False,
|
50 |
-
elem_classes=['read-only', 'input-box']
|
51 |
-
),
|
52 |
-
user_message: gr.update(
|
53 |
-
value=selected_test_case['user_message'],
|
54 |
-
visible=True,
|
55 |
-
interactive=True,
|
56 |
-
elem_classes=['input-box']
|
57 |
-
) if is_user_message_editable else gr.update(
|
58 |
-
value=selected_test_case['user_message'],
|
59 |
interactive=False,
|
60 |
-
elem_classes=[
|
61 |
-
)
|
62 |
-
|
63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
visible=True,
|
65 |
interactive=True,
|
66 |
-
elem_classes=[
|
67 |
-
)
|
68 |
-
|
69 |
-
|
|
|
|
|
70 |
interactive=False,
|
71 |
-
elem_classes=[
|
72 |
-
)
|
73 |
-
|
|
|
74 |
}
|
75 |
|
|
|
76 |
def change_button_color(event: gr.EventData):
|
77 |
-
return [
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
|
|
|
|
|
|
|
|
|
|
82 |
|
83 |
def on_submit(criteria, context, user_message, assistant_message, state):
|
84 |
-
criteria_name = state[
|
85 |
test_case = {
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
}
|
92 |
|
93 |
-
messages = get_messages(test_case=test_case, sub_catalog_name=state[
|
94 |
-
|
95 |
-
logger.debug(
|
96 |
-
|
97 |
-
|
|
|
|
|
98 |
|
99 |
html_str = f"<p>{get_result_description(state['selected_sub_catalog'], state['selected_criteria_name'])} <strong>{result_label}</strong></p>"
|
100 |
# html_str = f"{get_result_description(state['selected_sub_catalog'], state['selected_criteria_name'])} {result_label}"
|
101 |
return gr.update(value=html_str)
|
102 |
|
|
|
103 |
def on_show_prompt_click(criteria, context, user_message, assistant_message, state):
|
104 |
-
criteria_name = state[
|
105 |
test_case = {
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
}
|
112 |
|
113 |
-
messages = get_messages(test_case=test_case, sub_catalog_name=state[
|
114 |
prompt = get_prompt(messages, criteria_name)
|
115 |
print(prompt)
|
116 |
-
prompt = prompt.replace(
|
117 |
return gr.Markdown(prompt)
|
118 |
|
|
|
119 |
ibm_blue = gr.themes.Color(
|
120 |
name="ibm-blue",
|
121 |
c50="#eff6ff",
|
@@ -128,7 +157,7 @@ ibm_blue = gr.themes.Color(
|
|
128 |
c700="#1d4ed8",
|
129 |
c800="#1e40af",
|
130 |
c900="#1e3a8a",
|
131 |
-
c950="#1d3660"
|
132 |
)
|
133 |
|
134 |
head_style = """
|
@@ -149,107 +178,139 @@ head_style = """
|
|
149 |
"""
|
150 |
|
151 |
with gr.Blocks(
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
font=[gr.themes.GoogleFont("IBM Plex Sans"), gr.themes.GoogleFont(
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
state = gr.State(
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
|
|
|
|
|
|
|
|
|
|
170 |
with gr.Column(scale=4):
|
171 |
-
gr.HTML(
|
172 |
-
gr.HTML(
|
173 |
-
|
174 |
-
|
175 |
-
|
|
|
|
|
|
|
176 |
accordions = []
|
177 |
-
catalog_buttons: dict[str,dict[str,gr.Button]] = {}
|
178 |
for i, (sub_catalog_name, sub_catalog) in enumerate(catalog.items()):
|
179 |
-
with gr.Accordion(
|
|
|
|
|
180 |
for test_case in sub_catalog:
|
181 |
-
elem_classes=[
|
182 |
-
elem_id=f"{sub_catalog_name}---{test_case['name']}"
|
183 |
if starting_test_case == test_case:
|
184 |
-
elem_classes.append(
|
185 |
|
186 |
-
if not
|
187 |
catalog_buttons[sub_catalog_name] = {}
|
188 |
|
189 |
-
catalog_buttons[sub_catalog_name][test_case[
|
190 |
-
|
191 |
-
|
|
|
|
|
|
|
|
|
|
|
192 |
accordions.append(accordion)
|
193 |
|
194 |
with gr.Column(visible=True, scale=1) as test_case_content:
|
195 |
-
with gr.Row(elem_classes=
|
196 |
-
test_case_name = gr.HTML(
|
197 |
-
|
|
|
|
|
198 |
|
199 |
-
criteria = gr.Textbox(
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
204 |
|
205 |
submit_button = gr.Button(
|
206 |
-
"Evaluate",
|
207 |
-
variant=
|
208 |
-
icon=os.path.join(os.path.dirname(os.path.abspath(__file__)),
|
209 |
-
elem_classes=
|
210 |
-
|
|
|
211 |
# result_text = gr.HTML(label='Result', elem_classes=['result-text', 'read-only', 'input-box'], visible=False, value='')
|
212 |
result_text = gr.HTML(
|
213 |
-
label=
|
214 |
-
|
215 |
-
show_label=True,
|
216 |
-
visible=False,
|
217 |
-
value='')
|
218 |
|
219 |
-
with Modal(visible=False, elem_classes=
|
220 |
-
prompt = gr.Markdown(
|
221 |
|
222 |
-
|
223 |
### events
|
224 |
|
225 |
show_propt_button.click(
|
226 |
-
on_show_prompt_click,
|
227 |
-
inputs=[criteria, context, user_message, assistant_message, state],
|
228 |
-
outputs=prompt
|
229 |
).then(lambda: gr.update(visible=True), None, modal)
|
230 |
|
231 |
-
submit_button
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
238 |
)
|
239 |
-
|
240 |
-
|
241 |
-
button \
|
242 |
-
.click(
|
243 |
-
change_button_color,
|
244 |
-
inputs=None,
|
245 |
-
outputs=[v for c in catalog_buttons.values() for v in c.values()]) \
|
246 |
-
.then(
|
247 |
-
update_selected_test_case,
|
248 |
-
inputs=[button, state],
|
249 |
-
outputs=[state]) \
|
250 |
-
.then(
|
251 |
-
on_test_case_click,
|
252 |
-
inputs=state,
|
253 |
-
outputs={test_case_name, criteria, context, user_message, assistant_message, result_text})
|
254 |
-
|
255 |
-
demo.launch(server_name='0.0.0.0')
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
|
4 |
import gradio as gr
|
5 |
from dotenv import load_dotenv
|
6 |
+
from gradio_modal import Modal
|
7 |
+
|
8 |
+
from logger import logger
|
9 |
+
from model import generate_text, get_prompt
|
10 |
+
from utils import (
|
11 |
+
get_messages,
|
12 |
+
get_result_description,
|
13 |
+
load_command_line_args,
|
14 |
+
to_snake_case,
|
15 |
+
to_title_case,
|
16 |
+
)
|
17 |
|
|
|
18 |
load_command_line_args()
|
19 |
load_dotenv()
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
catalog = {}
|
22 |
|
23 |
+
with open("catalog.json") as f:
|
24 |
+
logger.debug("Loading catalog from json.")
|
25 |
catalog = json.load(f)
|
26 |
|
27 |
+
|
28 |
def update_selected_test_case(button_name, state: gr.State, event: gr.EventData):
|
29 |
+
target_sub_catalog_name, target_test_case_name = event.target.elem_id.split("---")
|
30 |
+
state["selected_sub_catalog"] = target_sub_catalog_name
|
31 |
+
state["selected_criteria_name"] = target_test_case_name
|
32 |
+
state["selected_test_case"] = [
|
33 |
+
t
|
34 |
+
for sub_catalog_name, sub_catalog in catalog.items()
|
35 |
+
for t in sub_catalog
|
36 |
+
if t["name"] == to_snake_case(button_name) and to_snake_case(sub_catalog_name) == target_sub_catalog_name
|
37 |
+
][0]
|
38 |
return state
|
39 |
|
40 |
+
|
41 |
def on_test_case_click(state: gr.State):
|
42 |
+
selected_sub_catalog = state["selected_sub_catalog"]
|
43 |
+
selected_criteria_name = state["selected_criteria_name"]
|
44 |
+
selected_test_case = state["selected_test_case"]
|
45 |
|
46 |
logger.debug(f'Changing to test case "{selected_criteria_name}" from catalog "{selected_sub_catalog}".')
|
47 |
|
48 |
+
is_context_iditable = selected_criteria_name == "context_relevance"
|
49 |
+
is_user_message_editable = selected_sub_catalog == "harmful_content_in_user_prompt"
|
50 |
+
is_assistant_message_editable = (
|
51 |
+
selected_sub_catalog == "harmful_content_in_assistant_response"
|
52 |
+
or selected_criteria_name == "groundedness"
|
53 |
+
or selected_criteria_name == "answer_relevance"
|
54 |
+
)
|
55 |
return {
|
56 |
test_case_name: f'<h2>{to_title_case(selected_test_case["name"])}</h2>',
|
57 |
+
criteria: selected_test_case["criteria"],
|
58 |
+
context: (
|
59 |
+
gr.update(value=selected_test_case["context"], interactive=True, visible=True, elem_classes=["input-box"])
|
60 |
+
if is_context_iditable
|
61 |
+
else gr.update(
|
62 |
+
visible=selected_test_case["context"] is not None,
|
63 |
+
value=selected_test_case["context"],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
interactive=False,
|
65 |
+
elem_classes=["read-only", "input-box"],
|
66 |
+
)
|
67 |
+
),
|
68 |
+
user_message: (
|
69 |
+
gr.update(
|
70 |
+
value=selected_test_case["user_message"], visible=True, interactive=True, elem_classes=["input-box"]
|
71 |
+
)
|
72 |
+
if is_user_message_editable
|
73 |
+
else gr.update(
|
74 |
+
value=selected_test_case["user_message"], interactive=False, elem_classes=["read-only", "input-box"]
|
75 |
+
)
|
76 |
+
),
|
77 |
+
assistant_message: (
|
78 |
+
gr.update(
|
79 |
+
value=selected_test_case["assistant_message"],
|
80 |
visible=True,
|
81 |
interactive=True,
|
82 |
+
elem_classes=["input-box"],
|
83 |
+
)
|
84 |
+
if is_assistant_message_editable
|
85 |
+
else gr.update(
|
86 |
+
visible=selected_test_case["assistant_message"] is not None,
|
87 |
+
value=selected_test_case["assistant_message"],
|
88 |
interactive=False,
|
89 |
+
elem_classes=["read-only", "input-box"],
|
90 |
+
)
|
91 |
+
),
|
92 |
+
result_text: gr.update(visible=False, value=""),
|
93 |
}
|
94 |
|
95 |
+
|
96 |
def change_button_color(event: gr.EventData):
|
97 |
+
return [
|
98 |
+
(
|
99 |
+
gr.update(elem_classes=["catalog-button", "selected"])
|
100 |
+
if v.elem_id == event.target.elem_id
|
101 |
+
else gr.update(elem_classes=["catalog-button"])
|
102 |
+
)
|
103 |
+
for c in catalog_buttons.values()
|
104 |
+
for v in c.values()
|
105 |
+
]
|
106 |
+
|
107 |
|
108 |
def on_submit(criteria, context, user_message, assistant_message, state):
|
109 |
+
criteria_name = state["selected_criteria_name"]
|
110 |
test_case = {
|
111 |
+
"name": criteria_name,
|
112 |
+
"criteria": criteria,
|
113 |
+
"context": context,
|
114 |
+
"user_message": user_message,
|
115 |
+
"assistant_message": assistant_message,
|
116 |
}
|
117 |
|
118 |
+
messages = get_messages(test_case=test_case, sub_catalog_name=state["selected_sub_catalog"])
|
119 |
+
|
120 |
+
logger.debug(
|
121 |
+
f"Starting evaluation for subcatelog {state['selected_sub_catalog']} and criteria name {state['selected_criteria_name']}"
|
122 |
+
)
|
123 |
+
|
124 |
+
result_label = generate_text(messages=messages, criteria_name=criteria_name)["assessment"] # Yes or No
|
125 |
|
126 |
html_str = f"<p>{get_result_description(state['selected_sub_catalog'], state['selected_criteria_name'])} <strong>{result_label}</strong></p>"
|
127 |
# html_str = f"{get_result_description(state['selected_sub_catalog'], state['selected_criteria_name'])} {result_label}"
|
128 |
return gr.update(value=html_str)
|
129 |
|
130 |
+
|
131 |
def on_show_prompt_click(criteria, context, user_message, assistant_message, state):
|
132 |
+
criteria_name = state["selected_criteria_name"]
|
133 |
test_case = {
|
134 |
+
"name": criteria_name,
|
135 |
+
"criteria": criteria,
|
136 |
+
"context": context,
|
137 |
+
"user_message": user_message,
|
138 |
+
"assistant_message": assistant_message,
|
139 |
}
|
140 |
|
141 |
+
messages = get_messages(test_case=test_case, sub_catalog_name=state["selected_sub_catalog"])
|
142 |
prompt = get_prompt(messages, criteria_name)
|
143 |
print(prompt)
|
144 |
+
prompt = prompt.replace("<", "<").replace(">", ">").replace("\\n", "<br>")
|
145 |
return gr.Markdown(prompt)
|
146 |
|
147 |
+
|
148 |
ibm_blue = gr.themes.Color(
|
149 |
name="ibm-blue",
|
150 |
c50="#eff6ff",
|
|
|
157 |
c700="#1d4ed8",
|
158 |
c800="#1e40af",
|
159 |
c900="#1e3a8a",
|
160 |
+
c950="#1d3660",
|
161 |
)
|
162 |
|
163 |
head_style = """
|
|
|
178 |
"""
|
179 |
|
180 |
with gr.Blocks(
|
181 |
+
title="Granite Guardian",
|
182 |
+
theme=gr.themes.Soft(
|
183 |
+
primary_hue=ibm_blue,
|
184 |
+
font=[gr.themes.GoogleFont("IBM Plex Sans"), gr.themes.GoogleFont("Source Sans 3")],
|
185 |
+
),
|
186 |
+
head=head_style,
|
187 |
+
fill_width=False,
|
188 |
+
css=os.path.join(os.path.dirname(os.path.abspath(__file__)), "./styles.css"),
|
189 |
+
) as demo:
|
190 |
+
|
191 |
+
state = gr.State(
|
192 |
+
value={"selected_sub_catalog": "harmful_content_in_user_prompt", "selected_criteria_name": "general_harm"}
|
193 |
+
)
|
194 |
+
|
195 |
+
starting_test_case = [
|
196 |
+
t
|
197 |
+
for sub_catalog_name, sub_catalog in catalog.items()
|
198 |
+
for t in sub_catalog
|
199 |
+
if t["name"] == state.value["selected_criteria_name"]
|
200 |
+
and sub_catalog_name == state.value["selected_sub_catalog"]
|
201 |
+
][0]
|
202 |
+
|
203 |
+
with gr.Row(elem_classes="header-row"):
|
204 |
with gr.Column(scale=4):
|
205 |
+
gr.HTML("<h2>IBM Granite Guardian 3.0</h2>", elem_classes="title")
|
206 |
+
gr.HTML(
|
207 |
+
elem_classes="system-description",
|
208 |
+
value="<p>Granite Guardian models are specialized language models in the Granite family that can detect harms and risks in generative AI systems. They can be used with any large language model to make interactions with generative AI systems safe. Select an example in the left panel to see how the Granite Guardian model evaluates harms and risks in user prompts, assistant responses, and for hallucinations in retrieval-augmented generation. In this demo, we use granite-guardian-3.0-8b.</p>",
|
209 |
+
)
|
210 |
+
with gr.Row(elem_classes="column-gap"):
|
211 |
+
with gr.Column(scale=0, elem_classes="no-gap"):
|
212 |
+
title_display_left = gr.HTML("<h2>Harms & Risks</h2>", elem_classes=["subtitle", "subtitle-harms"])
|
213 |
accordions = []
|
214 |
+
catalog_buttons: dict[str, dict[str, gr.Button]] = {}
|
215 |
for i, (sub_catalog_name, sub_catalog) in enumerate(catalog.items()):
|
216 |
+
with gr.Accordion(
|
217 |
+
to_title_case(sub_catalog_name), open=(i == 0), elem_classes="accordion"
|
218 |
+
) as accordion:
|
219 |
for test_case in sub_catalog:
|
220 |
+
elem_classes = ["catalog-button"]
|
221 |
+
elem_id = f"{sub_catalog_name}---{test_case['name']}"
|
222 |
if starting_test_case == test_case:
|
223 |
+
elem_classes.append("selected")
|
224 |
|
225 |
+
if sub_catalog_name not in catalog_buttons:
|
226 |
catalog_buttons[sub_catalog_name] = {}
|
227 |
|
228 |
+
catalog_buttons[sub_catalog_name][test_case["name"]] = gr.Button(
|
229 |
+
to_title_case(test_case["name"]),
|
230 |
+
elem_classes=elem_classes,
|
231 |
+
variant="secondary",
|
232 |
+
size="sm",
|
233 |
+
elem_id=elem_id,
|
234 |
+
)
|
235 |
+
|
236 |
accordions.append(accordion)
|
237 |
|
238 |
with gr.Column(visible=True, scale=1) as test_case_content:
|
239 |
+
with gr.Row(elem_classes="no-stretch"):
|
240 |
+
test_case_name = gr.HTML(
|
241 |
+
f'<h2>{to_title_case(starting_test_case["name"])}</h2>', elem_classes="subtitle"
|
242 |
+
)
|
243 |
+
show_propt_button = gr.Button("Show prompt", size="sm", scale=0, min_width=110)
|
244 |
|
245 |
+
criteria = gr.Textbox(
|
246 |
+
label="Evaluation Criteria",
|
247 |
+
lines=3,
|
248 |
+
interactive=False,
|
249 |
+
value=starting_test_case["criteria"],
|
250 |
+
elem_classes=["read-only", "input-box", "margin-bottom"],
|
251 |
+
)
|
252 |
+
gr.HTML(elem_classes=["block", "content-gap"])
|
253 |
+
context = gr.Textbox(
|
254 |
+
label="Context",
|
255 |
+
lines=3,
|
256 |
+
interactive=True,
|
257 |
+
value=starting_test_case["context"],
|
258 |
+
visible=False,
|
259 |
+
elem_classes=["input-box"],
|
260 |
+
)
|
261 |
+
user_message = gr.Textbox(
|
262 |
+
label="User Prompt",
|
263 |
+
lines=3,
|
264 |
+
interactive=True,
|
265 |
+
value=starting_test_case["user_message"],
|
266 |
+
elem_classes=["input-box"],
|
267 |
+
)
|
268 |
+
assistant_message = gr.Textbox(
|
269 |
+
label="Assistant Response",
|
270 |
+
lines=3,
|
271 |
+
interactive=True,
|
272 |
+
visible=False,
|
273 |
+
value=starting_test_case["assistant_message"],
|
274 |
+
elem_classes=["input-box"],
|
275 |
+
)
|
276 |
|
277 |
submit_button = gr.Button(
|
278 |
+
"Evaluate",
|
279 |
+
variant="primary",
|
280 |
+
icon=os.path.join(os.path.dirname(os.path.abspath(__file__)), "send-white.png"),
|
281 |
+
elem_classes="submit-button",
|
282 |
+
)
|
283 |
+
|
284 |
# result_text = gr.HTML(label='Result', elem_classes=['result-text', 'read-only', 'input-box'], visible=False, value='')
|
285 |
result_text = gr.HTML(
|
286 |
+
label="Result", elem_classes=["result-root"], show_label=True, visible=False, value=""
|
287 |
+
)
|
|
|
|
|
|
|
288 |
|
289 |
+
with Modal(visible=False, elem_classes="modal") as modal:
|
290 |
+
prompt = gr.Markdown("")
|
291 |
|
|
|
292 |
### events
|
293 |
|
294 |
show_propt_button.click(
|
295 |
+
on_show_prompt_click, inputs=[criteria, context, user_message, assistant_message, state], outputs=prompt
|
|
|
|
|
296 |
).then(lambda: gr.update(visible=True), None, modal)
|
297 |
|
298 |
+
submit_button.click(lambda: gr.update(visible=True, value=""), None, result_text).then(
|
299 |
+
on_submit,
|
300 |
+
inputs=[criteria, context, user_message, assistant_message, state],
|
301 |
+
outputs=[result_text],
|
302 |
+
scroll_to_output=True,
|
303 |
+
)
|
304 |
+
|
305 |
+
for button in [
|
306 |
+
t for sub_catalog_name, sub_catalog_buttons in catalog_buttons.items() for t in sub_catalog_buttons.values()
|
307 |
+
]:
|
308 |
+
button.click(
|
309 |
+
change_button_color, inputs=None, outputs=[v for c in catalog_buttons.values() for v in c.values()]
|
310 |
+
).then(update_selected_test_case, inputs=[button, state], outputs=[state]).then(
|
311 |
+
on_test_case_click,
|
312 |
+
inputs=state,
|
313 |
+
outputs={test_case_name, criteria, context, user_message, assistant_message, result_text},
|
314 |
)
|
315 |
+
|
316 |
+
demo.launch(server_name="0.0.0.0")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/logger.py
CHANGED
@@ -1,12 +1,14 @@
|
|
1 |
import logging
|
2 |
|
3 |
-
logger = logging.getLogger(
|
4 |
logger.setLevel(logging.DEBUG)
|
5 |
|
6 |
stream_handler = logging.StreamHandler()
|
7 |
stream_handler.setLevel(logging.DEBUG)
|
8 |
logger.addHandler(stream_handler)
|
9 |
|
10 |
-
file_handler = logging.FileHandler(
|
11 |
-
file_handler.setFormatter(
|
|
|
|
|
12 |
logger.addHandler(file_handler)
|
|
|
1 |
import logging
|
2 |
|
3 |
+
logger = logging.getLogger("demo")
|
4 |
logger.setLevel(logging.DEBUG)
|
5 |
|
6 |
stream_handler = logging.StreamHandler()
|
7 |
stream_handler.setLevel(logging.DEBUG)
|
8 |
logger.addHandler(stream_handler)
|
9 |
|
10 |
+
file_handler = logging.FileHandler("logs.txt")
|
11 |
+
file_handler.setFormatter(
|
12 |
+
logging.Formatter("%(asctime)s - %(filename)s:%(lineno)d - %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
|
13 |
+
)
|
14 |
logger.addHandler(file_handler)
|
src/model.py
CHANGED
@@ -1,46 +1,44 @@
|
|
1 |
-
import os
|
2 |
-
from time import time, sleep
|
3 |
-
from logger import logger
|
4 |
import math
|
5 |
import os
|
|
|
|
|
|
|
|
|
6 |
from ibm_watsonx_ai.client import APIClient
|
7 |
from ibm_watsonx_ai.foundation_models import ModelInference
|
8 |
from transformers import AutoTokenizer
|
9 |
-
import
|
10 |
-
|
|
|
11 |
|
12 |
safe_token = "No"
|
13 |
risky_token = "Yes"
|
14 |
nlogprobs = 5
|
15 |
|
16 |
-
inference_engine = os.getenv(
|
17 |
logger.debug(f"Inference engine is: '{inference_engine}'")
|
18 |
|
19 |
-
if inference_engine ==
|
20 |
-
|
21 |
-
|
22 |
-
from transformers import AutoTokenizer
|
23 |
-
model_path = os.getenv('MODEL_PATH', 'ibm-granite/granite-guardian-3.0-8b')
|
24 |
logger.debug(f"model_path is {model_path}")
|
25 |
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
26 |
sampling_params = SamplingParams(temperature=0.0, logprobs=nlogprobs)
|
27 |
model = LLM(model=model_path, tensor_parallel_size=1)
|
28 |
|
29 |
elif inference_engine == "WATSONX":
|
30 |
-
client = APIClient(
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
client.set.default_project(os.getenv(
|
35 |
hf_model_path = "ibm-granite/granite-guardian-3.0-8b"
|
36 |
tokenizer = AutoTokenizer.from_pretrained(hf_model_path)
|
37 |
|
38 |
-
model_id = "ibm/granite-guardian-3-8b"
|
39 |
-
model = ModelInference(
|
40 |
-
|
41 |
-
|
42 |
-
)
|
43 |
-
|
44 |
def parse_output(output):
|
45 |
label, prob = None, None
|
46 |
|
@@ -60,11 +58,13 @@ def parse_output(output):
|
|
60 |
|
61 |
return label, prob_of_risk.item()
|
62 |
|
|
|
63 |
def softmax(values):
|
64 |
exp_values = [math.exp(v) for v in values]
|
65 |
total = sum(exp_values)
|
66 |
return [v / total for v in exp_values]
|
67 |
|
|
|
68 |
def get_probablities(logprobs):
|
69 |
safe_token_prob = 1e-50
|
70 |
unsafe_token_prob = 1e-50
|
@@ -76,59 +76,55 @@ def get_probablities(logprobs):
|
|
76 |
if decoded_token.strip().lower() == risky_token.lower():
|
77 |
unsafe_token_prob += math.exp(token_prob.logprob)
|
78 |
|
79 |
-
probabilities = torch.softmax(
|
80 |
-
torch.tensor([math.log(safe_token_prob), math.log(unsafe_token_prob)]), dim=0
|
81 |
-
)
|
82 |
|
83 |
return probabilities
|
84 |
|
|
|
85 |
def get_probablities_watsonx(top_tokens_list):
|
86 |
safe_token_prob = 1e-50
|
87 |
risky_token_prob = 1e-50
|
88 |
for top_tokens in top_tokens_list:
|
89 |
for token in top_tokens:
|
90 |
-
if token[
|
91 |
-
safe_token_prob += math.exp(token[
|
92 |
-
if token[
|
93 |
-
risky_token_prob += math.exp(token[
|
94 |
|
95 |
probabilities = softmax([math.log(safe_token_prob), math.log(risky_token_prob)])
|
96 |
|
97 |
return probabilities
|
98 |
|
|
|
99 |
def get_prompt(messages, criteria_name):
|
100 |
-
guardian_config = {"risk_name": criteria_name if criteria_name !=
|
101 |
return tokenizer.apply_chat_template(
|
102 |
-
messages,
|
103 |
-
|
104 |
-
|
105 |
-
add_generation_prompt=True)
|
106 |
|
107 |
def generate_tokens(prompt):
|
108 |
result = model.generate(
|
109 |
prompt=[prompt],
|
110 |
params={
|
111 |
-
|
112 |
-
|
113 |
"temperature": 0,
|
114 |
-
"return_options": {
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
}
|
120 |
-
})
|
121 |
-
return result[0]['results'][0]['generated_tokens']
|
122 |
|
123 |
def parse_output_watsonx(generated_tokens_list):
|
124 |
label, prob_of_risk = None, None
|
125 |
|
126 |
if nlogprobs > 0:
|
127 |
-
top_tokens_list = [generated_tokens[
|
128 |
prob = get_probablities_watsonx(top_tokens_list)
|
129 |
prob_of_risk = prob[1]
|
130 |
|
131 |
-
res = next(iter(generated_tokens_list))[
|
132 |
|
133 |
if risky_token.lower() == res.lower():
|
134 |
label = risky_token
|
@@ -139,25 +135,26 @@ def parse_output_watsonx(generated_tokens_list):
|
|
139 |
|
140 |
return label, prob_of_risk
|
141 |
|
|
|
142 |
@spaces.GPU
|
143 |
def generate_text(messages, criteria_name):
|
144 |
-
logger.debug(f
|
145 |
-
|
146 |
start = time()
|
147 |
|
148 |
chat = get_prompt(messages, criteria_name)
|
149 |
-
logger.debug(f
|
150 |
|
151 |
-
if inference_engine=="MOCK":
|
152 |
-
logger.debug(
|
153 |
sleep(1)
|
154 |
-
label, prob_of_risk =
|
155 |
-
|
156 |
-
elif inference_engine=="WATSONX":
|
157 |
generated_tokens = generate_tokens(chat)
|
158 |
label, prob_of_risk = parse_output_watsonx(generated_tokens)
|
159 |
|
160 |
-
elif inference_engine=="VLLM":
|
161 |
with torch.no_grad():
|
162 |
output = model.generate(chat, sampling_params, use_tqdm=False)
|
163 |
|
@@ -165,11 +162,11 @@ def generate_text(messages, criteria_name):
|
|
165 |
else:
|
166 |
raise Exception("Environment variable 'INFERENCE_ENGINE' must be one of [WATSONX, MOCK, VLLM]")
|
167 |
|
168 |
-
logger.debug(f
|
169 |
-
logger.debug(f
|
170 |
-
|
171 |
end = time()
|
172 |
total = end - start
|
173 |
-
logger.debug(f
|
174 |
|
175 |
-
return {
|
|
|
|
|
|
|
|
|
1 |
import math
|
2 |
import os
|
3 |
+
from time import sleep, time
|
4 |
+
|
5 |
+
import spaces
|
6 |
+
import torch
|
7 |
from ibm_watsonx_ai.client import APIClient
|
8 |
from ibm_watsonx_ai.foundation_models import ModelInference
|
9 |
from transformers import AutoTokenizer
|
10 |
+
from vllm import LLM, SamplingParams
|
11 |
+
|
12 |
+
from logger import logger
|
13 |
|
14 |
safe_token = "No"
|
15 |
risky_token = "Yes"
|
16 |
nlogprobs = 5
|
17 |
|
18 |
+
inference_engine = os.getenv("INFERENCE_ENGINE", "VLLM")
|
19 |
logger.debug(f"Inference engine is: '{inference_engine}'")
|
20 |
|
21 |
+
if inference_engine == "VLLM":
|
22 |
+
|
23 |
+
model_path = os.getenv("MODEL_PATH", "ibm-granite/granite-guardian-3.0-8b")
|
|
|
|
|
24 |
logger.debug(f"model_path is {model_path}")
|
25 |
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
26 |
sampling_params = SamplingParams(temperature=0.0, logprobs=nlogprobs)
|
27 |
model = LLM(model=model_path, tensor_parallel_size=1)
|
28 |
|
29 |
elif inference_engine == "WATSONX":
|
30 |
+
client = APIClient(
|
31 |
+
credentials={"api_key": os.getenv("WATSONX_API_KEY"), "url": "https://us-south.ml.cloud.ibm.com"}
|
32 |
+
)
|
33 |
+
|
34 |
+
client.set.default_project(os.getenv("WATSONX_PROJECT_ID"))
|
35 |
hf_model_path = "ibm-granite/granite-guardian-3.0-8b"
|
36 |
tokenizer = AutoTokenizer.from_pretrained(hf_model_path)
|
37 |
|
38 |
+
model_id = "ibm/granite-guardian-3-8b" # 8B Model: "ibm/granite-guardian-3-8b"
|
39 |
+
model = ModelInference(model_id=model_id, api_client=client)
|
40 |
+
|
41 |
+
|
|
|
|
|
42 |
def parse_output(output):
|
43 |
label, prob = None, None
|
44 |
|
|
|
58 |
|
59 |
return label, prob_of_risk.item()
|
60 |
|
61 |
+
|
62 |
def softmax(values):
|
63 |
exp_values = [math.exp(v) for v in values]
|
64 |
total = sum(exp_values)
|
65 |
return [v / total for v in exp_values]
|
66 |
|
67 |
+
|
68 |
def get_probablities(logprobs):
|
69 |
safe_token_prob = 1e-50
|
70 |
unsafe_token_prob = 1e-50
|
|
|
76 |
if decoded_token.strip().lower() == risky_token.lower():
|
77 |
unsafe_token_prob += math.exp(token_prob.logprob)
|
78 |
|
79 |
+
probabilities = torch.softmax(torch.tensor([math.log(safe_token_prob), math.log(unsafe_token_prob)]), dim=0)
|
|
|
|
|
80 |
|
81 |
return probabilities
|
82 |
|
83 |
+
|
84 |
def get_probablities_watsonx(top_tokens_list):
|
85 |
safe_token_prob = 1e-50
|
86 |
risky_token_prob = 1e-50
|
87 |
for top_tokens in top_tokens_list:
|
88 |
for token in top_tokens:
|
89 |
+
if token["text"].strip().lower() == safe_token.lower():
|
90 |
+
safe_token_prob += math.exp(token["logprob"])
|
91 |
+
if token["text"].strip().lower() == risky_token.lower():
|
92 |
+
risky_token_prob += math.exp(token["logprob"])
|
93 |
|
94 |
probabilities = softmax([math.log(safe_token_prob), math.log(risky_token_prob)])
|
95 |
|
96 |
return probabilities
|
97 |
|
98 |
+
|
99 |
def get_prompt(messages, criteria_name):
|
100 |
+
guardian_config = {"risk_name": criteria_name if criteria_name != "general_harm" else "harm"}
|
101 |
return tokenizer.apply_chat_template(
|
102 |
+
messages, guardian_config=guardian_config, tokenize=False, add_generation_prompt=True
|
103 |
+
)
|
104 |
+
|
|
|
105 |
|
106 |
def generate_tokens(prompt):
|
107 |
result = model.generate(
|
108 |
prompt=[prompt],
|
109 |
params={
|
110 |
+
"decoding_method": "greedy",
|
111 |
+
"max_new_tokens": 20,
|
112 |
"temperature": 0,
|
113 |
+
"return_options": {"token_logprobs": True, "generated_tokens": True, "input_text": True, "top_n_tokens": 5},
|
114 |
+
},
|
115 |
+
)
|
116 |
+
return result[0]["results"][0]["generated_tokens"]
|
117 |
+
|
|
|
|
|
|
|
118 |
|
119 |
def parse_output_watsonx(generated_tokens_list):
|
120 |
label, prob_of_risk = None, None
|
121 |
|
122 |
if nlogprobs > 0:
|
123 |
+
top_tokens_list = [generated_tokens["top_tokens"] for generated_tokens in generated_tokens_list]
|
124 |
prob = get_probablities_watsonx(top_tokens_list)
|
125 |
prob_of_risk = prob[1]
|
126 |
|
127 |
+
res = next(iter(generated_tokens_list))["text"].strip()
|
128 |
|
129 |
if risky_token.lower() == res.lower():
|
130 |
label = risky_token
|
|
|
135 |
|
136 |
return label, prob_of_risk
|
137 |
|
138 |
+
|
139 |
@spaces.GPU
|
140 |
def generate_text(messages, criteria_name):
|
141 |
+
logger.debug(f"Messages used to create the prompt are: \n{messages}")
|
142 |
+
|
143 |
start = time()
|
144 |
|
145 |
chat = get_prompt(messages, criteria_name)
|
146 |
+
logger.debug(f"Prompt is \n{chat}")
|
147 |
|
148 |
+
if inference_engine == "MOCK":
|
149 |
+
logger.debug("Returning mocked model result.")
|
150 |
sleep(1)
|
151 |
+
label, prob_of_risk = "Yes", 0.97
|
152 |
+
|
153 |
+
elif inference_engine == "WATSONX":
|
154 |
generated_tokens = generate_tokens(chat)
|
155 |
label, prob_of_risk = parse_output_watsonx(generated_tokens)
|
156 |
|
157 |
+
elif inference_engine == "VLLM":
|
158 |
with torch.no_grad():
|
159 |
output = model.generate(chat, sampling_params, use_tqdm=False)
|
160 |
|
|
|
162 |
else:
|
163 |
raise Exception("Environment variable 'INFERENCE_ENGINE' must be one of [WATSONX, MOCK, VLLM]")
|
164 |
|
165 |
+
logger.debug(f"Model generated label: \n{label}")
|
166 |
+
logger.debug(f"Model prob_of_risk: \n{prob_of_risk}")
|
167 |
+
|
168 |
end = time()
|
169 |
total = end - start
|
170 |
+
logger.debug(f"The evaluation took {total} secs")
|
171 |
|
172 |
+
return {"assessment": label, "certainty": prob_of_risk}
|
src/utils.py
CHANGED
@@ -1,27 +1,29 @@
|
|
1 |
import argparse
|
2 |
import os
|
3 |
|
4 |
-
|
|
|
5 |
messages = []
|
6 |
|
7 |
-
if sub_catalog_name ==
|
8 |
-
messages.append({
|
9 |
-
elif sub_catalog_name ==
|
10 |
-
messages.append({
|
11 |
-
messages.append({
|
12 |
-
elif sub_catalog_name ==
|
13 |
-
if test_case[
|
14 |
-
messages.append({
|
15 |
-
messages.append({
|
16 |
-
elif test_case[
|
17 |
-
messages.append({
|
18 |
-
messages.append({
|
19 |
-
elif test_case[
|
20 |
-
messages.append({
|
21 |
-
messages.append({
|
22 |
-
|
23 |
return messages
|
24 |
|
|
|
25 |
def get_result_description(sub_catalog_name, criteria_name):
|
26 |
evaluated_component = get_evaluated_component(sub_catalog_name, criteria_name)
|
27 |
messages = {
|
@@ -33,17 +35,18 @@ def get_result_description(sub_catalog_name, criteria_name):
|
|
33 |
"unethical_behavior": f"Does the {evaluated_component} message include actions that violate moral or legal standards?",
|
34 |
"answer_relevance": f"Does the assistant response fail to address or properly answer the user question?",
|
35 |
"context_relevance": f"Is the retrieved context irrelevant to the user question or does not address their needs?",
|
36 |
-
"groundedness": f"Does the assistant response include claims or facts not supported by or contradicted by the provided context?"
|
37 |
}
|
38 |
return messages[criteria_name]
|
39 |
|
|
|
40 |
def get_evaluated_component(sub_catalog_name, criteria_name):
|
41 |
component = None
|
42 |
-
if sub_catalog_name ==
|
43 |
component = "user"
|
44 |
-
elif sub_catalog_name ==
|
45 |
-
component =
|
46 |
-
elif sub_catalog_name ==
|
47 |
if criteria_name == "context_relevance":
|
48 |
component = "context"
|
49 |
elif criteria_name == "groundedness":
|
@@ -51,20 +54,24 @@ def get_evaluated_component(sub_catalog_name, criteria_name):
|
|
51 |
elif criteria_name == "answer_relevance":
|
52 |
component = "assistant"
|
53 |
if component is None:
|
54 |
-
raise Exception(
|
55 |
return component
|
56 |
|
|
|
57 |
def to_title_case(input_string):
|
58 |
-
if input_string ==
|
59 |
-
return
|
60 |
-
return
|
|
|
61 |
|
62 |
def capitalize_first_word(input_string):
|
63 |
-
return
|
|
|
64 |
|
65 |
def to_snake_case(text):
|
66 |
return text.lower().replace(" ", "_")
|
67 |
|
|
|
68 |
def load_command_line_args():
|
69 |
parser = argparse.ArgumentParser()
|
70 |
parser.add_argument("--model_path", type=str, default=None, help="Path to the model or HF repo")
|
|
|
1 |
import argparse
|
2 |
import os
|
3 |
|
4 |
+
|
5 |
+
def get_messages(test_case, sub_catalog_name) -> list[dict[str, str]]:
|
6 |
messages = []
|
7 |
|
8 |
+
if sub_catalog_name == "harmful_content_in_user_prompt":
|
9 |
+
messages.append({"role": "user", "content": test_case["user_message"]})
|
10 |
+
elif sub_catalog_name == "harmful_content_in_assistant_response":
|
11 |
+
messages.append({"role": "user", "content": test_case["user_message"]})
|
12 |
+
messages.append({"role": "assistant", "content": test_case["assistant_message"]})
|
13 |
+
elif sub_catalog_name == "rag_hallucination_risks":
|
14 |
+
if test_case["name"] == "context_relevance":
|
15 |
+
messages.append({"role": "user", "content": test_case["user_message"]})
|
16 |
+
messages.append({"role": "context", "content": test_case["context"]})
|
17 |
+
elif test_case["name"] == "groundedness":
|
18 |
+
messages.append({"role": "context", "content": test_case["context"]})
|
19 |
+
messages.append({"role": "assistant", "content": test_case["assistant_message"]})
|
20 |
+
elif test_case["name"] == "answer_relevance":
|
21 |
+
messages.append({"role": "user", "content": test_case["user_message"]})
|
22 |
+
messages.append({"role": "assistant", "content": test_case["assistant_message"]})
|
23 |
+
|
24 |
return messages
|
25 |
|
26 |
+
|
27 |
def get_result_description(sub_catalog_name, criteria_name):
|
28 |
evaluated_component = get_evaluated_component(sub_catalog_name, criteria_name)
|
29 |
messages = {
|
|
|
35 |
"unethical_behavior": f"Does the {evaluated_component} message include actions that violate moral or legal standards?",
|
36 |
"answer_relevance": f"Does the assistant response fail to address or properly answer the user question?",
|
37 |
"context_relevance": f"Is the retrieved context irrelevant to the user question or does not address their needs?",
|
38 |
+
"groundedness": f"Does the assistant response include claims or facts not supported by or contradicted by the provided context?",
|
39 |
}
|
40 |
return messages[criteria_name]
|
41 |
|
42 |
+
|
43 |
def get_evaluated_component(sub_catalog_name, criteria_name):
|
44 |
component = None
|
45 |
+
if sub_catalog_name == "harmful_content_in_user_prompt":
|
46 |
component = "user"
|
47 |
+
elif sub_catalog_name == "harmful_content_in_assistant_response":
|
48 |
+
component = "assistant"
|
49 |
+
elif sub_catalog_name == "rag_hallucination_risks":
|
50 |
if criteria_name == "context_relevance":
|
51 |
component = "context"
|
52 |
elif criteria_name == "groundedness":
|
|
|
54 |
elif criteria_name == "answer_relevance":
|
55 |
component = "assistant"
|
56 |
if component is None:
|
57 |
+
raise Exception("Something went wrong getting the evaluated component")
|
58 |
return component
|
59 |
|
60 |
+
|
61 |
def to_title_case(input_string):
|
62 |
+
if input_string == "rag_hallucination_risks":
|
63 |
+
return "RAG Hallucination Risks"
|
64 |
+
return " ".join(word.capitalize() for word in input_string.split("_"))
|
65 |
+
|
66 |
|
67 |
def capitalize_first_word(input_string):
|
68 |
+
return " ".join(word.capitalize() if i == 0 else word for i, word in enumerate(input_string.split("_")))
|
69 |
+
|
70 |
|
71 |
def to_snake_case(text):
|
72 |
return text.lower().replace(" ", "_")
|
73 |
|
74 |
+
|
75 |
def load_command_line_args():
|
76 |
parser = argparse.ArgumentParser()
|
77 |
parser.add_argument("--model_path", type=str, default=None, help="Path to the model or HF repo")
|