fzmushko commited on
Commit
b27d10d
1 Parent(s): 7f34e4d

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -0
app.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import numpy as np
4
+
5
+ st.markdown("### A dummy site for classifying article topics by title and abstract.")
6
+ st.markdown("It can predict the following topics: Computer Science, Economics, Electrical Engineering and Systems Science, Mathematics, Quantitative Biology, Quantitative Finance, Statistics, Physics")
7
+
8
+
9
+ from transformers import pipeline
10
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
11
+
12
+ @st.cache(suppress_st_warning=True)
13
+ def model_tokenizer():
14
+ model_name = 'distilbert-base-cased'
15
+ #tokenizer = AutoTokenizer.from_pretrained("distilbert-base-cased", problem_type="multi_label_classification")
16
+ model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-cased", num_labels=8, problem_type="multi_label_classification")
17
+ weights = torch.load('model.pt', map_location=torch.device('cpu'))
18
+ model.load_state_dict(weights)
19
+ return model#, tokenizer
20
+
21
+ def make_prediction(model, tokenizer, text):
22
+ #print(text)
23
+ tokens = tokenizer.encode(text)
24
+ with torch.no_grad():
25
+ logits = model.cpu()(torch.as_tensor([tokens]))[0]
26
+ #print(logits)
27
+ probs = np.array(torch.softmax(logits[-1, :], dim=-1))
28
+ #print(probs)
29
+
30
+ sorted_classes, sorted_probs = np.flip(np.argsort(probs)), sorted(probs, reverse=True)
31
+ prediction_classes, prediction_probs = [], []
32
+ probs_sum = 0
33
+ i=0
34
+ res = []
35
+ while probs_sum <= 0.95:
36
+ # print(i)
37
+ # print(sorted_classes)
38
+ # print(sorted_classes[i])
39
+ # print(to_category)
40
+ # print(sorted_classes[i], to_category[sorted_classes[i]])
41
+ prediction_classes.append(to_category[sorted_classes[i]])
42
+ prediction_probs.append(100*sorted_probs[i])
43
+ probs_sum += sorted_probs[i]
44
+ i += 1
45
+ for pr, cl in zip(prediction_probs, prediction_classes):
46
+ print(str("{:.2f}".format(pr) + "%"), cl)
47
+ res.append((str("{:.2f}".format(pr) + "%"), cl))
48
+ return res
49
+
50
+ model = model_tokenizer()
51
+ tokenizer = AutoTokenizer.from_pretrained("distilbert-base-cased", problem_type="multi_label_classification")
52
+
53
+ categories_full = ['Computer Science', 'Economics', 'Electrical Engineering and Systems Science', 'Mathematics', 'Quantitative Biology', 'Quantitative Finance', 'Statistics', 'Physics']
54
+
55
+ to_category = {}
56
+
57
+ for i in range(len(categories_full)):
58
+ to_category[i] = categories_full[i]
59
+
60
+
61
+
62
+ title = st.text_area("Type the title of the article here")
63
+ abstract = st.text_area("Type the abstract of the article here")
64
+
65
+
66
+ if st.button('Analyse'):
67
+ if title or abstract:
68
+ text = '[TITLE] ' + title + ' [ABSTRACT] ' + abstract
69
+ res = make_prediction(model, tokenizer, text)
70
+ for cat in res:
71
+ st.markdown(f"{cat[0], cat[1]}")
72
+ else:
73
+ st.error(f"Write title or abstract")
74
+
75
+ #st.markdown(f"{make_prediction(model, tokenizer, text)}")