Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -5,13 +5,9 @@ import torch.nn as nn
|
|
5 |
from transformers import RobertaTokenizer, RobertaModel
|
6 |
|
7 |
@st.cache(suppress_st_warning=True)
|
8 |
-
def
|
9 |
tokenizer = RobertaTokenizer.from_pretrained("roberta-large-mnli")
|
10 |
-
return tokenizer
|
11 |
|
12 |
-
|
13 |
-
@st.cache(suppress_st_warning=True)
|
14 |
-
def init_model():
|
15 |
model = RobertaModel.from_pretrained("roberta-large-mnli")
|
16 |
|
17 |
model.pooler = nn.Sequential(
|
@@ -24,6 +20,8 @@ def init_model():
|
|
24 |
|
25 |
model_path = 'model.pt'
|
26 |
model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
|
|
|
|
|
27 |
|
28 |
cats = ['Computer Science', 'Economics', 'Electrical Engineering',
|
29 |
'Mathematics', 'Physics', 'Biology', 'Finance', 'Statistics']
|
@@ -38,6 +36,7 @@ def predict(outputs):
|
|
38 |
top += percent
|
39 |
st.write(f'{cat}: {round(percent, 1)}')
|
40 |
|
|
|
41 |
|
42 |
st.markdown("### Title")
|
43 |
|
@@ -50,10 +49,8 @@ abstract = st.text_area("Enter abstract", height=200)
|
|
50 |
if not title:
|
51 |
st.warning("Please fill out so required fields")
|
52 |
else:
|
53 |
-
tokenizer = init_tokenizer()
|
54 |
-
model = init_model()
|
55 |
-
|
56 |
encoded_input = tokenizer(title + '. ' + abstract, return_tensors='pt', padding=True,
|
57 |
max_length = 512, truncation=True)
|
58 |
-
|
59 |
-
|
|
|
|
5 |
from transformers import RobertaTokenizer, RobertaModel
|
6 |
|
7 |
@st.cache(suppress_st_warning=True)
|
8 |
+
def init():
|
9 |
tokenizer = RobertaTokenizer.from_pretrained("roberta-large-mnli")
|
|
|
10 |
|
|
|
|
|
|
|
11 |
model = RobertaModel.from_pretrained("roberta-large-mnli")
|
12 |
|
13 |
model.pooler = nn.Sequential(
|
|
|
20 |
|
21 |
model_path = 'model.pt'
|
22 |
model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
|
23 |
+
model.eval()
|
24 |
+
return tokenizer, model
|
25 |
|
26 |
cats = ['Computer Science', 'Economics', 'Electrical Engineering',
|
27 |
'Mathematics', 'Physics', 'Biology', 'Finance', 'Statistics']
|
|
|
36 |
top += percent
|
37 |
st.write(f'{cat}: {round(percent, 1)}')
|
38 |
|
39 |
+
tokenizer, model = init()
|
40 |
|
41 |
st.markdown("### Title")
|
42 |
|
|
|
49 |
if not title:
|
50 |
st.warning("Please fill out so required fields")
|
51 |
else:
|
|
|
|
|
|
|
52 |
encoded_input = tokenizer(title + '. ' + abstract, return_tensors='pt', padding=True,
|
53 |
max_length = 512, truncation=True)
|
54 |
+
with torch.no_grad():
|
55 |
+
outputs = model(**encoded_input).pooler_output[:, 0, :]
|
56 |
+
predict(outputs)
|