Spaces:
Runtime error
Runtime error
update
Browse files
app.py
CHANGED
@@ -29,40 +29,35 @@ def make_prediction(text):
|
|
29 |
st.markdown("### Category prediction:")
|
30 |
print_probs(pred_logits[0])
|
31 |
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
map_location=torch.device("cpu")
|
44 |
-
)
|
45 |
-
)
|
46 |
-
return tokenizer, model
|
47 |
|
48 |
# MAIN
|
49 |
from PIL import Image
|
50 |
image = Image.open('logo.png')
|
51 |
|
52 |
st.image(image)
|
53 |
-
|
54 |
-
# st.markdown("<img src='https://centroderecursosmarista.org/wp-content/uploads/2013/05/arvix.jpg' class='center'>", unsafe_allow_html=True)
|
55 |
-
# st.markdown("# Arxiv.org category classifier")
|
56 |
st.markdown("# ")
|
57 |
-
|
58 |
st.markdown("### Article Title")
|
|
|
59 |
text1 = st.text_area("Введите название научной статьи для классификации", height=20)
|
60 |
|
61 |
st.markdown("### Article Abstract")
|
62 |
|
63 |
text2 = st.text_area("Введите описание статьи", height=200)
|
|
|
64 |
common_text = text1 + text2
|
|
|
65 |
if common_text != "":
|
66 |
-
tokenizer, model = model_init()
|
67 |
make_prediction(common_text)
|
68 |
|
|
29 |
st.markdown("### Category prediction:")
|
30 |
print_probs(pred_logits[0])
|
31 |
|
32 |
+
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
|
33 |
+
|
34 |
+
model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=8)
|
35 |
+
model_name = "trained_model2"
|
36 |
+
model_path = model_name + '.zip'
|
37 |
+
model.load_state_dict(
|
38 |
+
torch.load(
|
39 |
+
model_path,
|
40 |
+
map_location=torch.device("cpu")
|
41 |
+
)
|
42 |
+
)
|
|
|
|
|
|
|
|
|
43 |
|
44 |
# MAIN
|
45 |
from PIL import Image
|
46 |
image = Image.open('logo.png')
|
47 |
|
48 |
st.image(image)
|
|
|
|
|
|
|
49 |
st.markdown("# ")
|
|
|
50 |
st.markdown("### Article Title")
|
51 |
+
|
52 |
text1 = st.text_area("Введите название научной статьи для классификации", height=20)
|
53 |
|
54 |
st.markdown("### Article Abstract")
|
55 |
|
56 |
text2 = st.text_area("Введите описание статьи", height=200)
|
57 |
+
|
58 |
common_text = text1 + text2
|
59 |
+
|
60 |
if common_text != "":
|
|
|
61 |
make_prediction(common_text)
|
62 |
|
63 |
+
|