File size: 1,376 Bytes
1f9f1ad
a483514
c2c3782
1f9f1ad
 
a6c4164
 
 
 
 
5a48121
 
a6c4164
 
 
 
 
7832338
627c1c5
7832338
fca33ac
7832338
 
 
 
 
 
 
 
 
 
 
547522b
 
7832338
 
c2c3782
ac09987
f8a5603
99ca19b
 
20da4f7
783d3ba
7832338
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import torch
from transformers import DistilBertForSequenceClassification, DistilBertTokenizerFast
import streamlit as st


def get_text(title: str, abstract: str):
  if abstract and title:
    text = abstract + ' ' + title
  elif title:
    text = title
  elif abstract:
    text = abstract
  else:
    text = None
    
  return text

def get_labels(text, model, tokenizer, count_labels=8):
  tokens = tokenizer(text, return_tensors='pt')
  outputs = model(**tokens)
  probs = torch.nn.Softmax()(outputs.logits)
  
  labels = ['Computer_science', 'Economics',
       'Electrical_Engineering_and_Systems_Science', 'Mathematics',
       'Physics', 'Quantitative_Biology', 'Quantitative_Finance',
       'Statistics']
       
  sort_lst = sorted([(prob, label) for prob, label in zip(probs.detach().numpy()[0], labels)], key=lambda x: -x[0])
  cumsum = 0
  result_labels = []
  for pair in sort_lst:
      cumsum += pair[0]
      if cumsum > 0.95:
          result_labels.append(pair[1])
          return result_labels
      result_labels.append(pair[1])
 
@st.cache(allow_output_mutation=True)
def load_model():
   tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-cased")
   model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-cased", num_labels=8)
   model.load_state_dict(torch.load('weight_model')) 
   return model, tokenizer