istassiy commited on
Commit
5409e3f
·
1 Parent(s): e1fd88d

commit from

Browse files
Files changed (1) hide show
  1. app.py +87 -7
app.py CHANGED
@@ -53,7 +53,7 @@ def sigmoid(x):
53
  return 1/(1 + np.exp(-x))
54
 
55
  def get_top_predictions(predictions):
56
- probs = (sigmoid(predictions) > 0).astype(int)
57
  probs = probs / np.sum(probs)
58
 
59
  res = {}
@@ -81,16 +81,96 @@ if not paper_title and not paper_summary:
81
  st.markdown(f"Must have non-empty title or summary")
82
  else:
83
  with torch.no_grad():
84
- st.markdown(f"{model}", unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  raw_predictions = model(
87
  **tokenizer(
88
  [paper_title + "." + paper_summary],
89
  padding=True, truncation=True, return_tensors="pt"
90
  )
91
- )
92
- dir_pred = dir(raw_predictions)
93
- st.markdown(f"{dir_pred}")
94
- st.markdown(f"{raw_predictions}")
95
  results = get_top_predictions(raw_predictions[0][0].numpy())
96
- st.markdown(f"{results}")
 
 
 
53
  return 1/(1 + np.exp(-x))
54
 
55
  def get_top_predictions(predictions):
56
+ probs = (sigmoid(predictions) > 0.5).astype(float)
57
  probs = probs / np.sum(probs)
58
 
59
  res = {}
 
81
  st.markdown(f"Must have non-empty title or summary")
82
  else:
83
  with torch.no_grad():
84
+ import streamlit as st
85
+ import numpy as np
86
+ import torch
87
+ from transformers import AutoTokenizer, AutoModel, DistilBertForSequenceClassification
88
+
89
+ my_model_name = "istassiy/ysda_2022_ml2_hw3_distilbert_base_uncased"
90
+
91
+ arxiv_code_to_topic = {
92
+ 'cs' : 'computer science',
93
+
94
+ 'q-bio' : 'biology',
95
 
96
+ 'q-fin' : 'finance',
97
+
98
+ 'astro-ph' : 'physics',
99
+ 'cond-mat' : 'physics',
100
+ 'gr-qc' : 'physics',
101
+ 'hep-ex' : 'physics',
102
+ 'hep-lat' : 'physics',
103
+ 'hep-ph' : 'physics',
104
+ 'hep-th' : 'physics',
105
+ 'math-ph' : 'physics',
106
+ 'nlin' : 'physics',
107
+ 'nucl-ex' : 'physics',
108
+ 'nucl-th' : 'physics',
109
+ 'quant-ph' : 'physics',
110
+ 'physics' : 'physics',
111
+
112
+ 'eess' : 'electrical engineering',
113
+
114
+ 'econ' : 'economics',
115
+
116
+ 'math' : 'mathematics',
117
+
118
+ 'stat' : 'statistics',
119
+ }
120
+
121
+ sorted_arxiv_topics = sorted(set(arxiv_code_to_topic.values()))
122
+
123
+ NUM_LABELS = len(sorted_arxiv_topics)
124
+
125
+ @st.cache(allow_output_mutation=True)
126
+ def load_tokenizer():
127
+ tokenizer = AutoTokenizer.from_pretrained(my_model_name)
128
+ return tokenizer
129
+
130
+ @st.cache(allow_output_mutation=True)
131
+ def load_model():
132
+ model = DistilBertForSequenceClassification.from_pretrained(my_model_name)
133
+ return model
134
+
135
+ def sigmoid(x):
136
+ return 1/(1 + np.exp(-x))
137
+
138
+ def get_top_predictions(predictions):
139
+ probs = (sigmoid(predictions) > 0.5).astype(float)
140
+ probs = probs / np.sum(probs)
141
+
142
+ res = {}
143
+ total_prob = 0
144
+ for topic, prob in zip(sorted_arxiv_topics, probs):
145
+ total_prob += prob
146
+ res[topic] = prob
147
+ if total_prob > 0.95:
148
+ break
149
+ return res
150
+
151
+ tokenizer = load_tokenizer()
152
+ model = load_model()
153
+
154
+ st.markdown("# Scientific paper classificator")
155
+ st.markdown(
156
+ "Fill in paper summary and / or title below:",
157
+ unsafe_allow_html=False
158
+ )
159
+
160
+ paper_title = st.text_area("Paper title")
161
+ paper_summary = st.text_area("Paper abstract")
162
+
163
+ if not paper_title and not paper_summary:
164
+ st.markdown(f"Must have non-empty title or summary")
165
+ else:
166
+ with torch.no_grad():
167
  raw_predictions = model(
168
  **tokenizer(
169
  [paper_title + "." + paper_summary],
170
  padding=True, truncation=True, return_tensors="pt"
171
  )
172
+ )
 
 
 
173
  results = get_top_predictions(raw_predictions[0][0].numpy())
174
+ st.markdown("The following are probabilities:")
175
+ for topic, prob in sorted(results.items(), lambda item: item[1], reverse=True):
176
+ st.markdown(f"{topic}: {prob}")