Spaces:
Running
Running
geekyrakshit
commited on
Merge pull request #2 from soumik12345/feat/guardrails-api
Browse files- .streamlit/config.toml +7 -0
- README.md +6 -0
- app.py +14 -50
- application_pages/chat_app.py +116 -0
- application_pages/evaluation_app.py +3 -0
- application_pages/intro_page.py +7 -0
- guardrails_genie/guardrails/__init__.py +7 -2
- guardrails_genie/guardrails/base.py +0 -4
- guardrails_genie/guardrails/injection/__init__.py +3 -2
- guardrails_genie/guardrails/injection/protectai_guardrail.py +34 -0
- guardrails_genie/guardrails/injection/survey_guardrail.py +8 -3
- guardrails_genie/guardrails/manager.py +20 -0
- pyproject.toml +1 -1
.streamlit/config.toml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[server]
|
2 |
+
|
3 |
+
headless = true
|
4 |
+
runOnSave = true
|
5 |
+
allowRunOnSave = true
|
6 |
+
fastReruns = true
|
7 |
+
fileWatcherType = "auto"
|
README.md
CHANGED
@@ -12,3 +12,9 @@ uv venv
|
|
12 |
uv pip install -e .
|
13 |
source .venv/bin/activate
|
14 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
uv pip install -e .
|
13 |
source .venv/bin/activate
|
14 |
```
|
15 |
+
|
16 |
+
## Run Chat App
|
17 |
+
|
18 |
+
```bash
|
19 |
+
OPENAI_API_KEY="YOUR_OPENAI_API_KEY" streamlit run app.py
|
20 |
+
```
|
app.py
CHANGED
@@ -1,52 +1,16 @@
|
|
1 |
import streamlit as st
|
2 |
-
import weave
|
3 |
-
from dotenv import load_dotenv
|
4 |
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
st.session_state.chat_started = True
|
20 |
-
|
21 |
-
# Display chat UI if chat has started
|
22 |
-
if st.session_state.chat_started:
|
23 |
-
st.title("Guardrails Genie")
|
24 |
-
|
25 |
-
# Initialize chat history
|
26 |
-
if "messages" not in st.session_state:
|
27 |
-
st.session_state.messages = []
|
28 |
-
|
29 |
-
llm_model = OpenAIModel(model_name=openai_model)
|
30 |
-
|
31 |
-
# Display chat messages from history on app rerun
|
32 |
-
for message in st.session_state.messages:
|
33 |
-
with st.chat_message(message["role"]):
|
34 |
-
st.markdown(message["content"])
|
35 |
-
|
36 |
-
# React to user input
|
37 |
-
if prompt := st.chat_input("What is up?"):
|
38 |
-
# Display user message in chat message container
|
39 |
-
st.chat_message("user").markdown(prompt)
|
40 |
-
# Add user message to chat history
|
41 |
-
st.session_state.messages.append({"role": "user", "content": prompt})
|
42 |
-
|
43 |
-
response, call = llm_model.predict.call(
|
44 |
-
llm_model, user_prompts=prompt, messages=st.session_state.messages
|
45 |
-
)
|
46 |
-
response = response.choices[0].message.content
|
47 |
-
|
48 |
-
# Display assistant response in chat message container
|
49 |
-
with st.chat_message("assistant"):
|
50 |
-
st.markdown(response + f"\n\n---\n[Explore in Weave]({call.ui_url})")
|
51 |
-
# Add assistant response to chat history
|
52 |
-
st.session_state.messages.append({"role": "assistant", "content": response})
|
|
|
1 |
import streamlit as st
|
|
|
|
|
2 |
|
3 |
+
intro_page = st.Page(
|
4 |
+
"application_pages/intro_page.py", title="Introduction", icon=":material/guardian:"
|
5 |
+
)
|
6 |
+
chat_page = st.Page(
|
7 |
+
"application_pages/chat_app.py", title="Chat", icon=":material/robot:"
|
8 |
+
)
|
9 |
+
evaluation_page = st.Page(
|
10 |
+
"application_pages/evaluation_app.py",
|
11 |
+
title="Evaluation",
|
12 |
+
icon=":material/monitoring:",
|
13 |
+
)
|
14 |
+
page_navigation = st.navigation([intro_page, chat_page, evaluation_page])
|
15 |
+
st.set_page_config(page_title="Guardrails Genie", page_icon=":material/guardian:")
|
16 |
+
page_navigation.run()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
application_pages/chat_app.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
|
3 |
+
import streamlit as st
|
4 |
+
import weave
|
5 |
+
from dotenv import load_dotenv
|
6 |
+
|
7 |
+
from guardrails_genie.guardrails import GuardrailManager
|
8 |
+
from guardrails_genie.llm import OpenAIModel
|
9 |
+
|
10 |
+
load_dotenv()
|
11 |
+
weave.init(project_name="guardrails-genie")
|
12 |
+
|
13 |
+
st.title(":material/robot: Guardrails Genie")
|
14 |
+
|
15 |
+
if "guardrails" not in st.session_state:
|
16 |
+
st.session_state.guardrails = []
|
17 |
+
if "guardrail_names" not in st.session_state:
|
18 |
+
st.session_state.guardrail_names = []
|
19 |
+
if "guardrails_manager" not in st.session_state:
|
20 |
+
st.session_state.guardrails_manager = None
|
21 |
+
if "chat_started" not in st.session_state:
|
22 |
+
st.session_state.chat_started = False
|
23 |
+
|
24 |
+
|
25 |
+
def initialize_guardrails():
|
26 |
+
st.session_state.guardrails = []
|
27 |
+
for guardrail_name in st.session_state.guardrail_names:
|
28 |
+
if guardrail_name == "PromptInjectionSurveyGuardrail":
|
29 |
+
survey_guardrail_model = st.sidebar.selectbox(
|
30 |
+
"Survey Guardrail LLM", ["", "gpt-4o-mini", "gpt-4o"]
|
31 |
+
)
|
32 |
+
if survey_guardrail_model:
|
33 |
+
st.session_state.guardrails.append(
|
34 |
+
getattr(
|
35 |
+
importlib.import_module("guardrails_genie.guardrails"),
|
36 |
+
guardrail_name,
|
37 |
+
)(llm_model=OpenAIModel(model_name=survey_guardrail_model))
|
38 |
+
)
|
39 |
+
else:
|
40 |
+
st.session_state.guardrails.append(
|
41 |
+
getattr(
|
42 |
+
importlib.import_module("guardrails_genie.guardrails"),
|
43 |
+
guardrail_name,
|
44 |
+
)()
|
45 |
+
)
|
46 |
+
st.session_state.guardrails_manager = GuardrailManager(
|
47 |
+
guardrails=st.session_state.guardrails
|
48 |
+
)
|
49 |
+
|
50 |
+
|
51 |
+
openai_model = st.sidebar.selectbox(
|
52 |
+
"OpenAI LLM for Chat", ["", "gpt-4o-mini", "gpt-4o"]
|
53 |
+
)
|
54 |
+
chat_condition = openai_model != ""
|
55 |
+
|
56 |
+
guardrails = []
|
57 |
+
|
58 |
+
guardrail_names = st.sidebar.multiselect(
|
59 |
+
label="Select Guardrails",
|
60 |
+
options=[
|
61 |
+
cls_name
|
62 |
+
for cls_name, cls_obj in vars(
|
63 |
+
importlib.import_module("guardrails_genie.guardrails")
|
64 |
+
).items()
|
65 |
+
if isinstance(cls_obj, type) and cls_name != "GuardrailManager"
|
66 |
+
],
|
67 |
+
)
|
68 |
+
st.session_state.guardrail_names = guardrail_names
|
69 |
+
|
70 |
+
if st.sidebar.button("Start Chat") and chat_condition:
|
71 |
+
st.session_state.chat_started = True
|
72 |
+
|
73 |
+
if st.session_state.chat_started:
|
74 |
+
with st.sidebar.status("Initializing Guardrails..."):
|
75 |
+
initialize_guardrails()
|
76 |
+
|
77 |
+
st.title("Guardrails Genie")
|
78 |
+
|
79 |
+
# Initialize chat history
|
80 |
+
if "messages" not in st.session_state:
|
81 |
+
st.session_state.messages = []
|
82 |
+
|
83 |
+
llm_model = OpenAIModel(model_name=openai_model)
|
84 |
+
|
85 |
+
# Display chat messages from history on app rerun
|
86 |
+
for message in st.session_state.messages:
|
87 |
+
with st.chat_message(message["role"]):
|
88 |
+
st.markdown(message["content"])
|
89 |
+
|
90 |
+
# React to user input
|
91 |
+
if prompt := st.chat_input("What is up?"):
|
92 |
+
# Display user message in chat message container
|
93 |
+
st.chat_message("user").markdown(prompt)
|
94 |
+
# Add user message to chat history
|
95 |
+
st.session_state.messages.append({"role": "user", "content": prompt})
|
96 |
+
|
97 |
+
guardrails_response, call = st.session_state.guardrails_manager.guard.call(
|
98 |
+
st.session_state.guardrails_manager, prompt=prompt
|
99 |
+
)
|
100 |
+
|
101 |
+
if guardrails_response["safe"]:
|
102 |
+
response, call = llm_model.predict.call(
|
103 |
+
llm_model, user_prompts=prompt, messages=st.session_state.messages
|
104 |
+
)
|
105 |
+
response = response.choices[0].message.content
|
106 |
+
|
107 |
+
# Display assistant response in chat message container
|
108 |
+
with st.chat_message("assistant"):
|
109 |
+
st.markdown(response + f"\n\n---\n[Explore in Weave]({call.ui_url})")
|
110 |
+
# Add assistant response to chat history
|
111 |
+
st.session_state.messages.append({"role": "assistant", "content": response})
|
112 |
+
else:
|
113 |
+
st.error("Guardrails detected an issue with the prompt.")
|
114 |
+
for alert in guardrails_response["alerts"]:
|
115 |
+
st.error(f"{alert['guardrail_name']}: {alert['response']}")
|
116 |
+
st.error(f"For details, explore in Weave at {call.ui_url}")
|
application_pages/evaluation_app.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
|
3 |
+
st.title(":material/monitoring: Evaluation")
|
application_pages/intro_page.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
|
3 |
+
st.title("🧞♂️ Guardrails Genie")
|
4 |
+
|
5 |
+
st.write(
|
6 |
+
"Guardrails-Genie is a tool that helps you implement guardrails in your LLM applications."
|
7 |
+
)
|
guardrails_genie/guardrails/__init__.py
CHANGED
@@ -1,3 +1,8 @@
|
|
1 |
-
from .injection import
|
|
|
2 |
|
3 |
-
__all__ = [
|
|
|
|
|
|
|
|
|
|
1 |
+
from .injection import PromptInjectionProtectAIGuardrail, PromptInjectionSurveyGuardrail
|
2 |
+
from .manager import GuardrailManager
|
3 |
|
4 |
+
__all__ = [
|
5 |
+
"PromptInjectionSurveyGuardrail",
|
6 |
+
"PromptInjectionProtectAIGuardrail",
|
7 |
+
"GuardrailManager",
|
8 |
+
]
|
guardrails_genie/guardrails/base.py
CHANGED
@@ -11,7 +11,3 @@ class Guardrail(weave.Model):
|
|
11 |
@weave.op()
|
12 |
def guard(self, prompt: str, **kwargs) -> list[str]:
|
13 |
pass
|
14 |
-
|
15 |
-
@weave.op()
|
16 |
-
def predict(self, prompt: str, **kwargs) -> list[str]:
|
17 |
-
return self.guard(prompt, **kwargs)
|
|
|
11 |
@weave.op()
|
12 |
def guard(self, prompt: str, **kwargs) -> list[str]:
|
13 |
pass
|
|
|
|
|
|
|
|
guardrails_genie/guardrails/injection/__init__.py
CHANGED
@@ -1,3 +1,4 @@
|
|
1 |
-
from .
|
|
|
2 |
|
3 |
-
__all__ = ["
|
|
|
1 |
+
from .protectai_guardrail import PromptInjectionProtectAIGuardrail
|
2 |
+
from .survey_guardrail import PromptInjectionSurveyGuardrail
|
3 |
|
4 |
+
__all__ = ["PromptInjectionSurveyGuardrail", "PromptInjectionProtectAIGuardrail"]
|
guardrails_genie/guardrails/injection/protectai_guardrail.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import weave
|
5 |
+
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
|
6 |
+
from transformers.pipelines.base import Pipeline
|
7 |
+
|
8 |
+
from ..base import Guardrail
|
9 |
+
|
10 |
+
|
11 |
+
class PromptInjectionProtectAIGuardrail(Guardrail):
|
12 |
+
model_name: str = "ProtectAI/deberta-v3-base-prompt-injection-v2"
|
13 |
+
_classifier: Optional[Pipeline] = None
|
14 |
+
|
15 |
+
def model_post_init(self, __context):
|
16 |
+
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
17 |
+
model = AutoModelForSequenceClassification.from_pretrained(self.model_name)
|
18 |
+
self._classifier = pipeline(
|
19 |
+
"text-classification",
|
20 |
+
model=model,
|
21 |
+
tokenizer=tokenizer,
|
22 |
+
truncation=True,
|
23 |
+
max_length=512,
|
24 |
+
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
|
25 |
+
)
|
26 |
+
|
27 |
+
@weave.op()
|
28 |
+
def predict(self, prompt: str):
|
29 |
+
return self._classifier(prompt)
|
30 |
+
|
31 |
+
@weave.op()
|
32 |
+
def guard(self, prompt: str):
|
33 |
+
response = self.predict(prompt)
|
34 |
+
return {"safe": response[0]["label"] != "INJECTION"}
|
guardrails_genie/guardrails/injection/survey_guardrail.py
CHANGED
@@ -15,9 +15,9 @@ class SurveyGuardrailResponse(BaseModel):
|
|
15 |
explanation: Optional[str]
|
16 |
|
17 |
|
18 |
-
class
|
19 |
llm_model: OpenAIModel
|
20 |
-
|
21 |
@weave.op()
|
22 |
def load_prompt_injection_survey(self) -> str:
|
23 |
prompt_injection_survey_path = os.path.join(
|
@@ -61,7 +61,7 @@ Here are some strict instructions that you must follow:
|
|
61 |
return user_prompt, system_prompt
|
62 |
|
63 |
@weave.op()
|
64 |
-
def
|
65 |
user_prompt, system_prompt = self.format_prompts(prompt)
|
66 |
chat_completion = self.llm_model.predict(
|
67 |
user_prompts=user_prompt,
|
@@ -70,3 +70,8 @@ Here are some strict instructions that you must follow:
|
|
70 |
**kwargs,
|
71 |
)
|
72 |
return chat_completion.choices[0].message.parsed
|
|
|
|
|
|
|
|
|
|
|
|
15 |
explanation: Optional[str]
|
16 |
|
17 |
|
18 |
+
class PromptInjectionSurveyGuardrail(Guardrail):
|
19 |
llm_model: OpenAIModel
|
20 |
+
|
21 |
@weave.op()
|
22 |
def load_prompt_injection_survey(self) -> str:
|
23 |
prompt_injection_survey_path = os.path.join(
|
|
|
61 |
return user_prompt, system_prompt
|
62 |
|
63 |
@weave.op()
|
64 |
+
def predict(self, prompt: str, **kwargs) -> list[str]:
|
65 |
user_prompt, system_prompt = self.format_prompts(prompt)
|
66 |
chat_completion = self.llm_model.predict(
|
67 |
user_prompts=user_prompt,
|
|
|
70 |
**kwargs,
|
71 |
)
|
72 |
return chat_completion.choices[0].message.parsed
|
73 |
+
|
74 |
+
@weave.op()
|
75 |
+
def guard(self, prompt: str, **kwargs) -> list[str]:
|
76 |
+
response = self.predict(prompt, **kwargs)
|
77 |
+
return {"safe": not response.injection_prompt}
|
guardrails_genie/guardrails/manager.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import weave
|
2 |
+
from rich.progress import track
|
3 |
+
from weave.flow.obj import Object as WeaveObject
|
4 |
+
|
5 |
+
from .base import Guardrail
|
6 |
+
|
7 |
+
|
8 |
+
class GuardrailManager(WeaveObject):
|
9 |
+
guardrails: list[Guardrail]
|
10 |
+
|
11 |
+
@weave.op()
|
12 |
+
def guard(self, prompt: str, **kwargs) -> dict:
|
13 |
+
alerts, safe = [], True
|
14 |
+
for guardrail in track(self.guardrails, description="Running guardrails"):
|
15 |
+
response = guardrail.guard(prompt, **kwargs)
|
16 |
+
alerts.append(
|
17 |
+
{"guardrail_name": guardrail.__class__.__name__, "response": response}
|
18 |
+
)
|
19 |
+
safe = safe and response["safe"]
|
20 |
+
return {"safe": safe, "alerts": alerts}
|
pyproject.toml
CHANGED
@@ -12,7 +12,7 @@ dependencies = [
|
|
12 |
"ruff>=0.6.9",
|
13 |
"pip>=24.2",
|
14 |
"uv>=0.4.20",
|
15 |
-
"weave>=0.51.
|
16 |
"streamlit>=1.40.1",
|
17 |
"python-dotenv>=1.0.1",
|
18 |
"watchdog>=6.0.0",
|
|
|
12 |
"ruff>=0.6.9",
|
13 |
"pip>=24.2",
|
14 |
"uv>=0.4.20",
|
15 |
+
"weave>=0.51.22",
|
16 |
"streamlit>=1.40.1",
|
17 |
"python-dotenv>=1.0.1",
|
18 |
"watchdog>=6.0.0",
|