Spaces:
Runtime error
Runtime error
kotstantinovskii
commited on
Commit
•
38aae15
1
Parent(s):
c19e64d
Upload model.py
Browse files
model.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from torch import nn
|
3 |
+
import numpy as np
|
4 |
+
from transformers import DistilBertForSequenceClassification
|
5 |
+
|
6 |
+
|
7 |
+
class ArxivModel:
|
8 |
+
|
9 |
+
def __init__(self, model, tokenizer):
|
10 |
+
self.model = model
|
11 |
+
self.tokenizer = tokenizer
|
12 |
+
|
13 |
+
self.model.to('cpu')
|
14 |
+
|
15 |
+
def get_logits(self, tweet_text):
|
16 |
+
text_tokens = self.tokenizer(tweet_text, return_tensors="pt").to('cpu')
|
17 |
+
softmax = nn.Softmax(dim=1)
|
18 |
+
|
19 |
+
return softmax(self.model(**text_tokens).logits.detach()).numpy()[0]
|
20 |
+
|
21 |
+
def get_idx_class(self, tweet_text, thr=-1.0):
|
22 |
+
logits = self.get_logits(tweet_text)
|
23 |
+
|
24 |
+
if thr == -1.0:
|
25 |
+
return [(np.argmax(logits), np.max(logits))]
|
26 |
+
else:
|
27 |
+
sum_probs = 0.0
|
28 |
+
idxs = []
|
29 |
+
for p in np.argsort(logits)[::-1]:
|
30 |
+
sum_probs += logits[p]
|
31 |
+
idxs.append((p, logits[p]))
|
32 |
+
if sum_probs > thr:
|
33 |
+
return idxs
|
34 |
+
|
35 |
+
|
36 |
+
@st.cache
|
37 |
+
def load_model(path="./checkpoint-15500", num_labels=153):
|
38 |
+
return DistilBertForSequenceClassification.from_pretrained(path, num_labels=num_labels)
|