import streamlit as st import torch from transformers import AutoModelForSequenceClassification, AutoTokenizer import datasets @st.cache def load_model(): return AutoModelForSequenceClassification.from_pretrained('./') if 'tokenizer' not in globals(): tokenizer = AutoTokenizer.from_pretrained('distilbert-base-cased') model = load_model() title = st.text_area('Title') summary = st.text_area('Summary') label_to_tag = {0: 'Computer science', 1: 'Math', 2: 'Physics', 3: 'Quantum biology', 4: 'Statistic'} def predict(title, summary): dataset = datasets.Dataset.from_dict({'title': [title], 'summary': [summary.replace("\n", " ")]}) dataset = tokenizer(dataset["title"], dataset['summary'], padding="max_length", truncation=True, return_tensors='pt') logits = model(input_ids=dataset['input_ids'], attention_mask=dataset['attention_mask'])['logits'] probs = torch.nn.functional.softmax(logits)[0].cpu().detach() preds = [] proba = 0. for i in probs.argsort(descending=True).tolist(): preds.append((label_to_tag[i], probs[i].item())) proba += probs[i] if proba > .95: break return preds if len(title) or len(summary): preds = predict(title, summary) st.text("Top 95% of topics") for topic, proba in preds: st.text(f"{topic}: {proba*100:.0f}%")