Spaces:
Runtime error
Runtime error
import streamlit as st | |
import numpy as np | |
import pandas as pd | |
import transformers | |
from transformers import TextClassificationPipeline, AutoTokenizer, AutoModelForSequenceClassification | |
def load_tok_and_model(): | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForSequenceClassification.from_pretrained(".") | |
return tokenizer, model | |
CATEGORIES = ["Computer Science", "Economics", "Electrical Engineering", "Mathematics", | |
"Q. Biology", "Q. Finances", "Statistics" , "Physics"] | |
def forward_pass(title, abstract, tokenizer, model): | |
title_tensor = torch.tensor(tokenizer(title, padding="max_length", truncation=True, max_length=32)['input_ids']) | |
abstract_tensor = torch.tensor(tokenizer(abstract, padding="max_length", truncation=True, max_length=480)['input_ids']) | |
embeddings = torch.cat((title_tensor, abstract_tensor)) | |
assert embeddings.shape == (512,) | |
with torch.no_grad(): | |
logits = model(embeddings[None])['logits'][0] | |
assert logits.shape == (8,) | |
probs = torch.softmax(logits).data.cpu().numpy() | |
return probs | |
st.title("Classification of arXiv articles' main topic") | |
st.markdown("Please provide both summary and title when possible") | |
tokenizer, model = load_tok_and_model() | |
title = st.text_area(label='Title', height=200) | |
abstract = st.text_area(label='Abstract', height=200) | |
button = st.button('Run classifier') | |
if button: | |
probs = forward_pass(title, abstract, tokenizer, model) | |
st.write(probs) | |