RishabhBhardwaj commited on
Commit
8be32c4
1 Parent(s): a2bd494

add examples

Browse files
Files changed (1) hide show
  1. app.py +31 -13
app.py CHANGED
@@ -3,6 +3,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import requests
4
  from PIL import Image
5
  from io import BytesIO
 
6
 
7
  # Define the template
8
  TEMPLATE = """[INST] Evaluate if the given text is harmful i.e. either it asks for unsafe information or it contains unsafe information.
@@ -28,6 +29,12 @@ def load_image_from_url(url):
28
  img = Image.open(BytesIO(response.content))
29
  return img
30
 
 
 
 
 
 
 
31
  # Evaluation function
32
  def evaluate_text(user_input):
33
  if user_input:
@@ -57,8 +64,19 @@ st.title("Text Safety Evaluator")
57
  if 'model_and_tokenizer' not in st.session_state:
58
  st.session_state.model_and_tokenizer = load_model()
59
 
 
 
 
 
 
 
 
 
 
60
  # User input
61
- user_input = st.text_area("Enter the text you want to evaluate:", height=100)
 
 
62
 
63
  # Create an empty container for the result
64
  result_container = st.empty()
@@ -72,17 +90,17 @@ if st.button("Evaluate"):
72
  result_container.warning("Please enter some text to evaluate.")
73
 
74
  # Add logo at the bottom center (only once)
75
- #if 'logo_displayed' not in st.session_state:
76
- col1, col2, col3 = st.columns([1,2,1])
77
- with col2:
78
- logo_url = "https://github.com/walledai/walledeval/assets/32847115/d8b1d14f-7071-448b-8997-2eeba4c2c8f6"
79
- logo = load_image_from_url(logo_url)
80
- st.image(logo, use_column_width=True, width=500) # Adjust the width as needed
81
- #st.session_state.logo_displayed = True
82
 
83
  # Add information about Walled Guard Advanced (only once)
84
- #if 'info_displayed' not in st.session_state:
85
- col1, col2, col3 = st.columns([1,2,1])
86
- with col2:
87
- st.info("For a more performant version, check out Walled Guard Advanced. Connect with us at admin@walled.ai for more information.")
88
- #st.session_state.info_displayed = True
 
3
  import requests
4
  from PIL import Image
5
  from io import BytesIO
6
+ from datasets import load_dataset
7
 
8
  # Define the template
9
  TEMPLATE = """[INST] Evaluate if the given text is harmful i.e. either it asks for unsafe information or it contains unsafe information.
 
29
  img = Image.open(BytesIO(response.content))
30
  return img
31
 
32
+ # Load dataset
33
+ @st.cache_data
34
+ def load_example_dataset():
35
+ ds = load_dataset("walledai/XSTest")
36
+ return ds['train']['prompt'][:10] # Get first 10 examples
37
+
38
  # Evaluation function
39
  def evaluate_text(user_input):
40
  if user_input:
 
64
  if 'model_and_tokenizer' not in st.session_state:
65
  st.session_state.model_and_tokenizer = load_model()
66
 
67
+ # Load example dataset
68
+ example_prompts = load_example_dataset()
69
+
70
+ # Display example prompts
71
+ st.subheader("Example Inputs:")
72
+ for i, prompt in enumerate(example_prompts):
73
+ if st.button(f"Example {i+1}", key=f"example_{i}"):
74
+ st.session_state.user_input = prompt
75
+
76
  # User input
77
+ user_input = st.text_area("Enter the text you want to evaluate:",
78
+ height=100,
79
+ value=st.session_state.get('user_input', ''))
80
 
81
  # Create an empty container for the result
82
  result_container = st.empty()
 
90
  result_container.warning("Please enter some text to evaluate.")
91
 
92
  # Add logo at the bottom center (only once)
93
+ if 'logo_displayed' not in st.session_state:
94
+ col1, col2, col3 = st.columns([1,2,1])
95
+ with col2:
96
+ logo_url = "https://github.com/walledai/walledeval/assets/32847115/d8b1d14f-7071-448b-8997-2eeba4c2c8f6"
97
+ logo = load_image_from_url(logo_url)
98
+ st.image(logo, use_column_width=True, width=500) # Adjust the width as needed
99
+ st.session_state.logo_displayed = True
100
 
101
  # Add information about Walled Guard Advanced (only once)
102
+ if 'info_displayed' not in st.session_state:
103
+ col1, col2, col3 = st.columns([1,2,1])
104
+ with col2:
105
+ st.info("For a more performant version, check out Walled Guard Advanced. Connect with us at admin@walled.ai for more information.")
106
+ st.session_state.info_displayed = True