seal345 commited on
Commit
ff5406d
1 Parent(s): 6fe3d5e

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +126 -0
app.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import numpy as np
3
+ import pandas as pd
4
+ import torch
5
+ import transformers
6
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
7
+ from scipy.special import softmax
8
+
9
+
10
+
11
+ def load_model():
12
+ model = AutoModelForSequenceClassification.from_pretrained('model_distilbert_trained')
13
+ tokenizer = AutoTokenizer.from_pretrained(
14
+ 'distilbert-base-cased', do_lower_case=True)
15
+ model.eval()
16
+ return model, tokenizer
17
+
18
+
19
+ def get_predictions(logits, indexes):
20
+ sum = 0
21
+ ind = []
22
+ probs = []
23
+ for i in indexes:
24
+ sum += logits[i]
25
+ ind.append(i)
26
+ probs.append(indexes[i])
27
+ if sum >= 0.95:
28
+ return ind, probs
29
+
30
+
31
+ def return_pred_name(name_dict, ind):
32
+ out = []
33
+ for i in ind:
34
+ out.append(name_dict[i])
35
+ return out
36
+
37
+
38
+ def predict(title, summary, model, tokenizer):
39
+ text = title + '.' + summary
40
+ tokens = tokenizer.encode(text)
41
+ with torch.no_grad():
42
+ logits = model(torch.as_tensor([tokens]))[0]
43
+ probs = torch.softmax(logits[-1, :], dim=-1).data.cpu().numpy()
44
+
45
+ classes = np.flip(np.argsort(probs))
46
+ sum_probs = 0
47
+ ind = 0
48
+ prediction = []
49
+ prediction_probs = []
50
+ while sum_probs < 0.95:
51
+ prediction.append(name_dict[classes[ind]])
52
+ prediction_probs.append(str("{:.2f}".format(100 * probs[classes[ind]])) + "%")
53
+ sum_probs += probs[classes[ind]]
54
+ ind += 1
55
+
56
+ return prediction, prediction_probs
57
+
58
+
59
+ def get_results(prediction, prediction_probs):
60
+ frame = pd.DataFrame({'Category': prediction, 'Confidence': prediction_probs})
61
+ frame.index = np.arange(1, len(frame) + 1)
62
+ return frame
63
+
64
+ name_dict = {4: 'cs',
65
+ 19: 'stat',
66
+ 1: 'astro-ph',
67
+ 16: 'q-bio',
68
+ 6: 'eess',
69
+ 3: 'cond-mat',
70
+ 12: 'math',
71
+ 15: 'physics',
72
+ 18: 'quant-ph',
73
+ 17: 'q-fin',
74
+ 7: 'gr-qc',
75
+ 13: 'nlin',
76
+ 2: 'cmp-lg',
77
+ 5: 'econ',
78
+ 8: 'hep-ex',
79
+ 11: 'hep-th',
80
+ 14: 'nucl-th',
81
+ 10: 'hep-ph',
82
+ 9: 'hep-lat',
83
+ 0: 'adap-org'}
84
+
85
+
86
+
87
+ st.title("Find out the topic of the article without reading!")
88
+ st.markdown("<h1 style='text-align: center;'><img width=320px src = 'https://upload.wikimedia.org/wikipedia/ru/8/81/Sheldon_cooper.jpg'>",
89
+ unsafe_allow_html=True)
90
+ # ^-- можно показывать пользователю текст, картинки, ограниченное подмножество html - всё как в jupyter
91
+
92
+ title = st.text_area(label='Title',
93
+ value='',
94
+ height=30,
95
+ help='If you know a title type it here')
96
+
97
+
98
+ summary = st.text_area(label='Summary',
99
+ value='',
100
+ height=200,
101
+ help='If you have a summary enter it here')
102
+
103
+
104
+ button = st.button(label='Get the theme!')
105
+
106
+ if button:
107
+ if (title == '' and summary == ''):
108
+ st.write('There is nothing to analyze...')
109
+ st.write('Fill at list one of the fields')
110
+ else:
111
+ model, tokenizer = load_model()
112
+ prediction, prediction_probs = predict(title, summary, model, tokenizer)
113
+ ans = get_results(prediction, prediction_probs)
114
+ st.write('Result')
115
+ st.write(ans)
116
+
117
+
118
+ # ^-- показать текстовое поле. В поле text лежит строка, которая находится там в данный момент
119
+
120
+ #from transformers import pipeline
121
+ #pipe = pipeline("ner", "Davlan/distilbert-base-multilingual-cased-ner-hrl")
122
+ #raw_predictions = pipe(text)
123
+ # тут уже знакомый вам код с huggingface.transformers -- его можно заменить на что угодно от fairseq до catboost
124
+
125
+ #st.markdown(f"{raw_predictions}")
126
+ # выводим результаты модели в текстовое поле, на потеху пользователю