sbenel commited on
Commit
7217924
1 Parent(s): 3e38711

fix in app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -2
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import gradio as gr
 
2
  import torch
3
  from transformers import DistilBertForSequenceClassification, DistilBertTokenizerFast
4
 
@@ -10,10 +11,12 @@ def translate(text):
10
  input = tokenizer(text, return_tensors="pt")
11
  labels = torch.tensor([1]).unsqueeze(0) # Batch size 1
12
  output = model(**input, labels=labels)
 
 
 
 
13
  # output = model.generate(input["input_ids"], max_length=40, num_beams=4, early_stopping=True)
14
 
15
- return tokenizer.decode(output[0], skip_special_tokens=True)
16
-
17
  title = "Text Emotion Classification"
18
  inputs = gr.inputs.Textbox(lines=1, label="Text")
19
  outputs = [gr.outputs.Textbox(label="Emotions")]
 
1
  import gradio as gr
2
+ import torch.nn.functional as F
3
  import torch
4
  from transformers import DistilBertForSequenceClassification, DistilBertTokenizerFast
5
 
 
11
  input = tokenizer(text, return_tensors="pt")
12
  labels = torch.tensor([1]).unsqueeze(0) # Batch size 1
13
  output = model(**input, labels=labels)
14
+ logits = outputs.logits
15
+ prediction = F.softmax(logits, dim=1)
16
+ y_pred = torch.argmax(prediction).numpy()
17
+ return y_pred
18
  # output = model.generate(input["input_ids"], max_length=40, num_beams=4, early_stopping=True)
19
 
 
 
20
  title = "Text Emotion Classification"
21
  inputs = gr.inputs.Textbox(lines=1, label="Text")
22
  outputs = [gr.outputs.Textbox(label="Emotions")]