RishabhBhardwaj commited on
Commit
5d09640
1 Parent(s): 80f9197

first commit

Browse files
Files changed (1) hide show
  1. app.py +58 -2
app.py CHANGED
@@ -1,3 +1,59 @@
1
- import gradio as gr
 
2
 
3
- gr.load("models/walledai/walledguard-c").launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #import gradio as gr
2
+ #gr.load("models/walledai/walledguard-c").launch()
3
 
4
+ import streamlit as st
5
+ import torch
6
+ import torch.nn as nn
7
+ from transformers import AutoTokenizer, AutoModelForCausalLM
8
+
9
+ # Define the template
10
+ TEMPLATE = """[INST] Evaluate if the given text is harmful i.e. either it asks for unsafe information or it contains unsafe information.
11
+ <START TEXT>
12
+ {prompt}
13
+ <END TEXT>
14
+ Answer: [/INST]
15
+ """
16
+
17
+ # Load the model and tokenizer
18
+ @st.cache_resource
19
+ def load_model():
20
+ model_name = "walledai/walledguard-c"
21
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
22
+ model = AutoModelForCausalLM.from_pretrained(model_name)
23
+ return tokenizer, model
24
+
25
+ tokenizer, model = load_model()
26
+
27
+ # Streamlit app
28
+ st.title("Text Safety Evaluator")
29
+
30
+ # User input
31
+ user_input = st.text_area("Enter the text you want to evaluate:", height=100)
32
+
33
+ if st.button("Evaluate"):
34
+ if user_input:
35
+ # Prepare input
36
+ input_ids = tokenizer.encode(TEMPLATE.format(prompt=user_input), return_tensors="pt")
37
+
38
+ # Generate output
39
+ output = model.generate(input_ids=input_ids, max_new_tokens=20, pad_token_id=0)
40
+
41
+ # Decode output
42
+ prompt_len = input_ids.shape[-1]
43
+ output_decoded = tokenizer.decode(output[0][prompt_len:], skip_special_tokens=True)
44
+
45
+ # Determine prediction
46
+ prediction = 'unsafe' if 'unsafe' in output_decoded.lower() else 'safe'
47
+
48
+ # Display results
49
+ st.subheader("Evaluation Result:")
50
+ st.write(f"The text is evaluated as: **{prediction.upper()}**")
51
+
52
+ st.subheader("Model Output:")
53
+ st.write(output_decoded)
54
+ else:
55
+ st.warning("Please enter some text to evaluate.")
56
+
57
+ # Add some information about the model
58
+ st.sidebar.header("About")
59
+ st.sidebar.info("This app uses the WalledGuard-C model to evaluate the safety of input text. It determines whether the text is asking for or containing unsafe information.")