File size: 2,793 Bytes
b438028
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
277c0bd
b438028
 
 
 
 
 
 
 
 
 
335a8d7
 
b438028
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import os

import gdown as gdown
import nltk
import streamlit as st
from nltk.tokenize import sent_tokenize

from source.pipeline import MultiLabelPipeline, inputs_to_dataset


def download_models(ids):
    """
    Download all models.

    :param ids: name and links of models
    :return:
    """

    # Download sentence tokenizer
    nltk.download('punkt')

    # Download model from drive if not stored locally
    for key in ids:
        if not os.path.isfile(f"model/{key}.pt"):
            url = f"https://drive.google.com/uc?id={ids[key]}"
            gdown.download(url=url, output=f"model/{key}.pt")


@st.cache
def load_labels():
    """
    Load model labels.

    :return:
    """

    return [
        "admiration",
        "amusement",
        "anger",
        "annoyance",
        "approval",
        "caring",
        "confusion",
        "curiosity",
        "desire",
        "disappointment",
        "disapproval",
        "disgust",
        "embarrassment",
        "excitement",
        "fear",
        "gratitude",
        "grief",
        "joy",
        "love",
        "nervousness",
        "optimism",
        "pride",
        "realization",
        "relief",
        "remorse",
        "sadness",
        "surprise",
        "neutral"
    ]


@st.cache(allow_output_mutation=True)
def load_model(model_path):
    """
    Load model and cache it.

    :param model_path: path to model
    :return:
    """

    model = MultiLabelPipeline(model_path=model_path)

    return model


# Page config
st.set_page_config(layout="centered")
st.title("Multiclass Emotion Classification")
st.write("DeepMind Language Perceiver for Multiclass Emotion Classification (Eng). ")

# Variables
ids = {'perceiver-go-emotions': st.secrets['model_key']}
labels = load_labels()

# Download all models from drive
download_models(ids)

# Display labels
st.markdown(f"__Labels:__ {', '.join(labels)}")

# Model selection
left, right = st.columns([4, 2])
inputs = left.text_area('', max_chars=4096, value='This is a space about multiclass emotion classification. Write '
                                                  'something here to see what happens!')
model_path = right.selectbox('', options=[k for k in ids], index=0, help='Model to use. ')
split = right.checkbox('Split into sentences')
model = load_model(model_path=f"model/{model_path}.pt")
right.write(model.device)

if split:
    if not inputs.isspace() and inputs != "":
        with st.spinner('Processing text... This may take a while.'):
            left.write(model(inputs_to_dataset(sent_tokenize(inputs)), batch_size=1))
else:
    if not inputs.isspace() and inputs != "":
        with st.spinner('Processing text... This may take a while.'):
            left.write(model(inputs_to_dataset([inputs]), batch_size=1))