Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| import transformers | |
| import matplotlib.pyplot as plt | |
| def get_pipe(): | |
| model_name = "joeddav/distilbert-base-uncased-go-emotions-student" | |
| model = transformers.AutoModelForSequenceClassification.from_pretrained(model_name) | |
| tokenizer = transformers.AutoTokenizer.from_pretrained(model_name) | |
| pipe = transformers.pipeline('text-classification', model=model, tokenizer=tokenizer, | |
| return_all_scores=True, truncation=True) | |
| return pipe | |
| def sort_predictions(predictions): | |
| return sorted(predictions, key=lambda x: x['score'], reverse=True) | |
| st.set_page_config(page_title="Emotion Prediction") | |
| st.title("Emotion Prediction") | |
| st.write("Type text into the text box and then press 'Predict' to get the predicted emotion.") | |
| default_text = "I really love using HuggingFace Spaces!" | |
| text = st.text_area('Enter text here:', value=default_text) | |
| submit = st.button('Predict') | |
| with st.spinner("Loading model..."): | |
| pipe = get_pipe() | |
| if (submit and len(text.strip()) > 0) or len(text.strip()) > 0: | |
| prediction = pipe(text)[0] | |
| prediction = sort_predictions(prediction) | |
| fig, ax = plt.subplots() | |
| ax.bar(x=[i for i, _ in enumerate(prediction)], | |
| height=[p['score'] for p in prediction], | |
| tick_label=[p['label'] for p in prediction]) | |
| ax.tick_params(rotation=90) | |
| ax.set_ylim(0, 1) | |
| st.header('Prediction:') | |
| st.pyplot(fig) | |
| prediction = dict([(p['label'], p['score']) for p in prediction]) | |
| st.header('Raw values:') | |
| st.json(prediction) | |