hysts's picture
hysts HF Staff
Update
ae46ac4
#!/usr/bin/env python
import copy
import json
import pathlib
import tempfile
import gradio as gr
from huggingface_hub import CommitOperationAdd, HfApi
from papers import PaperList
REPO_ID = "ICLR2024/ICLR2024-papers"
FILENAME = "data.json"
api = HfApi()
paper_list = PaperList()
path = api.hf_hub_download(repo_id=REPO_ID, filename=FILENAME, repo_type="dataset")
with pathlib.Path(path).open() as f:
raw_data = json.load(f)
paper_id_to_index = {str(paper["id"]): i for i, paper in enumerate(raw_data)}
with gr.Blocks() as demo_search:
with gr.Group():
search_title = gr.Textbox(label="Search title")
search_author = gr.Textbox(label="Search author")
df = gr.Dataframe(
value=paper_list.df_prettified,
datatype=paper_list.get_column_datatypes(paper_list.get_column_names()),
type="pandas",
row_count=(0, "dynamic"),
interactive=False,
max_height=1000,
elem_id="table",
wrap=True,
)
inputs = [
search_title,
search_author,
]
gr.on(
triggers=[
search_title.submit,
search_author.submit,
],
fn=paper_list.search,
inputs=inputs,
outputs=df,
queue=False,
api_name=False,
)
demo_search.load(
fn=paper_list.search,
inputs=inputs,
outputs=df,
queue=False,
api_name=False,
)
def load_data(paper_id: str) -> tuple[str, str, str, str, str, str, str, str, str]:
try:
index = paper_id_to_index[paper_id]
except KeyError as e:
error_message = f"Paper ID {paper_id} not found."
raise gr.Error(error_message) from e
paper = raw_data[index]
return (
paper["id"],
paper["title"],
"\n".join(paper["authors"]),
paper["arxiv_id"],
paper["project_page"],
"\n".join(paper["GitHub"]),
"\n".join(paper["Space"]),
"\n".join(paper["Model"]),
"\n".join(paper["Dataset"]),
)
def split_and_strip(s: str) -> list[str]:
return [x.strip() for x in s.split("\n") if x.strip()]
def create_pr(
paper_id: str,
title: str,
authors: str,
arxiv_id: str,
project_page: str,
github_links: str,
space_ids: str,
model_ids: str,
dataset_ids: str,
oauth_token: gr.OAuthToken | None,
) -> str:
if oauth_token is None:
return "Please log in first."
try:
index = paper_id_to_index[paper_id]
except KeyError as e:
error_message = f"Paper ID {paper_id} not found."
raise gr.Error(error_message) from e
data = copy.deepcopy(raw_data)
data[index]["title"] = title.strip()
data[index]["authors"] = split_and_strip(authors)
data[index]["arxiv_id"] = arxiv_id.strip()
data[index]["project_page"] = project_page.strip()
data[index]["GitHub"] = split_and_strip(github_links)
data[index]["Space"] = split_and_strip(space_ids)
data[index]["Model"] = split_and_strip(model_ids)
data[index]["Dataset"] = split_and_strip(dataset_ids)
with tempfile.NamedTemporaryFile(suffix=".json", mode="w", delete=False) as f:
json.dump(data, f, indent=2)
commit = CommitOperationAdd(FILENAME, f.name)
res = api.create_commit(
repo_id=REPO_ID,
operations=[commit],
commit_message=f"Update {paper_id}",
repo_type="dataset",
create_pr=True,
token=oauth_token.token,
)
return res.pr_url
with gr.Blocks() as demo_edit:
with gr.Group():
paper_id = gr.Textbox(label="ID", max_lines=1)
load_button = gr.Button("Load")
with gr.Group():
title = gr.Textbox(label="Title", max_lines=1)
authors = gr.Textbox(label="Authors", lines=5)
arxiv_id = gr.Textbox(label="arXiv ID", max_lines=1, placeholder="2404.00000")
project_page = gr.Textbox(label="Project page", max_lines=1, placeholder="https://aaa.github.io/bbb")
github_links = gr.Textbox(
label="GitHub links",
lines=5,
placeholder="https://github.com/aaa/bbb\nhttps://github.com/ccc/ddd",
)
space_ids = gr.Textbox(label="Space IDs", lines=5, placeholder="org_name1/repo_name1\norg_name2/repo_name2")
model_ids = gr.Textbox(label="Model IDs", lines=5, placeholder="org_name1/repo_name1\norg_name2/repo_name2")
dataset_ids = gr.Textbox(
label="Dataset IDs", lines=5, placeholder="org_name1/repo_name1\norg_name2/repo_name2"
)
create_pr_button = gr.Button("Create PR")
result = gr.Textbox(label="Result", max_lines=1)
gr.on(
triggers=[
paper_id.submit,
load_button.click,
],
fn=load_data,
inputs=paper_id,
outputs=[
paper_id,
title,
authors,
arxiv_id,
project_page,
github_links,
space_ids,
model_ids,
dataset_ids,
],
queue=False,
api_name=False,
)
create_pr_button.click(
fn=create_pr,
inputs=[
paper_id,
title,
authors,
arxiv_id,
project_page,
github_links,
space_ids,
model_ids,
dataset_ids,
],
outputs=result,
queue=False,
api_name=False,
)
with gr.Blocks(css_paths="style.css") as demo:
gr.Markdown(
"You can create PRs to update the JSON files in the [ICLR2024-papers repo](https://huggingface.co/datasets/ICLR2024/ICLR2024-papers) with this Space."
)
with gr.Tabs():
with gr.Tab(label="Step 1: Login"):
gr.Markdown("To create a PR, you first need to log in. Please press the login button below.")
gr.LoginButton()
with gr.Tab(label="Step 2: Search for paper ID"):
gr.Markdown("Search for the paper you would like to update and find its paper ID.")
demo_search.render()
with gr.Tab(label="Step 3: Edit and create PR"):
gr.Markdown("Enter the paper ID in the field below and press the Load button.")
gr.Markdown("After making the necessary changes, press the Create PR button.")
demo_edit.render()
if __name__ == "__main__":
demo.queue(api_open=False).launch(show_api=False)