import streamlit as st import numpy as np import pandas as pd import torch import transformers import tokenizers from transformers import AutoTokenizer, AutoModelForSequenceClassification def load_model(): model = AutoModelForSequenceClassification.from_pretrained('model_roberta_trained', use_auth_token=True) tokenizer = AutoTokenizer.from_pretrained( 'roberta-base', do_lower_case=True) model.eval() return model, tokenizer def get_predictions(logits, indexes): sum = 0 ind = [] probs = [] for i in indexes: sum += logits[i] ind.append(i) probs.append(indexes[i]) if sum >= 0.95: return ind, probs def return_pred_name(names, ind): out = [] for i in ind: out.append(names[i]) return out def predict(title, summary, model, tokenizer): text = title + '.' + summary tokens = tokenizer.encode(text) with torch.no_grad(): logits = model(torch.as_tensor([tokens]))[0] probs = torch.softmax(logits[-1, :], dim=-1).data.cpu().numpy() classes = np.flip(np.argsort(probs)) sum_probs = 0 ind = 0 prediction = [] prediction_probs = [] while sum_probs < 0.95: prediction.append(names[classes[ind]]) prediction_probs.append(str("{:.2f}".format(100 * probs[classes[ind]])) + "%") sum_probs += probs[classes[ind]] ind += 1 return prediction, prediction_probs def get_results(prediction, prediction_probs): frame = pd.DataFrame({'Category': prediction, 'Confidence': prediction_probs}) frame.index = np.arange(1, len(frame) + 1) return frame names = {3: 'cs', 18: 'stat', 10: 'math', 14: 'physics', 15: 'q-bio', 0: 'astro-ph', 2: 'cond-mat', 17: 'quant-ph', 5: 'eess', 1: 'cmp-lg', 8: 'hep-ph', 6: 'gr-qc', 9: 'hep-th', 12: 'nlin', 4: 'econ', 16: 'q-fin', 7: 'hep-ex', 11: 'math-ph', 13: 'nucl-th'} st.title("Find out the topic of the article without reading!") st.markdown("

", unsafe_allow_html=True) # ^-- можно показывать пользователю текст, картинки, ограниченное подмножество html - всё как в jupyter title = st.text_area(label='Title', value='', height=30, help='If you know a title type it here') summary = st.text_area(label='Summary', value='', height=200, help='If you have a summary enter it here') button = st.button(label='Get the theme!') if button: if (title == '' and summary == ''): st.write('There is nothing to analyze...') st.write('Fill at list one of the fields') else: if (summary == ''): st.write('WARNING: you have entered only the title. The accuracy of the prediction may be poor... Please enter summary to improve accuracy.') model, tokenizer = load_model() prediction, prediction_probs = predict(title, summary, model, tokenizer) ans = get_results(prediction, prediction_probs) st.write('Result') st.write(ans)