Spaces:
Runtime error
Runtime error
import streamlit as st | |
import numpy as np | |
import pandas as pd | |
import torch | |
import transformers | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
from datasets import load_dataset | |
def load_model(): | |
model = AutoModelForSequenceClassification.from_pretrained('model_distilbert_trained', use_auth_token=True) | |
tokenizer = AutoTokenizer.from_pretrained( | |
'distilbert-base-cased', do_lower_case=True) | |
model.eval() | |
return model, tokenizer | |
def get_predictions(logits, indexes): | |
sum = 0 | |
ind = [] | |
probs = [] | |
for i in indexes: | |
sum += logits[i] | |
ind.append(i) | |
probs.append(indexes[i]) | |
if sum >= 0.95: | |
return ind, probs | |
def return_pred_name(name_dict, ind): | |
out = [] | |
for i in ind: | |
out.append(name_dict[i]) | |
return out | |
def predict(title, summary, model, tokenizer): | |
text = title + '.' + summary | |
tokens = tokenizer.encode(text) | |
with torch.no_grad(): | |
logits = model(torch.as_tensor([tokens]))[0] | |
probs = torch.softmax(logits[-1, :], dim=-1).data.cpu().numpy() | |
classes = np.flip(np.argsort(probs)) | |
sum_probs = 0 | |
ind = 0 | |
prediction = [] | |
prediction_probs = [] | |
while sum_probs < 0.95: | |
prediction.append(name_dict[classes[ind]]) | |
prediction_probs.append(str("{:.2f}".format(100 * probs[classes[ind]])) + "%") | |
sum_probs += probs[classes[ind]] | |
ind += 1 | |
return prediction, prediction_probs | |
def get_results(prediction, prediction_probs): | |
frame = pd.DataFrame({'Category': prediction, 'Confidence': prediction_probs}) | |
frame.index = np.arange(1, len(frame) + 1) | |
return frame | |
name_dict = {4: 'cs', | |
19: 'stat', | |
1: 'astro-ph', | |
16: 'q-bio', | |
6: 'eess', | |
3: 'cond-mat', | |
12: 'math', | |
15: 'physics', | |
18: 'quant-ph', | |
17: 'q-fin', | |
7: 'gr-qc', | |
13: 'nlin', | |
2: 'cmp-lg', | |
5: 'econ', | |
8: 'hep-ex', | |
11: 'hep-th', | |
14: 'nucl-th', | |
10: 'hep-ph', | |
9: 'hep-lat', | |
0: 'adap-org'} | |
st.title("Find out the topic of the article without reading!") | |
st.markdown("<h1 style='text-align: center;'><img width=320px src = 'https://upload.wikimedia.org/wikipedia/ru/8/81/Sheldon_cooper.jpg'>", | |
unsafe_allow_html=True) | |
# ^-- можно показывать пользователю текст, картинки, ограниченное подмножество html - всё как в jupyter | |
title = st.text_area(label='Title', | |
value='', | |
height=30, | |
help='If you know a title type it here') | |
summary = st.text_area(label='Summary', | |
value='', | |
height=200, | |
help='If you have a summary enter it here') | |
button = st.button(label='Get the theme!') | |
if button: | |
if (title == '' and summary == ''): | |
st.write('There is nothing to analyze...') | |
st.write('Fill at list one of the fields') | |
else: | |
if (summary == ''): | |
st.write('WARNING: you have entered only the title. The accuracy of the prediction may be poor... Please enter summary to improve accuracy.') | |
model, tokenizer = load_model() | |
prediction, prediction_probs = predict(title, summary, model, tokenizer) | |
ans = get_results(prediction, prediction_probs) | |
st.write('Result') | |
st.write(ans) | |