File size: 761 Bytes
21ac434
fdd0b8f
78e2935
fdd0b8f
21ac434
78e2935
21ac434
 
fdd0b8f
 
 
21ac434
 
 
 
78e2935
21ac434
78e2935
 
21ac434
 
78e2935
 
21ac434
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import streamlit as st
import torch  
from transformers import RobertaTokenizer, RobertaModel
from prediction_sinhala import MDFEND  

# Load model and tokenizer
@st.cache(allow_output_mutation=True)
def load_model():
    
    tokenizer = RobertaTokenizer.from_pretrained("./prediction_sinhala/")
    model = MDFEND.from_pretrained("./prediction_sinhala/")
    return model, tokenizer

model, tokenizer = load_model()

# User input
text_input = st.text_area("Enter text here:")

# Prediction
if st.button("Predict"):
    inputs = tokenizer(text_input, return_tensors="pt")
    with torch.no_grad():  # Ensure no gradients are computed
        outputs = model(**inputs)
    prediction = outputs.logits.argmax(-1).item()
    st.write(f"Prediction: {prediction}")