Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import torch
|
2 |
import tokenizers
|
3 |
|
|
|
4 |
import streamlit as st
|
5 |
import torch.nn as nn
|
6 |
from transformers import RobertaTokenizer, RobertaModel
|
@@ -29,19 +30,26 @@ cats = ['Computer Science', 'Economics', 'Electrical Engineering',
|
|
29 |
def predict(outputs):
|
30 |
top = 0
|
31 |
probs = nn.functional.softmax(outputs, dim=1).tolist()[0]
|
|
|
|
|
|
|
32 |
|
33 |
for prob, cat in sorted(zip(probs, cats), reverse=True):
|
34 |
if top < 95:
|
35 |
percent = prob * 100
|
36 |
top += percent
|
37 |
-
|
|
|
|
|
|
|
|
|
38 |
|
39 |
tokenizer = RobertaTokenizer.from_pretrained("roberta-large-mnli")
|
40 |
model = init_model()
|
41 |
|
42 |
st.markdown("### Title")
|
43 |
|
44 |
-
title = st.text_area("Enter title", height=20)
|
45 |
|
46 |
st.markdown("### Abstract")
|
47 |
|
@@ -50,6 +58,7 @@ abstract = st.text_area("Enter abstract", height=200)
|
|
50 |
if not title:
|
51 |
st.warning("Please fill out so required fields")
|
52 |
else:
|
|
|
53 |
encoded_input = tokenizer(title + '. ' + abstract, return_tensors='pt', padding=True,
|
54 |
max_length = 512, truncation=True)
|
55 |
with torch.no_grad():
|
|
|
1 |
import torch
|
2 |
import tokenizers
|
3 |
|
4 |
+
import pandas as pd
|
5 |
import streamlit as st
|
6 |
import torch.nn as nn
|
7 |
from transformers import RobertaTokenizer, RobertaModel
|
|
|
30 |
def predict(outputs):
|
31 |
top = 0
|
32 |
probs = nn.functional.softmax(outputs, dim=1).tolist()[0]
|
33 |
+
|
34 |
+
top_cats = []
|
35 |
+
top_probs = []
|
36 |
|
37 |
for prob, cat in sorted(zip(probs, cats), reverse=True):
|
38 |
if top < 95:
|
39 |
percent = prob * 100
|
40 |
top += percent
|
41 |
+
top_cats.append(cat)
|
42 |
+
top_probs.append(prob)
|
43 |
+
|
44 |
+
chart_data = pd.DataFrame(top_probs, columns=top_cats)
|
45 |
+
st.bar_chart(chart_data)
|
46 |
|
47 |
tokenizer = RobertaTokenizer.from_pretrained("roberta-large-mnli")
|
48 |
model = init_model()
|
49 |
|
50 |
st.markdown("### Title")
|
51 |
|
52 |
+
title = st.text_area("* Enter title (required)", height=20)
|
53 |
|
54 |
st.markdown("### Abstract")
|
55 |
|
|
|
58 |
if not title:
|
59 |
st.warning("Please fill out so required fields")
|
60 |
else:
|
61 |
+
st.markdown("### Result")
|
62 |
encoded_input = tokenizer(title + '. ' + abstract, return_tensors='pt', padding=True,
|
63 |
max_length = 512, truncation=True)
|
64 |
with torch.no_grad():
|