qc7's picture
Migrate from wrong repo
6c8ab76
raw
history blame
1.71 kB
import streamlit as st
import numpy as np
import pandas as pd
import transformers
from transformers import TextClassificationPipeline, AutoTokenizer, AutoModelForSequenceClassification
@st.cache(suppress_st_warning=True, hash_funcs={transformers.AutoTokenizer: lambda _: None})
def load_tok_and_model():
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(".")
return tokenizer, model
CATEGORIES = ["Computer Science", "Economics", "Electrical Engineering", "Mathematics",
"Q. Biology", "Q. Finances", "Statistics" , "Physics"]
@st.cache(suppress_st_warning=True, hash_funcs={transformers.AutoTokenizer: lambda _: None})
def forward_pass(title, abstract, tokenizer, model):
title_tensor = torch.tensor(tokenizer(title, padding="max_length", truncation=True, max_length=32)['input_ids'])
abstract_tensor = torch.tensor(tokenizer(abstract, padding="max_length", truncation=True, max_length=480)['input_ids'])
embeddings = torch.cat((title_tensor, abstract_tensor))
assert embeddings.shape == (512,)
with torch.no_grad():
logits = model(embeddings[None])['logits'][0]
assert logits.shape == (8,)
probs = torch.softmax(logits).data.cpu().numpy()
return probs
st.title("Classification of arXiv articles' main topic")
st.markdown("Please provide both summary and title when possible")
tokenizer, model = load_tok_and_model()
title = st.text_area(label='Title', height=200)
abstract = st.text_area(label='Abstract', height=200)
button = st.button('Run classifier')
if button:
probs = forward_pass(title, abstract, tokenizer, model)
st.write(probs)