geekyrakshit commited on
Commit
224221c
2 Parent(s): a1c5338 a645df8

Merge pull request #2 from soumik12345/feat/guardrails-api

Browse files
.streamlit/config.toml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ [server]
2
+
3
+ headless = true
4
+ runOnSave = true
5
+ allowRunOnSave = true
6
+ fastReruns = true
7
+ fileWatcherType = "auto"
README.md CHANGED
@@ -12,3 +12,9 @@ uv venv
12
  uv pip install -e .
13
  source .venv/bin/activate
14
  ```
 
 
 
 
 
 
 
12
  uv pip install -e .
13
  source .venv/bin/activate
14
  ```
15
+
16
+ ## Run Chat App
17
+
18
+ ```bash
19
+ OPENAI_API_KEY="YOUR_OPENAI_API_KEY" streamlit run app.py
20
+ ```
app.py CHANGED
@@ -1,52 +1,16 @@
1
  import streamlit as st
2
- import weave
3
- from dotenv import load_dotenv
4
 
5
- from guardrails_genie.llm import OpenAIModel
6
-
7
- load_dotenv()
8
- weave.init(project_name="guardrails-genie")
9
-
10
- openai_model = st.sidebar.selectbox("OpenAI LLM", ["", "gpt-4o-mini", "gpt-4o"])
11
- chat_condition = openai_model != ""
12
-
13
- # Use session state to track if the chat has started
14
- if "chat_started" not in st.session_state:
15
- st.session_state.chat_started = False
16
-
17
- # Start chat when button is pressed
18
- if st.sidebar.button("Start Chat") and chat_condition:
19
- st.session_state.chat_started = True
20
-
21
- # Display chat UI if chat has started
22
- if st.session_state.chat_started:
23
- st.title("Guardrails Genie")
24
-
25
- # Initialize chat history
26
- if "messages" not in st.session_state:
27
- st.session_state.messages = []
28
-
29
- llm_model = OpenAIModel(model_name=openai_model)
30
-
31
- # Display chat messages from history on app rerun
32
- for message in st.session_state.messages:
33
- with st.chat_message(message["role"]):
34
- st.markdown(message["content"])
35
-
36
- # React to user input
37
- if prompt := st.chat_input("What is up?"):
38
- # Display user message in chat message container
39
- st.chat_message("user").markdown(prompt)
40
- # Add user message to chat history
41
- st.session_state.messages.append({"role": "user", "content": prompt})
42
-
43
- response, call = llm_model.predict.call(
44
- llm_model, user_prompts=prompt, messages=st.session_state.messages
45
- )
46
- response = response.choices[0].message.content
47
-
48
- # Display assistant response in chat message container
49
- with st.chat_message("assistant"):
50
- st.markdown(response + f"\n\n---\n[Explore in Weave]({call.ui_url})")
51
- # Add assistant response to chat history
52
- st.session_state.messages.append({"role": "assistant", "content": response})
 
1
  import streamlit as st
 
 
2
 
3
+ intro_page = st.Page(
4
+ "application_pages/intro_page.py", title="Introduction", icon=":material/guardian:"
5
+ )
6
+ chat_page = st.Page(
7
+ "application_pages/chat_app.py", title="Chat", icon=":material/robot:"
8
+ )
9
+ evaluation_page = st.Page(
10
+ "application_pages/evaluation_app.py",
11
+ title="Evaluation",
12
+ icon=":material/monitoring:",
13
+ )
14
+ page_navigation = st.navigation([intro_page, chat_page, evaluation_page])
15
+ st.set_page_config(page_title="Guardrails Genie", page_icon=":material/guardian:")
16
+ page_navigation.run()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
application_pages/chat_app.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+
3
+ import streamlit as st
4
+ import weave
5
+ from dotenv import load_dotenv
6
+
7
+ from guardrails_genie.guardrails import GuardrailManager
8
+ from guardrails_genie.llm import OpenAIModel
9
+
10
+ load_dotenv()
11
+ weave.init(project_name="guardrails-genie")
12
+
13
+ st.title(":material/robot: Guardrails Genie")
14
+
15
+ if "guardrails" not in st.session_state:
16
+ st.session_state.guardrails = []
17
+ if "guardrail_names" not in st.session_state:
18
+ st.session_state.guardrail_names = []
19
+ if "guardrails_manager" not in st.session_state:
20
+ st.session_state.guardrails_manager = None
21
+ if "chat_started" not in st.session_state:
22
+ st.session_state.chat_started = False
23
+
24
+
25
+ def initialize_guardrails():
26
+ st.session_state.guardrails = []
27
+ for guardrail_name in st.session_state.guardrail_names:
28
+ if guardrail_name == "PromptInjectionSurveyGuardrail":
29
+ survey_guardrail_model = st.sidebar.selectbox(
30
+ "Survey Guardrail LLM", ["", "gpt-4o-mini", "gpt-4o"]
31
+ )
32
+ if survey_guardrail_model:
33
+ st.session_state.guardrails.append(
34
+ getattr(
35
+ importlib.import_module("guardrails_genie.guardrails"),
36
+ guardrail_name,
37
+ )(llm_model=OpenAIModel(model_name=survey_guardrail_model))
38
+ )
39
+ else:
40
+ st.session_state.guardrails.append(
41
+ getattr(
42
+ importlib.import_module("guardrails_genie.guardrails"),
43
+ guardrail_name,
44
+ )()
45
+ )
46
+ st.session_state.guardrails_manager = GuardrailManager(
47
+ guardrails=st.session_state.guardrails
48
+ )
49
+
50
+
51
+ openai_model = st.sidebar.selectbox(
52
+ "OpenAI LLM for Chat", ["", "gpt-4o-mini", "gpt-4o"]
53
+ )
54
+ chat_condition = openai_model != ""
55
+
56
+ guardrails = []
57
+
58
+ guardrail_names = st.sidebar.multiselect(
59
+ label="Select Guardrails",
60
+ options=[
61
+ cls_name
62
+ for cls_name, cls_obj in vars(
63
+ importlib.import_module("guardrails_genie.guardrails")
64
+ ).items()
65
+ if isinstance(cls_obj, type) and cls_name != "GuardrailManager"
66
+ ],
67
+ )
68
+ st.session_state.guardrail_names = guardrail_names
69
+
70
+ if st.sidebar.button("Start Chat") and chat_condition:
71
+ st.session_state.chat_started = True
72
+
73
+ if st.session_state.chat_started:
74
+ with st.sidebar.status("Initializing Guardrails..."):
75
+ initialize_guardrails()
76
+
77
+ st.title("Guardrails Genie")
78
+
79
+ # Initialize chat history
80
+ if "messages" not in st.session_state:
81
+ st.session_state.messages = []
82
+
83
+ llm_model = OpenAIModel(model_name=openai_model)
84
+
85
+ # Display chat messages from history on app rerun
86
+ for message in st.session_state.messages:
87
+ with st.chat_message(message["role"]):
88
+ st.markdown(message["content"])
89
+
90
+ # React to user input
91
+ if prompt := st.chat_input("What is up?"):
92
+ # Display user message in chat message container
93
+ st.chat_message("user").markdown(prompt)
94
+ # Add user message to chat history
95
+ st.session_state.messages.append({"role": "user", "content": prompt})
96
+
97
+ guardrails_response, call = st.session_state.guardrails_manager.guard.call(
98
+ st.session_state.guardrails_manager, prompt=prompt
99
+ )
100
+
101
+ if guardrails_response["safe"]:
102
+ response, call = llm_model.predict.call(
103
+ llm_model, user_prompts=prompt, messages=st.session_state.messages
104
+ )
105
+ response = response.choices[0].message.content
106
+
107
+ # Display assistant response in chat message container
108
+ with st.chat_message("assistant"):
109
+ st.markdown(response + f"\n\n---\n[Explore in Weave]({call.ui_url})")
110
+ # Add assistant response to chat history
111
+ st.session_state.messages.append({"role": "assistant", "content": response})
112
+ else:
113
+ st.error("Guardrails detected an issue with the prompt.")
114
+ for alert in guardrails_response["alerts"]:
115
+ st.error(f"{alert['guardrail_name']}: {alert['response']}")
116
+ st.error(f"For details, explore in Weave at {call.ui_url}")
application_pages/evaluation_app.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ import streamlit as st
2
+
3
+ st.title(":material/monitoring: Evaluation")
application_pages/intro_page.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ st.title("🧞‍♂️ Guardrails Genie")
4
+
5
+ st.write(
6
+ "Guardrails-Genie is a tool that helps you implement guardrails in your LLM applications."
7
+ )
guardrails_genie/guardrails/__init__.py CHANGED
@@ -1,3 +1,8 @@
1
- from .injection import SurveyGuardrail
 
2
 
3
- __all__ = ["SurveyGuardrail"]
 
 
 
 
 
1
+ from .injection import PromptInjectionProtectAIGuardrail, PromptInjectionSurveyGuardrail
2
+ from .manager import GuardrailManager
3
 
4
+ __all__ = [
5
+ "PromptInjectionSurveyGuardrail",
6
+ "PromptInjectionProtectAIGuardrail",
7
+ "GuardrailManager",
8
+ ]
guardrails_genie/guardrails/base.py CHANGED
@@ -11,7 +11,3 @@ class Guardrail(weave.Model):
11
  @weave.op()
12
  def guard(self, prompt: str, **kwargs) -> list[str]:
13
  pass
14
-
15
- @weave.op()
16
- def predict(self, prompt: str, **kwargs) -> list[str]:
17
- return self.guard(prompt, **kwargs)
 
11
  @weave.op()
12
  def guard(self, prompt: str, **kwargs) -> list[str]:
13
  pass
 
 
 
 
guardrails_genie/guardrails/injection/__init__.py CHANGED
@@ -1,3 +1,4 @@
1
- from .survey_guardrail import SurveyGuardrail
 
2
 
3
- __all__ = ["SurveyGuardrail"]
 
1
+ from .protectai_guardrail import PromptInjectionProtectAIGuardrail
2
+ from .survey_guardrail import PromptInjectionSurveyGuardrail
3
 
4
+ __all__ = ["PromptInjectionSurveyGuardrail", "PromptInjectionProtectAIGuardrail"]
guardrails_genie/guardrails/injection/protectai_guardrail.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import weave
5
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
6
+ from transformers.pipelines.base import Pipeline
7
+
8
+ from ..base import Guardrail
9
+
10
+
11
+ class PromptInjectionProtectAIGuardrail(Guardrail):
12
+ model_name: str = "ProtectAI/deberta-v3-base-prompt-injection-v2"
13
+ _classifier: Optional[Pipeline] = None
14
+
15
+ def model_post_init(self, __context):
16
+ tokenizer = AutoTokenizer.from_pretrained(self.model_name)
17
+ model = AutoModelForSequenceClassification.from_pretrained(self.model_name)
18
+ self._classifier = pipeline(
19
+ "text-classification",
20
+ model=model,
21
+ tokenizer=tokenizer,
22
+ truncation=True,
23
+ max_length=512,
24
+ device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
25
+ )
26
+
27
+ @weave.op()
28
+ def predict(self, prompt: str):
29
+ return self._classifier(prompt)
30
+
31
+ @weave.op()
32
+ def guard(self, prompt: str):
33
+ response = self.predict(prompt)
34
+ return {"safe": response[0]["label"] != "INJECTION"}
guardrails_genie/guardrails/injection/survey_guardrail.py CHANGED
@@ -15,9 +15,9 @@ class SurveyGuardrailResponse(BaseModel):
15
  explanation: Optional[str]
16
 
17
 
18
- class SurveyGuardrail(Guardrail):
19
  llm_model: OpenAIModel
20
-
21
  @weave.op()
22
  def load_prompt_injection_survey(self) -> str:
23
  prompt_injection_survey_path = os.path.join(
@@ -61,7 +61,7 @@ Here are some strict instructions that you must follow:
61
  return user_prompt, system_prompt
62
 
63
  @weave.op()
64
- def guard(self, prompt: str, **kwargs) -> list[str]:
65
  user_prompt, system_prompt = self.format_prompts(prompt)
66
  chat_completion = self.llm_model.predict(
67
  user_prompts=user_prompt,
@@ -70,3 +70,8 @@ Here are some strict instructions that you must follow:
70
  **kwargs,
71
  )
72
  return chat_completion.choices[0].message.parsed
 
 
 
 
 
 
15
  explanation: Optional[str]
16
 
17
 
18
+ class PromptInjectionSurveyGuardrail(Guardrail):
19
  llm_model: OpenAIModel
20
+
21
  @weave.op()
22
  def load_prompt_injection_survey(self) -> str:
23
  prompt_injection_survey_path = os.path.join(
 
61
  return user_prompt, system_prompt
62
 
63
  @weave.op()
64
+ def predict(self, prompt: str, **kwargs) -> list[str]:
65
  user_prompt, system_prompt = self.format_prompts(prompt)
66
  chat_completion = self.llm_model.predict(
67
  user_prompts=user_prompt,
 
70
  **kwargs,
71
  )
72
  return chat_completion.choices[0].message.parsed
73
+
74
+ @weave.op()
75
+ def guard(self, prompt: str, **kwargs) -> list[str]:
76
+ response = self.predict(prompt, **kwargs)
77
+ return {"safe": not response.injection_prompt}
guardrails_genie/guardrails/manager.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import weave
2
+ from rich.progress import track
3
+ from weave.flow.obj import Object as WeaveObject
4
+
5
+ from .base import Guardrail
6
+
7
+
8
+ class GuardrailManager(WeaveObject):
9
+ guardrails: list[Guardrail]
10
+
11
+ @weave.op()
12
+ def guard(self, prompt: str, **kwargs) -> dict:
13
+ alerts, safe = [], True
14
+ for guardrail in track(self.guardrails, description="Running guardrails"):
15
+ response = guardrail.guard(prompt, **kwargs)
16
+ alerts.append(
17
+ {"guardrail_name": guardrail.__class__.__name__, "response": response}
18
+ )
19
+ safe = safe and response["safe"]
20
+ return {"safe": safe, "alerts": alerts}
pyproject.toml CHANGED
@@ -12,7 +12,7 @@ dependencies = [
12
  "ruff>=0.6.9",
13
  "pip>=24.2",
14
  "uv>=0.4.20",
15
- "weave>=0.51.19",
16
  "streamlit>=1.40.1",
17
  "python-dotenv>=1.0.1",
18
  "watchdog>=6.0.0",
 
12
  "ruff>=0.6.9",
13
  "pip>=24.2",
14
  "uv>=0.4.20",
15
+ "weave>=0.51.22",
16
  "streamlit>=1.40.1",
17
  "python-dotenv>=1.0.1",
18
  "watchdog>=6.0.0",