import torch import pandas as pd import streamlit as st import torch.nn as nn from transformers import RobertaTokenizer, RobertaModel @st.cache(suppress_st_warning=True) def init_model(): model = RobertaModel.from_pretrained("roberta-large-mnli") model.pooler = nn.Sequential( nn.Linear(1024, 256), nn.LayerNorm(256), nn.ReLU(), nn.Linear(256, 8), nn.Sigmoid() ) model_path = "model.pt" model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu"))) model.eval() return model cats = ["Computer Science", "Economics", "Electrical Engineering", "Mathematics", "Physics", "Biology", "Finance", "Statistics"] def predict(outputs): top = 0 temp = 100000 apr_probs = torch.nn.functional.softmax(torch.tensor([39253., 84., 220., 2263., 1214., 909., 66., 10661.]) / temp, dim=0) probs = nn.functional.softmax(outputs / apr_probs, dim=1).tolist()[0] top_cats = [] top_probs = [] first = True write_cs = False for prob, cat in sorted(zip(probs, cats), reverse=True): if first: if cat == "Computer Science": write_cs = True first = False if top < 95: percent = prob * 100 top += percent top_cats.append(cat) top_probs.append(str(round(percent, 1))) res = pd.DataFrame(top_probs, index=top_cats, columns=['Percent']) st.write(res) if write_cs: st.write("Today everything is connected with Computer Science") tokenizer = RobertaTokenizer.from_pretrained("roberta-large-mnli") model = init_model() st.title("Article classifier") st.markdown("", unsafe_allow_html=True) st.markdown("### Title") title = st.text_area("*Enter title (required)", height=20) st.markdown("### Abstract") abstract = st.text_area(" Enter abstract", height=200) if not title: st.warning("Please fill in required fields") else: try: st.markdown("### Result") encoded_input = tokenizer(title + ". " + abstract, return_tensors="pt", padding=True, max_length=1024, truncation=True) with torch.no_grad(): outputs = model(**encoded_input).pooler_output[:, 0, :] predict(outputs) except Exception: st.error("Something went wrong. Try different text or contact me. Telegram: @rrevoid")