TestApp / app.py
menikev's picture
Update app.py
41f1f8b verified
raw
history blame
No virus
1.77 kB
import streamlit as st
import torch
from prediction_sinhala import MDFEND, TokenizerFromPreTrained
# Set constants for model and tokenizer paths
MODEL_SAVE_PATH = "models/last-epoch-model-2024-03-08-15_34_03_6.pth"
BERT_MODEL_NAME = 'sinhala-nlp/sinbert-sold-si'
DOMAIN_NUM = 3
MAX_LEN = 160
BATCH_SIZE = 100
# Load model and tokenizer
@st.cache(allow_output_mutation=True)
def load_model():
# Load the tokenizer from the pre-trained model name
tokenizer = TokenizerFromPreTrained(MAX_LEN, BERT_MODEL_NAME)
# Initialize and load the custom model from saved state
model = MDFEND(BERT_MODEL_NAME, DOMAIN_NUM, expert_num=18, mlp_dims=[5080, 4020, 3010, 2024, 1012, 606, 400])
model.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location=torch.device('cpu')))
model.eval() # Set the model to evaluation mode
return model, tokenizer
model, tokenizer = load_model()
# User input
text_input = st.text_area("Enter text here:")
# Prediction
if st.button("Predict"):
if text_input: # Check if input is not empty
# Process the input text through the custom tokenizer
inputs = tokenizer.tokenize(text_input)
# Convert to tensor, add batch dimension, and send to same device as model
inputs = torch.tensor(inputs).unsqueeze(0).to(model.device)
with torch.no_grad(): # No gradient computation
# Get model prediction
output_prob = model.predict(inputs)
# Interpret the output probability
prediction = 1 if output_prob >= 0.5 else 0
result = "offensive" if prediction == 1 else "not offensive"
st.write(f"Prediction: {result}")
else:
st.error("Please enter some text to predict.")