Spaces:
Runtime error
Runtime error
update
Browse files
app.py
CHANGED
@@ -1,6 +1,44 @@
|
|
1 |
import streamlit as st
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
st.markdown("##Hello, people!")
|
4 |
st.markdown("<img src='https://centroderecursosmarista.org/wp-content/uploads/2013/05/arvix.jpg'>", unsafe_allow_html=True)
|
5 |
-
|
6 |
-
|
|
|
|
1 |
import streamlit as st
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
from transformers import TrainingArguments, \
|
5 |
+
Trainer, AutoTokenizer, DataCollatorWithPadding, \
|
6 |
+
AutoModelForSequenceClassification
|
7 |
+
def print_probs(logits):
|
8 |
+
probs = torch.nn.functional.softmax(logits, dim=0).numpy()*100
|
9 |
+
ans = list(zip(probs,labels))
|
10 |
+
ans.sort(reverse=True)
|
11 |
+
sum = 0
|
12 |
+
i = 0
|
13 |
+
while sum <= 95:
|
14 |
+
prob, idx = ans[i]
|
15 |
+
text = categories[idx] + ": "+ str(np.round(prob,1) + "%"
|
16 |
+
st.markdown(text)
|
17 |
+
sum+=prob
|
18 |
+
i+=1
|
19 |
+
|
20 |
+
def make_prediction(text):
|
21 |
+
tokenized_text = tokenizer(text, return_tensors='pt')
|
22 |
+
with torch.no_grad():
|
23 |
+
pred_logits = model(**tokenized_text).logits
|
24 |
+
st.markdown("Predictions:")
|
25 |
+
print_probs(pred_logits[0])
|
26 |
+
|
27 |
+
|
28 |
+
|
29 |
+
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
|
30 |
+
|
31 |
+
model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=8)
|
32 |
+
model_name = train_model2
|
33 |
+
model_path = model_name + '.zip'
|
34 |
+
model.load_state_dict(
|
35 |
+
torch.load(
|
36 |
+
model_path
|
37 |
+
)
|
38 |
+
)
|
39 |
|
40 |
st.markdown("##Hello, people!")
|
41 |
st.markdown("<img src='https://centroderecursosmarista.org/wp-content/uploads/2013/05/arvix.jpg'>", unsafe_allow_html=True)
|
42 |
+
text = st.text_area("Введите описание статьи")
|
43 |
+
make_prediction(text)
|
44 |
+
|