bentrevett's picture
added default text
f85c94c
import streamlit as st
import transformers
import matplotlib.pyplot as plt
@st.cache(allow_output_mutation=True, show_spinner=False)
def get_pipe():
model_name = "joeddav/distilbert-base-uncased-go-emotions-student"
model = transformers.AutoModelForSequenceClassification.from_pretrained(model_name)
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
pipe = transformers.pipeline('text-classification', model=model, tokenizer=tokenizer,
return_all_scores=True, truncation=True)
return pipe
def sort_predictions(predictions):
return sorted(predictions, key=lambda x: x['score'], reverse=True)
st.set_page_config(page_title="Emotion Prediction")
st.title("Emotion Prediction")
st.write("Type text into the text box and then press 'Predict' to get the predicted emotion.")
default_text = "I really love using HuggingFace Spaces!"
text = st.text_area('Enter text here:', value=default_text)
submit = st.button('Predict')
with st.spinner("Loading model..."):
pipe = get_pipe()
if (submit and len(text.strip()) > 0) or len(text.strip()) > 0:
prediction = pipe(text)[0]
prediction = sort_predictions(prediction)
fig, ax = plt.subplots()
ax.bar(x=[i for i, _ in enumerate(prediction)],
height=[p['score'] for p in prediction],
tick_label=[p['label'] for p in prediction])
ax.tick_params(rotation=90)
ax.set_ylim(0, 1)
st.header('Prediction:')
st.pyplot(fig)
prediction = dict([(p['label'], p['score']) for p in prediction])
st.header('Raw values:')
st.json(prediction)