YSDA_LAB2 / app.py
dmitrysluch's picture
Update app.py
1ef8cd1
import streamlit as st
import json
import pandas as pd
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForSequenceClassification
all_tags = '["cs.AI","cs.AR","cs.CC","cs.CE","cs.CG","cs.CL","cs.CR","cs.CV","cs.CY","cs.DB","cs.DC","cs.DL","cs.DM","cs.DS","cs.ET","cs.FL","cs.GL","cs.GR","cs.GT","cs.HC","cs.IR","cs.IT","cs.LG","cs.LO","cs.MA","cs.MM","cs.MS","cs.NA","cs.NE","cs.NI","cs.OH","cs.OS","cs.PF","cs.PL","cs.RO","cs.SC","cs.SD","cs.SE","cs.SI","cs.SY","econ.EM","econ.GN","econ.TH","eess.AS","eess.IV","eess.SP","eess.SY","math.AC","math.AG","math.AP","math.AT","math.CA","math.CO","math.CT","math.CV","math.DG","math.DS","math.FA","math.GM","math.GN","math.GR","math.GT","math.HO","math.IT","math.KT","math.LO","math.MG","math.MP","math.NA","math.NT","math.OA","math.OC","math.PR","math.QA","math.RA","math.RT","math.SG","math.SP","math.ST","astro-ph.CO","astro-ph.EP","astro-ph.GA","astro-ph.HE","astro-ph.IM","astro-ph.SR","cond-mat.dis-nn","cond-mat.mes-hall","cond-mat.mtrl-sci","cond-mat.other","cond-mat.quant-gas","cond-mat.soft","cond-mat.stat-mech","cond-mat.str-el","cond-mat.supr-con","gr-qc","hep-ex","hep-lat","hep-ph","hep-th","math-ph","nlin.AO","nlin.CD","nlin.CG","nlin.PS","nlin.SI","nucl-ex","nucl-th","physics.acc-ph","physics.ao-ph","physics.app-ph","physics.atm-clus","physics.atom-ph","physics.bio-ph","physics.chem-ph","physics.class-ph","physics.comp-ph","physics.data-an","physics.ed-ph","physics.flu-dyn","physics.gen-ph","physics.geo-ph","physics.hist-ph","physics.ins-det","physics.med-ph","physics.optics","physics.plasm-ph","physics.pop-ph","physics.soc-ph","physics.space-ph","quant-ph","q-bio.BM","q-bio.CB","q-bio.GN","q-bio.MN","q-bio.NC","q-bio.OT","q-bio.PE","q-bio.QM","q-bio.SC","q-bio.TO","q-fin.CP","q-fin.EC","q-fin.GN","q-fin.MF","q-fin.PM","q-fin.PR","q-fin.RM","q-fin.ST","q-fin.TR","stat.AP","stat.CO","stat.ME","stat.ML","stat.OT","stat.TH"]'
all_tags = json.loads(all_tags)
st.set_page_config(
page_title="Classify article", page_icon="πŸ΄β€β˜ οΈ")
st.title('πŸ΄β€β˜ οΈ Classify a scientific article')
text = st.text_area("Type article title and/or summary",
placeholder="Probabilistic Constructions of Computable Objects and a Computable Version of Lovasz Local Lemma", height=5, max_chars=4096)
def get_tokenizer_and_model():
model_name = 'distilbert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=len(all_tags))
with open('distilbert.pt', 'rb') as f:
sd = torch.load(f, map_location=torch.device('cpu'))
model.load_state_dict(sd)
model.eval()
return tokenizer, model
if text:
tokenizer, model = get_tokenizer_and_model()
tokenized = tokenizer([text], padding=True, truncation=True, return_tensors="pt")
prob = model(**tokenized)['logits']
prob = F.sigmoid(prob)
df = pd.DataFrame(data=prob.detach().numpy().reshape(-1,1), index=all_tags, columns=['probability'])
df=df.sort_values(by='probability', ascending=False)
total_prob = 0
max_i = df.shape[0]
for idx, p in enumerate(df['probability']):
total_prob += p
if total_prob > 0.95:
max_i = idx + 1
df = df.head(max_i)
st.table(df)