set_theme / app.py
seal345's picture
Update app.py
5603657
import streamlit as st
import numpy as np
import pandas as pd
import torch
import transformers
import tokenizers
from transformers import AutoTokenizer, AutoModelForSequenceClassification
def load_model():
model = AutoModelForSequenceClassification.from_pretrained('model_roberta_trained', use_auth_token=True)
tokenizer = AutoTokenizer.from_pretrained(
'roberta-base', do_lower_case=True)
model.eval()
return model, tokenizer
def get_predictions(logits, indexes):
sum = 0
ind = []
probs = []
for i in indexes:
sum += logits[i]
ind.append(i)
probs.append(indexes[i])
if sum >= 0.95:
return ind, probs
def return_pred_name(names, ind):
out = []
for i in ind:
out.append(names[i])
return out
def predict(title, summary, model, tokenizer):
text = title + '.' + summary
tokens = tokenizer.encode(text)
with torch.no_grad():
logits = model(torch.as_tensor([tokens]))[0]
probs = torch.softmax(logits[-1, :], dim=-1).data.cpu().numpy()
classes = np.flip(np.argsort(probs))
sum_probs = 0
ind = 0
prediction = []
prediction_probs = []
while sum_probs < 0.95:
prediction.append(names[classes[ind]])
prediction_probs.append(str("{:.2f}".format(100 * probs[classes[ind]])) + "%")
sum_probs += probs[classes[ind]]
ind += 1
return prediction, prediction_probs
def get_results(prediction, prediction_probs):
frame = pd.DataFrame({'Category': prediction, 'Confidence': prediction_probs})
frame.index = np.arange(1, len(frame) + 1)
return frame
names = {3: 'cs',
18: 'stat',
10: 'math',
14: 'physics',
15: 'q-bio',
0: 'astro-ph',
2: 'cond-mat',
17: 'quant-ph',
5: 'eess',
1: 'cmp-lg',
8: 'hep-ph',
6: 'gr-qc',
9: 'hep-th',
12: 'nlin',
4: 'econ',
16: 'q-fin',
7: 'hep-ex',
11: 'math-ph',
13: 'nucl-th'}
st.title("Find out the topic of the article without reading!")
st.markdown("<h1 style='text-align: center;'><img width=320px src = 'https://upload.wikimedia.org/wikipedia/ru/8/81/Sheldon_cooper.jpg'>",
unsafe_allow_html=True)
# ^-- можно показывать пользователю текст, картинки, ограниченное подмножество html - всё как в jupyter
title = st.text_area(label='Title',
value='',
height=30,
help='If you know a title type it here')
summary = st.text_area(label='Summary',
value='',
height=200,
help='If you have a summary enter it here')
button = st.button(label='Get the theme!')
if button:
if (title == '' and summary == ''):
st.write('There is nothing to analyze...')
st.write('Fill at list one of the fields')
else:
if (summary == ''):
st.write('WARNING: you have entered only the title. The accuracy of the prediction may be poor... Please enter summary to improve accuracy.')
model, tokenizer = load_model()
prediction, prediction_probs = predict(title, summary, model, tokenizer)
ans = get_results(prediction, prediction_probs)
st.write('Result')
st.write(ans)