lab2 / app.py
olya-const's picture
Update app.py
0ea3763
import streamlit as st
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import datasets
@st.cache
def load_model():
return AutoModelForSequenceClassification.from_pretrained('./')
if 'tokenizer' not in globals():
tokenizer = AutoTokenizer.from_pretrained('distilbert-base-cased')
model = load_model()
title = st.text_area('Title')
summary = st.text_area('Summary')
label_to_tag = {0: 'Computer science', 1: 'Math', 2: 'Physics',
3: 'Quantum biology', 4: 'Statistic'}
def predict(title, summary):
dataset = datasets.Dataset.from_dict({'title': [title],
'summary': [summary.replace("\n", " ")]})
dataset = tokenizer(dataset["title"], dataset['summary'],
padding="max_length", truncation=True, return_tensors='pt')
logits = model(input_ids=dataset['input_ids'],
attention_mask=dataset['attention_mask'])['logits']
probs = torch.nn.functional.softmax(logits)[0].cpu().detach()
preds = []
proba = 0.
for i in probs.argsort(descending=True).tolist():
preds.append((label_to_tag[i], probs[i].item()))
proba += probs[i]
if proba > .95:
break
return preds
if len(title) or len(summary):
preds = predict(title, summary)
st.text("Top 95% of topics")
for topic, proba in preds:
st.text(f"{topic}: {proba*100:.0f}%")