Spaces:
Running
on
Zero
Running
on
Zero
#!/usr/bin/env python | |
import gradio as gr | |
import polars as pl | |
from gradio_modal import Modal | |
from app_pr import demo as demo_pr | |
from semantic_search import semantic_search | |
from table import df_orig | |
DESCRIPTION = "# ICLR 2025" | |
TUTORIAL = """\ | |
#### Claiming Authorship for Papers on arXiv | |
If your ICLR 2025 paper is available on arXiv and listed in the table below, you can claim authorship by following these steps: | |
1. Find your paper in the table. | |
2. Click the link to the paper page in the table. | |
3. On that page, click your name. | |
4. Click **"Claim authorship"**. | |
- You'll be redirected to the *Papers* section of your Settings. | |
5. Confirm the request on the redirected page. | |
The admin team will review your request shortly. | |
Once confirmed, your paper page will be marked as verified, and you'll be able to add a project page and a GitHub repository. | |
If you need further help, check out the [guide here](https://huggingface.co/docs/hub/paper-pages). | |
#### Updating Missing or Incorrect Information in the Table | |
If you notice any missing or incorrect information in the table, feel free to submit a PR via the "Open PR" page, which you can find at the top right of this page. | |
""" | |
# TODO: remove this once https://github.com/gradio-app/gradio/issues/10916 https://github.com/gradio-app/gradio/issues/11001 https://github.com/gradio-app/gradio/issues/11002 are fixed # noqa: TD002, FIX002 | |
NOTE = """\ | |
Note: Sorting by upvotes or comments may not work correctly due to a known bug in Gradio. | |
""" | |
df_main = df_orig.select( | |
"title", | |
"authors_str", | |
"openreview_md", | |
"type", | |
"paper_page_md", | |
"upvotes", | |
"num_comments", | |
"project_page_md", | |
"github_md", | |
"Spaces", | |
"Models", | |
"Datasets", | |
"claimed", | |
"abstract", | |
"paper_id", | |
) | |
df_main = df_main.rename( | |
{ | |
"title": "Title", | |
"authors_str": "Authors", | |
"openreview_md": "OpenReview", | |
"type": "Type", | |
"paper_page_md": "Paper page", | |
"upvotes": "👍", | |
"num_comments": "💬", | |
"project_page_md": "Project page", | |
"github_md": "GitHub", | |
} | |
) | |
COLUMN_INFO = { | |
"Title": ("str", "40%"), | |
"Authors": ("str", "20%"), | |
"Type": ("str", None), | |
"Paper page": ("markdown", "135px"), | |
"👍": ("number", "50px"), | |
"💬": ("number", "50px"), | |
"OpenReview": ("markdown", None), | |
"Project page": ("markdown", None), | |
"GitHub": ("markdown", None), | |
"Spaces": ("markdown", None), | |
"Models": ("markdown", None), | |
"Datasets": ("markdown", None), | |
"claimed": ("markdown", None), | |
} | |
DEFAULT_COLUMNS = [ | |
"Title", | |
"Type", | |
"Paper page", | |
"👍", | |
"💬", | |
"OpenReview", | |
"Project page", | |
"GitHub", | |
"Spaces", | |
"Models", | |
"Datasets", | |
] | |
def update_num_papers(df: pl.DataFrame) -> str: | |
if "claimed" in df.columns: | |
return f"{len(df)} / {len(df_main)} ({df.select(pl.col('claimed').str.contains('✅').sum()).item()} claimed)" | |
return f"{len(df)} / {len(df_main)}" | |
def update_df( | |
search_mode: str, | |
search_query: str, | |
candidate_pool_size: int, | |
score_threshold: float, | |
presentation_type: str, | |
column_names: list[str], | |
case_insensitive: bool = True, | |
) -> gr.Dataframe: | |
df = df_main.clone() | |
column_names = ["Title", *column_names] | |
if search_query: | |
if search_mode == "Title Search": | |
if case_insensitive: | |
search_query = f"(?i){search_query}" | |
try: | |
df = df.filter(pl.col("Title").str.contains(search_query)) | |
except pl.exceptions.ComputeError as e: | |
raise gr.Error(str(e)) from e | |
else: | |
paper_ids, scores = semantic_search(search_query, candidate_pool_size, score_threshold) | |
if not paper_ids: | |
df = df.head(0) | |
else: | |
df = pl.DataFrame({"paper_id": paper_ids, "score": scores}).join(df, on="paper_id", how="inner") | |
df = df.sort("score", descending=True).drop("score") | |
if presentation_type != "(ALL)": | |
df = df.filter(pl.col("Type").str.contains(presentation_type)) | |
sorted_column_names = [col for col in COLUMN_INFO if col in column_names] | |
df = df.select(sorted_column_names) | |
return gr.Dataframe( | |
value=df, | |
datatype=[COLUMN_INFO[col][0] for col in sorted_column_names], | |
column_widths=[COLUMN_INFO[col][1] for col in sorted_column_names], | |
) | |
def update_search_mode(search_mode: str) -> gr.Accordion: | |
return gr.Accordion(visible=search_mode == "Semantic Search") | |
def df_row_selected( | |
evt: gr.SelectData, | |
) -> tuple[ | |
Modal, | |
gr.Textbox, # title | |
gr.Textbox, # abstract | |
]: | |
if evt.index[1] != 0: | |
return Modal(), gr.Textbox(), gr.Textbox() | |
title = evt.row_value[0] | |
row = df_main.filter(pl.col("Title") == title) | |
return ( | |
Modal(visible=True), | |
gr.Textbox(value=row["Title"].item()), # title | |
gr.Textbox(value=row["abstract"].item()), # abstract | |
) | |
with gr.Blocks(css_paths="style.css") as demo: | |
gr.Markdown(DESCRIPTION) | |
with gr.Accordion(label="Tutorial", open=True): | |
gr.Markdown(TUTORIAL) | |
with gr.Group(): | |
search_mode = gr.Radio( | |
label="Search Mode", | |
choices=["Semantic Search", "Title Search"], | |
value="Semantic Search", | |
show_label=False, | |
info="Note: Semantic search consumes your ZeroGPU quota.", | |
) | |
search_query = gr.Textbox(label="Search", submit_btn=True, show_label=False, placeholder="Enter query here") | |
with gr.Accordion(label="Advanced Search Options", open=False) as advanced_search_options: | |
with gr.Row(): | |
candidate_pool_size = gr.Slider( | |
label="Candidate Pool Size", minimum=1, maximum=1000, step=1, value=300 | |
) | |
score_threshold = gr.Slider(label="Score Threshold", minimum=0, maximum=1, step=0.01, value=0.5) | |
presentation_type = gr.Radio( | |
label="Presentation Type", | |
choices=["(ALL)", "Oral", "Spotlight", "Poster"], | |
value="(ALL)", | |
) | |
column_names = gr.CheckboxGroup( | |
label="Columns", | |
choices=[col for col in COLUMN_INFO if col != "Title"], | |
value=[col for col in DEFAULT_COLUMNS if col != "Title"], | |
) | |
num_papers = gr.Textbox(label="Number of papers", value=update_num_papers(df_orig), interactive=False) | |
gr.Markdown(NOTE) | |
df = gr.Dataframe( | |
value=df_main, | |
datatype=list(COLUMN_INFO.values()), | |
type="polars", | |
row_count=(0, "dynamic"), | |
show_row_numbers=True, | |
interactive=False, | |
max_height=1000, | |
elem_id="table", | |
column_widths=[COLUMN_INFO[col][1] for col in COLUMN_INFO], | |
) | |
with Modal(visible=False, elem_id="abstract-modal") as abstract_modal: | |
title = gr.Textbox(label="Title") | |
abstract = gr.Textbox(label="Abstract") | |
search_mode.change( | |
fn=update_search_mode, | |
inputs=search_mode, | |
outputs=advanced_search_options, | |
) | |
df.select(fn=df_row_selected, outputs=[abstract_modal, title, abstract]) | |
inputs = [ | |
search_mode, | |
search_query, | |
candidate_pool_size, | |
score_threshold, | |
presentation_type, | |
column_names, | |
] | |
gr.on( | |
triggers=[ | |
search_query.submit, | |
presentation_type.input, | |
column_names.input, | |
], | |
fn=update_df, | |
inputs=inputs, | |
outputs=df, | |
api_name=False, | |
).then( | |
fn=update_num_papers, | |
inputs=df, | |
outputs=num_papers, | |
queue=False, | |
api_name=False, | |
) | |
demo.load( | |
fn=update_df, | |
inputs=inputs, | |
outputs=df, | |
api_name=False, | |
).then( | |
fn=update_num_papers, | |
inputs=df, | |
outputs=num_papers, | |
queue=False, | |
api_name=False, | |
) | |
with demo.route("Open PR"): | |
demo_pr.render() | |
if __name__ == "__main__": | |
demo.queue(api_open=False).launch(show_api=False) | |