Spaces:
Runtime error
Runtime error
Commit
•
71411c5
1
Parent(s):
41889a9
Update app.py
Browse files
app.py
CHANGED
@@ -3,6 +3,24 @@ from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
|
3 |
import torch
|
4 |
import numpy as np
|
5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
|
7 |
model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH)
|
8 |
model = model.to(device)
|
|
|
3 |
import torch
|
4 |
import numpy as np
|
5 |
|
6 |
+
MODEL_PATH = 'finiteautomata/bertweet-base-sentiment-analysis'
|
7 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
|
8 |
+
model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH)
|
9 |
+
model = model.to(device)
|
10 |
+
|
11 |
+
|
12 |
+
logits = outputs.logits
|
13 |
+
sigmoid = torch.nn.Sigmoid()
|
14 |
+
probs = sigmoid(logits.squeeze().cpu())
|
15 |
+
probs = probs.detach().numpy()
|
16 |
+
|
17 |
+
for i, k in enumerate(label2id.keys()):
|
18 |
+
label2id[k] = probs[i]
|
19 |
+
|
20 |
+
|
21 |
+
label2id = {k: v for k, v in sorted(label2id.items(), key=lambda item: item[1], reverse=True)}
|
22 |
+
label2id
|
23 |
+
|
24 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
|
25 |
model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH)
|
26 |
model = model.to(device)
|