geekyrakshit commited on
Commit
9269c10
1 Parent(s): 1d117f2

update: app

Browse files
Files changed (2) hide show
  1. .streamlit/config.toml +7 -0
  2. chat_app.py +55 -18
.streamlit/config.toml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ [server]
2
+
3
+ headless = true
4
+ runOnSave = true
5
+ allowRunOnSave = true
6
+ fastReruns = true
7
+ fileWatcherType = "auto"
chat_app.py CHANGED
@@ -1,33 +1,67 @@
 
 
1
  import streamlit as st
2
  import weave
3
  from dotenv import load_dotenv
4
 
5
- from guardrails_genie.guardrails import GuardrailManager, PromptInjectionSurveyGuardrail
6
  from guardrails_genie.llm import OpenAIModel
7
 
8
  load_dotenv()
9
  weave.init(project_name="guardrails-genie")
10
 
11
- openai_model = st.sidebar.selectbox("OpenAI LLM", ["", "gpt-4o-mini", "gpt-4o"])
12
- chat_condition = openai_model != ""
13
-
14
- guardrails = []
 
 
15
 
16
- with st.sidebar.expander("Switch on Prompt Injection Guardrails"):
17
- is_survey_guardrail_enabled = st.toggle("Survey Guardrail")
18
 
19
- if is_survey_guardrail_enabled:
20
- survey_guardrail_model = st.selectbox(
21
- "Survey Guardrail Model", ["", "gpt-4o-mini", "gpt-4o"]
22
- )
23
- if survey_guardrail_model:
24
- guardrails.append(
25
- PromptInjectionSurveyGuardrail(
26
- llm_model=OpenAIModel(model_name=survey_guardrail_model)
 
 
 
 
 
27
  )
 
 
 
 
 
 
28
  )
 
 
 
29
 
30
- guardrails_manager = GuardrailManager(guardrails=guardrails)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  # Use session state to track if the chat has started
33
  if "chat_started" not in st.session_state:
@@ -39,6 +73,9 @@ if st.sidebar.button("Start Chat") and chat_condition:
39
 
40
  # Display chat UI if chat has started
41
  if st.session_state.chat_started:
 
 
 
42
  st.title("Guardrails Genie")
43
 
44
  # Initialize chat history
@@ -59,8 +96,8 @@ if st.session_state.chat_started:
59
  # Add user message to chat history
60
  st.session_state.messages.append({"role": "user", "content": prompt})
61
 
62
- guardrails_response, call = guardrails_manager.guard.call(
63
- guardrails_manager, prompt=prompt
64
  )
65
 
66
  if guardrails_response["safe"]:
 
1
+ import importlib
2
+
3
  import streamlit as st
4
  import weave
5
  from dotenv import load_dotenv
6
 
7
+ from guardrails_genie.guardrails import GuardrailManager
8
  from guardrails_genie.llm import OpenAIModel
9
 
10
  load_dotenv()
11
  weave.init(project_name="guardrails-genie")
12
 
13
+ if "guardrails" not in st.session_state:
14
+ st.session_state.guardrails = []
15
+ if "guardrail_names" not in st.session_state:
16
+ st.session_state.guardrail_names = []
17
+ if "guardrails_manager" not in st.session_state:
18
+ st.session_state.guardrails_manager = None
19
 
 
 
20
 
21
+ def initialize_guardrails():
22
+ st.session_state.guardrails = []
23
+ for guardrail_name in st.session_state.guardrail_names:
24
+ if guardrail_name == "PromptInjectionSurveyGuardrail":
25
+ survey_guardrail_model = st.sidebar.selectbox(
26
+ "Survey Guardrail LLM", ["", "gpt-4o-mini", "gpt-4o"]
27
+ )
28
+ if survey_guardrail_model:
29
+ st.session_state.guardrails.append(
30
+ getattr(
31
+ importlib.import_module("guardrails_genie.guardrails"),
32
+ guardrail_name,
33
+ )(llm_model=OpenAIModel(model_name=survey_guardrail_model))
34
  )
35
+ else:
36
+ st.session_state.guardrails.append(
37
+ getattr(
38
+ importlib.import_module("guardrails_genie.guardrails"),
39
+ guardrail_name,
40
+ )()
41
  )
42
+ st.session_state.guardrails_manager = GuardrailManager(
43
+ guardrails=st.session_state.guardrails
44
+ )
45
 
46
+
47
+ openai_model = st.sidebar.selectbox(
48
+ "OpenAI LLM for Chat", ["", "gpt-4o-mini", "gpt-4o"]
49
+ )
50
+ chat_condition = openai_model != ""
51
+
52
+ guardrails = []
53
+
54
+ guardrail_names = st.sidebar.multiselect(
55
+ label="Select Guardrails",
56
+ options=[
57
+ cls_name
58
+ for cls_name, cls_obj in vars(
59
+ importlib.import_module("guardrails_genie.guardrails")
60
+ ).items()
61
+ if isinstance(cls_obj, type) and cls_name != "GuardrailManager"
62
+ ],
63
+ )
64
+ st.session_state.guardrail_names = guardrail_names
65
 
66
  # Use session state to track if the chat has started
67
  if "chat_started" not in st.session_state:
 
73
 
74
  # Display chat UI if chat has started
75
  if st.session_state.chat_started:
76
+ with st.sidebar.status("Initializing Guardrails..."):
77
+ initialize_guardrails()
78
+
79
  st.title("Guardrails Genie")
80
 
81
  # Initialize chat history
 
96
  # Add user message to chat history
97
  st.session_state.messages.append({"role": "user", "content": prompt})
98
 
99
+ guardrails_response, call = st.session_state.guardrails_manager.guard.call(
100
+ st.session_state.guardrails_manager, prompt=prompt
101
  )
102
 
103
  if guardrails_response["safe"]: