Spaces:
Runtime error
Runtime error
File size: 1,454 Bytes
d71c37d dc449cf d71c37d 7c14be4 dc449cf 88b4df7 d71c37d 0ea3763 88b4df7 dc449cf 88b4df7 dc449cf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
import streamlit as st
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import datasets
@st.cache
def load_model():
return AutoModelForSequenceClassification.from_pretrained('./')
if 'tokenizer' not in globals():
tokenizer = AutoTokenizer.from_pretrained('distilbert-base-cased')
model = load_model()
title = st.text_area('Title')
summary = st.text_area('Summary')
label_to_tag = {0: 'Computer science', 1: 'Math', 2: 'Physics',
3: 'Quantum biology', 4: 'Statistic'}
def predict(title, summary):
dataset = datasets.Dataset.from_dict({'title': [title],
'summary': [summary.replace("\n", " ")]})
dataset = tokenizer(dataset["title"], dataset['summary'],
padding="max_length", truncation=True, return_tensors='pt')
logits = model(input_ids=dataset['input_ids'],
attention_mask=dataset['attention_mask'])['logits']
probs = torch.nn.functional.softmax(logits)[0].cpu().detach()
preds = []
proba = 0.
for i in probs.argsort(descending=True).tolist():
preds.append((label_to_tag[i], probs[i].item()))
proba += probs[i]
if proba > .95:
break
return preds
if len(title) or len(summary):
preds = predict(title, summary)
st.text("Top 95% of topics")
for topic, proba in preds:
st.text(f"{topic}: {proba*100:.0f}%")
|