Spaces:
Sleeping
Sleeping
Pavel Malov
commited on
Commit
Β·
28f6ce1
1
Parent(s):
d2af509
Add model
Browse files- app.py +15 -12
- inference.py +81 -0
- requirements.txt +2 -0
- resources/model.ckpt +3 -0
- resources/tag_mapping.json +172 -0
app.py
CHANGED
@@ -1,24 +1,27 @@
|
|
1 |
import streamlit as st
|
|
|
2 |
|
3 |
|
4 |
st.set_page_config(layout="wide")
|
5 |
|
6 |
-
st.
|
7 |
-
<style>
|
8 |
-
.big-font {
|
9 |
-
font-size:300px !important;
|
10 |
-
}
|
11 |
-
</style>
|
12 |
-
""", unsafe_allow_html=True)
|
13 |
-
|
14 |
-
st.title("ArxivTitlePicker")
|
15 |
st.write("This app helps define category of your scientific paper based on its name and abstract.")
|
16 |
name = st.text_input("Paste here name of your paper")
|
17 |
abstract = st.text_area("Paste here abstract of your paper")
|
18 |
|
19 |
-
|
20 |
-
|
|
|
|
|
|
|
21 |
|
22 |
if st.button("Start processing"):
|
23 |
if name == '':
|
24 |
-
st.write('<p style="font-family:sans-serif; color:Red; font-size: 21px;">Please, provide name of the paper!πββοΈ</p>', unsafe_allow_html=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
+
from inference import InferenceModel
|
3 |
|
4 |
|
5 |
st.set_page_config(layout="wide")
|
6 |
|
7 |
+
st.title("ArxivTopicPicker")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
st.write("This app helps define category of your scientific paper based on its name and abstract.")
|
9 |
name = st.text_input("Paste here name of your paper")
|
10 |
abstract = st.text_area("Paste here abstract of your paper")
|
11 |
|
12 |
+
model = InferenceModel()
|
13 |
+
model.inference('load')
|
14 |
+
|
15 |
+
# if name != '':
|
16 |
+
# st.text("Your paper:\n\tName: " + name + '.\n\tAbstract: ' + abstract)
|
17 |
|
18 |
if st.button("Start processing"):
|
19 |
if name == '':
|
20 |
+
st.write('<p style="font-family:sans-serif; color:Red; font-size: 21px;">Please, provide name of the paper!πββοΈ</p>', unsafe_allow_html=True)
|
21 |
+
else:
|
22 |
+
input_text = name + '. ' + abstract if abstract != '' else name + '.'
|
23 |
+
top_topics = model.inference(input_text)
|
24 |
+
if len(top_topics) == 0:
|
25 |
+
st.text("We don't know yetπ°")
|
26 |
+
else:
|
27 |
+
st.text(top_topics)
|
inference.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
from typing import List, Dict, Set
|
5 |
+
from pathlib import Path
|
6 |
+
from transformers import DistilBertTokenizer, DistilBertModel
|
7 |
+
|
8 |
+
|
9 |
+
class Nnet(nn.Module):
|
10 |
+
def __init__(self) -> None:
|
11 |
+
super().__init__()
|
12 |
+
|
13 |
+
self.nnet = nn.Sequential(
|
14 |
+
nn.Linear(768, 256),
|
15 |
+
nn.ReLU(),
|
16 |
+
nn.BatchNorm1d(256),
|
17 |
+
nn.Linear(256, 85)
|
18 |
+
)
|
19 |
+
|
20 |
+
def forward(self, x):
|
21 |
+
return self.nnet(x)
|
22 |
+
|
23 |
+
|
24 |
+
class ClassificationHead(nn.Module):
|
25 |
+
def __init__(self) -> None:
|
26 |
+
super().__init__()
|
27 |
+
|
28 |
+
self.nnet = Nnet()
|
29 |
+
|
30 |
+
ckpt = torch.load("resources/model.ckpt")
|
31 |
+
self.nnet.load_state_dict(ckpt['state_dict'], strict=False)
|
32 |
+
|
33 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
34 |
+
return self.nnet(x.unsqueeze(0))
|
35 |
+
|
36 |
+
class InferenceModel:
|
37 |
+
def __init__(self) -> None:
|
38 |
+
self.tokenizer: DistilBertTokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
|
39 |
+
self.bert: DistilBertModel = DistilBertModel.from_pretrained("distilbert-base-uncased")
|
40 |
+
self.head: nn.Module = ClassificationHead()
|
41 |
+
|
42 |
+
values: Set = set(json.loads(Path('resources/tag_mapping.json').read_text()).values())
|
43 |
+
values.remove('')
|
44 |
+
self.mapping: Dict = dict()
|
45 |
+
for i, val in enumerate(values):
|
46 |
+
self.mapping[i] = val
|
47 |
+
|
48 |
+
def topp(self, probs: torch.Tensor):
|
49 |
+
# sort probs
|
50 |
+
sorted_probs, sorted_inds = torch.sort(probs, descending=True)
|
51 |
+
# accumulate probs
|
52 |
+
accum = torch.cumsum(sorted_probs, dim=0)
|
53 |
+
# get index of the first element where cumsum reached 0.95
|
54 |
+
ind = torch.nonzero(accum > 0.95)[0]
|
55 |
+
return sorted_inds[:ind]
|
56 |
+
|
57 |
+
def get_lables(self, classes: torch.Tensor) -> List[str]:
|
58 |
+
output = ""
|
59 |
+
for cls in classes.numpy():
|
60 |
+
output += self.mapping[cls] + '\n'
|
61 |
+
|
62 |
+
return output
|
63 |
+
|
64 |
+
def inference(self, x: str) -> List[str]:
|
65 |
+
self.bert.eval()
|
66 |
+
self.head.eval()
|
67 |
+
with torch.no_grad():
|
68 |
+
# tokenize: str -> Tokens
|
69 |
+
encoded_input = self.tokenizer(x, return_tensors='pt', truncation=True)
|
70 |
+
# get embedding: Tokens -> Embeddings -> MeanEmbedding
|
71 |
+
embeddings = self.bert(**encoded_input)
|
72 |
+
mean_embedding = embeddings[0].mean(dim=1)[0]
|
73 |
+
# get probs: MeanEmbedding -> Probs
|
74 |
+
probs = self.head(mean_embedding).softmax(-1)[0]
|
75 |
+
|
76 |
+
# get top_p classes: Probs -> 95% classes
|
77 |
+
topp_calsses = self.topp(probs)
|
78 |
+
print(probs)
|
79 |
+
# map classes to lables
|
80 |
+
return self.get_lables(topp_calsses)
|
81 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
torch==1.13
|
2 |
+
transformers
|
resources/model.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7d581cc499259712e58a5cf251c7c2d8054d8d67cad61bde6c0e936ff4e285ca
|
3 |
+
size 2643089
|
resources/tag_mapping.json
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"60g15": "Probability",
|
3 |
+
"62-07": "Statistics Theory",
|
4 |
+
"62f15": "Parametric inference",
|
5 |
+
"62g08": "Nonparametric inference",
|
6 |
+
"62h30": "Multivariate analysis",
|
7 |
+
"62m45": "Inference from stochastic processes",
|
8 |
+
"65k10": "Mathematical programming, optimization and variational techniques",
|
9 |
+
"68q32": "Theory of computing",
|
10 |
+
"68t01": "Artificial intelligence",
|
11 |
+
"68t05": "Artificial intelligence",
|
12 |
+
"68t10": "Artificial intelligence",
|
13 |
+
"68t20": "Artificial intelligence",
|
14 |
+
"68t27": "Artificial intelligence",
|
15 |
+
"68t30": "Artificial intelligence",
|
16 |
+
"68t37": "Artificial intelligence",
|
17 |
+
"68t40": "Artificial intelligence",
|
18 |
+
"68t45": "Artificial intelligence",
|
19 |
+
"68t50": "Artificial intelligence",
|
20 |
+
"68txx": "Artificial intelligence",
|
21 |
+
"68u10": "Computing methodologies and applications",
|
22 |
+
"90c25": "Mathematical programming",
|
23 |
+
"90c26": "Mathematical programming",
|
24 |
+
"90c90": "Mathematical programming",
|
25 |
+
"91f20": "Other social and behavioral sciences (mathematical treatment)",
|
26 |
+
"92b20": "Mathematical biology in general",
|
27 |
+
"94a08": "Communication, information",
|
28 |
+
"97r40": "Mathematics education",
|
29 |
+
"astro-ph.im": "Instrumentation and Methods for Astrophysics",
|
30 |
+
"c.1.3": "Distributed, Parallel, and Cluster Computing",
|
31 |
+
"c.2.4": "Distributed, Parallel, and Cluster Computing",
|
32 |
+
"cmp-lg": "Computation and Language",
|
33 |
+
"cond-mat.dis-nn": "Disordered Systems and Neural Networks",
|
34 |
+
"cond-mat.stat-mech": "",
|
35 |
+
"cs.ai": "Artificial intelligence",
|
36 |
+
"cs.ar": "Hardware Architecture",
|
37 |
+
"cs.cc": "Computational Complexity",
|
38 |
+
"cs.ce": "Computational Engineering, Finance, and Science",
|
39 |
+
"cs.cg": "Computational Geometry",
|
40 |
+
"cs.cl": "Computation and Language",
|
41 |
+
"cs.cr": "Cryptography and Security",
|
42 |
+
"cs.cv": "Computer Vision and Pattern Recognition",
|
43 |
+
"cs.cy": "Computers and Society",
|
44 |
+
"cs.db": "Databases",
|
45 |
+
"cs.dc": "Distributed, Parallel, and Cluster Computing",
|
46 |
+
"cs.dl": "Digital Libraries",
|
47 |
+
"cs.dm": "Discrete Mathematics",
|
48 |
+
"cs.ds": "Data Structures and Algorithms",
|
49 |
+
"cs.et": "Emerging Technologies",
|
50 |
+
"cs.fl": "Formal Languages and Automata Theory",
|
51 |
+
"cs.gr": "Graphics",
|
52 |
+
"cs.gt": "Computer Science and Game Theory",
|
53 |
+
"cs.hc": "Human-Computer Interaction",
|
54 |
+
"cs.ir": "Information Retrieval",
|
55 |
+
"cs.it": "Information Theory",
|
56 |
+
"cs.lg": "Machine Learning",
|
57 |
+
"cs.lo": "Logic in Computer Science",
|
58 |
+
"cs.ma": "Multiagent Systems",
|
59 |
+
"cs.mm": "Multimedia",
|
60 |
+
"cs.ms": "Mathematical Software",
|
61 |
+
"cs.na": "Numerical Analysis",
|
62 |
+
"cs.ne": "Neural and Evolutionary Computing",
|
63 |
+
"cs.ni": "Networking and Internet Architecture",
|
64 |
+
"cs.pf": "Performance",
|
65 |
+
"cs.pl": "Programming Languages",
|
66 |
+
"cs.ro": "Robotics",
|
67 |
+
"cs.sc": "Symbolic Computation",
|
68 |
+
"cs.sd": "Sound",
|
69 |
+
"cs.se": "Software Engineering",
|
70 |
+
"cs.si": "Social and Information Networks",
|
71 |
+
"cs.sy": "Systems and Control",
|
72 |
+
"d.1.3": "Distributed, Parallel, and Cluster Computing",
|
73 |
+
"d.1.6": "Programming Languages",
|
74 |
+
"d.2.2": "Software Engineering",
|
75 |
+
"d.3.1": "Programming Languages",
|
76 |
+
"d.3.2": "Programming Languages",
|
77 |
+
"d.3.3": "Programming Languages",
|
78 |
+
"e.2": "Databases",
|
79 |
+
"e.4": "Information Theory",
|
80 |
+
"eess.as": "Sound",
|
81 |
+
"eess.iv": "Computer Vision and Pattern Recognition",
|
82 |
+
"eess.sp": "Signal Processing",
|
83 |
+
"f.1.1": "Formal Languages and Automata Theory",
|
84 |
+
"f.1.3": "Computational Complexity",
|
85 |
+
"f.2": "Data Structures and Algorithms",
|
86 |
+
"f.2.2": "Data Structures and Algorithms",
|
87 |
+
"f.4.1": "Logic in Computer Science",
|
88 |
+
"f.4.2": "Logic in Computer Science",
|
89 |
+
"g.1.2": "Numerical Analysis",
|
90 |
+
"g.1.3": "Numerical Analysis",
|
91 |
+
"g.1.6": "Numerical Analysi",
|
92 |
+
"g.2.2": "Discrete Mathematics",
|
93 |
+
"g.3": "Discrete Mathematics",
|
94 |
+
"h.1.1": "Information Theory",
|
95 |
+
"h.1.2": "Human-Computer Interaction",
|
96 |
+
"h.2.4": "Databases",
|
97 |
+
"h.2.8": "Databases",
|
98 |
+
"h.3.1": "Information Retrieval",
|
99 |
+
"h.3.3": "Information Retrieval",
|
100 |
+
"h.3.4": "Information Retrieval",
|
101 |
+
"h.3.5": "Information Retrieval",
|
102 |
+
"h.5.1": "Sound",
|
103 |
+
"h.5.2": "Sound",
|
104 |
+
"h.5.3": "Sound",
|
105 |
+
"i.2": "Artificial intelligence",
|
106 |
+
"i.2.0": "Artificial intelligence",
|
107 |
+
"i.2.1": "Artificial intelligence",
|
108 |
+
"i.2.10": "Artificial intelligence",
|
109 |
+
"i.2.11": "Artificial intelligence",
|
110 |
+
"i.2.2": "Artificial intelligence",
|
111 |
+
"i.2.3": "Artificial intelligence",
|
112 |
+
"i.2.4": "Artificial intelligence",
|
113 |
+
"i.2.6": "Artificial intelligence",
|
114 |
+
"i.2.7": "Artificial intelligence",
|
115 |
+
"i.2.8": "Artificial intelligence",
|
116 |
+
"i.2.9": "Artificial intelligence",
|
117 |
+
"i.4": "Computer Vision and Pattern Recognition",
|
118 |
+
"i.4.1": "Computer Vision and Pattern Recognition",
|
119 |
+
"i.4.10": "Computer Vision and Pattern Recognition",
|
120 |
+
"i.4.3": "Computer Vision and Pattern Recognition",
|
121 |
+
"i.4.5": "Computer Vision and Pattern Recognition",
|
122 |
+
"i.4.6": "Computer Vision and Pattern Recognition",
|
123 |
+
"i.4.7": "Computer Vision and Pattern Recognition",
|
124 |
+
"i.4.8": "Computer Vision and Pattern Recognition",
|
125 |
+
"i.4.9": "Computer Vision and Pattern Recognition",
|
126 |
+
"i.5": "Computer Vision and Pattern Recognition",
|
127 |
+
"i.5.1": "Computer Vision and Pattern Recognition",
|
128 |
+
"i.5.2": "Computer Vision and Pattern Recognition",
|
129 |
+
"i.5.3": "Computer Vision and Pattern Recognition",
|
130 |
+
"i.5.4": "Computer Vision and Pattern Recognition",
|
131 |
+
"i.5.5": "Computer Vision and Pattern Recognition",
|
132 |
+
"j.2": "Computer Applications",
|
133 |
+
"j.3": "Computer Applications",
|
134 |
+
"j.4": "Computer Applications",
|
135 |
+
"j.5": "Computer Applications",
|
136 |
+
"k.3.2": "Computers and Society",
|
137 |
+
"math.ag": "Algebraic Geometry",
|
138 |
+
"math.co": "Combinatorics",
|
139 |
+
"math.ct": "Category Theory",
|
140 |
+
"math.dg": "Differential Geometry",
|
141 |
+
"math.ds": "Dynamical Systems",
|
142 |
+
"math.fa": "Functional Analysis",
|
143 |
+
"math.it": "Information Theory",
|
144 |
+
"math.lo": "Logic",
|
145 |
+
"math.na": "Numerical Analysis",
|
146 |
+
"math.oc": "Optimization and Control",
|
147 |
+
"math.pr": "Probability",
|
148 |
+
"math.st": "Statistics Theory",
|
149 |
+
"nlin.ao": "Adaptation and Self-Organizing Systems",
|
150 |
+
"nlin.cd": "Chaotic Dynamics",
|
151 |
+
"nlin.cg": "Cellular Automata and Lattice Gases",
|
152 |
+
"physics.ao-ph": "Astrophysics",
|
153 |
+
"physics.bio-ph": "Biological Physics",
|
154 |
+
"physics.chem-ph": "Chemical Physics",
|
155 |
+
"physics.comp-ph": "Computational Physics",
|
156 |
+
"physics.data-an": "Data Analysis, Statistics and Probability",
|
157 |
+
"physics.med-ph": "Medical Physics",
|
158 |
+
"physics.optics": "Optics",
|
159 |
+
"physics.soc-ph": "Physics and Society",
|
160 |
+
"q-bio.bm": "Biomolecules",
|
161 |
+
"q-bio.gn": "Genomics",
|
162 |
+
"q-bio.mn": "Molecular Networks",
|
163 |
+
"q-bio.nc": "Neurons and Cognition",
|
164 |
+
"q-bio.pe": "Populations and Evolution",
|
165 |
+
"q-bio.qm": "Quantitative Methods",
|
166 |
+
"quant-ph": "Quantum Physics",
|
167 |
+
"stat.ap": "Applications",
|
168 |
+
"stat.co": "Computation",
|
169 |
+
"stat.me": "Methodology",
|
170 |
+
"stat.ml": "Machine Learning",
|
171 |
+
"stat.th": "Statistics Theory"
|
172 |
+
}
|