Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -0,0 +1,389 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""Anxiety_label_training_google.ipynb
|
3 |
+
|
4 |
+
Automatically generated by Colaboratory.
|
5 |
+
|
6 |
+
Original file is located at
|
7 |
+
https://colab.research.google.com/drive/17f7DEZeKdrpQTPfqFe50SWnC-kIg3G-5
|
8 |
+
|
9 |
+
#Prediction of anxiety levels through text analysis
|
10 |
+
|
11 |
+
#Transcript loading method
|
12 |
+
|
13 |
+
When considering both the interviewer and the participant, the dataset is reduced to the sessions of 186 individuals, as 3 transcripts do not contain the text corresponding to Ellie, the virtual interviewer.
|
14 |
+
"""
|
15 |
+
|
16 |
+
import pandas as pd
|
17 |
+
import re
|
18 |
+
import glob
|
19 |
+
|
20 |
+
|
21 |
+
|
22 |
+
"""#Importing the required libraries"""
|
23 |
+
|
24 |
+
import glob
|
25 |
+
import pandas as pd
|
26 |
+
import numpy as np
|
27 |
+
import re
|
28 |
+
import fnmatch
|
29 |
+
import os
|
30 |
+
import keras
|
31 |
+
|
32 |
+
from keras.datasets import fashion_mnist
|
33 |
+
from keras.models import Sequential, Model
|
34 |
+
from keras.layers import Dense, Dropout, Embedding, LSTM, Input, Activation, GlobalAveragePooling1D, Flatten, Concatenate, Conv1D, MaxPooling1D
|
35 |
+
from tensorflow.keras.layers import BatchNormalization
|
36 |
+
from keras.layers import concatenate
|
37 |
+
from keras.optimizers import SGD, RMSprop, Adagrad, Adam
|
38 |
+
from keras.preprocessing.text import one_hot, text_to_word_sequence, Tokenizer
|
39 |
+
from keras_preprocessing.sequence import pad_sequences
|
40 |
+
|
41 |
+
from keras.callbacks import EarlyStopping, ModelCheckpoint
|
42 |
+
from keras.utils.vis_utils import plot_model
|
43 |
+
|
44 |
+
from nltk.corpus import stopwords
|
45 |
+
from nltk.stem import SnowballStemmer
|
46 |
+
from string import punctuation
|
47 |
+
from scipy import stats
|
48 |
+
|
49 |
+
from keras.utils.vis_utils import plot_model
|
50 |
+
|
51 |
+
import matplotlib
|
52 |
+
import matplotlib.pyplot as plt
|
53 |
+
|
54 |
+
import itertools
|
55 |
+
import gensim
|
56 |
+
import nltk
|
57 |
+
from nltk.stem import WordNetLemmatizer
|
58 |
+
|
59 |
+
nltk.download('wordnet')
|
60 |
+
nltk.download('stopwords')
|
61 |
+
wordnet_lemmatizer = WordNetLemmatizer()
|
62 |
+
|
63 |
+
labels=['none','mild','moderate','moderately severe', 'severe']
|
64 |
+
num_classes = len(labels)
|
65 |
+
|
66 |
+
def plot_acc(history, title="Model Accuracy"):
|
67 |
+
"""Imprime una gráfica mostrando la accuracy por epoch obtenida en un entrenamiento"""
|
68 |
+
plt.plot(history.history['accuracy'])
|
69 |
+
plt.plot(history.history['val_accuracy'])
|
70 |
+
plt.title(title)
|
71 |
+
plt.ylabel('Accuracy')
|
72 |
+
plt.xlabel('Epoch')
|
73 |
+
plt.legend(['Train', 'Val'], loc='upper left')
|
74 |
+
plt.show()
|
75 |
+
|
76 |
+
def plot_loss(history, title="Model Loss"):
|
77 |
+
"""Imprime una gráfica mostrando la pérdida por epoch obtenida en un entrenamiento"""
|
78 |
+
plt.plot(history.history['loss'])
|
79 |
+
plt.plot(history.history['val_loss'])
|
80 |
+
plt.title(title)
|
81 |
+
plt.ylabel('Loss')
|
82 |
+
plt.xlabel('Epoch')
|
83 |
+
plt.legend(['Train', 'Val'], loc='upper right')
|
84 |
+
plt.show()
|
85 |
+
|
86 |
+
def plot_compare_losses(history1, history2, name1="Red 1",
|
87 |
+
name2="Red 2", title="Graph title"):
|
88 |
+
"""Compara losses de dos entrenamientos con nombres name1 y name2"""
|
89 |
+
plt.plot(history1.history['loss'], color="green")
|
90 |
+
plt.plot(history1.history['val_loss'], 'r--', color="green")
|
91 |
+
plt.plot(history2.history['loss'], color="blue")
|
92 |
+
plt.plot(history2.history['val_loss'], 'r--', color="blue")
|
93 |
+
plt.title(title)
|
94 |
+
plt.ylabel('Loss')
|
95 |
+
plt.xlabel('Epoch')
|
96 |
+
plt.legend(['Train ' + name1, 'Val ' + name1,
|
97 |
+
'Train ' + name2, 'Val ' + name2],
|
98 |
+
loc='upper right')
|
99 |
+
plt.show()
|
100 |
+
|
101 |
+
def plot_compare_accs(history1, history2, name1="Red 1",
|
102 |
+
name2="Red 2", title="Graph title"):
|
103 |
+
"""Compara accuracies de dos entrenamientos con nombres name1 y name2"""
|
104 |
+
plt.plot(history1.history['acc'], color="green")
|
105 |
+
plt.plot(history1.history['val_acc'], 'r--', color="green")
|
106 |
+
plt.plot(history2.history['acc'], color="blue")
|
107 |
+
plt.plot(history2.history['val_acc'], 'r--', color="blue")
|
108 |
+
plt.title(title)
|
109 |
+
plt.ylabel('Accuracy')
|
110 |
+
plt.xlabel('Epoch')
|
111 |
+
plt.legend(['Train ' + name1, 'Val ' + name1,
|
112 |
+
'Train ' + name2, 'Val ' + name2],
|
113 |
+
loc='lower right')
|
114 |
+
plt.show()
|
115 |
+
|
116 |
+
def plot_compare_multiple_metrics(history_array, names, colors, title="Graph title", metric='acc'):
|
117 |
+
legend = []
|
118 |
+
for i in range(0, len(history_array)):
|
119 |
+
plt.plot(history_array[i].history[metric], color=colors[i])
|
120 |
+
plt.plot(history_array[i].history['val_' + metric], 'r--', color=colors[i])
|
121 |
+
legend.append('Train ' + names[i])
|
122 |
+
legend.append('Val ' + names[i])
|
123 |
+
|
124 |
+
plt.title(title)
|
125 |
+
plt.ylabel('Accuracy')
|
126 |
+
plt.xlabel('Epoch')
|
127 |
+
plt.axis
|
128 |
+
plt.legend(legend,
|
129 |
+
loc='lower right')
|
130 |
+
plt.show()
|
131 |
+
|
132 |
+
"""#Loading and preprocessing of transcripts"""
|
133 |
+
|
134 |
+
all_participants = pd.read_csv('all.csv', sep=',')
|
135 |
+
all_participants.columns = ['index','personId', 'question', 'answer']
|
136 |
+
all_participants = all_participants.astype({"index": float, "personId": float, "question": str, "answer": str })
|
137 |
+
|
138 |
+
all_participants.head()
|
139 |
+
|
140 |
+
"""#Data analysis"""
|
141 |
+
|
142 |
+
ds_len = len(all_participants)
|
143 |
+
len_answers = [len(v) for v in all_participants['answer']]
|
144 |
+
ds_max = max(len_answers)
|
145 |
+
ds_min = min(len_answers)
|
146 |
+
|
147 |
+
stats.describe(len_answers)
|
148 |
+
plt.hist(len_answers)
|
149 |
+
plt.show()
|
150 |
+
|
151 |
+
"""#Auxiliary functions for text processing
|
152 |
+
Function taken from Kaggle for text cleaning
|
153 |
+
"""
|
154 |
+
|
155 |
+
# The function "text_to_wordlist" is from
|
156 |
+
# https://www.kaggle.com/currie32/quora-question-pairs/the-importance-of-cleaning-text
|
157 |
+
def text_to_wordlist(text, remove_stopwords=True, stem_words=False):
|
158 |
+
# Clean the text, with the option to remove stopwords and to stem words.
|
159 |
+
|
160 |
+
# Convert words to lower case and split them
|
161 |
+
text = text.lower().split()
|
162 |
+
|
163 |
+
# Optionally, remove stop words
|
164 |
+
if remove_stopwords:
|
165 |
+
stops = set(stopwords.words("english"))
|
166 |
+
text = [wordnet_lemmatizer.lemmatize(w) for w in text if not w in stops ]
|
167 |
+
text = [w for w in text if w != "nan" ]
|
168 |
+
|
169 |
+
text = " ".join(text)
|
170 |
+
|
171 |
+
# Clean the text
|
172 |
+
text = re.sub(r"[^A-Za-z0-9^,!.\/'+-=]", " ", text)
|
173 |
+
text = re.sub(r"what's", "what is ", text)
|
174 |
+
text = re.sub(r"\'s", " ", text)
|
175 |
+
text = re.sub(r"\'ve", " have ", text)
|
176 |
+
text = re.sub(r"can't", "cannot ", text)
|
177 |
+
text = re.sub(r"n't", " not ", text)
|
178 |
+
text = re.sub(r"i'm", "i am ", text)
|
179 |
+
text = re.sub(r"\'re", " are ", text)
|
180 |
+
text = re.sub(r"\'d", " would ", text)
|
181 |
+
text = re.sub(r"\'ll", " will ", text)
|
182 |
+
text = re.sub(r",", " ", text)
|
183 |
+
text = re.sub(r"\.", " ", text)
|
184 |
+
text = re.sub(r"!", " ! ", text)
|
185 |
+
text = re.sub(r"\/", " ", text)
|
186 |
+
text = re.sub(r"\^", " ^ ", text)
|
187 |
+
text = re.sub(r"\+", " + ", text)
|
188 |
+
text = re.sub(r"\-", " - ", text)
|
189 |
+
text = re.sub(r"\=", " = ", text)
|
190 |
+
|
191 |
+
text = re.sub(r"\<", " ", text)
|
192 |
+
text = re.sub(r"\>", " ", text)
|
193 |
+
|
194 |
+
text = re.sub(r"'", " ", text)
|
195 |
+
text = re.sub(r"(\d+)(k)", r"\g<1>000", text)
|
196 |
+
text = re.sub(r":", " : ", text)
|
197 |
+
text = re.sub(r" e g ", " eg ", text)
|
198 |
+
text = re.sub(r" b g ", " bg ", text)
|
199 |
+
text = re.sub(r" u s ", " american ", text)
|
200 |
+
text = re.sub(r"\0s", "0", text)
|
201 |
+
text = re.sub(r" 9 11 ", "911", text)
|
202 |
+
text = re.sub(r"e - mail", "email", text)
|
203 |
+
text = re.sub(r"j k", "jk", text)
|
204 |
+
text = re.sub(r"\s{2,}", " ", text)
|
205 |
+
|
206 |
+
# Optionally, shorten words to their stems
|
207 |
+
if stem_words:
|
208 |
+
text = text.split()
|
209 |
+
stemmer = SnowballStemmer('english')
|
210 |
+
stemmed_words = [stemmer.stem(word) for word in text]
|
211 |
+
text = " ".join(stemmed_words)
|
212 |
+
|
213 |
+
# Return a list of words
|
214 |
+
return(text)
|
215 |
+
|
216 |
+
nltk.download('omw-1.4')
|
217 |
+
|
218 |
+
all_participants_mix = all_participants.copy()
|
219 |
+
all_participants_mix['answer'] = all_participants_mix.apply(lambda row: text_to_wordlist(row.answer).split(), axis=1)
|
220 |
+
|
221 |
+
words = [w for w in all_participants_mix['answer'].tolist()]
|
222 |
+
words = set(itertools.chain(*words))
|
223 |
+
vocab_size = len(words)
|
224 |
+
|
225 |
+
"""Text cleaning
|
226 |
+
|
227 |
+
Lemmatization
|
228 |
+
|
229 |
+
Separation into vectors
|
230 |
+
"""
|
231 |
+
|
232 |
+
windows_size = 10
|
233 |
+
tokenizer = Tokenizer(num_words=vocab_size)
|
234 |
+
tokenizer.fit_on_texts(all_participants_mix['answer'])
|
235 |
+
tokenizer.fit_on_sequences(all_participants_mix['answer'])
|
236 |
+
|
237 |
+
all_participants_mix['t_answer'] = tokenizer.texts_to_sequences(all_participants_mix['answer'])
|
238 |
+
|
239 |
+
|
240 |
+
word_index = tokenizer.word_index
|
241 |
+
word_size = len(word_index)
|
242 |
+
|
243 |
+
|
244 |
+
all_participants_mix.drop(columns=['question'], inplace=True)
|
245 |
+
answers = all_participants_mix.groupby('personId').agg(lambda x: x.tolist())
|
246 |
+
|
247 |
+
import itertools
|
248 |
+
|
249 |
+
# group the remaining columns by 'personId' and convert each group to a list of lists
|
250 |
+
answers = all_participants_mix.groupby('personId').agg(lambda x: x.tolist())
|
251 |
+
|
252 |
+
# flatten the list of lists in the 'answer' column
|
253 |
+
answers['answer'] = answers['answer'].apply(lambda x: list(itertools.chain.from_iterable(x)))
|
254 |
+
|
255 |
+
# flatten the list of lists in the 't_answer' column
|
256 |
+
answers['t_answer'] = answers['t_answer'].apply(lambda x: list(itertools.chain.from_iterable(x)))
|
257 |
+
|
258 |
+
answers
|
259 |
+
|
260 |
+
windows_size = 10
|
261 |
+
cont = 0
|
262 |
+
phrases_lp = pd.DataFrame(columns=['personId','answer', 't_answer'])
|
263 |
+
|
264 |
+
for p in answers.iterrows():
|
265 |
+
words = p[1]["answer"]
|
266 |
+
size = len(words)
|
267 |
+
word_tokens = p[1]["t_answer"]
|
268 |
+
|
269 |
+
for i in range(size):
|
270 |
+
sentence = words[i:min(i+windows_size,size)]
|
271 |
+
tokens = word_tokens[i:min(i+windows_size,size)]
|
272 |
+
phrases_lp.loc[cont] = [p[0], sentence, tokens]
|
273 |
+
cont = cont + 1
|
274 |
+
|
275 |
+
|
276 |
+
|
277 |
+
def load_avec_dataset_file(path, score_column):
|
278 |
+
ds = pd.read_csv(path, sep=',')
|
279 |
+
ds['level'] = pd.cut(ds[score_column], bins=[-1,0,5,10,15,25], labels=[0,1,2,3,4])
|
280 |
+
ds['PHQ8_Score'] = ds[score_column]
|
281 |
+
ds['cat_level'] = keras.utils.to_categorical(ds['level'], num_classes).tolist()
|
282 |
+
ds = ds[['Participant_ID', 'level', 'cat_level', 'PHQ8_Score']]
|
283 |
+
ds = ds.astype({"Participant_ID": float, "level": int, 'PHQ8_Score': int})
|
284 |
+
return ds
|
285 |
+
|
286 |
+
|
287 |
+
|
288 |
+
def split_by_phq_level(ds):
|
289 |
+
none_ds = ds[ds['level']==0]
|
290 |
+
mild_ds = ds[ds['level']==1]
|
291 |
+
moderate_ds = ds[ds['level']==2]
|
292 |
+
moderate_severe_ds = ds[ds['level']==3]
|
293 |
+
severe_ds = ds[ds['level']==4]
|
294 |
+
return (none_ds, mild_ds, moderate_ds, moderate_severe_ds, severe_ds)
|
295 |
+
|
296 |
+
|
297 |
+
def distribute_instances(ds):
|
298 |
+
ds_shuffled = ds.sample(frac=1)
|
299 |
+
none_ds, mild_ds, moderate_ds, moderate_severe_ds, severe_ds = split_by_phq_level(ds_shuffled)
|
300 |
+
split = [70,14,16]
|
301 |
+
eq_ds = {}
|
302 |
+
prev_none = prev_mild = prev_moderate = prev_moderate_severe = prev_severe = 0
|
303 |
+
|
304 |
+
for p in split:
|
305 |
+
last_none = min(len(none_ds), prev_none + round(len(none_ds) * p/100))
|
306 |
+
last_mild = min(len(mild_ds), prev_mild + round(len(mild_ds) * p/100))
|
307 |
+
last_moderate = min(len(moderate_ds), prev_moderate + round(len(moderate_ds) * p/100))
|
308 |
+
last_moderate_severe = min(len(moderate_severe_ds), prev_moderate_severe + round(len(moderate_severe_ds) * p/100))
|
309 |
+
last_severe = min(len(severe_ds), prev_severe + round(len(severe_ds) * p/100))
|
310 |
+
eq_ds["d"+str(p)] = pd.concat([none_ds[prev_none: last_none], mild_ds[prev_mild: last_mild], moderate_ds[prev_moderate: last_moderate], moderate_severe_ds[prev_moderate_severe: last_moderate_severe], severe_ds[prev_severe: last_severe]])
|
311 |
+
prev_none = last_none
|
312 |
+
prev_mild = last_mild
|
313 |
+
prev_moderate = last_moderate
|
314 |
+
prev_moderate_severe = last_moderate_severe
|
315 |
+
prev_severe = last_severe
|
316 |
+
return (eq_ds["d70"], eq_ds["d14"], eq_ds["d16"])
|
317 |
+
|
318 |
+
def test_model(text, model):
|
319 |
+
print(text)
|
320 |
+
word_list = text_to_wordlist(text)
|
321 |
+
sequences = tokenizer.texts_to_sequences([word_list])
|
322 |
+
sequences_input = list(itertools.chain(*sequences))
|
323 |
+
sequences_input = pad_sequences([sequences_input], value=0, padding="post", maxlen=windows_size).tolist()
|
324 |
+
input_a = np.asarray(sequences_input)
|
325 |
+
pred = model.predict(input_a, batch_size=None, verbose=0, steps=None)
|
326 |
+
print(pred)
|
327 |
+
predicted_class = np.argmax(pred)
|
328 |
+
print(labels[predicted_class])
|
329 |
+
|
330 |
+
def confusion_matrix(model, x, y):
|
331 |
+
prediction = model.predict(x, batch_size=None, verbose=0, steps=None)
|
332 |
+
labels=['none','mild','moderate','moderately severe', 'severe']
|
333 |
+
|
334 |
+
max_prediction = np.argmax(prediction, axis=1)
|
335 |
+
max_actual = np.argmax(y, axis=1)
|
336 |
+
|
337 |
+
y_pred = pd.Categorical.from_codes(max_prediction, labels)
|
338 |
+
y_actu = pd.Categorical.from_codes(max_actual, labels)
|
339 |
+
|
340 |
+
return pd.crosstab(y_actu, y_pred)
|
341 |
+
|
342 |
+
|
343 |
+
|
344 |
+
|
345 |
+
import pickle
|
346 |
+
|
347 |
+
import pickle
|
348 |
+
windows_size = 10
|
349 |
+
# Load the trained model
|
350 |
+
with open('model_google.pkl', 'rb') as f:
|
351 |
+
Mode = pickle.load(f)
|
352 |
+
|
353 |
+
def Test_model(text, Model):
|
354 |
+
word_list = text_to_wordlist(text)
|
355 |
+
sequences = tokenizer.texts_to_sequences([word_list])
|
356 |
+
sequences_input = list(itertools.chain(*sequences))
|
357 |
+
sequences_input = pad_sequences([sequences_input], value=0, padding="post", maxlen=windows_size).tolist()
|
358 |
+
input_a = np.asarray(sequences_input)
|
359 |
+
pred = Model.predict(input_a, batch_size=None, verbose=0, steps=None)
|
360 |
+
#print(pred)
|
361 |
+
predicted_class = np.argmax(pred)
|
362 |
+
#print(labels[predicted_class])
|
363 |
+
|
364 |
+
|
365 |
+
|
366 |
+
import gradio as gr
|
367 |
+
import pickle
|
368 |
+
|
369 |
+
|
370 |
+
# Load the trained model
|
371 |
+
with open('model_google.pkl', 'rb') as f:
|
372 |
+
Modell = pickle.load(f)
|
373 |
+
|
374 |
+
def predict(text):
|
375 |
+
|
376 |
+
word_list = text_to_wordlist(text)
|
377 |
+
sequences = tokenizer.texts_to_sequences([word_list])
|
378 |
+
sequences_input = list(itertools.chain(*sequences))
|
379 |
+
sequences_input = pad_sequences([sequences_input], value=0, padding="post", maxlen=windows_size).tolist()
|
380 |
+
input_a = np.asarray(sequences_input)
|
381 |
+
pred = Modell.predict(input_a, batch_size=None, verbose=0, steps=None)
|
382 |
+
|
383 |
+
predicted_class = np.argmax(pred)
|
384 |
+
return labels[predicted_class]
|
385 |
+
input_text = gr.inputs.Textbox(label="Enter a sentence")
|
386 |
+
output_text = gr.outputs.Textbox(label="Predicted label")
|
387 |
+
iface = gr.Interface(fn=predict, inputs=input_text, outputs=output_text, title="Depression Severity Analysis",
|
388 |
+
description="Enter texts to classify its depression severity.")
|
389 |
+
iface.launch()
|