rohanphadke commited on
Commit
b8f48f7
1 Parent(s): 1677649

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -1
app.py CHANGED
@@ -4,9 +4,17 @@ from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
  pretrained = "rohanphadke/roberta-finetuned-triplebottomline"
5
  tokenizer = AutoTokenizer.from_pretrained(pretrained)
6
  model = AutoModelForSequenceClassification.from_pretrained(pretrained)
 
 
 
7
 
8
  def greet(name):
9
  return "Hello " + name + "!!"
10
 
11
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
 
 
 
 
 
12
  demo.launch()
 
4
  pretrained = "rohanphadke/roberta-finetuned-triplebottomline"
5
  tokenizer = AutoTokenizer.from_pretrained(pretrained)
6
  model = AutoModelForSequenceClassification.from_pretrained(pretrained)
7
+ threshold = 0.5
8
+
9
+ labels = {0: 'people', 1: 'planet', 2:'profit'}
10
 
11
  def greet(name):
12
  return "Hello " + name + "!!"
13
 
14
+ def predict_text(text):
15
+ inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128)
16
+ probs = torch.sigmoid(model(**inputs).logits).detach().cpu().numpy()[0]
17
+ return {labels[i]: float(probs[i]) for i in range(len(probs)) if probs[i] >= threshold}
18
+
19
+ demo = gr.Interface(fn=predict_text, inputs="text", outputs=gr.outputs.Label(num_top_classes=3))
20
  demo.launch()