geekyrakshit commited on
Commit
b8c0bf9
·
unverified ·
2 Parent(s): af688eb 1ec997f

Merge pull request #4 from soumik12345/feat/eval-table

Browse files
.gitignore CHANGED
@@ -165,4 +165,7 @@ cursor_prompts/
165
  uv.lock
166
  test.py
167
  temp.txt
168
- **.csv
 
 
 
 
165
  uv.lock
166
  test.py
167
  temp.txt
168
+ **.csv
169
+ binary-classifier/
170
+ wandb/
171
+ artifacts/
README.md CHANGED
@@ -18,7 +18,12 @@ source .venv/bin/activate
18
  ## Run the App
19
 
20
  ```bash
21
- OPENAI_API_KEY="YOUR_OPENAI_API_KEY" streamlit run app.py
 
 
 
 
 
22
  ```
23
 
24
  ## Use the Library
 
18
  ## Run the App
19
 
20
  ```bash
21
+ export OPENAI_API_KEY="YOUR_OPENAI_API_KEY"
22
+ export WEAVE_PROJECT="YOUR_WEAVE_PROJECT"
23
+ export WANDB_PROJECT_NAME="YOUR_WANDB_PROJECT_NAME"
24
+ export WANDB_ENTITY_NAME="YOUR_WANDB_ENTITY_NAME"
25
+ export WANDB_LOG_MODEL="checkpoint"
26
+ streamlit run app.py
27
  ```
28
 
29
  ## Use the Library
app.py CHANGED
@@ -13,6 +13,13 @@ evaluation_page = st.Page(
13
  title="Evaluation",
14
  icon=":material/monitoring:",
15
  )
16
- page_navigation = st.navigation([intro_page, chat_page, evaluation_page])
 
 
 
 
 
 
 
17
  st.set_page_config(page_title="Guardrails Genie", page_icon=":material/guardian:")
18
  page_navigation.run()
 
13
  title="Evaluation",
14
  icon=":material/monitoring:",
15
  )
16
+ train_classifier_page = st.Page(
17
+ "application_pages/train_classifier.py",
18
+ title="Train Classifier",
19
+ icon=":material/fitness_center:",
20
+ )
21
+ page_navigation = st.navigation(
22
+ [intro_page, chat_page, evaluation_page, train_classifier_page]
23
+ )
24
  st.set_page_config(page_title="Guardrails Genie", page_icon=":material/guardian:")
25
  page_navigation.run()
application_pages/chat_app.py CHANGED
@@ -1,4 +1,5 @@
1
  import importlib
 
2
 
3
  import streamlit as st
4
  import weave
@@ -7,27 +8,27 @@ from dotenv import load_dotenv
7
  from guardrails_genie.guardrails import GuardrailManager
8
  from guardrails_genie.llm import OpenAIModel
9
 
10
- st.title(":material/robot: Guardrails Genie Playground")
11
-
12
- load_dotenv()
13
- weave.init(project_name="guardrails-genie")
14
 
15
- if "guardrails" not in st.session_state:
16
- st.session_state.guardrails = []
17
- if "guardrail_names" not in st.session_state:
18
- st.session_state.guardrail_names = []
19
- if "guardrails_manager" not in st.session_state:
20
- st.session_state.guardrails_manager = None
21
- if "initialize_guardrails" not in st.session_state:
22
- st.session_state.initialize_guardrails = False
23
- if "system_prompt" not in st.session_state:
24
- st.session_state.system_prompt = ""
25
- if "user_prompt" not in st.session_state:
26
- st.session_state.user_prompt = ""
27
- if "test_guardrails" not in st.session_state:
28
- st.session_state.test_guardrails = False
29
- if "llm_model" not in st.session_state:
30
- st.session_state.llm_model = None
 
 
 
 
31
 
32
 
33
  def initialize_guardrails():
@@ -44,18 +45,30 @@ def initialize_guardrails():
44
  guardrail_name,
45
  )(llm_model=OpenAIModel(model_name=survey_guardrail_model))
46
  )
47
- else:
48
- st.session_state.guardrails.append(
49
- getattr(
50
- importlib.import_module("guardrails_genie.guardrails"),
51
- guardrail_name,
52
- )()
 
 
53
  )
 
 
 
 
 
 
 
54
  st.session_state.guardrails_manager = GuardrailManager(
55
  guardrails=st.session_state.guardrails
56
  )
57
 
58
 
 
 
 
59
  openai_model = st.sidebar.selectbox(
60
  "OpenAI LLM for Chat", ["", "gpt-4o-mini", "gpt-4o"]
61
  )
@@ -97,7 +110,7 @@ if st.session_state.initialize_guardrails:
97
 
98
  if guardrails_response["safe"]:
99
  st.markdown(
100
- f"\n\n---\nPrompt is safe! Explore prompt trace on [Weave]({call.ui_url})\n\n---\n"
101
  )
102
 
103
  with st.sidebar.status("Generating response from LLM..."):
 
1
  import importlib
2
+ import os
3
 
4
  import streamlit as st
5
  import weave
 
8
  from guardrails_genie.guardrails import GuardrailManager
9
  from guardrails_genie.llm import OpenAIModel
10
 
 
 
 
 
11
 
12
+ def initialize_session_state():
13
+ load_dotenv()
14
+ weave.init(project_name=os.getenv("WEAVE_PROJECT"))
15
+
16
+ if "guardrails" not in st.session_state:
17
+ st.session_state.guardrails = []
18
+ if "guardrail_names" not in st.session_state:
19
+ st.session_state.guardrail_names = []
20
+ if "guardrails_manager" not in st.session_state:
21
+ st.session_state.guardrails_manager = None
22
+ if "initialize_guardrails" not in st.session_state:
23
+ st.session_state.initialize_guardrails = False
24
+ if "system_prompt" not in st.session_state:
25
+ st.session_state.system_prompt = ""
26
+ if "user_prompt" not in st.session_state:
27
+ st.session_state.user_prompt = ""
28
+ if "test_guardrails" not in st.session_state:
29
+ st.session_state.test_guardrails = False
30
+ if "llm_model" not in st.session_state:
31
+ st.session_state.llm_model = None
32
 
33
 
34
  def initialize_guardrails():
 
45
  guardrail_name,
46
  )(llm_model=OpenAIModel(model_name=survey_guardrail_model))
47
  )
48
+ elif guardrail_name == "PromptInjectionClassifierGuardrail":
49
+ classifier_model_name = st.sidebar.selectbox(
50
+ "Classifier Guardrail Model",
51
+ [
52
+ "",
53
+ "ProtectAI/deberta-v3-base-prompt-injection-v2",
54
+ "wandb://geekyrakshit/guardrails-genie/model-6rwqup9b:v3",
55
+ ],
56
  )
57
+ if classifier_model_name != "":
58
+ st.session_state.guardrails.append(
59
+ getattr(
60
+ importlib.import_module("guardrails_genie.guardrails"),
61
+ guardrail_name,
62
+ )(model_name=classifier_model_name)
63
+ )
64
  st.session_state.guardrails_manager = GuardrailManager(
65
  guardrails=st.session_state.guardrails
66
  )
67
 
68
 
69
+ initialize_session_state()
70
+ st.title(":material/robot: Guardrails Genie Playground")
71
+
72
  openai_model = st.sidebar.selectbox(
73
  "OpenAI LLM for Chat", ["", "gpt-4o-mini", "gpt-4o"]
74
  )
 
110
 
111
  if guardrails_response["safe"]:
112
  st.markdown(
113
+ f"\n\n---\nPrompt is safe! Explore guardrail trace on [Weave]({call.ui_url})\n\n---\n"
114
  )
115
 
116
  with st.sidebar.status("Generating response from LLM..."):
application_pages/evaluation_app.py CHANGED
@@ -1,7 +1,10 @@
1
  import asyncio
 
 
2
  from importlib import import_module
3
 
4
  import pandas as pd
 
5
  import streamlit as st
6
  import weave
7
  from dotenv import load_dotenv
@@ -9,12 +12,11 @@ from dotenv import load_dotenv
9
  from guardrails_genie.guardrails import GuardrailManager
10
  from guardrails_genie.llm import OpenAIModel
11
  from guardrails_genie.metrics import AccuracyMetric
12
-
13
- load_dotenv()
14
- weave.init(project_name="guardrails-genie")
15
 
16
 
17
  def initialize_session_state():
 
18
  if "uploaded_file" not in st.session_state:
19
  st.session_state.uploaded_file = None
20
  if "dataset_name" not in st.session_state:
@@ -35,6 +37,18 @@ def initialize_session_state():
35
  st.session_state.evaluation_summary = None
36
  if "guardrail_manager" not in st.session_state:
37
  st.session_state.guardrail_manager = None
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
 
40
  def initialize_guardrail():
@@ -51,10 +65,22 @@ def initialize_guardrail():
51
  guardrail_name,
52
  )(llm_model=OpenAIModel(model_name=survey_guardrail_model))
53
  )
54
- else:
55
- guardrails.append(
56
- getattr(import_module("guardrails_genie.guardrails"), guardrail_name)()
 
 
 
 
 
57
  )
 
 
 
 
 
 
 
58
  st.session_state.guardrails = guardrails
59
  st.session_state.guardrail_manager = GuardrailManager(guardrails=guardrails)
60
 
@@ -107,6 +133,8 @@ if st.session_state.dataset_previewed:
107
 
108
  if st.session_state.guardrail_names != []:
109
  initialize_guardrail()
 
 
110
  if st.session_state.guardrail_manager is not None:
111
  if st.sidebar.button("Start Evaluation"):
112
  st.session_state.start_evaluation = True
@@ -119,10 +147,55 @@ if st.session_state.dataset_previewed:
119
  with st.expander("Evaluation Results", expanded=True):
120
  evaluation_summary, call = asyncio.run(
121
  evaluation.evaluate.call(
122
- evaluation, st.session_state.guardrail_manager
 
 
 
 
 
123
  )
124
  )
125
- st.markdown(f"[Explore evaluation in Weave]({call.ui_url})")
126
- st.write(evaluation_summary)
127
- st.session_state.evaluation_summary = evaluation_summary
128
- st.session_state.start_evaluation = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import asyncio
2
+ import os
3
+ import time
4
  from importlib import import_module
5
 
6
  import pandas as pd
7
+ import rich
8
  import streamlit as st
9
  import weave
10
  from dotenv import load_dotenv
 
12
  from guardrails_genie.guardrails import GuardrailManager
13
  from guardrails_genie.llm import OpenAIModel
14
  from guardrails_genie.metrics import AccuracyMetric
15
+ from guardrails_genie.utils import EvaluationCallManager
 
 
16
 
17
 
18
  def initialize_session_state():
19
+ load_dotenv()
20
  if "uploaded_file" not in st.session_state:
21
  st.session_state.uploaded_file = None
22
  if "dataset_name" not in st.session_state:
 
37
  st.session_state.evaluation_summary = None
38
  if "guardrail_manager" not in st.session_state:
39
  st.session_state.guardrail_manager = None
40
+ if "evaluation_name" not in st.session_state:
41
+ st.session_state.evaluation_name = ""
42
+ if "show_result_table" not in st.session_state:
43
+ st.session_state.show_result_table = False
44
+ if "weave_client" not in st.session_state:
45
+ st.session_state.weave_client = weave.init(
46
+ project_name=os.getenv("WEAVE_PROJECT")
47
+ )
48
+ if "evaluation_call_manager" not in st.session_state:
49
+ st.session_state.evaluation_call_manager = None
50
+ if "call_id" not in st.session_state:
51
+ st.session_state.call_id = None
52
 
53
 
54
  def initialize_guardrail():
 
65
  guardrail_name,
66
  )(llm_model=OpenAIModel(model_name=survey_guardrail_model))
67
  )
68
+ elif guardrail_name == "PromptInjectionClassifierGuardrail":
69
+ classifier_model_name = st.sidebar.selectbox(
70
+ "Classifier Guardrail Model",
71
+ [
72
+ "",
73
+ "ProtectAI/deberta-v3-base-prompt-injection-v2",
74
+ "wandb://geekyrakshit/guardrails-genie/model-6rwqup9b:v3",
75
+ ],
76
  )
77
+ if classifier_model_name:
78
+ st.session_state.guardrails.append(
79
+ getattr(
80
+ import_module("guardrails_genie.guardrails"),
81
+ guardrail_name,
82
+ )(model_name=classifier_model_name)
83
+ )
84
  st.session_state.guardrails = guardrails
85
  st.session_state.guardrail_manager = GuardrailManager(guardrails=guardrails)
86
 
 
133
 
134
  if st.session_state.guardrail_names != []:
135
  initialize_guardrail()
136
+ evaluation_name = st.sidebar.text_input("Evaluation name", value="")
137
+ st.session_state.evaluation_name = evaluation_name
138
  if st.session_state.guardrail_manager is not None:
139
  if st.sidebar.button("Start Evaluation"):
140
  st.session_state.start_evaluation = True
 
147
  with st.expander("Evaluation Results", expanded=True):
148
  evaluation_summary, call = asyncio.run(
149
  evaluation.evaluate.call(
150
+ evaluation,
151
+ st.session_state.guardrail_manager,
152
+ __weave={
153
+ "display_name": "Evaluation.evaluate:"
154
+ + st.session_state.evaluation_name
155
+ },
156
  )
157
  )
158
+ x_axis = list(evaluation_summary["AccuracyMetric"].keys())
159
+ y_axis = [
160
+ evaluation_summary["AccuracyMetric"][x_axis_item]
161
+ for x_axis_item in x_axis
162
+ ]
163
+ st.bar_chart(
164
+ pd.DataFrame({"Metric": x_axis, "Score": y_axis}),
165
+ x="Metric",
166
+ y="Score",
167
+ )
168
+ st.session_state.evaluation_summary = evaluation_summary
169
+ st.session_state.call_id = call.id
170
+ st.session_state.start_evaluation = False
171
+
172
+ if not st.session_state.start_evaluation:
173
+ time.sleep(5)
174
+ st.session_state.evaluation_call_manager = (
175
+ EvaluationCallManager(
176
+ entity="geekyrakshit",
177
+ project="guardrails-genie",
178
+ call_id=st.session_state.call_id,
179
+ )
180
+ )
181
+ for guardrail_name in st.session_state.guardrail_names:
182
+ st.session_state.evaluation_call_manager.call_list.append(
183
+ {
184
+ "guardrail_name": guardrail_name,
185
+ "calls": st.session_state.evaluation_call_manager.collect_guardrail_guard_calls_from_eval(),
186
+ }
187
+ )
188
+ rich.print(
189
+ st.session_state.evaluation_call_manager.call_list
190
+ )
191
+ st.dataframe(
192
+ st.session_state.evaluation_call_manager.render_calls_to_streamlit()
193
+ )
194
+ if st.session_state.evaluation_call_manager.show_warning_in_app:
195
+ st.warning(
196
+ f"Only {st.session_state.evaluation_call_manager.max_count} calls can be shown in the app."
197
+ )
198
+ st.markdown(
199
+ f"Explore the entire evaluation trace table in [Weave]({call.ui_url})"
200
+ )
201
+ st.session_state.evaluation_call_manager = None
application_pages/train_classifier.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import streamlit as st
4
+ from dotenv import load_dotenv
5
+
6
+ from guardrails_genie.train_classifier import train_binary_classifier
7
+
8
+
9
+ def initialize_session_state():
10
+ load_dotenv()
11
+ if "dataset_name" not in st.session_state:
12
+ st.session_state.dataset_name = None
13
+ if "base_model_name" not in st.session_state:
14
+ st.session_state.base_model_name = None
15
+ if "batch_size" not in st.session_state:
16
+ st.session_state.batch_size = 16
17
+ if "should_start_training" not in st.session_state:
18
+ st.session_state.should_start_training = False
19
+ if "training_output" not in st.session_state:
20
+ st.session_state.training_output = None
21
+
22
+
23
+ initialize_session_state()
24
+ st.title(":material/fitness_center: Train Classifier")
25
+
26
+ dataset_name = st.sidebar.text_input("Dataset Name", value="")
27
+ st.session_state.dataset_name = dataset_name
28
+
29
+ base_model_name = st.sidebar.selectbox(
30
+ "Base Model",
31
+ options=[
32
+ "distilbert/distilbert-base-uncased",
33
+ "FacebookAI/roberta-base",
34
+ "microsoft/deberta-v3-base",
35
+ ],
36
+ )
37
+ st.session_state.base_model_name = base_model_name
38
+
39
+ batch_size = st.sidebar.slider(
40
+ "Batch Size", min_value=4, max_value=256, value=16, step=4
41
+ )
42
+ st.session_state.batch_size = batch_size
43
+
44
+ train_button = st.sidebar.button("Train")
45
+ st.session_state.should_start_training = (
46
+ train_button and st.session_state.dataset_name and st.session_state.base_model_name
47
+ )
48
+
49
+ if st.session_state.should_start_training:
50
+ with st.expander("Training", expanded=True):
51
+ training_output = train_binary_classifier(
52
+ project_name=os.getenv("WANDB_PROJECT_NAME"),
53
+ entity_name=os.getenv("WANDB_ENTITY_NAME"),
54
+ run_name=f"{st.session_state.base_model_name}-finetuned",
55
+ dataset_repo=st.session_state.dataset_name,
56
+ model_name=st.session_state.base_model_name,
57
+ batch_size=st.session_state.batch_size,
58
+ streamlit_mode=True,
59
+ )
60
+ st.session_state.training_output = training_output
61
+ st.write(training_output)
guardrails_genie/guardrails/__init__.py CHANGED
@@ -1,8 +1,11 @@
1
- from .injection import PromptInjectionProtectAIGuardrail, PromptInjectionSurveyGuardrail
 
 
 
2
  from .manager import GuardrailManager
3
 
4
  __all__ = [
5
  "PromptInjectionSurveyGuardrail",
6
- "PromptInjectionProtectAIGuardrail",
7
  "GuardrailManager",
8
  ]
 
1
+ from .injection import (
2
+ PromptInjectionClassifierGuardrail,
3
+ PromptInjectionSurveyGuardrail,
4
+ )
5
  from .manager import GuardrailManager
6
 
7
  __all__ = [
8
  "PromptInjectionSurveyGuardrail",
9
+ "PromptInjectionClassifierGuardrail",
10
  "GuardrailManager",
11
  ]
guardrails_genie/guardrails/injection/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- from .protectai_guardrail import PromptInjectionProtectAIGuardrail
2
  from .survey_guardrail import PromptInjectionSurveyGuardrail
3
 
4
- __all__ = ["PromptInjectionSurveyGuardrail", "PromptInjectionProtectAIGuardrail"]
 
1
+ from .classifier_guardrail import PromptInjectionClassifierGuardrail
2
  from .survey_guardrail import PromptInjectionSurveyGuardrail
3
 
4
+ __all__ = ["PromptInjectionSurveyGuardrail", "PromptInjectionClassifierGuardrail"]
guardrails_genie/guardrails/injection/{protectai_guardrail.py → classifier_guardrail.py} RENAMED
@@ -5,16 +5,25 @@ import weave
5
  from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
6
  from transformers.pipelines.base import Pipeline
7
 
 
 
8
  from ..base import Guardrail
9
 
10
 
11
- class PromptInjectionProtectAIGuardrail(Guardrail):
12
  model_name: str = "ProtectAI/deberta-v3-base-prompt-injection-v2"
13
  _classifier: Optional[Pipeline] = None
14
 
15
  def model_post_init(self, __context):
16
- tokenizer = AutoTokenizer.from_pretrained(self.model_name)
17
- model = AutoModelForSequenceClassification.from_pretrained(self.model_name)
 
 
 
 
 
 
 
18
  self._classifier = pipeline(
19
  "text-classification",
20
  model=model,
@@ -28,11 +37,6 @@ class PromptInjectionProtectAIGuardrail(Guardrail):
28
  def classify(self, prompt: str):
29
  return self._classifier(prompt)
30
 
31
- @weave.op()
32
- def predict(self, prompt: str):
33
- response = self.classify(prompt)
34
- return {"safe": response[0]["label"] != "INJECTION"}
35
-
36
  @weave.op()
37
  def guard(self, prompt: str):
38
  response = self.classify(prompt)
@@ -41,3 +45,7 @@ class PromptInjectionProtectAIGuardrail(Guardrail):
41
  "safe": response[0]["label"] != "INJECTION",
42
  "summary": f"Prompt is deemed {response[0]['label']} with {confidence_percentage}% confidence.",
43
  }
 
 
 
 
 
5
  from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
6
  from transformers.pipelines.base import Pipeline
7
 
8
+ import wandb
9
+
10
  from ..base import Guardrail
11
 
12
 
13
+ class PromptInjectionClassifierGuardrail(Guardrail):
14
  model_name: str = "ProtectAI/deberta-v3-base-prompt-injection-v2"
15
  _classifier: Optional[Pipeline] = None
16
 
17
  def model_post_init(self, __context):
18
+ if self.model_name.startswith("wandb://"):
19
+ api = wandb.Api()
20
+ artifact = api.artifact(self.model_name.removeprefix("wandb://"))
21
+ artifact_dir = artifact.download()
22
+ tokenizer = AutoTokenizer.from_pretrained(artifact_dir)
23
+ model = AutoModelForSequenceClassification.from_pretrained(artifact_dir)
24
+ else:
25
+ tokenizer = AutoTokenizer.from_pretrained(self.model_name)
26
+ model = AutoModelForSequenceClassification.from_pretrained(self.model_name)
27
  self._classifier = pipeline(
28
  "text-classification",
29
  model=model,
 
37
  def classify(self, prompt: str):
38
  return self._classifier(prompt)
39
 
 
 
 
 
 
40
  @weave.op()
41
  def guard(self, prompt: str):
42
  response = self.classify(prompt)
 
45
  "safe": response[0]["label"] != "INJECTION",
46
  "summary": f"Prompt is deemed {response[0]['label']} with {confidence_percentage}% confidence.",
47
  }
48
+
49
+ @weave.op()
50
+ def predict(self, prompt: str):
51
+ return self.guard(prompt)
guardrails_genie/train_classifier.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import evaluate
2
+ import numpy as np
3
+ import streamlit as st
4
+ from datasets import load_dataset
5
+ from transformers import (
6
+ AutoModelForSequenceClassification,
7
+ AutoTokenizer,
8
+ DataCollatorWithPadding,
9
+ Trainer,
10
+ TrainerCallback,
11
+ TrainingArguments,
12
+ )
13
+ from transformers.trainer_callback import TrainerControl, TrainerState
14
+
15
+ import wandb
16
+
17
+
18
+ class StreamlitProgressbarCallback(TrainerCallback):
19
+
20
+ def __init__(self, *args, **kwargs):
21
+ super().__init__(*args, **kwargs)
22
+ self.progress_bar = st.progress(0, text="Training")
23
+
24
+ def on_step_begin(
25
+ self,
26
+ args: TrainingArguments,
27
+ state: TrainerState,
28
+ control: TrainerControl,
29
+ **kwargs,
30
+ ):
31
+ super().on_step_begin(args, state, control, **kwargs)
32
+ self.progress_bar.progress(
33
+ (state.global_step * 100 // state.max_steps) + 1,
34
+ text=f"Training {state.global_step} / {state.max_steps}",
35
+ )
36
+
37
+
38
+ def train_binary_classifier(
39
+ project_name: str,
40
+ entity_name: str,
41
+ run_name: str,
42
+ dataset_repo: str = "geekyrakshit/prompt-injection-dataset",
43
+ model_name: str = "distilbert/distilbert-base-uncased",
44
+ learning_rate: float = 2e-5,
45
+ batch_size: int = 16,
46
+ num_epochs: int = 2,
47
+ weight_decay: float = 0.01,
48
+ streamlit_mode: bool = False,
49
+ ):
50
+ wandb.init(project=project_name, entity=entity_name, name=run_name)
51
+ if streamlit_mode:
52
+ st.markdown(
53
+ f"Explore your training logs on [Weights & Biases]({wandb.run.url})"
54
+ )
55
+ dataset = load_dataset(dataset_repo)
56
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
57
+
58
+ def preprocess_function(examples):
59
+ return tokenizer(examples["prompt"], truncation=True)
60
+
61
+ tokenized_datasets = dataset.map(preprocess_function, batched=True)
62
+ data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
63
+ accuracy = evaluate.load("accuracy")
64
+
65
+ def compute_metrics(eval_pred):
66
+ predictions, labels = eval_pred
67
+ predictions = np.argmax(predictions, axis=1)
68
+ return accuracy.compute(predictions=predictions, references=labels)
69
+
70
+ id2label = {0: "SAFE", 1: "INJECTION"}
71
+ label2id = {"SAFE": 0, "INJECTION": 1}
72
+
73
+ model = AutoModelForSequenceClassification.from_pretrained(
74
+ model_name,
75
+ num_labels=2,
76
+ id2label=id2label,
77
+ label2id=label2id,
78
+ )
79
+
80
+ trainer = Trainer(
81
+ model=model,
82
+ args=TrainingArguments(
83
+ output_dir="binary-classifier",
84
+ learning_rate=learning_rate,
85
+ per_device_train_batch_size=batch_size,
86
+ per_device_eval_batch_size=batch_size,
87
+ num_train_epochs=num_epochs,
88
+ weight_decay=weight_decay,
89
+ eval_strategy="epoch",
90
+ save_strategy="epoch",
91
+ load_best_model_at_end=True,
92
+ push_to_hub=True,
93
+ report_to="wandb",
94
+ logging_strategy="steps",
95
+ logging_steps=1,
96
+ ),
97
+ train_dataset=tokenized_datasets["train"],
98
+ eval_dataset=tokenized_datasets["test"],
99
+ processing_class=tokenizer,
100
+ data_collator=data_collator,
101
+ compute_metrics=compute_metrics,
102
+ callbacks=[StreamlitProgressbarCallback()] if streamlit_mode else [],
103
+ )
104
+ try:
105
+ training_output = trainer.train()
106
+ except Exception as e:
107
+ wandb.finish()
108
+ raise e
109
+ wandb.finish()
110
+ return training_output
guardrails_genie/utils.py CHANGED
@@ -1,7 +1,9 @@
1
  import os
2
 
 
3
  import pymupdf4llm
4
  import weave
 
5
  from firerequests import FireRequests
6
 
7
 
@@ -11,3 +13,47 @@ def get_markdown_from_pdf_url(url: str) -> str:
11
  markdown = pymupdf4llm.to_markdown("temp.pdf", show_progress=False)
12
  os.remove("temp.pdf")
13
  return markdown
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
 
3
+ import pandas as pd
4
  import pymupdf4llm
5
  import weave
6
+ import weave.trace
7
  from firerequests import FireRequests
8
 
9
 
 
13
  markdown = pymupdf4llm.to_markdown("temp.pdf", show_progress=False)
14
  os.remove("temp.pdf")
15
  return markdown
16
+
17
+
18
+ class EvaluationCallManager:
19
+ def __init__(self, entity: str, project: str, call_id: str, max_count: int = 10):
20
+ self.base_call = weave.init(f"{entity}/{project}").get_call(call_id=call_id)
21
+ self.max_count = max_count
22
+ self.show_warning_in_app = False
23
+ self.call_list = []
24
+
25
+ def collect_guardrail_guard_calls_from_eval(self):
26
+ guard_calls, count = [], 0
27
+ for eval_predict_and_score_call in self.base_call.children():
28
+ if "Evaluation.summarize" in eval_predict_and_score_call._op_name:
29
+ break
30
+ guardrail_predict_call = eval_predict_and_score_call.children()[0]
31
+ guard_call = guardrail_predict_call.children()[0]
32
+ score_call = eval_predict_and_score_call.children()[1]
33
+ guard_calls.append(
34
+ {
35
+ "input_prompt": str(guard_call.inputs["prompt"]),
36
+ "outputs": dict(guard_call.output),
37
+ "score": dict(score_call.output),
38
+ }
39
+ )
40
+ count += 1
41
+ if count >= self.max_count:
42
+ self.show_warning_in_app = True
43
+ break
44
+ return guard_calls
45
+
46
+ def render_calls_to_streamlit(self):
47
+ dataframe = {
48
+ "input_prompt": [
49
+ call["input_prompt"] for call in self.call_list[0]["calls"]
50
+ ]
51
+ }
52
+ for guardrail_call in self.call_list:
53
+ dataframe[guardrail_call["guardrail_name"] + ".safe"] = [
54
+ call["outputs"]["safe"] for call in guardrail_call["calls"]
55
+ ]
56
+ dataframe[guardrail_call["guardrail_name"] + ".prediction_correctness"] = [
57
+ call["score"]["correct"] for call in guardrail_call["calls"]
58
+ ]
59
+ return pd.DataFrame(dataframe)
pyproject.toml CHANGED
@@ -12,7 +12,7 @@ dependencies = [
12
  "ruff>=0.6.9",
13
  "pip>=24.2",
14
  "uv>=0.4.20",
15
- "git+https://github.com/wandb/weave@feat/eval-progressbar",
16
  "streamlit>=1.40.1",
17
  "python-dotenv>=1.0.1",
18
  "watchdog>=6.0.0",
@@ -23,4 +23,4 @@ dependencies = [
23
  ]
24
 
25
  [tool.setuptools]
26
- py-modules = ["guardrails_genie"]
 
12
  "ruff>=0.6.9",
13
  "pip>=24.2",
14
  "uv>=0.4.20",
15
+ "weave @ git+https://github.com/wandb/weave@feat/eval-progressbar",
16
  "streamlit>=1.40.1",
17
  "python-dotenv>=1.0.1",
18
  "watchdog>=6.0.0",
 
23
  ]
24
 
25
  [tool.setuptools]
26
+ py-modules = ["guardrails_genie"]