kotstantinovskii commited on
Commit
c3da1c2
1 Parent(s): 5d645ca

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -0
app.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ from torch.nn import Softmax
4
+
5
+ from model import ArxivModel, load_model
6
+ from tokenizer import get_tokenizer
7
+
8
+ from lables import num_to_classes, taxonomy
9
+
10
+ from parser import get_text_title
11
+
12
+ model = load_model()
13
+ tokenizer = get_tokenizer()
14
+
15
+ arxiv_model = ArxivModel(model, tokenizer)
16
+ softmax = Softmax(dim=1)
17
+
18
+ st.markdown("### Classification of article topics")
19
+
20
+ col1, col2 = st.columns(2)
21
+
22
+ text = ""
23
+ with col1:
24
+ title_text = st.text_area("Write title of article", key='arxiv_title_input')
25
+
26
+ with col2:
27
+ summary_text = st.text_area("Write summary of article (optional)", key='arxiv_sum_input')
28
+ click_button_text = st.button('Submit title and summary', key=1)
29
+
30
+ if click_button_text and summary_text.strip() != "":
31
+ text = title_text.strip() + '\t' + summary_text.strip()
32
+ else:
33
+ text = title_text.strip()
34
+ text = text.strip()
35
+
36
+ id_url = st.text_input("Write article's url or id", key='arxiv_id_input').strip()
37
+ click_button_url = st.button('Submit id', key=1)
38
+
39
+ if click_button_url and id_url != "":
40
+ res = get_text_title(id_url)
41
+ if res is not None:
42
+ text = res[0].strip() + '\t' + res[1].strip()
43
+ text = text.strip()
44
+ else:
45
+ st.markdown(f'<p style="color:#FF2D00;font-size:18px">Incorrect url or id</p>', unsafe_allow_html=True)
46
+ text = ""
47
+
48
+ print(text)
49
+
50
+ if text != "":
51
+ idxs = arxiv_model.get_idx_class(text, thr=0.95)[:10]
52
+
53
+ for idx, prob in idxs:
54
+ if taxonomy.get(num_to_classes[idx], -1) != -1:
55
+ st.markdown("{} \t {}%".format(taxonomy.get(num_to_classes[idx], -1), round(prob * 100, 1)))
56
+ else:
57
+ st.markdown("")