import streamlit as st from transformers import AutoModelForSequenceClassification, AutoTokenizer import datasets import torch model_name = 'distilbert-base-cased' def load_model(): return AutoTokenizer.from_pretrained(model_name), AutoModelForSequenceClassification.from_pretrained('./') if 'tokenizer' not in globals(): tokenizer, model = load_model() title = st.text_area('Title') abstract = st.text_area('Abstract') label_to_topic_dict = dict(enumerate(['Computer Science', 'Economics', 'Electrical Engineering and Systems Science', 'Mathematics', 'Physics', 'Quantitative Biology', 'Quantitative Finance', 'Statistics'])) topic_to_label_dict = {label_to_topic_dict[key]: key for key in label_to_topic_dict.keys()} device='cuda:0' if torch.cuda.is_available() else 'cpu' def predict(title, abstract): d = {'title': [title], 'abstract': [abstract]} d = datasets.Dataset.from_dict(d) d = tokenizer(d["title"], d['abstract'], padding="max_length", truncation=True, return_tensors='pt') logits = model(input_ids=d['input_ids'].to(device), attention_mask=d['attention_mask'].to(device))['logits'] p = torch.nn.functional.softmax(logits)[0].cpu().detach() preds = [] proba = 0 for index in p.argsort(descending=True).tolist(): preds.append((label_to_topic_dict[index], p[index].item())) proba += p[index] if proba > .95: break return preds if len(title) == 0 and len(abstract) == 0: pass else: output = predict(title, abstract) st.text("Top 95% topics:") for topic, proba in output: st.text(f"{topic}: {proba*100:.0f}%")