File size: 1,454 Bytes
d71c37d
dc449cf
 
 
d71c37d
7c14be4
dc449cf
88b4df7
d71c37d
0ea3763
 
88b4df7
dc449cf
 
 
 
 
 
 
 
 
 
 
 
88b4df7
 
dc449cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import streamlit as st
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import datasets

@st.cache
def load_model():
     return AutoModelForSequenceClassification.from_pretrained('./')    

if 'tokenizer' not in globals():
    tokenizer = AutoTokenizer.from_pretrained('distilbert-base-cased')
model = load_model()

title = st.text_area('Title')
summary = st.text_area('Summary')

label_to_tag = {0: 'Computer science', 1: 'Math', 2: 'Physics',
                  3: 'Quantum biology', 4: 'Statistic'}

def predict(title, summary):
    dataset = datasets.Dataset.from_dict({'title': [title], 
                                          'summary': [summary.replace("\n", " ")]})
    dataset = tokenizer(dataset["title"], dataset['summary'],
                        padding="max_length", truncation=True, return_tensors='pt')
    logits = model(input_ids=dataset['input_ids'],
                   attention_mask=dataset['attention_mask'])['logits']
    probs = torch.nn.functional.softmax(logits)[0].cpu().detach()
    preds = []
    proba = 0.
    for i in probs.argsort(descending=True).tolist():
        preds.append((label_to_tag[i], probs[i].item()))
        proba += probs[i]
        if proba > .95:
            break
    return preds

if len(title) or len(summary):
    preds = predict(title, summary)
    st.text("Top 95% of topics")
    for topic, proba in preds:
        st.text(f"{topic}: {proba*100:.0f}%")