""" |
Filename: app.py |
Author: @DvdNss |
Created on 12/18/2021 |
""" |
import os |
import gdown as gdown |
import nltk |
import streamlit as st |
from nltk.tokenize import sent_tokenize |
from source.pipeline import MultiLabelPipeline, inputs_to_dataset |
def download_models(ids): |
""" |
Download all models. |
:param ids: name and links of models |
:return: |
""" |
nltk.download('punkt') |
for key in ids: |
if not os.path.isfile(f"model/{key}.pt"): |
url = f"https://drive.google.com/uc?id={ids[key]}" |
gdown.download(url=url, output=f"model/{key}.pt") |
@st.cache |
def load_labels(): |
""" |
Load model labels. |
:return: |
""" |
return [ |
"admiration", |
"amusement", |
"anger", |
"annoyance", |
"approval", |
"caring", |
"confusion", |
"curiosity", |
"desire", |
"disappointment", |
"disapproval", |
"disgust", |
"embarrassment", |
"excitement", |
"fear", |
"gratitude", |
"grief", |
"joy", |
"love", |
"nervousness", |
"optimism", |
"pride", |
"realization", |
"relief", |
"remorse", |
"sadness", |
"surprise", |
"neutral" |
] |
@st.cache(allow_output_mutation=True) |
def load_model(model_path): |
""" |
Load model and cache it. |
:param model_path: path to model |
:return: |
""" |
model = MultiLabelPipeline(model_path=model_path) |
return model |
st.set_page_config(layout="centered") |
st.title("Multiclass Emotion Classification") |
st.write("DeepMind Language Perceiver for Multiclass Emotion Classification (Eng). ") |
ids = {'perceiver-go-emotions': st.secrets['model_key']} |
labels = load_labels() |
download_models(ids) |
st.markdown(f"__Labels:__ {', '.join(labels)}") |
left, right = st.columns([4, 2]) |
inputs = left.text_area('', max_chars=4096, value='This is a space about multiclass emotion classification. Write ' |
'something here to see what happens!') |
model_path = right.selectbox('', options=[k for k in ids], index=0, help='Model to use. ') |
split = right.checkbox('Split into sentences') |
model = load_model(model_path=f"model/{model_path}.pt") |
right.write(model.device) |
if split: |
if not inputs.isspace() and inputs != "": |
with st.spinner('Processing text... This may take a while.'): |
left.write(model(inputs_to_dataset(sent_tokenize(inputs)), batch_size=1)) |
else: |
if not inputs.isspace() and inputs != "": |
with st.spinner('Processing text... This may take a while.'): |
left.write(model(inputs_to_dataset([inputs]), batch_size=1)) |