kaza167 commited on
Commit
8ce4a44
1 Parent(s): a2d4b5a

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -0
app.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import BertTokenizer, BertModel
3
+ import torch
4
+
5
+ TAGS_CLASSES = ['cs.CV', 'cs.LG', 'cs.AI', 'stat.ML', 'cs.CL', 'cs.NE', 'cs.IR',
6
+ 'math.OC', 'cs.RO', 'cs.LO', 'cs.SI', 'cs.DS', 'cs.IT', 'math.IT',
7
+ 'q-bio.NC', 'stat.ME', 'cs.HC', 'cs.CR', 'cs.DC', 'cs.SD', 'cs.CY',
8
+ 'stat.AP', 'cs.MM', 'math.ST', 'stat.TH', 'cs.DB', 'cs.GT', 'I.2.7',
9
+ 'physics.soc-ph', 'cs.CE', 'cs.SY', 'cs.MA', 'stat.CO', 'cs.NA',
10
+ 'q-bio.QM', 'cs.GR', 'cs.CC', 'physics.data-an', 'cs.SE', 'math.NA',
11
+ 'math.PR', 'quant-ph', 'cs.DL', 'cs.NI', 'I.2.6', 'cs.PL',
12
+ 'cond-mat.dis-nn', 'nlin.AO', 'cmp-lg', 'cs.DM', 'Other']
13
+
14
+ class BERTClf(torch.nn.Module):
15
+ def __init__(self):
16
+ super(BERTClf, self).__init__()
17
+ self.bert_model = BertModel.from_pretrained('bert-base-uncased', return_dict=True)
18
+ self.dropout = torch.nn.Dropout(0.1)
19
+ self.linear = torch.nn.Linear(768, len(TAGS_CLASSES))
20
+ self.sigm = nn.Sigmoid()
21
+
22
+
23
+ def forward(self, input_ids, attn_mask, token_type_ids):
24
+ output = self.bert_model(
25
+ input_ids,
26
+ attention_mask=attn_mask,
27
+ token_type_ids=token_type_ids
28
+ )
29
+ output_dropout = self.dropout(output.pooler_output)
30
+ output = self.sigm(self.linear(output_dropout))
31
+ return output
32
+
33
+
34
+ MAX_LEN = 128
35
+
36
+
37
+ st.markdown("# Paper classification")
38
+ st.markdown("### Title of paper")
39
+ # ^-- можно показывать пользователю текст, картинки, ограниченное подмножество html - всё как в jupyter
40
+
41
+ title = st.text_area("TEXT HERE")
42
+ # ^-- показать текстовое поле. В поле text лежит строка, которая находится там в данный момент
43
+ st.markdown("### Summary of paper")
44
+ summary = st.text_area("TEXT HERE", key = "last_name")
45
+
46
+
47
+ text = 'Title: ' + title + '\nSummary: ' + summary
48
+
49
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
50
+
51
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
52
+ model = torch.load('model_5_eps', map_location=device)
53
+
54
+
55
+ encodings = tokenizer.encode_plus(
56
+ text,
57
+ None,
58
+ add_special_tokens=True,
59
+ max_length=MAX_LEN,
60
+ padding='max_length',
61
+ return_token_type_ids=True,
62
+ truncation=True,
63
+ return_attention_mask=True,
64
+ return_tensors='pt'
65
+ )
66
+ model.eval()
67
+ with torch.no_grad():
68
+ input_ids = encodings['input_ids'].to(device, dtype=torch.long)
69
+ attention_mask = encodings['attention_mask'].to(device, dtype=torch.long)
70
+ token_type_ids = encodings['token_type_ids'].to(device, dtype=torch.long)
71
+ output = model(input_ids, attention_mask, token_type_ids)
72
+ final_output = output.cpu().detach().numpy().tolist()
73
+ pred = ([(k,v) for k, v in sorted(zip(TAGS_CLASSES, final_output[0]), key=lambda item: -item[1])])# тут уже знакомый вам код с huggingface.transformers -- его можно заменить на что угодно от fairseq до catboost
74
+ probs = 0
75
+ ans = []
76
+ for k, v in pred:
77
+ if probs > 0.95:
78
+ break
79
+ probs += v
80
+ ans.append(k)
81
+
82
+ st.markdown(f"{', '.join(ans)}")