menikev commited on
Commit
78e2935
1 Parent(s): cbf5c64

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -4
app.py CHANGED
@@ -1,17 +1,23 @@
1
  import streamlit as st
2
- from transformers import AutoModelForSequenceClassification, AutoTokenizer
 
3
 
 
4
  @st.cache(allow_output_mutation=True)
5
  def load_model():
6
- tokenizer = AutoTokenizer.from_pretrained("./prediction_sinhala.ipynb")
7
- model = AutoModelForSequenceClassification.from_pretrained("./prediction_sinhala.ipynb")
8
  return model, tokenizer
9
 
10
  model, tokenizer = load_model()
11
 
 
12
  text_input = st.text_area("Enter text here:")
 
 
13
  if st.button("Predict"):
14
  inputs = tokenizer(text_input, return_tensors="pt")
15
- outputs = model(**inputs)
 
16
  prediction = outputs.logits.argmax(-1).item()
17
  st.write(f"Prediction: {prediction}")
 
1
  import streamlit as st
2
+ from transformers import RobertaTokenizer, RobertaModel
3
+ from your_model_file import MDFEND # Ensure you import your model correctly
4
 
5
+ # Load model and tokenizer
6
  @st.cache(allow_output_mutation=True)
7
  def load_model():
8
+ tokenizer = RobertaTokenizer.from_pretrained("prediction_sinhala.ipynb")
9
+ model = MDFEND.from_pretrained("prediction_sinhala.ipynb")
10
  return model, tokenizer
11
 
12
  model, tokenizer = load_model()
13
 
14
+ # User input
15
  text_input = st.text_area("Enter text here:")
16
+
17
+ # Prediction
18
  if st.button("Predict"):
19
  inputs = tokenizer(text_input, return_tensors="pt")
20
+ with torch.no_grad(): # Ensure no gradients are computed
21
+ outputs = model(**inputs)
22
  prediction = outputs.logits.argmax(-1).item()
23
  st.write(f"Prediction: {prediction}")