import streamlit as st import numpy as np import pandas as pd import torch import transformers from transformers import TextClassificationPipeline, AutoTokenizer, AutoModelForSequenceClassification @st.cache(suppress_st_warning=True, hash_funcs={transformers.AutoTokenizer: lambda _: None}) def load_tok_and_model(): tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSequenceClassification.from_pretrained(".") return tokenizer, model CATEGORIES = ["Computer Science", "Economics", "Electrical Engineering", "Mathematics", "Q. Biology", "Q. Finances", "Statistics" , "Physics"] @st.cache(suppress_st_warning=True, hash_funcs={transformers.AutoTokenizer: lambda _: None}) def forward_pass(title, abstract, tokenizer, model): title_tensor = torch.tensor(tokenizer(title, padding="max_length", truncation=True, max_length=32)['input_ids']) abstract_tensor = torch.tensor(tokenizer(abstract, padding="max_length", truncation=True, max_length=480)['input_ids']) embeddings = torch.cat((title_tensor, abstract_tensor)) assert embeddings.shape == (512,) with torch.no_grad(): logits = model(embeddings[None])['logits'][0] assert logits.shape == (8,) probs = torch.softmax(logits).data.cpu().numpy() return probs st.title("Classification of arXiv articles' main topic") st.markdown("Please provide both summary and title when possible") tokenizer, model = load_tok_and_model() title = st.text_area(label='Title', height=200) abstract = st.text_area(label='Abstract', height=200) button = st.button('Run classifier') if button: probs = forward_pass(title, abstract, tokenizer, model) st.write(probs)