ysda-homework / app.py
rzaytsev's picture
Update app.py
d8c71b8
import streamlit as st
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import datasets
import torch
model_name = 'distilbert-base-cased'
def load_model():
return AutoTokenizer.from_pretrained(model_name), AutoModelForSequenceClassification.from_pretrained('./')
if 'tokenizer' not in globals():
tokenizer, model = load_model()
title = st.text_area('Title')
abstract = st.text_area('Abstract')
label_to_topic_dict = dict(enumerate(['Computer Science',
'Economics',
'Electrical Engineering and Systems Science',
'Mathematics',
'Physics',
'Quantitative Biology',
'Quantitative Finance',
'Statistics']))
topic_to_label_dict = {label_to_topic_dict[key]: key for key in label_to_topic_dict.keys()}
device='cuda:0' if torch.cuda.is_available() else 'cpu'
def predict(title, abstract):
d = {'title': [title], 'abstract': [abstract]}
d = datasets.Dataset.from_dict(d)
d = tokenizer(d["title"], d['abstract'], padding="max_length", truncation=True, return_tensors='pt')
logits = model(input_ids=d['input_ids'].to(device), attention_mask=d['attention_mask'].to(device))['logits']
p = torch.nn.functional.softmax(logits)[0].cpu().detach()
preds = []
proba = 0
for index in p.argsort(descending=True).tolist():
preds.append((label_to_topic_dict[index], p[index].item()))
proba += p[index]
if proba > .95:
break
return preds
if len(title) == 0 and len(abstract) == 0:
pass
else:
output = predict(title, abstract)
st.text("Top 95% topics:")
for topic, proba in output:
st.text(f"{topic}: {proba*100:.0f}%")