File size: 2,769 Bytes
b438028
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
277c0bd
b438028
 
 
 
 
 
 
 
 
 
e11691c
b438028
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
# coding:utf-8
"""
Filename: app.py
Author: @DvdNss

Created on 12/18/2021
"""
import os

import gdown as gdown
import nltk
import streamlit as st
from nltk.tokenize import sent_tokenize

from source.pipeline import MultiLabelPipeline, inputs_to_dataset


def download_models(ids):
    """
    Download all models.

    :param ids: name and links of models
    :return:
    """

    # Download sentence tokenizer
    nltk.download('punkt')

    # Download model from drive if not stored locally
    for key in ids:
        if not os.path.isfile(f"model/{key}.pt"):
            url = f"https://drive.google.com/uc?id={ids[key]}"
            gdown.download(url=url, output=f"model/{key}.pt")


@st.cache
def load_labels():
    """
    Load model labels.

    :return:
    """

    return [
        "admiration",
        "amusement",
        "anger",
        "annoyance",
        "approval",
        "caring",
        "confusion",
        "curiosity",
        "desire",
        "disappointment",
        "disapproval",
        "disgust",
        "embarrassment",
        "excitement",
        "fear",
        "gratitude",
        "grief",
        "joy",
        "love",
        "nervousness",
        "optimism",
        "pride",
        "realization",
        "relief",
        "remorse",
        "sadness",
        "surprise",
        "neutral"
    ]


@st.cache(allow_output_mutation=True)
def load_model(model_path):
    """
    Load model and cache it.

    :param model_path: path to model
    :return:
    """

    model = MultiLabelPipeline(model_path=model_path)

    return model


# Page config
st.set_page_config(layout="centered")
st.title("Multiclass Emotion Classification")
st.write("DeepMind Language Perceiver for Multiclass Emotion Classification (Eng). ")

# Variables
ids = {'perceiver-go-emotions': st.secrets['model_key']}
labels = load_labels()

# Download all models from drive
download_models(ids)

# Display labels
st.markdown(f"__Labels:__ {', '.join(labels)}")

# Model selection
left, right = st.columns([4, 2])
inputs = left.text_area('', max_chars=4096, placeholder='Write something here to see what happens! ')
model_path = right.selectbox('', options=[k for k in ids], index=0, help='Model to use. ')
split = right.checkbox('Split into sentences')
model = load_model(model_path=f"model/{model_path}.pt")
right.write(model.device)

if split:
    if not inputs.isspace() and inputs != "":
        with st.spinner('Processing text... This may take a while.'):
            left.write(model(inputs_to_dataset(sent_tokenize(inputs)), batch_size=1))
else:
    if not inputs.isspace() and inputs != "":
        with st.spinner('Processing text... This may take a while.'):
            left.write(model(inputs_to_dataset([inputs]), batch_size=1))