set_theme / app.py
seal345's picture
Update app.py
dc1ed99
raw
history blame
3.49 kB
import streamlit as st
import numpy as np
import pandas as pd
import torch
import transformers
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from datasets import load_dataset
def load_model():
model = AutoModelForSequenceClassification.from_pretrained('model_distilbert_trained', use_auth_token=True)
tokenizer = AutoTokenizer.from_pretrained(
'distilbert-base-cased', do_lower_case=True)
model.eval()
return model, tokenizer
def get_predictions(logits, indexes):
sum = 0
ind = []
probs = []
for i in indexes:
sum += logits[i]
ind.append(i)
probs.append(indexes[i])
if sum >= 0.95:
return ind, probs
def return_pred_name(name_dict, ind):
out = []
for i in ind:
out.append(name_dict[i])
return out
def predict(title, summary, model, tokenizer):
text = title + '.' + summary
tokens = tokenizer.encode(text)
with torch.no_grad():
logits = model(torch.as_tensor([tokens]))[0]
probs = torch.softmax(logits[-1, :], dim=-1).data.cpu().numpy()
classes = np.flip(np.argsort(probs))
sum_probs = 0
ind = 0
prediction = []
prediction_probs = []
while sum_probs < 0.95:
prediction.append(name_dict[classes[ind]])
prediction_probs.append(str("{:.2f}".format(100 * probs[classes[ind]])) + "%")
sum_probs += probs[classes[ind]]
ind += 1
return prediction, prediction_probs
def get_results(prediction, prediction_probs):
frame = pd.DataFrame({'Category': prediction, 'Confidence': prediction_probs})
frame.index = np.arange(1, len(frame) + 1)
return frame
name_dict = {4: 'cs',
19: 'stat',
1: 'astro-ph',
16: 'q-bio',
6: 'eess',
3: 'cond-mat',
12: 'math',
15: 'physics',
18: 'quant-ph',
17: 'q-fin',
7: 'gr-qc',
13: 'nlin',
2: 'cmp-lg',
5: 'econ',
8: 'hep-ex',
11: 'hep-th',
14: 'nucl-th',
10: 'hep-ph',
9: 'hep-lat',
0: 'adap-org'}
st.title("Find out the topic of the article without reading!")
st.markdown("<h1 style='text-align: center;'><img width=320px src = 'https://upload.wikimedia.org/wikipedia/ru/8/81/Sheldon_cooper.jpg'>",
unsafe_allow_html=True)
# ^-- можно показывать пользователю текст, картинки, ограниченное подмножество html - всё как в jupyter
title = st.text_area(label='Title',
value='',
height=30,
help='If you know a title type it here')
summary = st.text_area(label='Summary',
value='',
height=200,
help='If you have a summary enter it here')
button = st.button(label='Get the theme!')
if button:
if (title == '' and summary == ''):
st.write('There is nothing to analyze...')
st.write('Fill at list one of the fields')
else:
if (summary == ''):
st.write('WARNING: you have entered only the title. The accuracy of the prediction may be poor... Please enter summary to improve accuracy.')
model, tokenizer = load_model()
prediction, prediction_probs = predict(title, summary, model, tokenizer)
ans = get_results(prediction, prediction_probs)
st.write('Result')
st.write(ans)