geekyrakshit commited on
Commit
af688eb
·
1 Parent(s): 96b1c8c

add: summarizaion to guardrails

Browse files
app.py CHANGED
@@ -4,7 +4,9 @@ intro_page = st.Page(
4
  "application_pages/intro_page.py", title="Introduction", icon=":material/guardian:"
5
  )
6
  chat_page = st.Page(
7
- "application_pages/chat_app.py", title="Chat", icon=":material/robot:"
 
 
8
  )
9
  evaluation_page = st.Page(
10
  "application_pages/evaluation_app.py",
 
4
  "application_pages/intro_page.py", title="Introduction", icon=":material/guardian:"
5
  )
6
  chat_page = st.Page(
7
+ "application_pages/chat_app.py",
8
+ title="Playground",
9
+ icon=":material/sports_esports:",
10
  )
11
  evaluation_page = st.Page(
12
  "application_pages/evaluation_app.py",
application_pages/chat_app.py CHANGED
@@ -7,19 +7,27 @@ from dotenv import load_dotenv
7
  from guardrails_genie.guardrails import GuardrailManager
8
  from guardrails_genie.llm import OpenAIModel
9
 
 
 
10
  load_dotenv()
11
  weave.init(project_name="guardrails-genie")
12
 
13
- st.title(":material/robot: Guardrails Genie")
14
-
15
  if "guardrails" not in st.session_state:
16
  st.session_state.guardrails = []
17
  if "guardrail_names" not in st.session_state:
18
  st.session_state.guardrail_names = []
19
  if "guardrails_manager" not in st.session_state:
20
  st.session_state.guardrails_manager = None
21
- if "chat_started" not in st.session_state:
22
- st.session_state.chat_started = False
 
 
 
 
 
 
 
 
23
 
24
 
25
  def initialize_guardrails():
@@ -67,48 +75,41 @@ guardrail_names = st.sidebar.multiselect(
67
  )
68
  st.session_state.guardrail_names = guardrail_names
69
 
70
- if st.sidebar.button("Start Chat") and chat_condition:
71
- st.session_state.chat_started = True
72
 
73
- if st.session_state.chat_started:
74
  with st.sidebar.status("Initializing Guardrails..."):
75
  initialize_guardrails()
 
76
 
77
- # Initialize chat history
78
- if "messages" not in st.session_state:
79
- st.session_state.messages = []
80
-
81
- llm_model = OpenAIModel(model_name=openai_model)
82
-
83
- # Display chat messages from history on app rerun
84
- for message in st.session_state.messages:
85
- with st.chat_message(message["role"]):
86
- st.markdown(message["content"])
87
 
88
- # React to user input
89
- if prompt := st.chat_input("What is up?"):
90
- # Display user message in chat message container
91
- st.chat_message("user").markdown(prompt)
92
- # Add user message to chat history
93
- st.session_state.messages.append({"role": "user", "content": prompt})
94
 
95
- guardrails_response, call = st.session_state.guardrails_manager.guard.call(
96
- st.session_state.guardrails_manager, prompt=prompt
97
- )
 
 
98
 
99
  if guardrails_response["safe"]:
100
- response, call = llm_model.predict.call(
101
- llm_model, user_prompts=prompt, messages=st.session_state.messages
102
  )
103
- response = response.choices[0].message.content
104
 
105
- # Display assistant response in chat message container
106
- with st.chat_message("assistant"):
107
- st.markdown(response + f"\n\n---\n[Explore in Weave]({call.ui_url})")
108
- # Add assistant response to chat history
109
- st.session_state.messages.append({"role": "assistant", "content": response})
 
 
 
 
110
  else:
111
- st.error("Guardrails detected an issue with the prompt.")
112
- for alert in guardrails_response["alerts"]:
113
- st.error(f"{alert['guardrail_name']}: {alert['response']}")
114
- st.error(f"For details, explore in Weave at {call.ui_url}")
 
7
  from guardrails_genie.guardrails import GuardrailManager
8
  from guardrails_genie.llm import OpenAIModel
9
 
10
+ st.title(":material/robot: Guardrails Genie Playground")
11
+
12
  load_dotenv()
13
  weave.init(project_name="guardrails-genie")
14
 
 
 
15
  if "guardrails" not in st.session_state:
16
  st.session_state.guardrails = []
17
  if "guardrail_names" not in st.session_state:
18
  st.session_state.guardrail_names = []
19
  if "guardrails_manager" not in st.session_state:
20
  st.session_state.guardrails_manager = None
21
+ if "initialize_guardrails" not in st.session_state:
22
+ st.session_state.initialize_guardrails = False
23
+ if "system_prompt" not in st.session_state:
24
+ st.session_state.system_prompt = ""
25
+ if "user_prompt" not in st.session_state:
26
+ st.session_state.user_prompt = ""
27
+ if "test_guardrails" not in st.session_state:
28
+ st.session_state.test_guardrails = False
29
+ if "llm_model" not in st.session_state:
30
+ st.session_state.llm_model = None
31
 
32
 
33
  def initialize_guardrails():
 
75
  )
76
  st.session_state.guardrail_names = guardrail_names
77
 
78
+ if st.sidebar.button("Initialize Guardrails") and chat_condition:
79
+ st.session_state.initialize_guardrails = True
80
 
81
+ if st.session_state.initialize_guardrails:
82
  with st.sidebar.status("Initializing Guardrails..."):
83
  initialize_guardrails()
84
+ st.session_state.llm_model = OpenAIModel(model_name=openai_model)
85
 
86
+ user_prompt = st.text_area("User Prompt", value="")
87
+ st.session_state.user_prompt = user_prompt
 
 
 
 
 
 
 
 
88
 
89
+ test_guardrails_button = st.button("Test Guardrails")
90
+ st.session_state.test_guardrails = test_guardrails_button
 
 
 
 
91
 
92
+ if st.session_state.test_guardrails:
93
+ with st.sidebar.status("Running Guardrails..."):
94
+ guardrails_response, call = st.session_state.guardrails_manager.guard.call(
95
+ st.session_state.guardrails_manager, prompt=st.session_state.user_prompt
96
+ )
97
 
98
  if guardrails_response["safe"]:
99
+ st.markdown(
100
+ f"\n\n---\nPrompt is safe! Explore prompt trace on [Weave]({call.ui_url})\n\n---\n"
101
  )
 
102
 
103
+ with st.sidebar.status("Generating response from LLM..."):
104
+ response, call = st.session_state.llm_model.predict.call(
105
+ st.session_state.llm_model,
106
+ user_prompts=st.session_state.user_prompt,
107
+ )
108
+ st.markdown(
109
+ response.choices[0].message.content
110
+ + f"\n\n---\nExplore LLM generation trace on [Weave]({call.ui_url})"
111
+ )
112
  else:
113
+ st.warning("Prompt is not safe!")
114
+ st.markdown(guardrails_response["summary"])
115
+ st.markdown(f"Explore prompt trace on [Weave]({call.ui_url})")
 
guardrails_genie/guardrails/injection/protectai_guardrail.py CHANGED
@@ -35,4 +35,9 @@ class PromptInjectionProtectAIGuardrail(Guardrail):
35
 
36
  @weave.op()
37
  def guard(self, prompt: str):
38
- return self.predict(prompt)
 
 
 
 
 
 
35
 
36
  @weave.op()
37
  def guard(self, prompt: str):
38
+ response = self.classify(prompt)
39
+ confidence_percentage = round(response[0]["score"] * 100, 2)
40
+ return {
41
+ "safe": response[0]["label"] != "INJECTION",
42
+ "summary": f"Prompt is deemed {response[0]['label']} with {confidence_percentage}% confidence.",
43
+ }
guardrails_genie/guardrails/injection/survey_guardrail.py CHANGED
@@ -70,8 +70,17 @@ Here are some strict instructions that you must follow:
70
  **kwargs,
71
  )
72
  response = chat_completion.choices[0].message.parsed
73
- return {"safe": not response.injection_prompt}
74
 
75
  @weave.op()
76
  def guard(self, prompt: str, **kwargs) -> list[str]:
77
- return self.predict(prompt, **kwargs)
 
 
 
 
 
 
 
 
 
 
70
  **kwargs,
71
  )
72
  response = chat_completion.choices[0].message.parsed
73
+ return response
74
 
75
  @weave.op()
76
  def guard(self, prompt: str, **kwargs) -> list[str]:
77
+ response = self.predict(prompt, **kwargs)
78
+ summary = (
79
+ f"Prompt is deemed safe. {response.explanation}"
80
+ if not response.injection_prompt
81
+ else f"Prompt is deemed a {'direct attack' if response.is_direct_attack else 'indirect attack'} of type {response.attack_type}. {response.explanation}"
82
+ )
83
+ return {
84
+ "safe": not response.injection_prompt,
85
+ "summary": summary,
86
+ }
guardrails_genie/guardrails/manager.py CHANGED
@@ -9,7 +9,7 @@ class GuardrailManager(weave.Model):
9
 
10
  @weave.op()
11
  def guard(self, prompt: str, progress_bar: bool = True, **kwargs) -> dict:
12
- alerts, safe = [], True
13
  iterable = (
14
  track(self.guardrails, description="Running guardrails")
15
  if progress_bar
@@ -21,7 +21,10 @@ class GuardrailManager(weave.Model):
21
  {"guardrail_name": guardrail.__class__.__name__, "response": response}
22
  )
23
  safe = safe and response["safe"]
24
- return {"safe": safe, "alerts": alerts}
 
 
 
25
 
26
  @weave.op()
27
  def predict(self, prompt: str, **kwargs) -> dict:
 
9
 
10
  @weave.op()
11
  def guard(self, prompt: str, progress_bar: bool = True, **kwargs) -> dict:
12
+ alerts, summaries, safe = [], "", True
13
  iterable = (
14
  track(self.guardrails, description="Running guardrails")
15
  if progress_bar
 
21
  {"guardrail_name": guardrail.__class__.__name__, "response": response}
22
  )
23
  safe = safe and response["safe"]
24
+ summaries += (
25
+ f"**{guardrail.__class__.__name__}**: {response['summary']}\n\n---\n\n"
26
+ )
27
+ return {"safe": safe, "alerts": alerts, "summary": summaries}
28
 
29
  @weave.op()
30
  def predict(self, prompt: str, **kwargs) -> dict: