|
import os |
|
|
|
import arxiv |
|
import gradio as gr |
|
import pandas as pd |
|
from apscheduler.schedulers.background import BackgroundScheduler |
|
from cachetools import TTLCache, cached |
|
from setfit import SetFitModel |
|
from tqdm.auto import tqdm |
|
import stamina |
|
from arxiv import UnexpectedEmptyPageError, ArxivError |
|
|
|
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" |
|
|
|
CACHE_TIME = 60 * 60 * 12 |
|
MAX_RESULTS = 300 |
|
|
|
|
|
client = arxiv.Client(page_size=50, delay_seconds=3, num_retries=2) |
|
|
|
|
|
@cached(cache=TTLCache(maxsize=10, ttl=CACHE_TIME)) |
|
def get_arxiv_result(): |
|
return _get_arxiv_result() |
|
|
|
|
|
@stamina.retry( |
|
on=(ValueError, UnexpectedEmptyPageError, ArxivError), attempts=10, wait_max=60 * 15 |
|
) |
|
def _get_arxiv_result(): |
|
results = [ |
|
{ |
|
"title": result.title, |
|
"abstract": result.summary, |
|
"url": result.entry_id, |
|
"category": result.primary_category, |
|
"updated": result.updated, |
|
} |
|
for result in tqdm( |
|
client.results( |
|
arxiv.Search( |
|
query="ti:dataset", |
|
max_results=MAX_RESULTS, |
|
sort_by=arxiv.SortCriterion.SubmittedDate, |
|
) |
|
), |
|
total=MAX_RESULTS, |
|
) |
|
] |
|
if len(results) > 1: |
|
return results |
|
else: |
|
raise ValueError("No results found") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_model(): |
|
return SetFitModel.from_pretrained("librarian-bots/is_new_dataset_teacher_model") |
|
|
|
|
|
def format_row_for_model(row): |
|
return f"TITLE: {row['title']} \n\nABSTRACT: {row['abstract']}" |
|
|
|
|
|
int2label = {0: "new_dataset", 1: "not_new_dataset"} |
|
|
|
|
|
def get_predictions(data: list[dict], model=None, batch_size=128): |
|
if model is None: |
|
model = load_model() |
|
predictions = [] |
|
for i in tqdm(range(0, len(data), batch_size)): |
|
batch = data[i : i + batch_size] |
|
text_inputs = [format_row_for_model(row) for row in batch] |
|
batch_predictions = model.predict_proba(text_inputs) |
|
for j, row in enumerate(batch): |
|
prediction = batch_predictions[j] |
|
row["prediction"] = int2label[int(prediction.argmax())] |
|
row["probability"] = float(prediction.max()) |
|
predictions.append(row) |
|
return predictions |
|
|
|
|
|
def create_markdown(row): |
|
title = row["title"] |
|
abstract = row["abstract"] |
|
arxiv_id = row["arxiv_id"] |
|
hub_paper_url = f"https://huggingface.co/papers/{arxiv_id}" |
|
updated = row["updated"] |
|
updated = updated.strftime("%Y-%m-%d") |
|
broad_category = row["broad_category"] |
|
category = row["category"] |
|
return f""" <h2> {title} </h2> Updated: {updated} |
|
| Category: {broad_category} | Subcategory: {category} | |
|
\n\n{abstract} |
|
\n\n [Hugging Face Papers page]({hub_paper_url}) |
|
""" |
|
|
|
|
|
@cached(cache=TTLCache(maxsize=10, ttl=CACHE_TIME)) |
|
def prepare_data(): |
|
print("Downloading arxiv results...") |
|
arxiv_results = get_arxiv_result() |
|
print("loading model...") |
|
model = load_model() |
|
print("Making predictions...") |
|
predictions = get_predictions(arxiv_results, model=model) |
|
df = pd.DataFrame(predictions) |
|
df.loc[:, "arxiv_id"] = df["url"].str.extract(r"(\d+\.\d+)") |
|
df.loc[:, "broad_category"] = df["category"].str.split(".").str[0] |
|
df.loc[:, "markdown"] = df.apply(create_markdown, axis=1) |
|
return df |
|
|
|
|
|
all_possible_arxiv_categories = sorted(prepare_data().category.unique().tolist()) |
|
broad_categories = sorted(prepare_data().broad_category.unique().tolist()) |
|
|
|
|
|
|
|
def create_markdown_summary(categories=None, new_only=True, narrow_categories=None): |
|
df = prepare_data() |
|
if new_only: |
|
df = df[df["prediction"] == "new_dataset"] |
|
if narrow_categories is not None: |
|
df = df[df["category"].isin(narrow_categories)] |
|
if categories is not None and not narrow_categories: |
|
df = prepare_data() |
|
if new_only: |
|
df = df[df["prediction"] == "new_dataset"] |
|
df = df[df["broad_category"].isin(categories)] |
|
number_of_results = len(df) |
|
results = ( |
|
"<h1 style='text-align: center'> arXiv papers related to datasets</h1> \n\n" |
|
) |
|
results += f"Number of results: {number_of_results}\n\n" |
|
results += "\n\n<br>".join(df["markdown"].tolist()) |
|
return results |
|
|
|
|
|
scheduler = BackgroundScheduler() |
|
scheduler.add_job(prepare_data, "cron", hour=3, minute=30) |
|
scheduler.start() |
|
|
|
description = """This Space shows recent papers on arXiv that are *likely* to be papers introducing new datasets related to machine learning. \n\n |
|
The Space works by: |
|
- searching for papers on arXiv with the term `dataset` in the title + "machine learning" in the abstract |
|
- passing the abstract and title of the papers to a machine learning model that predicts if the paper is introducing a new dataset or not |
|
|
|
This Space is a work in progress. The model is not perfect, and the search query is not perfect. If you have suggestions for how to improve this Space, please open a Discussion.\n\n""" |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown( |
|
"<h1 style='text-align: center'> ✨New Datasets in Machine Learning " |
|
" ✨ </h1>" |
|
) |
|
gr.Markdown(description) |
|
with gr.Row(): |
|
broad_categories = gr.Dropdown( |
|
choices=broad_categories, |
|
label="Broad arXiv Category", |
|
multiselect=True, |
|
value="cs", |
|
) |
|
with gr.Accordion("Advanced Options", open=False): |
|
gr.Markdown( |
|
"Narrow by arXiv categories. **Note** this will take precedence over the" |
|
" broad category selection." |
|
) |
|
narrow_categories = gr.Dropdown( |
|
choices=all_possible_arxiv_categories, |
|
value=None, |
|
multiselect=True, |
|
label="Narrow arXiv Category", |
|
) |
|
gr.ClearButton(narrow_categories, "Clear Narrow Categories", size="sm") |
|
with gr.Row(): |
|
new_only = gr.Checkbox(True, label="New Datasets Only", interactive=True) |
|
results = gr.Markdown(create_markdown_summary()) |
|
broad_categories.change( |
|
create_markdown_summary, |
|
inputs=[broad_categories, new_only, narrow_categories], |
|
outputs=results, |
|
) |
|
narrow_categories.change( |
|
create_markdown_summary, |
|
inputs=[broad_categories, new_only, narrow_categories], |
|
outputs=results, |
|
) |
|
new_only.change( |
|
create_markdown_summary, |
|
[broad_categories, new_only, narrow_categories], |
|
results, |
|
) |
|
|
|
demo.launch() |
|
|