Spaces:
Runtime error
Runtime error
import streamlit as st | |
import torch | |
import numpy as np | |
from transformers import TrainingArguments, \ | |
Trainer, AutoTokenizer, DataCollatorWithPadding, \ | |
AutoModelForSequenceClassification | |
categories = ['Biology', 'Computer science', 'Economics', 'Electrics', 'Finance', | |
'Math', 'Physics', 'Statistics'] | |
labels = [i for i in range(len(categories))] | |
def print_probs(logits): | |
probs = torch.nn.functional.softmax(logits, dim=0).numpy()*100 | |
ans = list(zip(probs,labels)) | |
ans.sort(reverse=True) | |
sum = 0 | |
i = 0 | |
while sum <= 95: | |
prob, idx = ans[i] | |
text = categories[idx] + ": "+ str(np.round(prob,1)) + "%" | |
st.write(text) | |
sum+=prob | |
i+=1 | |
# @st.cache | |
def make_prediction(text): | |
tokenized_text = tokenizer(text, return_tensors='pt') | |
with torch.no_grad(): | |
pred_logits = model(**tokenized_text).logits | |
st.markdown("### Category prediction:") | |
print_probs(pred_logits[0]) | |
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased") | |
model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=8) | |
model_name = "trained_model2" | |
model_path = model_name + '.zip' | |
model.load_state_dict( | |
torch.load( | |
model_path, | |
map_location=torch.device("cpu") | |
) | |
) | |
# MAIN | |
from PIL import Image | |
image = Image.open('logo.png') | |
st.image(image) | |
# st.markdown("<img src='https://centroderecursosmarista.org/wp-content/uploads/2013/05/arvix.jpg' class='center'>", unsafe_allow_html=True) | |
# st.markdown("# Arxiv.org category classifier") | |
st.markdown("# ") | |
st.markdown("### Article Title") | |
text = st.text_area("Введите название научной статьи для классификации", height=20) | |
st.markdown("### Article Abstract") | |
text = st.text_area("Введите описание статьи", height=200) | |
make_prediction(text) | |