Update app.py
Browse files
app.py
CHANGED
@@ -37,14 +37,27 @@ def LoadModel():
|
|
37 |
|
38 |
model = LoadModel()
|
39 |
|
|
|
|
|
40 |
def process(title, summary):
|
41 |
text = title + summary
|
|
|
|
|
42 |
model.eval()
|
43 |
lines = [text]
|
44 |
X = tokenizer(lines, padding=True, truncation=True, return_tensors="pt")
|
45 |
out = model(X)
|
46 |
probs = torch.exp(out[0])
|
47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
|
49 |
title = st.text_area("Title", height=30)
|
50 |
|
|
|
37 |
|
38 |
model = LoadModel()
|
39 |
|
40 |
+
classes = ['Computer Science', 'Mathematics', 'Physics', 'Quantitative Biology', 'Statistics']
|
41 |
+
|
42 |
def process(title, summary):
|
43 |
text = title + summary
|
44 |
+
if not text.strip():
|
45 |
+
return ''
|
46 |
model.eval()
|
47 |
lines = [text]
|
48 |
X = tokenizer(lines, padding=True, truncation=True, return_tensors="pt")
|
49 |
out = model(X)
|
50 |
probs = torch.exp(out[0])
|
51 |
+
sorted_indexes = torch.argsort(probs, descending=True)
|
52 |
+
probs_sum = idx = 0
|
53 |
+
str = ''
|
54 |
+
while probs_sum < 0.95:
|
55 |
+
prob_idx = sorted_indexes[idx]
|
56 |
+
prob = probs[prob_idx]
|
57 |
+
str += f'{classes[prob_idx]}: {prob:.3f}\n'
|
58 |
+
idx += 1
|
59 |
+
probs_sum += prob
|
60 |
+
return str
|
61 |
|
62 |
title = st.text_area("Title", height=30)
|
63 |
|