set_theme / app.py
seal345's picture
Create app.py
0217086
raw
history blame
No virus
3.28 kB
import streamlit as st
import numpy as np
import pandas as pd
import torch
import transformers
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from scipy.special import softmax
def load_model():
model = AutoModelForSequenceClassification.from_pretrained('model_distilbert_trained')
tokenizer = AutoTokenizer.from_pretrained(
'distilbert-base-cased', do_lower_case=True)
model.eval()
return model, tokenizer
def get_predictions(logits, indexes):
sum = 0
ind = []
probs = []
for i in indexes:
sum += logits[i]
ind.append(i)
probs.append(indexes[i])
if sum >= 0.95:
return ind, probs
def return_pred_name(name_dict, ind):
out = []
for i in ind:
out.append(name_dict[i])
return out
def predict(title, summary, model, tokenizer):
text = title + '.' + summary
tokens = tokenizer.encode(text)
with torch.no_grad():
logits = model(torch.as_tensor([tokens]))[0]
probs = torch.softmax(logits[-1, :], dim=-1).data.cpu().numpy()
classes = np.flip(np.argsort(probs))
sum_probs = 0
ind = 0
prediction = []
prediction_probs = []
while sum_probs < 0.95:
prediction.append(name_dict[classes[ind]])
prediction_probs.append(str("{:.2f}".format(100 * probs[classes[ind]])) + "%")
sum_probs += probs[classes[ind]]
ind += 1
return prediction, prediction_probs
def get_results(prediction, prediction_probs):
frame = pd.DataFrame({'Category': prediction, 'Confidence': prediction_probs})
frame.index = np.arange(1, len(frame) + 1)
return frame
name_dict = {4: 'cs',
19: 'stat',
1: 'astro-ph',
16: 'q-bio',
6: 'eess',
3: 'cond-mat',
12: 'math',
15: 'physics',
18: 'quant-ph',
17: 'q-fin',
7: 'gr-qc',
13: 'nlin',
2: 'cmp-lg',
5: 'econ',
8: 'hep-ex',
11: 'hep-th',
14: 'nucl-th',
10: 'hep-ph',
9: 'hep-lat',
0: 'adap-org'}
st.title("Find out the topic of the article without reading!")
st.markdown("<h1 style='text-align: center;'><img width=320px src = 'https://upload.wikimedia.org/wikipedia/ru/8/81/Sheldon_cooper.jpg'>",
unsafe_allow_html=True)
# ^-- можно показывать пользователю текст, картинки, ограниченное подмножество html - всё как в jupyter
title = st.text_area(label='Title',
value='',
height=30,
help='If you know a title type it here')
summary = st.text_area(label='Summary',
value='',
height=200,
help='If you have a summary enter it here')
button = st.button(label='Get the theme!')
if button:
if (title == '' and summary == ''):
st.write('There is nothing to analyze...')
st.write('Fill at list one of the fields')
else:
model, tokenizer = load_model()
prediction, prediction_probs = predict(title, summary, model, tokenizer)
ans = get_results(prediction, prediction_probs)
st.write('Result')
st.write(ans)