davanstrien HF staff commited on
Commit
ad38c8f
1 Parent(s): 3c378cf
Files changed (3) hide show
  1. app.py +120 -0
  2. requirements.in +6 -0
  3. requirements.txt +329 -0
app.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import arxiv
2
+ import gradio as gr
3
+ import pandas as pd
4
+ from apscheduler.schedulers.background import BackgroundScheduler
5
+ from cachetools import TTLCache, cached
6
+ from setfit import SetFitModel
7
+ from tqdm.auto import tqdm
8
+
9
+ CACHE_TIME = 60 * 60 * 12
10
+ MAX_RESULTS = 30_000
11
+
12
+
13
+ @cached(cache=TTLCache(maxsize=10, ttl=CACHE_TIME))
14
+ def get_arxiv_result():
15
+ search = arxiv.Search(
16
+ query="ti:dataset AND abs:machine learning",
17
+ max_results=MAX_RESULTS,
18
+ sort_by=arxiv.SortCriterion.SubmittedDate,
19
+ )
20
+ return [
21
+ {
22
+ "title": result.title,
23
+ "abstract": result.summary,
24
+ "url": result.entry_id,
25
+ "category": result.primary_category,
26
+ "updated": result.updated,
27
+ }
28
+ for result in tqdm(search.results(), total=MAX_RESULTS)
29
+ ]
30
+
31
+
32
+ def load_model():
33
+ return SetFitModel.from_pretrained("librarian-bots/is_new_dataset_teacher_model")
34
+
35
+
36
+ def format_row_for_model(row):
37
+ return f"TITLE: {row['title']} \n\nABSTRACT: {row['abstract']}"
38
+
39
+
40
+ int2label = {0: "new_dataset", 1: "not_new_dataset"}
41
+
42
+
43
+ def get_predictions(data: list[dict], model=None, batch_size=32):
44
+ if model is None:
45
+ model = load_model()
46
+ predictions = []
47
+ for i in tqdm(range(0, len(data), batch_size)):
48
+ batch = data[i : i + batch_size]
49
+ text_inputs = [format_row_for_model(row) for row in batch]
50
+ batch_predictions = model.predict_proba(text_inputs)
51
+ for j, row in enumerate(batch):
52
+ prediction = batch_predictions[j]
53
+ row["prediction"] = int2label[int(prediction.argmax())]
54
+ row["probability"] = float(prediction.max())
55
+ predictions.append(row)
56
+ return predictions
57
+
58
+
59
+ def create_markdown(row):
60
+ title = row["title"]
61
+ abstract = row["abstract"]
62
+ arxiv_id = row["arxiv_id"]
63
+ hub_paper_url = f"https://huggingface.co/papers/{arxiv_id}"
64
+ updated = row["updated"]
65
+ updated = updated.strftime("%Y-%m-%d")
66
+ broad_category = row["broad_category"]
67
+ category = row["category"]
68
+ return f""" <h1> {title} </h1> updated: {updated}
69
+ | category: {broad_category} | subcategory: {category} |
70
+ \n\n{abstract}
71
+ \n\n [Hugging Face Papers page]({hub_paper_url})
72
+ """
73
+
74
+
75
+ @cached(cache=TTLCache(maxsize=100, ttl=CACHE_TIME))
76
+ def prepare_data():
77
+ print("Downloading arxiv results...")
78
+ arxiv_results = get_arxiv_result()
79
+ print("loading model...")
80
+ model = load_model()
81
+ print("Making predictions...")
82
+ predictions = get_predictions(arxiv_results, model=model)
83
+ df = pd.DataFrame(predictions)
84
+ df.loc[:, "arxiv_id"] = df["url"].str.extract(r"(\d+\.\d+)")
85
+ df.loc[:, "broad_category"] = df["category"].str.split(".").str[0]
86
+ df.loc[:, "markdown"] = df.apply(create_markdown, axis=1)
87
+ return df
88
+
89
+
90
+ all_possible_arxiv_categories = prepare_data().category.unique().tolist()
91
+ broad_categories = prepare_data().broad_category.unique().tolist()
92
+
93
+
94
+ def create_markdown_summary(categories=broad_categories, all_categories=None):
95
+ df = prepare_data()
96
+ if categories is not None:
97
+ df = df[df["broad_category"].isin(categories)]
98
+ return "\n\n".join(df["markdown"].tolist())
99
+
100
+
101
+ scheduler = BackgroundScheduler()
102
+ scheduler.add_job(prepare_data, "cron", hour=3, minute=30)
103
+ scheduler.start()
104
+
105
+ with gr.Blocks() as demo:
106
+ gr.Markdown("## New Datasets in Machine Learning")
107
+ gr.Markdown(
108
+ "This Space attempts to show new papers on arXiv that are *likely* to be papers"
109
+ " introducing new datasets. \n\n"
110
+ )
111
+ broad_categories = gr.Dropdown(
112
+ choices=broad_categories,
113
+ label="Categories",
114
+ multiselect=True,
115
+ value=broad_categories,
116
+ )
117
+ results = gr.Markdown(create_markdown_summary())
118
+ broad_categories.change(create_markdown_summary, broad_categories, results)
119
+
120
+ demo.launch()
requirements.in ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ apscheduler
2
+ arxiv
3
+ cachetools
4
+ gradio
5
+ scikit-learn==1.2.2
6
+ setfit
requirements.txt ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # This file is autogenerated by pip-compile with Python 3.11
3
+ # by the following command:
4
+ #
5
+ # pip-compile
6
+ #
7
+ aiofiles==23.2.1
8
+ # via gradio
9
+ aiohttp==3.8.5
10
+ # via
11
+ # datasets
12
+ # fsspec
13
+ aiosignal==1.3.1
14
+ # via aiohttp
15
+ altair==5.1.2
16
+ # via gradio
17
+ annotated-types==0.5.0
18
+ # via pydantic
19
+ anyio==3.7.1
20
+ # via
21
+ # fastapi
22
+ # httpcore
23
+ # starlette
24
+ apscheduler==3.10.4
25
+ # via -r requirements.in
26
+ arxiv==1.4.8
27
+ # via -r requirements.in
28
+ async-timeout==4.0.3
29
+ # via aiohttp
30
+ attrs==23.1.0
31
+ # via
32
+ # aiohttp
33
+ # jsonschema
34
+ # referencing
35
+ cachetools==5.3.1
36
+ # via -r requirements.in
37
+ certifi==2023.7.22
38
+ # via
39
+ # httpcore
40
+ # httpx
41
+ # requests
42
+ charset-normalizer==3.3.0
43
+ # via
44
+ # aiohttp
45
+ # requests
46
+ click==8.1.7
47
+ # via
48
+ # nltk
49
+ # uvicorn
50
+ contourpy==1.1.1
51
+ # via matplotlib
52
+ cycler==0.12.0
53
+ # via matplotlib
54
+ datasets==2.14.5
55
+ # via
56
+ # evaluate
57
+ # setfit
58
+ dill==0.3.7
59
+ # via
60
+ # datasets
61
+ # evaluate
62
+ # multiprocess
63
+ evaluate==0.4.0
64
+ # via setfit
65
+ fastapi==0.103.2
66
+ # via gradio
67
+ feedparser==6.0.10
68
+ # via arxiv
69
+ ffmpy==0.3.1
70
+ # via gradio
71
+ filelock==3.12.4
72
+ # via
73
+ # huggingface-hub
74
+ # torch
75
+ # transformers
76
+ fonttools==4.43.0
77
+ # via matplotlib
78
+ frozenlist==1.4.0
79
+ # via
80
+ # aiohttp
81
+ # aiosignal
82
+ fsspec[http]==2023.6.0
83
+ # via
84
+ # datasets
85
+ # evaluate
86
+ # gradio-client
87
+ # huggingface-hub
88
+ # torch
89
+ gradio==3.46.1
90
+ # via -r requirements.in
91
+ gradio-client==0.5.3
92
+ # via gradio
93
+ h11==0.14.0
94
+ # via
95
+ # httpcore
96
+ # uvicorn
97
+ httpcore==0.18.0
98
+ # via httpx
99
+ httpx==0.25.0
100
+ # via
101
+ # gradio
102
+ # gradio-client
103
+ huggingface-hub==0.16.4
104
+ # via
105
+ # datasets
106
+ # evaluate
107
+ # gradio
108
+ # gradio-client
109
+ # sentence-transformers
110
+ # tokenizers
111
+ # transformers
112
+ idna==3.4
113
+ # via
114
+ # anyio
115
+ # httpx
116
+ # requests
117
+ # yarl
118
+ importlib-resources==6.1.0
119
+ # via gradio
120
+ jinja2==3.1.2
121
+ # via
122
+ # altair
123
+ # gradio
124
+ # torch
125
+ joblib==1.3.2
126
+ # via
127
+ # nltk
128
+ # scikit-learn
129
+ jsonschema==4.19.1
130
+ # via altair
131
+ jsonschema-specifications==2023.7.1
132
+ # via jsonschema
133
+ kiwisolver==1.4.5
134
+ # via matplotlib
135
+ markupsafe==2.1.3
136
+ # via
137
+ # gradio
138
+ # jinja2
139
+ matplotlib==3.8.0
140
+ # via gradio
141
+ mpmath==1.3.0
142
+ # via sympy
143
+ multidict==6.0.4
144
+ # via
145
+ # aiohttp
146
+ # yarl
147
+ multiprocess==0.70.15
148
+ # via
149
+ # datasets
150
+ # evaluate
151
+ networkx==3.1
152
+ # via torch
153
+ nltk==3.8.1
154
+ # via sentence-transformers
155
+ numpy==1.26.0
156
+ # via
157
+ # altair
158
+ # contourpy
159
+ # datasets
160
+ # evaluate
161
+ # gradio
162
+ # matplotlib
163
+ # pandas
164
+ # pyarrow
165
+ # scikit-learn
166
+ # scipy
167
+ # sentence-transformers
168
+ # torchvision
169
+ # transformers
170
+ orjson==3.9.7
171
+ # via gradio
172
+ packaging==23.2
173
+ # via
174
+ # altair
175
+ # datasets
176
+ # evaluate
177
+ # gradio
178
+ # gradio-client
179
+ # huggingface-hub
180
+ # matplotlib
181
+ # transformers
182
+ pandas==2.1.1
183
+ # via
184
+ # altair
185
+ # datasets
186
+ # evaluate
187
+ # gradio
188
+ pillow==10.0.1
189
+ # via
190
+ # gradio
191
+ # matplotlib
192
+ # torchvision
193
+ pyarrow==13.0.0
194
+ # via datasets
195
+ pydantic==2.4.2
196
+ # via
197
+ # fastapi
198
+ # gradio
199
+ pydantic-core==2.10.1
200
+ # via pydantic
201
+ pydub==0.25.1
202
+ # via gradio
203
+ pyparsing==3.1.1
204
+ # via matplotlib
205
+ python-dateutil==2.8.2
206
+ # via
207
+ # matplotlib
208
+ # pandas
209
+ python-multipart==0.0.6
210
+ # via gradio
211
+ pytz==2023.3.post1
212
+ # via
213
+ # apscheduler
214
+ # pandas
215
+ pyyaml==6.0.1
216
+ # via
217
+ # datasets
218
+ # gradio
219
+ # huggingface-hub
220
+ # transformers
221
+ referencing==0.30.2
222
+ # via
223
+ # jsonschema
224
+ # jsonschema-specifications
225
+ regex==2023.10.3
226
+ # via
227
+ # nltk
228
+ # transformers
229
+ requests==2.31.0
230
+ # via
231
+ # datasets
232
+ # evaluate
233
+ # fsspec
234
+ # gradio
235
+ # gradio-client
236
+ # huggingface-hub
237
+ # responses
238
+ # torchvision
239
+ # transformers
240
+ responses==0.18.0
241
+ # via evaluate
242
+ rpds-py==0.10.4
243
+ # via
244
+ # jsonschema
245
+ # referencing
246
+ safetensors==0.3.3
247
+ # via transformers
248
+ scikit-learn==1.2.2
249
+ # via
250
+ # -r requirements.in
251
+ # sentence-transformers
252
+ scipy==1.11.3
253
+ # via
254
+ # scikit-learn
255
+ # sentence-transformers
256
+ semantic-version==2.10.0
257
+ # via gradio
258
+ sentence-transformers==2.2.2
259
+ # via setfit
260
+ sentencepiece==0.1.99
261
+ # via sentence-transformers
262
+ setfit==0.7.0
263
+ # via -r requirements.in
264
+ sgmllib3k==1.0.0
265
+ # via feedparser
266
+ six==1.16.0
267
+ # via
268
+ # apscheduler
269
+ # python-dateutil
270
+ sniffio==1.3.0
271
+ # via
272
+ # anyio
273
+ # httpcore
274
+ # httpx
275
+ starlette==0.27.0
276
+ # via fastapi
277
+ sympy==1.12
278
+ # via torch
279
+ threadpoolctl==3.2.0
280
+ # via scikit-learn
281
+ tokenizers==0.14.0
282
+ # via transformers
283
+ toolz==0.12.0
284
+ # via altair
285
+ torch==2.1.0
286
+ # via
287
+ # sentence-transformers
288
+ # torchvision
289
+ torchvision==0.16.0
290
+ # via sentence-transformers
291
+ tqdm==4.66.1
292
+ # via
293
+ # datasets
294
+ # evaluate
295
+ # huggingface-hub
296
+ # nltk
297
+ # sentence-transformers
298
+ # transformers
299
+ transformers==4.34.0
300
+ # via sentence-transformers
301
+ typing-extensions==4.8.0
302
+ # via
303
+ # fastapi
304
+ # gradio
305
+ # gradio-client
306
+ # huggingface-hub
307
+ # pydantic
308
+ # pydantic-core
309
+ # torch
310
+ tzdata==2023.3
311
+ # via pandas
312
+ tzlocal==5.1
313
+ # via apscheduler
314
+ urllib3==2.0.6
315
+ # via
316
+ # requests
317
+ # responses
318
+ uvicorn==0.23.2
319
+ # via gradio
320
+ websockets==11.0.3
321
+ # via
322
+ # gradio
323
+ # gradio-client
324
+ xxhash==3.4.1
325
+ # via
326
+ # datasets
327
+ # evaluate
328
+ yarl==1.9.2
329
+ # via aiohttp