Poe Dator commited on
Commit
b339b00
1 Parent(s): 90cb2de

inference code, init version

Browse files
Files changed (1) hide show
  1. app.py +66 -1
app.py CHANGED
@@ -1,9 +1,74 @@
1
  import streamlit as st
 
 
 
 
 
2
 
3
  st.markdown("### Privet, mir!")
4
  st.markdown("<img width=200px src='https://i.pinimg.com/736x/11/33/19/113319f0ffe91f4bb0f468914b9916da.jpg'>", unsafe_allow_html=True)
5
 
6
  text = st.text_area("ENTER TEXT HERE")
7
-
8
  t2 = text.upper()
9
  st.markdown(f"{t2}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import torch
3
+ from torch import nn
4
+ from transformers import BertModel, AutoTokenizer, AutoModel, pipeline
5
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
6
+ device = 'cpu'
7
 
8
  st.markdown("### Privet, mir!")
9
  st.markdown("<img width=200px src='https://i.pinimg.com/736x/11/33/19/113319f0ffe91f4bb0f468914b9916da.jpg'>", unsafe_allow_html=True)
10
 
11
  text = st.text_area("ENTER TEXT HERE")
 
12
  t2 = text.upper()
13
  st.markdown(f"{t2}")
14
+
15
+ # dict for decoding / enclding labels
16
+ labels = {'cs.NE': 0, 'cs.CL': 1, 'cs.AI': 2, 'stat.ML': 3, 'cs.CV': 4, 'cs.LG': 5}
17
+ labels_decoder = {'cs.NE': 'Neural and Evolutionary Computing', 'cs.CL': 'Computation and Language', 'cs.AI': 'Artificial Intelligence',
18
+ 'stat.ML': 'Machine Learning (stat)', 'cs.CV': 'Computer Vision', 'cs.LG': 'Machine Learning'}
19
+
20
+ model_name = 'bert-base-uncased'
21
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
22
+
23
+ class BertClassifier(nn.Module):
24
+
25
+ def __init__(self, n_classes, dropout=0.5, model_name='bert-base-uncased'):
26
+ super(BertClassifier, self).__init__()
27
+ self.bert = BertModel.from_pretrained(model_name)
28
+ self.dropout = nn.Dropout(dropout)
29
+ self.linear = nn.Linear(768, n_classes)
30
+ self.relu = nn.ReLU()
31
+
32
+ def forward(self, input_id, mask):
33
+ _, pooled_output = self.bert(input_ids=input_id, attention_mask=mask,return_dict=False)
34
+ dropout_output = self.dropout(pooled_output)
35
+ linear_output = self.linear(dropout_output)
36
+ final_layer = self.relu(linear_output)
37
+ return final_layer
38
+
39
+ model = BertClassifier(n_classes=len(labels))
40
+ model.load_state_dict(torch.load('model_weights_1.pt'))
41
+ model.eval()
42
+
43
+ def inference(txt, mode=None):
44
+ # infers classes for text topic based on the trained model from above
45
+ # has separate mode 'print' for just output
46
+
47
+ txt = txt.lower().replace('\n', '')
48
+
49
+ t2 = tokenizer(txt,
50
+ padding='max_length', max_length = 512, truncation=True,
51
+ return_tensors="pt")
52
+
53
+ inp2 = t2['input_ids'].to(device)
54
+ mask2 = t2['attention_mask'].unsqueeze(0).to(device)
55
+
56
+ out = model(inp2, mask2)
57
+ out = out.cpu().detach().numpy().reshape(-1)
58
+ out = out/out.sum() * 100
59
+ res = [(l, o) for l, o in zip (list(labels.keys()), out.tolist())]
60
+
61
+ if mode == 'print':
62
+ res.sort(key = lambda x : - x[1])
63
+ for lbl, score in res:
64
+ if score >=1:
65
+ print(f"[{lbl:<7}] {labels_decoder[lbl]:<35} {score:.1f}%")
66
+
67
+ elif mode == 'debug':
68
+ return out, res
69
+
70
+ else:
71
+ return res
72
+
73
+ res = inference(text, mode=None)
74
+ st.markdown(f"{res}")