import streamlit as st import torch from torch.nn import functional as F from transformers import AutoTokenizer, AutoModelForSequenceClassification import json import streamlit.components.v1 as components if __name__ == '__main__': st.markdown("### Arxiv paper classifier (No guarantees provided)") col1, col2 = st.columns([1, 1]) col1.image('imgs/akinator_ready.png', width=200) btn = col2.button('Classify!') model = AutoModelForSequenceClassification.from_pretrained('checkpoint-3000') tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased") with open('checkpoint-3000/config.json', 'r') as f: id2label = json.load(f)['id2label'] id2label = {int(key): value for key, value in id2label.items()} title = st.text_area(label='', placeholder='Input title...', height=3) abstract = st.text_area(label='', placeholder='Input abstract...', height=10) text = '\n'.join([title, abstract]) if btn and len(text) == 1: st.error('Title and abstract are empty!') if btn and len(text) > 1: tokenized = tokenizer(text) with torch.no_grad(): out = model(torch.tensor(tokenized['input_ids']).unsqueeze(dim=0)) _, ids = torch.sort(-out['logits']) probs = F.softmax(out['logits'][0, ids], dim=1) ids, probs = ids[0], probs[0] ptotal = 0 result = [] for i, prob in enumerate(probs): ptotal += prob result.append(f'{id2label[ids[i].item()]} (prob = {prob.item()})') output = '
'.join(result) components.html(f'
' f'
' f'{output}
')