File size: 3,616 Bytes
2f935de
 
bc152f8
 
 
8be32c4
5d09640
2f935de
 
 
 
 
28eb1f5
2f935de
 
5d09640
2f935de
d88280b
2f935de
 
 
 
 
5d09640
bc152f8
ecf0f3e
bc152f8
 
 
 
 
8be32c4
 
 
 
 
 
a3e8717
 
2f935de
d88280b
 
4760da5
2f935de
 
 
 
 
 
 
 
 
 
 
 
 
a3e8717
 
2f935de
ecf0f3e
 
 
d88280b
 
 
 
8be32c4
 
 
 
 
 
 
 
 
ecf0f3e
8be32c4
 
 
ecf0f3e
 
 
 
 
a3e8717
 
 
 
 
 
ecf0f3e
a3e8717
8be32c4
 
 
 
 
 
 
e5fca1e
a3e8717
8be32c4
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM
import requests
from PIL import Image
from io import BytesIO
from datasets import load_dataset

# Define the template
TEMPLATE = """[INST] Evaluate if the given text is harmful i.e. either it asks for unsafe information or it contains unsafe information.
<START TEXT>
{prompt}
<END TEXT>

Answer: [/INST]
"""

# Load the model and tokenizer
@st.cache_resource
def load_model():
    model_name = "walledai/walledguard-c"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name)
    return tokenizer, model

# Function to load image from URL
@st.cache_data()
def load_image_from_url(url):
    response = requests.get(url)
    img = Image.open(BytesIO(response.content))
    return img

# Load dataset
@st.cache_data
def load_example_dataset():
    ds = load_dataset("walledai/XSTest")
    return ds['train']['prompt'][:10]  # Get first 10 examples

# Evaluation function
def evaluate_text(user_input):
    if user_input:
        # Get model and tokenizer from session state
        tokenizer, model = st.session_state.model_and_tokenizer
        
        # Prepare input
        input_ids = tokenizer.encode(TEMPLATE.format(prompt=user_input), return_tensors="pt")
        
        # Generate output
        output = model.generate(input_ids=input_ids, max_new_tokens=20, pad_token_id=0)
        
        # Decode output
        prompt_len = input_ids.shape[-1]
        output_decoded = tokenizer.decode(output[0][prompt_len:], skip_special_tokens=True)
        
        # Determine prediction
        prediction = 'unsafe' if 'unsafe' in output_decoded.lower() else 'safe'
        
        return prediction
    return None

# Streamlit app
st.title("Text Safety Evaluator")

# Load model and tokenizer once and store in session state
if 'model_and_tokenizer' not in st.session_state:
    st.session_state.model_and_tokenizer = load_model()

# Load example dataset
example_prompts = load_example_dataset()

# Display example prompts
st.subheader("Example Inputs:")
for i, prompt in enumerate(example_prompts):
    if st.button(f"Example {i+1}", key=f"example_{i}"):
        st.session_state.user_input = prompt

# User input
user_input = st.text_area("Enter the text you want to evaluate:", 
                          height=100, 
                          value=st.session_state.get('user_input', ''))

# Create an empty container for the result
result_container = st.empty()

if st.button("Evaluate"):
    prediction = evaluate_text(user_input)
    if prediction:
        result_container.subheader("Evaluation Result:")
        result_container.write(f"The text is evaluated as: **{prediction.upper()}**")
    else:
        result_container.warning("Please enter some text to evaluate.")

# Add logo at the bottom center (only once)
if 'logo_displayed' not in st.session_state:
    col1, col2, col3 = st.columns([1,2,1])
    with col2:
        logo_url = "https://github.com/walledai/walledeval/assets/32847115/d8b1d14f-7071-448b-8997-2eeba4c2c8f6"
        logo = load_image_from_url(logo_url)
        st.image(logo, use_column_width=True, width=500)  # Adjust the width as needed
    st.session_state.logo_displayed = True

# Add information about Walled Guard Advanced (only once)
if 'info_displayed' not in st.session_state:
    col1, col2, col3 = st.columns([1,2,1])
    with col2:
        st.info("For a more performant version, check out Walled Guard Advanced. Connect with us at admin@walled.ai for more information.")
    st.session_state.info_displayed = True