set_theme / app.py
seal345's picture
Update app.py
5603657
raw history blame
No virus
3.68 kB
import streamlit as st
import numpy as np
import pandas as pd
import torch
import transformers
import tokenizers
from transformers import AutoTokenizer, AutoModelForSequenceClassification
def load_model():
model = AutoModelForSequenceClassification.from_pretrained('model_roberta_trained', use_auth_token=True)
tokenizer = AutoTokenizer.from_pretrained(
'roberta-base', do_lower_case=True)
model.eval()
return model, tokenizer
def get_predictions(logits, indexes):
sum = 0
ind = []
probs = []
for i in indexes:
sum += logits[i]
ind.append(i)
probs.append(indexes[i])
if sum >= 0.95:
return ind, probs
def return_pred_name(names, ind):
out = []
for i in ind:
out.append(names[i])
return out
def predict(title, summary, model, tokenizer):
text = title + '.' + summary
tokens = tokenizer.encode(text)
with torch.no_grad():
logits = model(torch.as_tensor([tokens]))[0]
probs = torch.softmax(logits[-1, :], dim=-1).data.cpu().numpy()
classes = np.flip(np.argsort(probs))
sum_probs = 0
ind = 0
prediction = []
prediction_probs = []
while sum_probs < 0.95:
prediction.append(names[classes[ind]])
prediction_probs.append(str("{:.2f}".format(100 * probs[classes[ind]])) + "%")
sum_probs += probs[classes[ind]]
ind += 1
return prediction, prediction_probs
def get_results(prediction, prediction_probs):
frame = pd.DataFrame({'Category': prediction, 'Confidence': prediction_probs})
frame.index = np.arange(1, len(frame) + 1)
return frame
names = {3: 'cs',
18: 'stat',
10: 'math',
14: 'physics',
15: 'q-bio',
0: 'astro-ph',
2: 'cond-mat',
17: 'quant-ph',
5: 'eess',
1: 'cmp-lg',
8: 'hep-ph',
6: 'gr-qc',
9: 'hep-th',
12: 'nlin',
4: 'econ',
16: 'q-fin',
7: 'hep-ex',
11: 'math-ph',
13: 'nucl-th'}
st.title("Find out the topic of the article without reading!")
st.markdown("<h1 style='text-align: center;'><img width=320px src = 'https://upload.wikimedia.org/wikipedia/ru/8/81/Sheldon_cooper.jpg'>",
unsafe_allow_html=True)
# ^-- можно показывать пользователю текст, картинки, ограниченное подмножество html - всё как в jupyter
title = st.text_area(label='Title',
value='',
height=30,
help='If you know a title type it here')
summary = st.text_area(label='Summary',
value='',
height=200,
help='If you have a summary enter it here')
button = st.button(label='Get the theme!')
if button:
if (title == '' and summary == ''):
st.write('There is nothing to analyze...')
st.write('Fill at list one of the fields')
else:
if (summary == ''):
st.write('WARNING: you have entered only the title. The accuracy of the prediction may be poor... Please enter summary to improve accuracy.')
model, tokenizer = load_model()
prediction, prediction_probs = predict(title, summary, model, tokenizer)
ans = get_results(prediction, prediction_probs)
st.write('Result')
st.write(ans)