yhn112's picture
Add application file
9a72fe1
import torch
import pandas as pd
import streamlit as st
import torch.nn as nn
from transformers import RobertaTokenizer, RobertaModel, PretrainedConfig
@st.cache_resource
def init_model():
model = RobertaModel(config=PretrainedConfig().from_pretrained("roberta-large-mnli"))
model.pooler = nn.Sequential(
nn.Linear(1024, 256),
nn.LayerNorm(256),
nn.ReLU(),
nn.Linear(256, 8),
nn.Sigmoid()
)
model_path = "model.pt"
model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
model.eval()
return model
cats = ["Computer Science", "Economics", "Electrical Engineering",
"Mathematics", "Physics", "Biology", "Finance", "Statistics"]
def predict(outputs):
top = 0
temp = 100000
apr_probs = torch.nn.functional.softmax(torch.tensor([39253., 84., 220., 2263., 1214., 909., 66., 10661.]) / temp, dim=0)
probs = nn.functional.softmax(outputs / apr_probs, dim=1).tolist()[0]
top_cats = []
top_probs = []
first = True
write_cs = False
for prob, cat in sorted(zip(probs, cats), reverse=True):
if first:
if cat == "Computer Science":
write_cs = True
first = False
if top < 95:
percent = prob * 100
top += percent
top_cats.append(cat)
top_probs.append(str(round(percent, 1)))
res = pd.DataFrame(top_probs, index=top_cats, columns=['Percent'])
st.write(res)
if write_cs:
st.write("Today everything is connected with Computer Science")
tokenizer = RobertaTokenizer.from_pretrained("roberta-large-mnli")
model = init_model()
st.title("Article classifier")
st.markdown("### Title")
title = st.text_input("*Enter title (required)")
st.markdown("### Abstract")
abstract = st.text_area(" Enter abstract", height=200)
if not title:
st.warning("Please fill in required fields")
else:
try:
st.markdown("### Result")
encoded_input = tokenizer(title + ". " + abstract, return_tensors="pt", padding=True,
max_length=1024, truncation=True)
with torch.no_grad():
outputs = model(**encoded_input).pooler_output[:, 0, :]
predict(outputs)
except Exception:
st.error("Something went wrong. Try different text")