Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -7,6 +7,7 @@ from transformers import TrainingArguments, \
|
|
7 |
categories = ['Biology', 'Computer science', 'Economics', 'Electrics', 'Finance',
|
8 |
'Math', 'Physics', 'Statistics']
|
9 |
labels = [i for i in range(len(categories))]
|
|
|
10 |
def print_probs(logits):
|
11 |
probs = torch.nn.functional.softmax(logits, dim=0).numpy()*100
|
12 |
ans = list(zip(probs,labels))
|
@@ -19,7 +20,8 @@ def print_probs(logits):
|
|
19 |
st.write(text)
|
20 |
sum+=prob
|
21 |
i+=1
|
22 |
-
|
|
|
23 |
def make_prediction(text):
|
24 |
tokenized_text = tokenizer(text, return_tensors='pt')
|
25 |
with torch.no_grad():
|
@@ -54,6 +56,6 @@ text = st.text_area("Введите название статьи", height=50)
|
|
54 |
st.markdown("### Article Abstract")
|
55 |
text = st.text_area("Введите описание статьи", height=400)
|
56 |
|
57 |
-
|
58 |
make_prediction(text)
|
59 |
|
|
|
7 |
categories = ['Biology', 'Computer science', 'Economics', 'Electrics', 'Finance',
|
8 |
'Math', 'Physics', 'Statistics']
|
9 |
labels = [i for i in range(len(categories))]
|
10 |
+
|
11 |
def print_probs(logits):
|
12 |
probs = torch.nn.functional.softmax(logits, dim=0).numpy()*100
|
13 |
ans = list(zip(probs,labels))
|
|
|
20 |
st.write(text)
|
21 |
sum+=prob
|
22 |
i+=1
|
23 |
+
|
24 |
+
@st.cache
|
25 |
def make_prediction(text):
|
26 |
tokenized_text = tokenizer(text, return_tensors='pt')
|
27 |
with torch.no_grad():
|
|
|
56 |
st.markdown("### Article Abstract")
|
57 |
text = st.text_area("Введите описание статьи", height=400)
|
58 |
|
59 |
+
|
60 |
make_prediction(text)
|
61 |
|