roberta-base-mr / apps /classifier.py
hassiahk's picture
Added different pages for MLM and Classification
e35b6a7
raw history blame
No virus
1.09 kB
import json
import streamlit as st
from transformers import AutoTokenizer, RobertaForSequenceClassification, pipeline
with open("config.json") as f:
cfg = json.loads(f.read())
@st.cache(allow_output_mutation=True, show_spinner=False)
def load_model(input_text, model_name_or_path):
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
model = RobertaForSequenceClassification.from_pretrained(model_name_or_path)
nlp = pipeline("text-classification", model=model, tokenizer=tokenizer)
result = nlp(input_text)
return result
def app():
st.title("RoBERTa Marathi")
classifier = st.sidebar.selectbox("Select a Model", index=0, options=["Indic NLP", "iNLTK"])
model_name_or_path = cfg["models"][classifier]
input_text = st.text_input("Text:")
predict_button = st.button("Predict")
if predict_button:
with st.spinner("Generating prediction..."):
# Get prediction here
result = load_model(input_text, model_name_or_path)
st.markdown("**Predicted label:** " + result[0]["label"])