File size: 2,781 Bytes
b438028 e11691c b438028 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
# coding:utf-8
"""
Filename: app.py
Author: @DvdNss
Created on 12/18/2021
"""
import os
import gdown as gdown
import nltk
import streamlit as st
from nltk.tokenize import sent_tokenize
from source.pipeline import MultiLabelPipeline, inputs_to_dataset
def download_models(ids):
"""
Download all models.
:param ids: name and links of models
:return:
"""
# Download sentence tokenizer
nltk.download('punkt')
# Download model from drive if not stored locally
for key in ids:
if not os.path.isfile(f"model/{key}.pt"):
url = f"https://drive.google.com/uc?id={ids[key]}"
gdown.download(url=url, output=f"model/{key}.pt")
@st.cache
def load_labels():
"""
Load model labels.
:return:
"""
return [
"admiration",
"amusement",
"anger",
"annoyance",
"approval",
"caring",
"confusion",
"curiosity",
"desire",
"disappointment",
"disapproval",
"disgust",
"embarrassment",
"excitement",
"fear",
"gratitude",
"grief",
"joy",
"love",
"nervousness",
"optimism",
"pride",
"realization",
"relief",
"remorse",
"sadness",
"surprise",
"neutral"
]
@st.cache(allow_output_mutation=True)
def load_model(model_path):
"""
Load model and cache it.
:param model_path: path to model
:return:
"""
model = MultiLabelPipeline(model_path=model_path)
return model
# Page config
st.set_page_config(layout="centered")
st.title("Multiclass Emotion Classification")
st.write("DeepMind Language Perceiver for Multiclass Emotion Classification (Eng). ")
# Variables
ids = {'perceiver-go-emotions': '15m-p0Pwwnh3STi7zXHkKr9HFxliGJikU'}
labels = load_labels()
# Download all models from drive
download_models(ids)
# Display labels
st.markdown(f"__Labels:__ {', '.join(labels)}")
# Model selection
left, right = st.columns([4, 2])
inputs = left.text_area('', max_chars=4096, placeholder='Write something here to see what happens! ')
model_path = right.selectbox('', options=[k for k in ids], index=0, help='Model to use. ')
split = right.checkbox('Split into sentences')
model = load_model(model_path=f"model/{model_path}.pt")
right.write(model.device)
if split:
if not inputs.isspace() and inputs != "":
with st.spinner('Processing text... This may take a while.'):
left.write(model(inputs_to_dataset(sent_tokenize(inputs)), batch_size=1))
else:
if not inputs.isspace() and inputs != "":
with st.spinner('Processing text... This may take a while.'):
left.write(model(inputs_to_dataset([inputs]), batch_size=1))
|