import streamlit as st import torch from transformers import RobertaTokenizer, RobertaModel from prediction_sinhala import MDFEND # Load model and tokenizer @st.cache(allow_output_mutation=True) def load_model(): tokenizer = RobertaTokenizer.from_pretrained("./prediction_sinhala/") model = MDFEND.from_pretrained("./prediction_sinhala/") return model, tokenizer model, tokenizer = load_model() # User input text_input = st.text_area("Enter text here:") # Prediction if st.button("Predict"): inputs = tokenizer(text_input, return_tensors="pt") with torch.no_grad(): # Ensure no gradients are computed outputs = model(**inputs) prediction = outputs.logits.argmax(-1).item() st.write(f"Prediction: {prediction}")