Spaces:
Sleeping
Sleeping
#!/usr/bin/env python | |
import copy | |
import json | |
import pathlib | |
import tempfile | |
import gradio as gr | |
from huggingface_hub import CommitOperationAdd, HfApi | |
from papers import PaperList | |
REPO_ID = "ICLR2024/ICLR2024-papers" | |
FILENAME = "data.json" | |
api = HfApi() | |
paper_list = PaperList() | |
path = api.hf_hub_download(repo_id=REPO_ID, filename=FILENAME, repo_type="dataset") | |
with pathlib.Path(path).open() as f: | |
raw_data = json.load(f) | |
paper_id_to_index = {str(paper["id"]): i for i, paper in enumerate(raw_data)} | |
with gr.Blocks() as demo_search: | |
with gr.Group(): | |
search_title = gr.Textbox(label="Search title") | |
search_author = gr.Textbox(label="Search author") | |
df = gr.Dataframe( | |
value=paper_list.df_prettified, | |
datatype=paper_list.get_column_datatypes(paper_list.get_column_names()), | |
type="pandas", | |
row_count=(0, "dynamic"), | |
interactive=False, | |
max_height=1000, | |
elem_id="table", | |
wrap=True, | |
) | |
inputs = [ | |
search_title, | |
search_author, | |
] | |
gr.on( | |
triggers=[ | |
search_title.submit, | |
search_author.submit, | |
], | |
fn=paper_list.search, | |
inputs=inputs, | |
outputs=df, | |
queue=False, | |
api_name=False, | |
) | |
demo_search.load( | |
fn=paper_list.search, | |
inputs=inputs, | |
outputs=df, | |
queue=False, | |
api_name=False, | |
) | |
def load_data(paper_id: str) -> tuple[str, str, str, str, str, str, str, str, str]: | |
try: | |
index = paper_id_to_index[paper_id] | |
except KeyError as e: | |
error_message = f"Paper ID {paper_id} not found." | |
raise gr.Error(error_message) from e | |
paper = raw_data[index] | |
return ( | |
paper["id"], | |
paper["title"], | |
"\n".join(paper["authors"]), | |
paper["arxiv_id"], | |
paper["project_page"], | |
"\n".join(paper["GitHub"]), | |
"\n".join(paper["Space"]), | |
"\n".join(paper["Model"]), | |
"\n".join(paper["Dataset"]), | |
) | |
def split_and_strip(s: str) -> list[str]: | |
return [x.strip() for x in s.split("\n") if x.strip()] | |
def create_pr( | |
paper_id: str, | |
title: str, | |
authors: str, | |
arxiv_id: str, | |
project_page: str, | |
github_links: str, | |
space_ids: str, | |
model_ids: str, | |
dataset_ids: str, | |
oauth_token: gr.OAuthToken | None, | |
) -> str: | |
if oauth_token is None: | |
return "Please log in first." | |
try: | |
index = paper_id_to_index[paper_id] | |
except KeyError as e: | |
error_message = f"Paper ID {paper_id} not found." | |
raise gr.Error(error_message) from e | |
data = copy.deepcopy(raw_data) | |
data[index]["title"] = title.strip() | |
data[index]["authors"] = split_and_strip(authors) | |
data[index]["arxiv_id"] = arxiv_id.strip() | |
data[index]["project_page"] = project_page.strip() | |
data[index]["GitHub"] = split_and_strip(github_links) | |
data[index]["Space"] = split_and_strip(space_ids) | |
data[index]["Model"] = split_and_strip(model_ids) | |
data[index]["Dataset"] = split_and_strip(dataset_ids) | |
with tempfile.NamedTemporaryFile(suffix=".json", mode="w", delete=False) as f: | |
json.dump(data, f, indent=2) | |
commit = CommitOperationAdd(FILENAME, f.name) | |
res = api.create_commit( | |
repo_id=REPO_ID, | |
operations=[commit], | |
commit_message=f"Update {paper_id}", | |
repo_type="dataset", | |
create_pr=True, | |
token=oauth_token.token, | |
) | |
return res.pr_url | |
with gr.Blocks() as demo_edit: | |
with gr.Group(): | |
paper_id = gr.Textbox(label="ID", max_lines=1) | |
load_button = gr.Button("Load") | |
with gr.Group(): | |
title = gr.Textbox(label="Title", max_lines=1) | |
authors = gr.Textbox(label="Authors", lines=5) | |
arxiv_id = gr.Textbox(label="arXiv ID", max_lines=1, placeholder="2404.00000") | |
project_page = gr.Textbox(label="Project page", max_lines=1, placeholder="https://aaa.github.io/bbb") | |
github_links = gr.Textbox( | |
label="GitHub links", | |
lines=5, | |
placeholder="https://github.com/aaa/bbb\nhttps://github.com/ccc/ddd", | |
) | |
space_ids = gr.Textbox(label="Space IDs", lines=5, placeholder="org_name1/repo_name1\norg_name2/repo_name2") | |
model_ids = gr.Textbox(label="Model IDs", lines=5, placeholder="org_name1/repo_name1\norg_name2/repo_name2") | |
dataset_ids = gr.Textbox( | |
label="Dataset IDs", lines=5, placeholder="org_name1/repo_name1\norg_name2/repo_name2" | |
) | |
create_pr_button = gr.Button("Create PR") | |
result = gr.Textbox(label="Result", max_lines=1) | |
gr.on( | |
triggers=[ | |
paper_id.submit, | |
load_button.click, | |
], | |
fn=load_data, | |
inputs=paper_id, | |
outputs=[ | |
paper_id, | |
title, | |
authors, | |
arxiv_id, | |
project_page, | |
github_links, | |
space_ids, | |
model_ids, | |
dataset_ids, | |
], | |
queue=False, | |
api_name=False, | |
) | |
create_pr_button.click( | |
fn=create_pr, | |
inputs=[ | |
paper_id, | |
title, | |
authors, | |
arxiv_id, | |
project_page, | |
github_links, | |
space_ids, | |
model_ids, | |
dataset_ids, | |
], | |
outputs=result, | |
queue=False, | |
api_name=False, | |
) | |
with gr.Blocks(css_paths="style.css") as demo: | |
gr.Markdown( | |
"You can create PRs to update the JSON files in the [ICLR2024-papers repo](https://huggingface.co/datasets/ICLR2024/ICLR2024-papers) with this Space." | |
) | |
with gr.Tabs(): | |
with gr.Tab(label="Step 1: Login"): | |
gr.Markdown("To create a PR, you first need to log in. Please press the login button below.") | |
gr.LoginButton() | |
with gr.Tab(label="Step 2: Search for paper ID"): | |
gr.Markdown("Search for the paper you would like to update and find its paper ID.") | |
demo_search.render() | |
with gr.Tab(label="Step 3: Edit and create PR"): | |
gr.Markdown("Enter the paper ID in the field below and press the Load button.") | |
gr.Markdown("After making the necessary changes, press the Create PR button.") | |
demo_edit.render() | |
if __name__ == "__main__": | |
demo.queue(api_open=False).launch(show_api=False) | |