rrevoid commited on
Commit
954b6e7
1 Parent(s): d670e16

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -10
app.py CHANGED
@@ -5,13 +5,9 @@ import torch.nn as nn
5
  from transformers import RobertaTokenizer, RobertaModel
6
 
7
  @st.cache(suppress_st_warning=True)
8
- def init_tokenizer():
9
  tokenizer = RobertaTokenizer.from_pretrained("roberta-large-mnli")
10
- return tokenizer
11
 
12
-
13
- @st.cache(suppress_st_warning=True)
14
- def init_model():
15
  model = RobertaModel.from_pretrained("roberta-large-mnli")
16
 
17
  model.pooler = nn.Sequential(
@@ -24,6 +20,8 @@ def init_model():
24
 
25
  model_path = 'model.pt'
26
  model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
 
 
27
 
28
  cats = ['Computer Science', 'Economics', 'Electrical Engineering',
29
  'Mathematics', 'Physics', 'Biology', 'Finance', 'Statistics']
@@ -38,6 +36,7 @@ def predict(outputs):
38
  top += percent
39
  st.write(f'{cat}: {round(percent, 1)}')
40
 
 
41
 
42
  st.markdown("### Title")
43
 
@@ -50,10 +49,8 @@ abstract = st.text_area("Enter abstract", height=200)
50
  if not title:
51
  st.warning("Please fill out so required fields")
52
  else:
53
- tokenizer = init_tokenizer()
54
- model = init_model()
55
-
56
  encoded_input = tokenizer(title + '. ' + abstract, return_tensors='pt', padding=True,
57
  max_length = 512, truncation=True)
58
- outputs = model(**encoded_input).pooler_output[:, 0, :]
59
- predict(outputs)
 
 
5
  from transformers import RobertaTokenizer, RobertaModel
6
 
7
  @st.cache(suppress_st_warning=True)
8
+ def init():
9
  tokenizer = RobertaTokenizer.from_pretrained("roberta-large-mnli")
 
10
 
 
 
 
11
  model = RobertaModel.from_pretrained("roberta-large-mnli")
12
 
13
  model.pooler = nn.Sequential(
 
20
 
21
  model_path = 'model.pt'
22
  model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
23
+ model.eval()
24
+ return tokenizer, model
25
 
26
  cats = ['Computer Science', 'Economics', 'Electrical Engineering',
27
  'Mathematics', 'Physics', 'Biology', 'Finance', 'Statistics']
 
36
  top += percent
37
  st.write(f'{cat}: {round(percent, 1)}')
38
 
39
+ tokenizer, model = init()
40
 
41
  st.markdown("### Title")
42
 
 
49
  if not title:
50
  st.warning("Please fill out so required fields")
51
  else:
 
 
 
52
  encoded_input = tokenizer(title + '. ' + abstract, return_tensors='pt', padding=True,
53
  max_length = 512, truncation=True)
54
+ with torch.no_grad():
55
+ outputs = model(**encoded_input).pooler_output[:, 0, :]
56
+ predict(outputs)