ki33elev commited on
Commit
2c5279b
1 Parent(s): 0d1beec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -3
app.py CHANGED
@@ -1,14 +1,53 @@
1
  import streamlit as st
 
 
2
 
3
  @st.cache(suppress_st_warning=True)
4
- def load_data():
5
- return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  st.title("Arxiv articles classification")
8
  st.markdown("This is an interface that can determine the article's topic based on its title and summary. Though it can work with title only, it is recommended that you provide summary if possible - this will result in a better prediction quality.")
9
 
 
 
10
  title = st.text_area(label='Title', height=100)
11
  summary = st.text_area(label='Summary (optional)', height=250)
12
 
13
- text = title + "\n" + summary
 
14
  st.markdown(text)
 
1
  import streamlit as st
2
+ import pytorch
3
+ import transformers
4
 
5
  @st.cache(suppress_st_warning=True)
6
+ def load_model():
7
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
8
+ model_name = 'distilbert-base-cased'
9
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
10
+ model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=8)
11
+ model.load_state_dict(torch.load('model_weights.pt', map_location=torch.device('cpu')))
12
+ model.eval()
13
+ return tokenizer, model
14
+
15
+ @st.cache(suppress_st_warning=True)
16
+ def predict(title, summary, tokenizer, model):
17
+ text = title + "\n" + summary
18
+ tokens = tokenizer.encode(text)
19
+ with torch.no_grad():
20
+ logits = model(torch.as_tensor([tokens], device=device))[0]
21
+ probs = torch.softmax(logits[-1, :], dim=-1).data.cpu().numpy()
22
+
23
+ classes = np.flip(np.argsort(probs))
24
+ sum_probs = 0
25
+ ind = 0
26
+ prediction = []
27
+ prediction_probs = []
28
+ while sum_probs < 0.95:
29
+ prediction.append(label_to_theme[classes[ind]])
30
+ prediction_probs.append(probs[classes[ind]])
31
+ sum_probs += probs[classes[ind]]
32
+ ind += 1
33
+
34
+ return prediction, prediction_probs
35
+
36
+ @st.cache(suppress_st_warning=True)
37
+ def get_results(prediction, prediction_probs):
38
+ ans = "Topic:\t\tConfidence:\n"
39
+ for (class, prob) in zip(prediction, prediction_probs):
40
+ ans += class + "\t\t" + str(prob) + "\n"
41
+ return ans
42
 
43
  st.title("Arxiv articles classification")
44
  st.markdown("This is an interface that can determine the article's topic based on its title and summary. Though it can work with title only, it is recommended that you provide summary if possible - this will result in a better prediction quality.")
45
 
46
+ tokenizer, model = load_model()
47
+
48
  title = st.text_area(label='Title', height=100)
49
  summary = st.text_area(label='Summary (optional)', height=250)
50
 
51
+ prediction, prediction_probs = predict(title, summary, tokenizer, model)
52
+ ans = get_results(prediction, prediction_probs)
53
  st.markdown(text)