poseg commited on
Commit
9579422
1 Parent(s): d98d1eb
Files changed (1) hide show
  1. app.py +163 -0
app.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import transformers
2
+ import torch
3
+
4
+ import streamlit as st
5
+
6
+ from transformers import BertTokenizer
7
+
8
+ st.markdown("### Из какой серии статья")
9
+ # st.markdown("<img width=200px src='https://rozetked.me/images/uploads/dwoilp3BVjlE.jpg'>", unsafe_allow_html=True)
10
+
11
+ # from transformers import pipeline
12
+
13
+ # pipe = pipeline("ner", "Davlan/distilbert-base-multilingual-cased-ner-hrl")
14
+
15
+
16
+ num_classes = 8
17
+ class BERTClass(torch.nn.Module):
18
+ def __init__(self, n_hid1 = 1024, n_out=num_classes, bert_path='bert-base-uncased'):
19
+ super(BERTClass, self).__init__()
20
+ self.l1 = transformers.BertModel.from_pretrained(bert_path)
21
+ self.l2 = torch.nn.Dropout(0.3)
22
+ self.l3 = torch.nn.Linear(768, n_hid1)
23
+ self.l4 = torch.nn.ReLU()
24
+ self.l5 = torch.nn.Dropout(0.2)
25
+ self.l6 = torch.nn.Linear(n_hid1, n_out)
26
+
27
+ def forward(self, ids, mask, token_type_ids):
28
+ # _, output_1= self.l1(ids, attention_mask = mask, token_type_ids = token_type_ids)
29
+ out = self.l1(ids, attention_mask = mask, token_type_ids = token_type_ids)
30
+ out = self.l2(out[1])
31
+ out = self.l3(out)
32
+ out = self.l4(out)
33
+ out = self.l5(out)
34
+ out = self.l6(out)
35
+ return out
36
+
37
+ @st.cache
38
+ def load_bert():
39
+ model = BERTClass(bert_path='bert_pretrained')
40
+ model.load_state_dict(torch.load('bert_pretrained.pt'))
41
+ model.eval()
42
+
43
+ tokenizer = BertTokenizer.from_pretrained('bert_tokenizer')
44
+
45
+ return model, tokenizer
46
+
47
+
48
+ def apply_bert(text, model, tokenizer):
49
+ """returns probabilities"""
50
+ MAX_LEN = 200
51
+ ins = tokenizer.encode_plus(text, None, add_special_tokens=True,
52
+ max_length=MAX_LEN,
53
+ pad_to_max_length=True,
54
+ return_token_type_ids=True
55
+ )
56
+ ids = torch.tensor(ins['input_ids']).unsqueeze(0)
57
+ mask = torch.tensor(ins['attention_mask']).unsqueeze(0)
58
+ token_type_ids = torch.tensor(ins["token_type_ids"])
59
+ out = model(ids, mask, token_type_ids)
60
+ return torch.sigmoid(out).flatten().detach()
61
+
62
+
63
+ class TinyBERTClass(torch.nn.Module):
64
+ def __init__(self, n_hid1 = 1024, n_out=num_classes, path='distilbert-base-uncased'):
65
+ super(TinyBERTClass, self).__init__()
66
+ self.l1 = transformers.DistilBertModel.from_pretrained(path)
67
+ self.l2 = torch.nn.Dropout(0.3)
68
+ self.l3 = torch.nn.Linear(768, n_hid1)
69
+ self.l4 = torch.nn.ReLU()
70
+ self.l5 = torch.nn.Dropout(0.2)
71
+ self.l6 = torch.nn.Linear(n_hid1, n_out)
72
+
73
+ def forward(self, ids, mask):
74
+ # _, output_1= self.l1(ids, attention_mask = mask, token_type_ids = token_type_ids)
75
+ out = self.l1(ids, attention_mask = mask)
76
+ out = self.l2(out.last_hidden_state[:,0,:])
77
+ out = self.l3(out)
78
+ out = self.l4(out)
79
+ out = self.l5(out)
80
+ out = self.l6(out)
81
+ return out
82
+
83
+
84
+ @st.cache(suppress_st_warning=True)
85
+ def load_tiny_bert():
86
+ model = TinyBERTClass(path = 'tiny_bert_pretrained')
87
+ model.load_state_dict(torch.load('tiny_bert.pt'))
88
+ model.eval()
89
+
90
+ tokenizer = transformers.DistilBertTokenizer.from_pretrained('tiny_bert_tokenizer')
91
+
92
+ return model, tokenizer
93
+
94
+
95
+ def apply_tiny_bert(text, model, tokenizer):
96
+ encoded_input = tokenizer(text, return_tensors='pt')
97
+ out = model(encoded_input['input_ids'], encoded_input['attention_mask'])
98
+
99
+ return torch.sigmoid(out).flatten().detach()
100
+
101
+
102
+
103
+ title = st.text_area("Название статьи")
104
+ if not title.endswith('.') and title:
105
+ title += '.'
106
+
107
+ summary = st.text_area("Аннотация статьи")
108
+
109
+ calc_button = st.button('Угадать тематику')
110
+
111
+ bert_model, bert_tokenizer = load_bert()
112
+ tiny_bert, tiny_bert_tokenizer = load_tiny_bert()
113
+
114
+ # calculate ================================================================
115
+ if calc_button:
116
+ print('title')
117
+ print(title)
118
+ print('=' * 80)
119
+ # print(text)
120
+
121
+ if summary:
122
+ text = title + summary
123
+ out = apply_bert(text, bert_model, bert_tokenizer)
124
+ else:
125
+ out = apply_tiny_bert(title, tiny_bert, tiny_bert_tokenizer)
126
+
127
+
128
+ RU_NAMES = ['компьютерным наукам'
129
+ ,'экономике'
130
+ ,'электротехнике и системотехнике'
131
+ ,'математике'
132
+ ,'физике'
133
+ ,'количественной биологии'
134
+ ,'количественным финансам'
135
+ ,'статистике'
136
+ ]
137
+
138
+ def get_classes(out, bandwidth = 0.5):
139
+ res = []
140
+ for i in range(out.size()[0]):
141
+ if out[i] >= bandwidth:
142
+ res.append(i)
143
+
144
+ ans = ''
145
+ total = 0
146
+ for i in res:
147
+ total += out[i].item()
148
+ if not ans:
149
+ ans += f'\nэто статья по {RU_NAMES[i]} с вероятностью {out[i].item():.2f}'
150
+ else:
151
+ ans += f',\nтакже она по {RU_NAMES[i]} с вероятностью {out[i].item():.2f}'
152
+
153
+ ans = 'Э' + ans[2:]
154
+ if total >= 1.0:
155
+ ans += '.\n(Решалась задача мультиклассификации, поэтому сумма вероятностей получилась больше 1.)'
156
+
157
+ if ans == 'Э':
158
+ return 'Не похоже на что-то научное, Вы уверены что это взято из статьи?'
159
+ return ans
160
+
161
+ res = get_classes(out)
162
+
163
+ st.markdown(f"{res}")