Spaces:
Runtime error
Runtime error
commit from
Browse files
app.py
CHANGED
@@ -53,7 +53,7 @@ def sigmoid(x):
|
|
53 |
return 1/(1 + np.exp(-x))
|
54 |
|
55 |
def get_top_predictions(predictions):
|
56 |
-
probs = (sigmoid(predictions) > 0).astype(
|
57 |
probs = probs / np.sum(probs)
|
58 |
|
59 |
res = {}
|
@@ -81,16 +81,96 @@ if not paper_title and not paper_summary:
|
|
81 |
st.markdown(f"Must have non-empty title or summary")
|
82 |
else:
|
83 |
with torch.no_grad():
|
84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
raw_predictions = model(
|
87 |
**tokenizer(
|
88 |
[paper_title + "." + paper_summary],
|
89 |
padding=True, truncation=True, return_tensors="pt"
|
90 |
)
|
91 |
-
)
|
92 |
-
dir_pred = dir(raw_predictions)
|
93 |
-
st.markdown(f"{dir_pred}")
|
94 |
-
st.markdown(f"{raw_predictions}")
|
95 |
results = get_top_predictions(raw_predictions[0][0].numpy())
|
96 |
-
st.markdown(
|
|
|
|
|
|
53 |
return 1/(1 + np.exp(-x))
|
54 |
|
55 |
def get_top_predictions(predictions):
|
56 |
+
probs = (sigmoid(predictions) > 0.5).astype(float)
|
57 |
probs = probs / np.sum(probs)
|
58 |
|
59 |
res = {}
|
|
|
81 |
st.markdown(f"Must have non-empty title or summary")
|
82 |
else:
|
83 |
with torch.no_grad():
|
84 |
+
import streamlit as st
|
85 |
+
import numpy as np
|
86 |
+
import torch
|
87 |
+
from transformers import AutoTokenizer, AutoModel, DistilBertForSequenceClassification
|
88 |
+
|
89 |
+
my_model_name = "istassiy/ysda_2022_ml2_hw3_distilbert_base_uncased"
|
90 |
+
|
91 |
+
arxiv_code_to_topic = {
|
92 |
+
'cs' : 'computer science',
|
93 |
+
|
94 |
+
'q-bio' : 'biology',
|
95 |
|
96 |
+
'q-fin' : 'finance',
|
97 |
+
|
98 |
+
'astro-ph' : 'physics',
|
99 |
+
'cond-mat' : 'physics',
|
100 |
+
'gr-qc' : 'physics',
|
101 |
+
'hep-ex' : 'physics',
|
102 |
+
'hep-lat' : 'physics',
|
103 |
+
'hep-ph' : 'physics',
|
104 |
+
'hep-th' : 'physics',
|
105 |
+
'math-ph' : 'physics',
|
106 |
+
'nlin' : 'physics',
|
107 |
+
'nucl-ex' : 'physics',
|
108 |
+
'nucl-th' : 'physics',
|
109 |
+
'quant-ph' : 'physics',
|
110 |
+
'physics' : 'physics',
|
111 |
+
|
112 |
+
'eess' : 'electrical engineering',
|
113 |
+
|
114 |
+
'econ' : 'economics',
|
115 |
+
|
116 |
+
'math' : 'mathematics',
|
117 |
+
|
118 |
+
'stat' : 'statistics',
|
119 |
+
}
|
120 |
+
|
121 |
+
sorted_arxiv_topics = sorted(set(arxiv_code_to_topic.values()))
|
122 |
+
|
123 |
+
NUM_LABELS = len(sorted_arxiv_topics)
|
124 |
+
|
125 |
+
@st.cache(allow_output_mutation=True)
|
126 |
+
def load_tokenizer():
|
127 |
+
tokenizer = AutoTokenizer.from_pretrained(my_model_name)
|
128 |
+
return tokenizer
|
129 |
+
|
130 |
+
@st.cache(allow_output_mutation=True)
|
131 |
+
def load_model():
|
132 |
+
model = DistilBertForSequenceClassification.from_pretrained(my_model_name)
|
133 |
+
return model
|
134 |
+
|
135 |
+
def sigmoid(x):
|
136 |
+
return 1/(1 + np.exp(-x))
|
137 |
+
|
138 |
+
def get_top_predictions(predictions):
|
139 |
+
probs = (sigmoid(predictions) > 0.5).astype(float)
|
140 |
+
probs = probs / np.sum(probs)
|
141 |
+
|
142 |
+
res = {}
|
143 |
+
total_prob = 0
|
144 |
+
for topic, prob in zip(sorted_arxiv_topics, probs):
|
145 |
+
total_prob += prob
|
146 |
+
res[topic] = prob
|
147 |
+
if total_prob > 0.95:
|
148 |
+
break
|
149 |
+
return res
|
150 |
+
|
151 |
+
tokenizer = load_tokenizer()
|
152 |
+
model = load_model()
|
153 |
+
|
154 |
+
st.markdown("# Scientific paper classificator")
|
155 |
+
st.markdown(
|
156 |
+
"Fill in paper summary and / or title below:",
|
157 |
+
unsafe_allow_html=False
|
158 |
+
)
|
159 |
+
|
160 |
+
paper_title = st.text_area("Paper title")
|
161 |
+
paper_summary = st.text_area("Paper abstract")
|
162 |
+
|
163 |
+
if not paper_title and not paper_summary:
|
164 |
+
st.markdown(f"Must have non-empty title or summary")
|
165 |
+
else:
|
166 |
+
with torch.no_grad():
|
167 |
raw_predictions = model(
|
168 |
**tokenizer(
|
169 |
[paper_title + "." + paper_summary],
|
170 |
padding=True, truncation=True, return_tensors="pt"
|
171 |
)
|
172 |
+
)
|
|
|
|
|
|
|
173 |
results = get_top_predictions(raw_predictions[0][0].numpy())
|
174 |
+
st.markdown("The following are probabilities:")
|
175 |
+
for topic, prob in sorted(results.items(), lambda item: item[1], reverse=True):
|
176 |
+
st.markdown(f"{topic}: {prob}")
|