ICLR2025 / app_pr.py
hysts's picture
hysts HF Staff
Update
1c00c70
import datetime
import difflib
import json
import re
import tempfile
import gradio as gr
import polars as pl
from gradio_modal import Modal
from huggingface_hub import CommitOperationAdd, HfApi
from table import PATCH_REPO_ID, df_orig
# TODO: remove this once https://github.com/gradio-app/gradio/issues/11022 is fixed # noqa: FIX002, TD002
NOTE = """\
#### ⚠️ Note
You may encounter an issue when selecting table data after using the search bar.
This is due to a known bug in Gradio.
The issue typically occurs when multiple rows remain after filtering.
If only one row remains, the selection should work as expected.
"""
api = HfApi()
PR_VIEW_COLUMNS = [
"title",
"authors_str",
"openreview_md",
"arxiv_id",
"github_md",
"Spaces",
"Models",
"Datasets",
"paper_id",
]
PR_RAW_COLUMNS = [
"paper_id",
"title",
"authors",
"arxiv_id",
"project_page",
"github",
"space_ids",
"model_ids",
"dataset_ids",
]
df_pr_view = df_orig.with_columns(pl.lit("📝").alias("Fix")).select(["Fix", *PR_VIEW_COLUMNS])
df_pr_view = df_pr_view.with_columns(pl.col("arxiv_id").fill_null(""))
df_pr_raw = df_orig.select(PR_RAW_COLUMNS)
def df_pr_row_selected(
evt: gr.SelectData,
) -> tuple[
Modal,
gr.Textbox, # title
gr.Textbox, # authors
gr.Textbox, # arxiv_id
gr.Textbox, # project_page
gr.Textbox, # github
gr.Textbox, # space_ids
gr.Textbox, # model_ids
gr.Textbox, # dataset_ids
dict | None, # original_data
]:
if evt.value != "📝":
return (
Modal(),
gr.Textbox(), # title
gr.Textbox(), # authors
gr.Textbox(), # arxiv_id
gr.Textbox(), # project_page
gr.Textbox(), # github
gr.Textbox(), # space_ids
gr.Textbox(), # model_ids
gr.Textbox(), # dataset_ids
None, # original_data
)
paper_id = evt.row_value[-1]
row = df_pr_raw.filter(pl.col("paper_id") == paper_id)
original_data = row.to_dicts()[0]
authors = original_data["authors"]
space_ids = original_data["space_ids"]
model_ids = original_data["model_ids"]
dataset_ids = original_data["dataset_ids"]
return (
Modal(visible=True),
gr.Textbox(value=row["title"].item()), # title
gr.Textbox(value="\n".join(authors)), # authors
gr.Textbox(value=row["arxiv_id"].item()), # arxiv_id
gr.Textbox(value=row["project_page"].item()), # project_page
gr.Textbox(value=row["github"].item()), # github
gr.Textbox(value="\n".join(space_ids)), # space_ids
gr.Textbox(value="\n".join(model_ids)), # model_ids
gr.Textbox(value="\n".join(dataset_ids)), # dataset_ids
original_data, # original_data
)
URL_PATTERN = re.compile(r"^(https?://)?([a-zA-Z0-9-]+\.)+[a-zA-Z]{2,}(:\d+)?(/.*)?$")
GITHUB_PATTERN = re.compile(r"^https://github\.com/[^/\s]+/[^/\s]+(/tree/[^/\s]+/[^/\s].*)?$")
REPO_ID_PATTERN = re.compile(r"^[a-zA-Z0-9_-]+/[a-zA-Z0-9_-]+$")
ARXIV_ID_PATTERN = re.compile(r"^\d{4}\.\d{4,5}$")
def is_valid_url(url: str) -> bool:
return URL_PATTERN.match(url) is not None
def is_valid_github_url(url: str) -> bool:
return GITHUB_PATTERN.match(url) is not None
def is_valid_repo_id(repo_id: str) -> bool:
return REPO_ID_PATTERN.match(repo_id) is not None
def is_valid_arxiv_id(arxiv_id: str) -> bool:
return ARXIV_ID_PATTERN.match(arxiv_id) is not None
def validate_pr_data(
title_pr: str,
authors_pr: str,
arxiv_id_pr: str,
project_page_pr: str,
github_pr: str,
space_ids: list[str],
model_ids: list[str],
dataset_ids: list[str],
) -> None:
if not title_pr:
raise gr.Error("Title cannot be empty", print_exception=False)
if not authors_pr:
raise gr.Error("Authors cannot be empty", print_exception=False)
if arxiv_id_pr and not is_valid_arxiv_id(arxiv_id_pr):
raise gr.Error(
"Invalid arXiv ID format. Expected format: 'YYYY.NNNNN' (e.g., '2023.01234')", print_exception=False
)
if project_page_pr and not is_valid_url(project_page_pr):
raise gr.Error("Project page must be a valid URL", print_exception=False)
if github_pr and not is_valid_github_url(github_pr):
raise gr.Error("GitHub must be a valid GitHub URL", print_exception=False)
for repo_id in space_ids + model_ids + dataset_ids:
if not is_valid_repo_id(repo_id):
error_msg = f"Space/Model/Dataset ID must be in the format 'org_name/repo_name'. Got: {repo_id}"
raise gr.Error(error_msg, print_exception=False)
def format_submitted_data(
title_pr: str,
authors_pr: str,
arxiv_id_pr: str,
project_page_pr: str,
github_pr: str,
space_ids_pr: str,
model_ids_pr: str,
dataset_ids_pr: str,
) -> dict:
space_ids = [repo_id for repo_id in space_ids_pr.split("\n") if repo_id.strip()]
model_ids = [repo_id for repo_id in model_ids_pr.split("\n") if repo_id.strip()]
dataset_ids = [repo_id for repo_id in dataset_ids_pr.split("\n") if repo_id.strip()]
validate_pr_data(title_pr, authors_pr, arxiv_id_pr, project_page_pr, github_pr, space_ids, model_ids, dataset_ids)
return {
"title": title_pr,
"authors": [a for a in authors_pr.split("\n") if a.strip()],
"arxiv_id": arxiv_id_pr if arxiv_id_pr else None,
"project_page": project_page_pr if project_page_pr else None,
"github": github_pr if github_pr else None,
"space_ids": space_ids,
"model_ids": model_ids,
"dataset_ids": dataset_ids,
}
def preview_diff(
title_pr: str,
authors_pr: str,
arxiv_id_pr: str,
project_page_pr: str,
github_pr: str,
space_ids_pr: str,
model_ids_pr: str,
dataset_ids_pr: str,
original_data: dict,
) -> tuple[gr.Markdown, gr.Button]:
submitted_data = format_submitted_data(
title_pr,
authors_pr,
arxiv_id_pr,
project_page_pr,
github_pr,
space_ids_pr,
model_ids_pr,
dataset_ids_pr,
)
submitted_data = {"paper_id": original_data["paper_id"], **submitted_data}
original_json = json.dumps(original_data, indent=2)
submitted_json = json.dumps(submitted_data, indent=2)
diff = difflib.unified_diff(
original_json.splitlines(),
submitted_json.splitlines(),
fromfile="before",
tofile="after",
lineterm="",
)
diff_str = "\n".join(diff)
return gr.Markdown(value=f"```diff\n{diff_str}\n```"), gr.Button(visible=True)
def open_pr(
title_pr: str,
authors_pr: str,
arxiv_id_pr: str,
project_page_pr: str,
github_pr: str,
space_ids_pr: str,
model_ids_pr: str,
dataset_ids_pr: str,
original_data: dict,
oauth_token: gr.OAuthToken | None,
) -> gr.Markdown:
submitted_data = format_submitted_data(
title_pr,
authors_pr,
arxiv_id_pr,
project_page_pr,
github_pr,
space_ids_pr,
model_ids_pr,
dataset_ids_pr,
)
diff_dict = {key: submitted_data[key] for key in submitted_data if submitted_data[key] != original_data[key]}
if not diff_dict:
gr.Info("No data to submit")
return ""
paper_id = original_data["paper_id"]
diff_dict["paper_id"] = paper_id
original_json = json.dumps(original_data, indent=2)
submitted_json = json.dumps(submitted_data, indent=2)
diff = "\n".join(
difflib.unified_diff(
original_json.splitlines(),
submitted_json.splitlines(),
fromfile="before",
tofile="after",
lineterm="",
)
)
diff_dict["diff"] = diff
timestamp = datetime.datetime.now(datetime.timezone.utc)
diff_dict["timestamp"] = timestamp.isoformat()
with tempfile.NamedTemporaryFile(suffix=".json", mode="w", delete=False) as f:
json.dump(diff_dict, f, indent=2)
f.flush()
commit = CommitOperationAdd(f"data/{paper_id}--{timestamp.strftime('%Y-%m-%d-%H-%M-%S')}.json", f.name)
res = api.create_commit(
repo_id=PATCH_REPO_ID,
operations=[commit],
commit_message=f"Update {paper_id}",
repo_type="dataset",
create_pr=True,
token=oauth_token.token if oauth_token else None,
)
return gr.Markdown(value=res.pr_url, visible=True)
def render_open_pr_page(profile: gr.OAuthProfile | None) -> dict:
return gr.Column(visible=profile is not None)
with gr.Blocks() as demo:
gr.LoginButton()
with gr.Column(visible=False) as open_pr_col:
gr.Markdown(NOTE)
df_pr = gr.Dataframe(
value=df_pr_view,
datatype=[
"str", # Fix
"str", # Title
"str", # Authors
"markdown", # openreview
"str", # arxiv_id
"markdown", # github
"markdown", # spaces
"markdown", # models
"markdown", # datasets
"str", # paper id
],
column_widths=[
"50px", # Fix
"40%", # Title
"20%", # Authors
None, # openreview
"100px", # arxiv_id
None, # github
None, # spaces
None, # models
None, # datasets
None, # paper id
],
type="polars",
row_count=(0, "dynamic"),
interactive=False,
max_height=1000,
show_search="search",
)
with Modal(visible=False) as pr_modal:
with gr.Group():
title_pr = gr.Textbox(label="Title")
authors_pr = gr.Textbox(label="Authors")
arxiv_id_pr = gr.Textbox(label="arXiv ID")
project_page_pr = gr.Textbox(label="Project page")
github_pr = gr.Textbox(label="GitHub")
spaces_pr = gr.Textbox(
label="Spaces",
info="Enter one space ID (e.g., 'org_name/space_name') per line.",
)
models_pr = gr.Textbox(
label="Models",
info="Enter one model ID (e.g., 'org_name/model_name') per line.",
)
datasets_pr = gr.Textbox(
label="Datasets",
info="Enter one dataset ID (e.g., 'org_name/dataset_name') per line.",
)
original_data = gr.State()
preview_diff_button = gr.Button("Preview diff")
diff_view = gr.Markdown()
open_pr_button = gr.Button("Open PR", visible=False)
pr_url = gr.Markdown(visible=False)
pr_modal.blur(
fn=lambda: (None, gr.Button(visible=False), gr.Markdown(visible=False)),
outputs=[diff_view, open_pr_button, pr_url],
)
df_pr.select(
fn=df_pr_row_selected,
outputs=[
pr_modal,
title_pr,
authors_pr,
arxiv_id_pr,
project_page_pr,
github_pr,
spaces_pr,
models_pr,
datasets_pr,
original_data,
],
)
preview_diff_button.click(
fn=preview_diff,
inputs=[
title_pr,
authors_pr,
arxiv_id_pr,
project_page_pr,
github_pr,
spaces_pr,
models_pr,
datasets_pr,
original_data,
],
outputs=[diff_view, open_pr_button],
)
open_pr_button.click(
fn=open_pr,
inputs=[
title_pr,
authors_pr,
arxiv_id_pr,
project_page_pr,
github_pr,
spaces_pr,
models_pr,
datasets_pr,
original_data,
],
outputs=pr_url,
)
demo.load(fn=render_open_pr_page, outputs=open_pr_col)
if __name__ == "__main__":
demo.queue(api_open=False).launch(show_api=False)