Darkhan commited on
Commit
6f6d6c2
1 Parent(s): 14893b3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -2
app.py CHANGED
@@ -1,4 +1,47 @@
1
  import streamlit as st
 
 
2
 
3
- x = st.slider('Select a value')
4
- st.write(x, 'squared is', x * x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
 
5
+ MODEL_NAME = 'bert-base-cased'
6
+ MODEL_PATH = 'bert_model'
7
+
8
+ ID2CLS = {
9
+ 0: 'Computer Science',
10
+ 1: 'Economics',
11
+ 2: 'Electrical Engineering and Systems Science',
12
+ 3: 'Mathematics',
13
+ 4: 'Physics',
14
+ 5: 'Quantitative Biology',
15
+ 6: 'Quantitative Finance',
16
+ 7: 'Statistics'
17
+ }
18
+
19
+
20
+ def classify(text, tokenizer, model):
21
+ if not text:
22
+ return [""]
23
+ tokens = tokenizer([text], truncation=True, padding=True, max_length=256, return_tensors="pt")['input_ids']
24
+ probabilities = torch.softmax(model(tokens).logits, dim=1).detach().cpu().numpy()[0]
25
+ total = 0
26
+ ans = []
27
+
28
+ for p in probabilities.argsort()[::-1]:
29
+ if probabilities[p] + total < 0.9:
30
+ total += probabilities[p]
31
+ ans += [f'{ID2CLS[p]}: {round(probabilities[p] * 100, 2)}%']
32
+
33
+ return ans
34
+
35
+
36
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
37
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=8)
38
+ # model.load_state_dict(torch.load(model_path))
39
+ model.eval()
40
+
41
+ st.markdown("## Article classifier")
42
+
43
+ title = st.text_area("title")
44
+ text = st.text_area("article")
45
+
46
+ for prediction in classify(title + text, tokenizer, model):
47
+ st.markdown(prediction)