Spaces:
Runtime error
Runtime error
import streamlit as st | |
import torch | |
from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
import datasets | |
def load_model(): | |
return AutoModelForSequenceClassification.from_pretrained('./') | |
if 'tokenizer' not in globals(): | |
tokenizer = AutoTokenizer.from_pretrained('distilbert-base-cased') | |
model = load_model() | |
title = st.text_area('Title') | |
summary = st.text_area('Summary') | |
label_to_tag = {0: 'Computer science', 1: 'Math', 2: 'Physics', | |
3: 'Quantum biology', 4: 'Statistic'} | |
def predict(title, summary): | |
dataset = datasets.Dataset.from_dict({'title': [title], | |
'summary': [summary.replace("\n", " ")]}) | |
dataset = tokenizer(dataset["title"], dataset['summary'], | |
padding="max_length", truncation=True, return_tensors='pt') | |
logits = model(input_ids=dataset['input_ids'], | |
attention_mask=dataset['attention_mask'])['logits'] | |
probs = torch.nn.functional.softmax(logits)[0].cpu().detach() | |
preds = [] | |
proba = 0. | |
for i in probs.argsort(descending=True).tolist(): | |
preds.append((label_to_tag[i], probs[i].item())) | |
proba += probs[i] | |
if proba > .95: | |
break | |
return preds | |
if len(title) or len(summary): | |
preds = predict(title, summary) | |
st.text("Top 95% of topics") | |
for topic, proba in preds: | |
st.text(f"{topic}: {proba*100:.0f}%") | |