sunwaee's picture
removed headers
4d679c8
raw
history blame
2.79 kB
import os
import gdown as gdown
import nltk
import streamlit as st
from nltk.tokenize import sent_tokenize
from source.pipeline import MultiLabelPipeline, inputs_to_dataset
def download_models(ids):
"""
Download all models.
:param ids: name and links of models
:return:
"""
# Download sentence tokenizer
nltk.download('punkt')
# Download model from drive if not stored locally
for key in ids:
if not os.path.isfile(f"model/{key}.pt"):
url = f"https://drive.google.com/uc?id={ids[key]}"
gdown.download(url=url, output=f"model/{key}.pt")
@st.cache
def load_labels():
"""
Load model labels.
:return:
"""
return [
"admiration",
"amusement",
"anger",
"annoyance",
"approval",
"caring",
"confusion",
"curiosity",
"desire",
"disappointment",
"disapproval",
"disgust",
"embarrassment",
"excitement",
"fear",
"gratitude",
"grief",
"joy",
"love",
"nervousness",
"optimism",
"pride",
"realization",
"relief",
"remorse",
"sadness",
"surprise",
"neutral"
]
@st.cache(allow_output_mutation=True)
def load_model(model_path):
"""
Load model and cache it.
:param model_path: path to model
:return:
"""
model = MultiLabelPipeline(model_path=model_path)
return model
# Page config
st.set_page_config(layout="centered")
st.title("Multiclass Emotion Classification")
st.write("DeepMind Language Perceiver for Multiclass Emotion Classification (Eng). ")
# Variables
ids = {'perceiver-go-emotions': st.secrets['model_key']}
labels = load_labels()
# Download all models from drive
download_models(ids)
# Display labels
st.markdown(f"__Labels:__ {', '.join(labels)}")
# Model selection
left, right = st.columns([4, 2])
inputs = left.text_area('', max_chars=4096, value='This is a space about multiclass emotion classification. Write '
'something here to see what happens!')
model_path = right.selectbox('', options=[k for k in ids], index=0, help='Model to use. ')
split = right.checkbox('Split into sentences')
model = load_model(model_path=f"model/{model_path}.pt")
right.write(model.device)
if split:
if not inputs.isspace() and inputs != "":
with st.spinner('Processing text... This may take a while.'):
left.write(model(inputs_to_dataset(sent_tokenize(inputs)), batch_size=1))
else:
if not inputs.isspace() and inputs != "":
with st.spinner('Processing text... This may take a while.'):
left.write(model(inputs_to_dataset([inputs]), batch_size=1))