Spaces:
Runtime error
Runtime error
import streamlit as st | |
import torch | |
import numpy as np | |
from transformers import TrainingArguments, \ | |
Trainer, AutoTokenizer, DataCollatorWithPadding, \ | |
AutoModelForSequenceClassification | |
categories = ['biology', 'computer science', 'economics', 'electrics', 'finance', | |
'math', 'physics', 'statistics'] | |
def print_probs(logits): | |
probs = torch.nn.functional.softmax(logits, dim=0).numpy()*100 | |
ans = list(zip(probs,labels)) | |
ans.sort(reverse=True) | |
sum = 0 | |
i = 0 | |
while sum <= 95: | |
prob, idx = ans[i] | |
text = categories[idx] + ": "+ str(np.round(prob,1)) | |
st.markdown(text) | |
sum+=prob | |
i+=1 | |
def make_prediction(text): | |
tokenized_text = tokenizer(text, return_tensors='pt') | |
with torch.no_grad(): | |
pred_logits = model(**tokenized_text).logits | |
st.markdown("Predictions:") | |
print_probs(pred_logits[0]) | |
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased") | |
model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=8) | |
model_name = "trained_model2" | |
model_path = model_name + '.zip' | |
model.load_state_dict( | |
torch.load( | |
model_path, | |
map_location=torch.device("cpu") | |
) | |
) | |
st.markdown("##Hello, people!") | |
st.markdown("<img src='https://centroderecursosmarista.org/wp-content/uploads/2013/05/arvix.jpg'>", unsafe_allow_html=True) | |
text = st.text_area("Введите описание статьи") | |
make_prediction(text) | |