|
import streamlit as st |
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer |
|
|
|
@st.cache(allow_output_mutation=True) |
|
def load_model(): |
|
tokenizer = AutoTokenizer.from_pretrained("./my_model") |
|
model = AutoModelForSequenceClassification.from_pretrained("./my_model") |
|
return model, tokenizer |
|
|
|
model, tokenizer = load_model() |
|
|
|
text_input = st.text_area("Enter text here:") |
|
if st.button("Predict"): |
|
inputs = tokenizer(text_input, return_tensors="pt") |
|
outputs = model(**inputs) |
|
prediction = outputs.logits.argmax(-1).item() |
|
st.write(f"Prediction: {prediction}") |
|
|