import numpy as np import pandas as pd import transformers import torch import tokenizers import streamlit as st NUM_LABELS = 15 labels_names = { 0: 'Astrophysics', 1: 'Condensed Matter', 2: 'Computer Science', 3: 'Economics', 4: 'Electrical Engineering and Systems Science', 5: 'General Relativity and Quantum Cosmology', 6: 'High Energy Physics', 7: 'Mathematics', 8: 'Nonlinear Sciences', 9: 'Nuclear Theory', 10: 'General Physics', 11: 'Quantitative Biology', 12: 'Quantitative Finance', 13: 'Quantum Physics', 14: 'Statistics', } @st.cache(hash_funcs={tokenizers.Tokenizer: lambda _: None}, suppress_st_warning=True) def get_model(model_name, model_path): tokenizer = transformers.AutoTokenizer.from_pretrained(model_name) model = transformers.AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=NUM_LABELS) model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) model.eval() return model, tokenizer @st.cache(hash_funcs={tokenizers.Tokenizer: lambda _: None}, suppress_st_warning=True) def predict(text, tokenizer, model, temperature = 1): tokens = tokenizer.encode(text) with torch.no_grad(): logits = model.cpu()(torch.as_tensor([tokens]))[0] probs = torch.softmax(logits[-1, :] / temperature, dim=-1).data.cpu().numpy() indexes_descending = np.argsort(probs)[::-1] percents = 0 preds = [] pred_probs = [] for index in indexes_descending: preds.append(labels_names[index]) pred_prob = 100 * probs[index] pred_probs.append(f"{pred_prob:.1f}%") percents += pred_prob if percents >= 95: break result = pd.DataFrame({'Probability': pred_probs}) result.index = preds return result model, tokenizer = get_model('bert-base-cased', 'bert-checkpoint-14644.bin') st.title("Yandex School of Data Analysis. ML course") st.title("Laboratory work 2: classifier of categories of scientific papers") st.markdown("", unsafe_allow_html=True) st.markdown("\n") st.markdown("Enter the title of the article and its abstract (although, if you really don't want to, you can do with just the title)") title = st.text_area(label='Title of the article', height=100) abstract = st.text_area(label='Abstract of the article', height=200) button = st.button('Go') if button: try: text = ' [ABSTRACT] '.join([title, abstract]) result = predict(text, tokenizer, model) if len(text) > 10: st.subheader('Bumblebee thinks, this paper related to') st.write(result) else: st.error("Enter some more info please") except Exception: st.error("Ooooops, something went wrong. Try again please and report to me, tg: @vladyur")