Porjaz commited on
Commit
3083011
1 Parent(s): ac354b9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -2
app.py CHANGED
@@ -165,9 +165,35 @@ def return_prediction(mic, file):
165
  else:
166
  return "You must either provide a mic recording or a file"
167
 
 
 
 
 
 
 
 
 
168
  score = score.item()
169
  score = str(round(100 * score, 2)) + "%"
170
- return text_lab, score
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
 
172
 
173
  classifier = foreign_class(source="Porjaz/wavlm-base-emo-fi", pymodule_file="custom_interface.py", classname="CustomEncoderWav2vec2Classifier")
@@ -184,4 +210,4 @@ gradio_app = gr.Interface(
184
 
185
 
186
  if __name__ == "__main__":
187
- gradio_app.launch(share=True, ssl_verify=False)
 
165
  else:
166
  return "You must either provide a mic recording or a file"
167
 
168
+ '''
169
+ '1' => 0
170
+ '3' => 1
171
+ '5' => 2
172
+ '4' => 3
173
+ '2' => 4
174
+ '''
175
+
176
  score = score.item()
177
  score = str(round(100 * score, 2)) + "%"
178
+ neu = round(100 * out_prob[0, 0].item(), 2)
179
+ joy = round(100 * out_prob[0, 1].item(), 2)
180
+ aff = round(100 * out_prob[0, 2].item(), 2)
181
+ ang = round(100 * out_prob[0, 3].item(), 2)
182
+ sad = round(100 * out_prob[0, 4].item(), 2)
183
+ result_dict = {
184
+ "Neutral: ": neu,
185
+ "Joy: ": joy,
186
+ "Affection: ": aff,
187
+ "Anger: ": ang,
188
+ "Sadness: ": sad,
189
+ }
190
+ # order the dict in reverse order by value
191
+ result_dict = dict(sorted(result_dict.items(), key=lambda item: item[1], reverse=True))
192
+ keys = list(result_dict.keys())
193
+ values = list(result_dict.values())
194
+ result_string = keys[0] + "\t" + str(values[0]) + "%" + "\n" + keys[1] + "\t" + str(values[1]) + "%" + "\n" + keys[2] + "\t" + str(values[2]) + "%" + "\n" + keys[3] + "\t" + str(values[3]) + "%" + "\n" + keys[4] + "\t" + str(values[4]) + "%"
195
+ # return text_lab, score
196
+ return result_string
197
 
198
 
199
  classifier = foreign_class(source="Porjaz/wavlm-base-emo-fi", pymodule_file="custom_interface.py", classname="CustomEncoderWav2vec2Classifier")
 
210
 
211
 
212
  if __name__ == "__main__":
213
+ gradio_app.launch(share=True, ssl_verify=False)