|
import os |
|
|
|
import gdown as gdown |
|
import nltk |
|
import streamlit as st |
|
from nltk.tokenize import sent_tokenize |
|
|
|
from source.pipeline import MultiLabelPipeline, inputs_to_dataset |
|
|
|
|
|
def download_models(ids): |
|
""" |
|
Download all models. |
|
|
|
:param ids: name and links of models |
|
:return: |
|
""" |
|
|
|
|
|
nltk.download('punkt') |
|
|
|
|
|
for key in ids: |
|
if not os.path.isfile(f"model/{key}.pt"): |
|
url = f"https://drive.google.com/uc?id={ids[key]}" |
|
gdown.download(url=url, output=f"model/{key}.pt") |
|
|
|
|
|
@st.cache |
|
def load_labels(): |
|
""" |
|
Load model labels. |
|
|
|
:return: |
|
""" |
|
|
|
return [ |
|
"admiration", |
|
"amusement", |
|
"anger", |
|
"annoyance", |
|
"approval", |
|
"caring", |
|
"confusion", |
|
"curiosity", |
|
"desire", |
|
"disappointment", |
|
"disapproval", |
|
"disgust", |
|
"embarrassment", |
|
"excitement", |
|
"fear", |
|
"gratitude", |
|
"grief", |
|
"joy", |
|
"love", |
|
"nervousness", |
|
"optimism", |
|
"pride", |
|
"realization", |
|
"relief", |
|
"remorse", |
|
"sadness", |
|
"surprise", |
|
"neutral" |
|
] |
|
|
|
|
|
@st.cache(allow_output_mutation=True) |
|
def load_model(model_path): |
|
""" |
|
Load model and cache it. |
|
|
|
:param model_path: path to model |
|
:return: |
|
""" |
|
|
|
model = MultiLabelPipeline(model_path=model_path) |
|
|
|
return model |
|
|
|
|
|
|
|
st.set_page_config(layout="centered") |
|
st.title("Multiclass Emotion Classification") |
|
st.write("DeepMind Language Perceiver for Multiclass Emotion Classification (Eng). ") |
|
|
|
|
|
ids = {'perceiver-go-emotions': st.secrets['model_key']} |
|
labels = load_labels() |
|
|
|
|
|
download_models(ids) |
|
|
|
|
|
st.markdown(f"__Labels:__ {', '.join(labels)}") |
|
|
|
|
|
left, right = st.columns([4, 2]) |
|
inputs = left.text_area('', max_chars=4096, value='This is a space about multiclass emotion classification. Write ' |
|
'something here to see what happens!') |
|
model_path = right.selectbox('', options=[k for k in ids], index=0, help='Model to use. ') |
|
split = right.checkbox('Split into sentences') |
|
model = load_model(model_path=f"model/{model_path}.pt") |
|
right.write(model.device) |
|
|
|
if split: |
|
if not inputs.isspace() and inputs != "": |
|
with st.spinner('Processing text... This may take a while.'): |
|
left.write(model(inputs_to_dataset(sent_tokenize(inputs)), batch_size=1)) |
|
else: |
|
if not inputs.isspace() and inputs != "": |
|
with st.spinner('Processing text... This may take a while.'): |
|
left.write(model(inputs_to_dataset([inputs]), batch_size=1)) |
|
|