ysda_homework / app.py
sof's picture
Upload app.py
78bef13
import streamlit as st
import transformers
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
readable_labels = {'cs.AI' :'Artificial Intelligence', 'cs.AR' :'Hardware Architecture', 'cs.CC' :'Computational Complexity', 'cs.CE' :'Computational Engineering, Finance, and Science', 'cs.CG' :'Computational Geometry', 'cs.CL' :'Computation and Language', 'cs.CR' :'Cryptography and Security', 'cs.CV' :'Computer Vision and Pattern Recognition', 'cs.CY' :'Computers and Society', 'cs.DB' :'Databases', 'cs.DC' :'Distributed, Parallel, and Cluster Computing', 'cs.DL' :'Digital Libraries', 'cs.DM' :'Discrete Mathematics', 'cs.DS' :'Data Structures and Algorithms', 'cs.ET' :'Emerging Technologies', 'cs.FL' :'Formal Languages and Automata Theory', 'cs.GL' :'General Literature', 'cs.GR' :'Graphics', 'cs.GT' :'Computer Science and Game Theory', 'cs.HC' :'Human-Computer Interaction', 'cs.IR' :'Information Retrieval', 'cs.IT' :'Information Theory', 'cs.LG' :'Machine Learning', 'cs.LO' :'Logic in Computer Science', 'cs.MA' :'Multiagent Systems', 'cs.MM' :'Multimedia', 'cs.MS' :'Mathematical Software', 'cs.NA' :'Numerical Analysis', 'cs.NE' :'Neural and Evolutionary Computing', 'cs.NI' :'Networking and Internet Architecture', 'cs.OH' :'Other Computer Science', 'cs.OS' :'Operating Systems', 'cs.PF' :'Performance', 'cs.PL' :'Programming Languages', 'cs.RO' :'Robotics', 'cs.SC' :'Symbolic Computation', 'cs.SD' :'Sound', 'cs.SE' :'Software Engineering', 'cs.SI' :'Social and Information Networks', 'cs.SY' :'Systems and Control', 'econ.EM' :'Econometrics', 'econ.GN' :'General Economics', 'econ.TH' :'Theoretical Economics', 'eess.AS' :'Audio and Speech Processing', 'eess.IV' :'Image and Video Processing', 'eess.SP' :'Signal Processing', 'eess.SY' :'Systems and Control', 'math.AC' :'Commutative Algebra', 'math.AG' :'Algebraic Geometry', 'math.AP' :'Analysis of PDEs', 'math.AT' :'Algebraic Topology', 'math.CA' :'Classical Analysis and ODEs', 'math.CO' :'Combinatorics', 'math.CT' :'Category Theory', 'math.CV' :'Complex Variables', 'math.DG' :'Differential Geometry', 'math.DS' :'Dynamical Systems', 'math.FA' :'Functional Analysis', 'math.GM' :'General Mathematics', 'math.GN' :'General Topology', 'math.GR' :'Group Theory', 'math.GT' :'Geometric Topology', 'math.HO' :'History and Overview', 'math.IT' :'Information Theory', 'math.KT' :'K-Theory and Homology', 'math.LO' :'Logic', 'math.MG' :'Metric Geometry', 'math.MP' :'Mathematical Physics', 'math.NA' :'Numerical Analysis', 'math.NT' :'Number Theory', 'math.OA' :'Operator Algebras', 'math.OC' :'Optimization and Control', 'math.PR' :'Probability', 'math.QA' :'Quantum Algebra', 'math.RA' :'Rings and Algebras', 'math.RT' :'Representation Theory', 'math.SG' :'Symplectic Geometry', 'math.SP' :'Spectral Theory', 'math.ST' :'Statistics Theory', 'astro-ph.CO' :'Cosmology and Nongalactic Astrophysics', 'astro-ph.EP' :'Earth and Planetary Astrophysics', 'astro-ph.GA' :'Astrophysics of Galaxies', 'astro-ph.HE' :'High Energy Astrophysical Phenomena', 'astro-ph.IM' :'Instrumentation and Methods for Astrophysics', 'astro-ph.SR' :'Solar and Stellar Astrophysics', 'cond-mat.dis-nn' :'Disordered Systems and Neural Networks', 'cond-mat.mes-hall' :'Mesoscale and Nanoscale Physics', 'cond-mat.mtrl-sci' :'Materials Science', 'cond-mat.other' :'Other Condensed Matter', 'cond-mat.quant-gas' :'Quantum Gases', 'cond-mat.soft' :'Soft Condensed Matter', 'cond-mat.stat-mech' :'Statistical Mechanics', 'cond-mat.str-el' :'Strongly Correlated Electrons', 'cond-mat.supr-con' :'Superconductivity', 'gr-qc' :'General Relativity and Quantum Cosmology', 'hep-ex' :'High Energy Physics - Experiment', 'hep-lat' :'High Energy Physics - Lattice', 'hep-ph' :'High Energy Physics - Phenomenology', 'hep-th' :'High Energy Physics - Theory', 'math-ph' :'Mathematical Physics', 'nlin.AO' :'Adaptation and Self-Organizing Systems', 'nlin.CD' :'Chaotic Dynamics', 'nlin.CG' :'Cellular Automata and Lattice Gases', 'nlin.PS' :'Pattern Formation and Solitons', 'nlin.SI' :'Exactly Solvable and Integrable Systems', 'nucl-ex' :'Nuclear Experiment', 'nucl-th' :'Nuclear Theory', 'physics.acc-ph' :'Accelerator Physics', 'physics.ao-ph' :'Atmospheric and Oceanic Physics', 'physics.app-ph' :'Applied Physics', 'physics.atm-clus' :'Atomic and Molecular Clusters', 'physics.atom-ph' :'Atomic Physics', 'physics.bio-ph' :'Biological Physics', 'physics.chem-ph' :'Chemical Physics', 'physics.class-ph' :'Classical Physics', 'physics.comp-ph' :'Computational Physics', 'physics.data-an' :'Data Analysis, Statistics and Probability', 'physics.ed-ph' :'Physics Education', 'physics.flu-dyn' :'Fluid Dynamics', 'physics.gen-ph' :'General Physics', 'physics.geo-ph' :'Geophysics', 'physics.hist-ph' :'History and Philosophy of Physics', 'physics.ins-det' :'Instrumentation and Detectors', 'physics.med-ph' :'Medical Physics', 'physics.optics' :'Optics', 'physics.plasm-ph' :'Plasma Physics', 'physics.pop-ph' :'Popular Physics', 'physics.soc-ph' :'Physics and Society', 'physics.space-ph' :'Space Physics', 'quant-ph' :'Quantum Physics', 'q-bio.BM' :'Biomolecules', 'q-bio.CB' :'Cell Behavior', 'q-bio.GN' :'Genomics', 'q-bio.MN' :'Molecular Networks', 'q-bio.NC' :'Neurons and Cognition', 'q-bio.OT' :'Other Quantitative Biology', 'q-bio.PE' :'Populations and Evolution', 'q-bio.QM' :'Quantitative Methods', 'q-bio.SC' :'Subcellular Processes', 'q-bio.TO' :'Tissues and Organs', 'q-fin.CP' :'Computational Finance', 'q-fin.EC' :'Economics', 'q-fin.GN' :'General Finance', 'q-fin.MF' :'Mathematical Finance', 'q-fin.PM' :'Portfolio Management', 'q-fin.PR' :'Pricing of Securities', 'q-fin.RM' :'Risk Management', 'q-fin.ST' :'Statistical Finance', 'q-fin.TR' :'Trading and Market Microstructure', 'stat.AP' :'Applications', 'stat.CO' :'Computation', 'stat.ME' :'Methodology', 'stat.ML' :'Machine Learning', 'stat.OT' :'Other Statistics', 'stat.TH' :'Statistics Theory'}
labels = ['cs.AI','cs.AR','cs.CC','cs.CE','cs.CG','cs.CL','cs.CR','cs.CV','cs.CY','cs.DB','cs.DC','cs.DL','cs.DM','cs.DS','cs.ET','cs.FL','cs.GL','cs.GR','cs.GT','cs.HC','cs.IR','cs.IT','cs.LG','cs.LO','cs.MA','cs.MM','cs.MS','cs.NA','cs.NE','cs.NI','cs.OH','cs.OS','cs.PF','cs.PL','cs.RO','cs.SC','cs.SD','cs.SE','cs.SI','cs.SY','econ.EM','econ.GN','econ.TH','eess.AS','eess.IV','eess.SP','eess.SY','math.AC','math.AG','math.AP','math.AT','math.CA','math.CO','math.CT','math.CV','math.DG','math.DS','math.FA','math.GM','math.GN','math.GR','math.GT','math.HO','math.IT','math.KT','math.LO','math.MG','math.MP','math.NA','math.NT','math.OA','math.OC','math.PR','math.QA','math.RA','math.RT','math.SG','math.SP','math.ST','astro-ph.CO','astro-ph.EP','astro-ph.GA','astro-ph.HE','astro-ph.IM','astro-ph.SR','cond-mat.dis-nn','cond-mat.mes-hall','cond-mat.mtrl-sci','cond-mat.other','cond-mat.quant-gas','cond-mat.soft','cond-mat.stat-mech','cond-mat.str-el','cond-mat.supr-con','gr-qc','hep-ex','hep-lat','hep-ph','hep-th','math-ph','nlin.AO','nlin.CD','nlin.CG','nlin.PS','nlin.SI','nucl-ex','nucl-th','physics.acc-ph','physics.ao-ph','physics.app-ph','physics.atm-clus','physics.atom-ph','physics.bio-ph','physics.chem-ph','physics.class-ph','physics.comp-ph','physics.data-an','physics.ed-ph','physics.flu-dyn','physics.gen-ph','physics.geo-ph','physics.hist-ph','physics.ins-det','physics.med-ph','physics.optics','physics.plasm-ph','physics.pop-ph','physics.soc-ph','physics.space-ph','quant-ph','q-bio.BM','q-bio.CB','q-bio.GN','q-bio.MN','q-bio.NC','q-bio.OT','q-bio.PE','q-bio.QM','q-bio.SC','q-bio.TO','q-fin.CP','q-fin.EC','q-fin.GN','q-fin.MF','q-fin.PM','q-fin.PR','q-fin.RM','q-fin.ST','q-fin.TR','stat.AP','stat.CO','stat.ME','stat.ML','stat.OT','stat.TH']
@st.cache
def load_model():
prediction_model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-cased", num_labels=len(readable_labels))
return prediction_model
# @st.cache
# def load_tokenizer():
# model_name = 'distilbert-base-cased'
# tokenizer = AutoTokenizer.from_pretrained(model_name)
# return tokenizer
st.markdown("## Classify scientific article")
# st.markdown("", unsafe_allow_html=True)
st.markdown("<img width=400px src='https://64.media.tumblr.com/27a87f0c40cf49afcb2a99158b780573/d24245cb8f6d87b6-d0/s1280x1920/717a15ccacd345d120b380315554075231fb9c56.jpg'>", unsafe_allow_html=True)
st.markdown("Here we offer an application that predicts a possible theme of an article. The choice is made based on a machine learning algorithm.")
# ^-- можно показывать пользователю текст, картинки, ограниченное подмножество html - всё как в jupyter
title = st.text_area("Please insert the title of the article (obligatory)")
# ^-- показать текстовое поле. В поле text лежит строка, которая находится там в данный момент
summary = st.text_area("Please insert the abstract of the article (optional)")
#################
## Model loading
#################
prediction_model = load_model()
if summary == '':
prediction_model.load_state_dict(torch.load('./title_model'))
else:
# prediction_model.load_state_dict(torch.load('./title_model'))
prediction_model.load_state_dict(torch.load('./summary_model'))
#################
## Tokenizing
#################
# tokenizer = load_tokenizer()
model_name = 'distilbert-base-cased'
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokens = tokenizer.encode(title)
# tokens = tokenizer.encode(title + summary)
with torch.no_grad():
logits = prediction_model(torch.as_tensor([tokens], device='cpu'))[0]
procents = torch.softmax(logits, dim = 1).tolist()
labels_with_pros = list()
for index, procent in enumerate(procents[0]):
labels_with_pros.append([labels[index], procent])
labels_with_pros.sort(key = lambda x: x[1], reverse=True)
sum_p = 0
out = 0
while sum_p < 0.95:
sum_p += labels_with_pros[out][1]
out += 1
predictions = str()
for elem in labels_with_pros[:out]:
predictions +=readable_labels[elem[0]] + ' (' + elem[0] + '\t' + str(elem[1]) + ') \n '
# raw_predictions = "theme1"
st.markdown("#### Probably the article fits some of the following topics:")
st.markdown("information in braces shows the original label and final probability")
st.markdown(f"{predictions}")
# выводим результаты модели в текстовое поле, на потеху пользователю