Spaces:
Running
Running
geekyrakshit
commited on
Merge pull request #4 from soumik12345/feat/eval-table
Browse files- .gitignore +4 -1
- README.md +6 -1
- app.py +8 -1
- application_pages/chat_app.py +40 -27
- application_pages/evaluation_app.py +84 -11
- application_pages/train_classifier.py +61 -0
- guardrails_genie/guardrails/__init__.py +5 -2
- guardrails_genie/guardrails/injection/__init__.py +2 -2
- guardrails_genie/guardrails/injection/{protectai_guardrail.py → classifier_guardrail.py} +16 -8
- guardrails_genie/train_classifier.py +110 -0
- guardrails_genie/utils.py +46 -0
- pyproject.toml +2 -2
.gitignore
CHANGED
@@ -165,4 +165,7 @@ cursor_prompts/
|
|
165 |
uv.lock
|
166 |
test.py
|
167 |
temp.txt
|
168 |
-
**.csv
|
|
|
|
|
|
|
|
165 |
uv.lock
|
166 |
test.py
|
167 |
temp.txt
|
168 |
+
**.csv
|
169 |
+
binary-classifier/
|
170 |
+
wandb/
|
171 |
+
artifacts/
|
README.md
CHANGED
@@ -18,7 +18,12 @@ source .venv/bin/activate
|
|
18 |
## Run the App
|
19 |
|
20 |
```bash
|
21 |
-
OPENAI_API_KEY="YOUR_OPENAI_API_KEY"
|
|
|
|
|
|
|
|
|
|
|
22 |
```
|
23 |
|
24 |
## Use the Library
|
|
|
18 |
## Run the App
|
19 |
|
20 |
```bash
|
21 |
+
export OPENAI_API_KEY="YOUR_OPENAI_API_KEY"
|
22 |
+
export WEAVE_PROJECT="YOUR_WEAVE_PROJECT"
|
23 |
+
export WANDB_PROJECT_NAME="YOUR_WANDB_PROJECT_NAME"
|
24 |
+
export WANDB_ENTITY_NAME="YOUR_WANDB_ENTITY_NAME"
|
25 |
+
export WANDB_LOG_MODEL="checkpoint"
|
26 |
+
streamlit run app.py
|
27 |
```
|
28 |
|
29 |
## Use the Library
|
app.py
CHANGED
@@ -13,6 +13,13 @@ evaluation_page = st.Page(
|
|
13 |
title="Evaluation",
|
14 |
icon=":material/monitoring:",
|
15 |
)
|
16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
st.set_page_config(page_title="Guardrails Genie", page_icon=":material/guardian:")
|
18 |
page_navigation.run()
|
|
|
13 |
title="Evaluation",
|
14 |
icon=":material/monitoring:",
|
15 |
)
|
16 |
+
train_classifier_page = st.Page(
|
17 |
+
"application_pages/train_classifier.py",
|
18 |
+
title="Train Classifier",
|
19 |
+
icon=":material/fitness_center:",
|
20 |
+
)
|
21 |
+
page_navigation = st.navigation(
|
22 |
+
[intro_page, chat_page, evaluation_page, train_classifier_page]
|
23 |
+
)
|
24 |
st.set_page_config(page_title="Guardrails Genie", page_icon=":material/guardian:")
|
25 |
page_navigation.run()
|
application_pages/chat_app.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import importlib
|
|
|
2 |
|
3 |
import streamlit as st
|
4 |
import weave
|
@@ -7,27 +8,27 @@ from dotenv import load_dotenv
|
|
7 |
from guardrails_genie.guardrails import GuardrailManager
|
8 |
from guardrails_genie.llm import OpenAIModel
|
9 |
|
10 |
-
st.title(":material/robot: Guardrails Genie Playground")
|
11 |
-
|
12 |
-
load_dotenv()
|
13 |
-
weave.init(project_name="guardrails-genie")
|
14 |
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
if "
|
20 |
-
|
21 |
-
if "
|
22 |
-
|
23 |
-
if "
|
24 |
-
|
25 |
-
if "
|
26 |
-
|
27 |
-
if "
|
28 |
-
|
29 |
-
if "
|
30 |
-
|
|
|
|
|
|
|
|
|
31 |
|
32 |
|
33 |
def initialize_guardrails():
|
@@ -44,18 +45,30 @@ def initialize_guardrails():
|
|
44 |
guardrail_name,
|
45 |
)(llm_model=OpenAIModel(model_name=survey_guardrail_model))
|
46 |
)
|
47 |
-
|
48 |
-
st.
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
|
|
|
|
53 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
st.session_state.guardrails_manager = GuardrailManager(
|
55 |
guardrails=st.session_state.guardrails
|
56 |
)
|
57 |
|
58 |
|
|
|
|
|
|
|
59 |
openai_model = st.sidebar.selectbox(
|
60 |
"OpenAI LLM for Chat", ["", "gpt-4o-mini", "gpt-4o"]
|
61 |
)
|
@@ -97,7 +110,7 @@ if st.session_state.initialize_guardrails:
|
|
97 |
|
98 |
if guardrails_response["safe"]:
|
99 |
st.markdown(
|
100 |
-
f"\n\n---\nPrompt is safe! Explore
|
101 |
)
|
102 |
|
103 |
with st.sidebar.status("Generating response from LLM..."):
|
|
|
1 |
import importlib
|
2 |
+
import os
|
3 |
|
4 |
import streamlit as st
|
5 |
import weave
|
|
|
8 |
from guardrails_genie.guardrails import GuardrailManager
|
9 |
from guardrails_genie.llm import OpenAIModel
|
10 |
|
|
|
|
|
|
|
|
|
11 |
|
12 |
+
def initialize_session_state():
|
13 |
+
load_dotenv()
|
14 |
+
weave.init(project_name=os.getenv("WEAVE_PROJECT"))
|
15 |
+
|
16 |
+
if "guardrails" not in st.session_state:
|
17 |
+
st.session_state.guardrails = []
|
18 |
+
if "guardrail_names" not in st.session_state:
|
19 |
+
st.session_state.guardrail_names = []
|
20 |
+
if "guardrails_manager" not in st.session_state:
|
21 |
+
st.session_state.guardrails_manager = None
|
22 |
+
if "initialize_guardrails" not in st.session_state:
|
23 |
+
st.session_state.initialize_guardrails = False
|
24 |
+
if "system_prompt" not in st.session_state:
|
25 |
+
st.session_state.system_prompt = ""
|
26 |
+
if "user_prompt" not in st.session_state:
|
27 |
+
st.session_state.user_prompt = ""
|
28 |
+
if "test_guardrails" not in st.session_state:
|
29 |
+
st.session_state.test_guardrails = False
|
30 |
+
if "llm_model" not in st.session_state:
|
31 |
+
st.session_state.llm_model = None
|
32 |
|
33 |
|
34 |
def initialize_guardrails():
|
|
|
45 |
guardrail_name,
|
46 |
)(llm_model=OpenAIModel(model_name=survey_guardrail_model))
|
47 |
)
|
48 |
+
elif guardrail_name == "PromptInjectionClassifierGuardrail":
|
49 |
+
classifier_model_name = st.sidebar.selectbox(
|
50 |
+
"Classifier Guardrail Model",
|
51 |
+
[
|
52 |
+
"",
|
53 |
+
"ProtectAI/deberta-v3-base-prompt-injection-v2",
|
54 |
+
"wandb://geekyrakshit/guardrails-genie/model-6rwqup9b:v3",
|
55 |
+
],
|
56 |
)
|
57 |
+
if classifier_model_name != "":
|
58 |
+
st.session_state.guardrails.append(
|
59 |
+
getattr(
|
60 |
+
importlib.import_module("guardrails_genie.guardrails"),
|
61 |
+
guardrail_name,
|
62 |
+
)(model_name=classifier_model_name)
|
63 |
+
)
|
64 |
st.session_state.guardrails_manager = GuardrailManager(
|
65 |
guardrails=st.session_state.guardrails
|
66 |
)
|
67 |
|
68 |
|
69 |
+
initialize_session_state()
|
70 |
+
st.title(":material/robot: Guardrails Genie Playground")
|
71 |
+
|
72 |
openai_model = st.sidebar.selectbox(
|
73 |
"OpenAI LLM for Chat", ["", "gpt-4o-mini", "gpt-4o"]
|
74 |
)
|
|
|
110 |
|
111 |
if guardrails_response["safe"]:
|
112 |
st.markdown(
|
113 |
+
f"\n\n---\nPrompt is safe! Explore guardrail trace on [Weave]({call.ui_url})\n\n---\n"
|
114 |
)
|
115 |
|
116 |
with st.sidebar.status("Generating response from LLM..."):
|
application_pages/evaluation_app.py
CHANGED
@@ -1,7 +1,10 @@
|
|
1 |
import asyncio
|
|
|
|
|
2 |
from importlib import import_module
|
3 |
|
4 |
import pandas as pd
|
|
|
5 |
import streamlit as st
|
6 |
import weave
|
7 |
from dotenv import load_dotenv
|
@@ -9,12 +12,11 @@ from dotenv import load_dotenv
|
|
9 |
from guardrails_genie.guardrails import GuardrailManager
|
10 |
from guardrails_genie.llm import OpenAIModel
|
11 |
from guardrails_genie.metrics import AccuracyMetric
|
12 |
-
|
13 |
-
load_dotenv()
|
14 |
-
weave.init(project_name="guardrails-genie")
|
15 |
|
16 |
|
17 |
def initialize_session_state():
|
|
|
18 |
if "uploaded_file" not in st.session_state:
|
19 |
st.session_state.uploaded_file = None
|
20 |
if "dataset_name" not in st.session_state:
|
@@ -35,6 +37,18 @@ def initialize_session_state():
|
|
35 |
st.session_state.evaluation_summary = None
|
36 |
if "guardrail_manager" not in st.session_state:
|
37 |
st.session_state.guardrail_manager = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
|
40 |
def initialize_guardrail():
|
@@ -51,10 +65,22 @@ def initialize_guardrail():
|
|
51 |
guardrail_name,
|
52 |
)(llm_model=OpenAIModel(model_name=survey_guardrail_model))
|
53 |
)
|
54 |
-
|
55 |
-
|
56 |
-
|
|
|
|
|
|
|
|
|
|
|
57 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
st.session_state.guardrails = guardrails
|
59 |
st.session_state.guardrail_manager = GuardrailManager(guardrails=guardrails)
|
60 |
|
@@ -107,6 +133,8 @@ if st.session_state.dataset_previewed:
|
|
107 |
|
108 |
if st.session_state.guardrail_names != []:
|
109 |
initialize_guardrail()
|
|
|
|
|
110 |
if st.session_state.guardrail_manager is not None:
|
111 |
if st.sidebar.button("Start Evaluation"):
|
112 |
st.session_state.start_evaluation = True
|
@@ -119,10 +147,55 @@ if st.session_state.dataset_previewed:
|
|
119 |
with st.expander("Evaluation Results", expanded=True):
|
120 |
evaluation_summary, call = asyncio.run(
|
121 |
evaluation.evaluate.call(
|
122 |
-
evaluation,
|
|
|
|
|
|
|
|
|
|
|
123 |
)
|
124 |
)
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import asyncio
|
2 |
+
import os
|
3 |
+
import time
|
4 |
from importlib import import_module
|
5 |
|
6 |
import pandas as pd
|
7 |
+
import rich
|
8 |
import streamlit as st
|
9 |
import weave
|
10 |
from dotenv import load_dotenv
|
|
|
12 |
from guardrails_genie.guardrails import GuardrailManager
|
13 |
from guardrails_genie.llm import OpenAIModel
|
14 |
from guardrails_genie.metrics import AccuracyMetric
|
15 |
+
from guardrails_genie.utils import EvaluationCallManager
|
|
|
|
|
16 |
|
17 |
|
18 |
def initialize_session_state():
|
19 |
+
load_dotenv()
|
20 |
if "uploaded_file" not in st.session_state:
|
21 |
st.session_state.uploaded_file = None
|
22 |
if "dataset_name" not in st.session_state:
|
|
|
37 |
st.session_state.evaluation_summary = None
|
38 |
if "guardrail_manager" not in st.session_state:
|
39 |
st.session_state.guardrail_manager = None
|
40 |
+
if "evaluation_name" not in st.session_state:
|
41 |
+
st.session_state.evaluation_name = ""
|
42 |
+
if "show_result_table" not in st.session_state:
|
43 |
+
st.session_state.show_result_table = False
|
44 |
+
if "weave_client" not in st.session_state:
|
45 |
+
st.session_state.weave_client = weave.init(
|
46 |
+
project_name=os.getenv("WEAVE_PROJECT")
|
47 |
+
)
|
48 |
+
if "evaluation_call_manager" not in st.session_state:
|
49 |
+
st.session_state.evaluation_call_manager = None
|
50 |
+
if "call_id" not in st.session_state:
|
51 |
+
st.session_state.call_id = None
|
52 |
|
53 |
|
54 |
def initialize_guardrail():
|
|
|
65 |
guardrail_name,
|
66 |
)(llm_model=OpenAIModel(model_name=survey_guardrail_model))
|
67 |
)
|
68 |
+
elif guardrail_name == "PromptInjectionClassifierGuardrail":
|
69 |
+
classifier_model_name = st.sidebar.selectbox(
|
70 |
+
"Classifier Guardrail Model",
|
71 |
+
[
|
72 |
+
"",
|
73 |
+
"ProtectAI/deberta-v3-base-prompt-injection-v2",
|
74 |
+
"wandb://geekyrakshit/guardrails-genie/model-6rwqup9b:v3",
|
75 |
+
],
|
76 |
)
|
77 |
+
if classifier_model_name:
|
78 |
+
st.session_state.guardrails.append(
|
79 |
+
getattr(
|
80 |
+
import_module("guardrails_genie.guardrails"),
|
81 |
+
guardrail_name,
|
82 |
+
)(model_name=classifier_model_name)
|
83 |
+
)
|
84 |
st.session_state.guardrails = guardrails
|
85 |
st.session_state.guardrail_manager = GuardrailManager(guardrails=guardrails)
|
86 |
|
|
|
133 |
|
134 |
if st.session_state.guardrail_names != []:
|
135 |
initialize_guardrail()
|
136 |
+
evaluation_name = st.sidebar.text_input("Evaluation name", value="")
|
137 |
+
st.session_state.evaluation_name = evaluation_name
|
138 |
if st.session_state.guardrail_manager is not None:
|
139 |
if st.sidebar.button("Start Evaluation"):
|
140 |
st.session_state.start_evaluation = True
|
|
|
147 |
with st.expander("Evaluation Results", expanded=True):
|
148 |
evaluation_summary, call = asyncio.run(
|
149 |
evaluation.evaluate.call(
|
150 |
+
evaluation,
|
151 |
+
st.session_state.guardrail_manager,
|
152 |
+
__weave={
|
153 |
+
"display_name": "Evaluation.evaluate:"
|
154 |
+
+ st.session_state.evaluation_name
|
155 |
+
},
|
156 |
)
|
157 |
)
|
158 |
+
x_axis = list(evaluation_summary["AccuracyMetric"].keys())
|
159 |
+
y_axis = [
|
160 |
+
evaluation_summary["AccuracyMetric"][x_axis_item]
|
161 |
+
for x_axis_item in x_axis
|
162 |
+
]
|
163 |
+
st.bar_chart(
|
164 |
+
pd.DataFrame({"Metric": x_axis, "Score": y_axis}),
|
165 |
+
x="Metric",
|
166 |
+
y="Score",
|
167 |
+
)
|
168 |
+
st.session_state.evaluation_summary = evaluation_summary
|
169 |
+
st.session_state.call_id = call.id
|
170 |
+
st.session_state.start_evaluation = False
|
171 |
+
|
172 |
+
if not st.session_state.start_evaluation:
|
173 |
+
time.sleep(5)
|
174 |
+
st.session_state.evaluation_call_manager = (
|
175 |
+
EvaluationCallManager(
|
176 |
+
entity="geekyrakshit",
|
177 |
+
project="guardrails-genie",
|
178 |
+
call_id=st.session_state.call_id,
|
179 |
+
)
|
180 |
+
)
|
181 |
+
for guardrail_name in st.session_state.guardrail_names:
|
182 |
+
st.session_state.evaluation_call_manager.call_list.append(
|
183 |
+
{
|
184 |
+
"guardrail_name": guardrail_name,
|
185 |
+
"calls": st.session_state.evaluation_call_manager.collect_guardrail_guard_calls_from_eval(),
|
186 |
+
}
|
187 |
+
)
|
188 |
+
rich.print(
|
189 |
+
st.session_state.evaluation_call_manager.call_list
|
190 |
+
)
|
191 |
+
st.dataframe(
|
192 |
+
st.session_state.evaluation_call_manager.render_calls_to_streamlit()
|
193 |
+
)
|
194 |
+
if st.session_state.evaluation_call_manager.show_warning_in_app:
|
195 |
+
st.warning(
|
196 |
+
f"Only {st.session_state.evaluation_call_manager.max_count} calls can be shown in the app."
|
197 |
+
)
|
198 |
+
st.markdown(
|
199 |
+
f"Explore the entire evaluation trace table in [Weave]({call.ui_url})"
|
200 |
+
)
|
201 |
+
st.session_state.evaluation_call_manager = None
|
application_pages/train_classifier.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import streamlit as st
|
4 |
+
from dotenv import load_dotenv
|
5 |
+
|
6 |
+
from guardrails_genie.train_classifier import train_binary_classifier
|
7 |
+
|
8 |
+
|
9 |
+
def initialize_session_state():
|
10 |
+
load_dotenv()
|
11 |
+
if "dataset_name" not in st.session_state:
|
12 |
+
st.session_state.dataset_name = None
|
13 |
+
if "base_model_name" not in st.session_state:
|
14 |
+
st.session_state.base_model_name = None
|
15 |
+
if "batch_size" not in st.session_state:
|
16 |
+
st.session_state.batch_size = 16
|
17 |
+
if "should_start_training" not in st.session_state:
|
18 |
+
st.session_state.should_start_training = False
|
19 |
+
if "training_output" not in st.session_state:
|
20 |
+
st.session_state.training_output = None
|
21 |
+
|
22 |
+
|
23 |
+
initialize_session_state()
|
24 |
+
st.title(":material/fitness_center: Train Classifier")
|
25 |
+
|
26 |
+
dataset_name = st.sidebar.text_input("Dataset Name", value="")
|
27 |
+
st.session_state.dataset_name = dataset_name
|
28 |
+
|
29 |
+
base_model_name = st.sidebar.selectbox(
|
30 |
+
"Base Model",
|
31 |
+
options=[
|
32 |
+
"distilbert/distilbert-base-uncased",
|
33 |
+
"FacebookAI/roberta-base",
|
34 |
+
"microsoft/deberta-v3-base",
|
35 |
+
],
|
36 |
+
)
|
37 |
+
st.session_state.base_model_name = base_model_name
|
38 |
+
|
39 |
+
batch_size = st.sidebar.slider(
|
40 |
+
"Batch Size", min_value=4, max_value=256, value=16, step=4
|
41 |
+
)
|
42 |
+
st.session_state.batch_size = batch_size
|
43 |
+
|
44 |
+
train_button = st.sidebar.button("Train")
|
45 |
+
st.session_state.should_start_training = (
|
46 |
+
train_button and st.session_state.dataset_name and st.session_state.base_model_name
|
47 |
+
)
|
48 |
+
|
49 |
+
if st.session_state.should_start_training:
|
50 |
+
with st.expander("Training", expanded=True):
|
51 |
+
training_output = train_binary_classifier(
|
52 |
+
project_name=os.getenv("WANDB_PROJECT_NAME"),
|
53 |
+
entity_name=os.getenv("WANDB_ENTITY_NAME"),
|
54 |
+
run_name=f"{st.session_state.base_model_name}-finetuned",
|
55 |
+
dataset_repo=st.session_state.dataset_name,
|
56 |
+
model_name=st.session_state.base_model_name,
|
57 |
+
batch_size=st.session_state.batch_size,
|
58 |
+
streamlit_mode=True,
|
59 |
+
)
|
60 |
+
st.session_state.training_output = training_output
|
61 |
+
st.write(training_output)
|
guardrails_genie/guardrails/__init__.py
CHANGED
@@ -1,8 +1,11 @@
|
|
1 |
-
from .injection import
|
|
|
|
|
|
|
2 |
from .manager import GuardrailManager
|
3 |
|
4 |
__all__ = [
|
5 |
"PromptInjectionSurveyGuardrail",
|
6 |
-
"
|
7 |
"GuardrailManager",
|
8 |
]
|
|
|
1 |
+
from .injection import (
|
2 |
+
PromptInjectionClassifierGuardrail,
|
3 |
+
PromptInjectionSurveyGuardrail,
|
4 |
+
)
|
5 |
from .manager import GuardrailManager
|
6 |
|
7 |
__all__ = [
|
8 |
"PromptInjectionSurveyGuardrail",
|
9 |
+
"PromptInjectionClassifierGuardrail",
|
10 |
"GuardrailManager",
|
11 |
]
|
guardrails_genie/guardrails/injection/__init__.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
from .
|
2 |
from .survey_guardrail import PromptInjectionSurveyGuardrail
|
3 |
|
4 |
-
__all__ = ["PromptInjectionSurveyGuardrail", "
|
|
|
1 |
+
from .classifier_guardrail import PromptInjectionClassifierGuardrail
|
2 |
from .survey_guardrail import PromptInjectionSurveyGuardrail
|
3 |
|
4 |
+
__all__ = ["PromptInjectionSurveyGuardrail", "PromptInjectionClassifierGuardrail"]
|
guardrails_genie/guardrails/injection/{protectai_guardrail.py → classifier_guardrail.py}
RENAMED
@@ -5,16 +5,25 @@ import weave
|
|
5 |
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
|
6 |
from transformers.pipelines.base import Pipeline
|
7 |
|
|
|
|
|
8 |
from ..base import Guardrail
|
9 |
|
10 |
|
11 |
-
class
|
12 |
model_name: str = "ProtectAI/deberta-v3-base-prompt-injection-v2"
|
13 |
_classifier: Optional[Pipeline] = None
|
14 |
|
15 |
def model_post_init(self, __context):
|
16 |
-
|
17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
self._classifier = pipeline(
|
19 |
"text-classification",
|
20 |
model=model,
|
@@ -28,11 +37,6 @@ class PromptInjectionProtectAIGuardrail(Guardrail):
|
|
28 |
def classify(self, prompt: str):
|
29 |
return self._classifier(prompt)
|
30 |
|
31 |
-
@weave.op()
|
32 |
-
def predict(self, prompt: str):
|
33 |
-
response = self.classify(prompt)
|
34 |
-
return {"safe": response[0]["label"] != "INJECTION"}
|
35 |
-
|
36 |
@weave.op()
|
37 |
def guard(self, prompt: str):
|
38 |
response = self.classify(prompt)
|
@@ -41,3 +45,7 @@ class PromptInjectionProtectAIGuardrail(Guardrail):
|
|
41 |
"safe": response[0]["label"] != "INJECTION",
|
42 |
"summary": f"Prompt is deemed {response[0]['label']} with {confidence_percentage}% confidence.",
|
43 |
}
|
|
|
|
|
|
|
|
|
|
5 |
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
|
6 |
from transformers.pipelines.base import Pipeline
|
7 |
|
8 |
+
import wandb
|
9 |
+
|
10 |
from ..base import Guardrail
|
11 |
|
12 |
|
13 |
+
class PromptInjectionClassifierGuardrail(Guardrail):
|
14 |
model_name: str = "ProtectAI/deberta-v3-base-prompt-injection-v2"
|
15 |
_classifier: Optional[Pipeline] = None
|
16 |
|
17 |
def model_post_init(self, __context):
|
18 |
+
if self.model_name.startswith("wandb://"):
|
19 |
+
api = wandb.Api()
|
20 |
+
artifact = api.artifact(self.model_name.removeprefix("wandb://"))
|
21 |
+
artifact_dir = artifact.download()
|
22 |
+
tokenizer = AutoTokenizer.from_pretrained(artifact_dir)
|
23 |
+
model = AutoModelForSequenceClassification.from_pretrained(artifact_dir)
|
24 |
+
else:
|
25 |
+
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
26 |
+
model = AutoModelForSequenceClassification.from_pretrained(self.model_name)
|
27 |
self._classifier = pipeline(
|
28 |
"text-classification",
|
29 |
model=model,
|
|
|
37 |
def classify(self, prompt: str):
|
38 |
return self._classifier(prompt)
|
39 |
|
|
|
|
|
|
|
|
|
|
|
40 |
@weave.op()
|
41 |
def guard(self, prompt: str):
|
42 |
response = self.classify(prompt)
|
|
|
45 |
"safe": response[0]["label"] != "INJECTION",
|
46 |
"summary": f"Prompt is deemed {response[0]['label']} with {confidence_percentage}% confidence.",
|
47 |
}
|
48 |
+
|
49 |
+
@weave.op()
|
50 |
+
def predict(self, prompt: str):
|
51 |
+
return self.guard(prompt)
|
guardrails_genie/train_classifier.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import evaluate
|
2 |
+
import numpy as np
|
3 |
+
import streamlit as st
|
4 |
+
from datasets import load_dataset
|
5 |
+
from transformers import (
|
6 |
+
AutoModelForSequenceClassification,
|
7 |
+
AutoTokenizer,
|
8 |
+
DataCollatorWithPadding,
|
9 |
+
Trainer,
|
10 |
+
TrainerCallback,
|
11 |
+
TrainingArguments,
|
12 |
+
)
|
13 |
+
from transformers.trainer_callback import TrainerControl, TrainerState
|
14 |
+
|
15 |
+
import wandb
|
16 |
+
|
17 |
+
|
18 |
+
class StreamlitProgressbarCallback(TrainerCallback):
|
19 |
+
|
20 |
+
def __init__(self, *args, **kwargs):
|
21 |
+
super().__init__(*args, **kwargs)
|
22 |
+
self.progress_bar = st.progress(0, text="Training")
|
23 |
+
|
24 |
+
def on_step_begin(
|
25 |
+
self,
|
26 |
+
args: TrainingArguments,
|
27 |
+
state: TrainerState,
|
28 |
+
control: TrainerControl,
|
29 |
+
**kwargs,
|
30 |
+
):
|
31 |
+
super().on_step_begin(args, state, control, **kwargs)
|
32 |
+
self.progress_bar.progress(
|
33 |
+
(state.global_step * 100 // state.max_steps) + 1,
|
34 |
+
text=f"Training {state.global_step} / {state.max_steps}",
|
35 |
+
)
|
36 |
+
|
37 |
+
|
38 |
+
def train_binary_classifier(
|
39 |
+
project_name: str,
|
40 |
+
entity_name: str,
|
41 |
+
run_name: str,
|
42 |
+
dataset_repo: str = "geekyrakshit/prompt-injection-dataset",
|
43 |
+
model_name: str = "distilbert/distilbert-base-uncased",
|
44 |
+
learning_rate: float = 2e-5,
|
45 |
+
batch_size: int = 16,
|
46 |
+
num_epochs: int = 2,
|
47 |
+
weight_decay: float = 0.01,
|
48 |
+
streamlit_mode: bool = False,
|
49 |
+
):
|
50 |
+
wandb.init(project=project_name, entity=entity_name, name=run_name)
|
51 |
+
if streamlit_mode:
|
52 |
+
st.markdown(
|
53 |
+
f"Explore your training logs on [Weights & Biases]({wandb.run.url})"
|
54 |
+
)
|
55 |
+
dataset = load_dataset(dataset_repo)
|
56 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
57 |
+
|
58 |
+
def preprocess_function(examples):
|
59 |
+
return tokenizer(examples["prompt"], truncation=True)
|
60 |
+
|
61 |
+
tokenized_datasets = dataset.map(preprocess_function, batched=True)
|
62 |
+
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
|
63 |
+
accuracy = evaluate.load("accuracy")
|
64 |
+
|
65 |
+
def compute_metrics(eval_pred):
|
66 |
+
predictions, labels = eval_pred
|
67 |
+
predictions = np.argmax(predictions, axis=1)
|
68 |
+
return accuracy.compute(predictions=predictions, references=labels)
|
69 |
+
|
70 |
+
id2label = {0: "SAFE", 1: "INJECTION"}
|
71 |
+
label2id = {"SAFE": 0, "INJECTION": 1}
|
72 |
+
|
73 |
+
model = AutoModelForSequenceClassification.from_pretrained(
|
74 |
+
model_name,
|
75 |
+
num_labels=2,
|
76 |
+
id2label=id2label,
|
77 |
+
label2id=label2id,
|
78 |
+
)
|
79 |
+
|
80 |
+
trainer = Trainer(
|
81 |
+
model=model,
|
82 |
+
args=TrainingArguments(
|
83 |
+
output_dir="binary-classifier",
|
84 |
+
learning_rate=learning_rate,
|
85 |
+
per_device_train_batch_size=batch_size,
|
86 |
+
per_device_eval_batch_size=batch_size,
|
87 |
+
num_train_epochs=num_epochs,
|
88 |
+
weight_decay=weight_decay,
|
89 |
+
eval_strategy="epoch",
|
90 |
+
save_strategy="epoch",
|
91 |
+
load_best_model_at_end=True,
|
92 |
+
push_to_hub=True,
|
93 |
+
report_to="wandb",
|
94 |
+
logging_strategy="steps",
|
95 |
+
logging_steps=1,
|
96 |
+
),
|
97 |
+
train_dataset=tokenized_datasets["train"],
|
98 |
+
eval_dataset=tokenized_datasets["test"],
|
99 |
+
processing_class=tokenizer,
|
100 |
+
data_collator=data_collator,
|
101 |
+
compute_metrics=compute_metrics,
|
102 |
+
callbacks=[StreamlitProgressbarCallback()] if streamlit_mode else [],
|
103 |
+
)
|
104 |
+
try:
|
105 |
+
training_output = trainer.train()
|
106 |
+
except Exception as e:
|
107 |
+
wandb.finish()
|
108 |
+
raise e
|
109 |
+
wandb.finish()
|
110 |
+
return training_output
|
guardrails_genie/utils.py
CHANGED
@@ -1,7 +1,9 @@
|
|
1 |
import os
|
2 |
|
|
|
3 |
import pymupdf4llm
|
4 |
import weave
|
|
|
5 |
from firerequests import FireRequests
|
6 |
|
7 |
|
@@ -11,3 +13,47 @@ def get_markdown_from_pdf_url(url: str) -> str:
|
|
11 |
markdown = pymupdf4llm.to_markdown("temp.pdf", show_progress=False)
|
12 |
os.remove("temp.pdf")
|
13 |
return markdown
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
|
3 |
+
import pandas as pd
|
4 |
import pymupdf4llm
|
5 |
import weave
|
6 |
+
import weave.trace
|
7 |
from firerequests import FireRequests
|
8 |
|
9 |
|
|
|
13 |
markdown = pymupdf4llm.to_markdown("temp.pdf", show_progress=False)
|
14 |
os.remove("temp.pdf")
|
15 |
return markdown
|
16 |
+
|
17 |
+
|
18 |
+
class EvaluationCallManager:
|
19 |
+
def __init__(self, entity: str, project: str, call_id: str, max_count: int = 10):
|
20 |
+
self.base_call = weave.init(f"{entity}/{project}").get_call(call_id=call_id)
|
21 |
+
self.max_count = max_count
|
22 |
+
self.show_warning_in_app = False
|
23 |
+
self.call_list = []
|
24 |
+
|
25 |
+
def collect_guardrail_guard_calls_from_eval(self):
|
26 |
+
guard_calls, count = [], 0
|
27 |
+
for eval_predict_and_score_call in self.base_call.children():
|
28 |
+
if "Evaluation.summarize" in eval_predict_and_score_call._op_name:
|
29 |
+
break
|
30 |
+
guardrail_predict_call = eval_predict_and_score_call.children()[0]
|
31 |
+
guard_call = guardrail_predict_call.children()[0]
|
32 |
+
score_call = eval_predict_and_score_call.children()[1]
|
33 |
+
guard_calls.append(
|
34 |
+
{
|
35 |
+
"input_prompt": str(guard_call.inputs["prompt"]),
|
36 |
+
"outputs": dict(guard_call.output),
|
37 |
+
"score": dict(score_call.output),
|
38 |
+
}
|
39 |
+
)
|
40 |
+
count += 1
|
41 |
+
if count >= self.max_count:
|
42 |
+
self.show_warning_in_app = True
|
43 |
+
break
|
44 |
+
return guard_calls
|
45 |
+
|
46 |
+
def render_calls_to_streamlit(self):
|
47 |
+
dataframe = {
|
48 |
+
"input_prompt": [
|
49 |
+
call["input_prompt"] for call in self.call_list[0]["calls"]
|
50 |
+
]
|
51 |
+
}
|
52 |
+
for guardrail_call in self.call_list:
|
53 |
+
dataframe[guardrail_call["guardrail_name"] + ".safe"] = [
|
54 |
+
call["outputs"]["safe"] for call in guardrail_call["calls"]
|
55 |
+
]
|
56 |
+
dataframe[guardrail_call["guardrail_name"] + ".prediction_correctness"] = [
|
57 |
+
call["score"]["correct"] for call in guardrail_call["calls"]
|
58 |
+
]
|
59 |
+
return pd.DataFrame(dataframe)
|
pyproject.toml
CHANGED
@@ -12,7 +12,7 @@ dependencies = [
|
|
12 |
"ruff>=0.6.9",
|
13 |
"pip>=24.2",
|
14 |
"uv>=0.4.20",
|
15 |
-
"git+https://github.com/wandb/weave@feat/eval-progressbar",
|
16 |
"streamlit>=1.40.1",
|
17 |
"python-dotenv>=1.0.1",
|
18 |
"watchdog>=6.0.0",
|
@@ -23,4 +23,4 @@ dependencies = [
|
|
23 |
]
|
24 |
|
25 |
[tool.setuptools]
|
26 |
-
py-modules = ["guardrails_genie"]
|
|
|
12 |
"ruff>=0.6.9",
|
13 |
"pip>=24.2",
|
14 |
"uv>=0.4.20",
|
15 |
+
"weave @ git+https://github.com/wandb/weave@feat/eval-progressbar",
|
16 |
"streamlit>=1.40.1",
|
17 |
"python-dotenv>=1.0.1",
|
18 |
"watchdog>=6.0.0",
|
|
|
23 |
]
|
24 |
|
25 |
[tool.setuptools]
|
26 |
+
py-modules = ["guardrails_genie"]
|