RishabhBhardwaj's picture
try different template
28eb1f5
raw
history blame
2.41 kB
import streamlit as st
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForCausalLM
import requests
from PIL import Image
from io import BytesIO
# Define the template
TEMPLATE = """[INST] Evaluate if the given text is harmful i.e. either it asks for unsafe information or it contains unsafe information.
<START TEXT>
{prompt}
<END TEXT>
Carefully check!
Answer: [/INST]
"""
# Load the model and tokenizer
@st.cache_resource
def load_model():
model_name = "walledai/walledguard-c"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
return tokenizer, model
tokenizer, model = load_model()
# Function to load image from URL
@st.cache_data
def load_image_from_url(url):
response = requests.get(url)
img = Image.open(BytesIO(response.content))
return img
# Streamlit app
st.title("Text Safety Evaluator")
# User input
user_input = st.text_area("Enter the text you want to evaluate:", height=100)
if st.button("Evaluate"):
if user_input:
# Prepare input
input_ids = tokenizer.encode(TEMPLATE.format(prompt=user_input), return_tensors="pt")
# Generate output
output = model.generate(input_ids=input_ids, max_new_tokens=20, pad_token_id=0)
# Decode output
prompt_len = input_ids.shape[-1]
output_decoded = tokenizer.decode(output[0][prompt_len:], skip_special_tokens=True)
# Determine prediction
prediction = 'unsafe' if 'unsafe' in output_decoded.lower() else 'safe'
# Display results
st.subheader("Evaluation Result:")
st.write(f"The text is evaluated as: **{prediction.upper()}**")
else:
st.warning("Please enter some text to evaluate.")
# Add logo at the bottom center
#st.markdown("---")
col1, col2, col3 = st.columns([1,2,1])
with col2:
logo_url = "https://github.com/walledai/walledeval/assets/32847115/d8b1d14f-7071-448b-8997-2eeba4c2c8f6"
logo = load_image_from_url(logo_url)
st.image(logo, use_column_width=True, width=500) # Adjust the width as needed
# Add information about Walled Guard Advanced
#st.markdown("---")
col1, col2, col3 = st.columns([1,2,1])
with col2:
st.info("For a more performant version, check out Walled Guard Advanced. Connect with us at admin@walled.ai for more information.")