RishabhBhardwaj commited on
Commit
d88280b
1 Parent(s): 7b8a95d

prevent model loads

Browse files
Files changed (1) hide show
  1. app.py +8 -9
app.py CHANGED
@@ -14,7 +14,7 @@ Answer: [/INST]
14
  """
15
 
16
  # Load the model and tokenizer
17
- @st.cache_data(persist="disk")
18
  def load_model():
19
  model_name = "walledai/walledguard-c"
20
  tokenizer = AutoTokenizer.from_pretrained(model_name)
@@ -32,8 +32,8 @@ def load_image_from_url(url):
32
  @st.experimental_fragment
33
  def evaluate_text(user_input, result_container):
34
  if user_input:
35
- # Load model and tokenizer
36
- tokenizer, model = load_model()
37
 
38
  # Prepare input
39
  input_ids = tokenizer.encode(TEMPLATE.format(prompt=user_input), return_tensors="pt")
@@ -59,6 +59,10 @@ def evaluate_text(user_input, result_container):
59
  # Streamlit app
60
  st.title("Text Safety Evaluator")
61
 
 
 
 
 
62
  # User input
63
  user_input = st.text_area("Enter the text you want to evaluate:", height=100)
64
 
@@ -73,9 +77,4 @@ col1, col2, col3 = st.columns([1,2,1])
73
  with col2:
74
  logo_url = "https://github.com/walledai/walledeval/assets/32847115/d8b1d14f-7071-448b-8997-2eeba4c2c8f6"
75
  logo = load_image_from_url(logo_url)
76
- st.image(logo, use_column_width=True, width=500) # Adjust the width as needed
77
-
78
- # Add information about Walled Guard Advanced
79
- col1, col2, col3 = st.columns([1,2,1])
80
- with col2:
81
- st.info("For a more performant version, check out Walled Guard Advanced. Connect with us at admin@walled.ai for more information.")
 
14
  """
15
 
16
  # Load the model and tokenizer
17
+ @st.cache_resource
18
  def load_model():
19
  model_name = "walledai/walledguard-c"
20
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
32
  @st.experimental_fragment
33
  def evaluate_text(user_input, result_container):
34
  if user_input:
35
+ # Get model and tokenizer from session state
36
+ tokenizer, model = st.session_state.model_and_tokenizer
37
 
38
  # Prepare input
39
  input_ids = tokenizer.encode(TEMPLATE.format(prompt=user_input), return_tensors="pt")
 
59
  # Streamlit app
60
  st.title("Text Safety Evaluator")
61
 
62
+ # Load model and tokenizer once and store in session state
63
+ if 'model_and_tokenizer' not in st.session_state:
64
+ st.session_state.model_and_tokenizer = load_model()
65
+
66
  # User input
67
  user_input = st.text_area("Enter the text you want to evaluate:", height=100)
68
 
 
77
  with col2:
78
  logo_url = "https://github.com/walledai/walledeval/assets/32847115/d8b1d14f-7071-448b-8997-2eeba4c2c8f6"
79
  logo = load_image_from_url(logo_url)
80
+ st.image(logo, use_column_