awacke1 commited on
Commit
21b28f0
1 Parent(s): 281179f

Create new file

Browse files
Files changed (1) hide show
  1. app.py +115 -0
app.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import gdown as gdown
4
+ import nltk
5
+ import streamlit as st
6
+ from nltk.tokenize import sent_tokenize
7
+
8
+ from source.pipeline import MultiLabelPipeline, inputs_to_dataset
9
+
10
+
11
+ def download_models(ids):
12
+ """
13
+ Download all models.
14
+ :param ids: name and links of models
15
+ :return:
16
+ """
17
+
18
+ # Download sentence tokenizer
19
+ nltk.download('punkt')
20
+
21
+ # Download model from drive if not stored locally
22
+ for key in ids:
23
+ if not os.path.isfile(f"model/{key}.pt"):
24
+ url = f"https://drive.google.com/uc?id={ids[key]}"
25
+ gdown.download(url=url, output=f"model/{key}.pt")
26
+
27
+
28
+ @st.cache
29
+ def load_labels():
30
+ """
31
+ Load model labels.
32
+ :return:
33
+ """
34
+
35
+ return [
36
+ "admiration",
37
+ "amusement",
38
+ "anger",
39
+ "annoyance",
40
+ "approval",
41
+ "caring",
42
+ "confusion",
43
+ "curiosity",
44
+ "desire",
45
+ "disappointment",
46
+ "disapproval",
47
+ "disgust",
48
+ "embarrassment",
49
+ "excitement",
50
+ "fear",
51
+ "gratitude",
52
+ "grief",
53
+ "joy",
54
+ "love",
55
+ "nervousness",
56
+ "optimism",
57
+ "pride",
58
+ "realization",
59
+ "relief",
60
+ "remorse",
61
+ "sadness",
62
+ "surprise",
63
+ "neutral"
64
+ ]
65
+
66
+
67
+ @st.cache(allow_output_mutation=True)
68
+ def load_model(model_path):
69
+ """
70
+ Load model and cache it.
71
+ :param model_path: path to model
72
+ :return:
73
+ """
74
+
75
+ model = MultiLabelPipeline(model_path=model_path)
76
+
77
+ return model
78
+
79
+
80
+ # Page config
81
+ st.set_page_config(layout="centered")
82
+ st.title("Multiclass Emotion Classification")
83
+ st.write("DeepMind Language Perceiver for Multiclass Emotion Classification (Eng). ")
84
+
85
+ maintenance = False
86
+ if maintenance:
87
+ st.write("Unavailable for now (file downloads limit). ")
88
+ else:
89
+ # Variables
90
+ ids = {'perceiver-go-emotions': st.secrets['model']}
91
+ labels = load_labels()
92
+
93
+ # Download all models from drive
94
+ download_models(ids)
95
+
96
+ # Display labels
97
+ st.markdown(f"__Labels:__ {', '.join(labels)}")
98
+
99
+ # Model selection
100
+ left, right = st.columns([4, 2])
101
+ inputs = left.text_area('', max_chars=4096, value='This is a space about multiclass emotion classification. Write '
102
+ 'something here to see what happens!')
103
+ model_path = right.selectbox('', options=[k for k in ids], index=0, help='Model to use. ')
104
+ split = right.checkbox('Split into sentences', value=True)
105
+ model = load_model(model_path=f"model/{model_path}.pt")
106
+ right.write(model.device)
107
+
108
+ if split:
109
+ if not inputs.isspace() and inputs != "":
110
+ with st.spinner('Processing text... This may take a while.'):
111
+ left.write(model(inputs_to_dataset(sent_tokenize(inputs)), batch_size=1))
112
+ else:
113
+ if not inputs.isspace() and inputs != "":
114
+ with st.spinner('Processing text... This may take a while.'):
115
+ left.write(model(inputs_to_dataset([inputs]), batch_size=1))