Spaces:
Runtime error
Runtime error
File size: 2,876 Bytes
b8769be 4b59c2a dee08fe 491d5a1 2c5279b c24d1f1 b8769be c24d1f1 2c5279b e62c2f3 2c5279b 0fbdf0a 2c5279b 62096a0 2c5279b beb608f 2c5279b fd871ea 2c5279b 62096a0 8393245 af657dc f07e807 26fd4ec af657dc 62096a0 d52f486 2c5279b 5e4fa04 f6a020d 5492f24 5e4fa04 e999374 62096a0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
import streamlit as st
import numpy as np
import pandas as pd
import torch
import transformers
import tokenizers
@st.cache(suppress_st_warning=True, hash_funcs={tokenizers.Tokenizer: lambda _: None})
def load_model():
from transformers import AutoTokenizer, AutoModelForSequenceClassification
model_name = 'distilbert-base-cased'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=8)
model.load_state_dict(torch.load('model_weights2.pt', map_location=torch.device('cpu')))
model.eval()
return tokenizer, model
@st.cache(suppress_st_warning=True, hash_funcs={tokenizers.Tokenizer: lambda _: None})
def predict(title, summary, tokenizer, model):
text = title + "\n" + summary
if len(text) < 20:
return 'error'
tokens = tokenizer.encode(text)
with torch.no_grad():
logits = model(torch.as_tensor([tokens]))[0]
probs = torch.softmax(logits[-1, :], dim=-1).data.cpu().numpy()
classes = np.flip(np.argsort(probs))
sum_probs = 0
ind = 0
prediction = []
prediction_probs = []
while sum_probs < 0.95:
prediction.append(label_to_theme[classes[ind]])
prediction_probs.append(str("{:.2f}".format(100 * probs[classes[ind]])) + "%")
sum_probs += probs[classes[ind]]
ind += 1
return prediction, prediction_probs
@st.cache(suppress_st_warning=True)
def get_results(prediction, prediction_probs):
frame = pd.DataFrame({'Category': prediction, 'Confidence': prediction_probs})
frame.index = np.arange(1, len(frame) + 1)
return frame
label_to_theme = {0: 'Computer science', 1: 'Economics', 2: 'Electrical Engineering and Systems Science', 3: 'Math',
4: 'Quantitative biology', 5: 'Quantitative Finance', 6: 'Statistics', 7: 'Physics'}
st.title("Arxiv articles classification")
st.markdown("<h1 style='text-align: center;'><img width=300px src='https://media.wired.com/photos/592700e3cfe0d93c474320f1/191:100/w_1200,h_630,c_limit/faces-icon.jpg'>", unsafe_allow_html=True)
st.markdown("This is an interface that can determine the article's category based on its title and summary. Though it can work with title only, it is recommended that you provide summary if possible - this will result in a better prediction quality.")
tokenizer, model = load_model()
title = st.text_area(label='Title', height=100)
summary = st.text_area(label='Summary (optional)', height=250)
button = st.button('Run')
if button:
prediction, prediction_probs = predict(title, summary, tokenizer, model)
ans = get_results(prediction, prediction_probs)
if ans == 'error':
st.error("Your input is too short. It is probably not a real article, please try again.")
else:
st.subheader('Results:')
st.write(ans) |