set_theme / app.py
seal345's picture
Upload app.py
ff5406d
import streamlit as st
import numpy as np
import pandas as pd
import torch
import transformers
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from scipy.special import softmax
def load_model():
model = AutoModelForSequenceClassification.from_pretrained('model_distilbert_trained')
tokenizer = AutoTokenizer.from_pretrained(
'distilbert-base-cased', do_lower_case=True)
model.eval()
return model, tokenizer
def get_predictions(logits, indexes):
sum = 0
ind = []
probs = []
for i in indexes:
sum += logits[i]
ind.append(i)
probs.append(indexes[i])
if sum >= 0.95:
return ind, probs
def return_pred_name(name_dict, ind):
out = []
for i in ind:
out.append(name_dict[i])
return out
def predict(title, summary, model, tokenizer):
text = title + '.' + summary
tokens = tokenizer.encode(text)
with torch.no_grad():
logits = model(torch.as_tensor([tokens]))[0]
probs = torch.softmax(logits[-1, :], dim=-1).data.cpu().numpy()
classes = np.flip(np.argsort(probs))
sum_probs = 0
ind = 0
prediction = []
prediction_probs = []
while sum_probs < 0.95:
prediction.append(name_dict[classes[ind]])
prediction_probs.append(str("{:.2f}".format(100 * probs[classes[ind]])) + "%")
sum_probs += probs[classes[ind]]
ind += 1
return prediction, prediction_probs
def get_results(prediction, prediction_probs):
frame = pd.DataFrame({'Category': prediction, 'Confidence': prediction_probs})
frame.index = np.arange(1, len(frame) + 1)
return frame
name_dict = {4: 'cs',
19: 'stat',
1: 'astro-ph',
16: 'q-bio',
6: 'eess',
3: 'cond-mat',
12: 'math',
15: 'physics',
18: 'quant-ph',
17: 'q-fin',
7: 'gr-qc',
13: 'nlin',
2: 'cmp-lg',
5: 'econ',
8: 'hep-ex',
11: 'hep-th',
14: 'nucl-th',
10: 'hep-ph',
9: 'hep-lat',
0: 'adap-org'}
st.title("Find out the topic of the article without reading!")
st.markdown("<h1 style='text-align: center;'><img width=320px src = 'https://upload.wikimedia.org/wikipedia/ru/8/81/Sheldon_cooper.jpg'>",
unsafe_allow_html=True)
# ^-- можно показывать пользователю текст, картинки, ограниченное подмножество html - всё как в jupyter
title = st.text_area(label='Title',
value='',
height=30,
help='If you know a title type it here')
summary = st.text_area(label='Summary',
value='',
height=200,
help='If you have a summary enter it here')
button = st.button(label='Get the theme!')
if button:
if (title == '' and summary == ''):
st.write('There is nothing to analyze...')
st.write('Fill at list one of the fields')
else:
model, tokenizer = load_model()
prediction, prediction_probs = predict(title, summary, model, tokenizer)
ans = get_results(prediction, prediction_probs)
st.write('Result')
st.write(ans)
# ^-- показать текстовое поле. В поле text лежит строка, которая находится там в данный момент
#from transformers import pipeline
#pipe = pipeline("ner", "Davlan/distilbert-base-multilingual-cased-ner-hrl")
#raw_predictions = pipe(text)
# тут уже знакомый вам код с huggingface.transformers -- его можно заменить на что угодно от fairseq до catboost
#st.markdown(f"{raw_predictions}")
# выводим результаты модели в текстовое поле, на потеху пользователю