import pandas as pd from typing import List, Dict, Optional from constants import ( DATASET_ARXIV_SCAN_PAPERS, DATASET_CONFERENCE_PAPERS, DATASET_COMMUNITY_SCIENCE, NEURIPS_ICO, DATASET_PAPER_CENTRAL, COLM_ICO, DEFAULT_ICO, MICCAI24ICO, CORL_ICO, ) import gradio as gr from utils import load_and_process import numpy as np from datetime import datetime, timedelta import re class PaperCentral: """ A class to manage and process paper data for display in a Gradio Dataframe component. """ CONFERENCES_ICONS = { "NeurIPS2024 D&B": NEURIPS_ICO, "NeurIPS2024": NEURIPS_ICO, "EMNLP2024": 'https://aclanthology.org/aclicon.ico', "CoRL2024": CORL_ICO, "ACMMM2024": "https://2024.acmmm.org/favicon.ico", "MICCAI2024": MICCAI24ICO, "COLM2024": COLM_ICO, "COLING2024": 'https://aclanthology.org/aclicon.ico', "CVPR2024": "https://openaccess.thecvf.com/favicon.ico", "ACL2024": 'https://aclanthology.org/aclicon.ico', "ACL2023": 'https://aclanthology.org/aclicon.ico', "CVPR2023": "https://openaccess.thecvf.com/favicon.ico", "ECCV2024": "https://openaccess.thecvf.com/favicon.ico", "EMNLP2023": 'https://aclanthology.org/aclicon.ico', "NAACL2023": 'https://aclanthology.org/aclicon.ico', "NeurIPS2023": NEURIPS_ICO, "NeurIPS2023 D&B": NEURIPS_ICO, } CONFERENCES = list(CONFERENCES_ICONS.keys()) # Class-level constants defining columns and their data types COLUMNS_START_PAPER_PAGE: List[str] = [ 'date', 'arxiv_id', 'paper_page', 'title', ] COLUMNS_ORDER_PAPER_PAGE: List[str] = [ 'chat_with_paper', 'date', 'arxiv_id', 'paper_page', 'num_models', 'num_datasets', 'num_spaces', 'upvotes', 'num_comments', 'github', 'github_stars', 'project_page', 'conference_name', 'id', 'type', 'proceedings', 'title', 'authors', ] DATATYPES: Dict[str, str] = { 'date': 'str', 'arxiv_id': 'markdown', 'paper_page': 'markdown', 'upvotes': 'number', 'num_comments': 'number', 'num_models': 'markdown', 'num_datasets': 'markdown', 'num_spaces': 'markdown', 'github': 'markdown', 'title': 'str', 'proceedings': 'markdown', 'conference_name': 'str', 'id': 'str', 'type': 'str', 'authors': 'str', 'github_stars': 'number', 'project_page': 'markdown', 'chat_with_paper': 'markdown', } # Mapping for renaming columns for display purposes COLUMN_RENAME_MAP: Dict[str, str] = { 'num_models': 'models', 'num_spaces': 'spaces', 'num_datasets': 'datasets', 'github': 'GitHub', 'github_stars': 'GitHub⭐', 'num_comments': '💬', 'upvotes': '👍', 'chat_with_paper': 'Chat', } def __init__(self): """ Initialize the PaperCentral class by loading and processing the datasets. """ self.df_raw: pd.DataFrame = self.get_df() self.df_prettified: pd.DataFrame = self.prettify(self.df_raw) @staticmethod def get_columns_order(columns: List[str]) -> List[str]: """ Get columns ordered according to COLUMNS_ORDER_PAPER_PAGE. Args: columns (List[str]): List of column names to order. Returns: List[str]: Ordered list of column names. """ return [c for c in PaperCentral.COLUMNS_ORDER_PAPER_PAGE if c in columns] @staticmethod def get_columns_datatypes(columns: List[str]) -> List[str]: """ Get data types for the specified columns. Args: columns (List[str]): List of column names. Returns: List[str]: List of data types corresponding to the columns. """ return [PaperCentral.DATATYPES[c] for c in columns] @staticmethod def get_df() -> pd.DataFrame: """ Load and merge datasets to create the raw DataFrame. Returns: pd.DataFrame: The merged and processed DataFrame. """ # Load datasets paper_central_df: pd.DataFrame = load_and_process(DATASET_PAPER_CENTRAL)[ ['arxiv_id', 'categories', 'primary_category', 'date', 'upvotes', 'num_comments', 'github', 'num_models', 'num_datasets', 'num_spaces', 'id', 'proceedings', 'type', 'conference_name', 'title', 'paper_page', 'authors', 'github_stars', 'project_page'] ] # If arxiv published_date is weekend, switch to Monday def adjust_date(dt): if dt.weekday() == 5: # Saturday return dt + pd.Timedelta(days=2) elif dt.weekday() == 6: # Sunday return dt + pd.Timedelta(days=1) else: return dt # Convert 'date' column to datetime paper_central_df['date'] = pd.to_datetime(paper_central_df['date'], format='%Y-%m-%d') paper_central_df['date'] = paper_central_df['date'].apply(adjust_date) paper_central_df['date'] = paper_central_df['date'].dt.strftime('%Y-%m-%d') return paper_central_df @staticmethod def format_df_date(df: pd.DataFrame, date_column: str = "date") -> pd.DataFrame: """ Format the date column in the DataFrame to 'YYYY-MM-DD'. Args: df (pd.DataFrame): The DataFrame to format. date_column (str): The name of the date column. Returns: pd.DataFrame: The DataFrame with the formatted date column. """ df.loc[:, date_column] = pd.to_datetime(df[date_column]).dt.strftime('%Y-%m-%d') return df @staticmethod def prettify(df: pd.DataFrame) -> pd.DataFrame: """ Prettify the DataFrame by adding markdown links and sorting. Args: df (pd.DataFrame): The DataFrame to prettify. Returns: pd.DataFrame: The prettified DataFrame. """ def update_row(row: pd.Series) -> pd.Series: """ Update a row by adding markdown links to 'paper_page' and 'arxiv_id' columns. Args: row (pd.Series): A row from the DataFrame. Returns: pd.Series: The updated row. """ # Process 'num_models' column if ( 'num_models' in row and pd.notna(row['num_models']) and row["arxiv_id"] and float(row['num_models']) > 0 ): num_models = int(float(row['num_models'])) row['num_models'] = ( f"[{num_models}](https://huggingface.co/models?other=arxiv:{row['arxiv_id']})" ) if ( 'num_datasets' in row and pd.notna(row['num_datasets']) and row["arxiv_id"] and float(row['num_datasets']) > 0 ): num_datasets = int(float(row['num_datasets'])) row['num_datasets'] = ( f"[{num_datasets}](https://huggingface.co/datasets?other=arxiv:{row['arxiv_id']})" ) if ( 'num_spaces' in row and pd.notna(row['num_spaces']) and row["arxiv_id"] and float(row['num_spaces']) > 0 ): num_spaces = int(float(row['num_spaces'])) row['num_spaces'] = ( f"[{num_spaces}](https://huggingface.co/spaces?other=arxiv:{row['arxiv_id']})" ) if 'proceedings' in row and pd.notna(row['proceedings']) and row['proceedings']: image_url = PaperCentral.CONFERENCES_ICONS.get(row["conference_name"], DEFAULT_ICO) style = "display:inline-block; vertical-align:middle; width: 16px; height:16px" row['proceedings'] = ( f"" f"proc_page" ) #### ### This should be processed last :) #### # Add markdown link to 'paper_page' if it exists if 'paper_page' in row and pd.notna(row['paper_page']) and row['paper_page']: row['paper_page'] = f"🤗[paper_page](https://huggingface.co/papers/{row['paper_page']})" # Add image and link to 'arxiv_id' if it exists if 'arxiv_id' in row and pd.notna(row['arxiv_id']) and row['arxiv_id']: image_url = "https://arxiv.org/static/browse/0.3.4/images/icons/favicon-16x16.png" style = "display:inline-block; vertical-align:middle;" row['arxiv_id'] = ( f"" f"arxiv_page" ) # Add image and link to 'arxiv_id' if it exists if 'github' in row and pd.notna(row['github']) and row["github"]: image_url = "https://github.githubassets.com/favicons/favicon.png" style = "display:inline-block; vertical-align:middle;width:16px;" row['github'] = ( f"" f"github" ) if 'project_page' in row and pd.notna(row['project_page']) and row["project_page"]: row['project_page'] = ( f"{row['project_page']}" ) return row df = df.copy() # Apply the update_row function to each row prettified_df: pd.DataFrame = df.apply(update_row, axis=1) return prettified_df def rename_columns_for_display(self, df: pd.DataFrame) -> pd.DataFrame: """ Rename columns in the DataFrame according to COLUMN_RENAME_MAP for display purposes. Args: df (pd.DataFrame): The DataFrame whose columns need to be renamed. Returns: pd.DataFrame: The DataFrame with renamed columns. """ return df.rename(columns=self.COLUMN_RENAME_MAP) def filter( self, selected_date: Optional[str] = None, cat_options: Optional[List[str]] = None, hf_options: Optional[List[str]] = None, conference_options: Optional[List[str]] = None, author_search_input: Optional[str] = None, title_search_input: Optional[str] = None, date_range_option: Optional[str] = None, ) -> gr.update: """ Filter the DataFrame based on selected date and options, and prepare it for display. """ filtered_df: pd.DataFrame = self.df_raw.copy() # Start with the initial columns to display columns_to_show: List[str] = PaperCentral.COLUMNS_START_PAPER_PAGE.copy() # Handle title search if title_search_input: if 'title' not in columns_to_show: columns_to_show.append('authors') search_string = title_search_input.lower() def title_match(title): if isinstance(title, str): return search_string in title.lower() else: return False filtered_df = filtered_df[filtered_df['title'].apply(title_match)] # Handle author search if author_search_input: if 'authors' not in columns_to_show: columns_to_show.append('authors') search_string = author_search_input.lower() def author_matches(authors_list): if authors_list is None or len(authors_list) == 0: return False if isinstance(authors_list, (list, tuple, pd.Series, np.ndarray)): return any( isinstance(author, str) and search_string in author.lower() for author in authors_list ) elif isinstance(authors_list, str): return search_string in authors_list.lower() else: return False filtered_df = filtered_df[filtered_df['authors'].apply(author_matches)] # Handle category options if cat_options: if "(ALL)" in cat_options: # If "(ALL)" is selected, include all categories without filtering pass # No action needed, include all categories else: # Proceed with filtering based on selected categories options = [o.replace(".*", "") for o in cat_options] conference_filter = pd.Series(False, index=filtered_df.index) for option in options: conference_filter |= ( filtered_df['primary_category'].notna() & filtered_df['primary_category'].str.contains(option, case=False) ) filtered_df = filtered_df[conference_filter] # Handle date filtering if not conference_options: if date_range_option: today = datetime.now() if date_range_option == "This week": start_date = (today - timedelta(days=7)).strftime('%Y-%m-%d') end_date = today.strftime('%Y-%m-%d') elif date_range_option == "This month": start_date = (today - timedelta(days=30)).strftime('%Y-%m-%d') end_date = today.strftime('%Y-%m-%d') elif date_range_option == "This year": start_date = (today - timedelta(days=365)).strftime('%Y-%m-%d') end_date = today.strftime('%Y-%m-%d') elif date_range_option == "All time": start_date = None end_date = None else: start_date = None end_date = None if start_date and end_date: filtered_df = filtered_df[ (filtered_df['date'] >= start_date) & (filtered_df['date'] <= end_date) ] else: pass # No date filtering for "All time" elif selected_date: selected_date = pd.to_datetime(selected_date).strftime('%Y-%m-%d') filtered_df = filtered_df[filtered_df['date'] == selected_date] # Handle Hugging Face options if hf_options: # Convert columns to numeric, handling non-numeric values filtered_df['num_datasets'] = pd.to_numeric(filtered_df['num_datasets'], errors='coerce').fillna(0).astype( int) filtered_df['num_models'] = pd.to_numeric(filtered_df['num_models'], errors='coerce').fillna(0).astype(int) filtered_df['num_spaces'] = pd.to_numeric(filtered_df['num_spaces'], errors='coerce').fillna(0).astype(int) if "🤗 artifacts" in hf_options: filtered_df = filtered_df[ (filtered_df['paper_page'] != "") & (filtered_df['paper_page'].notna()) ] if 'upvotes' not in columns_to_show: columns_to_show.append('upvotes') if 'num_comments' not in columns_to_show: columns_to_show.append('num_comments') if 'num_models' not in columns_to_show: columns_to_show.append('num_models') if 'num_datasets' not in columns_to_show: columns_to_show.append('num_datasets') if 'num_spaces' not in columns_to_show: columns_to_show.append('num_spaces') filtered_df = filtered_df[ (filtered_df['num_datasets'] > 0) | (filtered_df['num_models'] > 0) | (filtered_df['num_spaces'] > 0) ] if "datasets" in hf_options: if 'num_datasets' not in columns_to_show: columns_to_show.append('num_datasets') filtered_df = filtered_df[filtered_df['num_datasets'] != 0] if "models" in hf_options: if 'num_models' not in columns_to_show: columns_to_show.append('num_models') filtered_df = filtered_df[filtered_df['num_models'] != 0] if "spaces" in hf_options: if 'num_spaces' not in columns_to_show: columns_to_show.append('num_spaces') filtered_df = filtered_df[filtered_df['num_spaces'] != 0] if "github" in hf_options: if 'github' not in columns_to_show: columns_to_show.append('github') columns_to_show.append('github_stars') filtered_df = filtered_df[(filtered_df['github'] != "") & (filtered_df['github'].notnull())] if "project page" in hf_options: if 'project_page' not in columns_to_show: columns_to_show.append('project_page') filtered_df = filtered_df[(filtered_df['project_page'] != "") & (filtered_df['project_page'].notnull())] # Apply conference filtering if conference_options: columns_to_show = [col for col in columns_to_show if col not in ["date", "arxiv_id"]] if 'conference_name' not in columns_to_show: columns_to_show.append('conference_name') if 'proceedings' not in columns_to_show: columns_to_show.append('proceedings') if 'type' not in columns_to_show: columns_to_show.append('type') if 'id' not in columns_to_show: columns_to_show.append('id') if "ALL" in conference_options: filtered_df = filtered_df[ filtered_df['conference_name'].notna() & (filtered_df['conference_name'] != "") ] other_conferences = [conf for conf in conference_options if conf != "ALL"] if other_conferences: conference_filter = pd.Series(False, index=filtered_df.index) for conference in other_conferences: conference_filter |= ( filtered_df['conference_name'].notna() & (filtered_df['conference_name'].str.lower() == conference.lower()) ) filtered_df = filtered_df[conference_filter] if any(conf in ["NeurIPS2024 D&B", "NeurIPS2024"] for conf in conference_options): def create_chat_link(row): neurips_id = re.search(r'id=([^&]+)', row["proceedings"]) if neurips_id: neurips_id = neurips_id.group(1) return f'✨ Chat with paper' else: return "" # Add the "chat_with_paper" column filtered_df['chat_with_paper'] = filtered_df.apply(create_chat_link, axis=1) if 'chat_with_paper' not in columns_to_show: columns_to_show.append('chat_with_paper') # Prettify the DataFrame filtered_df = self.prettify(filtered_df) # Ensure columns are ordered according to COLUMNS_ORDER_PAPER_PAGE columns_in_order: List[str] = [col for col in PaperCentral.COLUMNS_ORDER_PAPER_PAGE if col in columns_to_show] filtered_df = filtered_df[columns_in_order] # Rename columns for display filtered_df = self.rename_columns_for_display(filtered_df) # Get the corresponding data types for the columns new_datatypes: List[str] = [ PaperCentral.DATATYPES.get(self._get_original_column_name(col), 'str') for col in filtered_df.columns ] # Sort rows to display entries with 'paper_page' first if 'paper_page' in filtered_df.columns: filtered_df['has_paper_page'] = filtered_df['paper_page'].notna() & (filtered_df['paper_page'] != "") filtered_df.sort_values(by='has_paper_page', ascending=False, inplace=True) filtered_df.drop(columns='has_paper_page', inplace=True) # Return an update object to modify the Dataframe component return gr.update(value=filtered_df, datatype=new_datatypes) def _get_original_column_name(self, display_column_name: str) -> str: """ Retrieve the original column name given a display column name. Args: display_column_name (str): The display name of the column. Returns: str: The original name of the column. """ inverse_map = {v: k for k, v in self.COLUMN_RENAME_MAP.items()} return inverse_map.get(display_column_name, display_column_name)